Chris@49: // Copyright (C) 2008-2013 NICTA (www.nicta.com.au) Chris@49: // Copyright (C) 2008-2013 Conrad Sanderson Chris@49: // Copyright (C) 2009 Edmund Highcock Chris@49: // Copyright (C) 2011 James Sanders Chris@49: // Copyright (C) 2011 Stanislav Funiak Chris@49: // Copyright (C) 2012 Eric Jon Sundstrom Chris@49: // Copyright (C) 2012 Michael McNeil Forbes Chris@49: // Chris@49: // This Source Code Form is subject to the terms of the Mozilla Public Chris@49: // License, v. 2.0. If a copy of the MPL was not distributed with this Chris@49: // file, You can obtain one at http://mozilla.org/MPL/2.0/. Chris@49: Chris@49: Chris@49: //! \addtogroup auxlib Chris@49: //! @{ Chris@49: Chris@49: Chris@49: Chris@49: //! immediate matrix inverse Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::inv(Mat& out, const Base& X, const bool slow) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: out = X.get_ref(); Chris@49: Chris@49: arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); Chris@49: Chris@49: bool status = false; Chris@49: Chris@49: const uword N = out.n_rows; Chris@49: Chris@49: if( (N <= 4) && (slow == false) ) Chris@49: { Chris@49: status = auxlib::inv_inplace_tinymat(out, N); Chris@49: } Chris@49: Chris@49: if( (N > 4) || (status == false) ) Chris@49: { Chris@49: status = auxlib::inv_inplace_lapack(out); Chris@49: } Chris@49: Chris@49: return status; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::inv(Mat& out, const Mat& X, const bool slow) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: arma_debug_check( (X.is_square() == false), "inv(): given matrix is not square" ); Chris@49: Chris@49: bool status = false; Chris@49: Chris@49: const uword N = X.n_rows; Chris@49: Chris@49: if( (N <= 4) && (slow == false) ) Chris@49: { Chris@49: status = (&out != &X) ? auxlib::inv_noalias_tinymat(out, X, N) : auxlib::inv_inplace_tinymat(out, N); Chris@49: } Chris@49: Chris@49: if( (N > 4) || (status == false) ) Chris@49: { Chris@49: out = X; Chris@49: status = auxlib::inv_inplace_lapack(out); Chris@49: } Chris@49: Chris@49: return status; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::inv_noalias_tinymat(Mat& out, const Mat& X, const uword N) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: bool det_ok = true; Chris@49: Chris@49: out.set_size(N,N); Chris@49: Chris@49: switch(N) Chris@49: { Chris@49: case 1: Chris@49: { Chris@49: out[0] = eT(1) / X[0]; Chris@49: }; Chris@49: break; Chris@49: Chris@49: case 2: Chris@49: { Chris@49: const eT* Xm = X.memptr(); Chris@49: Chris@49: const eT a = Xm[pos<0,0>::n2]; Chris@49: const eT b = Xm[pos<0,1>::n2]; Chris@49: const eT c = Xm[pos<1,0>::n2]; Chris@49: const eT d = Xm[pos<1,1>::n2]; Chris@49: Chris@49: const eT tmp_det = (a*d - b*c); Chris@49: Chris@49: if(tmp_det != eT(0)) Chris@49: { Chris@49: eT* outm = out.memptr(); Chris@49: Chris@49: outm[pos<0,0>::n2] = d / tmp_det; Chris@49: outm[pos<0,1>::n2] = -b / tmp_det; Chris@49: outm[pos<1,0>::n2] = -c / tmp_det; Chris@49: outm[pos<1,1>::n2] = a / tmp_det; Chris@49: } Chris@49: else Chris@49: { Chris@49: det_ok = false; Chris@49: } Chris@49: }; Chris@49: break; Chris@49: Chris@49: case 3: Chris@49: { Chris@49: const eT* X_col0 = X.colptr(0); Chris@49: const eT a11 = X_col0[0]; Chris@49: const eT a21 = X_col0[1]; Chris@49: const eT a31 = X_col0[2]; Chris@49: Chris@49: const eT* X_col1 = X.colptr(1); Chris@49: const eT a12 = X_col1[0]; Chris@49: const eT a22 = X_col1[1]; Chris@49: const eT a32 = X_col1[2]; Chris@49: Chris@49: const eT* X_col2 = X.colptr(2); Chris@49: const eT a13 = X_col2[0]; Chris@49: const eT a23 = X_col2[1]; Chris@49: const eT a33 = X_col2[2]; Chris@49: Chris@49: const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13); Chris@49: Chris@49: if(tmp_det != eT(0)) Chris@49: { Chris@49: eT* out_col0 = out.colptr(0); Chris@49: out_col0[0] = (a33*a22 - a32*a23) / tmp_det; Chris@49: out_col0[1] = -(a33*a21 - a31*a23) / tmp_det; Chris@49: out_col0[2] = (a32*a21 - a31*a22) / tmp_det; Chris@49: Chris@49: eT* out_col1 = out.colptr(1); Chris@49: out_col1[0] = -(a33*a12 - a32*a13) / tmp_det; Chris@49: out_col1[1] = (a33*a11 - a31*a13) / tmp_det; Chris@49: out_col1[2] = -(a32*a11 - a31*a12) / tmp_det; Chris@49: Chris@49: eT* out_col2 = out.colptr(2); Chris@49: out_col2[0] = (a23*a12 - a22*a13) / tmp_det; Chris@49: out_col2[1] = -(a23*a11 - a21*a13) / tmp_det; Chris@49: out_col2[2] = (a22*a11 - a21*a12) / tmp_det; Chris@49: } Chris@49: else Chris@49: { Chris@49: det_ok = false; Chris@49: } Chris@49: }; Chris@49: break; Chris@49: Chris@49: case 4: Chris@49: { Chris@49: const eT tmp_det = det(X); Chris@49: Chris@49: if(tmp_det != eT(0)) Chris@49: { Chris@49: const eT* Xm = X.memptr(); Chris@49: eT* outm = out.memptr(); Chris@49: Chris@49: outm[pos<0,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; Chris@49: outm[pos<1,0>::n4] = ( Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; Chris@49: outm[pos<2,0>::n4] = ( Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; Chris@49: outm[pos<3,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det; Chris@49: Chris@49: outm[pos<0,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; Chris@49: outm[pos<1,1>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; Chris@49: outm[pos<2,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; Chris@49: outm[pos<3,1>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det; Chris@49: Chris@49: outm[pos<0,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; Chris@49: outm[pos<1,2>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; Chris@49: outm[pos<2,2>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; Chris@49: outm[pos<3,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det; Chris@49: Chris@49: outm[pos<0,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det; Chris@49: outm[pos<1,3>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det; Chris@49: outm[pos<2,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] ) / tmp_det; Chris@49: outm[pos<3,3>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] ) / tmp_det; Chris@49: } Chris@49: else Chris@49: { Chris@49: det_ok = false; Chris@49: } Chris@49: }; Chris@49: break; Chris@49: Chris@49: default: Chris@49: ; Chris@49: } Chris@49: Chris@49: return det_ok; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::inv_inplace_tinymat(Mat& X, const uword N) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: bool det_ok = true; Chris@49: Chris@49: // for more info, see: Chris@49: // http://www.dr-lex.34sp.com/random/matrix_inv.html Chris@49: // http://www.cvl.iis.u-tokyo.ac.jp/~miyazaki/tech/teche23.html Chris@49: // http://www.euclideanspace.com/maths/algebra/matrix/functions/inverse/fourD/index.htm Chris@49: // http://www.geometrictools.com//LibFoundation/Mathematics/Wm4Matrix4.inl Chris@49: Chris@49: switch(N) Chris@49: { Chris@49: case 1: Chris@49: { Chris@49: X[0] = eT(1) / X[0]; Chris@49: }; Chris@49: break; Chris@49: Chris@49: case 2: Chris@49: { Chris@49: const eT a = X[pos<0,0>::n2]; Chris@49: const eT b = X[pos<0,1>::n2]; Chris@49: const eT c = X[pos<1,0>::n2]; Chris@49: const eT d = X[pos<1,1>::n2]; Chris@49: Chris@49: const eT tmp_det = (a*d - b*c); Chris@49: Chris@49: if(tmp_det != eT(0)) Chris@49: { Chris@49: X[pos<0,0>::n2] = d / tmp_det; Chris@49: X[pos<0,1>::n2] = -b / tmp_det; Chris@49: X[pos<1,0>::n2] = -c / tmp_det; Chris@49: X[pos<1,1>::n2] = a / tmp_det; Chris@49: } Chris@49: else Chris@49: { Chris@49: det_ok = false; Chris@49: } Chris@49: }; Chris@49: break; Chris@49: Chris@49: case 3: Chris@49: { Chris@49: eT* X_col0 = X.colptr(0); Chris@49: eT* X_col1 = X.colptr(1); Chris@49: eT* X_col2 = X.colptr(2); Chris@49: Chris@49: const eT a11 = X_col0[0]; Chris@49: const eT a21 = X_col0[1]; Chris@49: const eT a31 = X_col0[2]; Chris@49: Chris@49: const eT a12 = X_col1[0]; Chris@49: const eT a22 = X_col1[1]; Chris@49: const eT a32 = X_col1[2]; Chris@49: Chris@49: const eT a13 = X_col2[0]; Chris@49: const eT a23 = X_col2[1]; Chris@49: const eT a33 = X_col2[2]; Chris@49: Chris@49: const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13); Chris@49: Chris@49: if(tmp_det != eT(0)) Chris@49: { Chris@49: X_col0[0] = (a33*a22 - a32*a23) / tmp_det; Chris@49: X_col0[1] = -(a33*a21 - a31*a23) / tmp_det; Chris@49: X_col0[2] = (a32*a21 - a31*a22) / tmp_det; Chris@49: Chris@49: X_col1[0] = -(a33*a12 - a32*a13) / tmp_det; Chris@49: X_col1[1] = (a33*a11 - a31*a13) / tmp_det; Chris@49: X_col1[2] = -(a32*a11 - a31*a12) / tmp_det; Chris@49: Chris@49: X_col2[0] = (a23*a12 - a22*a13) / tmp_det; Chris@49: X_col2[1] = -(a23*a11 - a21*a13) / tmp_det; Chris@49: X_col2[2] = (a22*a11 - a21*a12) / tmp_det; Chris@49: } Chris@49: else Chris@49: { Chris@49: det_ok = false; Chris@49: } Chris@49: }; Chris@49: break; Chris@49: Chris@49: case 4: Chris@49: { Chris@49: const eT tmp_det = det(X); Chris@49: Chris@49: if(tmp_det != eT(0)) Chris@49: { Chris@49: const Mat A(X); Chris@49: Chris@49: const eT* Am = A.memptr(); Chris@49: eT* Xm = X.memptr(); Chris@49: Chris@49: Xm[pos<0,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] - Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] + Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; Chris@49: Xm[pos<1,0>::n4] = ( Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; Chris@49: Xm[pos<2,0>::n4] = ( Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det; Chris@49: Xm[pos<3,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det; Chris@49: Chris@49: Xm[pos<0,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] + Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] - Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; Chris@49: Xm[pos<1,1>::n4] = ( Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; Chris@49: Xm[pos<2,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det; Chris@49: Xm[pos<3,1>::n4] = ( Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det; Chris@49: Chris@49: Xm[pos<0,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] + Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; Chris@49: Xm[pos<1,2>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; Chris@49: Xm[pos<2,2>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det; Chris@49: Xm[pos<3,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det; Chris@49: Chris@49: Xm[pos<0,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] + Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det; Chris@49: Xm[pos<1,3>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det; Chris@49: Xm[pos<2,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] ) / tmp_det; Chris@49: Xm[pos<3,3>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] ) / tmp_det; Chris@49: } Chris@49: else Chris@49: { Chris@49: det_ok = false; Chris@49: } Chris@49: }; Chris@49: break; Chris@49: Chris@49: default: Chris@49: ; Chris@49: } Chris@49: Chris@49: return det_ok; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::inv_inplace_lapack(Mat& out) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: if(out.is_empty()) Chris@49: { Chris@49: return true; Chris@49: } Chris@49: Chris@49: #if defined(ARMA_USE_ATLAS) Chris@49: { Chris@49: podarray ipiv(out.n_rows); Chris@49: Chris@49: int info = atlas::clapack_getrf(atlas::CblasColMajor, out.n_rows, out.n_cols, out.memptr(), out.n_rows, ipiv.memptr()); Chris@49: Chris@49: if(info == 0) Chris@49: { Chris@49: info = atlas::clapack_getri(atlas::CblasColMajor, out.n_rows, out.memptr(), out.n_rows, ipiv.memptr()); Chris@49: } Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #elif defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: blas_int n_rows = out.n_rows; Chris@49: blas_int n_cols = out.n_cols; Chris@49: blas_int lwork = 0; Chris@49: blas_int lwork_min = (std::max)(blas_int(1), n_rows); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray ipiv(out.n_rows); Chris@49: Chris@49: eT work_query[2]; Chris@49: blas_int lwork_query = -1; Chris@49: Chris@49: lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), &work_query[0], &lwork_query, &info); Chris@49: Chris@49: if(info == 0) Chris@49: { Chris@49: const blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); Chris@49: Chris@49: lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min; Chris@49: } Chris@49: else Chris@49: { Chris@49: return false; Chris@49: } Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: Chris@49: lapack::getrf(&n_rows, &n_cols, out.memptr(), &n_rows, ipiv.memptr(), &info); Chris@49: Chris@49: if(info == 0) Chris@49: { Chris@49: lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &lwork, &info); Chris@49: } Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_stop("inv(): use of ATLAS or LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::inv_tr(Mat& out, const Base& X, const uword layout) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: out = X.get_ref(); Chris@49: Chris@49: arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); Chris@49: Chris@49: if(out.is_empty()) Chris@49: { Chris@49: return true; Chris@49: } Chris@49: Chris@49: bool status; Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: char uplo = (layout == 0) ? 'U' : 'L'; Chris@49: char diag = 'N'; Chris@49: blas_int n = blas_int(out.n_rows); Chris@49: blas_int info = 0; Chris@49: Chris@49: lapack::trtri(&uplo, &diag, &n, out.memptr(), &n, &info); Chris@49: Chris@49: status = (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(layout); Chris@49: arma_stop("inv(): use of LAPACK needs to be enabled"); Chris@49: status = false; Chris@49: } Chris@49: #endif Chris@49: Chris@49: Chris@49: if(status == true) Chris@49: { Chris@49: if(layout == 0) Chris@49: { Chris@49: // upper triangular Chris@49: out = trimatu(out); Chris@49: } Chris@49: else Chris@49: { Chris@49: // lower triangular Chris@49: out = trimatl(out); Chris@49: } Chris@49: } Chris@49: Chris@49: return status; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::inv_sym(Mat& out, const Base& X, const uword layout) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: out = X.get_ref(); Chris@49: Chris@49: arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); Chris@49: Chris@49: if(out.is_empty()) Chris@49: { Chris@49: return true; Chris@49: } Chris@49: Chris@49: bool status; Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: char uplo = (layout == 0) ? 'U' : 'L'; Chris@49: blas_int n = blas_int(out.n_rows); Chris@49: blas_int lwork = 3 * (n*n); // TODO: use lwork = -1 to determine optimal size Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray ipiv; Chris@49: ipiv.set_size(out.n_rows); Chris@49: Chris@49: podarray work; Chris@49: work.set_size( uword(lwork) ); Chris@49: Chris@49: lapack::sytrf(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &lwork, &info); Chris@49: Chris@49: status = (info == 0); Chris@49: Chris@49: if(status == true) Chris@49: { Chris@49: lapack::sytri(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &info); Chris@49: Chris@49: out = (layout == 0) ? symmatu(out) : symmatl(out); Chris@49: Chris@49: status = (info == 0); Chris@49: } Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(layout); Chris@49: arma_stop("inv(): use of LAPACK needs to be enabled"); Chris@49: status = false; Chris@49: } Chris@49: #endif Chris@49: Chris@49: return status; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::inv_sympd(Mat& out, const Base& X, const uword layout) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: out = X.get_ref(); Chris@49: Chris@49: arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); Chris@49: Chris@49: if(out.is_empty()) Chris@49: { Chris@49: return true; Chris@49: } Chris@49: Chris@49: bool status; Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: char uplo = (layout == 0) ? 'U' : 'L'; Chris@49: blas_int n = blas_int(out.n_rows); Chris@49: blas_int info = 0; Chris@49: Chris@49: lapack::potrf(&uplo, &n, out.memptr(), &n, &info); Chris@49: Chris@49: status = (info == 0); Chris@49: Chris@49: if(status == true) Chris@49: { Chris@49: lapack::potri(&uplo, &n, out.memptr(), &n, &info); Chris@49: Chris@49: out = (layout == 0) ? symmatu(out) : symmatl(out); Chris@49: Chris@49: status = (info == 0); Chris@49: } Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(layout); Chris@49: arma_stop("inv(): use of LAPACK needs to be enabled"); Chris@49: status = false; Chris@49: } Chris@49: #endif Chris@49: Chris@49: return status; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: eT Chris@49: auxlib::det(const Base& X, const bool slow) Chris@49: { Chris@49: const unwrap tmp(X.get_ref()); Chris@49: const Mat& A = tmp.M; Chris@49: Chris@49: arma_debug_check( (A.is_square() == false), "det(): matrix is not square" ); Chris@49: Chris@49: const bool make_copy = (is_Mat::value == true) ? true : false; Chris@49: Chris@49: if(slow == false) Chris@49: { Chris@49: const uword N = A.n_rows; Chris@49: Chris@49: switch(N) Chris@49: { Chris@49: case 0: Chris@49: case 1: Chris@49: case 2: Chris@49: return auxlib::det_tinymat(A, N); Chris@49: break; Chris@49: Chris@49: case 3: Chris@49: case 4: Chris@49: { Chris@49: const eT tmp_det = auxlib::det_tinymat(A, N); Chris@49: return (tmp_det != eT(0)) ? tmp_det : auxlib::det_lapack(A, make_copy); Chris@49: } Chris@49: break; Chris@49: Chris@49: default: Chris@49: return auxlib::det_lapack(A, make_copy); Chris@49: } Chris@49: } Chris@49: Chris@49: return auxlib::det_lapack(A, make_copy); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: eT Chris@49: auxlib::det_tinymat(const Mat& X, const uword N) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: switch(N) Chris@49: { Chris@49: case 0: Chris@49: return eT(1); Chris@49: break; Chris@49: Chris@49: case 1: Chris@49: return X[0]; Chris@49: break; Chris@49: Chris@49: case 2: Chris@49: { Chris@49: const eT* Xm = X.memptr(); Chris@49: Chris@49: return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] ); Chris@49: } Chris@49: break; Chris@49: Chris@49: case 3: Chris@49: { Chris@49: // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2); Chris@49: // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0); Chris@49: // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1); Chris@49: // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2); Chris@49: // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0); Chris@49: // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1); Chris@49: // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6); Chris@49: Chris@49: const eT* a_col0 = X.colptr(0); Chris@49: const eT a11 = a_col0[0]; Chris@49: const eT a21 = a_col0[1]; Chris@49: const eT a31 = a_col0[2]; Chris@49: Chris@49: const eT* a_col1 = X.colptr(1); Chris@49: const eT a12 = a_col1[0]; Chris@49: const eT a22 = a_col1[1]; Chris@49: const eT a32 = a_col1[2]; Chris@49: Chris@49: const eT* a_col2 = X.colptr(2); Chris@49: const eT a13 = a_col2[0]; Chris@49: const eT a23 = a_col2[1]; Chris@49: const eT a33 = a_col2[2]; Chris@49: Chris@49: return ( a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13) ); Chris@49: } Chris@49: break; Chris@49: Chris@49: case 4: Chris@49: { Chris@49: const eT* Xm = X.memptr(); Chris@49: Chris@49: const eT val = \ Chris@49: Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \ Chris@49: - Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \ Chris@49: - Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \ Chris@49: + Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \ Chris@49: + Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \ Chris@49: - Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \ Chris@49: - Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \ Chris@49: + Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \ Chris@49: + Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \ Chris@49: - Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \ Chris@49: - Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \ Chris@49: + Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \ Chris@49: + Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \ Chris@49: - Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \ Chris@49: - Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \ Chris@49: + Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \ Chris@49: + Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \ Chris@49: - Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \ Chris@49: - Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \ Chris@49: + Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \ Chris@49: + Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \ Chris@49: - Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \ Chris@49: - Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \ Chris@49: + Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \ Chris@49: ; Chris@49: Chris@49: return val; Chris@49: } Chris@49: break; Chris@49: Chris@49: default: Chris@49: return eT(0); Chris@49: ; Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! immediate determinant of a matrix using ATLAS or LAPACK Chris@49: template Chris@49: inline Chris@49: eT Chris@49: auxlib::det_lapack(const Mat& X, const bool make_copy) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: Mat X_copy; Chris@49: Chris@49: if(make_copy == true) Chris@49: { Chris@49: X_copy = X; Chris@49: } Chris@49: Chris@49: Mat& tmp = (make_copy == true) ? X_copy : const_cast< Mat& >(X); Chris@49: Chris@49: if(tmp.is_empty()) Chris@49: { Chris@49: return eT(1); Chris@49: } Chris@49: Chris@49: Chris@49: #if defined(ARMA_USE_ATLAS) Chris@49: { Chris@49: podarray ipiv(tmp.n_rows); Chris@49: Chris@49: //const int info = Chris@49: atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr()); Chris@49: Chris@49: // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero Chris@49: eT val = tmp.at(0,0); Chris@49: for(uword i=1; i < tmp.n_rows; ++i) Chris@49: { Chris@49: val *= tmp.at(i,i); Chris@49: } Chris@49: Chris@49: int sign = +1; Chris@49: for(uword i=0; i < tmp.n_rows; ++i) Chris@49: { Chris@49: if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0 Chris@49: { Chris@49: sign *= -1; Chris@49: } Chris@49: } Chris@49: Chris@49: return ( (sign < 0) ? -val : val ); Chris@49: } Chris@49: #elif defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: podarray ipiv(tmp.n_rows); Chris@49: Chris@49: blas_int info = 0; Chris@49: blas_int n_rows = blas_int(tmp.n_rows); Chris@49: blas_int n_cols = blas_int(tmp.n_cols); Chris@49: Chris@49: lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info); Chris@49: Chris@49: // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero Chris@49: eT val = tmp.at(0,0); Chris@49: for(uword i=1; i < tmp.n_rows; ++i) Chris@49: { Chris@49: val *= tmp.at(i,i); Chris@49: } Chris@49: Chris@49: blas_int sign = +1; Chris@49: for(uword i=0; i < tmp.n_rows; ++i) Chris@49: { Chris@49: if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1 Chris@49: { Chris@49: sign *= -1; Chris@49: } Chris@49: } Chris@49: Chris@49: return ( (sign < 0) ? -val : val ); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(X); Chris@49: arma_ignore(make_copy); Chris@49: arma_ignore(tmp); Chris@49: arma_stop("det(): use of ATLAS or LAPACK needs to be enabled"); Chris@49: return eT(0); Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! immediate log determinant of a matrix using ATLAS or LAPACK Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::log_det(eT& out_val, typename get_pod_type::result& out_sign, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename get_pod_type::result T; Chris@49: Chris@49: #if defined(ARMA_USE_ATLAS) Chris@49: { Chris@49: Mat tmp(X.get_ref()); Chris@49: arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" ); Chris@49: Chris@49: if(tmp.is_empty()) Chris@49: { Chris@49: out_val = eT(0); Chris@49: out_sign = T(1); Chris@49: return true; Chris@49: } Chris@49: Chris@49: podarray ipiv(tmp.n_rows); Chris@49: Chris@49: const int info = atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr()); Chris@49: Chris@49: // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero Chris@49: Chris@49: sword sign = (is_complex::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1; Chris@49: eT val = (is_complex::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) ); Chris@49: Chris@49: for(uword i=1; i < tmp.n_rows; ++i) Chris@49: { Chris@49: const eT x = tmp.at(i,i); Chris@49: Chris@49: sign *= (is_complex::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1; Chris@49: val += (is_complex::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); Chris@49: } Chris@49: Chris@49: for(uword i=0; i < tmp.n_rows; ++i) Chris@49: { Chris@49: if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0 Chris@49: { Chris@49: sign *= -1; Chris@49: } Chris@49: } Chris@49: Chris@49: out_val = val; Chris@49: out_sign = T(sign); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #elif defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: Mat tmp(X.get_ref()); Chris@49: arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" ); Chris@49: Chris@49: if(tmp.is_empty()) Chris@49: { Chris@49: out_val = eT(0); Chris@49: out_sign = T(1); Chris@49: return true; Chris@49: } Chris@49: Chris@49: podarray ipiv(tmp.n_rows); Chris@49: Chris@49: blas_int info = 0; Chris@49: blas_int n_rows = blas_int(tmp.n_rows); Chris@49: blas_int n_cols = blas_int(tmp.n_cols); Chris@49: Chris@49: lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info); Chris@49: Chris@49: // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero Chris@49: Chris@49: sword sign = (is_complex::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1; Chris@49: eT val = (is_complex::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) ); Chris@49: Chris@49: for(uword i=1; i < tmp.n_rows; ++i) Chris@49: { Chris@49: const eT x = tmp.at(i,i); Chris@49: Chris@49: sign *= (is_complex::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1; Chris@49: val += (is_complex::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); Chris@49: } Chris@49: Chris@49: for(uword i=0; i < tmp.n_rows; ++i) Chris@49: { Chris@49: if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1 Chris@49: { Chris@49: sign *= -1; Chris@49: } Chris@49: } Chris@49: Chris@49: out_val = val; Chris@49: out_sign = T(sign); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(X); Chris@49: Chris@49: out_val = eT(0); Chris@49: out_sign = T(0); Chris@49: Chris@49: arma_stop("log_det(): use of ATLAS or LAPACK needs to be enabled"); Chris@49: Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! immediate LU decomposition of a matrix using ATLAS or LAPACK Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::lu(Mat& L, Mat& U, podarray& ipiv, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: U = X.get_ref(); Chris@49: Chris@49: const uword U_n_rows = U.n_rows; Chris@49: const uword U_n_cols = U.n_cols; Chris@49: Chris@49: if(U.is_empty()) Chris@49: { Chris@49: L.set_size(U_n_rows, 0); Chris@49: U.set_size(0, U_n_cols); Chris@49: ipiv.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: #if defined(ARMA_USE_ATLAS) || defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: bool status; Chris@49: Chris@49: #if defined(ARMA_USE_ATLAS) Chris@49: { Chris@49: ipiv.set_size( (std::min)(U_n_rows, U_n_cols) ); Chris@49: Chris@49: int info = atlas::clapack_getrf(atlas::CblasColMajor, U_n_rows, U_n_cols, U.memptr(), U_n_rows, ipiv.memptr()); Chris@49: Chris@49: status = (info == 0); Chris@49: } Chris@49: #elif defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: ipiv.set_size( (std::min)(U_n_rows, U_n_cols) ); Chris@49: Chris@49: blas_int info = 0; Chris@49: Chris@49: blas_int n_rows = U_n_rows; Chris@49: blas_int n_cols = U_n_cols; Chris@49: Chris@49: Chris@49: lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info); Chris@49: Chris@49: // take into account that Fortran counts from 1 Chris@49: arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem); Chris@49: Chris@49: status = (info == 0); Chris@49: } Chris@49: #endif Chris@49: Chris@49: L.copy_size(U); Chris@49: Chris@49: for(uword col=0; col < U_n_cols; ++col) Chris@49: { Chris@49: for(uword row=0; (row < col) && (row < U_n_rows); ++row) Chris@49: { Chris@49: L.at(row,col) = eT(0); Chris@49: } Chris@49: Chris@49: if( L.in_range(col,col) == true ) Chris@49: { Chris@49: L.at(col,col) = eT(1); Chris@49: } Chris@49: Chris@49: for(uword row = (col+1); row < U_n_rows; ++row) Chris@49: { Chris@49: L.at(row,col) = U.at(row,col); Chris@49: U.at(row,col) = eT(0); Chris@49: } Chris@49: } Chris@49: Chris@49: return status; Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_stop("lu(): use of ATLAS or LAPACK needs to be enabled"); Chris@49: Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::lu(Mat& L, Mat& U, Mat& P, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: podarray ipiv1; Chris@49: const bool status = auxlib::lu(L, U, ipiv1, X); Chris@49: Chris@49: if(status == true) Chris@49: { Chris@49: if(U.is_empty()) Chris@49: { Chris@49: // L and U have been already set to the correct empty matrices Chris@49: P.eye(L.n_rows, L.n_rows); Chris@49: return true; Chris@49: } Chris@49: Chris@49: const uword n = ipiv1.n_elem; Chris@49: const uword P_rows = U.n_rows; Chris@49: Chris@49: podarray ipiv2(P_rows); Chris@49: Chris@49: const blas_int* ipiv1_mem = ipiv1.memptr(); Chris@49: blas_int* ipiv2_mem = ipiv2.memptr(); Chris@49: Chris@49: for(uword i=0; i(ipiv1_mem[i]); Chris@49: Chris@49: if( ipiv2_mem[i] != ipiv2_mem[k] ) Chris@49: { Chris@49: std::swap( ipiv2_mem[i], ipiv2_mem[k] ); Chris@49: } Chris@49: } Chris@49: Chris@49: P.zeros(P_rows, P_rows); Chris@49: Chris@49: for(uword row=0; row(ipiv2_mem[row])) = eT(1); Chris@49: } Chris@49: Chris@49: if(L.n_cols > U.n_rows) Chris@49: { Chris@49: L.shed_cols(U.n_rows, L.n_cols-1); Chris@49: } Chris@49: Chris@49: if(U.n_rows > L.n_cols) Chris@49: { Chris@49: U.shed_rows(L.n_cols, U.n_rows-1); Chris@49: } Chris@49: } Chris@49: Chris@49: return status; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::lu(Mat& L, Mat& U, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: podarray ipiv1; Chris@49: const bool status = auxlib::lu(L, U, ipiv1, X); Chris@49: Chris@49: if(status == true) Chris@49: { Chris@49: if(U.is_empty()) Chris@49: { Chris@49: // L and U have been already set to the correct empty matrices Chris@49: return true; Chris@49: } Chris@49: Chris@49: const uword n = ipiv1.n_elem; Chris@49: const uword P_rows = U.n_rows; Chris@49: Chris@49: podarray ipiv2(P_rows); Chris@49: Chris@49: const blas_int* ipiv1_mem = ipiv1.memptr(); Chris@49: blas_int* ipiv2_mem = ipiv2.memptr(); Chris@49: Chris@49: for(uword i=0; i(ipiv1_mem[i]); Chris@49: Chris@49: if( ipiv2_mem[i] != ipiv2_mem[k] ) Chris@49: { Chris@49: std::swap( ipiv2_mem[i], ipiv2_mem[k] ); Chris@49: L.swap_rows( static_cast(ipiv2_mem[i]), static_cast(ipiv2_mem[k]) ); Chris@49: } Chris@49: } Chris@49: Chris@49: if(L.n_cols > U.n_rows) Chris@49: { Chris@49: L.shed_cols(U.n_rows, L.n_cols-1); Chris@49: } Chris@49: Chris@49: if(U.n_rows > L.n_cols) Chris@49: { Chris@49: U.shed_rows(L.n_cols, U.n_rows-1); Chris@49: } Chris@49: } Chris@49: Chris@49: return status; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! immediate eigenvalues of a symmetric real matrix using LAPACK Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::eig_sym(Col& eigval, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: Mat A(X.get_ref()); Chris@49: Chris@49: arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square"); Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: eigval.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: eigval.set_size(A.n_rows); Chris@49: Chris@49: char jobz = 'N'; Chris@49: char uplo = 'U'; Chris@49: Chris@49: blas_int N = blas_int(A.n_rows); Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), 3*N-1) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: Chris@49: lapack::syev(&jobz, &uplo, &N, A.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, &info); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(eigval); Chris@49: arma_ignore(X); Chris@49: arma_stop("eig_sym(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! immediate eigenvalues of a hermitian complex matrix using LAPACK Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::eig_sym(Col& eigval, const Base,T1>& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: typedef typename std::complex eT; Chris@49: Chris@49: Mat A(X.get_ref()); Chris@49: Chris@49: arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square"); Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: eigval.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: eigval.set_size(A.n_rows); Chris@49: Chris@49: char jobz = 'N'; Chris@49: char uplo = 'U'; Chris@49: Chris@49: blas_int N = blas_int(A.n_rows); Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*N-1) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: podarray rwork( static_cast( (std::max)(blas_int(1), 3*N-2) ) ); Chris@49: Chris@49: arma_extra_debug_print("lapack::heev()"); Chris@49: lapack::heev(&jobz, &uplo, &N, A.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(eigval); Chris@49: arma_ignore(X); Chris@49: arma_stop("eig_sym(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! immediate eigenvalues and eigenvectors of a symmetric real matrix using LAPACK Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::eig_sym(Col& eigval, Mat& eigvec, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: eigvec = X.get_ref(); Chris@49: Chris@49: arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" ); Chris@49: Chris@49: if(eigvec.is_empty()) Chris@49: { Chris@49: eigval.reset(); Chris@49: eigvec.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: eigval.set_size(eigvec.n_rows); Chris@49: Chris@49: char jobz = 'V'; Chris@49: char uplo = 'U'; Chris@49: Chris@49: blas_int N = blas_int(eigvec.n_rows); Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), 3*N-1) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: Chris@49: lapack::syev(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, &info); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(eigval); Chris@49: arma_ignore(eigvec); Chris@49: arma_ignore(X); Chris@49: arma_stop("eig_sym(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! immediate eigenvalues and eigenvectors of a hermitian complex matrix using LAPACK Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::eig_sym(Col& eigval, Mat< std::complex >& eigvec, const Base,T1>& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: typedef typename std::complex eT; Chris@49: Chris@49: eigvec = X.get_ref(); Chris@49: Chris@49: arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" ); Chris@49: Chris@49: if(eigvec.is_empty()) Chris@49: { Chris@49: eigval.reset(); Chris@49: eigvec.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: eigval.set_size(eigvec.n_rows); Chris@49: Chris@49: char jobz = 'V'; Chris@49: char uplo = 'U'; Chris@49: Chris@49: blas_int N = blas_int(eigvec.n_rows); Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*N-1) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: podarray rwork( static_cast((std::max)(blas_int(1), 3*N-2)) ); Chris@49: Chris@49: arma_extra_debug_print("lapack::heev()"); Chris@49: lapack::heev(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(eigval); Chris@49: arma_ignore(eigvec); Chris@49: arma_ignore(X); Chris@49: arma_stop("eig_sym(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! immediate eigenvalues and eigenvectors of a symmetric real matrix using LAPACK (divide and conquer algorithm) Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::eig_sym_dc(Col& eigval, Mat& eigvec, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: eigvec = X.get_ref(); Chris@49: Chris@49: arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" ); Chris@49: Chris@49: if(eigvec.is_empty()) Chris@49: { Chris@49: eigval.reset(); Chris@49: eigvec.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: eigval.set_size(eigvec.n_rows); Chris@49: Chris@49: char jobz = 'V'; Chris@49: char uplo = 'U'; Chris@49: Chris@49: blas_int N = blas_int(eigvec.n_rows); Chris@49: blas_int lwork = 3 * (1 + 6*N + 2*(N*N)); Chris@49: blas_int liwork = 3 * (3 + 5*N + 2); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray work( static_cast( lwork) ); Chris@49: podarray iwork( static_cast(liwork) ); Chris@49: Chris@49: arma_extra_debug_print("lapack::syevd()"); Chris@49: lapack::syevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, iwork.memptr(), &liwork, &info); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(eigval); Chris@49: arma_ignore(eigvec); Chris@49: arma_ignore(X); Chris@49: arma_stop("eig_sym(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! immediate eigenvalues and eigenvectors of a hermitian complex matrix using LAPACK (divide and conquer algorithm) Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::eig_sym_dc(Col& eigval, Mat< std::complex >& eigvec, const Base,T1>& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: typedef typename std::complex eT; Chris@49: Chris@49: eigvec = X.get_ref(); Chris@49: Chris@49: arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" ); Chris@49: Chris@49: if(eigvec.is_empty()) Chris@49: { Chris@49: eigval.reset(); Chris@49: eigvec.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: eigval.set_size(eigvec.n_rows); Chris@49: Chris@49: char jobz = 'V'; Chris@49: char uplo = 'U'; Chris@49: Chris@49: blas_int N = blas_int(eigvec.n_rows); Chris@49: blas_int lwork = 3 * (2*N + N*N); Chris@49: blas_int lrwork = 3 * (1 + 5*N + 2*(N*N)); Chris@49: blas_int liwork = 3 * (3 + 5*N); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: podarray rwork( static_cast(lrwork) ); Chris@49: podarray iwork( static_cast(liwork) ); Chris@49: Chris@49: arma_extra_debug_print("lapack::heevd()"); Chris@49: lapack::heevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &lrwork, iwork.memptr(), &liwork, &info); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(eigval); Chris@49: arma_ignore(eigvec); Chris@49: arma_ignore(X); Chris@49: arma_stop("eig_sym(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! Eigenvalues and eigenvectors of a general square real matrix using LAPACK. Chris@49: //! The argument 'side' specifies which eigenvectors should be calculated Chris@49: //! (see code for mode details). Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::eig_gen Chris@49: ( Chris@49: Col< std::complex >& eigval, Chris@49: Mat& l_eigvec, Chris@49: Mat& r_eigvec, Chris@49: const Base& X, Chris@49: const char side Chris@49: ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: char jobvl; Chris@49: char jobvr; Chris@49: Chris@49: switch(side) Chris@49: { Chris@49: case 'l': // left Chris@49: jobvl = 'V'; Chris@49: jobvr = 'N'; Chris@49: break; Chris@49: Chris@49: case 'r': // right Chris@49: jobvl = 'N'; Chris@49: jobvr = 'V'; Chris@49: break; Chris@49: Chris@49: case 'b': // both Chris@49: jobvl = 'V'; Chris@49: jobvr = 'V'; Chris@49: break; Chris@49: Chris@49: case 'n': // neither Chris@49: jobvl = 'N'; Chris@49: jobvr = 'N'; Chris@49: break; Chris@49: Chris@49: default: Chris@49: arma_stop("eig_gen(): parameter 'side' is invalid"); Chris@49: return false; Chris@49: } Chris@49: Chris@49: Mat A(X.get_ref()); Chris@49: arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" ); Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: eigval.reset(); Chris@49: l_eigvec.reset(); Chris@49: r_eigvec.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: const uword A_n_rows = A.n_rows; Chris@49: Chris@49: eigval.set_size(A_n_rows); Chris@49: Chris@49: l_eigvec.set_size(A_n_rows, A_n_rows); Chris@49: r_eigvec.set_size(A_n_rows, A_n_rows); Chris@49: Chris@49: blas_int N = blas_int(A_n_rows); Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), 4*N) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: Chris@49: podarray wr(A_n_rows); Chris@49: podarray wi(A_n_rows); Chris@49: Chris@49: arma_extra_debug_print("lapack::geev()"); Chris@49: lapack::geev(&jobvl, &jobvr, &N, A.memptr(), &N, wr.memptr(), wi.memptr(), l_eigvec.memptr(), &N, r_eigvec.memptr(), &N, work.memptr(), &lwork, &info); Chris@49: Chris@49: eigval.set_size(A_n_rows); Chris@49: for(uword i=0; i(wr[i], wi[i]); Chris@49: } Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(eigval); Chris@49: arma_ignore(l_eigvec); Chris@49: arma_ignore(r_eigvec); Chris@49: arma_ignore(X); Chris@49: arma_ignore(side); Chris@49: arma_stop("eig_gen(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: Chris@49: Chris@49: //! Eigenvalues and eigenvectors of a general square complex matrix using LAPACK Chris@49: //! The argument 'side' specifies which eigenvectors should be calculated Chris@49: //! (see code for mode details). Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::eig_gen Chris@49: ( Chris@49: Col< std::complex >& eigval, Chris@49: Mat< std::complex >& l_eigvec, Chris@49: Mat< std::complex >& r_eigvec, Chris@49: const Base< std::complex, T1 >& X, Chris@49: const char side Chris@49: ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: typedef typename std::complex eT; Chris@49: Chris@49: char jobvl; Chris@49: char jobvr; Chris@49: Chris@49: switch(side) Chris@49: { Chris@49: case 'l': // left Chris@49: jobvl = 'V'; Chris@49: jobvr = 'N'; Chris@49: break; Chris@49: Chris@49: case 'r': // right Chris@49: jobvl = 'N'; Chris@49: jobvr = 'V'; Chris@49: break; Chris@49: Chris@49: case 'b': // both Chris@49: jobvl = 'V'; Chris@49: jobvr = 'V'; Chris@49: break; Chris@49: Chris@49: case 'n': // neither Chris@49: jobvl = 'N'; Chris@49: jobvr = 'N'; Chris@49: break; Chris@49: Chris@49: default: Chris@49: arma_stop("eig_gen(): parameter 'side' is invalid"); Chris@49: return false; Chris@49: } Chris@49: Chris@49: Mat A(X.get_ref()); Chris@49: arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" ); Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: eigval.reset(); Chris@49: l_eigvec.reset(); Chris@49: r_eigvec.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: const uword A_n_rows = A.n_rows; Chris@49: Chris@49: eigval.set_size(A_n_rows); Chris@49: Chris@49: l_eigvec.set_size(A_n_rows, A_n_rows); Chris@49: r_eigvec.set_size(A_n_rows, A_n_rows); Chris@49: Chris@49: blas_int N = blas_int(A_n_rows); Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*N) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: podarray rwork( static_cast(2*N) ); Chris@49: Chris@49: arma_extra_debug_print("lapack::cx_geev()"); Chris@49: lapack::cx_geev(&jobvl, &jobvr, &N, A.memptr(), &N, eigval.memptr(), l_eigvec.memptr(), &N, r_eigvec.memptr(), &N, work.memptr(), &lwork, rwork.memptr(), &info); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(eigval); Chris@49: arma_ignore(l_eigvec); Chris@49: arma_ignore(r_eigvec); Chris@49: arma_ignore(X); Chris@49: arma_ignore(side); Chris@49: arma_stop("eig_gen(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::chol(Mat& out, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: out = X.get_ref(); Chris@49: Chris@49: arma_debug_check( (out.is_square() == false), "chol(): given matrix is not square" ); Chris@49: Chris@49: if(out.is_empty()) Chris@49: { Chris@49: return true; Chris@49: } Chris@49: Chris@49: const uword out_n_rows = out.n_rows; Chris@49: Chris@49: char uplo = 'U'; Chris@49: blas_int n = out_n_rows; Chris@49: blas_int info = 0; Chris@49: Chris@49: lapack::potrf(&uplo, &n, out.memptr(), &n, &info); Chris@49: Chris@49: for(uword col=0; col Chris@49: inline Chris@49: bool Chris@49: auxlib::qr(Mat& Q, Mat& R, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: R = X.get_ref(); Chris@49: Chris@49: const uword R_n_rows = R.n_rows; Chris@49: const uword R_n_cols = R.n_cols; Chris@49: Chris@49: if(R.is_empty()) Chris@49: { Chris@49: Q.eye(R_n_rows, R_n_rows); Chris@49: return true; Chris@49: } Chris@49: Chris@49: blas_int m = static_cast(R_n_rows); Chris@49: blas_int n = static_cast(R_n_cols); Chris@49: blas_int lwork = 0; Chris@49: blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr() Chris@49: blas_int k = (std::min)(m,n); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray tau( static_cast(k) ); Chris@49: Chris@49: eT work_query[2]; Chris@49: blas_int lwork_query = -1; Chris@49: Chris@49: lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info); Chris@49: Chris@49: if(info == 0) Chris@49: { Chris@49: const blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); Chris@49: Chris@49: lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min; Chris@49: } Chris@49: else Chris@49: { Chris@49: return false; Chris@49: } Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: Chris@49: lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); Chris@49: Chris@49: Q.set_size(R_n_rows, R_n_rows); Chris@49: Chris@49: arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) ); Chris@49: Chris@49: // Chris@49: // construct R Chris@49: Chris@49: for(uword col=0; col < R_n_cols; ++col) Chris@49: { Chris@49: for(uword row=(col+1); row < R_n_rows; ++row) Chris@49: { Chris@49: R.at(row,col) = eT(0); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: if( (is_float::value == true) || (is_double::value == true) ) Chris@49: { Chris@49: lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); Chris@49: } Chris@49: else Chris@49: if( (is_supported_complex_float::value == true) || (is_supported_complex_double::value == true) ) Chris@49: { Chris@49: lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); Chris@49: } Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(Q); Chris@49: arma_ignore(R); Chris@49: arma_ignore(X); Chris@49: arma_stop("qr(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::qr_econ(Mat& Q, Mat& R, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: // This function implements a memory-efficient QR for a non-square X that has dimensions m x n. Chris@49: // This basically discards the basis for the null-space. Chris@49: // Chris@49: // if m <= n: (use standard routine) Chris@49: // Q[m,m]*R[m,n] = X[m,n] Chris@49: // geqrf Needs A[m,n]: Uses R Chris@49: // orgqr Needs A[m,m]: Uses Q Chris@49: // otherwise: (memory-efficient routine) Chris@49: // Q[m,n]*R[n,n] = X[m,n] Chris@49: // geqrf Needs A[m,n]: Uses Q Chris@49: // geqrf Needs A[m,n]: Uses Q Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: if(is_Mat::value == true) Chris@49: { Chris@49: const unwrap tmp(X.get_ref()); Chris@49: const Mat& M = tmp.M; Chris@49: Chris@49: if(M.n_rows < M.n_cols) Chris@49: { Chris@49: return auxlib::qr(Q, R, X); Chris@49: } Chris@49: } Chris@49: Chris@49: Q = X.get_ref(); Chris@49: Chris@49: const uword Q_n_rows = Q.n_rows; Chris@49: const uword Q_n_cols = Q.n_cols; Chris@49: Chris@49: if( Q_n_rows <= Q_n_cols ) Chris@49: { Chris@49: return auxlib::qr(Q, R, Q); Chris@49: } Chris@49: Chris@49: if(Q.is_empty()) Chris@49: { Chris@49: Q.set_size(Q_n_rows, 0 ); Chris@49: R.set_size(0, Q_n_cols); Chris@49: return true; Chris@49: } Chris@49: Chris@49: blas_int m = static_cast(Q_n_rows); Chris@49: blas_int n = static_cast(Q_n_cols); Chris@49: blas_int lwork = 0; Chris@49: blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr() Chris@49: blas_int k = (std::min)(m,n); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray tau( static_cast(k) ); Chris@49: Chris@49: eT work_query[2]; Chris@49: blas_int lwork_query = -1; Chris@49: Chris@49: lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info); Chris@49: Chris@49: if(info == 0) Chris@49: { Chris@49: const blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); Chris@49: Chris@49: lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min; Chris@49: } Chris@49: else Chris@49: { Chris@49: return false; Chris@49: } Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: Chris@49: lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); Chris@49: Chris@49: // Q now has the elements on and above the diagonal of the array Chris@49: // contain the min(M,N)-by-N upper trapezoidal matrix Q Chris@49: // (Q is upper triangular if m >= n); Chris@49: // the elements below the diagonal, with the array TAU, Chris@49: // represent the orthogonal matrix Q as a product of min(m,n) elementary reflectors. Chris@49: Chris@49: R.set_size(Q_n_cols, Q_n_cols); Chris@49: Chris@49: // Chris@49: // construct R Chris@49: Chris@49: for(uword col=0; col < Q_n_cols; ++col) Chris@49: { Chris@49: for(uword row=0; row <= col; ++row) Chris@49: { Chris@49: R.at(row,col) = Q.at(row,col); Chris@49: } Chris@49: Chris@49: for(uword row=(col+1); row < Q_n_cols; ++row) Chris@49: { Chris@49: R.at(row,col) = eT(0); Chris@49: } Chris@49: } Chris@49: Chris@49: if( (is_float::value == true) || (is_double::value == true) ) Chris@49: { Chris@49: lapack::orgqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); Chris@49: } Chris@49: else Chris@49: if( (is_supported_complex_float::value == true) || (is_supported_complex_double::value == true) ) Chris@49: { Chris@49: lapack::ungqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); Chris@49: } Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(Q); Chris@49: arma_ignore(R); Chris@49: arma_ignore(X); Chris@49: arma_stop("qr_econ(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::svd(Col& S, const Base& X, uword& X_n_rows, uword& X_n_cols) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: Mat A(X.get_ref()); Chris@49: Chris@49: X_n_rows = A.n_rows; Chris@49: X_n_cols = A.n_cols; Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: S.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: Mat U(1, 1); Chris@49: Mat V(1, A.n_cols); Chris@49: Chris@49: char jobu = 'N'; Chris@49: char jobvt = 'N'; Chris@49: Chris@49: blas_int m = A.n_rows; Chris@49: blas_int n = A.n_cols; Chris@49: blas_int min_mn = (std::min)(m,n); Chris@49: blas_int lda = A.n_rows; Chris@49: blas_int ldu = U.n_rows; Chris@49: blas_int ldvt = V.n_rows; Chris@49: blas_int lwork = 0; Chris@49: blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: S.set_size( static_cast(min_mn) ); Chris@49: Chris@49: eT work_query[2]; Chris@49: blas_int lwork_query = -1; Chris@49: Chris@49: lapack::gesvd Chris@49: ( Chris@49: &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info Chris@49: ); Chris@49: Chris@49: if(info == 0) Chris@49: { Chris@49: const blas_int lwork_proposed = static_cast( work_query[0] ); Chris@49: Chris@49: lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: Chris@49: lapack::gesvd Chris@49: ( Chris@49: &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, &info Chris@49: ); Chris@49: } Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(S); Chris@49: arma_ignore(X); Chris@49: arma_ignore(X_n_rows); Chris@49: arma_ignore(X_n_cols); Chris@49: arma_stop("svd(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::svd(Col& S, const Base, T1>& X, uword& X_n_rows, uword& X_n_cols) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: typedef std::complex eT; Chris@49: Chris@49: Mat A(X.get_ref()); Chris@49: Chris@49: X_n_rows = A.n_rows; Chris@49: X_n_cols = A.n_cols; Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: S.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: Mat U(1, 1); Chris@49: Mat V(1, A.n_cols); Chris@49: Chris@49: char jobu = 'N'; Chris@49: char jobvt = 'N'; Chris@49: Chris@49: blas_int m = A.n_rows; Chris@49: blas_int n = A.n_cols; Chris@49: blas_int min_mn = (std::min)(m,n); Chris@49: blas_int lda = A.n_rows; Chris@49: blas_int ldu = U.n_rows; Chris@49: blas_int ldvt = V.n_rows; Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*min_mn+(std::max)(m,n) ) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: S.set_size( static_cast(min_mn) ); Chris@49: Chris@49: podarray work( static_cast(lwork ) ); Chris@49: podarray< T> rwork( static_cast(5*min_mn) ); Chris@49: Chris@49: // let gesvd_() calculate the optimum size of the workspace Chris@49: blas_int lwork_tmp = -1; Chris@49: Chris@49: lapack::cx_gesvd Chris@49: ( Chris@49: &jobu, &jobvt, Chris@49: &m, &n, Chris@49: A.memptr(), &lda, Chris@49: S.memptr(), Chris@49: U.memptr(), &ldu, Chris@49: V.memptr(), &ldvt, Chris@49: work.memptr(), &lwork_tmp, Chris@49: rwork.memptr(), Chris@49: &info Chris@49: ); Chris@49: Chris@49: if(info == 0) Chris@49: { Chris@49: blas_int proposed_lwork = static_cast(real(work[0])); Chris@49: if(proposed_lwork > lwork) Chris@49: { Chris@49: lwork = proposed_lwork; Chris@49: work.set_size( static_cast(lwork) ); Chris@49: } Chris@49: Chris@49: lapack::cx_gesvd Chris@49: ( Chris@49: &jobu, &jobvt, Chris@49: &m, &n, Chris@49: A.memptr(), &lda, Chris@49: S.memptr(), Chris@49: U.memptr(), &ldu, Chris@49: V.memptr(), &ldvt, Chris@49: work.memptr(), &lwork, Chris@49: rwork.memptr(), Chris@49: &info Chris@49: ); Chris@49: } Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(S); Chris@49: arma_ignore(X); Chris@49: arma_ignore(X_n_rows); Chris@49: arma_ignore(X_n_cols); Chris@49: Chris@49: arma_stop("svd(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::svd(Col& S, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: uword junk; Chris@49: return auxlib::svd(S, X, junk, junk); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::svd(Col& S, const Base, T1>& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: uword junk; Chris@49: return auxlib::svd(S, X, junk, junk); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::svd(Mat& U, Col& S, Mat& V, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: Mat A(X.get_ref()); Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: U.eye(A.n_rows, A.n_rows); Chris@49: S.reset(); Chris@49: V.eye(A.n_cols, A.n_cols); Chris@49: return true; Chris@49: } Chris@49: Chris@49: U.set_size(A.n_rows, A.n_rows); Chris@49: V.set_size(A.n_cols, A.n_cols); Chris@49: Chris@49: char jobu = 'A'; Chris@49: char jobvt = 'A'; Chris@49: Chris@49: blas_int m = blas_int(A.n_rows); Chris@49: blas_int n = blas_int(A.n_cols); Chris@49: blas_int min_mn = (std::min)(m,n); Chris@49: blas_int lda = blas_int(A.n_rows); Chris@49: blas_int ldu = blas_int(U.n_rows); Chris@49: blas_int ldvt = blas_int(V.n_rows); Chris@49: blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); Chris@49: blas_int lwork = 0; Chris@49: blas_int info = 0; Chris@49: Chris@49: S.set_size( static_cast(min_mn) ); Chris@49: Chris@49: // let gesvd_() calculate the optimum size of the workspace Chris@49: eT work_query[2]; Chris@49: blas_int lwork_query = -1; Chris@49: Chris@49: lapack::gesvd Chris@49: ( Chris@49: &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info Chris@49: ); Chris@49: Chris@49: if(info == 0) Chris@49: { Chris@49: const blas_int lwork_proposed = static_cast( work_query[0] ); Chris@49: Chris@49: lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: Chris@49: lapack::gesvd Chris@49: ( Chris@49: &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, &info Chris@49: ); Chris@49: Chris@49: op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done Chris@49: } Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(U); Chris@49: arma_ignore(S); Chris@49: arma_ignore(V); Chris@49: arma_ignore(X); Chris@49: arma_stop("svd(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::svd(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, const Base< std::complex, T1>& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: typedef std::complex eT; Chris@49: Chris@49: Mat A(X.get_ref()); Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: U.eye(A.n_rows, A.n_rows); Chris@49: S.reset(); Chris@49: V.eye(A.n_cols, A.n_cols); Chris@49: return true; Chris@49: } Chris@49: Chris@49: U.set_size(A.n_rows, A.n_rows); Chris@49: V.set_size(A.n_cols, A.n_cols); Chris@49: Chris@49: char jobu = 'A'; Chris@49: char jobvt = 'A'; Chris@49: Chris@49: blas_int m = blas_int(A.n_rows); Chris@49: blas_int n = blas_int(A.n_cols); Chris@49: blas_int min_mn = (std::min)(m,n); Chris@49: blas_int lda = blas_int(A.n_rows); Chris@49: blas_int ldu = blas_int(U.n_rows); Chris@49: blas_int ldvt = blas_int(V.n_rows); Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*min_mn + (std::max)(m,n) ) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: S.set_size( static_cast(min_mn) ); Chris@49: Chris@49: podarray work( static_cast(lwork ) ); Chris@49: podarray rwork( static_cast(5*min_mn) ); Chris@49: Chris@49: // let gesvd_() calculate the optimum size of the workspace Chris@49: blas_int lwork_tmp = -1; Chris@49: lapack::cx_gesvd Chris@49: ( Chris@49: &jobu, &jobvt, Chris@49: &m, &n, Chris@49: A.memptr(), &lda, Chris@49: S.memptr(), Chris@49: U.memptr(), &ldu, Chris@49: V.memptr(), &ldvt, Chris@49: work.memptr(), &lwork_tmp, Chris@49: rwork.memptr(), Chris@49: &info Chris@49: ); Chris@49: Chris@49: if(info == 0) Chris@49: { Chris@49: blas_int proposed_lwork = static_cast(real(work[0])); Chris@49: Chris@49: if(proposed_lwork > lwork) Chris@49: { Chris@49: lwork = proposed_lwork; Chris@49: work.set_size( static_cast(lwork) ); Chris@49: } Chris@49: Chris@49: lapack::cx_gesvd Chris@49: ( Chris@49: &jobu, &jobvt, Chris@49: &m, &n, Chris@49: A.memptr(), &lda, Chris@49: S.memptr(), Chris@49: U.memptr(), &ldu, Chris@49: V.memptr(), &ldvt, Chris@49: work.memptr(), &lwork, Chris@49: rwork.memptr(), Chris@49: &info Chris@49: ); Chris@49: Chris@49: op_htrans::apply(V,V); // op_htrans will work out that an in-place transpose can be done Chris@49: } Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(U); Chris@49: arma_ignore(S); Chris@49: arma_ignore(V); Chris@49: arma_ignore(X); Chris@49: arma_stop("svd(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::svd_econ(Mat& U, Col& S, Mat& V, const Base& X, const char mode) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: Mat A(X.get_ref()); Chris@49: Chris@49: blas_int m = blas_int(A.n_rows); Chris@49: blas_int n = blas_int(A.n_cols); Chris@49: blas_int min_mn = (std::min)(m,n); Chris@49: blas_int lda = blas_int(A.n_rows); Chris@49: Chris@49: S.set_size( static_cast(min_mn) ); Chris@49: Chris@49: blas_int ldu = 0; Chris@49: blas_int ldvt = 0; Chris@49: Chris@49: char jobu; Chris@49: char jobvt; Chris@49: Chris@49: switch(mode) Chris@49: { Chris@49: case 'l': Chris@49: jobu = 'S'; Chris@49: jobvt = 'N'; Chris@49: Chris@49: ldu = m; Chris@49: ldvt = 1; Chris@49: Chris@49: U.set_size( static_cast(ldu), static_cast(min_mn) ); Chris@49: V.reset(); Chris@49: Chris@49: break; Chris@49: Chris@49: Chris@49: case 'r': Chris@49: jobu = 'N'; Chris@49: jobvt = 'S'; Chris@49: Chris@49: ldu = 1; Chris@49: ldvt = (std::min)(m,n); Chris@49: Chris@49: U.reset(); Chris@49: V.set_size( static_cast(ldvt), static_cast(n) ); Chris@49: Chris@49: break; Chris@49: Chris@49: Chris@49: case 'b': Chris@49: jobu = 'S'; Chris@49: jobvt = 'S'; Chris@49: Chris@49: ldu = m; Chris@49: ldvt = (std::min)(m,n); Chris@49: Chris@49: U.set_size( static_cast(ldu), static_cast(min_mn) ); Chris@49: V.set_size( static_cast(ldvt), static_cast(n ) ); Chris@49: Chris@49: break; Chris@49: Chris@49: Chris@49: default: Chris@49: U.reset(); Chris@49: S.reset(); Chris@49: V.reset(); Chris@49: return false; Chris@49: } Chris@49: Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: U.eye(); Chris@49: S.reset(); Chris@49: V.eye(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: Chris@49: // let gesvd_() calculate the optimum size of the workspace Chris@49: blas_int lwork_tmp = -1; Chris@49: Chris@49: lapack::gesvd Chris@49: ( Chris@49: &jobu, &jobvt, Chris@49: &m, &n, Chris@49: A.memptr(), &lda, Chris@49: S.memptr(), Chris@49: U.memptr(), &ldu, Chris@49: V.memptr(), &ldvt, Chris@49: work.memptr(), &lwork_tmp, Chris@49: &info Chris@49: ); Chris@49: Chris@49: if(info == 0) Chris@49: { Chris@49: blas_int proposed_lwork = static_cast(work[0]); Chris@49: if(proposed_lwork > lwork) Chris@49: { Chris@49: lwork = proposed_lwork; Chris@49: work.set_size( static_cast(lwork) ); Chris@49: } Chris@49: Chris@49: lapack::gesvd Chris@49: ( Chris@49: &jobu, &jobvt, Chris@49: &m, &n, Chris@49: A.memptr(), &lda, Chris@49: S.memptr(), Chris@49: U.memptr(), &ldu, Chris@49: V.memptr(), &ldvt, Chris@49: work.memptr(), &lwork, Chris@49: &info Chris@49: ); Chris@49: Chris@49: op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done Chris@49: } Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(U); Chris@49: arma_ignore(S); Chris@49: arma_ignore(V); Chris@49: arma_ignore(X); Chris@49: arma_ignore(mode); Chris@49: arma_stop("svd(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::svd_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, const Base< std::complex, T1>& X, const char mode) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: typedef std::complex eT; Chris@49: Chris@49: Mat A(X.get_ref()); Chris@49: Chris@49: blas_int m = blas_int(A.n_rows); Chris@49: blas_int n = blas_int(A.n_cols); Chris@49: blas_int min_mn = (std::min)(m,n); Chris@49: blas_int lda = blas_int(A.n_rows); Chris@49: Chris@49: S.set_size( static_cast(min_mn) ); Chris@49: Chris@49: blas_int ldu = 0; Chris@49: blas_int ldvt = 0; Chris@49: Chris@49: char jobu; Chris@49: char jobvt; Chris@49: Chris@49: switch(mode) Chris@49: { Chris@49: case 'l': Chris@49: jobu = 'S'; Chris@49: jobvt = 'N'; Chris@49: Chris@49: ldu = m; Chris@49: ldvt = 1; Chris@49: Chris@49: U.set_size( static_cast(ldu), static_cast(min_mn) ); Chris@49: V.reset(); Chris@49: Chris@49: break; Chris@49: Chris@49: Chris@49: case 'r': Chris@49: jobu = 'N'; Chris@49: jobvt = 'S'; Chris@49: Chris@49: ldu = 1; Chris@49: ldvt = (std::min)(m,n); Chris@49: Chris@49: U.reset(); Chris@49: V.set_size( static_cast(ldvt), static_cast(n) ); Chris@49: Chris@49: break; Chris@49: Chris@49: Chris@49: case 'b': Chris@49: jobu = 'S'; Chris@49: jobvt = 'S'; Chris@49: Chris@49: ldu = m; Chris@49: ldvt = (std::min)(m,n); Chris@49: Chris@49: U.set_size( static_cast(ldu), static_cast(min_mn) ); Chris@49: V.set_size( static_cast(ldvt), static_cast(n) ); Chris@49: Chris@49: break; Chris@49: Chris@49: Chris@49: default: Chris@49: U.reset(); Chris@49: S.reset(); Chris@49: V.reset(); Chris@49: return false; Chris@49: } Chris@49: Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: U.eye(); Chris@49: S.reset(); Chris@49: V.eye(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: Chris@49: podarray work( static_cast(lwork ) ); Chris@49: podarray rwork( static_cast(5*min_mn) ); Chris@49: Chris@49: // let gesvd_() calculate the optimum size of the workspace Chris@49: blas_int lwork_tmp = -1; Chris@49: Chris@49: lapack::cx_gesvd Chris@49: ( Chris@49: &jobu, &jobvt, Chris@49: &m, &n, Chris@49: A.memptr(), &lda, Chris@49: S.memptr(), Chris@49: U.memptr(), &ldu, Chris@49: V.memptr(), &ldvt, Chris@49: work.memptr(), &lwork_tmp, Chris@49: rwork.memptr(), Chris@49: &info Chris@49: ); Chris@49: Chris@49: if(info == 0) Chris@49: { Chris@49: blas_int proposed_lwork = static_cast(real(work[0])); Chris@49: if(proposed_lwork > lwork) Chris@49: { Chris@49: lwork = proposed_lwork; Chris@49: work.set_size( static_cast(lwork) ); Chris@49: } Chris@49: Chris@49: lapack::cx_gesvd Chris@49: ( Chris@49: &jobu, &jobvt, Chris@49: &m, &n, Chris@49: A.memptr(), &lda, Chris@49: S.memptr(), Chris@49: U.memptr(), &ldu, Chris@49: V.memptr(), &ldvt, Chris@49: work.memptr(), &lwork, Chris@49: rwork.memptr(), Chris@49: &info Chris@49: ); Chris@49: Chris@49: op_htrans::apply(V,V); // op_strans will work out that an in-place transpose can be done Chris@49: } Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(U); Chris@49: arma_ignore(S); Chris@49: arma_ignore(V); Chris@49: arma_ignore(X); Chris@49: arma_ignore(mode); Chris@49: arma_stop("svd(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::svd_dc(Mat& U, Col& S, Mat& V, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: Mat A(X.get_ref()); Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: U.eye(A.n_rows, A.n_rows); Chris@49: S.reset(); Chris@49: V.eye(A.n_cols, A.n_cols); Chris@49: return true; Chris@49: } Chris@49: Chris@49: U.set_size(A.n_rows, A.n_rows); Chris@49: V.set_size(A.n_cols, A.n_cols); Chris@49: Chris@49: char jobz = 'A'; Chris@49: Chris@49: blas_int m = blas_int(A.n_rows); Chris@49: blas_int n = blas_int(A.n_cols); Chris@49: blas_int min_mn = (std::min)(m,n); Chris@49: blas_int lda = blas_int(A.n_rows); Chris@49: blas_int ldu = blas_int(U.n_rows); Chris@49: blas_int ldvt = blas_int(V.n_rows); Chris@49: blas_int lwork = 3 * ( 3*min_mn*min_mn + (std::max)( (std::max)(m,n), 4*min_mn*min_mn + 4*min_mn ) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: S.set_size( static_cast(min_mn) ); Chris@49: Chris@49: podarray work( static_cast(lwork ) ); Chris@49: podarray iwork( static_cast(8*min_mn) ); Chris@49: Chris@49: lapack::gesdd Chris@49: ( Chris@49: &jobz, &m, &n, Chris@49: A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, Chris@49: work.memptr(), &lwork, iwork.memptr(), &info Chris@49: ); Chris@49: Chris@49: op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(U); Chris@49: arma_ignore(S); Chris@49: arma_ignore(V); Chris@49: arma_ignore(X); Chris@49: arma_stop("svd(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::svd_dc(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, const Base< std::complex, T1>& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: typedef std::complex eT; Chris@49: Chris@49: Mat A(X.get_ref()); Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: U.eye(A.n_rows, A.n_rows); Chris@49: S.reset(); Chris@49: V.eye(A.n_cols, A.n_cols); Chris@49: return true; Chris@49: } Chris@49: Chris@49: U.set_size(A.n_rows, A.n_rows); Chris@49: V.set_size(A.n_cols, A.n_cols); Chris@49: Chris@49: char jobz = 'A'; Chris@49: Chris@49: blas_int m = blas_int(A.n_rows); Chris@49: blas_int n = blas_int(A.n_cols); Chris@49: blas_int min_mn = (std::min)(m,n); Chris@49: blas_int lda = blas_int(A.n_rows); Chris@49: blas_int ldu = blas_int(U.n_rows); Chris@49: blas_int ldvt = blas_int(V.n_rows); Chris@49: blas_int lwork = 3 * (min_mn*min_mn + 2*min_mn + (std::max)(m,n)); Chris@49: blas_int info = 0; Chris@49: Chris@49: S.set_size( static_cast(min_mn) ); Chris@49: Chris@49: podarray work( static_cast(lwork ) ); Chris@49: podarray rwork( static_cast(5*min_mn*min_mn + 7*min_mn) ); Chris@49: podarray iwork( static_cast(8*min_mn ) ); Chris@49: Chris@49: lapack::cx_gesdd Chris@49: ( Chris@49: &jobz, &m, &n, Chris@49: A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, Chris@49: work.memptr(), &lwork, rwork.memptr(), iwork.memptr(), &info Chris@49: ); Chris@49: Chris@49: op_htrans::apply(V,V); // op_htrans will work out that an in-place transpose can be done Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(U); Chris@49: arma_ignore(S); Chris@49: arma_ignore(V); Chris@49: arma_ignore(X); Chris@49: arma_stop("svd(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! Solve a system of linear equations. Chris@49: //! Assumes that A.n_rows = A.n_cols and B.n_rows = A.n_rows Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::solve(Mat& out, Mat& A, const Base& X, const bool slow) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: bool status = false; Chris@49: Chris@49: const uword A_n_rows = A.n_rows; Chris@49: Chris@49: if( (A_n_rows <= 4) && (slow == false) ) Chris@49: { Chris@49: Mat A_inv; Chris@49: Chris@49: status = auxlib::inv_noalias_tinymat(A_inv, A, A_n_rows); Chris@49: Chris@49: if(status == true) Chris@49: { Chris@49: const unwrap_check Y( X.get_ref(), out ); Chris@49: const Mat& B = Y.M; Chris@49: Chris@49: const uword B_n_rows = B.n_rows; Chris@49: const uword B_n_cols = B.n_cols; Chris@49: Chris@49: arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" ); Chris@49: Chris@49: if(A.is_empty() || B.is_empty()) Chris@49: { Chris@49: out.zeros(A.n_cols, B_n_cols); Chris@49: return true; Chris@49: } Chris@49: Chris@49: out.set_size(A_n_rows, B_n_cols); Chris@49: Chris@49: gemm_emul::apply(out, A_inv, B); Chris@49: Chris@49: return true; Chris@49: } Chris@49: } Chris@49: Chris@49: if( (A_n_rows > 4) || (status == false) ) Chris@49: { Chris@49: out = X.get_ref(); Chris@49: Chris@49: const uword B_n_rows = out.n_rows; Chris@49: const uword B_n_cols = out.n_cols; Chris@49: Chris@49: arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" ); Chris@49: Chris@49: if(A.is_empty() || out.is_empty()) Chris@49: { Chris@49: out.zeros(A.n_cols, B_n_cols); Chris@49: return true; Chris@49: } Chris@49: Chris@49: #if defined(ARMA_USE_ATLAS) Chris@49: { Chris@49: podarray ipiv(A_n_rows + 2); // +2 for paranoia: old versions of Atlas might be trashing memory Chris@49: Chris@49: int info = atlas::clapack_gesv(atlas::CblasColMajor, A_n_rows, B_n_cols, A.memptr(), A_n_rows, ipiv.memptr(), out.memptr(), A_n_rows); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #elif defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: blas_int n = blas_int(A_n_rows); // assuming A is square Chris@49: blas_int lda = blas_int(A_n_rows); Chris@49: blas_int ldb = blas_int(A_n_rows); Chris@49: blas_int nrhs = blas_int(B_n_cols); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray ipiv(A_n_rows + 2); // +2 for paranoia: some versions of Lapack might be trashing memory Chris@49: Chris@49: arma_extra_debug_print("lapack::gesv()"); Chris@49: lapack::gesv(&n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); Chris@49: Chris@49: arma_extra_debug_print("lapack::gesv() -- finished"); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_stop("solve(): use of ATLAS or LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: return true; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! Solve an over-determined system. Chris@49: //! Assumes that A.n_rows > A.n_cols and B.n_rows = A.n_rows Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::solve_od(Mat& out, Mat& A, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: Mat tmp = X.get_ref(); Chris@49: Chris@49: const uword A_n_rows = A.n_rows; Chris@49: const uword A_n_cols = A.n_cols; Chris@49: Chris@49: const uword B_n_rows = tmp.n_rows; Chris@49: const uword B_n_cols = tmp.n_cols; Chris@49: Chris@49: arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" ); Chris@49: Chris@49: out.set_size(A_n_cols, B_n_cols); Chris@49: Chris@49: if(A.is_empty() || tmp.is_empty()) Chris@49: { Chris@49: out.zeros(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: char trans = 'N'; Chris@49: Chris@49: blas_int m = blas_int(A_n_rows); Chris@49: blas_int n = blas_int(A_n_cols); Chris@49: blas_int lda = blas_int(A_n_rows); Chris@49: blas_int ldb = blas_int(A_n_rows); Chris@49: blas_int nrhs = blas_int(B_n_cols); Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), n + (std::max)(n, nrhs)) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: Chris@49: // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems Chris@49: arma_extra_debug_print("lapack::gels()"); Chris@49: lapack::gels( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, work.memptr(), &lwork, &info ); Chris@49: Chris@49: arma_extra_debug_print("lapack::gels() -- finished"); Chris@49: Chris@49: for(uword col=0; col Chris@49: inline Chris@49: bool Chris@49: auxlib::solve_ud(Mat& out, Mat& A, const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: // TODO: this function provides the same results as Octave 3.4.2. Chris@49: // TODO: however, these results are different than Matlab 7.12.0.635. Chris@49: // TODO: figure out whether both Octave and Matlab are correct, or only one of them Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: const unwrap Y( X.get_ref() ); Chris@49: const Mat& B = Y.M; Chris@49: Chris@49: const uword A_n_rows = A.n_rows; Chris@49: const uword A_n_cols = A.n_cols; Chris@49: Chris@49: const uword B_n_rows = B.n_rows; Chris@49: const uword B_n_cols = B.n_cols; Chris@49: Chris@49: arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" ); Chris@49: Chris@49: // B could be an alias of "out", hence we need to check whether B is empty before setting the size of "out" Chris@49: if(A.is_empty() || B.is_empty()) Chris@49: { Chris@49: out.zeros(A_n_cols, B_n_cols); Chris@49: return true; Chris@49: } Chris@49: Chris@49: char trans = 'N'; Chris@49: Chris@49: blas_int m = blas_int(A_n_rows); Chris@49: blas_int n = blas_int(A_n_cols); Chris@49: blas_int lda = blas_int(A_n_rows); Chris@49: blas_int ldb = blas_int(A_n_cols); Chris@49: blas_int nrhs = blas_int(B_n_cols); Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), m + (std::max)(m,nrhs)) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: Mat tmp(A_n_cols, B_n_cols); Chris@49: tmp.zeros(); Chris@49: Chris@49: for(uword col=0; col work( static_cast(lwork) ); Chris@49: Chris@49: // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems Chris@49: arma_extra_debug_print("lapack::gels()"); Chris@49: lapack::gels( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, work.memptr(), &lwork, &info ); Chris@49: Chris@49: arma_extra_debug_print("lapack::gels() -- finished"); Chris@49: Chris@49: out.set_size(A_n_cols, B_n_cols); Chris@49: Chris@49: for(uword col=0; col Chris@49: inline Chris@49: bool Chris@49: auxlib::solve_tr(Mat& out, const Mat& A, const Mat& B, const uword layout) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: if(A.is_empty() || B.is_empty()) Chris@49: { Chris@49: out.zeros(A.n_cols, B.n_cols); Chris@49: return true; Chris@49: } Chris@49: Chris@49: out = B; Chris@49: Chris@49: char uplo = (layout == 0) ? 'U' : 'L'; Chris@49: char trans = 'N'; Chris@49: char diag = 'N'; Chris@49: blas_int n = blas_int(A.n_rows); Chris@49: blas_int nrhs = blas_int(B.n_cols); Chris@49: blas_int info = 0; Chris@49: Chris@49: lapack::trtrs(&uplo, &trans, &diag, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(out); Chris@49: arma_ignore(A); Chris@49: arma_ignore(B); Chris@49: arma_ignore(layout); Chris@49: arma_stop("solve(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: // Chris@49: // Schur decomposition Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::schur_dec(Mat& Z, Mat& T, const Mat& A) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: arma_debug_check( (A.is_square() == false), "schur_dec(): given matrix is not square" ); Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: Z.reset(); Chris@49: T.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: const uword A_n_rows = A.n_rows; Chris@49: Chris@49: Z.set_size(A_n_rows, A_n_rows); Chris@49: T = A; Chris@49: Chris@49: char jobvs = 'V'; // get Schur vectors (Z) Chris@49: char sort = 'N'; // do not sort eigenvalues/vectors Chris@49: blas_int* select = 0; // pointer to sorting function Chris@49: blas_int n = blas_int(A_n_rows); Chris@49: blas_int sdim = 0; // output for sorting Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), 3*n) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: podarray bwork(A_n_rows); Chris@49: Chris@49: podarray wr(A_n_rows); // output for eigenvalues Chris@49: podarray wi(A_n_rows); // output for eigenvalues Chris@49: Chris@49: lapack::gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, wr.memptr(), wi.memptr(), Z.memptr(), &n, work.memptr(), &lwork, bwork.memptr(), &info); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(Z); Chris@49: arma_ignore(T); Chris@49: arma_ignore(A); Chris@49: Chris@49: arma_stop("schur_dec(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::schur_dec(Mat >& Z, Mat >& T, const Mat >& A) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: arma_debug_check( (A.is_square() == false), "schur_dec(): matrix A is not square" ); Chris@49: Chris@49: if(A.is_empty()) Chris@49: { Chris@49: Z.reset(); Chris@49: T.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: typedef std::complex eT; Chris@49: Chris@49: const uword A_n_rows = A.n_rows; Chris@49: Chris@49: Z.set_size(A_n_rows, A_n_rows); Chris@49: T = A; Chris@49: Chris@49: char jobvs = 'V'; // get Schur vectors (Z) Chris@49: char sort = 'N'; // do not sort eigenvalues/vectors Chris@49: blas_int* select = 0; // pointer to sorting function Chris@49: blas_int n = blas_int(A_n_rows); Chris@49: blas_int sdim = 0; // output for sorting Chris@49: blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*n) ); Chris@49: blas_int info = 0; Chris@49: Chris@49: podarray work( static_cast(lwork) ); Chris@49: podarray bwork(A_n_rows); Chris@49: Chris@49: podarray w(A_n_rows); // output for eigenvalues Chris@49: podarray rwork(A_n_rows); Chris@49: Chris@49: lapack::cx_gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, w.memptr(), Z.memptr(), &n, work.memptr(), &lwork, rwork.memptr(), bwork.memptr(), &info); Chris@49: Chris@49: return (info == 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_ignore(Z); Chris@49: arma_ignore(T); Chris@49: arma_ignore(A); Chris@49: Chris@49: arma_stop("schur_dec(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: // Chris@49: // syl (solution of the Sylvester equation AX + XB = C) Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::syl(Mat& X, const Mat& A, const Mat& B, const Mat& C) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (A.is_square() == false) || (B.is_square() == false), Chris@49: "syl(): given matrix is not square" Chris@49: ); Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (C.n_rows != A.n_rows) || (C.n_cols != B.n_cols), Chris@49: "syl(): matrices are not conformant" Chris@49: ); Chris@49: Chris@49: if(A.is_empty() || B.is_empty() || C.is_empty()) Chris@49: { Chris@49: X.reset(); Chris@49: return true; Chris@49: } Chris@49: Chris@49: #if defined(ARMA_USE_LAPACK) Chris@49: { Chris@49: Mat Z1, Z2, T1, T2; Chris@49: Chris@49: const bool status_sd1 = auxlib::schur_dec(Z1, T1, A); Chris@49: const bool status_sd2 = auxlib::schur_dec(Z2, T2, B); Chris@49: Chris@49: if( (status_sd1 == false) || (status_sd2 == false) ) Chris@49: { Chris@49: return false; Chris@49: } Chris@49: Chris@49: char trana = 'N'; Chris@49: char tranb = 'N'; Chris@49: blas_int isgn = +1; Chris@49: blas_int m = blas_int(T1.n_rows); Chris@49: blas_int n = blas_int(T2.n_cols); Chris@49: Chris@49: eT scale = eT(0); Chris@49: blas_int info = 0; Chris@49: Chris@49: Mat Y = trans(Z1) * C * Z2; Chris@49: Chris@49: lapack::trsyl(&trana, &tranb, &isgn, &m, &n, T1.memptr(), &m, T2.memptr(), &n, Y.memptr(), &m, &scale, &info); Chris@49: Chris@49: //Y /= scale; Chris@49: Y /= (-scale); Chris@49: Chris@49: X = Z1 * Y * trans(Z2); Chris@49: Chris@49: return (info >= 0); Chris@49: } Chris@49: #else Chris@49: { Chris@49: arma_stop("syl(): use of LAPACK needs to be enabled"); Chris@49: return false; Chris@49: } Chris@49: #endif Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: // Chris@49: // lyap (solution of the continuous Lyapunov equation AX + XA^H + Q = 0) Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::lyap(Mat& X, const Mat& A, const Mat& Q) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: arma_debug_check( (A.is_square() == false), "lyap(): matrix A is not square"); Chris@49: arma_debug_check( (Q.is_square() == false), "lyap(): matrix Q is not square"); Chris@49: arma_debug_check( (A.n_rows != Q.n_rows), "lyap(): matrices A and Q have different dimensions"); Chris@49: Chris@49: Mat htransA; Chris@49: op_htrans::apply_noalias(htransA, A); Chris@49: Chris@49: const Mat mQ = -Q; Chris@49: Chris@49: return auxlib::syl(X, A, htransA, mQ); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: // Chris@49: // dlyap (solution of the discrete Lyapunov equation AXA^H - X + Q = 0) Chris@49: Chris@49: template Chris@49: inline Chris@49: bool Chris@49: auxlib::dlyap(Mat& X, const Mat& A, const Mat& Q) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: arma_debug_check( (A.is_square() == false), "dlyap(): matrix A is not square"); Chris@49: arma_debug_check( (Q.is_square() == false), "dlyap(): matrix Q is not square"); Chris@49: arma_debug_check( (A.n_rows != Q.n_rows), "dlyap(): matrices A and Q have different dimensions"); Chris@49: Chris@49: const Col vecQ = reshape(Q, Q.n_elem, 1); Chris@49: Chris@49: const Mat M = eye< Mat >(Q.n_elem, Q.n_elem) - kron(conj(A), A); Chris@49: Chris@49: Col vecX; Chris@49: Chris@49: const bool status = solve(vecX, M, vecQ); Chris@49: Chris@49: if(status == true) Chris@49: { Chris@49: X = reshape(vecX, Q.n_rows, Q.n_cols); Chris@49: return true; Chris@49: } Chris@49: else Chris@49: { Chris@49: X.reset(); Chris@49: return false; Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! @}