max@0
|
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2008-2011 Conrad Sanderson
|
max@0
|
3 //
|
max@0
|
4 // This file is part of the Armadillo C++ library.
|
max@0
|
5 // It is provided without any warranty of fitness
|
max@0
|
6 // for any purpose. You can redistribute this file
|
max@0
|
7 // and/or modify it under the terms of the GNU
|
max@0
|
8 // Lesser General Public License (LGPL) as published
|
max@0
|
9 // by the Free Software Foundation, either version 3
|
max@0
|
10 // of the License or (at your option) any later version.
|
max@0
|
11 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
12
|
max@0
|
13
|
max@0
|
14 //! \addtogroup op_dot
|
max@0
|
15 //! @{
|
max@0
|
16
|
max@0
|
17
|
max@0
|
18
|
max@0
|
19
|
max@0
|
20 //! for two arrays, generic version
|
max@0
|
21 template<typename eT>
|
max@0
|
22 arma_hot
|
max@0
|
23 arma_pure
|
max@0
|
24 inline
|
max@0
|
25 eT
|
max@0
|
26 op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B)
|
max@0
|
27 {
|
max@0
|
28 arma_extra_debug_sigprint();
|
max@0
|
29
|
max@0
|
30 eT val1 = eT(0);
|
max@0
|
31 eT val2 = eT(0);
|
max@0
|
32
|
max@0
|
33 uword i, j;
|
max@0
|
34
|
max@0
|
35 for(i=0, j=1; j<n_elem; i+=2, j+=2)
|
max@0
|
36 {
|
max@0
|
37 val1 += A[i] * B[i];
|
max@0
|
38 val2 += A[j] * B[j];
|
max@0
|
39 }
|
max@0
|
40
|
max@0
|
41 if(i < n_elem)
|
max@0
|
42 {
|
max@0
|
43 val1 += A[i] * B[i];
|
max@0
|
44 }
|
max@0
|
45
|
max@0
|
46 return val1 + val2;
|
max@0
|
47 }
|
max@0
|
48
|
max@0
|
49
|
max@0
|
50
|
max@0
|
51 //! for two arrays, float and double version
|
max@0
|
52 template<typename eT>
|
max@0
|
53 arma_hot
|
max@0
|
54 arma_pure
|
max@0
|
55 inline
|
max@0
|
56 typename arma_float_only<eT>::result
|
max@0
|
57 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
|
max@0
|
58 {
|
max@0
|
59 arma_extra_debug_sigprint();
|
max@0
|
60
|
max@0
|
61 if( n_elem <= (128/sizeof(eT)) )
|
max@0
|
62 {
|
max@0
|
63 return op_dot::direct_dot_arma(n_elem, A, B);
|
max@0
|
64 }
|
max@0
|
65 else
|
max@0
|
66 {
|
max@0
|
67 #if defined(ARMA_USE_ATLAS)
|
max@0
|
68 {
|
max@0
|
69 arma_extra_debug_print("atlas::cblas_dot()");
|
max@0
|
70
|
max@0
|
71 return atlas::cblas_dot(n_elem, A, B);
|
max@0
|
72 }
|
max@0
|
73 #elif defined(ARMA_USE_BLAS)
|
max@0
|
74 {
|
max@0
|
75 arma_extra_debug_print("blas::dot()");
|
max@0
|
76
|
max@0
|
77 return blas::dot(n_elem, A, B);
|
max@0
|
78 }
|
max@0
|
79 #else
|
max@0
|
80 {
|
max@0
|
81 return op_dot::direct_dot_arma(n_elem, A, B);
|
max@0
|
82 }
|
max@0
|
83 #endif
|
max@0
|
84 }
|
max@0
|
85 }
|
max@0
|
86
|
max@0
|
87
|
max@0
|
88
|
max@0
|
89 //! for two arrays, complex version
|
max@0
|
90 template<typename eT>
|
max@0
|
91 inline
|
max@0
|
92 arma_hot
|
max@0
|
93 arma_pure
|
max@0
|
94 typename arma_cx_only<eT>::result
|
max@0
|
95 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
|
max@0
|
96 {
|
max@0
|
97 #if defined(ARMA_USE_ATLAS)
|
max@0
|
98 {
|
max@0
|
99 arma_extra_debug_print("atlas::cx_cblas_dot()");
|
max@0
|
100
|
max@0
|
101 return atlas::cx_cblas_dot(n_elem, A, B);
|
max@0
|
102 }
|
max@0
|
103 #elif defined(ARMA_USE_BLAS)
|
max@0
|
104 {
|
max@0
|
105 // TODO: work out the mess with zdotu() and zdotu_sub() in BLAS
|
max@0
|
106 return op_dot::direct_dot_arma(n_elem, A, B);
|
max@0
|
107 }
|
max@0
|
108 #else
|
max@0
|
109 {
|
max@0
|
110 return op_dot::direct_dot_arma(n_elem, A, B);
|
max@0
|
111 }
|
max@0
|
112 #endif
|
max@0
|
113 }
|
max@0
|
114
|
max@0
|
115
|
max@0
|
116
|
max@0
|
117 //! for two arrays, integral version
|
max@0
|
118 template<typename eT>
|
max@0
|
119 arma_hot
|
max@0
|
120 arma_pure
|
max@0
|
121 inline
|
max@0
|
122 typename arma_integral_only<eT>::result
|
max@0
|
123 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
|
max@0
|
124 {
|
max@0
|
125 return op_dot::direct_dot_arma(n_elem, A, B);
|
max@0
|
126 }
|
max@0
|
127
|
max@0
|
128
|
max@0
|
129
|
max@0
|
130
|
max@0
|
131 //! for three arrays
|
max@0
|
132 template<typename eT>
|
max@0
|
133 arma_hot
|
max@0
|
134 arma_pure
|
max@0
|
135 inline
|
max@0
|
136 eT
|
max@0
|
137 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C)
|
max@0
|
138 {
|
max@0
|
139 arma_extra_debug_sigprint();
|
max@0
|
140
|
max@0
|
141 eT val = eT(0);
|
max@0
|
142
|
max@0
|
143 for(uword i=0; i<n_elem; ++i)
|
max@0
|
144 {
|
max@0
|
145 val += A[i] * B[i] * C[i];
|
max@0
|
146 }
|
max@0
|
147
|
max@0
|
148 return val;
|
max@0
|
149 }
|
max@0
|
150
|
max@0
|
151
|
max@0
|
152
|
max@0
|
153 template<typename T1, typename T2>
|
max@0
|
154 arma_hot
|
max@0
|
155 arma_inline
|
max@0
|
156 typename T1::elem_type
|
max@0
|
157 op_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
|
max@0
|
158 {
|
max@0
|
159 arma_extra_debug_sigprint();
|
max@0
|
160
|
max@0
|
161 if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
|
max@0
|
162 {
|
max@0
|
163 return op_dot::apply_unwrap(X,Y);
|
max@0
|
164 }
|
max@0
|
165 else
|
max@0
|
166 {
|
max@0
|
167 return op_dot::apply_proxy(X,Y);
|
max@0
|
168 }
|
max@0
|
169 }
|
max@0
|
170
|
max@0
|
171
|
max@0
|
172
|
max@0
|
173 template<typename T1, typename T2>
|
max@0
|
174 arma_hot
|
max@0
|
175 arma_inline
|
max@0
|
176 typename T1::elem_type
|
max@0
|
177 op_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
|
max@0
|
178 {
|
max@0
|
179 arma_extra_debug_sigprint();
|
max@0
|
180
|
max@0
|
181 typedef typename T1::elem_type eT;
|
max@0
|
182
|
max@0
|
183 const unwrap<T1> tmp1(X.get_ref());
|
max@0
|
184 const unwrap<T2> tmp2(Y.get_ref());
|
max@0
|
185
|
max@0
|
186 const Mat<eT>& A = tmp1.M;
|
max@0
|
187 const Mat<eT>& B = tmp2.M;
|
max@0
|
188
|
max@0
|
189 arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
|
max@0
|
190
|
max@0
|
191 return op_dot::direct_dot(A.n_elem, A.mem, B.mem);
|
max@0
|
192 }
|
max@0
|
193
|
max@0
|
194
|
max@0
|
195
|
max@0
|
196 template<typename T1, typename T2>
|
max@0
|
197 arma_hot
|
max@0
|
198 inline
|
max@0
|
199 typename T1::elem_type
|
max@0
|
200 op_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
|
max@0
|
201 {
|
max@0
|
202 arma_extra_debug_sigprint();
|
max@0
|
203
|
max@0
|
204 typedef typename T1::elem_type eT;
|
max@0
|
205 typedef typename Proxy<T1>::ea_type ea_type1;
|
max@0
|
206 typedef typename Proxy<T2>::ea_type ea_type2;
|
max@0
|
207
|
max@0
|
208 const Proxy<T1> A(X.get_ref());
|
max@0
|
209 const Proxy<T2> B(Y.get_ref());
|
max@0
|
210
|
max@0
|
211 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor);
|
max@0
|
212
|
max@0
|
213 if(prefer_at_accessor == false)
|
max@0
|
214 {
|
max@0
|
215 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "dot(): objects must have the same number of elements" );
|
max@0
|
216
|
max@0
|
217 const uword N = A.get_n_elem();
|
max@0
|
218 ea_type1 PA = A.get_ea();
|
max@0
|
219 ea_type2 PB = B.get_ea();
|
max@0
|
220
|
max@0
|
221 eT val1 = eT(0);
|
max@0
|
222 eT val2 = eT(0);
|
max@0
|
223
|
max@0
|
224 uword i,j;
|
max@0
|
225
|
max@0
|
226 for(i=0, j=1; j<N; i+=2, j+=2)
|
max@0
|
227 {
|
max@0
|
228 val1 += PA[i] * PB[i];
|
max@0
|
229 val2 += PA[j] * PB[j];
|
max@0
|
230 }
|
max@0
|
231
|
max@0
|
232 if(i < N)
|
max@0
|
233 {
|
max@0
|
234 val1 += PA[i] * PB[i];
|
max@0
|
235 }
|
max@0
|
236
|
max@0
|
237 return val1 + val2;
|
max@0
|
238 }
|
max@0
|
239 else
|
max@0
|
240 {
|
max@0
|
241 return op_dot::apply_unwrap(A.Q, B.Q);
|
max@0
|
242 }
|
max@0
|
243 }
|
max@0
|
244
|
max@0
|
245
|
max@0
|
246
|
max@0
|
247 //
|
max@0
|
248 // op_norm_dot
|
max@0
|
249
|
max@0
|
250
|
max@0
|
251
|
max@0
|
252 template<typename T1, typename T2>
|
max@0
|
253 arma_hot
|
max@0
|
254 inline
|
max@0
|
255 typename T1::elem_type
|
max@0
|
256 op_norm_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
|
max@0
|
257 {
|
max@0
|
258 arma_extra_debug_sigprint();
|
max@0
|
259
|
max@0
|
260 typedef typename T1::elem_type eT;
|
max@0
|
261 typedef typename Proxy<T1>::ea_type ea_type1;
|
max@0
|
262 typedef typename Proxy<T2>::ea_type ea_type2;
|
max@0
|
263
|
max@0
|
264 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor);
|
max@0
|
265
|
max@0
|
266 if(prefer_at_accessor == false)
|
max@0
|
267 {
|
max@0
|
268 const Proxy<T1> A(X.get_ref());
|
max@0
|
269 const Proxy<T2> B(Y.get_ref());
|
max@0
|
270
|
max@0
|
271 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "norm_dot(): objects must have the same number of elements" );
|
max@0
|
272
|
max@0
|
273 const uword N = A.get_n_elem();
|
max@0
|
274 ea_type1 PA = A.get_ea();
|
max@0
|
275 ea_type2 PB = B.get_ea();
|
max@0
|
276
|
max@0
|
277 eT acc1 = eT(0);
|
max@0
|
278 eT acc2 = eT(0);
|
max@0
|
279 eT acc3 = eT(0);
|
max@0
|
280
|
max@0
|
281 for(uword i=0; i<N; ++i)
|
max@0
|
282 {
|
max@0
|
283 const eT tmpA = PA[i];
|
max@0
|
284 const eT tmpB = PB[i];
|
max@0
|
285
|
max@0
|
286 acc1 += tmpA * tmpA;
|
max@0
|
287 acc2 += tmpB * tmpB;
|
max@0
|
288 acc3 += tmpA * tmpB;
|
max@0
|
289 }
|
max@0
|
290
|
max@0
|
291 return acc3 / ( std::sqrt(acc1 * acc2) );
|
max@0
|
292 }
|
max@0
|
293 else
|
max@0
|
294 {
|
max@0
|
295 return op_norm_dot::apply_unwrap(X, Y);
|
max@0
|
296 }
|
max@0
|
297 }
|
max@0
|
298
|
max@0
|
299
|
max@0
|
300
|
max@0
|
301 template<typename T1, typename T2>
|
max@0
|
302 arma_hot
|
max@0
|
303 inline
|
max@0
|
304 typename T1::elem_type
|
max@0
|
305 op_norm_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
|
max@0
|
306 {
|
max@0
|
307 arma_extra_debug_sigprint();
|
max@0
|
308
|
max@0
|
309 typedef typename T1::elem_type eT;
|
max@0
|
310
|
max@0
|
311 const unwrap<T1> tmp1(X.get_ref());
|
max@0
|
312 const unwrap<T2> tmp2(Y.get_ref());
|
max@0
|
313
|
max@0
|
314 const Mat<eT>& A = tmp1.M;
|
max@0
|
315 const Mat<eT>& B = tmp2.M;
|
max@0
|
316
|
max@0
|
317
|
max@0
|
318 arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
|
max@0
|
319
|
max@0
|
320 const uword N = A.n_elem;
|
max@0
|
321
|
max@0
|
322 const eT* A_mem = A.memptr();
|
max@0
|
323 const eT* B_mem = B.memptr();
|
max@0
|
324
|
max@0
|
325 eT acc1 = eT(0);
|
max@0
|
326 eT acc2 = eT(0);
|
max@0
|
327 eT acc3 = eT(0);
|
max@0
|
328
|
max@0
|
329 for(uword i=0; i<N; ++i)
|
max@0
|
330 {
|
max@0
|
331 const eT tmpA = A_mem[i];
|
max@0
|
332 const eT tmpB = B_mem[i];
|
max@0
|
333
|
max@0
|
334 acc1 += tmpA * tmpA;
|
max@0
|
335 acc2 += tmpB * tmpB;
|
max@0
|
336 acc3 += tmpA * tmpB;
|
max@0
|
337 }
|
max@0
|
338
|
max@0
|
339 return acc3 / ( std::sqrt(acc1 * acc2) );
|
max@0
|
340 }
|
max@0
|
341
|
max@0
|
342
|
max@0
|
343
|
max@0
|
344 //
|
max@0
|
345 // op_cdot
|
max@0
|
346
|
max@0
|
347
|
max@0
|
348
|
max@0
|
349 template<typename T1, typename T2>
|
max@0
|
350 arma_hot
|
max@0
|
351 arma_inline
|
max@0
|
352 typename T1::elem_type
|
max@0
|
353 op_cdot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
|
max@0
|
354 {
|
max@0
|
355 arma_extra_debug_sigprint();
|
max@0
|
356
|
max@0
|
357 typedef typename T1::elem_type eT;
|
max@0
|
358 typedef typename Proxy<T1>::ea_type ea_type1;
|
max@0
|
359 typedef typename Proxy<T2>::ea_type ea_type2;
|
max@0
|
360
|
max@0
|
361 const Proxy<T1> A(X.get_ref());
|
max@0
|
362 const Proxy<T2> B(Y.get_ref());
|
max@0
|
363
|
max@0
|
364 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "cdot(): objects must have the same number of elements" );
|
max@0
|
365
|
max@0
|
366 const uword N = A.get_n_elem();
|
max@0
|
367 ea_type1 PA = A.get_ea();
|
max@0
|
368 ea_type2 PB = B.get_ea();
|
max@0
|
369
|
max@0
|
370 eT val1 = eT(0);
|
max@0
|
371 eT val2 = eT(0);
|
max@0
|
372
|
max@0
|
373 uword i,j;
|
max@0
|
374 for(i=0, j=1; j<N; i+=2, j+=2)
|
max@0
|
375 {
|
max@0
|
376 val1 += std::conj(PA[i]) * PB[i];
|
max@0
|
377 val2 += std::conj(PA[j]) * PB[j];
|
max@0
|
378 }
|
max@0
|
379
|
max@0
|
380 if(i < N)
|
max@0
|
381 {
|
max@0
|
382 val1 += std::conj(PA[i]) * PB[i];
|
max@0
|
383 }
|
max@0
|
384
|
max@0
|
385 return val1 + val2;
|
max@0
|
386 }
|
max@0
|
387
|
max@0
|
388
|
max@0
|
389
|
max@0
|
390 //! @}
|