Chris@49
|
1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2012 Conrad Sanderson
|
Chris@49
|
3 // Copyright (C) 2012 Ryan Curtin
|
Chris@49
|
4 //
|
Chris@49
|
5 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
8
|
Chris@49
|
9
|
Chris@49
|
10 //! \addtogroup fn_trace
|
Chris@49
|
11 //! @{
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 //! Immediate trace (sum of diagonal elements) of a square dense matrix
|
Chris@49
|
15 template<typename T1>
|
Chris@49
|
16 arma_hot
|
Chris@49
|
17 arma_warn_unused
|
Chris@49
|
18 inline
|
Chris@49
|
19 typename enable_if2<is_arma_type<T1>::value, typename T1::elem_type>::result
|
Chris@49
|
20 trace(const T1& X)
|
Chris@49
|
21 {
|
Chris@49
|
22 arma_extra_debug_sigprint();
|
Chris@49
|
23
|
Chris@49
|
24 typedef typename T1::elem_type eT;
|
Chris@49
|
25
|
Chris@49
|
26 const Proxy<T1> A(X);
|
Chris@49
|
27
|
Chris@49
|
28 arma_debug_check( (A.get_n_rows() != A.get_n_cols()), "trace(): matrix must be square sized" );
|
Chris@49
|
29
|
Chris@49
|
30 const uword N = A.get_n_rows();
|
Chris@49
|
31
|
Chris@49
|
32 eT val1 = eT(0);
|
Chris@49
|
33 eT val2 = eT(0);
|
Chris@49
|
34
|
Chris@49
|
35 uword i,j;
|
Chris@49
|
36 for(i=0, j=1; j<N; i+=2, j+=2)
|
Chris@49
|
37 {
|
Chris@49
|
38 val1 += A.at(i,i);
|
Chris@49
|
39 val2 += A.at(j,j);
|
Chris@49
|
40 }
|
Chris@49
|
41
|
Chris@49
|
42 if(i < N)
|
Chris@49
|
43 {
|
Chris@49
|
44 val1 += A.at(i,i);
|
Chris@49
|
45 }
|
Chris@49
|
46
|
Chris@49
|
47 return val1 + val2;
|
Chris@49
|
48 }
|
Chris@49
|
49
|
Chris@49
|
50
|
Chris@49
|
51
|
Chris@49
|
52 template<typename T1>
|
Chris@49
|
53 arma_hot
|
Chris@49
|
54 arma_warn_unused
|
Chris@49
|
55 inline
|
Chris@49
|
56 typename T1::elem_type
|
Chris@49
|
57 trace(const Op<T1, op_diagmat>& X)
|
Chris@49
|
58 {
|
Chris@49
|
59 arma_extra_debug_sigprint();
|
Chris@49
|
60
|
Chris@49
|
61 typedef typename T1::elem_type eT;
|
Chris@49
|
62
|
Chris@49
|
63 const diagmat_proxy<T1> A(X.m);
|
Chris@49
|
64
|
Chris@49
|
65 const uword N = A.n_elem;
|
Chris@49
|
66
|
Chris@49
|
67 eT val = eT(0);
|
Chris@49
|
68
|
Chris@49
|
69 for(uword i=0; i<N; ++i)
|
Chris@49
|
70 {
|
Chris@49
|
71 val += A[i];
|
Chris@49
|
72 }
|
Chris@49
|
73
|
Chris@49
|
74 return val;
|
Chris@49
|
75 }
|
Chris@49
|
76
|
Chris@49
|
77
|
Chris@49
|
78
|
Chris@49
|
79 //! speedup for trace(A*B), where the result of A*B is a square sized matrix
|
Chris@49
|
80 template<typename T1, typename T2>
|
Chris@49
|
81 arma_hot
|
Chris@49
|
82 inline
|
Chris@49
|
83 typename T1::elem_type
|
Chris@49
|
84 trace_mul_unwrap(const T1& XA, const T2& XB)
|
Chris@49
|
85 {
|
Chris@49
|
86 arma_extra_debug_sigprint();
|
Chris@49
|
87
|
Chris@49
|
88 typedef typename T1::elem_type eT;
|
Chris@49
|
89
|
Chris@49
|
90 const Proxy<T1> PA(XA);
|
Chris@49
|
91 const unwrap<T2> tmpB(XB);
|
Chris@49
|
92
|
Chris@49
|
93 const Mat<eT>& B = tmpB.M;
|
Chris@49
|
94
|
Chris@49
|
95 arma_debug_assert_mul_size(PA.get_n_rows(), PA.get_n_cols(), B.n_rows, B.n_cols, "matrix multiplication");
|
Chris@49
|
96
|
Chris@49
|
97 arma_debug_check( (PA.get_n_rows() != B.n_cols), "trace(): matrix must be square sized" );
|
Chris@49
|
98
|
Chris@49
|
99 const uword N1 = PA.get_n_rows(); // equivalent to B.n_cols, due to square size requirements
|
Chris@49
|
100 const uword N2 = PA.get_n_cols(); // equivalent to B.n_rows, due to matrix multiplication requirements
|
Chris@49
|
101
|
Chris@49
|
102 eT val = eT(0);
|
Chris@49
|
103
|
Chris@49
|
104 for(uword i=0; i<N1; ++i)
|
Chris@49
|
105 {
|
Chris@49
|
106 const eT* B_colmem = B.colptr(i);
|
Chris@49
|
107
|
Chris@49
|
108 eT acc1 = eT(0);
|
Chris@49
|
109 eT acc2 = eT(0);
|
Chris@49
|
110
|
Chris@49
|
111 uword j,k;
|
Chris@49
|
112 for(j=0, k=1; k < N2; j+=2, k+=2)
|
Chris@49
|
113 {
|
Chris@49
|
114 const eT tmp_j = B_colmem[j];
|
Chris@49
|
115 const eT tmp_k = B_colmem[k];
|
Chris@49
|
116
|
Chris@49
|
117 acc1 += PA.at(i,j) * tmp_j;
|
Chris@49
|
118 acc2 += PA.at(i,k) * tmp_k;
|
Chris@49
|
119 }
|
Chris@49
|
120
|
Chris@49
|
121 if(j < N2)
|
Chris@49
|
122 {
|
Chris@49
|
123 acc1 += PA.at(i,j) * B_colmem[j];
|
Chris@49
|
124 }
|
Chris@49
|
125
|
Chris@49
|
126 val += (acc1 + acc2);
|
Chris@49
|
127 }
|
Chris@49
|
128
|
Chris@49
|
129 return val;
|
Chris@49
|
130 }
|
Chris@49
|
131
|
Chris@49
|
132
|
Chris@49
|
133
|
Chris@49
|
134 //! speedup for trace(A*B), where the result of A*B is a square sized matrix
|
Chris@49
|
135 template<typename T1, typename T2>
|
Chris@49
|
136 arma_hot
|
Chris@49
|
137 inline
|
Chris@49
|
138 typename T1::elem_type
|
Chris@49
|
139 trace_mul_proxy(const T1& XA, const T2& XB)
|
Chris@49
|
140 {
|
Chris@49
|
141 arma_extra_debug_sigprint();
|
Chris@49
|
142
|
Chris@49
|
143 typedef typename T1::elem_type eT;
|
Chris@49
|
144
|
Chris@49
|
145 const Proxy<T1> PA(XA);
|
Chris@49
|
146 const Proxy<T2> PB(XB);
|
Chris@49
|
147
|
Chris@49
|
148 if(is_Mat<typename Proxy<T2>::stored_type>::value == true)
|
Chris@49
|
149 {
|
Chris@49
|
150 return trace_mul_unwrap(PA.Q, PB.Q);
|
Chris@49
|
151 }
|
Chris@49
|
152
|
Chris@49
|
153 arma_debug_assert_mul_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "matrix multiplication");
|
Chris@49
|
154
|
Chris@49
|
155 arma_debug_check( (PA.get_n_rows() != PB.get_n_cols()), "trace(): matrix must be square sized" );
|
Chris@49
|
156
|
Chris@49
|
157 const uword N1 = PA.get_n_rows(); // equivalent to PB.get_n_cols(), due to square size requirements
|
Chris@49
|
158 const uword N2 = PA.get_n_cols(); // equivalent to PB.get_n_rows(), due to matrix multiplication requirements
|
Chris@49
|
159
|
Chris@49
|
160 eT val = eT(0);
|
Chris@49
|
161
|
Chris@49
|
162 for(uword i=0; i<N1; ++i)
|
Chris@49
|
163 {
|
Chris@49
|
164 eT acc1 = eT(0);
|
Chris@49
|
165 eT acc2 = eT(0);
|
Chris@49
|
166
|
Chris@49
|
167 uword j,k;
|
Chris@49
|
168 for(j=0, k=1; k < N2; j+=2, k+=2)
|
Chris@49
|
169 {
|
Chris@49
|
170 const eT tmp_j = PB.at(j,i);
|
Chris@49
|
171 const eT tmp_k = PB.at(k,i);
|
Chris@49
|
172
|
Chris@49
|
173 acc1 += PA.at(i,j) * tmp_j;
|
Chris@49
|
174 acc2 += PA.at(i,k) * tmp_k;
|
Chris@49
|
175 }
|
Chris@49
|
176
|
Chris@49
|
177 if(j < N2)
|
Chris@49
|
178 {
|
Chris@49
|
179 acc1 += PA.at(i,j) * PB.at(j,i);
|
Chris@49
|
180 }
|
Chris@49
|
181
|
Chris@49
|
182 val += (acc1 + acc2);
|
Chris@49
|
183 }
|
Chris@49
|
184
|
Chris@49
|
185 return val;
|
Chris@49
|
186 }
|
Chris@49
|
187
|
Chris@49
|
188
|
Chris@49
|
189
|
Chris@49
|
190 //! speedup for trace(A*B), where the result of A*B is a square sized matrix
|
Chris@49
|
191 template<typename T1, typename T2>
|
Chris@49
|
192 arma_hot
|
Chris@49
|
193 arma_warn_unused
|
Chris@49
|
194 inline
|
Chris@49
|
195 typename T1::elem_type
|
Chris@49
|
196 trace(const Glue<T1, T2, glue_times>& X)
|
Chris@49
|
197 {
|
Chris@49
|
198 arma_extra_debug_sigprint();
|
Chris@49
|
199
|
Chris@49
|
200 return (is_Mat<T2>::value) ? trace_mul_unwrap(X.A, X.B) : trace_mul_proxy(X.A, X.B);
|
Chris@49
|
201 }
|
Chris@49
|
202
|
Chris@49
|
203
|
Chris@49
|
204
|
Chris@49
|
205 //! trace of sparse object
|
Chris@49
|
206 template<typename T1>
|
Chris@49
|
207 arma_hot
|
Chris@49
|
208 arma_warn_unused
|
Chris@49
|
209 inline
|
Chris@49
|
210 typename enable_if2<is_arma_sparse_type<T1>::value, typename T1::elem_type>::result
|
Chris@49
|
211 trace(const T1& x)
|
Chris@49
|
212 {
|
Chris@49
|
213 arma_extra_debug_sigprint();
|
Chris@49
|
214
|
Chris@49
|
215 const SpProxy<T1> p(x);
|
Chris@49
|
216
|
Chris@49
|
217 arma_debug_check( (p.get_n_rows() != p.get_n_cols()), "trace(): matrix must be square sized" );
|
Chris@49
|
218
|
Chris@49
|
219 typedef typename T1::elem_type eT;
|
Chris@49
|
220
|
Chris@49
|
221 eT result = eT(0);
|
Chris@49
|
222
|
Chris@49
|
223 typename SpProxy<T1>::const_iterator_type it = p.begin();
|
Chris@49
|
224 typename SpProxy<T1>::const_iterator_type it_end = p.end();
|
Chris@49
|
225
|
Chris@49
|
226 while(it != it_end)
|
Chris@49
|
227 {
|
Chris@49
|
228 if(it.row() == it.col())
|
Chris@49
|
229 {
|
Chris@49
|
230 result += (*it);
|
Chris@49
|
231 }
|
Chris@49
|
232
|
Chris@49
|
233 ++it;
|
Chris@49
|
234 }
|
Chris@49
|
235
|
Chris@49
|
236 return result;
|
Chris@49
|
237 }
|
Chris@49
|
238
|
Chris@49
|
239
|
Chris@49
|
240
|
Chris@49
|
241 //! @}
|