Chris@49: // Copyright (C) 2009-2011 NICTA (www.nicta.com.au) Chris@49: // Copyright (C) 2009-2011 Conrad Sanderson Chris@49: // Copyright (C) 2009-2010 Dimitrios Bouzas Chris@49: // Copyright (C) 2011 Stanislav Funiak Chris@49: // Chris@49: // This Source Code Form is subject to the terms of the Mozilla Public Chris@49: // License, v. 2.0. If a copy of the MPL was not distributed with this Chris@49: // file, You can obtain one at http://mozilla.org/MPL/2.0/. Chris@49: Chris@49: Chris@49: Chris@49: //! \addtogroup op_pinv Chris@49: //! @{ Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: void Chris@49: op_pinv::direct_pinv(Mat& out, const Mat& A, const eT in_tol) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename get_pod_type::result T; Chris@49: Chris@49: T tol = access::tmp_real(in_tol); Chris@49: Chris@49: arma_debug_check((tol < T(0)), "pinv(): tolerance must be >= 0"); Chris@49: Chris@49: const uword n_rows = A.n_rows; Chris@49: const uword n_cols = A.n_cols; Chris@49: Chris@49: // economical SVD decomposition Chris@49: Mat U; Chris@49: Col< T> s; Chris@49: Mat V; Chris@49: Chris@49: const bool status = (n_cols > n_rows) ? auxlib::svd_econ(U,s,V,trans(A),'b') : auxlib::svd_econ(U,s,V,A,'b'); Chris@49: Chris@49: if(status == false) Chris@49: { Chris@49: out.reset(); Chris@49: arma_bad("pinv(): svd failed"); Chris@49: return; Chris@49: } Chris@49: Chris@49: const uword s_n_elem = s.n_elem; Chris@49: const T* s_mem = s.memptr(); Chris@49: Chris@49: // set tolerance to default if it hasn't been specified as an argument Chris@49: if( (tol == T(0)) && (s_n_elem > 0) ) Chris@49: { Chris@49: tol = (std::max)(n_rows, n_cols) * eop_aux::direct_eps( op_max::direct_max(s_mem, s_n_elem) ); Chris@49: } Chris@49: Chris@49: Chris@49: // count non zero valued elements in s Chris@49: Chris@49: uword count = 0; Chris@49: Chris@49: for(uword i = 0; i < s_n_elem; ++i) Chris@49: { Chris@49: if(s_mem[i] > tol) Chris@49: { Chris@49: ++count; Chris@49: } Chris@49: } Chris@49: Chris@49: if(count != 0) Chris@49: { Chris@49: Col s2(count); Chris@49: Chris@49: T* s2_mem = s2.memptr(); Chris@49: Chris@49: uword count2 = 0; Chris@49: Chris@49: for(uword i=0; i < s_n_elem; ++i) Chris@49: { Chris@49: const T val = s_mem[i]; Chris@49: Chris@49: if(val > tol) Chris@49: { Chris@49: s2_mem[count2] = T(1) / val; Chris@49: ++count2; Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: if(n_rows >= n_cols) Chris@49: { Chris@49: out = ( V.n_cols > count ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( U.n_cols > count ? U.cols(0,count-1) : U ); Chris@49: } Chris@49: else Chris@49: { Chris@49: out = ( U.n_cols > count ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( V.n_cols > count ? V.cols(0,count-1) : V ); Chris@49: } Chris@49: } Chris@49: else Chris@49: { Chris@49: out.zeros(n_cols, n_rows); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: void Chris@49: op_pinv::apply(Mat& out, const Op& in) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const unwrap tmp(in.m); Chris@49: const Mat& A = tmp.M; Chris@49: Chris@49: op_pinv::direct_pinv(out, A, in.aux); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! @}