annotate armadillo-2.4.4/include/armadillo_bits/op_pinv_meat.hpp @ 0:8b6102e2a9b0

Armadillo Library
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 09:27:06 +0100
parents
children
rev   line source
max@0 1 // Copyright (C) 2009-2011 NICTA (www.nicta.com.au)
max@0 2 // Copyright (C) 2009-2011 Conrad Sanderson
max@0 3 // Copyright (C) 2009-2010 Dimitrios Bouzas
max@0 4 // Copyright (C) 2011 Stanislav Funiak
max@0 5 //
max@0 6 // This file is part of the Armadillo C++ library.
max@0 7 // It is provided without any warranty of fitness
max@0 8 // for any purpose. You can redistribute this file
max@0 9 // and/or modify it under the terms of the GNU
max@0 10 // Lesser General Public License (LGPL) as published
max@0 11 // by the Free Software Foundation, either version 3
max@0 12 // of the License or (at your option) any later version.
max@0 13 // (see http://www.opensource.org/licenses for more info)
max@0 14
max@0 15
max@0 16
max@0 17 //! \addtogroup op_pinv
max@0 18 //! @{
max@0 19
max@0 20
max@0 21
max@0 22 template<typename eT>
max@0 23 inline
max@0 24 void
max@0 25 op_pinv::direct_pinv(Mat<eT>& out, const Mat<eT>& A, const eT in_tol)
max@0 26 {
max@0 27 arma_extra_debug_sigprint();
max@0 28
max@0 29 typedef typename get_pod_type<eT>::result T;
max@0 30
max@0 31 T tol = access::tmp_real(in_tol);
max@0 32
max@0 33 arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0");
max@0 34
max@0 35 const uword n_rows = A.n_rows;
max@0 36 const uword n_cols = A.n_cols;
max@0 37
max@0 38 // economical SVD decomposition
max@0 39 Mat<eT> U;
max@0 40 Col< T> s;
max@0 41 Mat<eT> V;
max@0 42
max@0 43 const bool status = (n_cols > n_rows) ? auxlib::svd_econ(U,s,V,trans(A),'b') : auxlib::svd_econ(U,s,V,A,'b');
max@0 44
max@0 45 if(status == false)
max@0 46 {
max@0 47 out.reset();
max@0 48 arma_bad("pinv(): svd failed");
max@0 49 return;
max@0 50 }
max@0 51
max@0 52 const uword s_n_elem = s.n_elem;
max@0 53 const T* s_mem = s.memptr();
max@0 54
max@0 55 // set tolerance to default if it hasn't been specified as an argument
max@0 56 if( (tol == T(0)) && (s_n_elem > 0) )
max@0 57 {
max@0 58 tol = (std::max)(n_rows, n_cols) * eop_aux::direct_eps( op_max::direct_max(s_mem, s_n_elem) );
max@0 59 }
max@0 60
max@0 61
max@0 62 // count non zero valued elements in s
max@0 63
max@0 64 uword count = 0;
max@0 65
max@0 66 for(uword i = 0; i < s_n_elem; ++i)
max@0 67 {
max@0 68 if(s_mem[i] > tol)
max@0 69 {
max@0 70 ++count;
max@0 71 }
max@0 72 }
max@0 73
max@0 74 if(count != 0)
max@0 75 {
max@0 76 Col<T> s2(count);
max@0 77
max@0 78 T* s2_mem = s2.memptr();
max@0 79
max@0 80 uword count2 = 0;
max@0 81
max@0 82 for(uword i=0; i < s_n_elem; ++i)
max@0 83 {
max@0 84 const T val = s_mem[i];
max@0 85
max@0 86 if(val > tol)
max@0 87 {
max@0 88 s2_mem[count2] = T(1) / val;
max@0 89 ++count2;
max@0 90 }
max@0 91 }
max@0 92
max@0 93
max@0 94 if(n_rows >= n_cols)
max@0 95 {
max@0 96 out = ( V.n_cols > count ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( U.n_cols > count ? U.cols(0,count-1) : U );
max@0 97 }
max@0 98 else
max@0 99 {
max@0 100 out = ( U.n_cols > count ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( V.n_cols > count ? V.cols(0,count-1) : V );
max@0 101 }
max@0 102 }
max@0 103 else
max@0 104 {
max@0 105 out.zeros(n_cols, n_rows);
max@0 106 }
max@0 107 }
max@0 108
max@0 109
max@0 110
max@0 111 template<typename T1>
max@0 112 inline
max@0 113 void
max@0 114 op_pinv::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_pinv>& in)
max@0 115 {
max@0 116 arma_extra_debug_sigprint();
max@0 117
max@0 118 typedef typename T1::elem_type eT;
max@0 119
max@0 120 const unwrap<T1> tmp(in.m);
max@0 121 const Mat<eT>& A = tmp.M;
max@0 122
max@0 123 op_pinv::direct_pinv(out, A, in.aux);
max@0 124 }
max@0 125
max@0 126
max@0 127
max@0 128 //! @}