Chris@49: // Copyright (C) 2008-2013 NICTA (www.nicta.com.au) Chris@49: // Copyright (C) 2008-2013 Conrad Sanderson Chris@49: // Chris@49: // This Source Code Form is subject to the terms of the Mozilla Public Chris@49: // License, v. 2.0. If a copy of the MPL was not distributed with this Chris@49: // file, You can obtain one at http://mozilla.org/MPL/2.0/. Chris@49: Chris@49: Chris@49: //! \addtogroup op_dot Chris@49: //! @{ Chris@49: Chris@49: Chris@49: Chris@49: //! for two arrays, generic version for non-complex values Chris@49: template Chris@49: arma_hot Chris@49: arma_pure Chris@49: arma_inline Chris@49: typename arma_not_cx::result Chris@49: op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: eT val1 = eT(0); Chris@49: eT val2 = eT(0); Chris@49: Chris@49: uword i, j; Chris@49: Chris@49: for(i=0, j=1; j Chris@49: arma_hot Chris@49: arma_pure Chris@49: inline Chris@49: typename arma_cx_only::result Chris@49: op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename get_pod_type::result T; Chris@49: Chris@49: T val_real = T(0); Chris@49: T val_imag = T(0); Chris@49: Chris@49: for(uword i=0; i& X = A[i]; Chris@49: const std::complex& Y = B[i]; Chris@49: Chris@49: const T a = X.real(); Chris@49: const T b = X.imag(); Chris@49: Chris@49: const T c = Y.real(); Chris@49: const T d = Y.imag(); Chris@49: Chris@49: val_real += (a*c) - (b*d); Chris@49: val_imag += (a*d) + (b*c); Chris@49: } Chris@49: Chris@49: return std::complex(val_real, val_imag); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! for two arrays, float and double version Chris@49: template Chris@49: arma_hot Chris@49: arma_pure Chris@49: inline Chris@49: typename arma_real_only::result Chris@49: op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: if( n_elem <= 32u ) Chris@49: { Chris@49: return op_dot::direct_dot_arma(n_elem, A, B); Chris@49: } Chris@49: else Chris@49: { Chris@49: #if defined(ARMA_USE_ATLAS) Chris@49: { Chris@49: arma_extra_debug_print("atlas::cblas_dot()"); Chris@49: Chris@49: return atlas::cblas_dot(n_elem, A, B); Chris@49: } Chris@49: #elif defined(ARMA_USE_BLAS) Chris@49: { Chris@49: arma_extra_debug_print("blas::dot()"); Chris@49: Chris@49: return blas::dot(n_elem, A, B); Chris@49: } Chris@49: #else Chris@49: { Chris@49: return op_dot::direct_dot_arma(n_elem, A, B); Chris@49: } Chris@49: #endif Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! for two arrays, complex version Chris@49: template Chris@49: inline Chris@49: arma_hot Chris@49: arma_pure Chris@49: typename arma_cx_only::result Chris@49: op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) Chris@49: { Chris@49: if( n_elem <= 16u ) Chris@49: { Chris@49: return op_dot::direct_dot_arma(n_elem, A, B); Chris@49: } Chris@49: else Chris@49: { Chris@49: #if defined(ARMA_USE_ATLAS) Chris@49: { Chris@49: arma_extra_debug_print("atlas::cx_cblas_dot()"); Chris@49: Chris@49: return atlas::cx_cblas_dot(n_elem, A, B); Chris@49: } Chris@49: #elif defined(ARMA_USE_BLAS) Chris@49: { Chris@49: arma_extra_debug_print("blas::dot()"); Chris@49: Chris@49: return blas::dot(n_elem, A, B); Chris@49: } Chris@49: #else Chris@49: { Chris@49: return op_dot::direct_dot_arma(n_elem, A, B); Chris@49: } Chris@49: #endif Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! for two arrays, integral version Chris@49: template Chris@49: arma_hot Chris@49: arma_pure Chris@49: inline Chris@49: typename arma_integral_only::result Chris@49: op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) Chris@49: { Chris@49: return op_dot::direct_dot_arma(n_elem, A, B); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: Chris@49: //! for three arrays Chris@49: template Chris@49: arma_hot Chris@49: arma_pure Chris@49: inline Chris@49: eT Chris@49: op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: eT val = eT(0); Chris@49: Chris@49: for(uword i=0; i Chris@49: arma_hot Chris@49: inline Chris@49: typename T1::elem_type Chris@49: op_dot::apply(const T1& X, const T2& Y) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: const bool prefer_at_accessor = (Proxy::prefer_at_accessor) || (Proxy::prefer_at_accessor); Chris@49: Chris@49: const bool do_unwrap = ((is_Mat::value == true) && (is_Mat::value == true)) || prefer_at_accessor; Chris@49: Chris@49: if(do_unwrap == true) Chris@49: { Chris@49: const unwrap tmp1(X); Chris@49: const unwrap tmp2(Y); Chris@49: Chris@49: const typename unwrap::stored_type& A = tmp1.M; Chris@49: const typename unwrap::stored_type& B = tmp2.M; Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (A.n_elem != B.n_elem), Chris@49: "dot(): objects must have the same number of elements" Chris@49: ); Chris@49: Chris@49: return op_dot::direct_dot(A.n_elem, A.memptr(), B.memptr()); Chris@49: } Chris@49: else Chris@49: { Chris@49: const Proxy PA(X); Chris@49: const Proxy PB(Y); Chris@49: Chris@49: arma_debug_check( (PA.get_n_elem() != PB.get_n_elem()), "dot(): objects must have the same number of elements" ); Chris@49: Chris@49: return op_dot::apply_proxy(PA,PB); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: typename arma_not_cx::result Chris@49: op_dot::apply_proxy(const Proxy& PA, const Proxy& PB) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: typedef typename Proxy::ea_type ea_type1; Chris@49: typedef typename Proxy::ea_type ea_type2; Chris@49: Chris@49: const uword N = PA.get_n_elem(); Chris@49: Chris@49: ea_type1 A = PA.get_ea(); Chris@49: ea_type2 B = PB.get_ea(); Chris@49: Chris@49: eT val1 = eT(0); Chris@49: eT val2 = eT(0); Chris@49: Chris@49: uword i,j; Chris@49: Chris@49: for(i=0, j=1; j Chris@49: arma_hot Chris@49: inline Chris@49: typename arma_cx_only::result Chris@49: op_dot::apply_proxy(const Proxy& PA, const Proxy& PB) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: typedef typename get_pod_type::result T; Chris@49: Chris@49: typedef typename Proxy::ea_type ea_type1; Chris@49: typedef typename Proxy::ea_type ea_type2; Chris@49: Chris@49: const uword N = PA.get_n_elem(); Chris@49: Chris@49: ea_type1 A = PA.get_ea(); Chris@49: ea_type2 B = PB.get_ea(); Chris@49: Chris@49: T val_real = T(0); Chris@49: T val_imag = T(0); Chris@49: Chris@49: for(uword i=0; i xx = A[i]; Chris@49: const std::complex yy = B[i]; Chris@49: Chris@49: const T a = xx.real(); Chris@49: const T b = xx.imag(); Chris@49: Chris@49: const T c = yy.real(); Chris@49: const T d = yy.imag(); Chris@49: Chris@49: val_real += (a*c) - (b*d); Chris@49: val_imag += (a*d) + (b*c); Chris@49: } Chris@49: Chris@49: return std::complex(val_real, val_imag); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: eT Chris@49: op_dot::dot_and_copy_row(eT* out, const TA& A, const uword row, const eT* B_mem, const uword N) Chris@49: { Chris@49: eT acc1 = eT(0); Chris@49: eT acc2 = eT(0); Chris@49: Chris@49: uword i,j; Chris@49: for(i=0, j=1; j < N; i+=2, j+=2) Chris@49: { Chris@49: const eT val_i = A.at(row, i); Chris@49: const eT val_j = A.at(row, j); Chris@49: Chris@49: out[i] = val_i; Chris@49: out[j] = val_j; Chris@49: Chris@49: acc1 += val_i * B_mem[i]; Chris@49: acc2 += val_j * B_mem[j]; Chris@49: } Chris@49: Chris@49: if(i < N) Chris@49: { Chris@49: const eT val_i = A.at(row, i); Chris@49: Chris@49: out[i] = val_i; Chris@49: Chris@49: acc1 += val_i * B_mem[i]; Chris@49: } Chris@49: Chris@49: return acc1 + acc2; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: // Chris@49: // op_norm_dot Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: typename T1::elem_type Chris@49: op_norm_dot::apply(const T1& X, const T2& Y) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: typedef typename Proxy::ea_type ea_type1; Chris@49: typedef typename Proxy::ea_type ea_type2; Chris@49: Chris@49: const bool prefer_at_accessor = (Proxy::prefer_at_accessor) && (Proxy::prefer_at_accessor); Chris@49: Chris@49: if(prefer_at_accessor == false) Chris@49: { Chris@49: const Proxy PA(X); Chris@49: const Proxy PB(Y); Chris@49: Chris@49: const uword N = PA.get_n_elem(); Chris@49: Chris@49: arma_debug_check( (N != PB.get_n_elem()), "norm_dot(): objects must have the same number of elements" ); Chris@49: Chris@49: ea_type1 A = PA.get_ea(); Chris@49: ea_type2 B = PB.get_ea(); Chris@49: Chris@49: eT acc1 = eT(0); Chris@49: eT acc2 = eT(0); Chris@49: eT acc3 = eT(0); Chris@49: Chris@49: for(uword i=0; i Chris@49: arma_hot Chris@49: inline Chris@49: typename T1::elem_type Chris@49: op_norm_dot::apply_unwrap(const T1& X, const T2& Y) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const unwrap tmp1(X); Chris@49: const unwrap tmp2(Y); Chris@49: Chris@49: const Mat& A = tmp1.M; Chris@49: const Mat& B = tmp2.M; Chris@49: Chris@49: Chris@49: arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" ); Chris@49: Chris@49: const uword N = A.n_elem; Chris@49: Chris@49: const eT* A_mem = A.memptr(); Chris@49: const eT* B_mem = B.memptr(); Chris@49: Chris@49: eT acc1 = eT(0); Chris@49: eT acc2 = eT(0); Chris@49: eT acc3 = eT(0); Chris@49: Chris@49: for(uword i=0; i Chris@49: arma_hot Chris@49: arma_pure Chris@49: inline Chris@49: eT Chris@49: op_cdot::direct_cdot_arma(const uword n_elem, const eT* const A, const eT* const B) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename get_pod_type::result T; Chris@49: Chris@49: T val_real = T(0); Chris@49: T val_imag = T(0); Chris@49: Chris@49: for(uword i=0; i& X = A[i]; Chris@49: const std::complex& Y = B[i]; Chris@49: Chris@49: const T a = X.real(); Chris@49: const T b = X.imag(); Chris@49: Chris@49: const T c = Y.real(); Chris@49: const T d = Y.imag(); Chris@49: Chris@49: val_real += (a*c) + (b*d); Chris@49: val_imag += (a*d) - (b*c); Chris@49: } Chris@49: Chris@49: return std::complex(val_real, val_imag); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: arma_pure Chris@49: inline Chris@49: eT Chris@49: op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const B) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: if( n_elem <= 32u ) Chris@49: { Chris@49: return op_cdot::direct_cdot_arma(n_elem, A, B); Chris@49: } Chris@49: else Chris@49: { Chris@49: #if defined(ARMA_USE_BLAS) Chris@49: { Chris@49: arma_extra_debug_print("blas::gemv()"); Chris@49: Chris@49: // using gemv() workaround due to compatibility issues with cdotc() and zdotc() Chris@49: Chris@49: const char trans = 'C'; Chris@49: Chris@49: const blas_int m = blas_int(n_elem); Chris@49: const blas_int n = 1; Chris@49: //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1); Chris@49: const blas_int inc = 1; Chris@49: Chris@49: const eT alpha = eT(1); Chris@49: const eT beta = eT(0); Chris@49: Chris@49: eT result[2]; // paranoia: using two elements instead of one Chris@49: Chris@49: //blas::gemv(&trans, &m, &n, &alpha, A, &lda, B, &inc, &beta, &result[0], &inc); Chris@49: blas::gemv(&trans, &m, &n, &alpha, A, &m, B, &inc, &beta, &result[0], &inc); Chris@49: Chris@49: return result[0]; Chris@49: } Chris@49: #elif defined(ARMA_USE_ATLAS) Chris@49: { Chris@49: // TODO: use dedicated atlas functions cblas_cdotc_sub() and cblas_zdotc_sub() and retune threshold Chris@49: Chris@49: return op_cdot::direct_cdot_arma(n_elem, A, B); Chris@49: } Chris@49: #else Chris@49: { Chris@49: return op_cdot::direct_cdot_arma(n_elem, A, B); Chris@49: } Chris@49: #endif Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: typename T1::elem_type Chris@49: op_cdot::apply(const T1& X, const T2& Y) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: if( (is_Mat::value == true) && (is_Mat::value == true) ) Chris@49: { Chris@49: return op_cdot::apply_unwrap(X,Y); Chris@49: } Chris@49: else Chris@49: { Chris@49: return op_cdot::apply_proxy(X,Y); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: typename T1::elem_type Chris@49: op_cdot::apply_unwrap(const T1& X, const T2& Y) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const unwrap tmp1(X); Chris@49: const unwrap tmp2(Y); Chris@49: Chris@49: const Mat& A = tmp1.M; Chris@49: const Mat& B = tmp2.M; Chris@49: Chris@49: arma_debug_check( (A.n_elem != B.n_elem), "cdot(): objects must have the same number of elements" ); Chris@49: Chris@49: return op_cdot::direct_cdot( A.n_elem, A.mem, B.mem ); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: typename T1::elem_type Chris@49: op_cdot::apply_proxy(const T1& X, const T2& Y) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: typedef typename get_pod_type::result T; Chris@49: Chris@49: typedef typename Proxy::ea_type ea_type1; Chris@49: typedef typename Proxy::ea_type ea_type2; Chris@49: Chris@49: const bool prefer_at_accessor = (Proxy::prefer_at_accessor) || (Proxy::prefer_at_accessor); Chris@49: Chris@49: if(prefer_at_accessor == false) Chris@49: { Chris@49: const Proxy PA(X); Chris@49: const Proxy PB(Y); Chris@49: Chris@49: const uword N = PA.get_n_elem(); Chris@49: Chris@49: arma_debug_check( (N != PB.get_n_elem()), "cdot(): objects must have the same number of elements" ); Chris@49: Chris@49: ea_type1 A = PA.get_ea(); Chris@49: ea_type2 B = PB.get_ea(); Chris@49: Chris@49: T val_real = T(0); Chris@49: T val_imag = T(0); Chris@49: Chris@49: for(uword i=0; i AA = A[i]; Chris@49: const std::complex BB = B[i]; Chris@49: Chris@49: const T a = AA.real(); Chris@49: const T b = AA.imag(); Chris@49: Chris@49: const T c = BB.real(); Chris@49: const T d = BB.imag(); Chris@49: Chris@49: val_real += (a*c) + (b*d); Chris@49: val_imag += (a*d) - (b*c); Chris@49: } Chris@49: Chris@49: return std::complex(val_real, val_imag); Chris@49: } Chris@49: else Chris@49: { Chris@49: return op_cdot::apply_unwrap( X, Y ); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! @}