comparison armadillo-2.4.4/include/armadillo_bits/op_dot_meat.hpp @ 0:8b6102e2a9b0

Armadillo Library
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 09:27:06 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:8b6102e2a9b0
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2011 Conrad Sanderson
3 //
4 // This file is part of the Armadillo C++ library.
5 // It is provided without any warranty of fitness
6 // for any purpose. You can redistribute this file
7 // and/or modify it under the terms of the GNU
8 // Lesser General Public License (LGPL) as published
9 // by the Free Software Foundation, either version 3
10 // of the License or (at your option) any later version.
11 // (see http://www.opensource.org/licenses for more info)
12
13
14 //! \addtogroup op_dot
15 //! @{
16
17
18
19
20 //! for two arrays, generic version
21 template<typename eT>
22 arma_hot
23 arma_pure
24 inline
25 eT
26 op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B)
27 {
28 arma_extra_debug_sigprint();
29
30 eT val1 = eT(0);
31 eT val2 = eT(0);
32
33 uword i, j;
34
35 for(i=0, j=1; j<n_elem; i+=2, j+=2)
36 {
37 val1 += A[i] * B[i];
38 val2 += A[j] * B[j];
39 }
40
41 if(i < n_elem)
42 {
43 val1 += A[i] * B[i];
44 }
45
46 return val1 + val2;
47 }
48
49
50
51 //! for two arrays, float and double version
52 template<typename eT>
53 arma_hot
54 arma_pure
55 inline
56 typename arma_float_only<eT>::result
57 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
58 {
59 arma_extra_debug_sigprint();
60
61 if( n_elem <= (128/sizeof(eT)) )
62 {
63 return op_dot::direct_dot_arma(n_elem, A, B);
64 }
65 else
66 {
67 #if defined(ARMA_USE_ATLAS)
68 {
69 arma_extra_debug_print("atlas::cblas_dot()");
70
71 return atlas::cblas_dot(n_elem, A, B);
72 }
73 #elif defined(ARMA_USE_BLAS)
74 {
75 arma_extra_debug_print("blas::dot()");
76
77 return blas::dot(n_elem, A, B);
78 }
79 #else
80 {
81 return op_dot::direct_dot_arma(n_elem, A, B);
82 }
83 #endif
84 }
85 }
86
87
88
89 //! for two arrays, complex version
90 template<typename eT>
91 inline
92 arma_hot
93 arma_pure
94 typename arma_cx_only<eT>::result
95 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
96 {
97 #if defined(ARMA_USE_ATLAS)
98 {
99 arma_extra_debug_print("atlas::cx_cblas_dot()");
100
101 return atlas::cx_cblas_dot(n_elem, A, B);
102 }
103 #elif defined(ARMA_USE_BLAS)
104 {
105 // TODO: work out the mess with zdotu() and zdotu_sub() in BLAS
106 return op_dot::direct_dot_arma(n_elem, A, B);
107 }
108 #else
109 {
110 return op_dot::direct_dot_arma(n_elem, A, B);
111 }
112 #endif
113 }
114
115
116
117 //! for two arrays, integral version
118 template<typename eT>
119 arma_hot
120 arma_pure
121 inline
122 typename arma_integral_only<eT>::result
123 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
124 {
125 return op_dot::direct_dot_arma(n_elem, A, B);
126 }
127
128
129
130
131 //! for three arrays
132 template<typename eT>
133 arma_hot
134 arma_pure
135 inline
136 eT
137 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C)
138 {
139 arma_extra_debug_sigprint();
140
141 eT val = eT(0);
142
143 for(uword i=0; i<n_elem; ++i)
144 {
145 val += A[i] * B[i] * C[i];
146 }
147
148 return val;
149 }
150
151
152
153 template<typename T1, typename T2>
154 arma_hot
155 arma_inline
156 typename T1::elem_type
157 op_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
158 {
159 arma_extra_debug_sigprint();
160
161 if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
162 {
163 return op_dot::apply_unwrap(X,Y);
164 }
165 else
166 {
167 return op_dot::apply_proxy(X,Y);
168 }
169 }
170
171
172
173 template<typename T1, typename T2>
174 arma_hot
175 arma_inline
176 typename T1::elem_type
177 op_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
178 {
179 arma_extra_debug_sigprint();
180
181 typedef typename T1::elem_type eT;
182
183 const unwrap<T1> tmp1(X.get_ref());
184 const unwrap<T2> tmp2(Y.get_ref());
185
186 const Mat<eT>& A = tmp1.M;
187 const Mat<eT>& B = tmp2.M;
188
189 arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
190
191 return op_dot::direct_dot(A.n_elem, A.mem, B.mem);
192 }
193
194
195
196 template<typename T1, typename T2>
197 arma_hot
198 inline
199 typename T1::elem_type
200 op_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
201 {
202 arma_extra_debug_sigprint();
203
204 typedef typename T1::elem_type eT;
205 typedef typename Proxy<T1>::ea_type ea_type1;
206 typedef typename Proxy<T2>::ea_type ea_type2;
207
208 const Proxy<T1> A(X.get_ref());
209 const Proxy<T2> B(Y.get_ref());
210
211 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor);
212
213 if(prefer_at_accessor == false)
214 {
215 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "dot(): objects must have the same number of elements" );
216
217 const uword N = A.get_n_elem();
218 ea_type1 PA = A.get_ea();
219 ea_type2 PB = B.get_ea();
220
221 eT val1 = eT(0);
222 eT val2 = eT(0);
223
224 uword i,j;
225
226 for(i=0, j=1; j<N; i+=2, j+=2)
227 {
228 val1 += PA[i] * PB[i];
229 val2 += PA[j] * PB[j];
230 }
231
232 if(i < N)
233 {
234 val1 += PA[i] * PB[i];
235 }
236
237 return val1 + val2;
238 }
239 else
240 {
241 return op_dot::apply_unwrap(A.Q, B.Q);
242 }
243 }
244
245
246
247 //
248 // op_norm_dot
249
250
251
252 template<typename T1, typename T2>
253 arma_hot
254 inline
255 typename T1::elem_type
256 op_norm_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
257 {
258 arma_extra_debug_sigprint();
259
260 typedef typename T1::elem_type eT;
261 typedef typename Proxy<T1>::ea_type ea_type1;
262 typedef typename Proxy<T2>::ea_type ea_type2;
263
264 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor);
265
266 if(prefer_at_accessor == false)
267 {
268 const Proxy<T1> A(X.get_ref());
269 const Proxy<T2> B(Y.get_ref());
270
271 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "norm_dot(): objects must have the same number of elements" );
272
273 const uword N = A.get_n_elem();
274 ea_type1 PA = A.get_ea();
275 ea_type2 PB = B.get_ea();
276
277 eT acc1 = eT(0);
278 eT acc2 = eT(0);
279 eT acc3 = eT(0);
280
281 for(uword i=0; i<N; ++i)
282 {
283 const eT tmpA = PA[i];
284 const eT tmpB = PB[i];
285
286 acc1 += tmpA * tmpA;
287 acc2 += tmpB * tmpB;
288 acc3 += tmpA * tmpB;
289 }
290
291 return acc3 / ( std::sqrt(acc1 * acc2) );
292 }
293 else
294 {
295 return op_norm_dot::apply_unwrap(X, Y);
296 }
297 }
298
299
300
301 template<typename T1, typename T2>
302 arma_hot
303 inline
304 typename T1::elem_type
305 op_norm_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
306 {
307 arma_extra_debug_sigprint();
308
309 typedef typename T1::elem_type eT;
310
311 const unwrap<T1> tmp1(X.get_ref());
312 const unwrap<T2> tmp2(Y.get_ref());
313
314 const Mat<eT>& A = tmp1.M;
315 const Mat<eT>& B = tmp2.M;
316
317
318 arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
319
320 const uword N = A.n_elem;
321
322 const eT* A_mem = A.memptr();
323 const eT* B_mem = B.memptr();
324
325 eT acc1 = eT(0);
326 eT acc2 = eT(0);
327 eT acc3 = eT(0);
328
329 for(uword i=0; i<N; ++i)
330 {
331 const eT tmpA = A_mem[i];
332 const eT tmpB = B_mem[i];
333
334 acc1 += tmpA * tmpA;
335 acc2 += tmpB * tmpB;
336 acc3 += tmpA * tmpB;
337 }
338
339 return acc3 / ( std::sqrt(acc1 * acc2) );
340 }
341
342
343
344 //
345 // op_cdot
346
347
348
349 template<typename T1, typename T2>
350 arma_hot
351 arma_inline
352 typename T1::elem_type
353 op_cdot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
354 {
355 arma_extra_debug_sigprint();
356
357 typedef typename T1::elem_type eT;
358 typedef typename Proxy<T1>::ea_type ea_type1;
359 typedef typename Proxy<T2>::ea_type ea_type2;
360
361 const Proxy<T1> A(X.get_ref());
362 const Proxy<T2> B(Y.get_ref());
363
364 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "cdot(): objects must have the same number of elements" );
365
366 const uword N = A.get_n_elem();
367 ea_type1 PA = A.get_ea();
368 ea_type2 PB = B.get_ea();
369
370 eT val1 = eT(0);
371 eT val2 = eT(0);
372
373 uword i,j;
374 for(i=0, j=1; j<N; i+=2, j+=2)
375 {
376 val1 += std::conj(PA[i]) * PB[i];
377 val2 += std::conj(PA[j]) * PB[j];
378 }
379
380 if(i < N)
381 {
382 val1 += std::conj(PA[i]) * PB[i];
383 }
384
385 return val1 + val2;
386 }
387
388
389
390 //! @}