Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-2.4.4/include/armadillo_bits/op_dot_meat.hpp @ 0:8b6102e2a9b0
Armadillo Library
author | maxzanoni76 <max.zanoni@eecs.qmul.ac.uk> |
---|---|
date | Wed, 11 Apr 2012 09:27:06 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:8b6102e2a9b0 |
---|---|
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2008-2011 Conrad Sanderson | |
3 // | |
4 // This file is part of the Armadillo C++ library. | |
5 // It is provided without any warranty of fitness | |
6 // for any purpose. You can redistribute this file | |
7 // and/or modify it under the terms of the GNU | |
8 // Lesser General Public License (LGPL) as published | |
9 // by the Free Software Foundation, either version 3 | |
10 // of the License or (at your option) any later version. | |
11 // (see http://www.opensource.org/licenses for more info) | |
12 | |
13 | |
14 //! \addtogroup op_dot | |
15 //! @{ | |
16 | |
17 | |
18 | |
19 | |
20 //! for two arrays, generic version | |
21 template<typename eT> | |
22 arma_hot | |
23 arma_pure | |
24 inline | |
25 eT | |
26 op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) | |
27 { | |
28 arma_extra_debug_sigprint(); | |
29 | |
30 eT val1 = eT(0); | |
31 eT val2 = eT(0); | |
32 | |
33 uword i, j; | |
34 | |
35 for(i=0, j=1; j<n_elem; i+=2, j+=2) | |
36 { | |
37 val1 += A[i] * B[i]; | |
38 val2 += A[j] * B[j]; | |
39 } | |
40 | |
41 if(i < n_elem) | |
42 { | |
43 val1 += A[i] * B[i]; | |
44 } | |
45 | |
46 return val1 + val2; | |
47 } | |
48 | |
49 | |
50 | |
51 //! for two arrays, float and double version | |
52 template<typename eT> | |
53 arma_hot | |
54 arma_pure | |
55 inline | |
56 typename arma_float_only<eT>::result | |
57 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) | |
58 { | |
59 arma_extra_debug_sigprint(); | |
60 | |
61 if( n_elem <= (128/sizeof(eT)) ) | |
62 { | |
63 return op_dot::direct_dot_arma(n_elem, A, B); | |
64 } | |
65 else | |
66 { | |
67 #if defined(ARMA_USE_ATLAS) | |
68 { | |
69 arma_extra_debug_print("atlas::cblas_dot()"); | |
70 | |
71 return atlas::cblas_dot(n_elem, A, B); | |
72 } | |
73 #elif defined(ARMA_USE_BLAS) | |
74 { | |
75 arma_extra_debug_print("blas::dot()"); | |
76 | |
77 return blas::dot(n_elem, A, B); | |
78 } | |
79 #else | |
80 { | |
81 return op_dot::direct_dot_arma(n_elem, A, B); | |
82 } | |
83 #endif | |
84 } | |
85 } | |
86 | |
87 | |
88 | |
89 //! for two arrays, complex version | |
90 template<typename eT> | |
91 inline | |
92 arma_hot | |
93 arma_pure | |
94 typename arma_cx_only<eT>::result | |
95 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) | |
96 { | |
97 #if defined(ARMA_USE_ATLAS) | |
98 { | |
99 arma_extra_debug_print("atlas::cx_cblas_dot()"); | |
100 | |
101 return atlas::cx_cblas_dot(n_elem, A, B); | |
102 } | |
103 #elif defined(ARMA_USE_BLAS) | |
104 { | |
105 // TODO: work out the mess with zdotu() and zdotu_sub() in BLAS | |
106 return op_dot::direct_dot_arma(n_elem, A, B); | |
107 } | |
108 #else | |
109 { | |
110 return op_dot::direct_dot_arma(n_elem, A, B); | |
111 } | |
112 #endif | |
113 } | |
114 | |
115 | |
116 | |
117 //! for two arrays, integral version | |
118 template<typename eT> | |
119 arma_hot | |
120 arma_pure | |
121 inline | |
122 typename arma_integral_only<eT>::result | |
123 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) | |
124 { | |
125 return op_dot::direct_dot_arma(n_elem, A, B); | |
126 } | |
127 | |
128 | |
129 | |
130 | |
131 //! for three arrays | |
132 template<typename eT> | |
133 arma_hot | |
134 arma_pure | |
135 inline | |
136 eT | |
137 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C) | |
138 { | |
139 arma_extra_debug_sigprint(); | |
140 | |
141 eT val = eT(0); | |
142 | |
143 for(uword i=0; i<n_elem; ++i) | |
144 { | |
145 val += A[i] * B[i] * C[i]; | |
146 } | |
147 | |
148 return val; | |
149 } | |
150 | |
151 | |
152 | |
153 template<typename T1, typename T2> | |
154 arma_hot | |
155 arma_inline | |
156 typename T1::elem_type | |
157 op_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) | |
158 { | |
159 arma_extra_debug_sigprint(); | |
160 | |
161 if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) ) | |
162 { | |
163 return op_dot::apply_unwrap(X,Y); | |
164 } | |
165 else | |
166 { | |
167 return op_dot::apply_proxy(X,Y); | |
168 } | |
169 } | |
170 | |
171 | |
172 | |
173 template<typename T1, typename T2> | |
174 arma_hot | |
175 arma_inline | |
176 typename T1::elem_type | |
177 op_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) | |
178 { | |
179 arma_extra_debug_sigprint(); | |
180 | |
181 typedef typename T1::elem_type eT; | |
182 | |
183 const unwrap<T1> tmp1(X.get_ref()); | |
184 const unwrap<T2> tmp2(Y.get_ref()); | |
185 | |
186 const Mat<eT>& A = tmp1.M; | |
187 const Mat<eT>& B = tmp2.M; | |
188 | |
189 arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" ); | |
190 | |
191 return op_dot::direct_dot(A.n_elem, A.mem, B.mem); | |
192 } | |
193 | |
194 | |
195 | |
196 template<typename T1, typename T2> | |
197 arma_hot | |
198 inline | |
199 typename T1::elem_type | |
200 op_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) | |
201 { | |
202 arma_extra_debug_sigprint(); | |
203 | |
204 typedef typename T1::elem_type eT; | |
205 typedef typename Proxy<T1>::ea_type ea_type1; | |
206 typedef typename Proxy<T2>::ea_type ea_type2; | |
207 | |
208 const Proxy<T1> A(X.get_ref()); | |
209 const Proxy<T2> B(Y.get_ref()); | |
210 | |
211 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor); | |
212 | |
213 if(prefer_at_accessor == false) | |
214 { | |
215 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "dot(): objects must have the same number of elements" ); | |
216 | |
217 const uword N = A.get_n_elem(); | |
218 ea_type1 PA = A.get_ea(); | |
219 ea_type2 PB = B.get_ea(); | |
220 | |
221 eT val1 = eT(0); | |
222 eT val2 = eT(0); | |
223 | |
224 uword i,j; | |
225 | |
226 for(i=0, j=1; j<N; i+=2, j+=2) | |
227 { | |
228 val1 += PA[i] * PB[i]; | |
229 val2 += PA[j] * PB[j]; | |
230 } | |
231 | |
232 if(i < N) | |
233 { | |
234 val1 += PA[i] * PB[i]; | |
235 } | |
236 | |
237 return val1 + val2; | |
238 } | |
239 else | |
240 { | |
241 return op_dot::apply_unwrap(A.Q, B.Q); | |
242 } | |
243 } | |
244 | |
245 | |
246 | |
247 // | |
248 // op_norm_dot | |
249 | |
250 | |
251 | |
252 template<typename T1, typename T2> | |
253 arma_hot | |
254 inline | |
255 typename T1::elem_type | |
256 op_norm_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) | |
257 { | |
258 arma_extra_debug_sigprint(); | |
259 | |
260 typedef typename T1::elem_type eT; | |
261 typedef typename Proxy<T1>::ea_type ea_type1; | |
262 typedef typename Proxy<T2>::ea_type ea_type2; | |
263 | |
264 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor); | |
265 | |
266 if(prefer_at_accessor == false) | |
267 { | |
268 const Proxy<T1> A(X.get_ref()); | |
269 const Proxy<T2> B(Y.get_ref()); | |
270 | |
271 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "norm_dot(): objects must have the same number of elements" ); | |
272 | |
273 const uword N = A.get_n_elem(); | |
274 ea_type1 PA = A.get_ea(); | |
275 ea_type2 PB = B.get_ea(); | |
276 | |
277 eT acc1 = eT(0); | |
278 eT acc2 = eT(0); | |
279 eT acc3 = eT(0); | |
280 | |
281 for(uword i=0; i<N; ++i) | |
282 { | |
283 const eT tmpA = PA[i]; | |
284 const eT tmpB = PB[i]; | |
285 | |
286 acc1 += tmpA * tmpA; | |
287 acc2 += tmpB * tmpB; | |
288 acc3 += tmpA * tmpB; | |
289 } | |
290 | |
291 return acc3 / ( std::sqrt(acc1 * acc2) ); | |
292 } | |
293 else | |
294 { | |
295 return op_norm_dot::apply_unwrap(X, Y); | |
296 } | |
297 } | |
298 | |
299 | |
300 | |
301 template<typename T1, typename T2> | |
302 arma_hot | |
303 inline | |
304 typename T1::elem_type | |
305 op_norm_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) | |
306 { | |
307 arma_extra_debug_sigprint(); | |
308 | |
309 typedef typename T1::elem_type eT; | |
310 | |
311 const unwrap<T1> tmp1(X.get_ref()); | |
312 const unwrap<T2> tmp2(Y.get_ref()); | |
313 | |
314 const Mat<eT>& A = tmp1.M; | |
315 const Mat<eT>& B = tmp2.M; | |
316 | |
317 | |
318 arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" ); | |
319 | |
320 const uword N = A.n_elem; | |
321 | |
322 const eT* A_mem = A.memptr(); | |
323 const eT* B_mem = B.memptr(); | |
324 | |
325 eT acc1 = eT(0); | |
326 eT acc2 = eT(0); | |
327 eT acc3 = eT(0); | |
328 | |
329 for(uword i=0; i<N; ++i) | |
330 { | |
331 const eT tmpA = A_mem[i]; | |
332 const eT tmpB = B_mem[i]; | |
333 | |
334 acc1 += tmpA * tmpA; | |
335 acc2 += tmpB * tmpB; | |
336 acc3 += tmpA * tmpB; | |
337 } | |
338 | |
339 return acc3 / ( std::sqrt(acc1 * acc2) ); | |
340 } | |
341 | |
342 | |
343 | |
344 // | |
345 // op_cdot | |
346 | |
347 | |
348 | |
349 template<typename T1, typename T2> | |
350 arma_hot | |
351 arma_inline | |
352 typename T1::elem_type | |
353 op_cdot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y) | |
354 { | |
355 arma_extra_debug_sigprint(); | |
356 | |
357 typedef typename T1::elem_type eT; | |
358 typedef typename Proxy<T1>::ea_type ea_type1; | |
359 typedef typename Proxy<T2>::ea_type ea_type2; | |
360 | |
361 const Proxy<T1> A(X.get_ref()); | |
362 const Proxy<T2> B(Y.get_ref()); | |
363 | |
364 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "cdot(): objects must have the same number of elements" ); | |
365 | |
366 const uword N = A.get_n_elem(); | |
367 ea_type1 PA = A.get_ea(); | |
368 ea_type2 PB = B.get_ea(); | |
369 | |
370 eT val1 = eT(0); | |
371 eT val2 = eT(0); | |
372 | |
373 uword i,j; | |
374 for(i=0, j=1; j<N; i+=2, j+=2) | |
375 { | |
376 val1 += std::conj(PA[i]) * PB[i]; | |
377 val2 += std::conj(PA[j]) * PB[j]; | |
378 } | |
379 | |
380 if(i < N) | |
381 { | |
382 val1 += std::conj(PA[i]) * PB[i]; | |
383 } | |
384 | |
385 return val1 + val2; | |
386 } | |
387 | |
388 | |
389 | |
390 //! @} |