comparison armadillo-3.900.4/include/armadillo_bits/op_pinv_meat.hpp @ 49:1ec0e2823891

Switch to using subrepo copies of qm-dsp, nnls-chroma, vamp-plugin-sdk; update Armadillo version; assume build without external BLAS/LAPACK
author Chris Cannam
date Thu, 13 Jun 2013 10:25:24 +0100
parents
children
comparison
equal deleted inserted replaced
48:69251e11a913 49:1ec0e2823891
1 // Copyright (C) 2009-2011 NICTA (www.nicta.com.au)
2 // Copyright (C) 2009-2011 Conrad Sanderson
3 // Copyright (C) 2009-2010 Dimitrios Bouzas
4 // Copyright (C) 2011 Stanislav Funiak
5 //
6 // This Source Code Form is subject to the terms of the Mozilla Public
7 // License, v. 2.0. If a copy of the MPL was not distributed with this
8 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10
11
12 //! \addtogroup op_pinv
13 //! @{
14
15
16
17 template<typename eT>
18 inline
19 void
20 op_pinv::direct_pinv(Mat<eT>& out, const Mat<eT>& A, const eT in_tol)
21 {
22 arma_extra_debug_sigprint();
23
24 typedef typename get_pod_type<eT>::result T;
25
26 T tol = access::tmp_real(in_tol);
27
28 arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0");
29
30 const uword n_rows = A.n_rows;
31 const uword n_cols = A.n_cols;
32
33 // economical SVD decomposition
34 Mat<eT> U;
35 Col< T> s;
36 Mat<eT> V;
37
38 const bool status = (n_cols > n_rows) ? auxlib::svd_econ(U,s,V,trans(A),'b') : auxlib::svd_econ(U,s,V,A,'b');
39
40 if(status == false)
41 {
42 out.reset();
43 arma_bad("pinv(): svd failed");
44 return;
45 }
46
47 const uword s_n_elem = s.n_elem;
48 const T* s_mem = s.memptr();
49
50 // set tolerance to default if it hasn't been specified as an argument
51 if( (tol == T(0)) && (s_n_elem > 0) )
52 {
53 tol = (std::max)(n_rows, n_cols) * eop_aux::direct_eps( op_max::direct_max(s_mem, s_n_elem) );
54 }
55
56
57 // count non zero valued elements in s
58
59 uword count = 0;
60
61 for(uword i = 0; i < s_n_elem; ++i)
62 {
63 if(s_mem[i] > tol)
64 {
65 ++count;
66 }
67 }
68
69 if(count != 0)
70 {
71 Col<T> s2(count);
72
73 T* s2_mem = s2.memptr();
74
75 uword count2 = 0;
76
77 for(uword i=0; i < s_n_elem; ++i)
78 {
79 const T val = s_mem[i];
80
81 if(val > tol)
82 {
83 s2_mem[count2] = T(1) / val;
84 ++count2;
85 }
86 }
87
88
89 if(n_rows >= n_cols)
90 {
91 out = ( V.n_cols > count ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( U.n_cols > count ? U.cols(0,count-1) : U );
92 }
93 else
94 {
95 out = ( U.n_cols > count ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( V.n_cols > count ? V.cols(0,count-1) : V );
96 }
97 }
98 else
99 {
100 out.zeros(n_cols, n_rows);
101 }
102 }
103
104
105
106 template<typename T1>
107 inline
108 void
109 op_pinv::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_pinv>& in)
110 {
111 arma_extra_debug_sigprint();
112
113 typedef typename T1::elem_type eT;
114
115 const unwrap<T1> tmp(in.m);
116 const Mat<eT>& A = tmp.M;
117
118 op_pinv::direct_pinv(out, A, in.aux);
119 }
120
121
122
123 //! @}