Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-3.900.4/include/armadillo_bits/op_pinv_meat.hpp @ 49:1ec0e2823891
Switch to using subrepo copies of qm-dsp, nnls-chroma, vamp-plugin-sdk; update Armadillo version; assume build without external BLAS/LAPACK
author | Chris Cannam |
---|---|
date | Thu, 13 Jun 2013 10:25:24 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
48:69251e11a913 | 49:1ec0e2823891 |
---|---|
1 // Copyright (C) 2009-2011 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2009-2011 Conrad Sanderson | |
3 // Copyright (C) 2009-2010 Dimitrios Bouzas | |
4 // Copyright (C) 2011 Stanislav Funiak | |
5 // | |
6 // This Source Code Form is subject to the terms of the Mozilla Public | |
7 // License, v. 2.0. If a copy of the MPL was not distributed with this | |
8 // file, You can obtain one at http://mozilla.org/MPL/2.0/. | |
9 | |
10 | |
11 | |
12 //! \addtogroup op_pinv | |
13 //! @{ | |
14 | |
15 | |
16 | |
17 template<typename eT> | |
18 inline | |
19 void | |
20 op_pinv::direct_pinv(Mat<eT>& out, const Mat<eT>& A, const eT in_tol) | |
21 { | |
22 arma_extra_debug_sigprint(); | |
23 | |
24 typedef typename get_pod_type<eT>::result T; | |
25 | |
26 T tol = access::tmp_real(in_tol); | |
27 | |
28 arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0"); | |
29 | |
30 const uword n_rows = A.n_rows; | |
31 const uword n_cols = A.n_cols; | |
32 | |
33 // economical SVD decomposition | |
34 Mat<eT> U; | |
35 Col< T> s; | |
36 Mat<eT> V; | |
37 | |
38 const bool status = (n_cols > n_rows) ? auxlib::svd_econ(U,s,V,trans(A),'b') : auxlib::svd_econ(U,s,V,A,'b'); | |
39 | |
40 if(status == false) | |
41 { | |
42 out.reset(); | |
43 arma_bad("pinv(): svd failed"); | |
44 return; | |
45 } | |
46 | |
47 const uword s_n_elem = s.n_elem; | |
48 const T* s_mem = s.memptr(); | |
49 | |
50 // set tolerance to default if it hasn't been specified as an argument | |
51 if( (tol == T(0)) && (s_n_elem > 0) ) | |
52 { | |
53 tol = (std::max)(n_rows, n_cols) * eop_aux::direct_eps( op_max::direct_max(s_mem, s_n_elem) ); | |
54 } | |
55 | |
56 | |
57 // count non zero valued elements in s | |
58 | |
59 uword count = 0; | |
60 | |
61 for(uword i = 0; i < s_n_elem; ++i) | |
62 { | |
63 if(s_mem[i] > tol) | |
64 { | |
65 ++count; | |
66 } | |
67 } | |
68 | |
69 if(count != 0) | |
70 { | |
71 Col<T> s2(count); | |
72 | |
73 T* s2_mem = s2.memptr(); | |
74 | |
75 uword count2 = 0; | |
76 | |
77 for(uword i=0; i < s_n_elem; ++i) | |
78 { | |
79 const T val = s_mem[i]; | |
80 | |
81 if(val > tol) | |
82 { | |
83 s2_mem[count2] = T(1) / val; | |
84 ++count2; | |
85 } | |
86 } | |
87 | |
88 | |
89 if(n_rows >= n_cols) | |
90 { | |
91 out = ( V.n_cols > count ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( U.n_cols > count ? U.cols(0,count-1) : U ); | |
92 } | |
93 else | |
94 { | |
95 out = ( U.n_cols > count ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( V.n_cols > count ? V.cols(0,count-1) : V ); | |
96 } | |
97 } | |
98 else | |
99 { | |
100 out.zeros(n_cols, n_rows); | |
101 } | |
102 } | |
103 | |
104 | |
105 | |
106 template<typename T1> | |
107 inline | |
108 void | |
109 op_pinv::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_pinv>& in) | |
110 { | |
111 arma_extra_debug_sigprint(); | |
112 | |
113 typedef typename T1::elem_type eT; | |
114 | |
115 const unwrap<T1> tmp(in.m); | |
116 const Mat<eT>& A = tmp.M; | |
117 | |
118 op_pinv::direct_pinv(out, A, in.aux); | |
119 } | |
120 | |
121 | |
122 | |
123 //! @} |