Chris@49: // Copyright (C) 2009-2012 NICTA (www.nicta.com.au) Chris@49: // Copyright (C) 2009-2012 Conrad Sanderson Chris@49: // Copyright (C) 2009-2010 Dimitrios Bouzas Chris@49: // Chris@49: // This Source Code Form is subject to the terms of the Mozilla Public Chris@49: // License, v. 2.0. If a copy of the MPL was not distributed with this Chris@49: // file, You can obtain one at http://mozilla.org/MPL/2.0/. Chris@49: Chris@49: Chris@49: //! \addtogroup glue_cov Chris@49: //! @{ Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: void Chris@49: glue_cov::direct_cov(Mat& out, const Mat& A, const Mat& B, const uword norm_type) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: if(A.is_vec() && B.is_vec()) Chris@49: { Chris@49: arma_debug_check( (A.n_elem != B.n_elem), "cov(): the number of elements in A and B must match" ); Chris@49: Chris@49: const eT* A_ptr = A.memptr(); Chris@49: const eT* B_ptr = B.memptr(); Chris@49: Chris@49: eT A_acc = eT(0); Chris@49: eT B_acc = eT(0); Chris@49: eT out_acc = eT(0); Chris@49: Chris@49: const uword N = A.n_elem; Chris@49: Chris@49: for(uword i=0; i 1) ? eT(N-1) : eT(1) ) : eT(N); Chris@49: Chris@49: out.set_size(1,1); Chris@49: out[0] = out_acc/norm_val; Chris@49: } Chris@49: else Chris@49: { Chris@49: arma_debug_assert_mul_size(A, B, true, false, "cov()"); Chris@49: Chris@49: const uword N = A.n_rows; Chris@49: const eT norm_val = (norm_type == 0) ? ( (N > 1) ? eT(N-1) : eT(1) ) : eT(N); Chris@49: Chris@49: out = trans(A) * B; Chris@49: out -= (trans(sum(A)) * sum(B))/eT(N); Chris@49: out /= norm_val; Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: void Chris@49: glue_cov::direct_cov(Mat< std::complex >& out, const Mat< std::complex >& A, const Mat< std::complex >& B, const uword norm_type) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename std::complex eT; Chris@49: Chris@49: if(A.is_vec() && B.is_vec()) Chris@49: { Chris@49: arma_debug_check( (A.n_elem != B.n_elem), "cov(): the number of elements in A and B must match" ); Chris@49: Chris@49: const eT* A_ptr = A.memptr(); Chris@49: const eT* B_ptr = B.memptr(); Chris@49: Chris@49: eT A_acc = eT(0); Chris@49: eT B_acc = eT(0); Chris@49: eT out_acc = eT(0); Chris@49: Chris@49: const uword N = A.n_elem; Chris@49: Chris@49: for(uword i=0; i 1) ? eT(N-1) : eT(1) ) : eT(N); Chris@49: Chris@49: out.set_size(1,1); Chris@49: out[0] = out_acc/norm_val; Chris@49: } Chris@49: else Chris@49: { Chris@49: arma_debug_assert_mul_size(A, B, true, false, "cov()"); Chris@49: Chris@49: const uword N = A.n_rows; Chris@49: const eT norm_val = (norm_type == 0) ? ( (N > 1) ? eT(N-1) : eT(1) ) : eT(N); Chris@49: Chris@49: out = trans(A) * B; // out = strans(conj(A)) * B; Chris@49: out -= (trans(sum(A)) * sum(B))/eT(N); // out -= (strans(conj(sum(A))) * sum(B))/eT(N); Chris@49: out /= norm_val; Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: void Chris@49: glue_cov::apply(Mat& out, const Glue& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const unwrap_check A_tmp(X.A, out); Chris@49: const unwrap_check B_tmp(X.B, out); Chris@49: Chris@49: const Mat& A = A_tmp.M; Chris@49: const Mat& B = B_tmp.M; Chris@49: Chris@49: const uword norm_type = X.aux_uword; Chris@49: Chris@49: if(&A != &B) Chris@49: { Chris@49: glue_cov::direct_cov(out, A, B, norm_type); Chris@49: } Chris@49: else Chris@49: { Chris@49: op_cov::direct_cov(out, A, norm_type); Chris@49: } Chris@49: Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! @}