Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-2.4.4/include/armadillo_bits/op_pinv_meat.hpp @ 0:8b6102e2a9b0
Armadillo Library
author | maxzanoni76 <max.zanoni@eecs.qmul.ac.uk> |
---|---|
date | Wed, 11 Apr 2012 09:27:06 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:8b6102e2a9b0 |
---|---|
1 // Copyright (C) 2009-2011 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2009-2011 Conrad Sanderson | |
3 // Copyright (C) 2009-2010 Dimitrios Bouzas | |
4 // Copyright (C) 2011 Stanislav Funiak | |
5 // | |
6 // This file is part of the Armadillo C++ library. | |
7 // It is provided without any warranty of fitness | |
8 // for any purpose. You can redistribute this file | |
9 // and/or modify it under the terms of the GNU | |
10 // Lesser General Public License (LGPL) as published | |
11 // by the Free Software Foundation, either version 3 | |
12 // of the License or (at your option) any later version. | |
13 // (see http://www.opensource.org/licenses for more info) | |
14 | |
15 | |
16 | |
17 //! \addtogroup op_pinv | |
18 //! @{ | |
19 | |
20 | |
21 | |
22 template<typename eT> | |
23 inline | |
24 void | |
25 op_pinv::direct_pinv(Mat<eT>& out, const Mat<eT>& A, const eT in_tol) | |
26 { | |
27 arma_extra_debug_sigprint(); | |
28 | |
29 typedef typename get_pod_type<eT>::result T; | |
30 | |
31 T tol = access::tmp_real(in_tol); | |
32 | |
33 arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0"); | |
34 | |
35 const uword n_rows = A.n_rows; | |
36 const uword n_cols = A.n_cols; | |
37 | |
38 // economical SVD decomposition | |
39 Mat<eT> U; | |
40 Col< T> s; | |
41 Mat<eT> V; | |
42 | |
43 const bool status = (n_cols > n_rows) ? auxlib::svd_econ(U,s,V,trans(A),'b') : auxlib::svd_econ(U,s,V,A,'b'); | |
44 | |
45 if(status == false) | |
46 { | |
47 out.reset(); | |
48 arma_bad("pinv(): svd failed"); | |
49 return; | |
50 } | |
51 | |
52 const uword s_n_elem = s.n_elem; | |
53 const T* s_mem = s.memptr(); | |
54 | |
55 // set tolerance to default if it hasn't been specified as an argument | |
56 if( (tol == T(0)) && (s_n_elem > 0) ) | |
57 { | |
58 tol = (std::max)(n_rows, n_cols) * eop_aux::direct_eps( op_max::direct_max(s_mem, s_n_elem) ); | |
59 } | |
60 | |
61 | |
62 // count non zero valued elements in s | |
63 | |
64 uword count = 0; | |
65 | |
66 for(uword i = 0; i < s_n_elem; ++i) | |
67 { | |
68 if(s_mem[i] > tol) | |
69 { | |
70 ++count; | |
71 } | |
72 } | |
73 | |
74 if(count != 0) | |
75 { | |
76 Col<T> s2(count); | |
77 | |
78 T* s2_mem = s2.memptr(); | |
79 | |
80 uword count2 = 0; | |
81 | |
82 for(uword i=0; i < s_n_elem; ++i) | |
83 { | |
84 const T val = s_mem[i]; | |
85 | |
86 if(val > tol) | |
87 { | |
88 s2_mem[count2] = T(1) / val; | |
89 ++count2; | |
90 } | |
91 } | |
92 | |
93 | |
94 if(n_rows >= n_cols) | |
95 { | |
96 out = ( V.n_cols > count ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( U.n_cols > count ? U.cols(0,count-1) : U ); | |
97 } | |
98 else | |
99 { | |
100 out = ( U.n_cols > count ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( V.n_cols > count ? V.cols(0,count-1) : V ); | |
101 } | |
102 } | |
103 else | |
104 { | |
105 out.zeros(n_cols, n_rows); | |
106 } | |
107 } | |
108 | |
109 | |
110 | |
111 template<typename T1> | |
112 inline | |
113 void | |
114 op_pinv::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_pinv>& in) | |
115 { | |
116 arma_extra_debug_sigprint(); | |
117 | |
118 typedef typename T1::elem_type eT; | |
119 | |
120 const unwrap<T1> tmp(in.m); | |
121 const Mat<eT>& A = tmp.M; | |
122 | |
123 op_pinv::direct_pinv(out, A, in.aux); | |
124 } | |
125 | |
126 | |
127 | |
128 //! @} |