annotate armadillo-2.4.4/include/armadillo_bits/auxlib_meat.hpp @ 18:8d046a9d36aa slimline

Back out rev 13:ac07c60aa798. Like an idiot, I committed a whole pile of unrelated changes in the guise of a single typo fix. Will re-commit in stages
author Chris Cannam
date Thu, 10 May 2012 10:45:44 +0100
parents 8b6102e2a9b0
children
rev   line source
max@0 1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
max@0 2 // Copyright (C) 2008-2012 Conrad Sanderson
max@0 3 // Copyright (C) 2009 Edmund Highcock
max@0 4 // Copyright (C) 2011 James Sanders
max@0 5 // Copyright (C) 2011 Stanislav Funiak
max@0 6 //
max@0 7 // This file is part of the Armadillo C++ library.
max@0 8 // It is provided without any warranty of fitness
max@0 9 // for any purpose. You can redistribute this file
max@0 10 // and/or modify it under the terms of the GNU
max@0 11 // Lesser General Public License (LGPL) as published
max@0 12 // by the Free Software Foundation, either version 3
max@0 13 // of the License or (at your option) any later version.
max@0 14 // (see http://www.opensource.org/licenses for more info)
max@0 15
max@0 16
max@0 17 //! \addtogroup auxlib
max@0 18 //! @{
max@0 19
max@0 20
max@0 21
max@0 22 //! immediate matrix inverse
max@0 23 template<typename eT, typename T1>
max@0 24 inline
max@0 25 bool
max@0 26 auxlib::inv(Mat<eT>& out, const Base<eT,T1>& X, const bool slow)
max@0 27 {
max@0 28 arma_extra_debug_sigprint();
max@0 29
max@0 30 out = X.get_ref();
max@0 31
max@0 32 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
max@0 33
max@0 34 bool status = false;
max@0 35
max@0 36 const uword N = out.n_rows;
max@0 37
max@0 38 if( (N <= 4) && (slow == false) )
max@0 39 {
max@0 40 status = auxlib::inv_inplace_tinymat(out, N);
max@0 41 }
max@0 42
max@0 43 if( (N > 4) || (status == false) )
max@0 44 {
max@0 45 status = auxlib::inv_inplace_lapack(out);
max@0 46 }
max@0 47
max@0 48 return status;
max@0 49 }
max@0 50
max@0 51
max@0 52
max@0 53 template<typename eT>
max@0 54 inline
max@0 55 bool
max@0 56 auxlib::inv(Mat<eT>& out, const Mat<eT>& X, const bool slow)
max@0 57 {
max@0 58 arma_extra_debug_sigprint();
max@0 59
max@0 60 arma_debug_check( (X.is_square() == false), "inv(): given matrix is not square" );
max@0 61
max@0 62 bool status = false;
max@0 63
max@0 64 const uword N = X.n_rows;
max@0 65
max@0 66 if( (N <= 4) && (slow == false) )
max@0 67 {
max@0 68 status = (&out != &X) ? auxlib::inv_noalias_tinymat(out, X, N) : auxlib::inv_inplace_tinymat(out, N);
max@0 69 }
max@0 70
max@0 71 if( (N > 4) || (status == false) )
max@0 72 {
max@0 73 out = X;
max@0 74 status = auxlib::inv_inplace_lapack(out);
max@0 75 }
max@0 76
max@0 77 return status;
max@0 78 }
max@0 79
max@0 80
max@0 81
max@0 82 template<typename eT>
max@0 83 inline
max@0 84 bool
max@0 85 auxlib::inv_noalias_tinymat(Mat<eT>& out, const Mat<eT>& X, const uword N)
max@0 86 {
max@0 87 arma_extra_debug_sigprint();
max@0 88
max@0 89 bool det_ok = true;
max@0 90
max@0 91 out.set_size(N,N);
max@0 92
max@0 93 switch(N)
max@0 94 {
max@0 95 case 1:
max@0 96 {
max@0 97 out[0] = eT(1) / X[0];
max@0 98 };
max@0 99 break;
max@0 100
max@0 101 case 2:
max@0 102 {
max@0 103 const eT* Xm = X.memptr();
max@0 104
max@0 105 const eT a = Xm[pos<0,0>::n2];
max@0 106 const eT b = Xm[pos<0,1>::n2];
max@0 107 const eT c = Xm[pos<1,0>::n2];
max@0 108 const eT d = Xm[pos<1,1>::n2];
max@0 109
max@0 110 const eT tmp_det = (a*d - b*c);
max@0 111
max@0 112 if(tmp_det != eT(0))
max@0 113 {
max@0 114 eT* outm = out.memptr();
max@0 115
max@0 116 outm[pos<0,0>::n2] = d / tmp_det;
max@0 117 outm[pos<0,1>::n2] = -b / tmp_det;
max@0 118 outm[pos<1,0>::n2] = -c / tmp_det;
max@0 119 outm[pos<1,1>::n2] = a / tmp_det;
max@0 120 }
max@0 121 else
max@0 122 {
max@0 123 det_ok = false;
max@0 124 }
max@0 125 };
max@0 126 break;
max@0 127
max@0 128 case 3:
max@0 129 {
max@0 130 const eT* X_col0 = X.colptr(0);
max@0 131 const eT a11 = X_col0[0];
max@0 132 const eT a21 = X_col0[1];
max@0 133 const eT a31 = X_col0[2];
max@0 134
max@0 135 const eT* X_col1 = X.colptr(1);
max@0 136 const eT a12 = X_col1[0];
max@0 137 const eT a22 = X_col1[1];
max@0 138 const eT a32 = X_col1[2];
max@0 139
max@0 140 const eT* X_col2 = X.colptr(2);
max@0 141 const eT a13 = X_col2[0];
max@0 142 const eT a23 = X_col2[1];
max@0 143 const eT a33 = X_col2[2];
max@0 144
max@0 145 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
max@0 146
max@0 147 if(tmp_det != eT(0))
max@0 148 {
max@0 149 eT* out_col0 = out.colptr(0);
max@0 150 out_col0[0] = (a33*a22 - a32*a23) / tmp_det;
max@0 151 out_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
max@0 152 out_col0[2] = (a32*a21 - a31*a22) / tmp_det;
max@0 153
max@0 154 eT* out_col1 = out.colptr(1);
max@0 155 out_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
max@0 156 out_col1[1] = (a33*a11 - a31*a13) / tmp_det;
max@0 157 out_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
max@0 158
max@0 159 eT* out_col2 = out.colptr(2);
max@0 160 out_col2[0] = (a23*a12 - a22*a13) / tmp_det;
max@0 161 out_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
max@0 162 out_col2[2] = (a22*a11 - a21*a12) / tmp_det;
max@0 163 }
max@0 164 else
max@0 165 {
max@0 166 det_ok = false;
max@0 167 }
max@0 168 };
max@0 169 break;
max@0 170
max@0 171 case 4:
max@0 172 {
max@0 173 const eT tmp_det = det(X);
max@0 174
max@0 175 if(tmp_det != eT(0))
max@0 176 {
max@0 177 const eT* Xm = X.memptr();
max@0 178 eT* outm = out.memptr();
max@0 179
max@0 180 outm[pos<0,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
max@0 181 outm[pos<1,0>::n4] = ( Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
max@0 182 outm[pos<2,0>::n4] = ( Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
max@0 183 outm[pos<3,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
max@0 184
max@0 185 outm[pos<0,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
max@0 186 outm[pos<1,1>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
max@0 187 outm[pos<2,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
max@0 188 outm[pos<3,1>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
max@0 189
max@0 190 outm[pos<0,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
max@0 191 outm[pos<1,2>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
max@0 192 outm[pos<2,2>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
max@0 193 outm[pos<3,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
max@0 194
max@0 195 outm[pos<0,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
max@0 196 outm[pos<1,3>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
max@0 197 outm[pos<2,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
max@0 198 outm[pos<3,3>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] ) / tmp_det;
max@0 199 }
max@0 200 else
max@0 201 {
max@0 202 det_ok = false;
max@0 203 }
max@0 204 };
max@0 205 break;
max@0 206
max@0 207 default:
max@0 208 ;
max@0 209 }
max@0 210
max@0 211 return det_ok;
max@0 212 }
max@0 213
max@0 214
max@0 215
max@0 216 template<typename eT>
max@0 217 inline
max@0 218 bool
max@0 219 auxlib::inv_inplace_tinymat(Mat<eT>& X, const uword N)
max@0 220 {
max@0 221 arma_extra_debug_sigprint();
max@0 222
max@0 223 bool det_ok = true;
max@0 224
max@0 225 // for more info, see:
max@0 226 // http://www.dr-lex.34sp.com/random/matrix_inv.html
max@0 227 // http://www.cvl.iis.u-tokyo.ac.jp/~miyazaki/tech/teche23.html
max@0 228 // http://www.euclideanspace.com/maths/algebra/matrix/functions/inverse/fourD/index.htm
max@0 229 // http://www.geometrictools.com//LibFoundation/Mathematics/Wm4Matrix4.inl
max@0 230
max@0 231 switch(N)
max@0 232 {
max@0 233 case 1:
max@0 234 {
max@0 235 X[0] = eT(1) / X[0];
max@0 236 };
max@0 237 break;
max@0 238
max@0 239 case 2:
max@0 240 {
max@0 241 const eT a = X[pos<0,0>::n2];
max@0 242 const eT b = X[pos<0,1>::n2];
max@0 243 const eT c = X[pos<1,0>::n2];
max@0 244 const eT d = X[pos<1,1>::n2];
max@0 245
max@0 246 const eT tmp_det = (a*d - b*c);
max@0 247
max@0 248 if(tmp_det != eT(0))
max@0 249 {
max@0 250 X[pos<0,0>::n2] = d / tmp_det;
max@0 251 X[pos<0,1>::n2] = -b / tmp_det;
max@0 252 X[pos<1,0>::n2] = -c / tmp_det;
max@0 253 X[pos<1,1>::n2] = a / tmp_det;
max@0 254 }
max@0 255 else
max@0 256 {
max@0 257 det_ok = false;
max@0 258 }
max@0 259 };
max@0 260 break;
max@0 261
max@0 262 case 3:
max@0 263 {
max@0 264 eT* X_col0 = X.colptr(0);
max@0 265 eT* X_col1 = X.colptr(1);
max@0 266 eT* X_col2 = X.colptr(2);
max@0 267
max@0 268 const eT a11 = X_col0[0];
max@0 269 const eT a21 = X_col0[1];
max@0 270 const eT a31 = X_col0[2];
max@0 271
max@0 272 const eT a12 = X_col1[0];
max@0 273 const eT a22 = X_col1[1];
max@0 274 const eT a32 = X_col1[2];
max@0 275
max@0 276 const eT a13 = X_col2[0];
max@0 277 const eT a23 = X_col2[1];
max@0 278 const eT a33 = X_col2[2];
max@0 279
max@0 280 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
max@0 281
max@0 282 if(tmp_det != eT(0))
max@0 283 {
max@0 284 X_col0[0] = (a33*a22 - a32*a23) / tmp_det;
max@0 285 X_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
max@0 286 X_col0[2] = (a32*a21 - a31*a22) / tmp_det;
max@0 287
max@0 288 X_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
max@0 289 X_col1[1] = (a33*a11 - a31*a13) / tmp_det;
max@0 290 X_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
max@0 291
max@0 292 X_col2[0] = (a23*a12 - a22*a13) / tmp_det;
max@0 293 X_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
max@0 294 X_col2[2] = (a22*a11 - a21*a12) / tmp_det;
max@0 295 }
max@0 296 else
max@0 297 {
max@0 298 det_ok = false;
max@0 299 }
max@0 300 };
max@0 301 break;
max@0 302
max@0 303 case 4:
max@0 304 {
max@0 305 const eT tmp_det = det(X);
max@0 306
max@0 307 if(tmp_det != eT(0))
max@0 308 {
max@0 309 const Mat<eT> A(X);
max@0 310
max@0 311 const eT* Am = A.memptr();
max@0 312 eT* Xm = X.memptr();
max@0 313
max@0 314 Xm[pos<0,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] - Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] + Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
max@0 315 Xm[pos<1,0>::n4] = ( Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
max@0 316 Xm[pos<2,0>::n4] = ( Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
max@0 317 Xm[pos<3,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
max@0 318
max@0 319 Xm[pos<0,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] + Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] - Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
max@0 320 Xm[pos<1,1>::n4] = ( Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
max@0 321 Xm[pos<2,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
max@0 322 Xm[pos<3,1>::n4] = ( Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
max@0 323
max@0 324 Xm[pos<0,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] + Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
max@0 325 Xm[pos<1,2>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
max@0 326 Xm[pos<2,2>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
max@0 327 Xm[pos<3,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
max@0 328
max@0 329 Xm[pos<0,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] + Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
max@0 330 Xm[pos<1,3>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
max@0 331 Xm[pos<2,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
max@0 332 Xm[pos<3,3>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] ) / tmp_det;
max@0 333 }
max@0 334 else
max@0 335 {
max@0 336 det_ok = false;
max@0 337 }
max@0 338 };
max@0 339 break;
max@0 340
max@0 341 default:
max@0 342 ;
max@0 343 }
max@0 344
max@0 345 return det_ok;
max@0 346 }
max@0 347
max@0 348
max@0 349
max@0 350 template<typename eT>
max@0 351 inline
max@0 352 bool
max@0 353 auxlib::inv_inplace_lapack(Mat<eT>& out)
max@0 354 {
max@0 355 arma_extra_debug_sigprint();
max@0 356
max@0 357 if(out.is_empty())
max@0 358 {
max@0 359 return true;
max@0 360 }
max@0 361
max@0 362 #if defined(ARMA_USE_ATLAS)
max@0 363 {
max@0 364 podarray<int> ipiv(out.n_rows);
max@0 365
max@0 366 int info = atlas::clapack_getrf(atlas::CblasColMajor, out.n_rows, out.n_cols, out.memptr(), out.n_rows, ipiv.memptr());
max@0 367
max@0 368 if(info == 0)
max@0 369 {
max@0 370 info = atlas::clapack_getri(atlas::CblasColMajor, out.n_rows, out.memptr(), out.n_rows, ipiv.memptr());
max@0 371 }
max@0 372
max@0 373 return (info == 0);
max@0 374 }
max@0 375 #elif defined(ARMA_USE_LAPACK)
max@0 376 {
max@0 377 blas_int n_rows = out.n_rows;
max@0 378 blas_int n_cols = out.n_cols;
max@0 379 blas_int info = 0;
max@0 380
max@0 381 podarray<blas_int> ipiv(out.n_rows);
max@0 382
max@0 383 // 84 was empirically found -- it is the maximum value suggested by LAPACK (as provided by ATLAS v3.6)
max@0 384 // based on tests with various matrix types on 32-bit and 64-bit machines
max@0 385 //
max@0 386 // the "work" array is deliberately long so that a secondary (time-consuming)
max@0 387 // memory allocation is avoided, if possible
max@0 388
max@0 389 blas_int work_len = (std::max)(blas_int(1), n_rows*84);
max@0 390 podarray<eT> work( static_cast<uword>(work_len) );
max@0 391
max@0 392 lapack::getrf(&n_rows, &n_cols, out.memptr(), &n_rows, ipiv.memptr(), &info);
max@0 393
max@0 394 if(info == 0)
max@0 395 {
max@0 396 // query for optimum size of work_len
max@0 397
max@0 398 blas_int work_len_tmp = -1;
max@0 399 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len_tmp, &info);
max@0 400
max@0 401 if(info == 0)
max@0 402 {
max@0 403 blas_int proposed_work_len = static_cast<blas_int>(access::tmp_real(work[0]));
max@0 404
max@0 405 // if necessary, allocate more memory
max@0 406 if(work_len < proposed_work_len)
max@0 407 {
max@0 408 work_len = proposed_work_len;
max@0 409 work.set_size( static_cast<uword>(work_len) );
max@0 410 }
max@0 411 }
max@0 412
max@0 413 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len, &info);
max@0 414 }
max@0 415
max@0 416 return (info == 0);
max@0 417 }
max@0 418 #else
max@0 419 {
max@0 420 arma_ignore(out);
max@0 421 arma_stop("inv(): use of ATLAS or LAPACK needs to be enabled");
max@0 422 return false;
max@0 423 }
max@0 424 #endif
max@0 425 }
max@0 426
max@0 427
max@0 428
max@0 429 template<typename eT, typename T1>
max@0 430 inline
max@0 431 bool
max@0 432 auxlib::inv_tr(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
max@0 433 {
max@0 434 arma_extra_debug_sigprint();
max@0 435
max@0 436 out = X.get_ref();
max@0 437
max@0 438 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
max@0 439
max@0 440 if(out.is_empty())
max@0 441 {
max@0 442 return true;
max@0 443 }
max@0 444
max@0 445 bool status;
max@0 446
max@0 447 #if defined(ARMA_USE_LAPACK)
max@0 448 {
max@0 449 char uplo = (layout == 0) ? 'U' : 'L';
max@0 450 char diag = 'N';
max@0 451 blas_int n = blas_int(out.n_rows);
max@0 452 blas_int info = 0;
max@0 453
max@0 454 lapack::trtri(&uplo, &diag, &n, out.memptr(), &n, &info);
max@0 455
max@0 456 status = (info == 0);
max@0 457 }
max@0 458 #else
max@0 459 {
max@0 460 arma_ignore(layout);
max@0 461 arma_stop("inv(): use of LAPACK needs to be enabled");
max@0 462 status = false;
max@0 463 }
max@0 464 #endif
max@0 465
max@0 466
max@0 467 if(status == true)
max@0 468 {
max@0 469 if(layout == 0)
max@0 470 {
max@0 471 // upper triangular
max@0 472 out = trimatu(out);
max@0 473 }
max@0 474 else
max@0 475 {
max@0 476 // lower triangular
max@0 477 out = trimatl(out);
max@0 478 }
max@0 479 }
max@0 480
max@0 481 return status;
max@0 482 }
max@0 483
max@0 484
max@0 485
max@0 486 template<typename eT, typename T1>
max@0 487 inline
max@0 488 bool
max@0 489 auxlib::inv_sym(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
max@0 490 {
max@0 491 arma_extra_debug_sigprint();
max@0 492
max@0 493 out = X.get_ref();
max@0 494
max@0 495 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
max@0 496
max@0 497 if(out.is_empty())
max@0 498 {
max@0 499 return true;
max@0 500 }
max@0 501
max@0 502 bool status;
max@0 503
max@0 504 #if defined(ARMA_USE_LAPACK)
max@0 505 {
max@0 506 char uplo = (layout == 0) ? 'U' : 'L';
max@0 507 blas_int n = blas_int(out.n_rows);
max@0 508 blas_int lwork = n*n; // TODO: use lwork = -1 to determine optimal size
max@0 509 blas_int info = 0;
max@0 510
max@0 511 podarray<blas_int> ipiv;
max@0 512 ipiv.set_size(out.n_rows);
max@0 513
max@0 514 podarray<eT> work;
max@0 515 work.set_size( uword(lwork) );
max@0 516
max@0 517 lapack::sytrf(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &lwork, &info);
max@0 518
max@0 519 status = (info == 0);
max@0 520
max@0 521 if(status == true)
max@0 522 {
max@0 523 lapack::sytri(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &info);
max@0 524
max@0 525 out = (layout == 0) ? symmatu(out) : symmatl(out);
max@0 526
max@0 527 status = (info == 0);
max@0 528 }
max@0 529 }
max@0 530 #else
max@0 531 {
max@0 532 arma_ignore(layout);
max@0 533 arma_stop("inv(): use of LAPACK needs to be enabled");
max@0 534 status = false;
max@0 535 }
max@0 536 #endif
max@0 537
max@0 538 return status;
max@0 539 }
max@0 540
max@0 541
max@0 542
max@0 543 template<typename eT, typename T1>
max@0 544 inline
max@0 545 bool
max@0 546 auxlib::inv_sympd(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
max@0 547 {
max@0 548 arma_extra_debug_sigprint();
max@0 549
max@0 550 out = X.get_ref();
max@0 551
max@0 552 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
max@0 553
max@0 554 if(out.is_empty())
max@0 555 {
max@0 556 return true;
max@0 557 }
max@0 558
max@0 559 bool status;
max@0 560
max@0 561 #if defined(ARMA_USE_LAPACK)
max@0 562 {
max@0 563 char uplo = (layout == 0) ? 'U' : 'L';
max@0 564 blas_int n = blas_int(out.n_rows);
max@0 565 blas_int info = 0;
max@0 566
max@0 567 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
max@0 568
max@0 569 status = (info == 0);
max@0 570
max@0 571 if(status == true)
max@0 572 {
max@0 573 lapack::potri(&uplo, &n, out.memptr(), &n, &info);
max@0 574
max@0 575 out = (layout == 0) ? symmatu(out) : symmatl(out);
max@0 576
max@0 577 status = (info == 0);
max@0 578 }
max@0 579 }
max@0 580 #else
max@0 581 {
max@0 582 arma_ignore(layout);
max@0 583 arma_stop("inv(): use of LAPACK needs to be enabled");
max@0 584 status = false;
max@0 585 }
max@0 586 #endif
max@0 587
max@0 588 return status;
max@0 589 }
max@0 590
max@0 591
max@0 592
max@0 593 template<typename eT, typename T1>
max@0 594 inline
max@0 595 eT
max@0 596 auxlib::det(const Base<eT,T1>& X, const bool slow)
max@0 597 {
max@0 598 const unwrap<T1> tmp(X.get_ref());
max@0 599 const Mat<eT>& A = tmp.M;
max@0 600
max@0 601 arma_debug_check( (A.is_square() == false), "det(): matrix is not square" );
max@0 602
max@0 603 const bool make_copy = (is_Mat<T1>::value == true) ? true : false;
max@0 604
max@0 605 if(slow == false)
max@0 606 {
max@0 607 const uword N = A.n_rows;
max@0 608
max@0 609 switch(N)
max@0 610 {
max@0 611 case 0:
max@0 612 case 1:
max@0 613 case 2:
max@0 614 return auxlib::det_tinymat(A, N);
max@0 615 break;
max@0 616
max@0 617 case 3:
max@0 618 case 4:
max@0 619 {
max@0 620 const eT tmp_det = auxlib::det_tinymat(A, N);
max@0 621 return (tmp_det != eT(0)) ? tmp_det : auxlib::det_lapack(A, make_copy);
max@0 622 }
max@0 623 break;
max@0 624
max@0 625 default:
max@0 626 return auxlib::det_lapack(A, make_copy);
max@0 627 }
max@0 628 }
max@0 629 else
max@0 630 {
max@0 631 return auxlib::det_lapack(A, make_copy);
max@0 632 }
max@0 633 }
max@0 634
max@0 635
max@0 636
max@0 637 template<typename eT>
max@0 638 inline
max@0 639 eT
max@0 640 auxlib::det_tinymat(const Mat<eT>& X, const uword N)
max@0 641 {
max@0 642 arma_extra_debug_sigprint();
max@0 643
max@0 644 switch(N)
max@0 645 {
max@0 646 case 0:
max@0 647 return eT(1);
max@0 648 break;
max@0 649
max@0 650 case 1:
max@0 651 return X[0];
max@0 652 break;
max@0 653
max@0 654 case 2:
max@0 655 {
max@0 656 const eT* Xm = X.memptr();
max@0 657
max@0 658 return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] );
max@0 659 }
max@0 660 break;
max@0 661
max@0 662 case 3:
max@0 663 {
max@0 664 // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2);
max@0 665 // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0);
max@0 666 // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1);
max@0 667 // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2);
max@0 668 // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0);
max@0 669 // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1);
max@0 670 // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6);
max@0 671
max@0 672 const eT* a_col0 = X.colptr(0);
max@0 673 const eT a11 = a_col0[0];
max@0 674 const eT a21 = a_col0[1];
max@0 675 const eT a31 = a_col0[2];
max@0 676
max@0 677 const eT* a_col1 = X.colptr(1);
max@0 678 const eT a12 = a_col1[0];
max@0 679 const eT a22 = a_col1[1];
max@0 680 const eT a32 = a_col1[2];
max@0 681
max@0 682 const eT* a_col2 = X.colptr(2);
max@0 683 const eT a13 = a_col2[0];
max@0 684 const eT a23 = a_col2[1];
max@0 685 const eT a33 = a_col2[2];
max@0 686
max@0 687 return ( a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13) );
max@0 688 }
max@0 689 break;
max@0 690
max@0 691 case 4:
max@0 692 {
max@0 693 const eT* Xm = X.memptr();
max@0 694
max@0 695 const eT val = \
max@0 696 Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
max@0 697 - Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
max@0 698 - Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
max@0 699 + Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
max@0 700 + Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
max@0 701 - Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
max@0 702 - Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
max@0 703 + Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
max@0 704 + Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
max@0 705 - Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
max@0 706 - Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
max@0 707 + Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
max@0 708 + Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
max@0 709 - Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
max@0 710 - Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
max@0 711 + Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
max@0 712 + Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
max@0 713 - Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
max@0 714 - Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
max@0 715 + Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
max@0 716 + Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
max@0 717 - Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
max@0 718 - Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
max@0 719 + Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
max@0 720 ;
max@0 721
max@0 722 return val;
max@0 723 }
max@0 724 break;
max@0 725
max@0 726 default:
max@0 727 return eT(0);
max@0 728 ;
max@0 729 }
max@0 730 }
max@0 731
max@0 732
max@0 733
max@0 734 //! immediate determinant of a matrix using ATLAS or LAPACK
max@0 735 template<typename eT>
max@0 736 inline
max@0 737 eT
max@0 738 auxlib::det_lapack(const Mat<eT>& X, const bool make_copy)
max@0 739 {
max@0 740 arma_extra_debug_sigprint();
max@0 741
max@0 742 Mat<eT> X_copy;
max@0 743
max@0 744 if(make_copy == true)
max@0 745 {
max@0 746 X_copy = X;
max@0 747 }
max@0 748
max@0 749 Mat<eT>& tmp = (make_copy == true) ? X_copy : const_cast< Mat<eT>& >(X);
max@0 750
max@0 751 if(tmp.is_empty())
max@0 752 {
max@0 753 return eT(1);
max@0 754 }
max@0 755
max@0 756
max@0 757 #if defined(ARMA_USE_ATLAS)
max@0 758 {
max@0 759 podarray<int> ipiv(tmp.n_rows);
max@0 760
max@0 761 //const int info =
max@0 762 atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
max@0 763
max@0 764 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
max@0 765 eT val = tmp.at(0,0);
max@0 766 for(uword i=1; i < tmp.n_rows; ++i)
max@0 767 {
max@0 768 val *= tmp.at(i,i);
max@0 769 }
max@0 770
max@0 771 int sign = +1;
max@0 772 for(uword i=0; i < tmp.n_rows; ++i)
max@0 773 {
max@0 774 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0
max@0 775 {
max@0 776 sign *= -1;
max@0 777 }
max@0 778 }
max@0 779
max@0 780 return ( (sign < 0) ? -val : val );
max@0 781 }
max@0 782 #elif defined(ARMA_USE_LAPACK)
max@0 783 {
max@0 784 podarray<blas_int> ipiv(tmp.n_rows);
max@0 785
max@0 786 blas_int info = 0;
max@0 787 blas_int n_rows = blas_int(tmp.n_rows);
max@0 788 blas_int n_cols = blas_int(tmp.n_cols);
max@0 789
max@0 790 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
max@0 791
max@0 792 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
max@0 793 eT val = tmp.at(0,0);
max@0 794 for(uword i=1; i < tmp.n_rows; ++i)
max@0 795 {
max@0 796 val *= tmp.at(i,i);
max@0 797 }
max@0 798
max@0 799 blas_int sign = +1;
max@0 800 for(uword i=0; i < tmp.n_rows; ++i)
max@0 801 {
max@0 802 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1
max@0 803 {
max@0 804 sign *= -1;
max@0 805 }
max@0 806 }
max@0 807
max@0 808 return ( (sign < 0) ? -val : val );
max@0 809 }
max@0 810 #else
max@0 811 {
max@0 812 arma_ignore(X);
max@0 813 arma_ignore(make_copy);
max@0 814 arma_ignore(tmp);
max@0 815 arma_stop("det(): use of ATLAS or LAPACK needs to be enabled");
max@0 816 return eT(0);
max@0 817 }
max@0 818 #endif
max@0 819 }
max@0 820
max@0 821
max@0 822
max@0 823 //! immediate log determinant of a matrix using ATLAS or LAPACK
max@0 824 template<typename eT, typename T1>
max@0 825 inline
max@0 826 bool
max@0 827 auxlib::log_det(eT& out_val, typename get_pod_type<eT>::result& out_sign, const Base<eT,T1>& X)
max@0 828 {
max@0 829 arma_extra_debug_sigprint();
max@0 830
max@0 831 typedef typename get_pod_type<eT>::result T;
max@0 832
max@0 833 #if defined(ARMA_USE_ATLAS)
max@0 834 {
max@0 835 Mat<eT> tmp(X.get_ref());
max@0 836 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
max@0 837
max@0 838 if(tmp.is_empty())
max@0 839 {
max@0 840 out_val = eT(0);
max@0 841 out_sign = T(1);
max@0 842 return true;
max@0 843 }
max@0 844
max@0 845 podarray<int> ipiv(tmp.n_rows);
max@0 846
max@0 847 const int info = atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
max@0 848
max@0 849 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
max@0 850
max@0 851 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
max@0 852 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
max@0 853
max@0 854 for(uword i=1; i < tmp.n_rows; ++i)
max@0 855 {
max@0 856 const eT x = tmp.at(i,i);
max@0 857
max@0 858 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
max@0 859 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
max@0 860 }
max@0 861
max@0 862 for(uword i=0; i < tmp.n_rows; ++i)
max@0 863 {
max@0 864 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0
max@0 865 {
max@0 866 sign *= -1;
max@0 867 }
max@0 868 }
max@0 869
max@0 870 out_val = val;
max@0 871 out_sign = T(sign);
max@0 872
max@0 873 return (info == 0);
max@0 874 }
max@0 875 #elif defined(ARMA_USE_LAPACK)
max@0 876 {
max@0 877 Mat<eT> tmp(X.get_ref());
max@0 878 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
max@0 879
max@0 880 if(tmp.is_empty())
max@0 881 {
max@0 882 out_val = eT(0);
max@0 883 out_sign = T(1);
max@0 884 return true;
max@0 885 }
max@0 886
max@0 887 podarray<blas_int> ipiv(tmp.n_rows);
max@0 888
max@0 889 blas_int info = 0;
max@0 890 blas_int n_rows = blas_int(tmp.n_rows);
max@0 891 blas_int n_cols = blas_int(tmp.n_cols);
max@0 892
max@0 893 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
max@0 894
max@0 895 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
max@0 896
max@0 897 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
max@0 898 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
max@0 899
max@0 900 for(uword i=1; i < tmp.n_rows; ++i)
max@0 901 {
max@0 902 const eT x = tmp.at(i,i);
max@0 903
max@0 904 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
max@0 905 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
max@0 906 }
max@0 907
max@0 908 for(uword i=0; i < tmp.n_rows; ++i)
max@0 909 {
max@0 910 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1
max@0 911 {
max@0 912 sign *= -1;
max@0 913 }
max@0 914 }
max@0 915
max@0 916 out_val = val;
max@0 917 out_sign = T(sign);
max@0 918
max@0 919 return (info == 0);
max@0 920 }
max@0 921 #else
max@0 922 {
max@0 923 out_val = eT(0);
max@0 924 out_sign = T(0);
max@0 925
max@0 926 arma_stop("log_det(): use of ATLAS or LAPACK needs to be enabled");
max@0 927
max@0 928 return false;
max@0 929 }
max@0 930 #endif
max@0 931 }
max@0 932
max@0 933
max@0 934
max@0 935 //! immediate LU decomposition of a matrix using ATLAS or LAPACK
max@0 936 template<typename eT, typename T1>
max@0 937 inline
max@0 938 bool
max@0 939 auxlib::lu(Mat<eT>& L, Mat<eT>& U, podarray<blas_int>& ipiv, const Base<eT,T1>& X)
max@0 940 {
max@0 941 arma_extra_debug_sigprint();
max@0 942
max@0 943 U = X.get_ref();
max@0 944
max@0 945 const uword U_n_rows = U.n_rows;
max@0 946 const uword U_n_cols = U.n_cols;
max@0 947
max@0 948 if(U.is_empty())
max@0 949 {
max@0 950 L.set_size(U_n_rows, 0);
max@0 951 U.set_size(0, U_n_cols);
max@0 952 ipiv.reset();
max@0 953 return true;
max@0 954 }
max@0 955
max@0 956 #if defined(ARMA_USE_ATLAS) || defined(ARMA_USE_LAPACK)
max@0 957 {
max@0 958 bool status;
max@0 959
max@0 960 #if defined(ARMA_USE_ATLAS)
max@0 961 {
max@0 962 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
max@0 963
max@0 964 int info = atlas::clapack_getrf(atlas::CblasColMajor, U_n_rows, U_n_cols, U.memptr(), U_n_rows, ipiv.memptr());
max@0 965
max@0 966 status = (info == 0);
max@0 967 }
max@0 968 #elif defined(ARMA_USE_LAPACK)
max@0 969 {
max@0 970 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
max@0 971
max@0 972 blas_int info = 0;
max@0 973
max@0 974 blas_int n_rows = U_n_rows;
max@0 975 blas_int n_cols = U_n_cols;
max@0 976
max@0 977
max@0 978 lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info);
max@0 979
max@0 980 // take into account that Fortran counts from 1
max@0 981 arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem);
max@0 982
max@0 983 status = (info == 0);
max@0 984 }
max@0 985 #endif
max@0 986
max@0 987 L.copy_size(U);
max@0 988
max@0 989 for(uword col=0; col < U_n_cols; ++col)
max@0 990 {
max@0 991 for(uword row=0; (row < col) && (row < U_n_rows); ++row)
max@0 992 {
max@0 993 L.at(row,col) = eT(0);
max@0 994 }
max@0 995
max@0 996 if( L.in_range(col,col) == true )
max@0 997 {
max@0 998 L.at(col,col) = eT(1);
max@0 999 }
max@0 1000
max@0 1001 for(uword row = (col+1); row < U_n_rows; ++row)
max@0 1002 {
max@0 1003 L.at(row,col) = U.at(row,col);
max@0 1004 U.at(row,col) = eT(0);
max@0 1005 }
max@0 1006 }
max@0 1007
max@0 1008 return status;
max@0 1009 }
max@0 1010 #else
max@0 1011 {
max@0 1012 arma_stop("lu(): use of ATLAS or LAPACK needs to be enabled");
max@0 1013
max@0 1014 return false;
max@0 1015 }
max@0 1016 #endif
max@0 1017 }
max@0 1018
max@0 1019
max@0 1020
max@0 1021 template<typename eT, typename T1>
max@0 1022 inline
max@0 1023 bool
max@0 1024 auxlib::lu(Mat<eT>& L, Mat<eT>& U, Mat<eT>& P, const Base<eT,T1>& X)
max@0 1025 {
max@0 1026 arma_extra_debug_sigprint();
max@0 1027
max@0 1028 podarray<blas_int> ipiv1;
max@0 1029 const bool status = auxlib::lu(L, U, ipiv1, X);
max@0 1030
max@0 1031 if(status == true)
max@0 1032 {
max@0 1033 if(U.is_empty())
max@0 1034 {
max@0 1035 // L and U have been already set to the correct empty matrices
max@0 1036 P.eye(L.n_rows, L.n_rows);
max@0 1037 return true;
max@0 1038 }
max@0 1039
max@0 1040 const uword n = ipiv1.n_elem;
max@0 1041 const uword P_rows = U.n_rows;
max@0 1042
max@0 1043 podarray<blas_int> ipiv2(P_rows);
max@0 1044
max@0 1045 const blas_int* ipiv1_mem = ipiv1.memptr();
max@0 1046 blas_int* ipiv2_mem = ipiv2.memptr();
max@0 1047
max@0 1048 for(uword i=0; i<P_rows; ++i)
max@0 1049 {
max@0 1050 ipiv2_mem[i] = blas_int(i);
max@0 1051 }
max@0 1052
max@0 1053 for(uword i=0; i<n; ++i)
max@0 1054 {
max@0 1055 const uword k = static_cast<uword>(ipiv1_mem[i]);
max@0 1056
max@0 1057 if( ipiv2_mem[i] != ipiv2_mem[k] )
max@0 1058 {
max@0 1059 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
max@0 1060 }
max@0 1061 }
max@0 1062
max@0 1063 P.zeros(P_rows, P_rows);
max@0 1064
max@0 1065 for(uword row=0; row<P_rows; ++row)
max@0 1066 {
max@0 1067 P.at(row, static_cast<uword>(ipiv2_mem[row])) = eT(1);
max@0 1068 }
max@0 1069
max@0 1070 if(L.n_cols > U.n_rows)
max@0 1071 {
max@0 1072 L.shed_cols(U.n_rows, L.n_cols-1);
max@0 1073 }
max@0 1074
max@0 1075 if(U.n_rows > L.n_cols)
max@0 1076 {
max@0 1077 U.shed_rows(L.n_cols, U.n_rows-1);
max@0 1078 }
max@0 1079 }
max@0 1080
max@0 1081 return status;
max@0 1082 }
max@0 1083
max@0 1084
max@0 1085
max@0 1086 template<typename eT, typename T1>
max@0 1087 inline
max@0 1088 bool
max@0 1089 auxlib::lu(Mat<eT>& L, Mat<eT>& U, const Base<eT,T1>& X)
max@0 1090 {
max@0 1091 arma_extra_debug_sigprint();
max@0 1092
max@0 1093 podarray<blas_int> ipiv1;
max@0 1094 const bool status = auxlib::lu(L, U, ipiv1, X);
max@0 1095
max@0 1096 if(status == true)
max@0 1097 {
max@0 1098 if(U.is_empty())
max@0 1099 {
max@0 1100 // L and U have been already set to the correct empty matrices
max@0 1101 return true;
max@0 1102 }
max@0 1103
max@0 1104 const uword n = ipiv1.n_elem;
max@0 1105 const uword P_rows = U.n_rows;
max@0 1106
max@0 1107 podarray<blas_int> ipiv2(P_rows);
max@0 1108
max@0 1109 const blas_int* ipiv1_mem = ipiv1.memptr();
max@0 1110 blas_int* ipiv2_mem = ipiv2.memptr();
max@0 1111
max@0 1112 for(uword i=0; i<P_rows; ++i)
max@0 1113 {
max@0 1114 ipiv2_mem[i] = blas_int(i);
max@0 1115 }
max@0 1116
max@0 1117 for(uword i=0; i<n; ++i)
max@0 1118 {
max@0 1119 const uword k = static_cast<uword>(ipiv1_mem[i]);
max@0 1120
max@0 1121 if( ipiv2_mem[i] != ipiv2_mem[k] )
max@0 1122 {
max@0 1123 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
max@0 1124 L.swap_rows( static_cast<uword>(ipiv2_mem[i]), static_cast<uword>(ipiv2_mem[k]) );
max@0 1125 }
max@0 1126 }
max@0 1127
max@0 1128 if(L.n_cols > U.n_rows)
max@0 1129 {
max@0 1130 L.shed_cols(U.n_rows, L.n_cols-1);
max@0 1131 }
max@0 1132
max@0 1133 if(U.n_rows > L.n_cols)
max@0 1134 {
max@0 1135 U.shed_rows(L.n_cols, U.n_rows-1);
max@0 1136 }
max@0 1137 }
max@0 1138
max@0 1139 return status;
max@0 1140 }
max@0 1141
max@0 1142
max@0 1143
max@0 1144 //! immediate eigenvalues of a symmetric real matrix using LAPACK
max@0 1145 template<typename eT, typename T1>
max@0 1146 inline
max@0 1147 bool
max@0 1148 auxlib::eig_sym(Col<eT>& eigval, const Base<eT,T1>& X)
max@0 1149 {
max@0 1150 arma_extra_debug_sigprint();
max@0 1151
max@0 1152 #if defined(ARMA_USE_LAPACK)
max@0 1153 {
max@0 1154 Mat<eT> A(X.get_ref());
max@0 1155
max@0 1156 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square");
max@0 1157
max@0 1158 if(A.is_empty())
max@0 1159 {
max@0 1160 eigval.reset();
max@0 1161 return true;
max@0 1162 }
max@0 1163
max@0 1164 // rudimentary "better-than-nothing" test for symmetry
max@0 1165 //arma_debug_check( (A.at(A.n_rows-1, 0) != A.at(0, A.n_cols-1)), "auxlib::eig(): given matrix is not symmetric" );
max@0 1166
max@0 1167 char jobz = 'N';
max@0 1168 char uplo = 'U';
max@0 1169
max@0 1170 blas_int n_rows = A.n_rows;
max@0 1171 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1);
max@0 1172
max@0 1173 eigval.set_size( static_cast<uword>(n_rows) );
max@0 1174 podarray<eT> work( static_cast<uword>(lwork) );
max@0 1175
max@0 1176 blas_int info;
max@0 1177
max@0 1178 arma_extra_debug_print("lapack::syev()");
max@0 1179 lapack::syev(&jobz, &uplo, &n_rows, A.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info);
max@0 1180
max@0 1181 return (info == 0);
max@0 1182 }
max@0 1183 #else
max@0 1184 {
max@0 1185 arma_ignore(eigval);
max@0 1186 arma_ignore(X);
max@0 1187 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
max@0 1188 return false;
max@0 1189 }
max@0 1190 #endif
max@0 1191 }
max@0 1192
max@0 1193
max@0 1194
max@0 1195 //! immediate eigenvalues of a hermitian complex matrix using LAPACK
max@0 1196 template<typename T, typename T1>
max@0 1197 inline
max@0 1198 bool
max@0 1199 auxlib::eig_sym(Col<T>& eigval, const Base<std::complex<T>,T1>& X)
max@0 1200 {
max@0 1201 arma_extra_debug_sigprint();
max@0 1202
max@0 1203 typedef typename std::complex<T> eT;
max@0 1204
max@0 1205 #if defined(ARMA_USE_LAPACK)
max@0 1206 {
max@0 1207 Mat<eT> A(X.get_ref());
max@0 1208 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not hermitian");
max@0 1209
max@0 1210 if(A.is_empty())
max@0 1211 {
max@0 1212 eigval.reset();
max@0 1213 return true;
max@0 1214 }
max@0 1215
max@0 1216 char jobz = 'N';
max@0 1217 char uplo = 'U';
max@0 1218
max@0 1219 blas_int n_rows = A.n_rows;
max@0 1220 blas_int lda = A.n_rows;
max@0 1221 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1); // TODO: automatically find best size of lwork
max@0 1222
max@0 1223 eigval.set_size( static_cast<uword>(n_rows) );
max@0 1224
max@0 1225 podarray<eT> work( static_cast<uword>(lwork) );
max@0 1226 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) );
max@0 1227
max@0 1228 blas_int info;
max@0 1229
max@0 1230 arma_extra_debug_print("lapack::heev()");
max@0 1231 lapack::heev(&jobz, &uplo, &n_rows, A.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
max@0 1232
max@0 1233 return (info == 0);
max@0 1234 }
max@0 1235 #else
max@0 1236 {
max@0 1237 arma_ignore(eigval);
max@0 1238 arma_ignore(X);
max@0 1239 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
max@0 1240 return false;
max@0 1241 }
max@0 1242 #endif
max@0 1243 }
max@0 1244
max@0 1245
max@0 1246
max@0 1247 //! immediate eigenvalues and eigenvectors of a symmetric real matrix using LAPACK
max@0 1248 template<typename eT, typename T1>
max@0 1249 inline
max@0 1250 bool
max@0 1251 auxlib::eig_sym(Col<eT>& eigval, Mat<eT>& eigvec, const Base<eT,T1>& X)
max@0 1252 {
max@0 1253 arma_extra_debug_sigprint();
max@0 1254
max@0 1255 #if defined(ARMA_USE_LAPACK)
max@0 1256 {
max@0 1257 eigvec = X.get_ref();
max@0 1258
max@0 1259 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
max@0 1260
max@0 1261 if(eigvec.is_empty())
max@0 1262 {
max@0 1263 eigval.reset();
max@0 1264 eigvec.reset();
max@0 1265 return true;
max@0 1266 }
max@0 1267
max@0 1268 // rudimentary "better-than-nothing" test for symmetry
max@0 1269 //arma_debug_check( (A.at(A.n_rows-1, 0) != A.at(0, A.n_cols-1)), "auxlib::eig(): given matrix is not symmetric" );
max@0 1270
max@0 1271 char jobz = 'V';
max@0 1272 char uplo = 'U';
max@0 1273
max@0 1274 blas_int n_rows = eigvec.n_rows;
max@0 1275 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1);
max@0 1276
max@0 1277 eigval.set_size( static_cast<uword>(n_rows) );
max@0 1278 podarray<eT> work( static_cast<uword>(lwork) );
max@0 1279
max@0 1280 blas_int info;
max@0 1281
max@0 1282 arma_extra_debug_print("lapack::syev()");
max@0 1283 lapack::syev(&jobz, &uplo, &n_rows, eigvec.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info);
max@0 1284
max@0 1285 return (info == 0);
max@0 1286 }
max@0 1287 #else
max@0 1288 {
max@0 1289 arma_ignore(eigval);
max@0 1290 arma_ignore(eigvec);
max@0 1291 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
max@0 1292
max@0 1293 return false;
max@0 1294 }
max@0 1295 #endif
max@0 1296 }
max@0 1297
max@0 1298
max@0 1299
max@0 1300 //! immediate eigenvalues and eigenvectors of a hermitian complex matrix using LAPACK
max@0 1301 template<typename T, typename T1>
max@0 1302 inline
max@0 1303 bool
max@0 1304 auxlib::eig_sym(Col<T>& eigval, Mat< std::complex<T> >& eigvec, const Base<std::complex<T>,T1>& X)
max@0 1305 {
max@0 1306 arma_extra_debug_sigprint();
max@0 1307
max@0 1308 typedef typename std::complex<T> eT;
max@0 1309
max@0 1310 #if defined(ARMA_USE_LAPACK)
max@0 1311 {
max@0 1312 eigvec = X.get_ref();
max@0 1313
max@0 1314 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not hermitian" );
max@0 1315
max@0 1316 if(eigvec.is_empty())
max@0 1317 {
max@0 1318 eigval.reset();
max@0 1319 eigvec.reset();
max@0 1320 return true;
max@0 1321 }
max@0 1322
max@0 1323 char jobz = 'V';
max@0 1324 char uplo = 'U';
max@0 1325
max@0 1326 blas_int n_rows = eigvec.n_rows;
max@0 1327 blas_int lda = eigvec.n_rows;
max@0 1328 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1); // TODO: automatically find best size of lwork
max@0 1329
max@0 1330 eigval.set_size( static_cast<uword>(n_rows) );
max@0 1331
max@0 1332 podarray<eT> work( static_cast<uword>(lwork) );
max@0 1333 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) );
max@0 1334
max@0 1335 blas_int info;
max@0 1336
max@0 1337 arma_extra_debug_print("lapack::heev()");
max@0 1338 lapack::heev(&jobz, &uplo, &n_rows, eigvec.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
max@0 1339
max@0 1340 return (info == 0);
max@0 1341 }
max@0 1342 #else
max@0 1343 {
max@0 1344 arma_ignore(eigval);
max@0 1345 arma_ignore(eigvec);
max@0 1346 arma_ignore(X);
max@0 1347 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
max@0 1348 return false;
max@0 1349 }
max@0 1350 #endif
max@0 1351 }
max@0 1352
max@0 1353
max@0 1354
max@0 1355 //! Eigenvalues and eigenvectors of a general square real matrix using LAPACK.
max@0 1356 //! The argument 'side' specifies which eigenvectors should be calculated
max@0 1357 //! (see code for mode details).
max@0 1358 template<typename T, typename T1>
max@0 1359 inline
max@0 1360 bool
max@0 1361 auxlib::eig_gen
max@0 1362 (
max@0 1363 Col< std::complex<T> >& eigval,
max@0 1364 Mat<T>& l_eigvec,
max@0 1365 Mat<T>& r_eigvec,
max@0 1366 const Base<T,T1>& X,
max@0 1367 const char side
max@0 1368 )
max@0 1369 {
max@0 1370 arma_extra_debug_sigprint();
max@0 1371
max@0 1372 #if defined(ARMA_USE_LAPACK)
max@0 1373 {
max@0 1374 char jobvl;
max@0 1375 char jobvr;
max@0 1376
max@0 1377 switch(side)
max@0 1378 {
max@0 1379 case 'l': // left
max@0 1380 jobvl = 'V';
max@0 1381 jobvr = 'N';
max@0 1382 break;
max@0 1383
max@0 1384 case 'r': // right
max@0 1385 jobvl = 'N';
max@0 1386 jobvr = 'V';
max@0 1387 break;
max@0 1388
max@0 1389 case 'b': // both
max@0 1390 jobvl = 'V';
max@0 1391 jobvr = 'V';
max@0 1392 break;
max@0 1393
max@0 1394 case 'n': // neither
max@0 1395 jobvl = 'N';
max@0 1396 jobvr = 'N';
max@0 1397 break;
max@0 1398
max@0 1399 default:
max@0 1400 arma_stop("eig_gen(): parameter 'side' is invalid");
max@0 1401 return false;
max@0 1402 }
max@0 1403
max@0 1404 Mat<T> A(X.get_ref());
max@0 1405 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
max@0 1406
max@0 1407 if(A.is_empty())
max@0 1408 {
max@0 1409 eigval.reset();
max@0 1410 l_eigvec.reset();
max@0 1411 r_eigvec.reset();
max@0 1412 return true;
max@0 1413 }
max@0 1414
max@0 1415 uword A_n_rows = A.n_rows;
max@0 1416
max@0 1417 blas_int n_rows = A_n_rows;
max@0 1418 blas_int lda = A_n_rows;
max@0 1419 blas_int lwork = (std::max)(blas_int(1), 4*n_rows); // TODO: automatically find best size of lwork
max@0 1420
max@0 1421 eigval.set_size(A_n_rows);
max@0 1422 l_eigvec.set_size(A_n_rows, A_n_rows);
max@0 1423 r_eigvec.set_size(A_n_rows, A_n_rows);
max@0 1424
max@0 1425 podarray<T> work( static_cast<uword>(lwork) );
max@0 1426 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) );
max@0 1427
max@0 1428 podarray<T> wr(A_n_rows);
max@0 1429 podarray<T> wi(A_n_rows);
max@0 1430
max@0 1431 Mat<T> A_copy = A;
max@0 1432 blas_int info;
max@0 1433
max@0 1434 arma_extra_debug_print("lapack::geev()");
max@0 1435 lapack::geev(&jobvl, &jobvr, &n_rows, A_copy.memptr(), &lda, wr.memptr(), wi.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, &info);
max@0 1436
max@0 1437
max@0 1438 eigval.set_size(A_n_rows);
max@0 1439 for(uword i=0; i<A_n_rows; ++i)
max@0 1440 {
max@0 1441 eigval[i] = std::complex<T>(wr[i], wi[i]);
max@0 1442 }
max@0 1443
max@0 1444 return (info == 0);
max@0 1445 }
max@0 1446 #else
max@0 1447 {
max@0 1448 arma_ignore(eigval);
max@0 1449 arma_ignore(l_eigvec);
max@0 1450 arma_ignore(r_eigvec);
max@0 1451 arma_ignore(X);
max@0 1452 arma_ignore(side);
max@0 1453 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
max@0 1454 return false;
max@0 1455 }
max@0 1456 #endif
max@0 1457 }
max@0 1458
max@0 1459
max@0 1460
max@0 1461
max@0 1462
max@0 1463 //! Eigenvalues and eigenvectors of a general square complex matrix using LAPACK
max@0 1464 //! The argument 'side' specifies which eigenvectors should be calculated
max@0 1465 //! (see code for mode details).
max@0 1466 template<typename T, typename T1>
max@0 1467 inline
max@0 1468 bool
max@0 1469 auxlib::eig_gen
max@0 1470 (
max@0 1471 Col< std::complex<T> >& eigval,
max@0 1472 Mat< std::complex<T> >& l_eigvec,
max@0 1473 Mat< std::complex<T> >& r_eigvec,
max@0 1474 const Base< std::complex<T>, T1 >& X,
max@0 1475 const char side
max@0 1476 )
max@0 1477 {
max@0 1478 arma_extra_debug_sigprint();
max@0 1479
max@0 1480 typedef typename std::complex<T> eT;
max@0 1481
max@0 1482 #if defined(ARMA_USE_LAPACK)
max@0 1483 {
max@0 1484 char jobvl;
max@0 1485 char jobvr;
max@0 1486
max@0 1487 switch(side)
max@0 1488 {
max@0 1489 case 'l': // left
max@0 1490 jobvl = 'V';
max@0 1491 jobvr = 'N';
max@0 1492 break;
max@0 1493
max@0 1494 case 'r': // right
max@0 1495 jobvl = 'N';
max@0 1496 jobvr = 'V';
max@0 1497 break;
max@0 1498
max@0 1499 case 'b': // both
max@0 1500 jobvl = 'V';
max@0 1501 jobvr = 'V';
max@0 1502 break;
max@0 1503
max@0 1504 case 'n': // neither
max@0 1505 jobvl = 'N';
max@0 1506 jobvr = 'N';
max@0 1507 break;
max@0 1508
max@0 1509 default:
max@0 1510 arma_stop("eig_gen(): parameter 'side' is invalid");
max@0 1511 return false;
max@0 1512 }
max@0 1513
max@0 1514 Mat<eT> A(X.get_ref());
max@0 1515 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
max@0 1516
max@0 1517 if(A.is_empty())
max@0 1518 {
max@0 1519 eigval.reset();
max@0 1520 l_eigvec.reset();
max@0 1521 r_eigvec.reset();
max@0 1522 return true;
max@0 1523 }
max@0 1524
max@0 1525 uword A_n_rows = A.n_rows;
max@0 1526
max@0 1527 blas_int n_rows = A_n_rows;
max@0 1528 blas_int lda = A_n_rows;
max@0 1529 blas_int lwork = (std::max)(blas_int(1), 4*n_rows); // TODO: automatically find best size of lwork
max@0 1530
max@0 1531 eigval.set_size(A_n_rows);
max@0 1532 l_eigvec.set_size(A_n_rows, A_n_rows);
max@0 1533 r_eigvec.set_size(A_n_rows, A_n_rows);
max@0 1534
max@0 1535 podarray<eT> work( static_cast<uword>(lwork) );
max@0 1536 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) ); // was 2,3
max@0 1537
max@0 1538 blas_int info;
max@0 1539
max@0 1540 arma_extra_debug_print("lapack::cx_geev()");
max@0 1541 lapack::cx_geev(&jobvl, &jobvr, &n_rows, A.memptr(), &lda, eigval.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, rwork.memptr(), &info);
max@0 1542
max@0 1543 return (info == 0);
max@0 1544 }
max@0 1545 #else
max@0 1546 {
max@0 1547 arma_ignore(eigval);
max@0 1548 arma_ignore(l_eigvec);
max@0 1549 arma_ignore(r_eigvec);
max@0 1550 arma_ignore(X);
max@0 1551 arma_ignore(side);
max@0 1552 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
max@0 1553 return false;
max@0 1554 }
max@0 1555 #endif
max@0 1556 }
max@0 1557
max@0 1558
max@0 1559
max@0 1560 template<typename eT, typename T1>
max@0 1561 inline
max@0 1562 bool
max@0 1563 auxlib::chol(Mat<eT>& out, const Base<eT,T1>& X)
max@0 1564 {
max@0 1565 arma_extra_debug_sigprint();
max@0 1566
max@0 1567 #if defined(ARMA_USE_LAPACK)
max@0 1568 {
max@0 1569 out = X.get_ref();
max@0 1570
max@0 1571 arma_debug_check( (out.is_square() == false), "chol(): given matrix is not square" );
max@0 1572
max@0 1573 if(out.is_empty())
max@0 1574 {
max@0 1575 return true;
max@0 1576 }
max@0 1577
max@0 1578 const uword out_n_rows = out.n_rows;
max@0 1579
max@0 1580 char uplo = 'U';
max@0 1581 blas_int n = out_n_rows;
max@0 1582 blas_int info;
max@0 1583
max@0 1584 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
max@0 1585
max@0 1586 for(uword col=0; col<out_n_rows; ++col)
max@0 1587 {
max@0 1588 eT* colptr = out.colptr(col);
max@0 1589
max@0 1590 for(uword row=(col+1); row < out_n_rows; ++row)
max@0 1591 {
max@0 1592 colptr[row] = eT(0);
max@0 1593 }
max@0 1594 }
max@0 1595
max@0 1596 return (info == 0);
max@0 1597 }
max@0 1598 #else
max@0 1599 {
max@0 1600 arma_ignore(out);
max@0 1601 arma_stop("chol(): use of LAPACK needs to be enabled");
max@0 1602 return false;
max@0 1603 }
max@0 1604 #endif
max@0 1605 }
max@0 1606
max@0 1607
max@0 1608
max@0 1609 template<typename eT, typename T1>
max@0 1610 inline
max@0 1611 bool
max@0 1612 auxlib::qr(Mat<eT>& Q, Mat<eT>& R, const Base<eT,T1>& X)
max@0 1613 {
max@0 1614 arma_extra_debug_sigprint();
max@0 1615
max@0 1616 #if defined(ARMA_USE_LAPACK)
max@0 1617 {
max@0 1618 R = X.get_ref();
max@0 1619
max@0 1620 const uword R_n_rows = R.n_rows;
max@0 1621 const uword R_n_cols = R.n_cols;
max@0 1622
max@0 1623 if(R.is_empty())
max@0 1624 {
max@0 1625 Q.eye(R_n_rows, R_n_rows);
max@0 1626 return true;
max@0 1627 }
max@0 1628
max@0 1629 blas_int m = static_cast<blas_int>(R_n_rows);
max@0 1630 blas_int n = static_cast<blas_int>(R_n_cols);
max@0 1631 blas_int work_len = (std::max)(blas_int(1),n);
max@0 1632 blas_int work_len_tmp;
max@0 1633 blas_int k = (std::min)(m,n);
max@0 1634 blas_int info;
max@0 1635
max@0 1636 podarray<eT> tau( static_cast<uword>(k) );
max@0 1637 podarray<eT> work( static_cast<uword>(work_len) );
max@0 1638
max@0 1639 // query for the optimum value of work_len
max@0 1640 work_len_tmp = -1;
max@0 1641 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
max@0 1642
max@0 1643 if(info == 0)
max@0 1644 {
max@0 1645 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
max@0 1646 work.set_size( static_cast<uword>(work_len) );
max@0 1647 }
max@0 1648
max@0 1649 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
max@0 1650
max@0 1651 Q.set_size(R_n_rows, R_n_rows);
max@0 1652
max@0 1653 arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) );
max@0 1654
max@0 1655 //
max@0 1656 // construct R
max@0 1657
max@0 1658 for(uword col=0; col < R_n_cols; ++col)
max@0 1659 {
max@0 1660 for(uword row=(col+1); row < R_n_rows; ++row)
max@0 1661 {
max@0 1662 R.at(row,col) = eT(0);
max@0 1663 }
max@0 1664 }
max@0 1665
max@0 1666
max@0 1667 if( (is_float<eT>::value == true) || (is_double<eT>::value == true) )
max@0 1668 {
max@0 1669 // query for the optimum value of work_len
max@0 1670 work_len_tmp = -1;
max@0 1671 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
max@0 1672
max@0 1673 if(info == 0)
max@0 1674 {
max@0 1675 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
max@0 1676 work.set_size( static_cast<uword>(work_len) );
max@0 1677 }
max@0 1678
max@0 1679 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
max@0 1680 }
max@0 1681 else
max@0 1682 if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) )
max@0 1683 {
max@0 1684 // query for the optimum value of work_len
max@0 1685 work_len_tmp = -1;
max@0 1686 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
max@0 1687
max@0 1688 if(info == 0)
max@0 1689 {
max@0 1690 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
max@0 1691 work.set_size( static_cast<uword>(work_len) );
max@0 1692 }
max@0 1693
max@0 1694 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
max@0 1695 }
max@0 1696
max@0 1697 return (info == 0);
max@0 1698 }
max@0 1699 #else
max@0 1700 {
max@0 1701 arma_ignore(Q);
max@0 1702 arma_ignore(R);
max@0 1703 arma_ignore(X);
max@0 1704 arma_stop("qr(): use of LAPACK needs to be enabled");
max@0 1705 return false;
max@0 1706 }
max@0 1707 #endif
max@0 1708 }
max@0 1709
max@0 1710
max@0 1711
max@0 1712 template<typename eT, typename T1>
max@0 1713 inline
max@0 1714 bool
max@0 1715 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X, uword& X_n_rows, uword& X_n_cols)
max@0 1716 {
max@0 1717 arma_extra_debug_sigprint();
max@0 1718
max@0 1719 #if defined(ARMA_USE_LAPACK)
max@0 1720 {
max@0 1721 Mat<eT> A(X.get_ref());
max@0 1722
max@0 1723 X_n_rows = A.n_rows;
max@0 1724 X_n_cols = A.n_cols;
max@0 1725
max@0 1726 if(A.is_empty())
max@0 1727 {
max@0 1728 S.reset();
max@0 1729 return true;
max@0 1730 }
max@0 1731
max@0 1732 Mat<eT> U(1, 1);
max@0 1733 Mat<eT> V(1, A.n_cols);
max@0 1734
max@0 1735 char jobu = 'N';
max@0 1736 char jobvt = 'N';
max@0 1737
max@0 1738 blas_int m = A.n_rows;
max@0 1739 blas_int n = A.n_cols;
max@0 1740 blas_int lda = A.n_rows;
max@0 1741 blas_int ldu = U.n_rows;
max@0 1742 blas_int ldvt = V.n_rows;
max@0 1743 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
max@0 1744 blas_int info;
max@0 1745
max@0 1746 S.set_size( static_cast<uword>((std::min)(m, n)) );
max@0 1747
max@0 1748 podarray<eT> work( static_cast<uword>(lwork) );
max@0 1749
max@0 1750
max@0 1751 // let gesvd_() calculate the optimum size of the workspace
max@0 1752 blas_int lwork_tmp = -1;
max@0 1753
max@0 1754 lapack::gesvd<eT>
max@0 1755 (
max@0 1756 &jobu, &jobvt,
max@0 1757 &m,&n,
max@0 1758 A.memptr(), &lda,
max@0 1759 S.memptr(),
max@0 1760 U.memptr(), &ldu,
max@0 1761 V.memptr(), &ldvt,
max@0 1762 work.memptr(), &lwork_tmp,
max@0 1763 &info
max@0 1764 );
max@0 1765
max@0 1766 if(info == 0)
max@0 1767 {
max@0 1768 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
max@0 1769
max@0 1770 if(proposed_lwork > lwork)
max@0 1771 {
max@0 1772 lwork = proposed_lwork;
max@0 1773 work.set_size( static_cast<uword>(lwork) );
max@0 1774 }
max@0 1775
max@0 1776 lapack::gesvd<eT>
max@0 1777 (
max@0 1778 &jobu, &jobvt,
max@0 1779 &m, &n,
max@0 1780 A.memptr(), &lda,
max@0 1781 S.memptr(),
max@0 1782 U.memptr(), &ldu,
max@0 1783 V.memptr(), &ldvt,
max@0 1784 work.memptr(), &lwork,
max@0 1785 &info
max@0 1786 );
max@0 1787 }
max@0 1788
max@0 1789 return (info == 0);
max@0 1790 }
max@0 1791 #else
max@0 1792 {
max@0 1793 arma_ignore(S);
max@0 1794 arma_ignore(X);
max@0 1795 arma_ignore(X_n_rows);
max@0 1796 arma_ignore(X_n_cols);
max@0 1797 arma_stop("svd(): use of LAPACK needs to be enabled");
max@0 1798 return false;
max@0 1799 }
max@0 1800 #endif
max@0 1801 }
max@0 1802
max@0 1803
max@0 1804
max@0 1805 template<typename T, typename T1>
max@0 1806 inline
max@0 1807 bool
max@0 1808 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X, uword& X_n_rows, uword& X_n_cols)
max@0 1809 {
max@0 1810 arma_extra_debug_sigprint();
max@0 1811
max@0 1812 typedef std::complex<T> eT;
max@0 1813
max@0 1814 #if defined(ARMA_USE_LAPACK)
max@0 1815 {
max@0 1816 Mat<eT> A(X.get_ref());
max@0 1817
max@0 1818 X_n_rows = A.n_rows;
max@0 1819 X_n_cols = A.n_cols;
max@0 1820
max@0 1821 if(A.is_empty())
max@0 1822 {
max@0 1823 S.reset();
max@0 1824 return true;
max@0 1825 }
max@0 1826
max@0 1827 Mat<eT> U(1, 1);
max@0 1828 Mat<eT> V(1, A.n_cols);
max@0 1829
max@0 1830 char jobu = 'N';
max@0 1831 char jobvt = 'N';
max@0 1832
max@0 1833 blas_int m = A.n_rows;
max@0 1834 blas_int n = A.n_cols;
max@0 1835 blas_int lda = A.n_rows;
max@0 1836 blas_int ldu = U.n_rows;
max@0 1837 blas_int ldvt = V.n_rows;
max@0 1838 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) );
max@0 1839 blas_int info;
max@0 1840
max@0 1841 S.set_size( static_cast<uword>((std::min)(m,n)) );
max@0 1842
max@0 1843 podarray<eT> work( static_cast<uword>(lwork) );
max@0 1844 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
max@0 1845
max@0 1846 // let gesvd_() calculate the optimum size of the workspace
max@0 1847 blas_int lwork_tmp = -1;
max@0 1848
max@0 1849 lapack::cx_gesvd<T>
max@0 1850 (
max@0 1851 &jobu, &jobvt,
max@0 1852 &m, &n,
max@0 1853 A.memptr(), &lda,
max@0 1854 S.memptr(),
max@0 1855 U.memptr(), &ldu,
max@0 1856 V.memptr(), &ldvt,
max@0 1857 work.memptr(), &lwork_tmp,
max@0 1858 rwork.memptr(),
max@0 1859 &info
max@0 1860 );
max@0 1861
max@0 1862 if(info == 0)
max@0 1863 {
max@0 1864 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
max@0 1865 if(proposed_lwork > lwork)
max@0 1866 {
max@0 1867 lwork = proposed_lwork;
max@0 1868 work.set_size( static_cast<uword>(lwork) );
max@0 1869 }
max@0 1870
max@0 1871 lapack::cx_gesvd<T>
max@0 1872 (
max@0 1873 &jobu, &jobvt,
max@0 1874 &m, &n,
max@0 1875 A.memptr(), &lda,
max@0 1876 S.memptr(),
max@0 1877 U.memptr(), &ldu,
max@0 1878 V.memptr(), &ldvt,
max@0 1879 work.memptr(), &lwork,
max@0 1880 rwork.memptr(),
max@0 1881 &info
max@0 1882 );
max@0 1883 }
max@0 1884
max@0 1885 return (info == 0);
max@0 1886 }
max@0 1887 #else
max@0 1888 {
max@0 1889 arma_ignore(S);
max@0 1890 arma_ignore(X);
max@0 1891 arma_ignore(X_n_rows);
max@0 1892 arma_ignore(X_n_cols);
max@0 1893
max@0 1894 arma_stop("svd(): use of LAPACK needs to be enabled");
max@0 1895 return false;
max@0 1896 }
max@0 1897 #endif
max@0 1898 }
max@0 1899
max@0 1900
max@0 1901
max@0 1902 template<typename eT, typename T1>
max@0 1903 inline
max@0 1904 bool
max@0 1905 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X)
max@0 1906 {
max@0 1907 arma_extra_debug_sigprint();
max@0 1908
max@0 1909 uword junk;
max@0 1910 return auxlib::svd(S, X, junk, junk);
max@0 1911 }
max@0 1912
max@0 1913
max@0 1914
max@0 1915 template<typename T, typename T1>
max@0 1916 inline
max@0 1917 bool
max@0 1918 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X)
max@0 1919 {
max@0 1920 arma_extra_debug_sigprint();
max@0 1921
max@0 1922 uword junk;
max@0 1923 return auxlib::svd(S, X, junk, junk);
max@0 1924 }
max@0 1925
max@0 1926
max@0 1927
max@0 1928 template<typename eT, typename T1>
max@0 1929 inline
max@0 1930 bool
max@0 1931 auxlib::svd(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X)
max@0 1932 {
max@0 1933 arma_extra_debug_sigprint();
max@0 1934
max@0 1935 #if defined(ARMA_USE_LAPACK)
max@0 1936 {
max@0 1937 Mat<eT> A(X.get_ref());
max@0 1938
max@0 1939 if(A.is_empty())
max@0 1940 {
max@0 1941 U.eye(A.n_rows, A.n_rows);
max@0 1942 S.reset();
max@0 1943 V.eye(A.n_cols, A.n_cols);
max@0 1944 return true;
max@0 1945 }
max@0 1946
max@0 1947 U.set_size(A.n_rows, A.n_rows);
max@0 1948 V.set_size(A.n_cols, A.n_cols);
max@0 1949
max@0 1950 char jobu = 'A';
max@0 1951 char jobvt = 'A';
max@0 1952
max@0 1953 blas_int m = A.n_rows;
max@0 1954 blas_int n = A.n_cols;
max@0 1955 blas_int lda = A.n_rows;
max@0 1956 blas_int ldu = U.n_rows;
max@0 1957 blas_int ldvt = V.n_rows;
max@0 1958 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
max@0 1959 blas_int info;
max@0 1960
max@0 1961
max@0 1962 S.set_size( static_cast<uword>((std::min)(m,n)) );
max@0 1963 podarray<eT> work( static_cast<uword>(lwork) );
max@0 1964
max@0 1965 // let gesvd_() calculate the optimum size of the workspace
max@0 1966 blas_int lwork_tmp = -1;
max@0 1967
max@0 1968 lapack::gesvd<eT>
max@0 1969 (
max@0 1970 &jobu, &jobvt,
max@0 1971 &m, &n,
max@0 1972 A.memptr(), &lda,
max@0 1973 S.memptr(),
max@0 1974 U.memptr(), &ldu,
max@0 1975 V.memptr(), &ldvt,
max@0 1976 work.memptr(), &lwork_tmp,
max@0 1977 &info
max@0 1978 );
max@0 1979
max@0 1980 if(info == 0)
max@0 1981 {
max@0 1982 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
max@0 1983 if(proposed_lwork > lwork)
max@0 1984 {
max@0 1985 lwork = proposed_lwork;
max@0 1986 work.set_size( static_cast<uword>(lwork) );
max@0 1987 }
max@0 1988
max@0 1989 lapack::gesvd<eT>
max@0 1990 (
max@0 1991 &jobu, &jobvt,
max@0 1992 &m, &n,
max@0 1993 A.memptr(), &lda,
max@0 1994 S.memptr(),
max@0 1995 U.memptr(), &ldu,
max@0 1996 V.memptr(), &ldvt,
max@0 1997 work.memptr(), &lwork,
max@0 1998 &info
max@0 1999 );
max@0 2000
max@0 2001 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
max@0 2002 }
max@0 2003
max@0 2004 return (info == 0);
max@0 2005 }
max@0 2006 #else
max@0 2007 {
max@0 2008 arma_ignore(U);
max@0 2009 arma_ignore(S);
max@0 2010 arma_ignore(V);
max@0 2011 arma_ignore(X);
max@0 2012 arma_stop("svd(): use of LAPACK needs to be enabled");
max@0 2013 return false;
max@0 2014 }
max@0 2015 #endif
max@0 2016 }
max@0 2017
max@0 2018
max@0 2019
max@0 2020 template<typename T, typename T1>
max@0 2021 inline
max@0 2022 bool
max@0 2023 auxlib::svd(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X)
max@0 2024 {
max@0 2025 arma_extra_debug_sigprint();
max@0 2026
max@0 2027 typedef std::complex<T> eT;
max@0 2028
max@0 2029 #if defined(ARMA_USE_LAPACK)
max@0 2030 {
max@0 2031 Mat<eT> A(X.get_ref());
max@0 2032
max@0 2033 if(A.is_empty())
max@0 2034 {
max@0 2035 U.eye(A.n_rows, A.n_rows);
max@0 2036 S.reset();
max@0 2037 V.eye(A.n_cols, A.n_cols);
max@0 2038 return true;
max@0 2039 }
max@0 2040
max@0 2041 U.set_size(A.n_rows, A.n_rows);
max@0 2042 V.set_size(A.n_cols, A.n_cols);
max@0 2043
max@0 2044 char jobu = 'A';
max@0 2045 char jobvt = 'A';
max@0 2046
max@0 2047 blas_int m = A.n_rows;
max@0 2048 blas_int n = A.n_cols;
max@0 2049 blas_int lda = A.n_rows;
max@0 2050 blas_int ldu = U.n_rows;
max@0 2051 blas_int ldvt = V.n_rows;
max@0 2052 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) );
max@0 2053 blas_int info;
max@0 2054
max@0 2055 S.set_size( static_cast<uword>((std::min)(m,n)) );
max@0 2056
max@0 2057 podarray<eT> work( static_cast<uword>(lwork) );
max@0 2058 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
max@0 2059
max@0 2060 // let gesvd_() calculate the optimum size of the workspace
max@0 2061 blas_int lwork_tmp = -1;
max@0 2062 lapack::cx_gesvd<T>
max@0 2063 (
max@0 2064 &jobu, &jobvt,
max@0 2065 &m, &n,
max@0 2066 A.memptr(), &lda,
max@0 2067 S.memptr(),
max@0 2068 U.memptr(), &ldu,
max@0 2069 V.memptr(), &ldvt,
max@0 2070 work.memptr(), &lwork_tmp,
max@0 2071 rwork.memptr(),
max@0 2072 &info
max@0 2073 );
max@0 2074
max@0 2075 if(info == 0)
max@0 2076 {
max@0 2077 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
max@0 2078 if(proposed_lwork > lwork)
max@0 2079 {
max@0 2080 lwork = proposed_lwork;
max@0 2081 work.set_size( static_cast<uword>(lwork) );
max@0 2082 }
max@0 2083
max@0 2084 lapack::cx_gesvd<T>
max@0 2085 (
max@0 2086 &jobu, &jobvt,
max@0 2087 &m, &n,
max@0 2088 A.memptr(), &lda,
max@0 2089 S.memptr(),
max@0 2090 U.memptr(), &ldu,
max@0 2091 V.memptr(), &ldvt,
max@0 2092 work.memptr(), &lwork,
max@0 2093 rwork.memptr(),
max@0 2094 &info
max@0 2095 );
max@0 2096
max@0 2097 op_htrans::apply(V,V); // op_htrans will work out that an in-place transpose can be done
max@0 2098 }
max@0 2099
max@0 2100 return (info == 0);
max@0 2101 }
max@0 2102 #else
max@0 2103 {
max@0 2104 arma_ignore(U);
max@0 2105 arma_ignore(S);
max@0 2106 arma_ignore(V);
max@0 2107 arma_ignore(X);
max@0 2108 arma_stop("svd(): use of LAPACK needs to be enabled");
max@0 2109 return false;
max@0 2110 }
max@0 2111 #endif
max@0 2112
max@0 2113 }
max@0 2114
max@0 2115
max@0 2116
max@0 2117 template<typename eT, typename T1>
max@0 2118 inline
max@0 2119 bool
max@0 2120 auxlib::svd_econ(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X, const char mode)
max@0 2121 {
max@0 2122 arma_extra_debug_sigprint();
max@0 2123
max@0 2124 #if defined(ARMA_USE_LAPACK)
max@0 2125 {
max@0 2126 Mat<eT> A(X.get_ref());
max@0 2127
max@0 2128 blas_int m = A.n_rows;
max@0 2129 blas_int n = A.n_cols;
max@0 2130 blas_int lda = A.n_rows;
max@0 2131
max@0 2132 S.set_size( static_cast<uword>((std::min)(m,n)) );
max@0 2133
max@0 2134 blas_int ldu = 0;
max@0 2135 blas_int ldvt = 0;
max@0 2136
max@0 2137 char jobu;
max@0 2138 char jobvt;
max@0 2139
max@0 2140 switch(mode)
max@0 2141 {
max@0 2142 case 'l':
max@0 2143 jobu = 'S';
max@0 2144 jobvt = 'N';
max@0 2145
max@0 2146 ldu = m;
max@0 2147 ldvt = 1;
max@0 2148
max@0 2149 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
max@0 2150 V.reset();
max@0 2151
max@0 2152 break;
max@0 2153
max@0 2154
max@0 2155 case 'r':
max@0 2156 jobu = 'N';
max@0 2157 jobvt = 'S';
max@0 2158
max@0 2159 ldu = 1;
max@0 2160 ldvt = (std::min)(m,n);
max@0 2161
max@0 2162 U.reset();
max@0 2163 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
max@0 2164
max@0 2165 break;
max@0 2166
max@0 2167
max@0 2168 case 'b':
max@0 2169 jobu = 'S';
max@0 2170 jobvt = 'S';
max@0 2171
max@0 2172 ldu = m;
max@0 2173 ldvt = (std::min)(m,n);
max@0 2174
max@0 2175 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
max@0 2176 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
max@0 2177
max@0 2178 break;
max@0 2179
max@0 2180
max@0 2181 default:
max@0 2182 U.reset();
max@0 2183 S.reset();
max@0 2184 V.reset();
max@0 2185 return false;
max@0 2186 }
max@0 2187
max@0 2188
max@0 2189 if(A.is_empty())
max@0 2190 {
max@0 2191 U.eye();
max@0 2192 S.reset();
max@0 2193 V.eye();
max@0 2194 return true;
max@0 2195 }
max@0 2196
max@0 2197
max@0 2198 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
max@0 2199 blas_int info = 0;
max@0 2200
max@0 2201
max@0 2202 podarray<eT> work( static_cast<uword>(lwork) );
max@0 2203
max@0 2204 // let gesvd_() calculate the optimum size of the workspace
max@0 2205 blas_int lwork_tmp = -1;
max@0 2206
max@0 2207 lapack::gesvd<eT>
max@0 2208 (
max@0 2209 &jobu, &jobvt,
max@0 2210 &m, &n,
max@0 2211 A.memptr(), &lda,
max@0 2212 S.memptr(),
max@0 2213 U.memptr(), &ldu,
max@0 2214 V.memptr(), &ldvt,
max@0 2215 work.memptr(), &lwork_tmp,
max@0 2216 &info
max@0 2217 );
max@0 2218
max@0 2219 if(info == 0)
max@0 2220 {
max@0 2221 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
max@0 2222 if(proposed_lwork > lwork)
max@0 2223 {
max@0 2224 lwork = proposed_lwork;
max@0 2225 work.set_size( static_cast<uword>(lwork) );
max@0 2226 }
max@0 2227
max@0 2228 lapack::gesvd<eT>
max@0 2229 (
max@0 2230 &jobu, &jobvt,
max@0 2231 &m, &n,
max@0 2232 A.memptr(), &lda,
max@0 2233 S.memptr(),
max@0 2234 U.memptr(), &ldu,
max@0 2235 V.memptr(), &ldvt,
max@0 2236 work.memptr(), &lwork,
max@0 2237 &info
max@0 2238 );
max@0 2239
max@0 2240 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
max@0 2241 }
max@0 2242
max@0 2243 return (info == 0);
max@0 2244 }
max@0 2245 #else
max@0 2246 {
max@0 2247 arma_ignore(U);
max@0 2248 arma_ignore(S);
max@0 2249 arma_ignore(V);
max@0 2250 arma_ignore(X);
max@0 2251 arma_ignore(mode);
max@0 2252 arma_stop("svd(): use of LAPACK needs to be enabled");
max@0 2253 return false;
max@0 2254 }
max@0 2255 #endif
max@0 2256 }
max@0 2257
max@0 2258
max@0 2259
max@0 2260 template<typename T, typename T1>
max@0 2261 inline
max@0 2262 bool
max@0 2263 auxlib::svd_econ(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X, const char mode)
max@0 2264 {
max@0 2265 arma_extra_debug_sigprint();
max@0 2266
max@0 2267 typedef std::complex<T> eT;
max@0 2268
max@0 2269 #if defined(ARMA_USE_LAPACK)
max@0 2270 {
max@0 2271 Mat<eT> A(X.get_ref());
max@0 2272
max@0 2273 blas_int m = A.n_rows;
max@0 2274 blas_int n = A.n_cols;
max@0 2275 blas_int lda = A.n_rows;
max@0 2276
max@0 2277 S.set_size( static_cast<uword>((std::min)(m,n)) );
max@0 2278
max@0 2279 blas_int ldu = 0;
max@0 2280 blas_int ldvt = 0;
max@0 2281
max@0 2282 char jobu;
max@0 2283 char jobvt;
max@0 2284
max@0 2285 switch(mode)
max@0 2286 {
max@0 2287 case 'l':
max@0 2288 jobu = 'S';
max@0 2289 jobvt = 'N';
max@0 2290
max@0 2291 ldu = m;
max@0 2292 ldvt = 1;
max@0 2293
max@0 2294 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
max@0 2295 V.reset();
max@0 2296
max@0 2297 break;
max@0 2298
max@0 2299
max@0 2300 case 'r':
max@0 2301 jobu = 'N';
max@0 2302 jobvt = 'S';
max@0 2303
max@0 2304 ldu = 1;
max@0 2305 ldvt = (std::min)(m,n);
max@0 2306
max@0 2307 U.reset();
max@0 2308 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
max@0 2309
max@0 2310 break;
max@0 2311
max@0 2312
max@0 2313 case 'b':
max@0 2314 jobu = 'S';
max@0 2315 jobvt = 'S';
max@0 2316
max@0 2317 ldu = m;
max@0 2318 ldvt = (std::min)(m,n);
max@0 2319
max@0 2320 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
max@0 2321 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
max@0 2322
max@0 2323 break;
max@0 2324
max@0 2325
max@0 2326 default:
max@0 2327 U.reset();
max@0 2328 S.reset();
max@0 2329 V.reset();
max@0 2330 return false;
max@0 2331 }
max@0 2332
max@0 2333
max@0 2334 if(A.is_empty())
max@0 2335 {
max@0 2336 U.eye();
max@0 2337 S.reset();
max@0 2338 V.eye();
max@0 2339 return true;
max@0 2340 }
max@0 2341
max@0 2342
max@0 2343 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
max@0 2344 blas_int info = 0;
max@0 2345
max@0 2346
max@0 2347 podarray<eT> work( static_cast<uword>(lwork) );
max@0 2348 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
max@0 2349
max@0 2350 // let gesvd_() calculate the optimum size of the workspace
max@0 2351 blas_int lwork_tmp = -1;
max@0 2352
max@0 2353 lapack::cx_gesvd<T>
max@0 2354 (
max@0 2355 &jobu, &jobvt,
max@0 2356 &m, &n,
max@0 2357 A.memptr(), &lda,
max@0 2358 S.memptr(),
max@0 2359 U.memptr(), &ldu,
max@0 2360 V.memptr(), &ldvt,
max@0 2361 work.memptr(), &lwork_tmp,
max@0 2362 rwork.memptr(),
max@0 2363 &info
max@0 2364 );
max@0 2365
max@0 2366 if(info == 0)
max@0 2367 {
max@0 2368 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
max@0 2369 if(proposed_lwork > lwork)
max@0 2370 {
max@0 2371 lwork = proposed_lwork;
max@0 2372 work.set_size( static_cast<uword>(lwork) );
max@0 2373 }
max@0 2374
max@0 2375 lapack::cx_gesvd<T>
max@0 2376 (
max@0 2377 &jobu, &jobvt,
max@0 2378 &m, &n,
max@0 2379 A.memptr(), &lda,
max@0 2380 S.memptr(),
max@0 2381 U.memptr(), &ldu,
max@0 2382 V.memptr(), &ldvt,
max@0 2383 work.memptr(), &lwork,
max@0 2384 rwork.memptr(),
max@0 2385 &info
max@0 2386 );
max@0 2387
max@0 2388 op_htrans::apply(V,V); // op_strans will work out that an in-place transpose can be done
max@0 2389 }
max@0 2390
max@0 2391 return (info == 0);
max@0 2392 }
max@0 2393 #else
max@0 2394 {
max@0 2395 arma_ignore(U);
max@0 2396 arma_ignore(S);
max@0 2397 arma_ignore(V);
max@0 2398 arma_ignore(X);
max@0 2399 arma_ignore(mode);
max@0 2400 arma_stop("svd(): use of LAPACK needs to be enabled");
max@0 2401 return false;
max@0 2402 }
max@0 2403 #endif
max@0 2404 }
max@0 2405
max@0 2406
max@0 2407
max@0 2408 //! Solve a system of linear equations.
max@0 2409 //! Assumes that A.n_rows = A.n_cols and B.n_rows = A.n_rows
max@0 2410 template<typename eT>
max@0 2411 inline
max@0 2412 bool
max@0 2413 auxlib::solve(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B, const bool slow)
max@0 2414 {
max@0 2415 arma_extra_debug_sigprint();
max@0 2416
max@0 2417 if(A.is_empty() || B.is_empty())
max@0 2418 {
max@0 2419 out.zeros(A.n_cols, B.n_cols);
max@0 2420 return true;
max@0 2421 }
max@0 2422 else
max@0 2423 {
max@0 2424 const uword A_n_rows = A.n_rows;
max@0 2425
max@0 2426 bool status = false;
max@0 2427
max@0 2428 if( (A_n_rows <= 4) && (slow == false) )
max@0 2429 {
max@0 2430 Mat<eT> A_inv;
max@0 2431
max@0 2432 status = auxlib::inv_noalias_tinymat(A_inv, A, A_n_rows);
max@0 2433
max@0 2434 if(status == true)
max@0 2435 {
max@0 2436 out.set_size(A_n_rows, B.n_cols);
max@0 2437
max@0 2438 gemm_emul<false,false,false,false>::apply(out, A_inv, B);
max@0 2439
max@0 2440 return true;
max@0 2441 }
max@0 2442 }
max@0 2443
max@0 2444 if( (A_n_rows > 4) || (status == false) )
max@0 2445 {
max@0 2446 #if defined(ARMA_USE_ATLAS)
max@0 2447 {
max@0 2448 podarray<int> ipiv(A_n_rows);
max@0 2449
max@0 2450 out = B;
max@0 2451
max@0 2452 int info = atlas::clapack_gesv<eT>(atlas::CblasColMajor, A_n_rows, B.n_cols, A.memptr(), A_n_rows, ipiv.memptr(), out.memptr(), A_n_rows);
max@0 2453
max@0 2454 return (info == 0);
max@0 2455 }
max@0 2456 #elif defined(ARMA_USE_LAPACK)
max@0 2457 {
max@0 2458 blas_int n = A_n_rows;
max@0 2459 blas_int lda = A_n_rows;
max@0 2460 blas_int ldb = A_n_rows;
max@0 2461 blas_int nrhs = B.n_cols;
max@0 2462 blas_int info;
max@0 2463
max@0 2464 podarray<blas_int> ipiv(A_n_rows);
max@0 2465
max@0 2466 out = B;
max@0 2467
max@0 2468 lapack::gesv<eT>(&n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info);
max@0 2469
max@0 2470 return (info == 0);
max@0 2471 }
max@0 2472 #else
max@0 2473 {
max@0 2474 arma_stop("solve(): use of ATLAS or LAPACK needs to be enabled");
max@0 2475 return false;
max@0 2476 }
max@0 2477 #endif
max@0 2478 }
max@0 2479 }
max@0 2480
max@0 2481 return true;
max@0 2482 }
max@0 2483
max@0 2484
max@0 2485
max@0 2486 //! Solve an over-determined system.
max@0 2487 //! Assumes that A.n_rows > A.n_cols and B.n_rows = A.n_rows
max@0 2488 template<typename eT>
max@0 2489 inline
max@0 2490 bool
max@0 2491 auxlib::solve_od(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B)
max@0 2492 {
max@0 2493 arma_extra_debug_sigprint();
max@0 2494
max@0 2495 #if defined(ARMA_USE_LAPACK)
max@0 2496 {
max@0 2497 if(A.is_empty() || B.is_empty())
max@0 2498 {
max@0 2499 out.zeros(A.n_cols, B.n_cols);
max@0 2500 return true;
max@0 2501 }
max@0 2502
max@0 2503 char trans = 'N';
max@0 2504
max@0 2505 blas_int m = A.n_rows;
max@0 2506 blas_int n = A.n_cols;
max@0 2507 blas_int lda = A.n_rows;
max@0 2508 blas_int ldb = A.n_rows;
max@0 2509 blas_int nrhs = B.n_cols;
max@0 2510 blas_int lwork = n + (std::max)(n, nrhs);
max@0 2511 blas_int info;
max@0 2512
max@0 2513 Mat<eT> tmp = B;
max@0 2514
max@0 2515 podarray<eT> work( static_cast<uword>(lwork) );
max@0 2516
max@0 2517 arma_extra_debug_print("lapack::gels()");
max@0 2518
max@0 2519 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems
max@0 2520
max@0 2521 lapack::gels<eT>
max@0 2522 (
max@0 2523 &trans, &m, &n, &nrhs,
max@0 2524 A.memptr(), &lda,
max@0 2525 tmp.memptr(), &ldb,
max@0 2526 work.memptr(), &lwork,
max@0 2527 &info
max@0 2528 );
max@0 2529
max@0 2530 arma_extra_debug_print("lapack::gels() -- finished");
max@0 2531
max@0 2532 out.set_size(A.n_cols, B.n_cols);
max@0 2533
max@0 2534 for(uword col=0; col<B.n_cols; ++col)
max@0 2535 {
max@0 2536 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols );
max@0 2537 }
max@0 2538
max@0 2539 return (info == 0);
max@0 2540 }
max@0 2541 #else
max@0 2542 {
max@0 2543 arma_ignore(out);
max@0 2544 arma_ignore(A);
max@0 2545 arma_ignore(B);
max@0 2546 arma_stop("solve(): use of LAPACK needs to be enabled");
max@0 2547 return false;
max@0 2548 }
max@0 2549 #endif
max@0 2550 }
max@0 2551
max@0 2552
max@0 2553
max@0 2554 //! Solve an under-determined system.
max@0 2555 //! Assumes that A.n_rows < A.n_cols and B.n_rows = A.n_rows
max@0 2556 template<typename eT>
max@0 2557 inline
max@0 2558 bool
max@0 2559 auxlib::solve_ud(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B)
max@0 2560 {
max@0 2561 arma_extra_debug_sigprint();
max@0 2562
max@0 2563 #if defined(ARMA_USE_LAPACK)
max@0 2564 {
max@0 2565 if(A.is_empty() || B.is_empty())
max@0 2566 {
max@0 2567 out.zeros(A.n_cols, B.n_cols);
max@0 2568 return true;
max@0 2569 }
max@0 2570
max@0 2571 char trans = 'N';
max@0 2572
max@0 2573 blas_int m = A.n_rows;
max@0 2574 blas_int n = A.n_cols;
max@0 2575 blas_int lda = A.n_rows;
max@0 2576 blas_int ldb = A.n_cols;
max@0 2577 blas_int nrhs = B.n_cols;
max@0 2578 blas_int lwork = m + (std::max)(m,nrhs);
max@0 2579 blas_int info;
max@0 2580
max@0 2581
max@0 2582 Mat<eT> tmp;
max@0 2583 tmp.zeros(A.n_cols, B.n_cols);
max@0 2584
max@0 2585 for(uword col=0; col<B.n_cols; ++col)
max@0 2586 {
max@0 2587 eT* tmp_colmem = tmp.colptr(col);
max@0 2588
max@0 2589 arrayops::copy( tmp_colmem, B.colptr(col), B.n_rows );
max@0 2590
max@0 2591 for(uword row=B.n_rows; row<A.n_cols; ++row)
max@0 2592 {
max@0 2593 tmp_colmem[row] = eT(0);
max@0 2594 }
max@0 2595 }
max@0 2596
max@0 2597 podarray<eT> work( static_cast<uword>(lwork) );
max@0 2598
max@0 2599 arma_extra_debug_print("lapack::gels()");
max@0 2600
max@0 2601 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems
max@0 2602
max@0 2603 lapack::gels<eT>
max@0 2604 (
max@0 2605 &trans, &m, &n, &nrhs,
max@0 2606 A.memptr(), &lda,
max@0 2607 tmp.memptr(), &ldb,
max@0 2608 work.memptr(), &lwork,
max@0 2609 &info
max@0 2610 );
max@0 2611
max@0 2612 arma_extra_debug_print("lapack::gels() -- finished");
max@0 2613
max@0 2614 out.set_size(A.n_cols, B.n_cols);
max@0 2615
max@0 2616 for(uword col=0; col<B.n_cols; ++col)
max@0 2617 {
max@0 2618 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols );
max@0 2619 }
max@0 2620
max@0 2621 return (info == 0);
max@0 2622 }
max@0 2623 #else
max@0 2624 {
max@0 2625 arma_ignore(out);
max@0 2626 arma_ignore(A);
max@0 2627 arma_ignore(B);
max@0 2628 arma_stop("solve(): use of LAPACK needs to be enabled");
max@0 2629 return false;
max@0 2630 }
max@0 2631 #endif
max@0 2632 }
max@0 2633
max@0 2634
max@0 2635
max@0 2636 //
max@0 2637 // solve_tr
max@0 2638
max@0 2639 template<typename eT>
max@0 2640 inline
max@0 2641 bool
max@0 2642 auxlib::solve_tr(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const uword layout)
max@0 2643 {
max@0 2644 arma_extra_debug_sigprint();
max@0 2645
max@0 2646 #if defined(ARMA_USE_LAPACK)
max@0 2647 {
max@0 2648 if(A.is_empty() || B.is_empty())
max@0 2649 {
max@0 2650 out.zeros(A.n_cols, B.n_cols);
max@0 2651 return true;
max@0 2652 }
max@0 2653
max@0 2654 out = B;
max@0 2655
max@0 2656 char uplo = (layout == 0) ? 'U' : 'L';
max@0 2657 char trans = 'N';
max@0 2658 char diag = 'N';
max@0 2659 blas_int n = blas_int(A.n_rows);
max@0 2660 blas_int nrhs = blas_int(B.n_cols);
max@0 2661 blas_int info = 0;
max@0 2662
max@0 2663 lapack::trtrs<eT>(&uplo, &trans, &diag, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info);
max@0 2664
max@0 2665 return (info == 0);
max@0 2666 }
max@0 2667 #else
max@0 2668 {
max@0 2669 arma_ignore(out);
max@0 2670 arma_ignore(A);
max@0 2671 arma_ignore(B);
max@0 2672 arma_ignore(layout);
max@0 2673 arma_stop("solve(): use of LAPACK needs to be enabled");
max@0 2674 return false;
max@0 2675 }
max@0 2676 #endif
max@0 2677 }
max@0 2678
max@0 2679
max@0 2680
max@0 2681 //
max@0 2682 // Schur decomposition
max@0 2683
max@0 2684 template<typename eT>
max@0 2685 inline
max@0 2686 bool
max@0 2687 auxlib::schur_dec(Mat<eT>& Z, Mat<eT>& T, const Mat<eT>& A)
max@0 2688 {
max@0 2689 arma_extra_debug_sigprint();
max@0 2690
max@0 2691 #if defined(ARMA_USE_LAPACK)
max@0 2692 {
max@0 2693 arma_debug_check( (A.is_square() == false), "schur_dec(): given matrix is not square" );
max@0 2694
max@0 2695 if(A.is_empty())
max@0 2696 {
max@0 2697 Z.reset();
max@0 2698 T.reset();
max@0 2699 return true;
max@0 2700 }
max@0 2701
max@0 2702 const uword A_n_rows = A.n_rows;
max@0 2703
max@0 2704 char jobvs = 'V'; // get Schur vectors (Z)
max@0 2705 char sort = 'N'; // do not sort eigenvalues/vectors
max@0 2706 blas_int* select = 0; // pointer to sorting function
max@0 2707 blas_int n = blas_int(A_n_rows);
max@0 2708 blas_int sdim = 0; // output for sorting
max@0 2709
max@0 2710 blas_int lwork = 3 * n; // workspace must be at least 3 * n (if set to -1, optimal size is output in work(0) and nothing else is done
max@0 2711
max@0 2712 podarray<eT> work( static_cast<uword>(lwork) );
max@0 2713 podarray<blas_int> bwork(A_n_rows);
max@0 2714
max@0 2715 blas_int info = 0;
max@0 2716
max@0 2717 Z.set_size(A_n_rows, A_n_rows);
max@0 2718 T = A;
max@0 2719
max@0 2720 podarray<eT> wr(A_n_rows); // output for eigenvalues
max@0 2721 podarray<eT> wi(A_n_rows); // output for eigenvalues
max@0 2722
max@0 2723 lapack::gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, wr.memptr(), wi.memptr(), Z.memptr(), &n, work.memptr(), &lwork, bwork.memptr(), &info);
max@0 2724
max@0 2725 return (info == 0);
max@0 2726 }
max@0 2727 #else
max@0 2728 {
max@0 2729 arma_ignore(Z);
max@0 2730 arma_ignore(T);
max@0 2731 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
max@0 2732 return false;
max@0 2733 }
max@0 2734 #endif
max@0 2735 }
max@0 2736
max@0 2737
max@0 2738
max@0 2739 template<typename cT>
max@0 2740 inline
max@0 2741 bool
max@0 2742 auxlib::schur_dec(Mat<std::complex<cT> >& Z, Mat<std::complex<cT> >& T, const Mat<std::complex<cT> >& A)
max@0 2743 {
max@0 2744 arma_extra_debug_sigprint();
max@0 2745
max@0 2746 #if defined(ARMA_USE_LAPACK)
max@0 2747 {
max@0 2748 arma_debug_check( (A.is_square() == false), "schur_dec(): matrix A is not square" );
max@0 2749
max@0 2750 if(A.is_empty())
max@0 2751 {
max@0 2752 Z.reset();
max@0 2753 T.reset();
max@0 2754 return true;
max@0 2755 }
max@0 2756
max@0 2757 typedef std::complex<cT> eT;
max@0 2758
max@0 2759 const uword A_n_rows = A.n_rows;
max@0 2760
max@0 2761 char jobvs = 'V'; // get Schur vectors (Z)
max@0 2762 char sort = 'N'; // do not sort eigenvalues/vectors
max@0 2763 blas_int* select = 0; // pointer to sorting function
max@0 2764 blas_int n = blas_int(A_n_rows);
max@0 2765 blas_int sdim = 0; // output for sorting
max@0 2766
max@0 2767 blas_int lwork = 3 * n; // workspace must be at least 3 * n (if set to -1, optimal size is output in work(0) and nothing else is done
max@0 2768
max@0 2769 podarray<eT> work( static_cast<uword>(lwork) );
max@0 2770 podarray<blas_int> bwork(A_n_rows);
max@0 2771
max@0 2772 blas_int info = 0;
max@0 2773
max@0 2774 Z.set_size(A_n_rows, A_n_rows);
max@0 2775 T = A;
max@0 2776
max@0 2777 podarray<eT> w(A_n_rows); // output for eigenvalues
max@0 2778 podarray<cT> rwork(A_n_rows);
max@0 2779
max@0 2780 lapack::cx_gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, w.memptr(), Z.memptr(), &n, work.memptr(), &lwork, rwork.memptr(), bwork.memptr(), &info);
max@0 2781
max@0 2782 return (info == 0);
max@0 2783 }
max@0 2784 #else
max@0 2785 {
max@0 2786 arma_ignore(Z);
max@0 2787 arma_ignore(T);
max@0 2788 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
max@0 2789 return false;
max@0 2790 }
max@0 2791 #endif
max@0 2792 }
max@0 2793
max@0 2794
max@0 2795
max@0 2796 //
max@0 2797 // syl (solution of the Sylvester equation AX + XB = C)
max@0 2798
max@0 2799 template<typename eT>
max@0 2800 inline
max@0 2801 bool
max@0 2802 auxlib::syl(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C)
max@0 2803 {
max@0 2804 arma_extra_debug_sigprint();
max@0 2805
max@0 2806 arma_debug_check
max@0 2807 (
max@0 2808 (A.is_square() == false) || (B.is_square() == false),
max@0 2809 "syl(): given matrix is not square"
max@0 2810 );
max@0 2811
max@0 2812 arma_debug_check
max@0 2813 (
max@0 2814 (C.n_rows != A.n_rows) || (C.n_cols != B.n_cols),
max@0 2815 "syl(): matrices are not conformant"
max@0 2816 );
max@0 2817
max@0 2818 if(A.is_empty() || B.is_empty() || C.is_empty())
max@0 2819 {
max@0 2820 X.reset();
max@0 2821 return true;
max@0 2822 }
max@0 2823
max@0 2824 #if defined(ARMA_USE_LAPACK)
max@0 2825 {
max@0 2826 Mat<eT> Z1, Z2, T1, T2;
max@0 2827
max@0 2828 const bool status_sd1 = auxlib::schur_dec(Z1, T1, A);
max@0 2829 const bool status_sd2 = auxlib::schur_dec(Z2, T2, B);
max@0 2830
max@0 2831 if( (status_sd1 == false) || (status_sd2 == false) )
max@0 2832 {
max@0 2833 return false;
max@0 2834 }
max@0 2835
max@0 2836 char trana = 'N';
max@0 2837 char tranb = 'N';
max@0 2838 blas_int isgn = +1;
max@0 2839 blas_int m = blas_int(T1.n_rows);
max@0 2840 blas_int n = blas_int(T2.n_cols);
max@0 2841
max@0 2842 eT scale = eT(0);
max@0 2843 blas_int info = 0;
max@0 2844
max@0 2845 Mat<eT> Y = trans(Z1) * C * Z2;
max@0 2846
max@0 2847 lapack::trsyl<eT>(&trana, &tranb, &isgn, &m, &n, T1.memptr(), &m, T2.memptr(), &n, Y.memptr(), &m, &scale, &info);
max@0 2848
max@0 2849 //Y /= scale;
max@0 2850 Y /= (-scale);
max@0 2851
max@0 2852 X = Z1 * Y * trans(Z2);
max@0 2853
max@0 2854 return (info >= 0);
max@0 2855 }
max@0 2856 #else
max@0 2857 {
max@0 2858 arma_stop("syl(): use of LAPACK needs to be enabled");
max@0 2859 return false;
max@0 2860 }
max@0 2861 #endif
max@0 2862 }
max@0 2863
max@0 2864
max@0 2865
max@0 2866 //
max@0 2867 // lyap (solution of the continuous Lyapunov equation AX + XA^H + Q = 0)
max@0 2868
max@0 2869 template<typename eT>
max@0 2870 inline
max@0 2871 bool
max@0 2872 auxlib::lyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
max@0 2873 {
max@0 2874 arma_extra_debug_sigprint();
max@0 2875
max@0 2876 arma_debug_check( (A.is_square() == false), "lyap(): matrix A is not square");
max@0 2877 arma_debug_check( (Q.is_square() == false), "lyap(): matrix Q is not square");
max@0 2878 arma_debug_check( (A.n_rows != Q.n_rows), "lyap(): matrices A and Q have different dimensions");
max@0 2879
max@0 2880 Mat<eT> htransA;
max@0 2881 op_htrans::apply_noalias(htransA, A);
max@0 2882
max@0 2883 const Mat<eT> mQ = -Q;
max@0 2884
max@0 2885 return auxlib::syl(X, A, htransA, mQ);
max@0 2886 }
max@0 2887
max@0 2888
max@0 2889
max@0 2890 //
max@0 2891 // dlyap (solution of the discrete Lyapunov equation AXA^H - X + Q = 0)
max@0 2892
max@0 2893 template<typename eT>
max@0 2894 inline
max@0 2895 bool
max@0 2896 auxlib::dlyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
max@0 2897 {
max@0 2898 arma_extra_debug_sigprint();
max@0 2899
max@0 2900 arma_debug_check( (A.is_square() == false), "dlyap(): matrix A is not square");
max@0 2901 arma_debug_check( (Q.is_square() == false), "dlyap(): matrix Q is not square");
max@0 2902 arma_debug_check( (A.n_rows != Q.n_rows), "dlyap(): matrices A and Q have different dimensions");
max@0 2903
max@0 2904 const Col<eT> vecQ = reshape(Q, Q.n_elem, 1);
max@0 2905
max@0 2906 const Mat<eT> M = eye< Mat<eT> >(Q.n_elem, Q.n_elem) - kron(conj(A), A);
max@0 2907
max@0 2908 Col<eT> vecX;
max@0 2909
max@0 2910 const bool status = solve(vecX, M, vecQ);
max@0 2911
max@0 2912 if(status == true)
max@0 2913 {
max@0 2914 X = reshape(vecX, Q.n_rows, Q.n_cols);
max@0 2915 return true;
max@0 2916 }
max@0 2917 else
max@0 2918 {
max@0 2919 X.reset();
max@0 2920 return false;
max@0 2921 }
max@0 2922 }
max@0 2923
max@0 2924
max@0 2925
max@0 2926 //! @}