annotate armadillo-2.4.4/include/armadillo_bits/op_dot_meat.hpp @ 5:79b343f3e4b8

In thi version the problem of letters assigned to each segment has been solved.
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 13:48:13 +0100
parents 8b6102e2a9b0
children
rev   line source
max@0 1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
max@0 2 // Copyright (C) 2008-2011 Conrad Sanderson
max@0 3 //
max@0 4 // This file is part of the Armadillo C++ library.
max@0 5 // It is provided without any warranty of fitness
max@0 6 // for any purpose. You can redistribute this file
max@0 7 // and/or modify it under the terms of the GNU
max@0 8 // Lesser General Public License (LGPL) as published
max@0 9 // by the Free Software Foundation, either version 3
max@0 10 // of the License or (at your option) any later version.
max@0 11 // (see http://www.opensource.org/licenses for more info)
max@0 12
max@0 13
max@0 14 //! \addtogroup op_dot
max@0 15 //! @{
max@0 16
max@0 17
max@0 18
max@0 19
max@0 20 //! for two arrays, generic version
max@0 21 template<typename eT>
max@0 22 arma_hot
max@0 23 arma_pure
max@0 24 inline
max@0 25 eT
max@0 26 op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B)
max@0 27 {
max@0 28 arma_extra_debug_sigprint();
max@0 29
max@0 30 eT val1 = eT(0);
max@0 31 eT val2 = eT(0);
max@0 32
max@0 33 uword i, j;
max@0 34
max@0 35 for(i=0, j=1; j<n_elem; i+=2, j+=2)
max@0 36 {
max@0 37 val1 += A[i] * B[i];
max@0 38 val2 += A[j] * B[j];
max@0 39 }
max@0 40
max@0 41 if(i < n_elem)
max@0 42 {
max@0 43 val1 += A[i] * B[i];
max@0 44 }
max@0 45
max@0 46 return val1 + val2;
max@0 47 }
max@0 48
max@0 49
max@0 50
max@0 51 //! for two arrays, float and double version
max@0 52 template<typename eT>
max@0 53 arma_hot
max@0 54 arma_pure
max@0 55 inline
max@0 56 typename arma_float_only<eT>::result
max@0 57 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
max@0 58 {
max@0 59 arma_extra_debug_sigprint();
max@0 60
max@0 61 if( n_elem <= (128/sizeof(eT)) )
max@0 62 {
max@0 63 return op_dot::direct_dot_arma(n_elem, A, B);
max@0 64 }
max@0 65 else
max@0 66 {
max@0 67 #if defined(ARMA_USE_ATLAS)
max@0 68 {
max@0 69 arma_extra_debug_print("atlas::cblas_dot()");
max@0 70
max@0 71 return atlas::cblas_dot(n_elem, A, B);
max@0 72 }
max@0 73 #elif defined(ARMA_USE_BLAS)
max@0 74 {
max@0 75 arma_extra_debug_print("blas::dot()");
max@0 76
max@0 77 return blas::dot(n_elem, A, B);
max@0 78 }
max@0 79 #else
max@0 80 {
max@0 81 return op_dot::direct_dot_arma(n_elem, A, B);
max@0 82 }
max@0 83 #endif
max@0 84 }
max@0 85 }
max@0 86
max@0 87
max@0 88
max@0 89 //! for two arrays, complex version
max@0 90 template<typename eT>
max@0 91 inline
max@0 92 arma_hot
max@0 93 arma_pure
max@0 94 typename arma_cx_only<eT>::result
max@0 95 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
max@0 96 {
max@0 97 #if defined(ARMA_USE_ATLAS)
max@0 98 {
max@0 99 arma_extra_debug_print("atlas::cx_cblas_dot()");
max@0 100
max@0 101 return atlas::cx_cblas_dot(n_elem, A, B);
max@0 102 }
max@0 103 #elif defined(ARMA_USE_BLAS)
max@0 104 {
max@0 105 // TODO: work out the mess with zdotu() and zdotu_sub() in BLAS
max@0 106 return op_dot::direct_dot_arma(n_elem, A, B);
max@0 107 }
max@0 108 #else
max@0 109 {
max@0 110 return op_dot::direct_dot_arma(n_elem, A, B);
max@0 111 }
max@0 112 #endif
max@0 113 }
max@0 114
max@0 115
max@0 116
max@0 117 //! for two arrays, integral version
max@0 118 template<typename eT>
max@0 119 arma_hot
max@0 120 arma_pure
max@0 121 inline
max@0 122 typename arma_integral_only<eT>::result
max@0 123 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B)
max@0 124 {
max@0 125 return op_dot::direct_dot_arma(n_elem, A, B);
max@0 126 }
max@0 127
max@0 128
max@0 129
max@0 130
max@0 131 //! for three arrays
max@0 132 template<typename eT>
max@0 133 arma_hot
max@0 134 arma_pure
max@0 135 inline
max@0 136 eT
max@0 137 op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C)
max@0 138 {
max@0 139 arma_extra_debug_sigprint();
max@0 140
max@0 141 eT val = eT(0);
max@0 142
max@0 143 for(uword i=0; i<n_elem; ++i)
max@0 144 {
max@0 145 val += A[i] * B[i] * C[i];
max@0 146 }
max@0 147
max@0 148 return val;
max@0 149 }
max@0 150
max@0 151
max@0 152
max@0 153 template<typename T1, typename T2>
max@0 154 arma_hot
max@0 155 arma_inline
max@0 156 typename T1::elem_type
max@0 157 op_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
max@0 158 {
max@0 159 arma_extra_debug_sigprint();
max@0 160
max@0 161 if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) )
max@0 162 {
max@0 163 return op_dot::apply_unwrap(X,Y);
max@0 164 }
max@0 165 else
max@0 166 {
max@0 167 return op_dot::apply_proxy(X,Y);
max@0 168 }
max@0 169 }
max@0 170
max@0 171
max@0 172
max@0 173 template<typename T1, typename T2>
max@0 174 arma_hot
max@0 175 arma_inline
max@0 176 typename T1::elem_type
max@0 177 op_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
max@0 178 {
max@0 179 arma_extra_debug_sigprint();
max@0 180
max@0 181 typedef typename T1::elem_type eT;
max@0 182
max@0 183 const unwrap<T1> tmp1(X.get_ref());
max@0 184 const unwrap<T2> tmp2(Y.get_ref());
max@0 185
max@0 186 const Mat<eT>& A = tmp1.M;
max@0 187 const Mat<eT>& B = tmp2.M;
max@0 188
max@0 189 arma_debug_check( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" );
max@0 190
max@0 191 return op_dot::direct_dot(A.n_elem, A.mem, B.mem);
max@0 192 }
max@0 193
max@0 194
max@0 195
max@0 196 template<typename T1, typename T2>
max@0 197 arma_hot
max@0 198 inline
max@0 199 typename T1::elem_type
max@0 200 op_dot::apply_proxy(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
max@0 201 {
max@0 202 arma_extra_debug_sigprint();
max@0 203
max@0 204 typedef typename T1::elem_type eT;
max@0 205 typedef typename Proxy<T1>::ea_type ea_type1;
max@0 206 typedef typename Proxy<T2>::ea_type ea_type2;
max@0 207
max@0 208 const Proxy<T1> A(X.get_ref());
max@0 209 const Proxy<T2> B(Y.get_ref());
max@0 210
max@0 211 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor);
max@0 212
max@0 213 if(prefer_at_accessor == false)
max@0 214 {
max@0 215 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "dot(): objects must have the same number of elements" );
max@0 216
max@0 217 const uword N = A.get_n_elem();
max@0 218 ea_type1 PA = A.get_ea();
max@0 219 ea_type2 PB = B.get_ea();
max@0 220
max@0 221 eT val1 = eT(0);
max@0 222 eT val2 = eT(0);
max@0 223
max@0 224 uword i,j;
max@0 225
max@0 226 for(i=0, j=1; j<N; i+=2, j+=2)
max@0 227 {
max@0 228 val1 += PA[i] * PB[i];
max@0 229 val2 += PA[j] * PB[j];
max@0 230 }
max@0 231
max@0 232 if(i < N)
max@0 233 {
max@0 234 val1 += PA[i] * PB[i];
max@0 235 }
max@0 236
max@0 237 return val1 + val2;
max@0 238 }
max@0 239 else
max@0 240 {
max@0 241 return op_dot::apply_unwrap(A.Q, B.Q);
max@0 242 }
max@0 243 }
max@0 244
max@0 245
max@0 246
max@0 247 //
max@0 248 // op_norm_dot
max@0 249
max@0 250
max@0 251
max@0 252 template<typename T1, typename T2>
max@0 253 arma_hot
max@0 254 inline
max@0 255 typename T1::elem_type
max@0 256 op_norm_dot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
max@0 257 {
max@0 258 arma_extra_debug_sigprint();
max@0 259
max@0 260 typedef typename T1::elem_type eT;
max@0 261 typedef typename Proxy<T1>::ea_type ea_type1;
max@0 262 typedef typename Proxy<T2>::ea_type ea_type2;
max@0 263
max@0 264 const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor);
max@0 265
max@0 266 if(prefer_at_accessor == false)
max@0 267 {
max@0 268 const Proxy<T1> A(X.get_ref());
max@0 269 const Proxy<T2> B(Y.get_ref());
max@0 270
max@0 271 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "norm_dot(): objects must have the same number of elements" );
max@0 272
max@0 273 const uword N = A.get_n_elem();
max@0 274 ea_type1 PA = A.get_ea();
max@0 275 ea_type2 PB = B.get_ea();
max@0 276
max@0 277 eT acc1 = eT(0);
max@0 278 eT acc2 = eT(0);
max@0 279 eT acc3 = eT(0);
max@0 280
max@0 281 for(uword i=0; i<N; ++i)
max@0 282 {
max@0 283 const eT tmpA = PA[i];
max@0 284 const eT tmpB = PB[i];
max@0 285
max@0 286 acc1 += tmpA * tmpA;
max@0 287 acc2 += tmpB * tmpB;
max@0 288 acc3 += tmpA * tmpB;
max@0 289 }
max@0 290
max@0 291 return acc3 / ( std::sqrt(acc1 * acc2) );
max@0 292 }
max@0 293 else
max@0 294 {
max@0 295 return op_norm_dot::apply_unwrap(X, Y);
max@0 296 }
max@0 297 }
max@0 298
max@0 299
max@0 300
max@0 301 template<typename T1, typename T2>
max@0 302 arma_hot
max@0 303 inline
max@0 304 typename T1::elem_type
max@0 305 op_norm_dot::apply_unwrap(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
max@0 306 {
max@0 307 arma_extra_debug_sigprint();
max@0 308
max@0 309 typedef typename T1::elem_type eT;
max@0 310
max@0 311 const unwrap<T1> tmp1(X.get_ref());
max@0 312 const unwrap<T2> tmp2(Y.get_ref());
max@0 313
max@0 314 const Mat<eT>& A = tmp1.M;
max@0 315 const Mat<eT>& B = tmp2.M;
max@0 316
max@0 317
max@0 318 arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" );
max@0 319
max@0 320 const uword N = A.n_elem;
max@0 321
max@0 322 const eT* A_mem = A.memptr();
max@0 323 const eT* B_mem = B.memptr();
max@0 324
max@0 325 eT acc1 = eT(0);
max@0 326 eT acc2 = eT(0);
max@0 327 eT acc3 = eT(0);
max@0 328
max@0 329 for(uword i=0; i<N; ++i)
max@0 330 {
max@0 331 const eT tmpA = A_mem[i];
max@0 332 const eT tmpB = B_mem[i];
max@0 333
max@0 334 acc1 += tmpA * tmpA;
max@0 335 acc2 += tmpB * tmpB;
max@0 336 acc3 += tmpA * tmpB;
max@0 337 }
max@0 338
max@0 339 return acc3 / ( std::sqrt(acc1 * acc2) );
max@0 340 }
max@0 341
max@0 342
max@0 343
max@0 344 //
max@0 345 // op_cdot
max@0 346
max@0 347
max@0 348
max@0 349 template<typename T1, typename T2>
max@0 350 arma_hot
max@0 351 arma_inline
max@0 352 typename T1::elem_type
max@0 353 op_cdot::apply(const Base<typename T1::elem_type,T1>& X, const Base<typename T1::elem_type,T2>& Y)
max@0 354 {
max@0 355 arma_extra_debug_sigprint();
max@0 356
max@0 357 typedef typename T1::elem_type eT;
max@0 358 typedef typename Proxy<T1>::ea_type ea_type1;
max@0 359 typedef typename Proxy<T2>::ea_type ea_type2;
max@0 360
max@0 361 const Proxy<T1> A(X.get_ref());
max@0 362 const Proxy<T2> B(Y.get_ref());
max@0 363
max@0 364 arma_debug_check( (A.get_n_elem() != B.get_n_elem()), "cdot(): objects must have the same number of elements" );
max@0 365
max@0 366 const uword N = A.get_n_elem();
max@0 367 ea_type1 PA = A.get_ea();
max@0 368 ea_type2 PB = B.get_ea();
max@0 369
max@0 370 eT val1 = eT(0);
max@0 371 eT val2 = eT(0);
max@0 372
max@0 373 uword i,j;
max@0 374 for(i=0, j=1; j<N; i+=2, j+=2)
max@0 375 {
max@0 376 val1 += std::conj(PA[i]) * PB[i];
max@0 377 val2 += std::conj(PA[j]) * PB[j];
max@0 378 }
max@0 379
max@0 380 if(i < N)
max@0 381 {
max@0 382 val1 += std::conj(PA[i]) * PB[i];
max@0 383 }
max@0 384
max@0 385 return val1 + val2;
max@0 386 }
max@0 387
max@0 388
max@0 389
max@0 390 //! @}