Mercurial > hg > segmenter-vamp-plugin
view armadillo-3.900.4/include/armadillo_bits/op_dot_meat.hpp @ 84:55a047986812 tip
Update library URI so as not to be document-local
author | Chris Cannam |
---|---|
date | Wed, 22 Apr 2020 14:21:57 +0100 |
parents | 1ec0e2823891 |
children |
line wrap: on
line source
// Copyright (C) 2008-2013 NICTA (www.nicta.com.au) // Copyright (C) 2008-2013 Conrad Sanderson // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. //! \addtogroup op_dot //! @{ //! for two arrays, generic version for non-complex values template<typename eT> arma_hot arma_pure arma_inline typename arma_not_cx<eT>::result op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) { arma_extra_debug_sigprint(); eT val1 = eT(0); eT val2 = eT(0); uword i, j; for(i=0, j=1; j<n_elem; i+=2, j+=2) { val1 += A[i] * B[i]; val2 += A[j] * B[j]; } if(i < n_elem) { val1 += A[i] * B[i]; } return val1 + val2; } //! for two arrays, generic version for complex values template<typename eT> arma_hot arma_pure inline typename arma_cx_only<eT>::result op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) { arma_extra_debug_sigprint(); typedef typename get_pod_type<eT>::result T; T val_real = T(0); T val_imag = T(0); for(uword i=0; i<n_elem; ++i) { const std::complex<T>& X = A[i]; const std::complex<T>& Y = B[i]; const T a = X.real(); const T b = X.imag(); const T c = Y.real(); const T d = Y.imag(); val_real += (a*c) - (b*d); val_imag += (a*d) + (b*c); } return std::complex<T>(val_real, val_imag); } //! for two arrays, float and double version template<typename eT> arma_hot arma_pure inline typename arma_real_only<eT>::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) { arma_extra_debug_sigprint(); if( n_elem <= 32u ) { return op_dot::direct_dot_arma(n_elem, A, B); } else { #if defined(ARMA_USE_ATLAS) { arma_extra_debug_print("atlas::cblas_dot()"); return atlas::cblas_dot(n_elem, A, B); } #elif defined(ARMA_USE_BLAS) { arma_extra_debug_print("blas::dot()"); return blas::dot(n_elem, A, B); } #else { return op_dot::direct_dot_arma(n_elem, A, B); } #endif } } //! for two arrays, complex version template<typename eT> inline arma_hot arma_pure typename arma_cx_only<eT>::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) { if( n_elem <= 16u ) { return op_dot::direct_dot_arma(n_elem, A, B); } else { #if defined(ARMA_USE_ATLAS) { arma_extra_debug_print("atlas::cx_cblas_dot()"); return atlas::cx_cblas_dot(n_elem, A, B); } #elif defined(ARMA_USE_BLAS) { arma_extra_debug_print("blas::dot()"); return blas::dot(n_elem, A, B); } #else { return op_dot::direct_dot_arma(n_elem, A, B); } #endif } } //! for two arrays, integral version template<typename eT> arma_hot arma_pure inline typename arma_integral_only<eT>::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) { return op_dot::direct_dot_arma(n_elem, A, B); } //! for three arrays template<typename eT> arma_hot arma_pure inline eT op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C) { arma_extra_debug_sigprint(); eT val = eT(0); for(uword i=0; i<n_elem; ++i) { val += A[i] * B[i] * C[i]; } return val; } template<typename T1, typename T2> arma_hot inline typename T1::elem_type op_dot::apply(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) || (Proxy<T2>::prefer_at_accessor); const bool do_unwrap = ((is_Mat<T1>::value == true) && (is_Mat<T2>::value == true)) || prefer_at_accessor; if(do_unwrap == true) { const unwrap<T1> tmp1(X); const unwrap<T2> tmp2(Y); const typename unwrap<T1>::stored_type& A = tmp1.M; const typename unwrap<T2>::stored_type& B = tmp2.M; arma_debug_check ( (A.n_elem != B.n_elem), "dot(): objects must have the same number of elements" ); return op_dot::direct_dot(A.n_elem, A.memptr(), B.memptr()); } else { const Proxy<T1> PA(X); const Proxy<T2> PB(Y); arma_debug_check( (PA.get_n_elem() != PB.get_n_elem()), "dot(): objects must have the same number of elements" ); return op_dot::apply_proxy(PA,PB); } } template<typename T1, typename T2> arma_hot inline typename arma_not_cx<typename T1::elem_type>::result op_dot::apply_proxy(const Proxy<T1>& PA, const Proxy<T2>& PB) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename Proxy<T1>::ea_type ea_type1; typedef typename Proxy<T2>::ea_type ea_type2; const uword N = PA.get_n_elem(); ea_type1 A = PA.get_ea(); ea_type2 B = PB.get_ea(); eT val1 = eT(0); eT val2 = eT(0); uword i,j; for(i=0, j=1; j<N; i+=2, j+=2) { val1 += A[i] * B[i]; val2 += A[j] * B[j]; } if(i < N) { val1 += A[i] * B[i]; } return val1 + val2; } template<typename T1, typename T2> arma_hot inline typename arma_cx_only<typename T1::elem_type>::result op_dot::apply_proxy(const Proxy<T1>& PA, const Proxy<T2>& PB) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename get_pod_type<eT>::result T; typedef typename Proxy<T1>::ea_type ea_type1; typedef typename Proxy<T2>::ea_type ea_type2; const uword N = PA.get_n_elem(); ea_type1 A = PA.get_ea(); ea_type2 B = PB.get_ea(); T val_real = T(0); T val_imag = T(0); for(uword i=0; i<N; ++i) { const std::complex<T> xx = A[i]; const std::complex<T> yy = B[i]; const T a = xx.real(); const T b = xx.imag(); const T c = yy.real(); const T d = yy.imag(); val_real += (a*c) - (b*d); val_imag += (a*d) + (b*c); } return std::complex<T>(val_real, val_imag); } template<typename eT, typename TA> arma_hot inline eT op_dot::dot_and_copy_row(eT* out, const TA& A, const uword row, const eT* B_mem, const uword N) { eT acc1 = eT(0); eT acc2 = eT(0); uword i,j; for(i=0, j=1; j < N; i+=2, j+=2) { const eT val_i = A.at(row, i); const eT val_j = A.at(row, j); out[i] = val_i; out[j] = val_j; acc1 += val_i * B_mem[i]; acc2 += val_j * B_mem[j]; } if(i < N) { const eT val_i = A.at(row, i); out[i] = val_i; acc1 += val_i * B_mem[i]; } return acc1 + acc2; } // // op_norm_dot template<typename T1, typename T2> arma_hot inline typename T1::elem_type op_norm_dot::apply(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename Proxy<T1>::ea_type ea_type1; typedef typename Proxy<T2>::ea_type ea_type2; const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) && (Proxy<T2>::prefer_at_accessor); if(prefer_at_accessor == false) { const Proxy<T1> PA(X); const Proxy<T2> PB(Y); const uword N = PA.get_n_elem(); arma_debug_check( (N != PB.get_n_elem()), "norm_dot(): objects must have the same number of elements" ); ea_type1 A = PA.get_ea(); ea_type2 B = PB.get_ea(); eT acc1 = eT(0); eT acc2 = eT(0); eT acc3 = eT(0); for(uword i=0; i<N; ++i) { const eT tmpA = A[i]; const eT tmpB = B[i]; acc1 += tmpA * tmpA; acc2 += tmpB * tmpB; acc3 += tmpA * tmpB; } return acc3 / ( std::sqrt(acc1 * acc2) ); } else { return op_norm_dot::apply_unwrap(X, Y); } } template<typename T1, typename T2> arma_hot inline typename T1::elem_type op_norm_dot::apply_unwrap(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; const unwrap<T1> tmp1(X); const unwrap<T2> tmp2(Y); const Mat<eT>& A = tmp1.M; const Mat<eT>& B = tmp2.M; arma_debug_check( (A.n_elem != B.n_elem), "norm_dot(): objects must have the same number of elements" ); const uword N = A.n_elem; const eT* A_mem = A.memptr(); const eT* B_mem = B.memptr(); eT acc1 = eT(0); eT acc2 = eT(0); eT acc3 = eT(0); for(uword i=0; i<N; ++i) { const eT tmpA = A_mem[i]; const eT tmpB = B_mem[i]; acc1 += tmpA * tmpA; acc2 += tmpB * tmpB; acc3 += tmpA * tmpB; } return acc3 / ( std::sqrt(acc1 * acc2) ); } // // op_cdot template<typename eT> arma_hot arma_pure inline eT op_cdot::direct_cdot_arma(const uword n_elem, const eT* const A, const eT* const B) { arma_extra_debug_sigprint(); typedef typename get_pod_type<eT>::result T; T val_real = T(0); T val_imag = T(0); for(uword i=0; i<n_elem; ++i) { const std::complex<T>& X = A[i]; const std::complex<T>& Y = B[i]; const T a = X.real(); const T b = X.imag(); const T c = Y.real(); const T d = Y.imag(); val_real += (a*c) + (b*d); val_imag += (a*d) - (b*c); } return std::complex<T>(val_real, val_imag); } template<typename eT> arma_hot arma_pure inline eT op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const B) { arma_extra_debug_sigprint(); if( n_elem <= 32u ) { return op_cdot::direct_cdot_arma(n_elem, A, B); } else { #if defined(ARMA_USE_BLAS) { arma_extra_debug_print("blas::gemv()"); // using gemv() workaround due to compatibility issues with cdotc() and zdotc() const char trans = 'C'; const blas_int m = blas_int(n_elem); const blas_int n = 1; //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1); const blas_int inc = 1; const eT alpha = eT(1); const eT beta = eT(0); eT result[2]; // paranoia: using two elements instead of one //blas::gemv(&trans, &m, &n, &alpha, A, &lda, B, &inc, &beta, &result[0], &inc); blas::gemv(&trans, &m, &n, &alpha, A, &m, B, &inc, &beta, &result[0], &inc); return result[0]; } #elif defined(ARMA_USE_ATLAS) { // TODO: use dedicated atlas functions cblas_cdotc_sub() and cblas_zdotc_sub() and retune threshold return op_cdot::direct_cdot_arma(n_elem, A, B); } #else { return op_cdot::direct_cdot_arma(n_elem, A, B); } #endif } } template<typename T1, typename T2> arma_hot inline typename T1::elem_type op_cdot::apply(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); if( (is_Mat<T1>::value == true) && (is_Mat<T2>::value == true) ) { return op_cdot::apply_unwrap(X,Y); } else { return op_cdot::apply_proxy(X,Y); } } template<typename T1, typename T2> arma_hot inline typename T1::elem_type op_cdot::apply_unwrap(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; const unwrap<T1> tmp1(X); const unwrap<T2> tmp2(Y); const Mat<eT>& A = tmp1.M; const Mat<eT>& B = tmp2.M; arma_debug_check( (A.n_elem != B.n_elem), "cdot(): objects must have the same number of elements" ); return op_cdot::direct_cdot( A.n_elem, A.mem, B.mem ); } template<typename T1, typename T2> arma_hot inline typename T1::elem_type op_cdot::apply_proxy(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; typedef typename get_pod_type<eT>::result T; typedef typename Proxy<T1>::ea_type ea_type1; typedef typename Proxy<T2>::ea_type ea_type2; const bool prefer_at_accessor = (Proxy<T1>::prefer_at_accessor) || (Proxy<T2>::prefer_at_accessor); if(prefer_at_accessor == false) { const Proxy<T1> PA(X); const Proxy<T2> PB(Y); const uword N = PA.get_n_elem(); arma_debug_check( (N != PB.get_n_elem()), "cdot(): objects must have the same number of elements" ); ea_type1 A = PA.get_ea(); ea_type2 B = PB.get_ea(); T val_real = T(0); T val_imag = T(0); for(uword i=0; i<N; ++i) { const std::complex<T> AA = A[i]; const std::complex<T> BB = B[i]; const T a = AA.real(); const T b = AA.imag(); const T c = BB.real(); const T d = BB.imag(); val_real += (a*c) + (b*d); val_imag += (a*d) - (b*c); } return std::complex<T>(val_real, val_imag); } else { return op_cdot::apply_unwrap( X, Y ); } } //! @}