Chris@49
|
1 // Copyright (C) 2009-2011 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2009-2011 Conrad Sanderson
|
Chris@49
|
3 // Copyright (C) 2009-2010 Dimitrios Bouzas
|
Chris@49
|
4 // Copyright (C) 2011 Stanislav Funiak
|
Chris@49
|
5 //
|
Chris@49
|
6 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
7 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
8 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
9
|
Chris@49
|
10
|
Chris@49
|
11
|
Chris@49
|
12 //! \addtogroup op_pinv
|
Chris@49
|
13 //! @{
|
Chris@49
|
14
|
Chris@49
|
15
|
Chris@49
|
16
|
Chris@49
|
17 template<typename eT>
|
Chris@49
|
18 inline
|
Chris@49
|
19 void
|
Chris@49
|
20 op_pinv::direct_pinv(Mat<eT>& out, const Mat<eT>& A, const eT in_tol)
|
Chris@49
|
21 {
|
Chris@49
|
22 arma_extra_debug_sigprint();
|
Chris@49
|
23
|
Chris@49
|
24 typedef typename get_pod_type<eT>::result T;
|
Chris@49
|
25
|
Chris@49
|
26 T tol = access::tmp_real(in_tol);
|
Chris@49
|
27
|
Chris@49
|
28 arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0");
|
Chris@49
|
29
|
Chris@49
|
30 const uword n_rows = A.n_rows;
|
Chris@49
|
31 const uword n_cols = A.n_cols;
|
Chris@49
|
32
|
Chris@49
|
33 // economical SVD decomposition
|
Chris@49
|
34 Mat<eT> U;
|
Chris@49
|
35 Col< T> s;
|
Chris@49
|
36 Mat<eT> V;
|
Chris@49
|
37
|
Chris@49
|
38 const bool status = (n_cols > n_rows) ? auxlib::svd_econ(U,s,V,trans(A),'b') : auxlib::svd_econ(U,s,V,A,'b');
|
Chris@49
|
39
|
Chris@49
|
40 if(status == false)
|
Chris@49
|
41 {
|
Chris@49
|
42 out.reset();
|
Chris@49
|
43 arma_bad("pinv(): svd failed");
|
Chris@49
|
44 return;
|
Chris@49
|
45 }
|
Chris@49
|
46
|
Chris@49
|
47 const uword s_n_elem = s.n_elem;
|
Chris@49
|
48 const T* s_mem = s.memptr();
|
Chris@49
|
49
|
Chris@49
|
50 // set tolerance to default if it hasn't been specified as an argument
|
Chris@49
|
51 if( (tol == T(0)) && (s_n_elem > 0) )
|
Chris@49
|
52 {
|
Chris@49
|
53 tol = (std::max)(n_rows, n_cols) * eop_aux::direct_eps( op_max::direct_max(s_mem, s_n_elem) );
|
Chris@49
|
54 }
|
Chris@49
|
55
|
Chris@49
|
56
|
Chris@49
|
57 // count non zero valued elements in s
|
Chris@49
|
58
|
Chris@49
|
59 uword count = 0;
|
Chris@49
|
60
|
Chris@49
|
61 for(uword i = 0; i < s_n_elem; ++i)
|
Chris@49
|
62 {
|
Chris@49
|
63 if(s_mem[i] > tol)
|
Chris@49
|
64 {
|
Chris@49
|
65 ++count;
|
Chris@49
|
66 }
|
Chris@49
|
67 }
|
Chris@49
|
68
|
Chris@49
|
69 if(count != 0)
|
Chris@49
|
70 {
|
Chris@49
|
71 Col<T> s2(count);
|
Chris@49
|
72
|
Chris@49
|
73 T* s2_mem = s2.memptr();
|
Chris@49
|
74
|
Chris@49
|
75 uword count2 = 0;
|
Chris@49
|
76
|
Chris@49
|
77 for(uword i=0; i < s_n_elem; ++i)
|
Chris@49
|
78 {
|
Chris@49
|
79 const T val = s_mem[i];
|
Chris@49
|
80
|
Chris@49
|
81 if(val > tol)
|
Chris@49
|
82 {
|
Chris@49
|
83 s2_mem[count2] = T(1) / val;
|
Chris@49
|
84 ++count2;
|
Chris@49
|
85 }
|
Chris@49
|
86 }
|
Chris@49
|
87
|
Chris@49
|
88
|
Chris@49
|
89 if(n_rows >= n_cols)
|
Chris@49
|
90 {
|
Chris@49
|
91 out = ( V.n_cols > count ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( U.n_cols > count ? U.cols(0,count-1) : U );
|
Chris@49
|
92 }
|
Chris@49
|
93 else
|
Chris@49
|
94 {
|
Chris@49
|
95 out = ( U.n_cols > count ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( V.n_cols > count ? V.cols(0,count-1) : V );
|
Chris@49
|
96 }
|
Chris@49
|
97 }
|
Chris@49
|
98 else
|
Chris@49
|
99 {
|
Chris@49
|
100 out.zeros(n_cols, n_rows);
|
Chris@49
|
101 }
|
Chris@49
|
102 }
|
Chris@49
|
103
|
Chris@49
|
104
|
Chris@49
|
105
|
Chris@49
|
106 template<typename T1>
|
Chris@49
|
107 inline
|
Chris@49
|
108 void
|
Chris@49
|
109 op_pinv::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_pinv>& in)
|
Chris@49
|
110 {
|
Chris@49
|
111 arma_extra_debug_sigprint();
|
Chris@49
|
112
|
Chris@49
|
113 typedef typename T1::elem_type eT;
|
Chris@49
|
114
|
Chris@49
|
115 const unwrap<T1> tmp(in.m);
|
Chris@49
|
116 const Mat<eT>& A = tmp.M;
|
Chris@49
|
117
|
Chris@49
|
118 op_pinv::direct_pinv(out, A, in.aux);
|
Chris@49
|
119 }
|
Chris@49
|
120
|
Chris@49
|
121
|
Chris@49
|
122
|
Chris@49
|
123 //! @}
|