annotate armadillo-3.900.4/include/armadillo_bits/fn_trace.hpp @ 84:55a047986812 tip

Update library URI so as not to be document-local
author Chris Cannam
date Wed, 22 Apr 2020 14:21:57 +0100
parents 1ec0e2823891
children
rev   line source
Chris@49 1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
Chris@49 2 // Copyright (C) 2008-2012 Conrad Sanderson
Chris@49 3 // Copyright (C) 2012 Ryan Curtin
Chris@49 4 //
Chris@49 5 // This Source Code Form is subject to the terms of the Mozilla Public
Chris@49 6 // License, v. 2.0. If a copy of the MPL was not distributed with this
Chris@49 7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
Chris@49 8
Chris@49 9
Chris@49 10 //! \addtogroup fn_trace
Chris@49 11 //! @{
Chris@49 12
Chris@49 13
Chris@49 14 //! Immediate trace (sum of diagonal elements) of a square dense matrix
Chris@49 15 template<typename T1>
Chris@49 16 arma_hot
Chris@49 17 arma_warn_unused
Chris@49 18 inline
Chris@49 19 typename enable_if2<is_arma_type<T1>::value, typename T1::elem_type>::result
Chris@49 20 trace(const T1& X)
Chris@49 21 {
Chris@49 22 arma_extra_debug_sigprint();
Chris@49 23
Chris@49 24 typedef typename T1::elem_type eT;
Chris@49 25
Chris@49 26 const Proxy<T1> A(X);
Chris@49 27
Chris@49 28 arma_debug_check( (A.get_n_rows() != A.get_n_cols()), "trace(): matrix must be square sized" );
Chris@49 29
Chris@49 30 const uword N = A.get_n_rows();
Chris@49 31
Chris@49 32 eT val1 = eT(0);
Chris@49 33 eT val2 = eT(0);
Chris@49 34
Chris@49 35 uword i,j;
Chris@49 36 for(i=0, j=1; j<N; i+=2, j+=2)
Chris@49 37 {
Chris@49 38 val1 += A.at(i,i);
Chris@49 39 val2 += A.at(j,j);
Chris@49 40 }
Chris@49 41
Chris@49 42 if(i < N)
Chris@49 43 {
Chris@49 44 val1 += A.at(i,i);
Chris@49 45 }
Chris@49 46
Chris@49 47 return val1 + val2;
Chris@49 48 }
Chris@49 49
Chris@49 50
Chris@49 51
Chris@49 52 template<typename T1>
Chris@49 53 arma_hot
Chris@49 54 arma_warn_unused
Chris@49 55 inline
Chris@49 56 typename T1::elem_type
Chris@49 57 trace(const Op<T1, op_diagmat>& X)
Chris@49 58 {
Chris@49 59 arma_extra_debug_sigprint();
Chris@49 60
Chris@49 61 typedef typename T1::elem_type eT;
Chris@49 62
Chris@49 63 const diagmat_proxy<T1> A(X.m);
Chris@49 64
Chris@49 65 const uword N = A.n_elem;
Chris@49 66
Chris@49 67 eT val = eT(0);
Chris@49 68
Chris@49 69 for(uword i=0; i<N; ++i)
Chris@49 70 {
Chris@49 71 val += A[i];
Chris@49 72 }
Chris@49 73
Chris@49 74 return val;
Chris@49 75 }
Chris@49 76
Chris@49 77
Chris@49 78
Chris@49 79 //! speedup for trace(A*B), where the result of A*B is a square sized matrix
Chris@49 80 template<typename T1, typename T2>
Chris@49 81 arma_hot
Chris@49 82 inline
Chris@49 83 typename T1::elem_type
Chris@49 84 trace_mul_unwrap(const T1& XA, const T2& XB)
Chris@49 85 {
Chris@49 86 arma_extra_debug_sigprint();
Chris@49 87
Chris@49 88 typedef typename T1::elem_type eT;
Chris@49 89
Chris@49 90 const Proxy<T1> PA(XA);
Chris@49 91 const unwrap<T2> tmpB(XB);
Chris@49 92
Chris@49 93 const Mat<eT>& B = tmpB.M;
Chris@49 94
Chris@49 95 arma_debug_assert_mul_size(PA.get_n_rows(), PA.get_n_cols(), B.n_rows, B.n_cols, "matrix multiplication");
Chris@49 96
Chris@49 97 arma_debug_check( (PA.get_n_rows() != B.n_cols), "trace(): matrix must be square sized" );
Chris@49 98
Chris@49 99 const uword N1 = PA.get_n_rows(); // equivalent to B.n_cols, due to square size requirements
Chris@49 100 const uword N2 = PA.get_n_cols(); // equivalent to B.n_rows, due to matrix multiplication requirements
Chris@49 101
Chris@49 102 eT val = eT(0);
Chris@49 103
Chris@49 104 for(uword i=0; i<N1; ++i)
Chris@49 105 {
Chris@49 106 const eT* B_colmem = B.colptr(i);
Chris@49 107
Chris@49 108 eT acc1 = eT(0);
Chris@49 109 eT acc2 = eT(0);
Chris@49 110
Chris@49 111 uword j,k;
Chris@49 112 for(j=0, k=1; k < N2; j+=2, k+=2)
Chris@49 113 {
Chris@49 114 const eT tmp_j = B_colmem[j];
Chris@49 115 const eT tmp_k = B_colmem[k];
Chris@49 116
Chris@49 117 acc1 += PA.at(i,j) * tmp_j;
Chris@49 118 acc2 += PA.at(i,k) * tmp_k;
Chris@49 119 }
Chris@49 120
Chris@49 121 if(j < N2)
Chris@49 122 {
Chris@49 123 acc1 += PA.at(i,j) * B_colmem[j];
Chris@49 124 }
Chris@49 125
Chris@49 126 val += (acc1 + acc2);
Chris@49 127 }
Chris@49 128
Chris@49 129 return val;
Chris@49 130 }
Chris@49 131
Chris@49 132
Chris@49 133
Chris@49 134 //! speedup for trace(A*B), where the result of A*B is a square sized matrix
Chris@49 135 template<typename T1, typename T2>
Chris@49 136 arma_hot
Chris@49 137 inline
Chris@49 138 typename T1::elem_type
Chris@49 139 trace_mul_proxy(const T1& XA, const T2& XB)
Chris@49 140 {
Chris@49 141 arma_extra_debug_sigprint();
Chris@49 142
Chris@49 143 typedef typename T1::elem_type eT;
Chris@49 144
Chris@49 145 const Proxy<T1> PA(XA);
Chris@49 146 const Proxy<T2> PB(XB);
Chris@49 147
Chris@49 148 if(is_Mat<typename Proxy<T2>::stored_type>::value == true)
Chris@49 149 {
Chris@49 150 return trace_mul_unwrap(PA.Q, PB.Q);
Chris@49 151 }
Chris@49 152
Chris@49 153 arma_debug_assert_mul_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "matrix multiplication");
Chris@49 154
Chris@49 155 arma_debug_check( (PA.get_n_rows() != PB.get_n_cols()), "trace(): matrix must be square sized" );
Chris@49 156
Chris@49 157 const uword N1 = PA.get_n_rows(); // equivalent to PB.get_n_cols(), due to square size requirements
Chris@49 158 const uword N2 = PA.get_n_cols(); // equivalent to PB.get_n_rows(), due to matrix multiplication requirements
Chris@49 159
Chris@49 160 eT val = eT(0);
Chris@49 161
Chris@49 162 for(uword i=0; i<N1; ++i)
Chris@49 163 {
Chris@49 164 eT acc1 = eT(0);
Chris@49 165 eT acc2 = eT(0);
Chris@49 166
Chris@49 167 uword j,k;
Chris@49 168 for(j=0, k=1; k < N2; j+=2, k+=2)
Chris@49 169 {
Chris@49 170 const eT tmp_j = PB.at(j,i);
Chris@49 171 const eT tmp_k = PB.at(k,i);
Chris@49 172
Chris@49 173 acc1 += PA.at(i,j) * tmp_j;
Chris@49 174 acc2 += PA.at(i,k) * tmp_k;
Chris@49 175 }
Chris@49 176
Chris@49 177 if(j < N2)
Chris@49 178 {
Chris@49 179 acc1 += PA.at(i,j) * PB.at(j,i);
Chris@49 180 }
Chris@49 181
Chris@49 182 val += (acc1 + acc2);
Chris@49 183 }
Chris@49 184
Chris@49 185 return val;
Chris@49 186 }
Chris@49 187
Chris@49 188
Chris@49 189
Chris@49 190 //! speedup for trace(A*B), where the result of A*B is a square sized matrix
Chris@49 191 template<typename T1, typename T2>
Chris@49 192 arma_hot
Chris@49 193 arma_warn_unused
Chris@49 194 inline
Chris@49 195 typename T1::elem_type
Chris@49 196 trace(const Glue<T1, T2, glue_times>& X)
Chris@49 197 {
Chris@49 198 arma_extra_debug_sigprint();
Chris@49 199
Chris@49 200 return (is_Mat<T2>::value) ? trace_mul_unwrap(X.A, X.B) : trace_mul_proxy(X.A, X.B);
Chris@49 201 }
Chris@49 202
Chris@49 203
Chris@49 204
Chris@49 205 //! trace of sparse object
Chris@49 206 template<typename T1>
Chris@49 207 arma_hot
Chris@49 208 arma_warn_unused
Chris@49 209 inline
Chris@49 210 typename enable_if2<is_arma_sparse_type<T1>::value, typename T1::elem_type>::result
Chris@49 211 trace(const T1& x)
Chris@49 212 {
Chris@49 213 arma_extra_debug_sigprint();
Chris@49 214
Chris@49 215 const SpProxy<T1> p(x);
Chris@49 216
Chris@49 217 arma_debug_check( (p.get_n_rows() != p.get_n_cols()), "trace(): matrix must be square sized" );
Chris@49 218
Chris@49 219 typedef typename T1::elem_type eT;
Chris@49 220
Chris@49 221 eT result = eT(0);
Chris@49 222
Chris@49 223 typename SpProxy<T1>::const_iterator_type it = p.begin();
Chris@49 224 typename SpProxy<T1>::const_iterator_type it_end = p.end();
Chris@49 225
Chris@49 226 while(it != it_end)
Chris@49 227 {
Chris@49 228 if(it.row() == it.col())
Chris@49 229 {
Chris@49 230 result += (*it);
Chris@49 231 }
Chris@49 232
Chris@49 233 ++it;
Chris@49 234 }
Chris@49 235
Chris@49 236 return result;
Chris@49 237 }
Chris@49 238
Chris@49 239
Chris@49 240
Chris@49 241 //! @}