Mercurial > hg > segmenter-vamp-plugin
view armadillo-3.900.4/include/armadillo_bits/running_stat_vec_meat.hpp @ 84:55a047986812 tip
Update library URI so as not to be document-local
author | Chris Cannam |
---|---|
date | Wed, 22 Apr 2020 14:21:57 +0100 |
parents | 1ec0e2823891 |
children |
line wrap: on
line source
// Copyright (C) 2009-2011 NICTA (www.nicta.com.au) // Copyright (C) 2009-2011 Conrad Sanderson // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. //! \addtogroup running_stat_vec //! @{ template<typename eT> running_stat_vec<eT>::~running_stat_vec() { arma_extra_debug_sigprint_this(this); } template<typename eT> running_stat_vec<eT>::running_stat_vec(const bool in_calc_cov) : calc_cov(in_calc_cov) { arma_extra_debug_sigprint_this(this); } template<typename eT> running_stat_vec<eT>::running_stat_vec(const running_stat_vec<eT>& in_rsv) : calc_cov (in_rsv.calc_cov) , counter (in_rsv.counter) , r_mean (in_rsv.r_mean) , r_var (in_rsv.r_var) , r_cov (in_rsv.r_cov) , min_val (in_rsv.min_val) , max_val (in_rsv.max_val) , min_val_norm(in_rsv.min_val_norm) , max_val_norm(in_rsv.max_val_norm) { arma_extra_debug_sigprint_this(this); } template<typename eT> const running_stat_vec<eT>& running_stat_vec<eT>::operator=(const running_stat_vec<eT>& in_rsv) { arma_extra_debug_sigprint(); access::rw(calc_cov) = in_rsv.calc_cov; counter = in_rsv.counter; r_mean = in_rsv.r_mean; r_var = in_rsv.r_var; r_cov = in_rsv.r_cov; min_val = in_rsv.min_val; max_val = in_rsv.max_val; min_val_norm = in_rsv.min_val_norm; max_val_norm = in_rsv.max_val_norm; return *this; } //! update statistics to reflect new sample template<typename eT> template<typename T1> arma_hot inline void running_stat_vec<eT>::operator() (const Base<typename get_pod_type<eT>::result, T1>& X) { arma_extra_debug_sigprint(); //typedef typename get_pod_type<eT>::result T; const unwrap<T1> tmp(X.get_ref()); const Mat<eT>& sample = tmp.M; if( sample.is_empty() ) { return; } if( sample.is_finite() == false ) { arma_warn(true, "running_stat_vec: sample ignored as it has non-finite elements"); return; } running_stat_vec_aux::update_stats(*this, sample); } //! update statistics to reflect new sample (version for complex numbers) template<typename eT> template<typename T1> arma_hot inline void running_stat_vec<eT>::operator() (const Base<std::complex<typename get_pod_type<eT>::result>, T1>& X) { arma_extra_debug_sigprint(); //typedef typename std::complex<typename get_pod_type<eT>::result> eT; const unwrap<T1> tmp(X.get_ref()); const Mat<eT>& sample = tmp.M; if( sample.is_empty() ) { return; } if( sample.is_finite() == false ) { arma_warn(true, "running_stat_vec: sample ignored as it has non-finite elements"); return; } running_stat_vec_aux::update_stats(*this, sample); } //! set all statistics to zero template<typename eT> inline void running_stat_vec<eT>::reset() { arma_extra_debug_sigprint(); counter.reset(); r_mean.reset(); r_var.reset(); r_cov.reset(); min_val.reset(); max_val.reset(); min_val_norm.reset(); max_val_norm.reset(); r_var_dummy.reset(); r_cov_dummy.reset(); tmp1.reset(); tmp2.reset(); } //! mean or average value template<typename eT> inline const Mat<eT>& running_stat_vec<eT>::mean() const { arma_extra_debug_sigprint(); return r_mean; } //! variance template<typename eT> inline const Mat<typename get_pod_type<eT>::result>& running_stat_vec<eT>::var(const uword norm_type) { arma_extra_debug_sigprint(); const T N = counter.value(); if(N > T(1)) { if(norm_type == 0) { return r_var; } else { const T N_minus_1 = counter.value_minus_1(); r_var_dummy = (N_minus_1/N) * r_var; return r_var_dummy; } } else { r_var_dummy.zeros(r_mean.n_rows, r_mean.n_cols); return r_var_dummy; } } //! standard deviation template<typename eT> inline Mat<typename get_pod_type<eT>::result> running_stat_vec<eT>::stddev(const uword norm_type) const { arma_extra_debug_sigprint(); const T N = counter.value(); if(N > T(1)) { if(norm_type == 0) { return sqrt(r_var); } else { const T N_minus_1 = counter.value_minus_1(); return sqrt( (N_minus_1/N) * r_var ); } } else { return Mat<T>(); } } //! covariance template<typename eT> inline const Mat<eT>& running_stat_vec<eT>::cov(const uword norm_type) { arma_extra_debug_sigprint(); if(calc_cov == true) { const T N = counter.value(); if(N > T(1)) { if(norm_type == 0) { return r_cov; } else { const T N_minus_1 = counter.value_minus_1(); r_cov_dummy = (N_minus_1/N) * r_cov; return r_cov_dummy; } } else { r_cov_dummy.zeros(r_mean.n_rows, r_mean.n_cols); return r_cov_dummy; } } else { r_cov_dummy.reset(); return r_cov_dummy; } } //! vector with minimum values template<typename eT> inline const Mat<eT>& running_stat_vec<eT>::min() const { arma_extra_debug_sigprint(); return min_val; } //! vector with maximum values template<typename eT> inline const Mat<eT>& running_stat_vec<eT>::max() const { arma_extra_debug_sigprint(); return max_val; } //! number of samples so far template<typename eT> inline typename get_pod_type<eT>::result running_stat_vec<eT>::count() const { arma_extra_debug_sigprint(); return counter.value(); } // //! update statistics to reflect new sample template<typename eT> inline void running_stat_vec_aux::update_stats(running_stat_vec<eT>& x, const Mat<eT>& sample) { arma_extra_debug_sigprint(); typedef typename running_stat_vec<eT>::T T; const T N = x.counter.value(); if(N > T(0)) { arma_debug_assert_same_size(x.r_mean, sample, "running_stat_vec(): dimensionality mismatch"); const uword n_elem = sample.n_elem; const eT* sample_mem = sample.memptr(); eT* r_mean_mem = x.r_mean.memptr(); T* r_var_mem = x.r_var.memptr(); eT* min_val_mem = x.min_val.memptr(); eT* max_val_mem = x.max_val.memptr(); const T N_plus_1 = x.counter.value_plus_1(); const T N_minus_1 = x.counter.value_minus_1(); if(x.calc_cov == true) { Mat<eT>& tmp1 = x.tmp1; Mat<eT>& tmp2 = x.tmp2; tmp1 = sample - x.r_mean; if(sample.n_cols == 1) { tmp2 = tmp1*trans(tmp1); } else { tmp2 = trans(tmp1)*tmp1; } x.r_cov *= (N_minus_1/N); x.r_cov += tmp2 / N_plus_1; } for(uword i=0; i<n_elem; ++i) { const eT val = sample_mem[i]; if(val < min_val_mem[i]) { min_val_mem[i] = val; } if(val > max_val_mem[i]) { max_val_mem[i] = val; } const eT r_mean_val = r_mean_mem[i]; const eT tmp = val - r_mean_val; r_var_mem[i] = N_minus_1/N * r_var_mem[i] + (tmp*tmp)/N_plus_1; r_mean_mem[i] = r_mean_val + (val - r_mean_val)/N_plus_1; } } else { arma_debug_check( (sample.is_vec() == false), "running_stat_vec(): given sample is not a vector"); x.r_mean.set_size(sample.n_rows, sample.n_cols); x.r_var.zeros(sample.n_rows, sample.n_cols); if(x.calc_cov == true) { x.r_cov.zeros(sample.n_elem, sample.n_elem); } x.min_val.set_size(sample.n_rows, sample.n_cols); x.max_val.set_size(sample.n_rows, sample.n_cols); const uword n_elem = sample.n_elem; const eT* sample_mem = sample.memptr(); eT* r_mean_mem = x.r_mean.memptr(); eT* min_val_mem = x.min_val.memptr(); eT* max_val_mem = x.max_val.memptr(); for(uword i=0; i<n_elem; ++i) { const eT val = sample_mem[i]; r_mean_mem[i] = val; min_val_mem[i] = val; max_val_mem[i] = val; } } x.counter++; } //! update statistics to reflect new sample (version for complex numbers) template<typename T> inline void running_stat_vec_aux::update_stats(running_stat_vec< std::complex<T> >& x, const Mat<T>& sample) { arma_extra_debug_sigprint(); const Mat< std::complex<T> > tmp = conv_to< Mat< std::complex<T> > >::from(sample); running_stat_vec_aux::update_stats(x, tmp); } //! alter statistics to reflect new sample (version for complex numbers) template<typename T> inline void running_stat_vec_aux::update_stats(running_stat_vec< std::complex<T> >& x, const Mat< std::complex<T> >& sample) { arma_extra_debug_sigprint(); typedef typename std::complex<T> eT; const T N = x.counter.value(); if(N > T(0)) { arma_debug_assert_same_size(x.r_mean, sample, "running_stat_vec(): dimensionality mismatch"); const uword n_elem = sample.n_elem; const eT* sample_mem = sample.memptr(); eT* r_mean_mem = x.r_mean.memptr(); T* r_var_mem = x.r_var.memptr(); eT* min_val_mem = x.min_val.memptr(); eT* max_val_mem = x.max_val.memptr(); T* min_val_norm_mem = x.min_val_norm.memptr(); T* max_val_norm_mem = x.max_val_norm.memptr(); const T N_plus_1 = x.counter.value_plus_1(); const T N_minus_1 = x.counter.value_minus_1(); if(x.calc_cov == true) { Mat<eT>& tmp1 = x.tmp1; Mat<eT>& tmp2 = x.tmp2; tmp1 = sample - x.r_mean; if(sample.n_cols == 1) { tmp2 = arma::conj(tmp1)*strans(tmp1); } else { tmp2 = trans(tmp1)*tmp1; //tmp2 = strans(conj(tmp1))*tmp1; } x.r_cov *= (N_minus_1/N); x.r_cov += tmp2 / N_plus_1; } for(uword i=0; i<n_elem; ++i) { const eT& val = sample_mem[i]; const T val_norm = std::norm(val); if(val_norm < min_val_norm_mem[i]) { min_val_norm_mem[i] = val_norm; min_val_mem[i] = val; } if(val_norm > max_val_norm_mem[i]) { max_val_norm_mem[i] = val_norm; max_val_mem[i] = val; } const eT& r_mean_val = r_mean_mem[i]; r_var_mem[i] = N_minus_1/N * r_var_mem[i] + std::norm(val - r_mean_val)/N_plus_1; r_mean_mem[i] = r_mean_val + (val - r_mean_val)/N_plus_1; } } else { arma_debug_check( (sample.is_vec() == false), "running_stat_vec(): given sample is not a vector"); x.r_mean.set_size(sample.n_rows, sample.n_cols); x.r_var.zeros(sample.n_rows, sample.n_cols); if(x.calc_cov == true) { x.r_cov.zeros(sample.n_elem, sample.n_elem); } x.min_val.set_size(sample.n_rows, sample.n_cols); x.max_val.set_size(sample.n_rows, sample.n_cols); x.min_val_norm.set_size(sample.n_rows, sample.n_cols); x.max_val_norm.set_size(sample.n_rows, sample.n_cols); const uword n_elem = sample.n_elem; const eT* sample_mem = sample.memptr(); eT* r_mean_mem = x.r_mean.memptr(); eT* min_val_mem = x.min_val.memptr(); eT* max_val_mem = x.max_val.memptr(); T* min_val_norm_mem = x.min_val_norm.memptr(); T* max_val_norm_mem = x.max_val_norm.memptr(); for(uword i=0; i<n_elem; ++i) { const eT& val = sample_mem[i]; const T val_norm = std::norm(val); r_mean_mem[i] = val; min_val_mem[i] = val; max_val_mem[i] = val; min_val_norm_mem[i] = val_norm; max_val_norm_mem[i] = val_norm; } } x.counter++; } //! @}