Chris@49
|
1 // Copyright (C) 2012 Ryan Curtin
|
Chris@49
|
2 // Copyright (C) 2012 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9 //! \addtogroup spop_var
|
Chris@49
|
10 //! @{
|
Chris@49
|
11
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 template<typename T1>
|
Chris@49
|
15 inline
|
Chris@49
|
16 void
|
Chris@49
|
17 spop_var::apply(SpMat<typename T1::pod_type>& out, const mtSpOp<typename T1::pod_type, T1, spop_var>& in)
|
Chris@49
|
18 {
|
Chris@49
|
19 arma_extra_debug_sigprint();
|
Chris@49
|
20
|
Chris@49
|
21 //typedef typename T1::elem_type in_eT;
|
Chris@49
|
22 typedef typename T1::pod_type out_eT;
|
Chris@49
|
23
|
Chris@49
|
24 const uword norm_type = in.aux_uword_a;
|
Chris@49
|
25 const uword dim = in.aux_uword_b;
|
Chris@49
|
26
|
Chris@49
|
27 arma_debug_check((norm_type > 1), "var(): incorrect usage. norm_type must be 0 or 1");
|
Chris@49
|
28 arma_debug_check((dim > 1), "var(): incorrect usage. dim must be 0 or 1");
|
Chris@49
|
29
|
Chris@49
|
30 SpProxy<T1> p(in.m);
|
Chris@49
|
31
|
Chris@49
|
32 if(p.is_alias(out) == false)
|
Chris@49
|
33 {
|
Chris@49
|
34 spop_var::apply_noalias(out, p, norm_type, dim);
|
Chris@49
|
35 }
|
Chris@49
|
36 else
|
Chris@49
|
37 {
|
Chris@49
|
38 SpMat<out_eT> tmp;
|
Chris@49
|
39
|
Chris@49
|
40 spop_var::apply_noalias(tmp, p, norm_type, dim);
|
Chris@49
|
41
|
Chris@49
|
42 out.steal_mem(tmp);
|
Chris@49
|
43 }
|
Chris@49
|
44 }
|
Chris@49
|
45
|
Chris@49
|
46
|
Chris@49
|
47
|
Chris@49
|
48 template<typename T1>
|
Chris@49
|
49 inline
|
Chris@49
|
50 void
|
Chris@49
|
51 spop_var::apply_noalias
|
Chris@49
|
52 (
|
Chris@49
|
53 SpMat<typename T1::pod_type>& out_ref,
|
Chris@49
|
54 const SpProxy<T1>& p,
|
Chris@49
|
55 const uword norm_type,
|
Chris@49
|
56 const uword dim
|
Chris@49
|
57 )
|
Chris@49
|
58 {
|
Chris@49
|
59 arma_extra_debug_sigprint();
|
Chris@49
|
60
|
Chris@49
|
61 typedef typename T1::elem_type in_eT;
|
Chris@49
|
62 //typedef typename T1::pod_type out_eT;
|
Chris@49
|
63
|
Chris@49
|
64 const uword p_n_rows = p.get_n_rows();
|
Chris@49
|
65 const uword p_n_cols = p.get_n_cols();
|
Chris@49
|
66
|
Chris@49
|
67 if(dim == 0)
|
Chris@49
|
68 {
|
Chris@49
|
69 arma_extra_debug_print("spop_var::apply(), dim = 0");
|
Chris@49
|
70
|
Chris@49
|
71 arma_debug_check((p_n_rows == 0), "var(): given object has zero rows");
|
Chris@49
|
72
|
Chris@49
|
73 out_ref.set_size(1, p_n_cols);
|
Chris@49
|
74
|
Chris@49
|
75 for(uword col = 0; col < p_n_cols; ++col)
|
Chris@49
|
76 {
|
Chris@49
|
77 if(SpProxy<T1>::must_use_iterator == true)
|
Chris@49
|
78 {
|
Chris@49
|
79 // We must use an iterator; we can't access memory directly.
|
Chris@49
|
80 typename SpProxy<T1>::const_iterator_type it = p.begin_col(col);
|
Chris@49
|
81 typename SpProxy<T1>::const_iterator_type end = p.begin_col(col + 1);
|
Chris@49
|
82
|
Chris@49
|
83 const uword n_zero = p.get_n_rows() - (end.pos() - it.pos());
|
Chris@49
|
84
|
Chris@49
|
85 // in_eT is used just to get the specialization right (complex / noncomplex)
|
Chris@49
|
86 out_ref.at(col) = spop_var::iterator_var(it, end, n_zero, norm_type, in_eT(0));
|
Chris@49
|
87 }
|
Chris@49
|
88 else
|
Chris@49
|
89 {
|
Chris@49
|
90 // We can use direct memory access to calculate the variance.
|
Chris@49
|
91 out_ref.at(col) = spop_var::direct_var
|
Chris@49
|
92 (
|
Chris@49
|
93 &p.get_values()[p.get_col_ptrs()[col]],
|
Chris@49
|
94 p.get_col_ptrs()[col + 1] - p.get_col_ptrs()[col],
|
Chris@49
|
95 p.get_n_rows(),
|
Chris@49
|
96 norm_type
|
Chris@49
|
97 );
|
Chris@49
|
98 }
|
Chris@49
|
99 }
|
Chris@49
|
100 }
|
Chris@49
|
101 else if(dim == 1)
|
Chris@49
|
102 {
|
Chris@49
|
103 arma_extra_debug_print("spop_var::apply_noalias(), dim = 1");
|
Chris@49
|
104
|
Chris@49
|
105 arma_debug_check((p_n_cols == 0), "var(): given object has zero columns");
|
Chris@49
|
106
|
Chris@49
|
107 out_ref.set_size(p_n_rows, 1);
|
Chris@49
|
108
|
Chris@49
|
109 for(uword row = 0; row < p_n_rows; ++row)
|
Chris@49
|
110 {
|
Chris@49
|
111 // We have to use an iterator here regardless of whether or not we can
|
Chris@49
|
112 // directly access memory.
|
Chris@49
|
113 typename SpProxy<T1>::const_row_iterator_type it = p.begin_row(row);
|
Chris@49
|
114 typename SpProxy<T1>::const_row_iterator_type end = p.end_row(row);
|
Chris@49
|
115
|
Chris@49
|
116 const uword n_zero = p.get_n_cols() - (end.pos() - it.pos());
|
Chris@49
|
117
|
Chris@49
|
118 out_ref.at(row) = spop_var::iterator_var(it, end, n_zero, norm_type, in_eT(0));
|
Chris@49
|
119 }
|
Chris@49
|
120 }
|
Chris@49
|
121 }
|
Chris@49
|
122
|
Chris@49
|
123
|
Chris@49
|
124
|
Chris@49
|
125 template<typename T1>
|
Chris@49
|
126 inline
|
Chris@49
|
127 typename T1::pod_type
|
Chris@49
|
128 spop_var::var_vec
|
Chris@49
|
129 (
|
Chris@49
|
130 const T1& X,
|
Chris@49
|
131 const uword norm_type
|
Chris@49
|
132 )
|
Chris@49
|
133 {
|
Chris@49
|
134 arma_extra_debug_sigprint();
|
Chris@49
|
135
|
Chris@49
|
136 arma_debug_check((norm_type > 1), "var(): incorrect usage. norm_type must be 0 or 1.");
|
Chris@49
|
137
|
Chris@49
|
138 // conditionally unwrap it into a temporary and then directly operate.
|
Chris@49
|
139
|
Chris@49
|
140 const unwrap_spmat<T1> tmp(X);
|
Chris@49
|
141
|
Chris@49
|
142 return direct_var(tmp.M.values, tmp.M.n_nonzero, tmp.M.n_elem, norm_type);
|
Chris@49
|
143 }
|
Chris@49
|
144
|
Chris@49
|
145
|
Chris@49
|
146
|
Chris@49
|
147 template<typename eT>
|
Chris@49
|
148 inline
|
Chris@49
|
149 eT
|
Chris@49
|
150 spop_var::direct_var
|
Chris@49
|
151 (
|
Chris@49
|
152 const eT* const X,
|
Chris@49
|
153 const uword length,
|
Chris@49
|
154 const uword N,
|
Chris@49
|
155 const uword norm_type
|
Chris@49
|
156 )
|
Chris@49
|
157 {
|
Chris@49
|
158 arma_extra_debug_sigprint();
|
Chris@49
|
159
|
Chris@49
|
160 if(length >= 2 && N >= 2)
|
Chris@49
|
161 {
|
Chris@49
|
162 const eT acc1 = spop_mean::direct_mean(X, length, N);
|
Chris@49
|
163
|
Chris@49
|
164 eT acc2 = eT(0);
|
Chris@49
|
165 eT acc3 = eT(0);
|
Chris@49
|
166
|
Chris@49
|
167 uword i, j;
|
Chris@49
|
168
|
Chris@49
|
169 for(i = 0, j = 1; j < length; i += 2, j += 2)
|
Chris@49
|
170 {
|
Chris@49
|
171 const eT Xi = X[i];
|
Chris@49
|
172 const eT Xj = X[j];
|
Chris@49
|
173
|
Chris@49
|
174 const eT tmpi = acc1 - Xi;
|
Chris@49
|
175 const eT tmpj = acc1 - Xj;
|
Chris@49
|
176
|
Chris@49
|
177 acc2 += tmpi * tmpi + tmpj * tmpj;
|
Chris@49
|
178 acc3 += tmpi + tmpj;
|
Chris@49
|
179 }
|
Chris@49
|
180
|
Chris@49
|
181 if(i < length)
|
Chris@49
|
182 {
|
Chris@49
|
183 const eT Xi = X[i];
|
Chris@49
|
184
|
Chris@49
|
185 const eT tmpi = acc1 - Xi;
|
Chris@49
|
186
|
Chris@49
|
187 acc2 += tmpi * tmpi;
|
Chris@49
|
188 acc3 += tmpi;
|
Chris@49
|
189 }
|
Chris@49
|
190
|
Chris@49
|
191 // Now add in all zero elements.
|
Chris@49
|
192 acc2 += (N - length) * (acc1 * acc1);
|
Chris@49
|
193 acc3 += (N - length) * acc1;
|
Chris@49
|
194
|
Chris@49
|
195 const eT norm_val = (norm_type == 0) ? eT(N - 1) : eT(N);
|
Chris@49
|
196 const eT var_val = (acc2 - (acc3 * acc3) / eT(N)) / norm_val;
|
Chris@49
|
197
|
Chris@49
|
198 return var_val;
|
Chris@49
|
199 }
|
Chris@49
|
200 else if(length == 1 && N > 1) // if N == 1, then variance is zero.
|
Chris@49
|
201 {
|
Chris@49
|
202 const eT mean = X[0] / eT(N);
|
Chris@49
|
203 const eT val = mean - X[0];
|
Chris@49
|
204
|
Chris@49
|
205 const eT acc2 = (val * val) + (N - length) * (mean * mean);
|
Chris@49
|
206 const eT acc3 = val + (N - length) * mean;
|
Chris@49
|
207
|
Chris@49
|
208 const eT norm_val = (norm_type == 0) ? eT(N - 1) : eT(N);
|
Chris@49
|
209 const eT var_val = (acc2 - (acc3 * acc3) / eT(N)) / norm_val;
|
Chris@49
|
210
|
Chris@49
|
211 return var_val;
|
Chris@49
|
212 }
|
Chris@49
|
213 else
|
Chris@49
|
214 {
|
Chris@49
|
215 return eT(0);
|
Chris@49
|
216 }
|
Chris@49
|
217 }
|
Chris@49
|
218
|
Chris@49
|
219
|
Chris@49
|
220
|
Chris@49
|
221 template<typename T>
|
Chris@49
|
222 inline
|
Chris@49
|
223 T
|
Chris@49
|
224 spop_var::direct_var
|
Chris@49
|
225 (
|
Chris@49
|
226 const std::complex<T>* const X,
|
Chris@49
|
227 const uword length,
|
Chris@49
|
228 const uword N,
|
Chris@49
|
229 const uword norm_type
|
Chris@49
|
230 )
|
Chris@49
|
231 {
|
Chris@49
|
232 arma_extra_debug_sigprint();
|
Chris@49
|
233
|
Chris@49
|
234 typedef typename std::complex<T> eT;
|
Chris@49
|
235
|
Chris@49
|
236 if(length >= 2 && N >= 2)
|
Chris@49
|
237 {
|
Chris@49
|
238 const eT acc1 = spop_mean::direct_mean(X, length, N);
|
Chris@49
|
239
|
Chris@49
|
240 T acc2 = T(0);
|
Chris@49
|
241 eT acc3 = eT(0);
|
Chris@49
|
242
|
Chris@49
|
243 for (uword i = 0; i < length; ++i)
|
Chris@49
|
244 {
|
Chris@49
|
245 const eT tmp = acc1 - X[i];
|
Chris@49
|
246
|
Chris@49
|
247 acc2 += std::norm(tmp);
|
Chris@49
|
248 acc3 += tmp;
|
Chris@49
|
249 }
|
Chris@49
|
250
|
Chris@49
|
251 // Add zero elements to sums
|
Chris@49
|
252 acc2 += std::norm(acc1) * T(N - length);
|
Chris@49
|
253 acc3 += acc1 * T(N - length);
|
Chris@49
|
254
|
Chris@49
|
255 const T norm_val = (norm_type == 0) ? T(N - 1) : T(N);
|
Chris@49
|
256 const T var_val = (acc2 - std::norm(acc3) / T(N)) / norm_val;
|
Chris@49
|
257
|
Chris@49
|
258 return var_val;
|
Chris@49
|
259 }
|
Chris@49
|
260 else if(length == 1 && N > 1) // if N == 1, then variance is zero.
|
Chris@49
|
261 {
|
Chris@49
|
262 const eT mean = X[0] / T(N);
|
Chris@49
|
263 const eT val = mean - X[0];
|
Chris@49
|
264
|
Chris@49
|
265 const T acc2 = std::norm(val) + (N - length) * std::norm(mean);
|
Chris@49
|
266 const eT acc3 = val + T(N - length) * mean;
|
Chris@49
|
267
|
Chris@49
|
268 const T norm_val = (norm_type == 0) ? T(N - 1) : T(N);
|
Chris@49
|
269 const T var_val = (acc2 - std::norm(acc3) / T(N)) / norm_val;
|
Chris@49
|
270
|
Chris@49
|
271 return var_val;
|
Chris@49
|
272 }
|
Chris@49
|
273 else
|
Chris@49
|
274 {
|
Chris@49
|
275 return T(0); // All elements are zero
|
Chris@49
|
276 }
|
Chris@49
|
277 }
|
Chris@49
|
278
|
Chris@49
|
279
|
Chris@49
|
280
|
Chris@49
|
281 template<typename T1, typename eT>
|
Chris@49
|
282 inline
|
Chris@49
|
283 eT
|
Chris@49
|
284 spop_var::iterator_var
|
Chris@49
|
285 (
|
Chris@49
|
286 T1& it,
|
Chris@49
|
287 const T1& end,
|
Chris@49
|
288 const uword n_zero,
|
Chris@49
|
289 const uword norm_type,
|
Chris@49
|
290 const eT junk1,
|
Chris@49
|
291 const typename arma_not_cx<eT>::result* junk2
|
Chris@49
|
292 )
|
Chris@49
|
293 {
|
Chris@49
|
294 arma_extra_debug_sigprint();
|
Chris@49
|
295 arma_ignore(junk1);
|
Chris@49
|
296 arma_ignore(junk2);
|
Chris@49
|
297
|
Chris@49
|
298 T1 new_it(it); // for mean
|
Chris@49
|
299 // T1 backup_it(it); // in case we have to call robust iterator_var
|
Chris@49
|
300 eT mean = spop_mean::iterator_mean(new_it, end, n_zero, eT(0));
|
Chris@49
|
301
|
Chris@49
|
302 eT acc2 = eT(0);
|
Chris@49
|
303 eT acc3 = eT(0);
|
Chris@49
|
304
|
Chris@49
|
305 const uword it_begin_pos = it.pos();
|
Chris@49
|
306
|
Chris@49
|
307 while (it != end)
|
Chris@49
|
308 {
|
Chris@49
|
309 const eT tmp = mean - (*it);
|
Chris@49
|
310
|
Chris@49
|
311 acc2 += (tmp * tmp);
|
Chris@49
|
312 acc3 += (tmp);
|
Chris@49
|
313
|
Chris@49
|
314 ++it;
|
Chris@49
|
315 }
|
Chris@49
|
316
|
Chris@49
|
317 const uword n_nonzero = (it.pos() - it_begin_pos);
|
Chris@49
|
318 if (n_nonzero == 0)
|
Chris@49
|
319 {
|
Chris@49
|
320 return eT(0);
|
Chris@49
|
321 }
|
Chris@49
|
322
|
Chris@49
|
323 if (n_nonzero + n_zero == 1)
|
Chris@49
|
324 {
|
Chris@49
|
325 return eT(0); // only one element
|
Chris@49
|
326 }
|
Chris@49
|
327
|
Chris@49
|
328 // Add in entries for zeros.
|
Chris@49
|
329 acc2 += eT(n_zero) * (mean * mean);
|
Chris@49
|
330 acc3 += eT(n_zero) * mean;
|
Chris@49
|
331
|
Chris@49
|
332 const eT norm_val = (norm_type == 0) ? eT(n_zero + n_nonzero - 1) : eT(n_zero + n_nonzero);
|
Chris@49
|
333 const eT var_val = (acc2 - (acc3 * acc3) / eT(n_nonzero + n_zero)) / norm_val;
|
Chris@49
|
334
|
Chris@49
|
335 return var_val;
|
Chris@49
|
336 }
|
Chris@49
|
337
|
Chris@49
|
338
|
Chris@49
|
339
|
Chris@49
|
340 template<typename T1, typename eT>
|
Chris@49
|
341 inline
|
Chris@49
|
342 typename get_pod_type<eT>::result
|
Chris@49
|
343 spop_var::iterator_var
|
Chris@49
|
344 (
|
Chris@49
|
345 T1& it,
|
Chris@49
|
346 const T1& end,
|
Chris@49
|
347 const uword n_zero,
|
Chris@49
|
348 const uword norm_type,
|
Chris@49
|
349 const eT junk1,
|
Chris@49
|
350 const typename arma_cx_only<eT>::result* junk2
|
Chris@49
|
351 )
|
Chris@49
|
352 {
|
Chris@49
|
353 arma_extra_debug_sigprint();
|
Chris@49
|
354 arma_ignore(junk1);
|
Chris@49
|
355 arma_ignore(junk2);
|
Chris@49
|
356
|
Chris@49
|
357 typedef typename get_pod_type<eT>::result T;
|
Chris@49
|
358
|
Chris@49
|
359 T1 new_it(it); // for mean
|
Chris@49
|
360 // T1 backup_it(it); // in case we have to call robust iterator_var
|
Chris@49
|
361 eT mean = spop_mean::iterator_mean(new_it, end, n_zero, eT(0));
|
Chris@49
|
362
|
Chris@49
|
363 T acc2 = T(0);
|
Chris@49
|
364 eT acc3 = eT(0);
|
Chris@49
|
365
|
Chris@49
|
366 const uword it_begin_pos = it.pos();
|
Chris@49
|
367
|
Chris@49
|
368 while (it != end)
|
Chris@49
|
369 {
|
Chris@49
|
370 eT tmp = mean - (*it);
|
Chris@49
|
371
|
Chris@49
|
372 acc2 += std::norm(tmp);
|
Chris@49
|
373 acc3 += (tmp);
|
Chris@49
|
374
|
Chris@49
|
375 ++it;
|
Chris@49
|
376 }
|
Chris@49
|
377
|
Chris@49
|
378 const uword n_nonzero = (it.pos() - it_begin_pos);
|
Chris@49
|
379 if (n_nonzero == 0)
|
Chris@49
|
380 {
|
Chris@49
|
381 return T(0);
|
Chris@49
|
382 }
|
Chris@49
|
383
|
Chris@49
|
384 if (n_nonzero + n_zero == 1)
|
Chris@49
|
385 {
|
Chris@49
|
386 return T(0); // only one element
|
Chris@49
|
387 }
|
Chris@49
|
388
|
Chris@49
|
389 // Add in entries for zero elements.
|
Chris@49
|
390 acc2 += T(n_zero) * std::norm(mean);
|
Chris@49
|
391 acc3 += T(n_zero) * mean;
|
Chris@49
|
392
|
Chris@49
|
393 const T norm_val = (norm_type == 0) ? T(n_zero + n_nonzero - 1) : T(n_zero + n_nonzero);
|
Chris@49
|
394 const T var_val = (acc2 - std::norm(acc3) / T(n_nonzero + n_zero)) / norm_val;
|
Chris@49
|
395
|
Chris@49
|
396 return var_val;
|
Chris@49
|
397 }
|
Chris@49
|
398
|
Chris@49
|
399
|
Chris@49
|
400
|
Chris@49
|
401 //! @}
|