max@0
|
1 // Copyright (C) 2009-2011 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2009-2011 Conrad Sanderson
|
max@0
|
3 //
|
max@0
|
4 // This file is part of the Armadillo C++ library.
|
max@0
|
5 // It is provided without any warranty of fitness
|
max@0
|
6 // for any purpose. You can redistribute this file
|
max@0
|
7 // and/or modify it under the terms of the GNU
|
max@0
|
8 // Lesser General Public License (LGPL) as published
|
max@0
|
9 // by the Free Software Foundation, either version 3
|
max@0
|
10 // of the License or (at your option) any later version.
|
max@0
|
11 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
12
|
max@0
|
13
|
max@0
|
14 //! \addtogroup op_var
|
max@0
|
15 //! @{
|
max@0
|
16
|
max@0
|
17
|
max@0
|
18 //! find the variance of an array
|
max@0
|
19 template<typename eT>
|
max@0
|
20 inline
|
max@0
|
21 eT
|
max@0
|
22 op_var::direct_var(const eT* const X, const uword n_elem, const uword norm_type)
|
max@0
|
23 {
|
max@0
|
24 arma_extra_debug_sigprint();
|
max@0
|
25
|
max@0
|
26 if(n_elem >= 2)
|
max@0
|
27 {
|
max@0
|
28 const eT acc1 = op_mean::direct_mean(X, n_elem);
|
max@0
|
29
|
max@0
|
30 eT acc2 = eT(0);
|
max@0
|
31 eT acc3 = eT(0);
|
max@0
|
32
|
max@0
|
33 uword i,j;
|
max@0
|
34
|
max@0
|
35 for(i=0, j=1; j<n_elem; i+=2, j+=2)
|
max@0
|
36 {
|
max@0
|
37 const eT Xi = X[i];
|
max@0
|
38 const eT Xj = X[j];
|
max@0
|
39
|
max@0
|
40 const eT tmpi = acc1 - Xi;
|
max@0
|
41 const eT tmpj = acc1 - Xj;
|
max@0
|
42
|
max@0
|
43 acc2 += tmpi*tmpi + tmpj*tmpj;
|
max@0
|
44 acc3 += tmpi + tmpj;
|
max@0
|
45 }
|
max@0
|
46
|
max@0
|
47 if(i < n_elem)
|
max@0
|
48 {
|
max@0
|
49 const eT Xi = X[i];
|
max@0
|
50
|
max@0
|
51 const eT tmpi = acc1 - Xi;
|
max@0
|
52
|
max@0
|
53 acc2 += tmpi*tmpi;
|
max@0
|
54 acc3 += tmpi;
|
max@0
|
55 }
|
max@0
|
56
|
max@0
|
57 const eT norm_val = (norm_type == 0) ? eT(n_elem-1) : eT(n_elem);
|
max@0
|
58 const eT var_val = (acc2 - acc3*acc3/eT(n_elem)) / norm_val;
|
max@0
|
59
|
max@0
|
60 return arma_isfinite(var_val) ? var_val : op_var::direct_var_robust(X, n_elem, norm_type);
|
max@0
|
61 }
|
max@0
|
62 else
|
max@0
|
63 {
|
max@0
|
64 return eT(0);
|
max@0
|
65 }
|
max@0
|
66 }
|
max@0
|
67
|
max@0
|
68
|
max@0
|
69
|
max@0
|
70 //! find the variance of an array (version for complex numbers)
|
max@0
|
71 template<typename T>
|
max@0
|
72 inline
|
max@0
|
73 T
|
max@0
|
74 op_var::direct_var(const std::complex<T>* const X, const uword n_elem, const uword norm_type)
|
max@0
|
75 {
|
max@0
|
76 arma_extra_debug_sigprint();
|
max@0
|
77
|
max@0
|
78 typedef typename std::complex<T> eT;
|
max@0
|
79
|
max@0
|
80 if(n_elem >= 2)
|
max@0
|
81 {
|
max@0
|
82 const eT acc1 = op_mean::direct_mean(X, n_elem);
|
max@0
|
83
|
max@0
|
84 T acc2 = T(0);
|
max@0
|
85 eT acc3 = eT(0);
|
max@0
|
86
|
max@0
|
87 for(uword i=0; i<n_elem; ++i)
|
max@0
|
88 {
|
max@0
|
89 const eT tmp = acc1 - X[i];
|
max@0
|
90
|
max@0
|
91 acc2 += std::norm(tmp);
|
max@0
|
92 acc3 += tmp;
|
max@0
|
93 }
|
max@0
|
94
|
max@0
|
95 const T norm_val = (norm_type == 0) ? T(n_elem-1) : T(n_elem);
|
max@0
|
96 const T var_val = (acc2 - std::norm(acc3)/T(n_elem)) / norm_val;
|
max@0
|
97
|
max@0
|
98 return arma_isfinite(var_val) ? var_val : op_var::direct_var_robust(X, n_elem, norm_type);
|
max@0
|
99 }
|
max@0
|
100 else
|
max@0
|
101 {
|
max@0
|
102 return T(0);
|
max@0
|
103 }
|
max@0
|
104 }
|
max@0
|
105
|
max@0
|
106
|
max@0
|
107
|
max@0
|
108 //! find the variance of a subview_row
|
max@0
|
109 template<typename eT>
|
max@0
|
110 inline
|
max@0
|
111 typename get_pod_type<eT>::result
|
max@0
|
112 op_var::direct_var(const subview_row<eT>& X, const uword norm_type)
|
max@0
|
113 {
|
max@0
|
114 arma_extra_debug_sigprint();
|
max@0
|
115
|
max@0
|
116 const uword n_elem = X.n_elem;
|
max@0
|
117
|
max@0
|
118 podarray<eT> tmp(n_elem);
|
max@0
|
119
|
max@0
|
120 eT* tmp_mem = tmp.memptr();
|
max@0
|
121
|
max@0
|
122 for(uword i=0; i<n_elem; ++i)
|
max@0
|
123 {
|
max@0
|
124 tmp_mem[i] = X[i];
|
max@0
|
125 }
|
max@0
|
126
|
max@0
|
127 return op_var::direct_var(tmp_mem, n_elem, norm_type);
|
max@0
|
128 }
|
max@0
|
129
|
max@0
|
130
|
max@0
|
131
|
max@0
|
132 //! find the variance of a subview_col
|
max@0
|
133 template<typename eT>
|
max@0
|
134 inline
|
max@0
|
135 typename get_pod_type<eT>::result
|
max@0
|
136 op_var::direct_var(const subview_col<eT>& X, const uword norm_type)
|
max@0
|
137 {
|
max@0
|
138 arma_extra_debug_sigprint();
|
max@0
|
139
|
max@0
|
140 return op_var::direct_var(X.colptr(0), X.n_elem, norm_type);
|
max@0
|
141 }
|
max@0
|
142
|
max@0
|
143
|
max@0
|
144
|
max@0
|
145 //! find the variance of a diagview
|
max@0
|
146 template<typename eT>
|
max@0
|
147 inline
|
max@0
|
148 typename get_pod_type<eT>::result
|
max@0
|
149 op_var::direct_var(const diagview<eT>& X, const uword norm_type)
|
max@0
|
150 {
|
max@0
|
151 arma_extra_debug_sigprint();
|
max@0
|
152
|
max@0
|
153 const uword n_elem = X.n_elem;
|
max@0
|
154
|
max@0
|
155 podarray<eT> tmp(n_elem);
|
max@0
|
156
|
max@0
|
157 eT* tmp_mem = tmp.memptr();
|
max@0
|
158
|
max@0
|
159 for(uword i=0; i<n_elem; ++i)
|
max@0
|
160 {
|
max@0
|
161 tmp_mem[i] = X[i];
|
max@0
|
162 }
|
max@0
|
163
|
max@0
|
164 return op_var::direct_var(tmp_mem, n_elem, norm_type);
|
max@0
|
165 }
|
max@0
|
166
|
max@0
|
167
|
max@0
|
168
|
max@0
|
169 //! \brief
|
max@0
|
170 //! For each row or for each column, find the variance.
|
max@0
|
171 //! The result is stored in a dense matrix that has either one column or one row.
|
max@0
|
172 //! The dimension, for which the variances are found, is set via the var() function.
|
max@0
|
173 template<typename T1>
|
max@0
|
174 inline
|
max@0
|
175 void
|
max@0
|
176 op_var::apply(Mat<typename T1::pod_type>& out, const mtOp<typename T1::pod_type, T1, op_var>& in)
|
max@0
|
177 {
|
max@0
|
178 arma_extra_debug_sigprint();
|
max@0
|
179
|
max@0
|
180 typedef typename T1::elem_type in_eT;
|
max@0
|
181 typedef typename T1::pod_type out_eT;
|
max@0
|
182
|
max@0
|
183 const unwrap_check_mixed<T1> tmp(in.m, out);
|
max@0
|
184 const Mat<in_eT>& X = tmp.M;
|
max@0
|
185
|
max@0
|
186 const uword norm_type = in.aux_uword_a;
|
max@0
|
187 const uword dim = in.aux_uword_b;
|
max@0
|
188
|
max@0
|
189 arma_debug_check( (norm_type > 1), "var(): incorrect usage. norm_type must be 0 or 1");
|
max@0
|
190 arma_debug_check( (dim > 1), "var(): incorrect usage. dim must be 0 or 1" );
|
max@0
|
191
|
max@0
|
192 const uword X_n_rows = X.n_rows;
|
max@0
|
193 const uword X_n_cols = X.n_cols;
|
max@0
|
194
|
max@0
|
195 if(dim == 0)
|
max@0
|
196 {
|
max@0
|
197 arma_extra_debug_print("op_var::apply(), dim = 0");
|
max@0
|
198
|
max@0
|
199 arma_debug_check( (X_n_rows == 0), "var(): given object has zero rows" );
|
max@0
|
200
|
max@0
|
201 out.set_size(1, X_n_cols);
|
max@0
|
202
|
max@0
|
203 out_eT* out_mem = out.memptr();
|
max@0
|
204
|
max@0
|
205 for(uword col=0; col<X_n_cols; ++col)
|
max@0
|
206 {
|
max@0
|
207 out_mem[col] = op_var::direct_var( X.colptr(col), X_n_rows, norm_type );
|
max@0
|
208 }
|
max@0
|
209 }
|
max@0
|
210 else
|
max@0
|
211 if(dim == 1)
|
max@0
|
212 {
|
max@0
|
213 arma_extra_debug_print("op_var::apply(), dim = 1");
|
max@0
|
214
|
max@0
|
215 arma_debug_check( (X_n_cols == 0), "var(): given object has zero columns" );
|
max@0
|
216
|
max@0
|
217 out.set_size(X_n_rows, 1);
|
max@0
|
218
|
max@0
|
219 podarray<in_eT> tmp(X_n_cols);
|
max@0
|
220
|
max@0
|
221 in_eT* tmp_mem = tmp.memptr();
|
max@0
|
222 out_eT* out_mem = out.memptr();
|
max@0
|
223
|
max@0
|
224 for(uword row=0; row<X_n_rows; ++row)
|
max@0
|
225 {
|
max@0
|
226 tmp.copy_row(X, row);
|
max@0
|
227
|
max@0
|
228 out_mem[row] = op_var::direct_var( tmp_mem, X_n_cols, norm_type );
|
max@0
|
229 }
|
max@0
|
230 }
|
max@0
|
231 }
|
max@0
|
232
|
max@0
|
233
|
max@0
|
234
|
max@0
|
235 //! find the variance of an array (robust but slow)
|
max@0
|
236 template<typename eT>
|
max@0
|
237 inline
|
max@0
|
238 eT
|
max@0
|
239 op_var::direct_var_robust(const eT* const X, const uword n_elem, const uword norm_type)
|
max@0
|
240 {
|
max@0
|
241 arma_extra_debug_sigprint();
|
max@0
|
242
|
max@0
|
243 if(n_elem > 1)
|
max@0
|
244 {
|
max@0
|
245 eT r_mean = X[0];
|
max@0
|
246 eT r_var = eT(0);
|
max@0
|
247
|
max@0
|
248 for(uword i=1; i<n_elem; ++i)
|
max@0
|
249 {
|
max@0
|
250 const eT tmp = X[i] - r_mean;
|
max@0
|
251 const eT i_plus_1 = eT(i+1);
|
max@0
|
252
|
max@0
|
253 r_var = eT(i-1)/eT(i) * r_var + (tmp*tmp)/i_plus_1;
|
max@0
|
254
|
max@0
|
255 r_mean = r_mean + tmp/i_plus_1;
|
max@0
|
256 }
|
max@0
|
257
|
max@0
|
258 return (norm_type == 0) ? r_var : (eT(n_elem-1)/eT(n_elem)) * r_var;
|
max@0
|
259 }
|
max@0
|
260 else
|
max@0
|
261 {
|
max@0
|
262 return eT(0);
|
max@0
|
263 }
|
max@0
|
264 }
|
max@0
|
265
|
max@0
|
266
|
max@0
|
267
|
max@0
|
268 //! find the variance of an array (version for complex numbers) (robust but slow)
|
max@0
|
269 template<typename T>
|
max@0
|
270 inline
|
max@0
|
271 T
|
max@0
|
272 op_var::direct_var_robust(const std::complex<T>* const X, const uword n_elem, const uword norm_type)
|
max@0
|
273 {
|
max@0
|
274 arma_extra_debug_sigprint();
|
max@0
|
275
|
max@0
|
276 typedef typename std::complex<T> eT;
|
max@0
|
277
|
max@0
|
278 if(n_elem > 1)
|
max@0
|
279 {
|
max@0
|
280 eT r_mean = X[0];
|
max@0
|
281 T r_var = T(0);
|
max@0
|
282
|
max@0
|
283 for(uword i=1; i<n_elem; ++i)
|
max@0
|
284 {
|
max@0
|
285 const eT tmp = X[i] - r_mean;
|
max@0
|
286 const T i_plus_1 = T(i+1);
|
max@0
|
287
|
max@0
|
288 r_var = T(i-1)/T(i) * r_var + std::norm(tmp)/i_plus_1;
|
max@0
|
289
|
max@0
|
290 r_mean = r_mean + tmp/i_plus_1;
|
max@0
|
291 }
|
max@0
|
292
|
max@0
|
293 return (norm_type == 0) ? r_var : (T(n_elem-1)/T(n_elem)) * r_var;
|
max@0
|
294 }
|
max@0
|
295 else
|
max@0
|
296 {
|
max@0
|
297 return T(0);
|
max@0
|
298 }
|
max@0
|
299 }
|
max@0
|
300
|
max@0
|
301
|
max@0
|
302
|
max@0
|
303 //! @}
|
max@0
|
304
|