max@0
|
1 // Copyright (C) 2009-2011 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2009-2011 Conrad Sanderson
|
max@0
|
3 // Copyright (C) 2009-2010 Dimitrios Bouzas
|
max@0
|
4 // Copyright (C) 2011 Stanislav Funiak
|
max@0
|
5 //
|
max@0
|
6 // This file is part of the Armadillo C++ library.
|
max@0
|
7 // It is provided without any warranty of fitness
|
max@0
|
8 // for any purpose. You can redistribute this file
|
max@0
|
9 // and/or modify it under the terms of the GNU
|
max@0
|
10 // Lesser General Public License (LGPL) as published
|
max@0
|
11 // by the Free Software Foundation, either version 3
|
max@0
|
12 // of the License or (at your option) any later version.
|
max@0
|
13 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
14
|
max@0
|
15
|
max@0
|
16
|
max@0
|
17 //! \addtogroup op_pinv
|
max@0
|
18 //! @{
|
max@0
|
19
|
max@0
|
20
|
max@0
|
21
|
max@0
|
22 template<typename eT>
|
max@0
|
23 inline
|
max@0
|
24 void
|
max@0
|
25 op_pinv::direct_pinv(Mat<eT>& out, const Mat<eT>& A, const eT in_tol)
|
max@0
|
26 {
|
max@0
|
27 arma_extra_debug_sigprint();
|
max@0
|
28
|
max@0
|
29 typedef typename get_pod_type<eT>::result T;
|
max@0
|
30
|
max@0
|
31 T tol = access::tmp_real(in_tol);
|
max@0
|
32
|
max@0
|
33 arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0");
|
max@0
|
34
|
max@0
|
35 const uword n_rows = A.n_rows;
|
max@0
|
36 const uword n_cols = A.n_cols;
|
max@0
|
37
|
max@0
|
38 // economical SVD decomposition
|
max@0
|
39 Mat<eT> U;
|
max@0
|
40 Col< T> s;
|
max@0
|
41 Mat<eT> V;
|
max@0
|
42
|
max@0
|
43 const bool status = (n_cols > n_rows) ? auxlib::svd_econ(U,s,V,trans(A),'b') : auxlib::svd_econ(U,s,V,A,'b');
|
max@0
|
44
|
max@0
|
45 if(status == false)
|
max@0
|
46 {
|
max@0
|
47 out.reset();
|
max@0
|
48 arma_bad("pinv(): svd failed");
|
max@0
|
49 return;
|
max@0
|
50 }
|
max@0
|
51
|
max@0
|
52 const uword s_n_elem = s.n_elem;
|
max@0
|
53 const T* s_mem = s.memptr();
|
max@0
|
54
|
max@0
|
55 // set tolerance to default if it hasn't been specified as an argument
|
max@0
|
56 if( (tol == T(0)) && (s_n_elem > 0) )
|
max@0
|
57 {
|
max@0
|
58 tol = (std::max)(n_rows, n_cols) * eop_aux::direct_eps( op_max::direct_max(s_mem, s_n_elem) );
|
max@0
|
59 }
|
max@0
|
60
|
max@0
|
61
|
max@0
|
62 // count non zero valued elements in s
|
max@0
|
63
|
max@0
|
64 uword count = 0;
|
max@0
|
65
|
max@0
|
66 for(uword i = 0; i < s_n_elem; ++i)
|
max@0
|
67 {
|
max@0
|
68 if(s_mem[i] > tol)
|
max@0
|
69 {
|
max@0
|
70 ++count;
|
max@0
|
71 }
|
max@0
|
72 }
|
max@0
|
73
|
max@0
|
74 if(count != 0)
|
max@0
|
75 {
|
max@0
|
76 Col<T> s2(count);
|
max@0
|
77
|
max@0
|
78 T* s2_mem = s2.memptr();
|
max@0
|
79
|
max@0
|
80 uword count2 = 0;
|
max@0
|
81
|
max@0
|
82 for(uword i=0; i < s_n_elem; ++i)
|
max@0
|
83 {
|
max@0
|
84 const T val = s_mem[i];
|
max@0
|
85
|
max@0
|
86 if(val > tol)
|
max@0
|
87 {
|
max@0
|
88 s2_mem[count2] = T(1) / val;
|
max@0
|
89 ++count2;
|
max@0
|
90 }
|
max@0
|
91 }
|
max@0
|
92
|
max@0
|
93
|
max@0
|
94 if(n_rows >= n_cols)
|
max@0
|
95 {
|
max@0
|
96 out = ( V.n_cols > count ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( U.n_cols > count ? U.cols(0,count-1) : U );
|
max@0
|
97 }
|
max@0
|
98 else
|
max@0
|
99 {
|
max@0
|
100 out = ( U.n_cols > count ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( V.n_cols > count ? V.cols(0,count-1) : V );
|
max@0
|
101 }
|
max@0
|
102 }
|
max@0
|
103 else
|
max@0
|
104 {
|
max@0
|
105 out.zeros(n_cols, n_rows);
|
max@0
|
106 }
|
max@0
|
107 }
|
max@0
|
108
|
max@0
|
109
|
max@0
|
110
|
max@0
|
111 template<typename T1>
|
max@0
|
112 inline
|
max@0
|
113 void
|
max@0
|
114 op_pinv::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_pinv>& in)
|
max@0
|
115 {
|
max@0
|
116 arma_extra_debug_sigprint();
|
max@0
|
117
|
max@0
|
118 typedef typename T1::elem_type eT;
|
max@0
|
119
|
max@0
|
120 const unwrap<T1> tmp(in.m);
|
max@0
|
121 const Mat<eT>& A = tmp.M;
|
max@0
|
122
|
max@0
|
123 op_pinv::direct_pinv(out, A, in.aux);
|
max@0
|
124 }
|
max@0
|
125
|
max@0
|
126
|
max@0
|
127
|
max@0
|
128 //! @}
|