Chris@49
|
1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2012 Conrad Sanderson
|
Chris@49
|
3 // Copyright (C) 2012 Ryan Curtin
|
Chris@49
|
4 //
|
Chris@49
|
5 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
8
|
Chris@49
|
9
|
Chris@49
|
10 //! \addtogroup fn_accu
|
Chris@49
|
11 //! @{
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14
|
Chris@49
|
15 template<typename T1>
|
Chris@49
|
16 arma_hot
|
Chris@49
|
17 inline
|
Chris@49
|
18 typename T1::elem_type
|
Chris@49
|
19 accu_proxy_linear(const Proxy<T1>& P)
|
Chris@49
|
20 {
|
Chris@49
|
21 typedef typename T1::elem_type eT;
|
Chris@49
|
22 typedef typename Proxy<T1>::ea_type ea_type;
|
Chris@49
|
23
|
Chris@49
|
24 ea_type A = P.get_ea();
|
Chris@49
|
25 const uword n_elem = P.get_n_elem();
|
Chris@49
|
26
|
Chris@49
|
27 eT val1 = eT(0);
|
Chris@49
|
28 eT val2 = eT(0);
|
Chris@49
|
29
|
Chris@49
|
30 uword i,j;
|
Chris@49
|
31 for(i=0, j=1; j < n_elem; i+=2, j+=2)
|
Chris@49
|
32 {
|
Chris@49
|
33 val1 += A[i];
|
Chris@49
|
34 val2 += A[j];
|
Chris@49
|
35 }
|
Chris@49
|
36
|
Chris@49
|
37 if(i < n_elem)
|
Chris@49
|
38 {
|
Chris@49
|
39 val1 += A[i]; // equivalent to: val1 += A[n_elem-1];
|
Chris@49
|
40 }
|
Chris@49
|
41
|
Chris@49
|
42 return (val1 + val2);
|
Chris@49
|
43 }
|
Chris@49
|
44
|
Chris@49
|
45
|
Chris@49
|
46
|
Chris@49
|
47 template<typename T1>
|
Chris@49
|
48 arma_hot
|
Chris@49
|
49 inline
|
Chris@49
|
50 typename T1::elem_type
|
Chris@49
|
51 accu_proxy_at(const Proxy<T1>& P)
|
Chris@49
|
52 {
|
Chris@49
|
53 typedef typename T1::elem_type eT;
|
Chris@49
|
54
|
Chris@49
|
55 const uword n_rows = P.get_n_rows();
|
Chris@49
|
56 const uword n_cols = P.get_n_cols();
|
Chris@49
|
57
|
Chris@49
|
58 eT val = eT(0);
|
Chris@49
|
59
|
Chris@49
|
60 if(n_rows != 1)
|
Chris@49
|
61 {
|
Chris@49
|
62 for(uword col=0; col < n_cols; ++col)
|
Chris@49
|
63 for(uword row=0; row < n_rows; ++row)
|
Chris@49
|
64 {
|
Chris@49
|
65 val += P.at(row,col);
|
Chris@49
|
66 }
|
Chris@49
|
67 }
|
Chris@49
|
68 else
|
Chris@49
|
69 {
|
Chris@49
|
70 for(uword col=0; col < n_cols; ++col)
|
Chris@49
|
71 {
|
Chris@49
|
72 val += P.at(0,col);
|
Chris@49
|
73 }
|
Chris@49
|
74 }
|
Chris@49
|
75
|
Chris@49
|
76 return val;
|
Chris@49
|
77 }
|
Chris@49
|
78
|
Chris@49
|
79
|
Chris@49
|
80
|
Chris@49
|
81 //! accumulate the elements of a matrix
|
Chris@49
|
82 template<typename T1>
|
Chris@49
|
83 arma_hot
|
Chris@49
|
84 inline
|
Chris@49
|
85 typename enable_if2< is_arma_type<T1>::value, typename T1::elem_type >::result
|
Chris@49
|
86 accu(const T1& X)
|
Chris@49
|
87 {
|
Chris@49
|
88 arma_extra_debug_sigprint();
|
Chris@49
|
89
|
Chris@49
|
90 const Proxy<T1> P(X);
|
Chris@49
|
91
|
Chris@49
|
92 return (Proxy<T1>::prefer_at_accessor == false) ? accu_proxy_linear(P) : accu_proxy_at(P);
|
Chris@49
|
93 }
|
Chris@49
|
94
|
Chris@49
|
95
|
Chris@49
|
96
|
Chris@49
|
97 //! explicit handling of Hamming norm (also known as zero norm)
|
Chris@49
|
98 template<typename T1>
|
Chris@49
|
99 inline
|
Chris@49
|
100 arma_warn_unused
|
Chris@49
|
101 uword
|
Chris@49
|
102 accu(const mtOp<uword,T1,op_rel_noteq>& X)
|
Chris@49
|
103 {
|
Chris@49
|
104 arma_extra_debug_sigprint();
|
Chris@49
|
105
|
Chris@49
|
106 typedef typename T1::elem_type eT;
|
Chris@49
|
107
|
Chris@49
|
108 const eT val = X.aux;
|
Chris@49
|
109
|
Chris@49
|
110 const Proxy<T1> P(X.m);
|
Chris@49
|
111
|
Chris@49
|
112 uword n_nonzero = 0;
|
Chris@49
|
113
|
Chris@49
|
114 if(Proxy<T1>::prefer_at_accessor == false)
|
Chris@49
|
115 {
|
Chris@49
|
116 typedef typename Proxy<T1>::ea_type ea_type;
|
Chris@49
|
117
|
Chris@49
|
118 ea_type A = P.get_ea();
|
Chris@49
|
119 const uword n_elem = P.get_n_elem();
|
Chris@49
|
120
|
Chris@49
|
121 for(uword i=0; i<n_elem; ++i)
|
Chris@49
|
122 {
|
Chris@49
|
123 if(A[i] != val) { ++n_nonzero; }
|
Chris@49
|
124 }
|
Chris@49
|
125 }
|
Chris@49
|
126 else
|
Chris@49
|
127 {
|
Chris@49
|
128 const uword P_n_cols = P.get_n_cols();
|
Chris@49
|
129 const uword P_n_rows = P.get_n_rows();
|
Chris@49
|
130
|
Chris@49
|
131 if(P_n_rows == 1)
|
Chris@49
|
132 {
|
Chris@49
|
133 for(uword col=0; col < P_n_cols; ++col)
|
Chris@49
|
134 {
|
Chris@49
|
135 if(P.at(0,col) != val) { ++n_nonzero; }
|
Chris@49
|
136 }
|
Chris@49
|
137 }
|
Chris@49
|
138 else
|
Chris@49
|
139 {
|
Chris@49
|
140 for(uword col=0; col < P_n_cols; ++col)
|
Chris@49
|
141 for(uword row=0; row < P_n_rows; ++row)
|
Chris@49
|
142 {
|
Chris@49
|
143 if(P.at(row,col) != val) { ++n_nonzero; }
|
Chris@49
|
144 }
|
Chris@49
|
145 }
|
Chris@49
|
146 }
|
Chris@49
|
147
|
Chris@49
|
148 return n_nonzero;
|
Chris@49
|
149 }
|
Chris@49
|
150
|
Chris@49
|
151
|
Chris@49
|
152
|
Chris@49
|
153 //! accumulate the elements of a subview (submatrix)
|
Chris@49
|
154 template<typename eT>
|
Chris@49
|
155 arma_hot
|
Chris@49
|
156 arma_pure
|
Chris@49
|
157 arma_warn_unused
|
Chris@49
|
158 inline
|
Chris@49
|
159 eT
|
Chris@49
|
160 accu(const subview<eT>& X)
|
Chris@49
|
161 {
|
Chris@49
|
162 arma_extra_debug_sigprint();
|
Chris@49
|
163
|
Chris@49
|
164 const uword X_n_rows = X.n_rows;
|
Chris@49
|
165 const uword X_n_cols = X.n_cols;
|
Chris@49
|
166
|
Chris@49
|
167 eT val = eT(0);
|
Chris@49
|
168
|
Chris@49
|
169 if(X_n_rows == 1)
|
Chris@49
|
170 {
|
Chris@49
|
171 const Mat<eT>& A = X.m;
|
Chris@49
|
172
|
Chris@49
|
173 const uword start_row = X.aux_row1;
|
Chris@49
|
174 const uword start_col = X.aux_col1;
|
Chris@49
|
175
|
Chris@49
|
176 const uword end_col_p1 = start_col + X_n_cols;
|
Chris@49
|
177
|
Chris@49
|
178 uword i,j;
|
Chris@49
|
179 for(i=start_col, j=start_col+1; j < end_col_p1; i+=2, j+=2)
|
Chris@49
|
180 {
|
Chris@49
|
181 val += A.at(start_row, i);
|
Chris@49
|
182 val += A.at(start_row, j);
|
Chris@49
|
183 }
|
Chris@49
|
184
|
Chris@49
|
185 if(i < end_col_p1)
|
Chris@49
|
186 {
|
Chris@49
|
187 val += A.at(start_row, i);
|
Chris@49
|
188 }
|
Chris@49
|
189 }
|
Chris@49
|
190 else
|
Chris@49
|
191 if(X_n_cols == 1)
|
Chris@49
|
192 {
|
Chris@49
|
193 val = arrayops::accumulate( X.colptr(0), X_n_rows );
|
Chris@49
|
194 }
|
Chris@49
|
195 else
|
Chris@49
|
196 {
|
Chris@49
|
197 for(uword col=0; col < X_n_cols; ++col)
|
Chris@49
|
198 {
|
Chris@49
|
199 val += arrayops::accumulate( X.colptr(col), X_n_rows );
|
Chris@49
|
200 }
|
Chris@49
|
201 }
|
Chris@49
|
202
|
Chris@49
|
203 return val;
|
Chris@49
|
204 }
|
Chris@49
|
205
|
Chris@49
|
206
|
Chris@49
|
207
|
Chris@49
|
208 template<typename eT>
|
Chris@49
|
209 arma_hot
|
Chris@49
|
210 arma_pure
|
Chris@49
|
211 arma_warn_unused
|
Chris@49
|
212 inline
|
Chris@49
|
213 eT
|
Chris@49
|
214 accu(const subview_col<eT>& X)
|
Chris@49
|
215 {
|
Chris@49
|
216 arma_extra_debug_sigprint();
|
Chris@49
|
217
|
Chris@49
|
218 return arrayops::accumulate( X.colptr(0), X.n_rows );
|
Chris@49
|
219 }
|
Chris@49
|
220
|
Chris@49
|
221
|
Chris@49
|
222
|
Chris@49
|
223 //! accumulate the elements of a cube
|
Chris@49
|
224 template<typename T1>
|
Chris@49
|
225 arma_hot
|
Chris@49
|
226 arma_warn_unused
|
Chris@49
|
227 inline
|
Chris@49
|
228 typename T1::elem_type
|
Chris@49
|
229 accu(const BaseCube<typename T1::elem_type,T1>& X)
|
Chris@49
|
230 {
|
Chris@49
|
231 arma_extra_debug_sigprint();
|
Chris@49
|
232
|
Chris@49
|
233 typedef typename T1::elem_type eT;
|
Chris@49
|
234 typedef typename ProxyCube<T1>::ea_type ea_type;
|
Chris@49
|
235
|
Chris@49
|
236 const ProxyCube<T1> A(X.get_ref());
|
Chris@49
|
237
|
Chris@49
|
238 if(ProxyCube<T1>::prefer_at_accessor == false)
|
Chris@49
|
239 {
|
Chris@49
|
240 ea_type P = A.get_ea();
|
Chris@49
|
241 const uword n_elem = A.get_n_elem();
|
Chris@49
|
242
|
Chris@49
|
243 eT val1 = eT(0);
|
Chris@49
|
244 eT val2 = eT(0);
|
Chris@49
|
245
|
Chris@49
|
246 uword i,j;
|
Chris@49
|
247
|
Chris@49
|
248 for(i=0, j=1; j<n_elem; i+=2, j+=2)
|
Chris@49
|
249 {
|
Chris@49
|
250 val1 += P[i];
|
Chris@49
|
251 val2 += P[j];
|
Chris@49
|
252 }
|
Chris@49
|
253
|
Chris@49
|
254 if(i < n_elem)
|
Chris@49
|
255 {
|
Chris@49
|
256 val1 += P[i];
|
Chris@49
|
257 }
|
Chris@49
|
258
|
Chris@49
|
259 return val1 + val2;
|
Chris@49
|
260 }
|
Chris@49
|
261 else
|
Chris@49
|
262 {
|
Chris@49
|
263 const uword n_rows = A.get_n_rows();
|
Chris@49
|
264 const uword n_cols = A.get_n_cols();
|
Chris@49
|
265 const uword n_slices = A.get_n_slices();
|
Chris@49
|
266
|
Chris@49
|
267 eT val = eT(0);
|
Chris@49
|
268
|
Chris@49
|
269 for(uword slice=0; slice<n_slices; ++slice)
|
Chris@49
|
270 for(uword col=0; col<n_cols; ++col)
|
Chris@49
|
271 for(uword row=0; row<n_rows; ++row)
|
Chris@49
|
272 {
|
Chris@49
|
273 val += A.at(row,col,slice);
|
Chris@49
|
274 }
|
Chris@49
|
275
|
Chris@49
|
276 return val;
|
Chris@49
|
277 }
|
Chris@49
|
278 }
|
Chris@49
|
279
|
Chris@49
|
280
|
Chris@49
|
281
|
Chris@49
|
282 template<typename T>
|
Chris@49
|
283 arma_inline
|
Chris@49
|
284 arma_warn_unused
|
Chris@49
|
285 const typename arma_scalar_only<T>::result &
|
Chris@49
|
286 accu(const T& x)
|
Chris@49
|
287 {
|
Chris@49
|
288 return x;
|
Chris@49
|
289 }
|
Chris@49
|
290
|
Chris@49
|
291
|
Chris@49
|
292
|
Chris@49
|
293 //! accumulate values in a sparse object
|
Chris@49
|
294 template<typename T1>
|
Chris@49
|
295 arma_hot
|
Chris@49
|
296 inline
|
Chris@49
|
297 arma_warn_unused
|
Chris@49
|
298 typename enable_if2<is_arma_sparse_type<T1>::value, typename T1::elem_type>::result
|
Chris@49
|
299 accu(const T1& x)
|
Chris@49
|
300 {
|
Chris@49
|
301 arma_extra_debug_sigprint();
|
Chris@49
|
302
|
Chris@49
|
303 typedef typename T1::elem_type eT;
|
Chris@49
|
304
|
Chris@49
|
305 const SpProxy<T1> p(x);
|
Chris@49
|
306
|
Chris@49
|
307 if(SpProxy<T1>::must_use_iterator == false)
|
Chris@49
|
308 {
|
Chris@49
|
309 // direct counting
|
Chris@49
|
310 return arrayops::accumulate(p.get_values(), p.get_n_nonzero());
|
Chris@49
|
311 }
|
Chris@49
|
312 else
|
Chris@49
|
313 {
|
Chris@49
|
314 typename SpProxy<T1>::const_iterator_type it = p.begin();
|
Chris@49
|
315 typename SpProxy<T1>::const_iterator_type it_end = p.end();
|
Chris@49
|
316
|
Chris@49
|
317 eT result = eT(0);
|
Chris@49
|
318
|
Chris@49
|
319 while(it != it_end)
|
Chris@49
|
320 {
|
Chris@49
|
321 result += (*it);
|
Chris@49
|
322 ++it;
|
Chris@49
|
323 }
|
Chris@49
|
324
|
Chris@49
|
325 return result;
|
Chris@49
|
326 }
|
Chris@49
|
327 }
|
Chris@49
|
328
|
Chris@49
|
329
|
Chris@49
|
330
|
Chris@49
|
331 //! @}
|