Chris@49
|
1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2013 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9 //! \addtogroup op_dot
|
Chris@49
|
10 //! @{
|
Chris@49
|
11
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 //! for two arrays, generic version for non-complex values
|
Chris@49
|
15 template<typename eT>
|
Chris@49
|
16 arma_hot
|
Chris@49
|
17 arma_pure
|
Chris@49
|
18 arma_inline
|
Chris@49
|
19 typename arma_not_cx<eT>::result
|
Chris@49
|
20 op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B)
|
Chris@49
|
21 {
|
Chris@49
|
22 arma_extra_debug_sigprint();
|
Chris@49
|
23
|
Chris@49
|
24 eT val1 = eT(0);
|
Chris@49
|
25 eT val2 = eT(0);
|
Chris@49
|
26
|
Chris@49
|
27 uword i, j;
|
Chris@49
|
28
|
Chris@49
|
29 for(i=0, j=1; j<n_elem; i+=2, j+=2)
|
Chris@49
|
30 {
|
Chris@49
|
31 val1 += A[i] * B[i];
|
Chris@49
|
32 val2 += A[j] * B[j];
|
Chris@49
|
33 }
|
Chris@49
|
34
|
Chris@49
|
35 if(i < n_elem)
|
Chris@49
|
36 {
|
Chris@49
|
37 val1 += A[i] * B[i];
|
Chris@49
|
38 }
|
Chris@49
|
39
|
Chris@49
|
40 return val1 + val2;
|
Chris@49
|
41 }
|
Chris@49
|
42
|
Chris@49
|
43
|
Chris@49
|
44
|
Chris@49
|
45 //! for two arrays, generic version for complex values
|
Chris@49
|
46 template<typename eT>
|
Chris@49
|
47 arma_hot
|
Chris@49
|
48 arma_pure
|
Chris@49
|
49 inline
|
Chris@49
|
50 typename arma_cx_only<eT>::result
|
Chris@49
|
51 op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B)
|
Chris@49
|
52 {
|
Chris@49
|
53 arma_extra_debug_sigprint();
|
Chris@49
|
54
|
Chris@49
|
55 typedef typename get_pod_type<eT>::result T;
|
Chris@49
|
56
|
Chris@49
|
57 T val_real = T(0);
|
Chris@49
|
58 T val_imag = T(0);
|
Chris@49
|
59
|
Chris@49
|
60 for(uword i=0; i<n_elem; ++i)
|
Chris@49
|
61 {
|
Chris@49
|
62 const std::complex<T>& X = A[i];
|
Chris@49
|
63 const std::complex<T>& Y = B[i];
|
Chris@49
|
64
|
Chris@49
|
65 const T a = X.real();
|
Chris@49
|
66 const T b = X.imag();
|
Chris@49
|
67
|
Chris@49
|
68 const T c = Y.real();
|
Chris@49
|
69 const T d = Y.imag();
|
Chris@49
|
70
|
Chris@49
|
71 val_real += (a*c) - (b*d);
|
Chris@49
|
72 val_imag += (a*d) + (b*c);
|
Chris@49
|
73 }
|
Chris@49
|
74
|
Chris@49
|
75 return std::complex<T>(val_real, val_imag);
|
Chris@49
|
76 }
|
Chris@49
|
77
|
Chris@49
|
78
|
Chris@49
|
79
|
Chris@49
|
80 //! for two arrays, float and double version
|
Chris@49
|
81 template<typename eT>
|
Chris@49
|
82 arma_hot
|
Chris@49
|
83 arma_pure
|
Chris@49
|
84 inline
|
Chris@49
|
85 typename arma_real_only<eT>::result
|
Chris@49
|
86 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
|
Chris@49
|
87 {
|
Chris@49
|
88 arma_extra_debug_sigprint();
|
Chris@49
|
89
|
Chris@49
|
90 if( n_elem <= 32u )
|
Chris@49
|
91 {
|
Chris@49
|
92 return op_dot::direct_dot_arma(n_elem, A, B);
|
Chris@49
|
93 }
|
Chris@49
|
94 else
|
Chris@49
|
95 {
|
Chris@49
|
96 #if defined(ARMA_USE_ATLAS)
|
Chris@49
|
97 {
|
Chris@49
|
98 arma_extra_debug_print("atlas::cblas_dot()");
|
Chris@49
|
99
|
Chris@49
|
100 return atlas::cblas_dot(n_elem, A, B);
|
Chris@49
|
101 }
|
Chris@49
|
102 #elif defined(ARMA_USE_BLAS)
|
Chris@49
|
103 {
|
Chris@49
|
104 arma_extra_debug_print("blas::dot()");
|
Chris@49
|
105
|
Chris@49
|
106 return blas::dot(n_elem, A, B);
|
Chris@49
|
107 }
|
Chris@49
|
108 #else
|
Chris@49
|
109 {
|
Chris@49
|
110 return op_dot::direct_dot_arma(n_elem, A, B);
|
Chris@49
|
111 }
|
Chris@49
|
112 #endif
|
Chris@49
|
113 }
|
Chris@49
|
114 }
|
Chris@49
|
115
|
Chris@49
|
116
|
Chris@49
|
117
|
Chris@49
|
118 //! for two arrays, complex version
|
Chris@49
|
119 template<typename eT>
|
Chris@49
|
120 inline
|
Chris@49
|
121 arma_hot
|
Chris@49
|
122 arma_pure
|
Chris@49
|
123 typename arma_cx_only<eT>::result
|
Chris@49
|
124 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
|
Chris@49
|
125 {
|
Chris@49
|
126 if( n_elem <= 16u )
|
Chris@49
|
127 {
|
Chris@49
|
128 return op_dot::direct_dot_arma(n_elem, A, B);
|
Chris@49
|
129 }
|
Chris@49
|
130 else
|
Chris@49
|
131 {
|
Chris@49
|
132 #if defined(ARMA_USE_ATLAS)
|
Chris@49
|
133 {
|
Chris@49
|
134 arma_extra_debug_print("atlas::cx_cblas_dot()");
|
Chris@49
|
135
|
Chris@49
|
136 return atlas::cx_cblas_dot(n_elem, A, B);
|
Chris@49
|
137 }
|
Chris@49
|
138 #elif defined(ARMA_USE_BLAS)
|
Chris@49
|
139 {
|
Chris@49
|
140 arma_extra_debug_print("blas::dot()");
|
Chris@49
|
141
|
Chris@49
|
142 return blas::dot(n_elem, A, B);
|
Chris@49
|
143 }
|
Chris@49
|
144 #else
|
Chris@49
|
145 {
|
Chris@49
|
146 return op_dot::direct_dot_arma(n_elem, A, B);
|
Chris@49
|
147 }
|
Chris@49
|
148 #endif
|
Chris@49
|
149 }
|
Chris@49
|
150 }
|
Chris@49
|
151
|
Chris@49
|
152
|
Chris@49
|
153
|
Chris@49
|
154 //! for two arrays, integral version
|
Chris@49
|
155 template<typename eT>
|
Chris@49
|
156 arma_hot
|
Chris@49
|
157 arma_pure
|
Chris@49
|
158 inline
|
Chris@49
|
159 typename arma_integral_only<eT>::result
|
Chris@49
|
160 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
|
Chris@49
|
161 {
|
Chris@49
|
162 return op_dot::direct_dot_arma(n_elem, A, B);
|
Chris@49
|
163 }
|
Chris@49
|
164
|
Chris@49
|
165
|
Chris@49
|
166
|
Chris@49
|
167
|
Chris@49
|
168 //! for three arrays
|
Chris@49
|
169 template<typename eT>
|
Chris@49
|
170 arma_hot
|
Chris@49
|
171 arma_pure
|
Chris@49
|
172 inline
|
Chris@49
|
173 eT
|
Chris@49
|
174 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C)
|
Chris@49
|
175 {
|
Chris@49
|
176 arma_extra_debug_sigprint();
|
Chris@49
|
177
|
Chris@49
|
178 eT val = eT(0);
|
Chris@49
|
179
|
Chris@49
|
180 for(uword i=0; i<n_elem; ++i)
|
Chris@49
|
181 {
|
Chris@49
|
182 val += A[i] * B[i] * C[i];
|
Chris@49
|
183 }
|
Chris@49
|
184
|
Chris@49
|
185 return val;
|
Chris@49
|
186 }
|
Chris@49
|
187
|
Chris@49
|
188
|
Chris@49
|
189
|
Chris@49
|
190 template<typename T1, typename T2>
|
Chris@49
|
191 arma_hot
|
Chris@49
|
192 inline
|
Chris@49
|
193 typename T1::elem_type
|
Chris@49
|
194 op_dot::apply(const T1& X, const T2& Y)
|
Chris@49
|
195 {
|
Chris@49
|
196 arma_extra_debug_sigprint();
|
Chris@49
|
197
|
Chris@49
|
198 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) || (Proxy<T2>::prefer_at_accessor);
|
Chris@49
|
199
|
Chris@49
|
200 const bool do_unwrap = ((is_Mat<T1>::value == true) && (is_Mat<T2>::value == true)) || prefer_at_accessor;
|
Chris@49
|
201
|
Chris@49
|
202 if(do_unwrap == true)
|
Chris@49
|
203 {
|
Chris@49
|
204 const unwrap<T1> tmp1(X);
|
Chris@49
|
205 const unwrap<T2> tmp2(Y);
|
Chris@49
|
206
|
Chris@49
|
207 const typename unwrap<T1>::stored_type& A = tmp1.M;
|
Chris@49
|
208 const typename unwrap<T2>::stored_type& B = tmp2.M;
|
Chris@49
|
209
|
Chris@49
|
210 arma_debug_check
|
Chris@49
|
211 (
|
Chris@49
|
212 (A.n_elem != B.n_elem),
|
Chris@49
|
213 "dot(): objects must have the same number of elements"
|
Chris@49
|
214 );
|
Chris@49
|
215
|
Chris@49
|
216 return op_dot::direct_dot(A.n_elem, A.memptr(), B.memptr());
|
Chris@49
|
217 }
|
Chris@49
|
218 else
|
Chris@49
|
219 {
|
Chris@49
|
220 const Proxy<T1> PA(X);
|
Chris@49
|
221 const Proxy<T2> PB(Y);
|
Chris@49
|
222
|
Chris@49
|
223 arma_debug_check( (PA.get_n_elem() != PB.get_n_elem()), "dot(): objects must have the same number of elements" );
|
Chris@49
|
224
|
Chris@49
|
225 return op_dot::apply_proxy(PA,PB);
|
Chris@49
|
226 }
|
Chris@49
|
227 }
|
Chris@49
|
228
|
Chris@49
|
229
|
Chris@49
|
230
|
Chris@49
|
231 template<typename T1, typename T2>
|
Chris@49
|
232 arma_hot
|
Chris@49
|
233 inline
|
Chris@49
|
234 typename arma_not_cx<typename T1::elem_type>::result
|
Chris@49
|
235 op_dot::apply_proxy(const Proxy<T1>& PA, const Proxy<T2>& PB)
|
Chris@49
|
236 {
|
Chris@49
|
237 arma_extra_debug_sigprint();
|
Chris@49
|
238
|
Chris@49
|
239 typedef typename T1::elem_type eT;
|
Chris@49
|
240 typedef typename Proxy<T1>::ea_type ea_type1;
|
Chris@49
|
241 typedef typename Proxy<T2>::ea_type ea_type2;
|
Chris@49
|
242
|
Chris@49
|
243 const uword N = PA.get_n_elem();
|
Chris@49
|
244
|
Chris@49
|
245 ea_type1 A = PA.get_ea();
|
Chris@49
|
246 ea_type2 B = PB.get_ea();
|
Chris@49
|
247
|
Chris@49
|
248 eT val1 = eT(0);
|
Chris@49
|
249 eT val2 = eT(0);
|
Chris@49
|
250
|
Chris@49
|
251 uword i,j;
|
Chris@49
|
252
|
Chris@49
|
253 for(i=0, j=1; j<N; i+=2, j+=2)
|
Chris@49
|
254 {
|
Chris@49
|
255 val1 += A[i] * B[i];
|
Chris@49
|
256 val2 += A[j] * B[j];
|
Chris@49
|
257 }
|
Chris@49
|
258
|
Chris@49
|
259 if(i < N)
|
Chris@49
|
260 {
|
Chris@49
|
261 val1 += A[i] * B[i];
|
Chris@49
|
262 }
|
Chris@49
|
263
|
Chris@49
|
264 return val1 + val2;
|
Chris@49
|
265 }
|
Chris@49
|
266
|
Chris@49
|
267
|
Chris@49
|
268
|
Chris@49
|
269 template<typename T1, typename T2>
|
Chris@49
|
270 arma_hot
|
Chris@49
|
271 inline
|
Chris@49
|
272 typename arma_cx_only<typename T1::elem_type>::result
|
Chris@49
|
273 op_dot::apply_proxy(const Proxy<T1>& PA, const Proxy<T2>& PB)
|
Chris@49
|
274 {
|
Chris@49
|
275 arma_extra_debug_sigprint();
|
Chris@49
|
276
|
Chris@49
|
277 typedef typename T1::elem_type eT;
|
Chris@49
|
278 typedef typename get_pod_type<eT>::result T;
|
Chris@49
|
279
|
Chris@49
|
280 typedef typename Proxy<T1>::ea_type ea_type1;
|
Chris@49
|
281 typedef typename Proxy<T2>::ea_type ea_type2;
|
Chris@49
|
282
|
Chris@49
|
283 const uword N = PA.get_n_elem();
|
Chris@49
|
284
|
Chris@49
|
285 ea_type1 A = PA.get_ea();
|
Chris@49
|
286 ea_type2 B = PB.get_ea();
|
Chris@49
|
287
|
Chris@49
|
288 T val_real = T(0);
|
Chris@49
|
289 T val_imag = T(0);
|
Chris@49
|
290
|
Chris@49
|
291 for(uword i=0; i<N; ++i)
|
Chris@49
|
292 {
|
Chris@49
|
293 const std::complex<T> xx = A[i];
|
Chris@49
|
294 const std::complex<T> yy = B[i];
|
Chris@49
|
295
|
Chris@49
|
296 const T a = xx.real();
|
Chris@49
|
297 const T b = xx.imag();
|
Chris@49
|
298
|
Chris@49
|
299 const T c = yy.real();
|
Chris@49
|
300 const T d = yy.imag();
|
Chris@49
|
301
|
Chris@49
|
302 val_real += (a*c) - (b*d);
|
Chris@49
|
303 val_imag += (a*d) + (b*c);
|
Chris@49
|
304 }
|
Chris@49
|
305
|
Chris@49
|
306 return std::complex<T>(val_real, val_imag);
|
Chris@49
|
307 }
|
Chris@49
|
308
|
Chris@49
|
309
|
Chris@49
|
310
|
Chris@49
|
311 template<typename eT, typename TA>
|
Chris@49
|
312 arma_hot
|
Chris@49
|
313 inline
|
Chris@49
|
314 eT
|
Chris@49
|
315 op_dot::dot_and_copy_row(eT* out, const TA& A, const uword row, const eT* B_mem, const uword N)
|
Chris@49
|
316 {
|
Chris@49
|
317 eT acc1 = eT(0);
|
Chris@49
|
318 eT acc2 = eT(0);
|
Chris@49
|
319
|
Chris@49
|
320 uword i,j;
|
Chris@49
|
321 for(i=0, j=1; j < N; i+=2, j+=2)
|
Chris@49
|
322 {
|
Chris@49
|
323 const eT val_i = A.at(row, i);
|
Chris@49
|
324 const eT val_j = A.at(row, j);
|
Chris@49
|
325
|
Chris@49
|
326 out[i] = val_i;
|
Chris@49
|
327 out[j] = val_j;
|
Chris@49
|
328
|
Chris@49
|
329 acc1 += val_i * B_mem[i];
|
Chris@49
|
330 acc2 += val_j * B_mem[j];
|
Chris@49
|
331 }
|
Chris@49
|
332
|
Chris@49
|
333 if(i < N)
|
Chris@49
|
334 {
|
Chris@49
|
335 const eT val_i = A.at(row, i);
|
Chris@49
|
336
|
Chris@49
|
337 out[i] = val_i;
|
Chris@49
|
338
|
Chris@49
|
339 acc1 += val_i * B_mem[i];
|
Chris@49
|
340 }
|
Chris@49
|
341
|
Chris@49
|
342 return acc1 + acc2;
|
Chris@49
|
343 }
|
Chris@49
|
344
|
Chris@49
|
345
|
Chris@49
|
346
|
Chris@49
|
347 //
|
Chris@49
|
348 // op_norm_dot
|
Chris@49
|
349
|
Chris@49
|
350
|
Chris@49
|
351
|
Chris@49
|
352 template<typename T1, typename T2>
|
Chris@49
|
353 arma_hot
|
Chris@49
|
354 inline
|
Chris@49
|
355 typename T1::elem_type
|
Chris@49
|
356 op_norm_dot::apply(const T1& X, const T2& Y)
|
Chris@49
|
357 {
|
Chris@49
|
358 arma_extra_debug_sigprint();
|
Chris@49
|
359
|
Chris@49
|
360 typedef typename T1::elem_type eT;
|
Chris@49
|
361 typedef typename Proxy<T1>::ea_type ea_type1;
|
Chris@49
|
362 typedef typename Proxy<T2>::ea_type ea_type2;
|
Chris@49
|
363
|
Chris@49
|
364 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor);
|
Chris@49
|
365
|
Chris@49
|
366 if(prefer_at_accessor == false)
|
Chris@49
|
367 {
|
Chris@49
|
368 const Proxy<T1> PA(X);
|
Chris@49
|
369 const Proxy<T2> PB(Y);
|
Chris@49
|
370
|
Chris@49
|
371 const uword N = PA.get_n_elem();
|
Chris@49
|
372
|
Chris@49
|
373 arma_debug_check( (N != PB.get_n_elem()), "norm_dot(): objects must have the same number of elements" );
|
Chris@49
|
374
|
Chris@49
|
375 ea_type1 A = PA.get_ea();
|
Chris@49
|
376 ea_type2 B = PB.get_ea();
|
Chris@49
|
377
|
Chris@49
|
378 eT acc1 = eT(0);
|
Chris@49
|
379 eT acc2 = eT(0);
|
Chris@49
|
380 eT acc3 = eT(0);
|
Chris@49
|
381
|
Chris@49
|
382 for(uword i=0; i<N; ++i)
|
Chris@49
|
383 {
|
Chris@49
|
384 const eT tmpA = A[i];
|
Chris@49
|
385 const eT tmpB = B[i];
|
Chris@49
|
386
|
Chris@49
|
387 acc1 += tmpA * tmpA;
|
Chris@49
|
388 acc2 += tmpB * tmpB;
|
Chris@49
|
389 acc3 += tmpA * tmpB;
|
Chris@49
|
390 }
|
Chris@49
|
391
|
Chris@49
|
392 return acc3 / ( std::sqrt(acc1 * acc2) );
|
Chris@49
|
393 }
|
Chris@49
|
394 else
|
Chris@49
|
395 {
|
Chris@49
|
396 return op_norm_dot::apply_unwrap(X, Y);
|
Chris@49
|
397 }
|
Chris@49
|
398 }
|
Chris@49
|
399
|
Chris@49
|
400
|
Chris@49
|
401
|
Chris@49
|
402 template<typename T1, typename T2>
|
Chris@49
|
403 arma_hot
|
Chris@49
|
404 inline
|
Chris@49
|
405 typename T1::elem_type
|
Chris@49
|
406 op_norm_dot::apply_unwrap(const T1& X, const T2& Y)
|
Chris@49
|
407 {
|
Chris@49
|
408 arma_extra_debug_sigprint();
|
Chris@49
|
409
|
Chris@49
|
410 typedef typename T1::elem_type eT;
|
Chris@49
|
411
|
Chris@49
|
412 const unwrap<T1> tmp1(X);
|
Chris@49
|
413 const unwrap<T2> tmp2(Y);
|
Chris@49
|
414
|
Chris@49
|
415 const Mat<eT>& A = tmp1.M;
|
Chris@49
|
416 const Mat<eT>& B = tmp2.M;
|
Chris@49
|
417
|
Chris@49
|
418
|
Chris@49
|
419 arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
|
Chris@49
|
420
|
Chris@49
|
421 const uword N = A.n_elem;
|
Chris@49
|
422
|
Chris@49
|
423 const eT* A_mem = A.memptr();
|
Chris@49
|
424 const eT* B_mem = B.memptr();
|
Chris@49
|
425
|
Chris@49
|
426 eT acc1 = eT(0);
|
Chris@49
|
427 eT acc2 = eT(0);
|
Chris@49
|
428 eT acc3 = eT(0);
|
Chris@49
|
429
|
Chris@49
|
430 for(uword i=0; i<N; ++i)
|
Chris@49
|
431 {
|
Chris@49
|
432 const eT tmpA = A_mem[i];
|
Chris@49
|
433 const eT tmpB = B_mem[i];
|
Chris@49
|
434
|
Chris@49
|
435 acc1 += tmpA * tmpA;
|
Chris@49
|
436 acc2 += tmpB * tmpB;
|
Chris@49
|
437 acc3 += tmpA * tmpB;
|
Chris@49
|
438 }
|
Chris@49
|
439
|
Chris@49
|
440 return acc3 / ( std::sqrt(acc1 * acc2) );
|
Chris@49
|
441 }
|
Chris@49
|
442
|
Chris@49
|
443
|
Chris@49
|
444
|
Chris@49
|
445 //
|
Chris@49
|
446 // op_cdot
|
Chris@49
|
447
|
Chris@49
|
448
|
Chris@49
|
449
|
Chris@49
|
450 template<typename eT>
|
Chris@49
|
451 arma_hot
|
Chris@49
|
452 arma_pure
|
Chris@49
|
453 inline
|
Chris@49
|
454 eT
|
Chris@49
|
455 op_cdot::direct_cdot_arma(const uword n_elem, const eT* const A, const eT* const B)
|
Chris@49
|
456 {
|
Chris@49
|
457 arma_extra_debug_sigprint();
|
Chris@49
|
458
|
Chris@49
|
459 typedef typename get_pod_type<eT>::result T;
|
Chris@49
|
460
|
Chris@49
|
461 T val_real = T(0);
|
Chris@49
|
462 T val_imag = T(0);
|
Chris@49
|
463
|
Chris@49
|
464 for(uword i=0; i<n_elem; ++i)
|
Chris@49
|
465 {
|
Chris@49
|
466 const std::complex<T>& X = A[i];
|
Chris@49
|
467 const std::complex<T>& Y = B[i];
|
Chris@49
|
468
|
Chris@49
|
469 const T a = X.real();
|
Chris@49
|
470 const T b = X.imag();
|
Chris@49
|
471
|
Chris@49
|
472 const T c = Y.real();
|
Chris@49
|
473 const T d = Y.imag();
|
Chris@49
|
474
|
Chris@49
|
475 val_real += (a*c) + (b*d);
|
Chris@49
|
476 val_imag += (a*d) - (b*c);
|
Chris@49
|
477 }
|
Chris@49
|
478
|
Chris@49
|
479 return std::complex<T>(val_real, val_imag);
|
Chris@49
|
480 }
|
Chris@49
|
481
|
Chris@49
|
482
|
Chris@49
|
483
|
Chris@49
|
484 template<typename eT>
|
Chris@49
|
485 arma_hot
|
Chris@49
|
486 arma_pure
|
Chris@49
|
487 inline
|
Chris@49
|
488 eT
|
Chris@49
|
489 op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const B)
|
Chris@49
|
490 {
|
Chris@49
|
491 arma_extra_debug_sigprint();
|
Chris@49
|
492
|
Chris@49
|
493 if( n_elem <= 32u )
|
Chris@49
|
494 {
|
Chris@49
|
495 return op_cdot::direct_cdot_arma(n_elem, A, B);
|
Chris@49
|
496 }
|
Chris@49
|
497 else
|
Chris@49
|
498 {
|
Chris@49
|
499 #if defined(ARMA_USE_BLAS)
|
Chris@49
|
500 {
|
Chris@49
|
501 arma_extra_debug_print("blas::gemv()");
|
Chris@49
|
502
|
Chris@49
|
503 // using gemv() workaround due to compatibility issues with cdotc() and zdotc()
|
Chris@49
|
504
|
Chris@49
|
505 const char trans = 'C';
|
Chris@49
|
506
|
Chris@49
|
507 const blas_int m = blas_int(n_elem);
|
Chris@49
|
508 const blas_int n = 1;
|
Chris@49
|
509 //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1);
|
Chris@49
|
510 const blas_int inc = 1;
|
Chris@49
|
511
|
Chris@49
|
512 const eT alpha = eT(1);
|
Chris@49
|
513 const eT beta = eT(0);
|
Chris@49
|
514
|
Chris@49
|
515 eT result[2]; // paranoia: using two elements instead of one
|
Chris@49
|
516
|
Chris@49
|
517 //blas::gemv(&trans, &m, &n, &alpha, A, &lda, B, &inc, &beta, &result[0], &inc);
|
Chris@49
|
518 blas::gemv(&trans, &m, &n, &alpha, A, &m, B, &inc, &beta, &result[0], &inc);
|
Chris@49
|
519
|
Chris@49
|
520 return result[0];
|
Chris@49
|
521 }
|
Chris@49
|
522 #elif defined(ARMA_USE_ATLAS)
|
Chris@49
|
523 {
|
Chris@49
|
524 // TODO: use dedicated atlas functions cblas_cdotc_sub() and cblas_zdotc_sub() and retune threshold
|
Chris@49
|
525
|
Chris@49
|
526 return op_cdot::direct_cdot_arma(n_elem, A, B);
|
Chris@49
|
527 }
|
Chris@49
|
528 #else
|
Chris@49
|
529 {
|
Chris@49
|
530 return op_cdot::direct_cdot_arma(n_elem, A, B);
|
Chris@49
|
531 }
|
Chris@49
|
532 #endif
|
Chris@49
|
533 }
|
Chris@49
|
534 }
|
Chris@49
|
535
|
Chris@49
|
536
|
Chris@49
|
537
|
Chris@49
|
538 template<typename T1, typename T2>
|
Chris@49
|
539 arma_hot
|
Chris@49
|
540 inline
|
Chris@49
|
541 typename T1::elem_type
|
Chris@49
|
542 op_cdot::apply(const T1& X, const T2& Y)
|
Chris@49
|
543 {
|
Chris@49
|
544 arma_extra_debug_sigprint();
|
Chris@49
|
545
|
Chris@49
|
546 if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
|
Chris@49
|
547 {
|
Chris@49
|
548 return op_cdot::apply_unwrap(X,Y);
|
Chris@49
|
549 }
|
Chris@49
|
550 else
|
Chris@49
|
551 {
|
Chris@49
|
552 return op_cdot::apply_proxy(X,Y);
|
Chris@49
|
553 }
|
Chris@49
|
554 }
|
Chris@49
|
555
|
Chris@49
|
556
|
Chris@49
|
557
|
Chris@49
|
558 template<typename T1, typename T2>
|
Chris@49
|
559 arma_hot
|
Chris@49
|
560 inline
|
Chris@49
|
561 typename T1::elem_type
|
Chris@49
|
562 op_cdot::apply_unwrap(const T1& X, const T2& Y)
|
Chris@49
|
563 {
|
Chris@49
|
564 arma_extra_debug_sigprint();
|
Chris@49
|
565
|
Chris@49
|
566 typedef typename T1::elem_type eT;
|
Chris@49
|
567
|
Chris@49
|
568 const unwrap<T1> tmp1(X);
|
Chris@49
|
569 const unwrap<T2> tmp2(Y);
|
Chris@49
|
570
|
Chris@49
|
571 const Mat<eT>& A = tmp1.M;
|
Chris@49
|
572 const Mat<eT>& B = tmp2.M;
|
Chris@49
|
573
|
Chris@49
|
574 arma_debug_check( (A.n_elem != B.n_elem), "cdot(): objects must have the same number of elements" );
|
Chris@49
|
575
|
Chris@49
|
576 return op_cdot::direct_cdot( A.n_elem, A.mem, B.mem );
|
Chris@49
|
577 }
|
Chris@49
|
578
|
Chris@49
|
579
|
Chris@49
|
580
|
Chris@49
|
581 template<typename T1, typename T2>
|
Chris@49
|
582 arma_hot
|
Chris@49
|
583 inline
|
Chris@49
|
584 typename T1::elem_type
|
Chris@49
|
585 op_cdot::apply_proxy(const T1& X, const T2& Y)
|
Chris@49
|
586 {
|
Chris@49
|
587 arma_extra_debug_sigprint();
|
Chris@49
|
588
|
Chris@49
|
589 typedef typename T1::elem_type eT;
|
Chris@49
|
590 typedef typename get_pod_type<eT>::result T;
|
Chris@49
|
591
|
Chris@49
|
592 typedef typename Proxy<T1>::ea_type ea_type1;
|
Chris@49
|
593 typedef typename Proxy<T2>::ea_type ea_type2;
|
Chris@49
|
594
|
Chris@49
|
595 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) || (Proxy<T2>::prefer_at_accessor);
|
Chris@49
|
596
|
Chris@49
|
597 if(prefer_at_accessor == false)
|
Chris@49
|
598 {
|
Chris@49
|
599 const Proxy<T1> PA(X);
|
Chris@49
|
600 const Proxy<T2> PB(Y);
|
Chris@49
|
601
|
Chris@49
|
602 const uword N = PA.get_n_elem();
|
Chris@49
|
603
|
Chris@49
|
604 arma_debug_check( (N != PB.get_n_elem()), "cdot(): objects must have the same number of elements" );
|
Chris@49
|
605
|
Chris@49
|
606 ea_type1 A = PA.get_ea();
|
Chris@49
|
607 ea_type2 B = PB.get_ea();
|
Chris@49
|
608
|
Chris@49
|
609 T val_real = T(0);
|
Chris@49
|
610 T val_imag = T(0);
|
Chris@49
|
611
|
Chris@49
|
612 for(uword i=0; i<N; ++i)
|
Chris@49
|
613 {
|
Chris@49
|
614 const std::complex<T> AA = A[i];
|
Chris@49
|
615 const std::complex<T> BB = B[i];
|
Chris@49
|
616
|
Chris@49
|
617 const T a = AA.real();
|
Chris@49
|
618 const T b = AA.imag();
|
Chris@49
|
619
|
Chris@49
|
620 const T c = BB.real();
|
Chris@49
|
621 const T d = BB.imag();
|
Chris@49
|
622
|
Chris@49
|
623 val_real += (a*c) + (b*d);
|
Chris@49
|
624 val_imag += (a*d) - (b*c);
|
Chris@49
|
625 }
|
Chris@49
|
626
|
Chris@49
|
627 return std::complex<T>(val_real, val_imag);
|
Chris@49
|
628 }
|
Chris@49
|
629 else
|
Chris@49
|
630 {
|
Chris@49
|
631 return op_cdot::apply_unwrap( X, Y );
|
Chris@49
|
632 }
|
Chris@49
|
633 }
|
Chris@49
|
634
|
Chris@49
|
635
|
Chris@49
|
636
|
Chris@49
|
637 //! @}
|