max@0: // Copyright (C) 2009-2011 NICTA (www.nicta.com.au) max@0: // Copyright (C) 2009-2011 Conrad Sanderson max@0: // Copyright (C) 2009-2010 Dimitrios Bouzas max@0: // Copyright (C) 2011 Stanislav Funiak max@0: // max@0: // This file is part of the Armadillo C++ library. max@0: // It is provided without any warranty of fitness max@0: // for any purpose. You can redistribute this file max@0: // and/or modify it under the terms of the GNU max@0: // Lesser General Public License (LGPL) as published max@0: // by the Free Software Foundation, either version 3 max@0: // of the License or (at your option) any later version. max@0: // (see http://www.opensource.org/licenses for more info) max@0: max@0: max@0: max@0: //! \addtogroup op_pinv max@0: //! @{ max@0: max@0: max@0: max@0: template max@0: inline max@0: void max@0: op_pinv::direct_pinv(Mat& out, const Mat& A, const eT in_tol) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename get_pod_type::result T; max@0: max@0: T tol = access::tmp_real(in_tol); max@0: max@0: arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0"); max@0: max@0: const uword n_rows = A.n_rows; max@0: const uword n_cols = A.n_cols; max@0: max@0: // economical SVD decomposition max@0: Mat U; max@0: Col< T> s; max@0: Mat V; max@0: max@0: const bool status = (n_cols > n_rows) ? auxlib::svd_econ(U,s,V,trans(A),'b') : auxlib::svd_econ(U,s,V,A,'b'); max@0: max@0: if(status == false) max@0: { max@0: out.reset(); max@0: arma_bad("pinv(): svd failed"); max@0: return; max@0: } max@0: max@0: const uword s_n_elem = s.n_elem; max@0: const T* s_mem = s.memptr(); max@0: max@0: // set tolerance to default if it hasn't been specified as an argument max@0: if( (tol == T(0)) && (s_n_elem > 0) ) max@0: { max@0: tol = (std::max)(n_rows, n_cols) * eop_aux::direct_eps( op_max::direct_max(s_mem, s_n_elem) ); max@0: } max@0: max@0: max@0: // count non zero valued elements in s max@0: max@0: uword count = 0; max@0: max@0: for(uword i = 0; i < s_n_elem; ++i) max@0: { max@0: if(s_mem[i] > tol) max@0: { max@0: ++count; max@0: } max@0: } max@0: max@0: if(count != 0) max@0: { max@0: Col s2(count); max@0: max@0: T* s2_mem = s2.memptr(); max@0: max@0: uword count2 = 0; max@0: max@0: for(uword i=0; i < s_n_elem; ++i) max@0: { max@0: const T val = s_mem[i]; max@0: max@0: if(val > tol) max@0: { max@0: s2_mem[count2] = T(1) / val; max@0: ++count2; max@0: } max@0: } max@0: max@0: max@0: if(n_rows >= n_cols) max@0: { max@0: out = ( V.n_cols > count ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( U.n_cols > count ? U.cols(0,count-1) : U ); max@0: } max@0: else max@0: { max@0: out = ( U.n_cols > count ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( V.n_cols > count ? V.cols(0,count-1) : V ); max@0: } max@0: } max@0: else max@0: { max@0: out.zeros(n_cols, n_rows); max@0: } max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: void max@0: op_pinv::apply(Mat& out, const Op& in) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: const unwrap tmp(in.m); max@0: const Mat& A = tmp.M; max@0: max@0: op_pinv::direct_pinv(out, A, in.aux); max@0: } max@0: max@0: max@0: max@0: //! @}