Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-2.4.4/include/armadillo_bits/running_stat_vec_meat.hpp @ 0:8b6102e2a9b0
Armadillo Library
author | maxzanoni76 <max.zanoni@eecs.qmul.ac.uk> |
---|---|
date | Wed, 11 Apr 2012 09:27:06 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:8b6102e2a9b0 |
---|---|
1 // Copyright (C) 2009-2011 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2009-2011 Conrad Sanderson | |
3 // | |
4 // This file is part of the Armadillo C++ library. | |
5 // It is provided without any warranty of fitness | |
6 // for any purpose. You can redistribute this file | |
7 // and/or modify it under the terms of the GNU | |
8 // Lesser General Public License (LGPL) as published | |
9 // by the Free Software Foundation, either version 3 | |
10 // of the License or (at your option) any later version. | |
11 // (see http://www.opensource.org/licenses for more info) | |
12 | |
13 | |
14 //! \addtogroup running_stat_vec | |
15 //! @{ | |
16 | |
17 | |
18 | |
19 template<typename eT> | |
20 running_stat_vec<eT>::~running_stat_vec() | |
21 { | |
22 arma_extra_debug_sigprint_this(this); | |
23 } | |
24 | |
25 | |
26 | |
27 template<typename eT> | |
28 running_stat_vec<eT>::running_stat_vec(const bool in_calc_cov) | |
29 : calc_cov(in_calc_cov) | |
30 { | |
31 arma_extra_debug_sigprint_this(this); | |
32 } | |
33 | |
34 | |
35 | |
36 template<typename eT> | |
37 running_stat_vec<eT>::running_stat_vec(const running_stat_vec<eT>& in_rsv) | |
38 : calc_cov (in_rsv.calc_cov) | |
39 , counter (in_rsv.counter) | |
40 , r_mean (in_rsv.r_mean) | |
41 , r_var (in_rsv.r_var) | |
42 , r_cov (in_rsv.r_cov) | |
43 , min_val (in_rsv.min_val) | |
44 , max_val (in_rsv.max_val) | |
45 , min_val_norm(in_rsv.min_val_norm) | |
46 , max_val_norm(in_rsv.max_val_norm) | |
47 { | |
48 arma_extra_debug_sigprint_this(this); | |
49 } | |
50 | |
51 | |
52 | |
53 template<typename eT> | |
54 const running_stat_vec<eT>& | |
55 running_stat_vec<eT>::operator=(const running_stat_vec<eT>& in_rsv) | |
56 { | |
57 arma_extra_debug_sigprint(); | |
58 | |
59 access::rw(calc_cov) = in_rsv.calc_cov; | |
60 | |
61 counter = in_rsv.counter; | |
62 r_mean = in_rsv.r_mean; | |
63 r_var = in_rsv.r_var; | |
64 r_cov = in_rsv.r_cov; | |
65 min_val = in_rsv.min_val; | |
66 max_val = in_rsv.max_val; | |
67 min_val_norm = in_rsv.min_val_norm; | |
68 max_val_norm = in_rsv.max_val_norm; | |
69 | |
70 return *this; | |
71 } | |
72 | |
73 | |
74 | |
75 //! update statistics to reflect new sample | |
76 template<typename eT> | |
77 template<typename T1> | |
78 arma_hot | |
79 inline | |
80 void | |
81 running_stat_vec<eT>::operator() (const Base<typename get_pod_type<eT>::result, T1>& X) | |
82 { | |
83 arma_extra_debug_sigprint(); | |
84 | |
85 //typedef typename get_pod_type<eT>::result T; | |
86 | |
87 const unwrap<T1> tmp(X.get_ref()); | |
88 const Mat<eT>& sample = tmp.M; | |
89 | |
90 if( sample.is_empty() ) | |
91 { | |
92 return; | |
93 } | |
94 | |
95 if( sample.is_finite() == false ) | |
96 { | |
97 arma_warn(true, "running_stat_vec: sample ignored as it has non-finite elements"); | |
98 return; | |
99 } | |
100 | |
101 running_stat_vec_aux::update_stats(*this, sample); | |
102 } | |
103 | |
104 | |
105 | |
106 //! update statistics to reflect new sample (version for complex numbers) | |
107 template<typename eT> | |
108 template<typename T1> | |
109 arma_hot | |
110 inline | |
111 void | |
112 running_stat_vec<eT>::operator() (const Base<std::complex<typename get_pod_type<eT>::result>, T1>& X) | |
113 { | |
114 arma_extra_debug_sigprint(); | |
115 | |
116 //typedef typename std::complex<typename get_pod_type<eT>::result> eT; | |
117 | |
118 const unwrap<T1> tmp(X.get_ref()); | |
119 const Mat<eT>& sample = tmp.M; | |
120 | |
121 if( sample.is_empty() ) | |
122 { | |
123 return; | |
124 } | |
125 | |
126 if( sample.is_finite() == false ) | |
127 { | |
128 arma_warn(true, "running_stat_vec: sample ignored as it has non-finite elements"); | |
129 return; | |
130 } | |
131 | |
132 running_stat_vec_aux::update_stats(*this, sample); | |
133 } | |
134 | |
135 | |
136 | |
137 //! set all statistics to zero | |
138 template<typename eT> | |
139 inline | |
140 void | |
141 running_stat_vec<eT>::reset() | |
142 { | |
143 arma_extra_debug_sigprint(); | |
144 | |
145 counter.reset(); | |
146 | |
147 r_mean.reset(); | |
148 r_var.reset(); | |
149 r_cov.reset(); | |
150 | |
151 min_val.reset(); | |
152 max_val.reset(); | |
153 | |
154 min_val_norm.reset(); | |
155 max_val_norm.reset(); | |
156 | |
157 r_var_dummy.reset(); | |
158 r_cov_dummy.reset(); | |
159 | |
160 tmp1.reset(); | |
161 tmp2.reset(); | |
162 } | |
163 | |
164 | |
165 | |
166 //! mean or average value | |
167 template<typename eT> | |
168 inline | |
169 const Mat<eT>& | |
170 running_stat_vec<eT>::mean() const | |
171 { | |
172 arma_extra_debug_sigprint(); | |
173 | |
174 return r_mean; | |
175 } | |
176 | |
177 | |
178 | |
179 //! variance | |
180 template<typename eT> | |
181 inline | |
182 const Mat<typename get_pod_type<eT>::result>& | |
183 running_stat_vec<eT>::var(const uword norm_type) | |
184 { | |
185 arma_extra_debug_sigprint(); | |
186 | |
187 const T N = counter.value(); | |
188 | |
189 if(N > T(1)) | |
190 { | |
191 if(norm_type == 0) | |
192 { | |
193 return r_var; | |
194 } | |
195 else | |
196 { | |
197 const T N_minus_1 = counter.value_minus_1(); | |
198 | |
199 r_var_dummy = (N_minus_1/N) * r_var; | |
200 | |
201 return r_var_dummy; | |
202 } | |
203 } | |
204 else | |
205 { | |
206 r_var_dummy.zeros(r_mean.n_rows, r_mean.n_cols); | |
207 | |
208 return r_var_dummy; | |
209 } | |
210 | |
211 } | |
212 | |
213 | |
214 | |
215 //! standard deviation | |
216 template<typename eT> | |
217 inline | |
218 Mat<typename get_pod_type<eT>::result> | |
219 running_stat_vec<eT>::stddev(const uword norm_type) const | |
220 { | |
221 arma_extra_debug_sigprint(); | |
222 | |
223 const T N = counter.value(); | |
224 | |
225 if(N > T(1)) | |
226 { | |
227 if(norm_type == 0) | |
228 { | |
229 return sqrt(r_var); | |
230 } | |
231 else | |
232 { | |
233 const T N_minus_1 = counter.value_minus_1(); | |
234 | |
235 return sqrt( (N_minus_1/N) * r_var ); | |
236 } | |
237 } | |
238 else | |
239 { | |
240 return Mat<T>(); | |
241 } | |
242 } | |
243 | |
244 | |
245 | |
246 //! covariance | |
247 template<typename eT> | |
248 inline | |
249 const Mat<eT>& | |
250 running_stat_vec<eT>::cov(const uword norm_type) | |
251 { | |
252 arma_extra_debug_sigprint(); | |
253 | |
254 if(calc_cov == true) | |
255 { | |
256 const T N = counter.value(); | |
257 | |
258 if(N > T(1)) | |
259 { | |
260 if(norm_type == 0) | |
261 { | |
262 return r_cov; | |
263 } | |
264 else | |
265 { | |
266 const T N_minus_1 = counter.value_minus_1(); | |
267 | |
268 r_cov_dummy = (N_minus_1/N) * r_cov; | |
269 | |
270 return r_cov_dummy; | |
271 } | |
272 } | |
273 else | |
274 { | |
275 r_cov_dummy.zeros(r_mean.n_rows, r_mean.n_cols); | |
276 | |
277 return r_cov_dummy; | |
278 } | |
279 } | |
280 else | |
281 { | |
282 r_cov_dummy.reset(); | |
283 | |
284 return r_cov_dummy; | |
285 } | |
286 | |
287 } | |
288 | |
289 | |
290 | |
291 //! vector with minimum values | |
292 template<typename eT> | |
293 inline | |
294 const Mat<eT>& | |
295 running_stat_vec<eT>::min() const | |
296 { | |
297 arma_extra_debug_sigprint(); | |
298 | |
299 return min_val; | |
300 } | |
301 | |
302 | |
303 | |
304 //! vector with maximum values | |
305 template<typename eT> | |
306 inline | |
307 const Mat<eT>& | |
308 running_stat_vec<eT>::max() const | |
309 { | |
310 arma_extra_debug_sigprint(); | |
311 | |
312 return max_val; | |
313 } | |
314 | |
315 | |
316 | |
317 //! number of samples so far | |
318 template<typename eT> | |
319 inline | |
320 typename get_pod_type<eT>::result | |
321 running_stat_vec<eT>::count() const | |
322 { | |
323 arma_extra_debug_sigprint(); | |
324 | |
325 return counter.value(); | |
326 } | |
327 | |
328 | |
329 | |
330 // | |
331 | |
332 | |
333 | |
334 //! update statistics to reflect new sample | |
335 template<typename eT> | |
336 inline | |
337 void | |
338 running_stat_vec_aux::update_stats(running_stat_vec<eT>& x, const Mat<eT>& sample) | |
339 { | |
340 arma_extra_debug_sigprint(); | |
341 | |
342 typedef typename running_stat_vec<eT>::T T; | |
343 | |
344 const T N = x.counter.value(); | |
345 | |
346 if(N > T(0)) | |
347 { | |
348 arma_debug_assert_same_size(x.r_mean, sample, "running_stat_vec(): dimensionality mismatch"); | |
349 | |
350 const uword n_elem = sample.n_elem; | |
351 const eT* sample_mem = sample.memptr(); | |
352 eT* r_mean_mem = x.r_mean.memptr(); | |
353 T* r_var_mem = x.r_var.memptr(); | |
354 eT* min_val_mem = x.min_val.memptr(); | |
355 eT* max_val_mem = x.max_val.memptr(); | |
356 | |
357 const T N_plus_1 = x.counter.value_plus_1(); | |
358 const T N_minus_1 = x.counter.value_minus_1(); | |
359 | |
360 if(x.calc_cov == true) | |
361 { | |
362 Mat<eT>& tmp1 = x.tmp1; | |
363 Mat<eT>& tmp2 = x.tmp2; | |
364 | |
365 tmp1 = sample - x.r_mean; | |
366 | |
367 if(sample.n_cols == 1) | |
368 { | |
369 tmp2 = tmp1*trans(tmp1); | |
370 } | |
371 else | |
372 { | |
373 tmp2 = trans(tmp1)*tmp1; | |
374 } | |
375 | |
376 x.r_cov *= (N_minus_1/N); | |
377 x.r_cov += tmp2 / N_plus_1; | |
378 } | |
379 | |
380 | |
381 for(uword i=0; i<n_elem; ++i) | |
382 { | |
383 const eT val = sample_mem[i]; | |
384 | |
385 if(val < min_val_mem[i]) | |
386 { | |
387 min_val_mem[i] = val; | |
388 } | |
389 | |
390 if(val > max_val_mem[i]) | |
391 { | |
392 max_val_mem[i] = val; | |
393 } | |
394 | |
395 const eT r_mean_val = r_mean_mem[i]; | |
396 const eT tmp = val - r_mean_val; | |
397 | |
398 r_var_mem[i] = N_minus_1/N * r_var_mem[i] + (tmp*tmp)/N_plus_1; | |
399 | |
400 r_mean_mem[i] = r_mean_val + (val - r_mean_val)/N_plus_1; | |
401 } | |
402 } | |
403 else | |
404 { | |
405 arma_debug_check( (sample.is_vec() == false), "running_stat_vec(): given sample is not a vector"); | |
406 | |
407 x.r_mean.set_size(sample.n_rows, sample.n_cols); | |
408 | |
409 x.r_var.zeros(sample.n_rows, sample.n_cols); | |
410 | |
411 if(x.calc_cov == true) | |
412 { | |
413 x.r_cov.zeros(sample.n_elem, sample.n_elem); | |
414 } | |
415 | |
416 x.min_val.set_size(sample.n_rows, sample.n_cols); | |
417 x.max_val.set_size(sample.n_rows, sample.n_cols); | |
418 | |
419 | |
420 const uword n_elem = sample.n_elem; | |
421 const eT* sample_mem = sample.memptr(); | |
422 eT* r_mean_mem = x.r_mean.memptr(); | |
423 eT* min_val_mem = x.min_val.memptr(); | |
424 eT* max_val_mem = x.max_val.memptr(); | |
425 | |
426 | |
427 for(uword i=0; i<n_elem; ++i) | |
428 { | |
429 const eT val = sample_mem[i]; | |
430 | |
431 r_mean_mem[i] = val; | |
432 min_val_mem[i] = val; | |
433 max_val_mem[i] = val; | |
434 } | |
435 } | |
436 | |
437 x.counter++; | |
438 } | |
439 | |
440 | |
441 | |
442 //! update statistics to reflect new sample (version for complex numbers) | |
443 template<typename T> | |
444 inline | |
445 void | |
446 running_stat_vec_aux::update_stats(running_stat_vec< std::complex<T> >& x, const Mat<T>& sample) | |
447 { | |
448 arma_extra_debug_sigprint(); | |
449 | |
450 const Mat< std::complex<T> > tmp = conv_to< Mat< std::complex<T> > >::from(sample); | |
451 | |
452 running_stat_vec_aux::update_stats(x, tmp); | |
453 } | |
454 | |
455 | |
456 | |
457 //! alter statistics to reflect new sample (version for complex numbers) | |
458 template<typename T> | |
459 inline | |
460 void | |
461 running_stat_vec_aux::update_stats(running_stat_vec< std::complex<T> >& x, const Mat< std::complex<T> >& sample) | |
462 { | |
463 arma_extra_debug_sigprint(); | |
464 | |
465 typedef typename std::complex<T> eT; | |
466 | |
467 const T N = x.counter.value(); | |
468 | |
469 if(N > T(0)) | |
470 { | |
471 arma_debug_assert_same_size(x.r_mean, sample, "running_stat_vec(): dimensionality mismatch"); | |
472 | |
473 const uword n_elem = sample.n_elem; | |
474 const eT* sample_mem = sample.memptr(); | |
475 eT* r_mean_mem = x.r_mean.memptr(); | |
476 T* r_var_mem = x.r_var.memptr(); | |
477 eT* min_val_mem = x.min_val.memptr(); | |
478 eT* max_val_mem = x.max_val.memptr(); | |
479 T* min_val_norm_mem = x.min_val_norm.memptr(); | |
480 T* max_val_norm_mem = x.max_val_norm.memptr(); | |
481 | |
482 const T N_plus_1 = x.counter.value_plus_1(); | |
483 const T N_minus_1 = x.counter.value_minus_1(); | |
484 | |
485 if(x.calc_cov == true) | |
486 { | |
487 Mat<eT>& tmp1 = x.tmp1; | |
488 Mat<eT>& tmp2 = x.tmp2; | |
489 | |
490 tmp1 = sample - x.r_mean; | |
491 | |
492 if(sample.n_cols == 1) | |
493 { | |
494 tmp2 = arma::conj(tmp1)*strans(tmp1); | |
495 } | |
496 else | |
497 { | |
498 tmp2 = trans(tmp1)*tmp1; //tmp2 = strans(conj(tmp1))*tmp1; | |
499 } | |
500 | |
501 x.r_cov *= (N_minus_1/N); | |
502 x.r_cov += tmp2 / N_plus_1; | |
503 } | |
504 | |
505 | |
506 for(uword i=0; i<n_elem; ++i) | |
507 { | |
508 const eT& val = sample_mem[i]; | |
509 const T val_norm = std::norm(val); | |
510 | |
511 if(val_norm < min_val_norm_mem[i]) | |
512 { | |
513 min_val_norm_mem[i] = val_norm; | |
514 min_val_mem[i] = val; | |
515 } | |
516 | |
517 if(val_norm > max_val_norm_mem[i]) | |
518 { | |
519 max_val_norm_mem[i] = val_norm; | |
520 max_val_mem[i] = val; | |
521 } | |
522 | |
523 const eT& r_mean_val = r_mean_mem[i]; | |
524 | |
525 r_var_mem[i] = N_minus_1/N * r_var_mem[i] + std::norm(val - r_mean_val)/N_plus_1; | |
526 | |
527 r_mean_mem[i] = r_mean_val + (val - r_mean_val)/N_plus_1; | |
528 } | |
529 | |
530 } | |
531 else | |
532 { | |
533 arma_debug_check( (sample.is_vec() == false), "running_stat_vec(): given sample is not a vector"); | |
534 | |
535 x.r_mean.set_size(sample.n_rows, sample.n_cols); | |
536 | |
537 x.r_var.zeros(sample.n_rows, sample.n_cols); | |
538 | |
539 if(x.calc_cov == true) | |
540 { | |
541 x.r_cov.zeros(sample.n_elem, sample.n_elem); | |
542 } | |
543 | |
544 x.min_val.set_size(sample.n_rows, sample.n_cols); | |
545 x.max_val.set_size(sample.n_rows, sample.n_cols); | |
546 | |
547 x.min_val_norm.set_size(sample.n_rows, sample.n_cols); | |
548 x.max_val_norm.set_size(sample.n_rows, sample.n_cols); | |
549 | |
550 | |
551 const uword n_elem = sample.n_elem; | |
552 const eT* sample_mem = sample.memptr(); | |
553 eT* r_mean_mem = x.r_mean.memptr(); | |
554 eT* min_val_mem = x.min_val.memptr(); | |
555 eT* max_val_mem = x.max_val.memptr(); | |
556 T* min_val_norm_mem = x.min_val_norm.memptr(); | |
557 T* max_val_norm_mem = x.max_val_norm.memptr(); | |
558 | |
559 for(uword i=0; i<n_elem; ++i) | |
560 { | |
561 const eT& val = sample_mem[i]; | |
562 const T val_norm = std::norm(val); | |
563 | |
564 r_mean_mem[i] = val; | |
565 min_val_mem[i] = val; | |
566 max_val_mem[i] = val; | |
567 | |
568 min_val_norm_mem[i] = val_norm; | |
569 max_val_norm_mem[i] = val_norm; | |
570 } | |
571 } | |
572 | |
573 x.counter++; | |
574 } | |
575 | |
576 | |
577 | |
578 //! @} |