annotate armadillo-3.900.4/include/armadillo_bits/auxlib_meat.hpp @ 84:55a047986812 tip

Update library URI so as not to be document-local
author Chris Cannam
date Wed, 22 Apr 2020 14:21:57 +0100
parents 1ec0e2823891
children
rev   line source
Chris@49 1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au)
Chris@49 2 // Copyright (C) 2008-2013 Conrad Sanderson
Chris@49 3 // Copyright (C) 2009 Edmund Highcock
Chris@49 4 // Copyright (C) 2011 James Sanders
Chris@49 5 // Copyright (C) 2011 Stanislav Funiak
Chris@49 6 // Copyright (C) 2012 Eric Jon Sundstrom
Chris@49 7 // Copyright (C) 2012 Michael McNeil Forbes
Chris@49 8 //
Chris@49 9 // This Source Code Form is subject to the terms of the Mozilla Public
Chris@49 10 // License, v. 2.0. If a copy of the MPL was not distributed with this
Chris@49 11 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
Chris@49 12
Chris@49 13
Chris@49 14 //! \addtogroup auxlib
Chris@49 15 //! @{
Chris@49 16
Chris@49 17
Chris@49 18
Chris@49 19 //! immediate matrix inverse
Chris@49 20 template<typename eT, typename T1>
Chris@49 21 inline
Chris@49 22 bool
Chris@49 23 auxlib::inv(Mat<eT>& out, const Base<eT,T1>& X, const bool slow)
Chris@49 24 {
Chris@49 25 arma_extra_debug_sigprint();
Chris@49 26
Chris@49 27 out = X.get_ref();
Chris@49 28
Chris@49 29 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
Chris@49 30
Chris@49 31 bool status = false;
Chris@49 32
Chris@49 33 const uword N = out.n_rows;
Chris@49 34
Chris@49 35 if( (N <= 4) && (slow == false) )
Chris@49 36 {
Chris@49 37 status = auxlib::inv_inplace_tinymat(out, N);
Chris@49 38 }
Chris@49 39
Chris@49 40 if( (N > 4) || (status == false) )
Chris@49 41 {
Chris@49 42 status = auxlib::inv_inplace_lapack(out);
Chris@49 43 }
Chris@49 44
Chris@49 45 return status;
Chris@49 46 }
Chris@49 47
Chris@49 48
Chris@49 49
Chris@49 50 template<typename eT>
Chris@49 51 inline
Chris@49 52 bool
Chris@49 53 auxlib::inv(Mat<eT>& out, const Mat<eT>& X, const bool slow)
Chris@49 54 {
Chris@49 55 arma_extra_debug_sigprint();
Chris@49 56
Chris@49 57 arma_debug_check( (X.is_square() == false), "inv(): given matrix is not square" );
Chris@49 58
Chris@49 59 bool status = false;
Chris@49 60
Chris@49 61 const uword N = X.n_rows;
Chris@49 62
Chris@49 63 if( (N <= 4) && (slow == false) )
Chris@49 64 {
Chris@49 65 status = (&out != &X) ? auxlib::inv_noalias_tinymat(out, X, N) : auxlib::inv_inplace_tinymat(out, N);
Chris@49 66 }
Chris@49 67
Chris@49 68 if( (N > 4) || (status == false) )
Chris@49 69 {
Chris@49 70 out = X;
Chris@49 71 status = auxlib::inv_inplace_lapack(out);
Chris@49 72 }
Chris@49 73
Chris@49 74 return status;
Chris@49 75 }
Chris@49 76
Chris@49 77
Chris@49 78
Chris@49 79 template<typename eT>
Chris@49 80 inline
Chris@49 81 bool
Chris@49 82 auxlib::inv_noalias_tinymat(Mat<eT>& out, const Mat<eT>& X, const uword N)
Chris@49 83 {
Chris@49 84 arma_extra_debug_sigprint();
Chris@49 85
Chris@49 86 bool det_ok = true;
Chris@49 87
Chris@49 88 out.set_size(N,N);
Chris@49 89
Chris@49 90 switch(N)
Chris@49 91 {
Chris@49 92 case 1:
Chris@49 93 {
Chris@49 94 out[0] = eT(1) / X[0];
Chris@49 95 };
Chris@49 96 break;
Chris@49 97
Chris@49 98 case 2:
Chris@49 99 {
Chris@49 100 const eT* Xm = X.memptr();
Chris@49 101
Chris@49 102 const eT a = Xm[pos<0,0>::n2];
Chris@49 103 const eT b = Xm[pos<0,1>::n2];
Chris@49 104 const eT c = Xm[pos<1,0>::n2];
Chris@49 105 const eT d = Xm[pos<1,1>::n2];
Chris@49 106
Chris@49 107 const eT tmp_det = (a*d - b*c);
Chris@49 108
Chris@49 109 if(tmp_det != eT(0))
Chris@49 110 {
Chris@49 111 eT* outm = out.memptr();
Chris@49 112
Chris@49 113 outm[pos<0,0>::n2] = d / tmp_det;
Chris@49 114 outm[pos<0,1>::n2] = -b / tmp_det;
Chris@49 115 outm[pos<1,0>::n2] = -c / tmp_det;
Chris@49 116 outm[pos<1,1>::n2] = a / tmp_det;
Chris@49 117 }
Chris@49 118 else
Chris@49 119 {
Chris@49 120 det_ok = false;
Chris@49 121 }
Chris@49 122 };
Chris@49 123 break;
Chris@49 124
Chris@49 125 case 3:
Chris@49 126 {
Chris@49 127 const eT* X_col0 = X.colptr(0);
Chris@49 128 const eT a11 = X_col0[0];
Chris@49 129 const eT a21 = X_col0[1];
Chris@49 130 const eT a31 = X_col0[2];
Chris@49 131
Chris@49 132 const eT* X_col1 = X.colptr(1);
Chris@49 133 const eT a12 = X_col1[0];
Chris@49 134 const eT a22 = X_col1[1];
Chris@49 135 const eT a32 = X_col1[2];
Chris@49 136
Chris@49 137 const eT* X_col2 = X.colptr(2);
Chris@49 138 const eT a13 = X_col2[0];
Chris@49 139 const eT a23 = X_col2[1];
Chris@49 140 const eT a33 = X_col2[2];
Chris@49 141
Chris@49 142 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
Chris@49 143
Chris@49 144 if(tmp_det != eT(0))
Chris@49 145 {
Chris@49 146 eT* out_col0 = out.colptr(0);
Chris@49 147 out_col0[0] = (a33*a22 - a32*a23) / tmp_det;
Chris@49 148 out_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
Chris@49 149 out_col0[2] = (a32*a21 - a31*a22) / tmp_det;
Chris@49 150
Chris@49 151 eT* out_col1 = out.colptr(1);
Chris@49 152 out_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
Chris@49 153 out_col1[1] = (a33*a11 - a31*a13) / tmp_det;
Chris@49 154 out_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
Chris@49 155
Chris@49 156 eT* out_col2 = out.colptr(2);
Chris@49 157 out_col2[0] = (a23*a12 - a22*a13) / tmp_det;
Chris@49 158 out_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
Chris@49 159 out_col2[2] = (a22*a11 - a21*a12) / tmp_det;
Chris@49 160 }
Chris@49 161 else
Chris@49 162 {
Chris@49 163 det_ok = false;
Chris@49 164 }
Chris@49 165 };
Chris@49 166 break;
Chris@49 167
Chris@49 168 case 4:
Chris@49 169 {
Chris@49 170 const eT tmp_det = det(X);
Chris@49 171
Chris@49 172 if(tmp_det != eT(0))
Chris@49 173 {
Chris@49 174 const eT* Xm = X.memptr();
Chris@49 175 eT* outm = out.memptr();
Chris@49 176
Chris@49 177 outm[pos<0,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
Chris@49 178 outm[pos<1,0>::n4] = ( Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
Chris@49 179 outm[pos<2,0>::n4] = ( Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
Chris@49 180 outm[pos<3,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
Chris@49 181
Chris@49 182 outm[pos<0,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
Chris@49 183 outm[pos<1,1>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
Chris@49 184 outm[pos<2,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
Chris@49 185 outm[pos<3,1>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
Chris@49 186
Chris@49 187 outm[pos<0,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
Chris@49 188 outm[pos<1,2>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
Chris@49 189 outm[pos<2,2>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
Chris@49 190 outm[pos<3,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
Chris@49 191
Chris@49 192 outm[pos<0,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
Chris@49 193 outm[pos<1,3>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
Chris@49 194 outm[pos<2,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
Chris@49 195 outm[pos<3,3>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] ) / tmp_det;
Chris@49 196 }
Chris@49 197 else
Chris@49 198 {
Chris@49 199 det_ok = false;
Chris@49 200 }
Chris@49 201 };
Chris@49 202 break;
Chris@49 203
Chris@49 204 default:
Chris@49 205 ;
Chris@49 206 }
Chris@49 207
Chris@49 208 return det_ok;
Chris@49 209 }
Chris@49 210
Chris@49 211
Chris@49 212
Chris@49 213 template<typename eT>
Chris@49 214 inline
Chris@49 215 bool
Chris@49 216 auxlib::inv_inplace_tinymat(Mat<eT>& X, const uword N)
Chris@49 217 {
Chris@49 218 arma_extra_debug_sigprint();
Chris@49 219
Chris@49 220 bool det_ok = true;
Chris@49 221
Chris@49 222 // for more info, see:
Chris@49 223 // http://www.dr-lex.34sp.com/random/matrix_inv.html
Chris@49 224 // http://www.cvl.iis.u-tokyo.ac.jp/~miyazaki/tech/teche23.html
Chris@49 225 // http://www.euclideanspace.com/maths/algebra/matrix/functions/inverse/fourD/index.htm
Chris@49 226 // http://www.geometrictools.com//LibFoundation/Mathematics/Wm4Matrix4.inl
Chris@49 227
Chris@49 228 switch(N)
Chris@49 229 {
Chris@49 230 case 1:
Chris@49 231 {
Chris@49 232 X[0] = eT(1) / X[0];
Chris@49 233 };
Chris@49 234 break;
Chris@49 235
Chris@49 236 case 2:
Chris@49 237 {
Chris@49 238 const eT a = X[pos<0,0>::n2];
Chris@49 239 const eT b = X[pos<0,1>::n2];
Chris@49 240 const eT c = X[pos<1,0>::n2];
Chris@49 241 const eT d = X[pos<1,1>::n2];
Chris@49 242
Chris@49 243 const eT tmp_det = (a*d - b*c);
Chris@49 244
Chris@49 245 if(tmp_det != eT(0))
Chris@49 246 {
Chris@49 247 X[pos<0,0>::n2] = d / tmp_det;
Chris@49 248 X[pos<0,1>::n2] = -b / tmp_det;
Chris@49 249 X[pos<1,0>::n2] = -c / tmp_det;
Chris@49 250 X[pos<1,1>::n2] = a / tmp_det;
Chris@49 251 }
Chris@49 252 else
Chris@49 253 {
Chris@49 254 det_ok = false;
Chris@49 255 }
Chris@49 256 };
Chris@49 257 break;
Chris@49 258
Chris@49 259 case 3:
Chris@49 260 {
Chris@49 261 eT* X_col0 = X.colptr(0);
Chris@49 262 eT* X_col1 = X.colptr(1);
Chris@49 263 eT* X_col2 = X.colptr(2);
Chris@49 264
Chris@49 265 const eT a11 = X_col0[0];
Chris@49 266 const eT a21 = X_col0[1];
Chris@49 267 const eT a31 = X_col0[2];
Chris@49 268
Chris@49 269 const eT a12 = X_col1[0];
Chris@49 270 const eT a22 = X_col1[1];
Chris@49 271 const eT a32 = X_col1[2];
Chris@49 272
Chris@49 273 const eT a13 = X_col2[0];
Chris@49 274 const eT a23 = X_col2[1];
Chris@49 275 const eT a33 = X_col2[2];
Chris@49 276
Chris@49 277 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
Chris@49 278
Chris@49 279 if(tmp_det != eT(0))
Chris@49 280 {
Chris@49 281 X_col0[0] = (a33*a22 - a32*a23) / tmp_det;
Chris@49 282 X_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
Chris@49 283 X_col0[2] = (a32*a21 - a31*a22) / tmp_det;
Chris@49 284
Chris@49 285 X_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
Chris@49 286 X_col1[1] = (a33*a11 - a31*a13) / tmp_det;
Chris@49 287 X_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
Chris@49 288
Chris@49 289 X_col2[0] = (a23*a12 - a22*a13) / tmp_det;
Chris@49 290 X_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
Chris@49 291 X_col2[2] = (a22*a11 - a21*a12) / tmp_det;
Chris@49 292 }
Chris@49 293 else
Chris@49 294 {
Chris@49 295 det_ok = false;
Chris@49 296 }
Chris@49 297 };
Chris@49 298 break;
Chris@49 299
Chris@49 300 case 4:
Chris@49 301 {
Chris@49 302 const eT tmp_det = det(X);
Chris@49 303
Chris@49 304 if(tmp_det != eT(0))
Chris@49 305 {
Chris@49 306 const Mat<eT> A(X);
Chris@49 307
Chris@49 308 const eT* Am = A.memptr();
Chris@49 309 eT* Xm = X.memptr();
Chris@49 310
Chris@49 311 Xm[pos<0,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] - Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] + Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
Chris@49 312 Xm[pos<1,0>::n4] = ( Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
Chris@49 313 Xm[pos<2,0>::n4] = ( Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
Chris@49 314 Xm[pos<3,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
Chris@49 315
Chris@49 316 Xm[pos<0,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] + Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] - Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
Chris@49 317 Xm[pos<1,1>::n4] = ( Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
Chris@49 318 Xm[pos<2,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
Chris@49 319 Xm[pos<3,1>::n4] = ( Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
Chris@49 320
Chris@49 321 Xm[pos<0,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] + Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
Chris@49 322 Xm[pos<1,2>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
Chris@49 323 Xm[pos<2,2>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
Chris@49 324 Xm[pos<3,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
Chris@49 325
Chris@49 326 Xm[pos<0,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] + Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
Chris@49 327 Xm[pos<1,3>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
Chris@49 328 Xm[pos<2,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
Chris@49 329 Xm[pos<3,3>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] ) / tmp_det;
Chris@49 330 }
Chris@49 331 else
Chris@49 332 {
Chris@49 333 det_ok = false;
Chris@49 334 }
Chris@49 335 };
Chris@49 336 break;
Chris@49 337
Chris@49 338 default:
Chris@49 339 ;
Chris@49 340 }
Chris@49 341
Chris@49 342 return det_ok;
Chris@49 343 }
Chris@49 344
Chris@49 345
Chris@49 346
Chris@49 347 template<typename eT>
Chris@49 348 inline
Chris@49 349 bool
Chris@49 350 auxlib::inv_inplace_lapack(Mat<eT>& out)
Chris@49 351 {
Chris@49 352 arma_extra_debug_sigprint();
Chris@49 353
Chris@49 354 if(out.is_empty())
Chris@49 355 {
Chris@49 356 return true;
Chris@49 357 }
Chris@49 358
Chris@49 359 #if defined(ARMA_USE_ATLAS)
Chris@49 360 {
Chris@49 361 podarray<int> ipiv(out.n_rows);
Chris@49 362
Chris@49 363 int info = atlas::clapack_getrf(atlas::CblasColMajor, out.n_rows, out.n_cols, out.memptr(), out.n_rows, ipiv.memptr());
Chris@49 364
Chris@49 365 if(info == 0)
Chris@49 366 {
Chris@49 367 info = atlas::clapack_getri(atlas::CblasColMajor, out.n_rows, out.memptr(), out.n_rows, ipiv.memptr());
Chris@49 368 }
Chris@49 369
Chris@49 370 return (info == 0);
Chris@49 371 }
Chris@49 372 #elif defined(ARMA_USE_LAPACK)
Chris@49 373 {
Chris@49 374 blas_int n_rows = out.n_rows;
Chris@49 375 blas_int n_cols = out.n_cols;
Chris@49 376 blas_int lwork = 0;
Chris@49 377 blas_int lwork_min = (std::max)(blas_int(1), n_rows);
Chris@49 378 blas_int info = 0;
Chris@49 379
Chris@49 380 podarray<blas_int> ipiv(out.n_rows);
Chris@49 381
Chris@49 382 eT work_query[2];
Chris@49 383 blas_int lwork_query = -1;
Chris@49 384
Chris@49 385 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), &work_query[0], &lwork_query, &info);
Chris@49 386
Chris@49 387 if(info == 0)
Chris@49 388 {
Chris@49 389 const blas_int lwork_proposed = static_cast<blas_int>( access::tmp_real(work_query[0]) );
Chris@49 390
Chris@49 391 lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min;
Chris@49 392 }
Chris@49 393 else
Chris@49 394 {
Chris@49 395 return false;
Chris@49 396 }
Chris@49 397
Chris@49 398 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 399
Chris@49 400 lapack::getrf(&n_rows, &n_cols, out.memptr(), &n_rows, ipiv.memptr(), &info);
Chris@49 401
Chris@49 402 if(info == 0)
Chris@49 403 {
Chris@49 404 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &lwork, &info);
Chris@49 405 }
Chris@49 406
Chris@49 407 return (info == 0);
Chris@49 408 }
Chris@49 409 #else
Chris@49 410 {
Chris@49 411 arma_stop("inv(): use of ATLAS or LAPACK needs to be enabled");
Chris@49 412 return false;
Chris@49 413 }
Chris@49 414 #endif
Chris@49 415 }
Chris@49 416
Chris@49 417
Chris@49 418
Chris@49 419 template<typename eT, typename T1>
Chris@49 420 inline
Chris@49 421 bool
Chris@49 422 auxlib::inv_tr(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
Chris@49 423 {
Chris@49 424 arma_extra_debug_sigprint();
Chris@49 425
Chris@49 426 out = X.get_ref();
Chris@49 427
Chris@49 428 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
Chris@49 429
Chris@49 430 if(out.is_empty())
Chris@49 431 {
Chris@49 432 return true;
Chris@49 433 }
Chris@49 434
Chris@49 435 bool status;
Chris@49 436
Chris@49 437 #if defined(ARMA_USE_LAPACK)
Chris@49 438 {
Chris@49 439 char uplo = (layout == 0) ? 'U' : 'L';
Chris@49 440 char diag = 'N';
Chris@49 441 blas_int n = blas_int(out.n_rows);
Chris@49 442 blas_int info = 0;
Chris@49 443
Chris@49 444 lapack::trtri(&uplo, &diag, &n, out.memptr(), &n, &info);
Chris@49 445
Chris@49 446 status = (info == 0);
Chris@49 447 }
Chris@49 448 #else
Chris@49 449 {
Chris@49 450 arma_ignore(layout);
Chris@49 451 arma_stop("inv(): use of LAPACK needs to be enabled");
Chris@49 452 status = false;
Chris@49 453 }
Chris@49 454 #endif
Chris@49 455
Chris@49 456
Chris@49 457 if(status == true)
Chris@49 458 {
Chris@49 459 if(layout == 0)
Chris@49 460 {
Chris@49 461 // upper triangular
Chris@49 462 out = trimatu(out);
Chris@49 463 }
Chris@49 464 else
Chris@49 465 {
Chris@49 466 // lower triangular
Chris@49 467 out = trimatl(out);
Chris@49 468 }
Chris@49 469 }
Chris@49 470
Chris@49 471 return status;
Chris@49 472 }
Chris@49 473
Chris@49 474
Chris@49 475
Chris@49 476 template<typename eT, typename T1>
Chris@49 477 inline
Chris@49 478 bool
Chris@49 479 auxlib::inv_sym(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
Chris@49 480 {
Chris@49 481 arma_extra_debug_sigprint();
Chris@49 482
Chris@49 483 out = X.get_ref();
Chris@49 484
Chris@49 485 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
Chris@49 486
Chris@49 487 if(out.is_empty())
Chris@49 488 {
Chris@49 489 return true;
Chris@49 490 }
Chris@49 491
Chris@49 492 bool status;
Chris@49 493
Chris@49 494 #if defined(ARMA_USE_LAPACK)
Chris@49 495 {
Chris@49 496 char uplo = (layout == 0) ? 'U' : 'L';
Chris@49 497 blas_int n = blas_int(out.n_rows);
Chris@49 498 blas_int lwork = 3 * (n*n); // TODO: use lwork = -1 to determine optimal size
Chris@49 499 blas_int info = 0;
Chris@49 500
Chris@49 501 podarray<blas_int> ipiv;
Chris@49 502 ipiv.set_size(out.n_rows);
Chris@49 503
Chris@49 504 podarray<eT> work;
Chris@49 505 work.set_size( uword(lwork) );
Chris@49 506
Chris@49 507 lapack::sytrf(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &lwork, &info);
Chris@49 508
Chris@49 509 status = (info == 0);
Chris@49 510
Chris@49 511 if(status == true)
Chris@49 512 {
Chris@49 513 lapack::sytri(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &info);
Chris@49 514
Chris@49 515 out = (layout == 0) ? symmatu(out) : symmatl(out);
Chris@49 516
Chris@49 517 status = (info == 0);
Chris@49 518 }
Chris@49 519 }
Chris@49 520 #else
Chris@49 521 {
Chris@49 522 arma_ignore(layout);
Chris@49 523 arma_stop("inv(): use of LAPACK needs to be enabled");
Chris@49 524 status = false;
Chris@49 525 }
Chris@49 526 #endif
Chris@49 527
Chris@49 528 return status;
Chris@49 529 }
Chris@49 530
Chris@49 531
Chris@49 532
Chris@49 533 template<typename eT, typename T1>
Chris@49 534 inline
Chris@49 535 bool
Chris@49 536 auxlib::inv_sympd(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
Chris@49 537 {
Chris@49 538 arma_extra_debug_sigprint();
Chris@49 539
Chris@49 540 out = X.get_ref();
Chris@49 541
Chris@49 542 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
Chris@49 543
Chris@49 544 if(out.is_empty())
Chris@49 545 {
Chris@49 546 return true;
Chris@49 547 }
Chris@49 548
Chris@49 549 bool status;
Chris@49 550
Chris@49 551 #if defined(ARMA_USE_LAPACK)
Chris@49 552 {
Chris@49 553 char uplo = (layout == 0) ? 'U' : 'L';
Chris@49 554 blas_int n = blas_int(out.n_rows);
Chris@49 555 blas_int info = 0;
Chris@49 556
Chris@49 557 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
Chris@49 558
Chris@49 559 status = (info == 0);
Chris@49 560
Chris@49 561 if(status == true)
Chris@49 562 {
Chris@49 563 lapack::potri(&uplo, &n, out.memptr(), &n, &info);
Chris@49 564
Chris@49 565 out = (layout == 0) ? symmatu(out) : symmatl(out);
Chris@49 566
Chris@49 567 status = (info == 0);
Chris@49 568 }
Chris@49 569 }
Chris@49 570 #else
Chris@49 571 {
Chris@49 572 arma_ignore(layout);
Chris@49 573 arma_stop("inv(): use of LAPACK needs to be enabled");
Chris@49 574 status = false;
Chris@49 575 }
Chris@49 576 #endif
Chris@49 577
Chris@49 578 return status;
Chris@49 579 }
Chris@49 580
Chris@49 581
Chris@49 582
Chris@49 583 template<typename eT, typename T1>
Chris@49 584 inline
Chris@49 585 eT
Chris@49 586 auxlib::det(const Base<eT,T1>& X, const bool slow)
Chris@49 587 {
Chris@49 588 const unwrap<T1> tmp(X.get_ref());
Chris@49 589 const Mat<eT>& A = tmp.M;
Chris@49 590
Chris@49 591 arma_debug_check( (A.is_square() == false), "det(): matrix is not square" );
Chris@49 592
Chris@49 593 const bool make_copy = (is_Mat<T1>::value == true) ? true : false;
Chris@49 594
Chris@49 595 if(slow == false)
Chris@49 596 {
Chris@49 597 const uword N = A.n_rows;
Chris@49 598
Chris@49 599 switch(N)
Chris@49 600 {
Chris@49 601 case 0:
Chris@49 602 case 1:
Chris@49 603 case 2:
Chris@49 604 return auxlib::det_tinymat(A, N);
Chris@49 605 break;
Chris@49 606
Chris@49 607 case 3:
Chris@49 608 case 4:
Chris@49 609 {
Chris@49 610 const eT tmp_det = auxlib::det_tinymat(A, N);
Chris@49 611 return (tmp_det != eT(0)) ? tmp_det : auxlib::det_lapack(A, make_copy);
Chris@49 612 }
Chris@49 613 break;
Chris@49 614
Chris@49 615 default:
Chris@49 616 return auxlib::det_lapack(A, make_copy);
Chris@49 617 }
Chris@49 618 }
Chris@49 619
Chris@49 620 return auxlib::det_lapack(A, make_copy);
Chris@49 621 }
Chris@49 622
Chris@49 623
Chris@49 624
Chris@49 625 template<typename eT>
Chris@49 626 inline
Chris@49 627 eT
Chris@49 628 auxlib::det_tinymat(const Mat<eT>& X, const uword N)
Chris@49 629 {
Chris@49 630 arma_extra_debug_sigprint();
Chris@49 631
Chris@49 632 switch(N)
Chris@49 633 {
Chris@49 634 case 0:
Chris@49 635 return eT(1);
Chris@49 636 break;
Chris@49 637
Chris@49 638 case 1:
Chris@49 639 return X[0];
Chris@49 640 break;
Chris@49 641
Chris@49 642 case 2:
Chris@49 643 {
Chris@49 644 const eT* Xm = X.memptr();
Chris@49 645
Chris@49 646 return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] );
Chris@49 647 }
Chris@49 648 break;
Chris@49 649
Chris@49 650 case 3:
Chris@49 651 {
Chris@49 652 // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2);
Chris@49 653 // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0);
Chris@49 654 // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1);
Chris@49 655 // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2);
Chris@49 656 // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0);
Chris@49 657 // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1);
Chris@49 658 // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6);
Chris@49 659
Chris@49 660 const eT* a_col0 = X.colptr(0);
Chris@49 661 const eT a11 = a_col0[0];
Chris@49 662 const eT a21 = a_col0[1];
Chris@49 663 const eT a31 = a_col0[2];
Chris@49 664
Chris@49 665 const eT* a_col1 = X.colptr(1);
Chris@49 666 const eT a12 = a_col1[0];
Chris@49 667 const eT a22 = a_col1[1];
Chris@49 668 const eT a32 = a_col1[2];
Chris@49 669
Chris@49 670 const eT* a_col2 = X.colptr(2);
Chris@49 671 const eT a13 = a_col2[0];
Chris@49 672 const eT a23 = a_col2[1];
Chris@49 673 const eT a33 = a_col2[2];
Chris@49 674
Chris@49 675 return ( a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13) );
Chris@49 676 }
Chris@49 677 break;
Chris@49 678
Chris@49 679 case 4:
Chris@49 680 {
Chris@49 681 const eT* Xm = X.memptr();
Chris@49 682
Chris@49 683 const eT val = \
Chris@49 684 Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
Chris@49 685 - Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
Chris@49 686 - Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
Chris@49 687 + Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
Chris@49 688 + Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
Chris@49 689 - Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
Chris@49 690 - Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
Chris@49 691 + Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
Chris@49 692 + Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
Chris@49 693 - Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
Chris@49 694 - Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
Chris@49 695 + Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
Chris@49 696 + Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
Chris@49 697 - Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
Chris@49 698 - Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
Chris@49 699 + Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
Chris@49 700 + Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
Chris@49 701 - Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
Chris@49 702 - Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
Chris@49 703 + Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
Chris@49 704 + Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
Chris@49 705 - Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
Chris@49 706 - Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
Chris@49 707 + Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
Chris@49 708 ;
Chris@49 709
Chris@49 710 return val;
Chris@49 711 }
Chris@49 712 break;
Chris@49 713
Chris@49 714 default:
Chris@49 715 return eT(0);
Chris@49 716 ;
Chris@49 717 }
Chris@49 718 }
Chris@49 719
Chris@49 720
Chris@49 721
Chris@49 722 //! immediate determinant of a matrix using ATLAS or LAPACK
Chris@49 723 template<typename eT>
Chris@49 724 inline
Chris@49 725 eT
Chris@49 726 auxlib::det_lapack(const Mat<eT>& X, const bool make_copy)
Chris@49 727 {
Chris@49 728 arma_extra_debug_sigprint();
Chris@49 729
Chris@49 730 Mat<eT> X_copy;
Chris@49 731
Chris@49 732 if(make_copy == true)
Chris@49 733 {
Chris@49 734 X_copy = X;
Chris@49 735 }
Chris@49 736
Chris@49 737 Mat<eT>& tmp = (make_copy == true) ? X_copy : const_cast< Mat<eT>& >(X);
Chris@49 738
Chris@49 739 if(tmp.is_empty())
Chris@49 740 {
Chris@49 741 return eT(1);
Chris@49 742 }
Chris@49 743
Chris@49 744
Chris@49 745 #if defined(ARMA_USE_ATLAS)
Chris@49 746 {
Chris@49 747 podarray<int> ipiv(tmp.n_rows);
Chris@49 748
Chris@49 749 //const int info =
Chris@49 750 atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
Chris@49 751
Chris@49 752 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
Chris@49 753 eT val = tmp.at(0,0);
Chris@49 754 for(uword i=1; i < tmp.n_rows; ++i)
Chris@49 755 {
Chris@49 756 val *= tmp.at(i,i);
Chris@49 757 }
Chris@49 758
Chris@49 759 int sign = +1;
Chris@49 760 for(uword i=0; i < tmp.n_rows; ++i)
Chris@49 761 {
Chris@49 762 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0
Chris@49 763 {
Chris@49 764 sign *= -1;
Chris@49 765 }
Chris@49 766 }
Chris@49 767
Chris@49 768 return ( (sign < 0) ? -val : val );
Chris@49 769 }
Chris@49 770 #elif defined(ARMA_USE_LAPACK)
Chris@49 771 {
Chris@49 772 podarray<blas_int> ipiv(tmp.n_rows);
Chris@49 773
Chris@49 774 blas_int info = 0;
Chris@49 775 blas_int n_rows = blas_int(tmp.n_rows);
Chris@49 776 blas_int n_cols = blas_int(tmp.n_cols);
Chris@49 777
Chris@49 778 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
Chris@49 779
Chris@49 780 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
Chris@49 781 eT val = tmp.at(0,0);
Chris@49 782 for(uword i=1; i < tmp.n_rows; ++i)
Chris@49 783 {
Chris@49 784 val *= tmp.at(i,i);
Chris@49 785 }
Chris@49 786
Chris@49 787 blas_int sign = +1;
Chris@49 788 for(uword i=0; i < tmp.n_rows; ++i)
Chris@49 789 {
Chris@49 790 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1
Chris@49 791 {
Chris@49 792 sign *= -1;
Chris@49 793 }
Chris@49 794 }
Chris@49 795
Chris@49 796 return ( (sign < 0) ? -val : val );
Chris@49 797 }
Chris@49 798 #else
Chris@49 799 {
Chris@49 800 arma_ignore(X);
Chris@49 801 arma_ignore(make_copy);
Chris@49 802 arma_ignore(tmp);
Chris@49 803 arma_stop("det(): use of ATLAS or LAPACK needs to be enabled");
Chris@49 804 return eT(0);
Chris@49 805 }
Chris@49 806 #endif
Chris@49 807 }
Chris@49 808
Chris@49 809
Chris@49 810
Chris@49 811 //! immediate log determinant of a matrix using ATLAS or LAPACK
Chris@49 812 template<typename eT, typename T1>
Chris@49 813 inline
Chris@49 814 bool
Chris@49 815 auxlib::log_det(eT& out_val, typename get_pod_type<eT>::result& out_sign, const Base<eT,T1>& X)
Chris@49 816 {
Chris@49 817 arma_extra_debug_sigprint();
Chris@49 818
Chris@49 819 typedef typename get_pod_type<eT>::result T;
Chris@49 820
Chris@49 821 #if defined(ARMA_USE_ATLAS)
Chris@49 822 {
Chris@49 823 Mat<eT> tmp(X.get_ref());
Chris@49 824 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
Chris@49 825
Chris@49 826 if(tmp.is_empty())
Chris@49 827 {
Chris@49 828 out_val = eT(0);
Chris@49 829 out_sign = T(1);
Chris@49 830 return true;
Chris@49 831 }
Chris@49 832
Chris@49 833 podarray<int> ipiv(tmp.n_rows);
Chris@49 834
Chris@49 835 const int info = atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
Chris@49 836
Chris@49 837 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
Chris@49 838
Chris@49 839 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
Chris@49 840 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
Chris@49 841
Chris@49 842 for(uword i=1; i < tmp.n_rows; ++i)
Chris@49 843 {
Chris@49 844 const eT x = tmp.at(i,i);
Chris@49 845
Chris@49 846 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
Chris@49 847 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
Chris@49 848 }
Chris@49 849
Chris@49 850 for(uword i=0; i < tmp.n_rows; ++i)
Chris@49 851 {
Chris@49 852 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0
Chris@49 853 {
Chris@49 854 sign *= -1;
Chris@49 855 }
Chris@49 856 }
Chris@49 857
Chris@49 858 out_val = val;
Chris@49 859 out_sign = T(sign);
Chris@49 860
Chris@49 861 return (info == 0);
Chris@49 862 }
Chris@49 863 #elif defined(ARMA_USE_LAPACK)
Chris@49 864 {
Chris@49 865 Mat<eT> tmp(X.get_ref());
Chris@49 866 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
Chris@49 867
Chris@49 868 if(tmp.is_empty())
Chris@49 869 {
Chris@49 870 out_val = eT(0);
Chris@49 871 out_sign = T(1);
Chris@49 872 return true;
Chris@49 873 }
Chris@49 874
Chris@49 875 podarray<blas_int> ipiv(tmp.n_rows);
Chris@49 876
Chris@49 877 blas_int info = 0;
Chris@49 878 blas_int n_rows = blas_int(tmp.n_rows);
Chris@49 879 blas_int n_cols = blas_int(tmp.n_cols);
Chris@49 880
Chris@49 881 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
Chris@49 882
Chris@49 883 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
Chris@49 884
Chris@49 885 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
Chris@49 886 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
Chris@49 887
Chris@49 888 for(uword i=1; i < tmp.n_rows; ++i)
Chris@49 889 {
Chris@49 890 const eT x = tmp.at(i,i);
Chris@49 891
Chris@49 892 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
Chris@49 893 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
Chris@49 894 }
Chris@49 895
Chris@49 896 for(uword i=0; i < tmp.n_rows; ++i)
Chris@49 897 {
Chris@49 898 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1
Chris@49 899 {
Chris@49 900 sign *= -1;
Chris@49 901 }
Chris@49 902 }
Chris@49 903
Chris@49 904 out_val = val;
Chris@49 905 out_sign = T(sign);
Chris@49 906
Chris@49 907 return (info == 0);
Chris@49 908 }
Chris@49 909 #else
Chris@49 910 {
Chris@49 911 arma_ignore(X);
Chris@49 912
Chris@49 913 out_val = eT(0);
Chris@49 914 out_sign = T(0);
Chris@49 915
Chris@49 916 arma_stop("log_det(): use of ATLAS or LAPACK needs to be enabled");
Chris@49 917
Chris@49 918 return false;
Chris@49 919 }
Chris@49 920 #endif
Chris@49 921 }
Chris@49 922
Chris@49 923
Chris@49 924
Chris@49 925 //! immediate LU decomposition of a matrix using ATLAS or LAPACK
Chris@49 926 template<typename eT, typename T1>
Chris@49 927 inline
Chris@49 928 bool
Chris@49 929 auxlib::lu(Mat<eT>& L, Mat<eT>& U, podarray<blas_int>& ipiv, const Base<eT,T1>& X)
Chris@49 930 {
Chris@49 931 arma_extra_debug_sigprint();
Chris@49 932
Chris@49 933 U = X.get_ref();
Chris@49 934
Chris@49 935 const uword U_n_rows = U.n_rows;
Chris@49 936 const uword U_n_cols = U.n_cols;
Chris@49 937
Chris@49 938 if(U.is_empty())
Chris@49 939 {
Chris@49 940 L.set_size(U_n_rows, 0);
Chris@49 941 U.set_size(0, U_n_cols);
Chris@49 942 ipiv.reset();
Chris@49 943 return true;
Chris@49 944 }
Chris@49 945
Chris@49 946 #if defined(ARMA_USE_ATLAS) || defined(ARMA_USE_LAPACK)
Chris@49 947 {
Chris@49 948 bool status;
Chris@49 949
Chris@49 950 #if defined(ARMA_USE_ATLAS)
Chris@49 951 {
Chris@49 952 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
Chris@49 953
Chris@49 954 int info = atlas::clapack_getrf(atlas::CblasColMajor, U_n_rows, U_n_cols, U.memptr(), U_n_rows, ipiv.memptr());
Chris@49 955
Chris@49 956 status = (info == 0);
Chris@49 957 }
Chris@49 958 #elif defined(ARMA_USE_LAPACK)
Chris@49 959 {
Chris@49 960 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
Chris@49 961
Chris@49 962 blas_int info = 0;
Chris@49 963
Chris@49 964 blas_int n_rows = U_n_rows;
Chris@49 965 blas_int n_cols = U_n_cols;
Chris@49 966
Chris@49 967
Chris@49 968 lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info);
Chris@49 969
Chris@49 970 // take into account that Fortran counts from 1
Chris@49 971 arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem);
Chris@49 972
Chris@49 973 status = (info == 0);
Chris@49 974 }
Chris@49 975 #endif
Chris@49 976
Chris@49 977 L.copy_size(U);
Chris@49 978
Chris@49 979 for(uword col=0; col < U_n_cols; ++col)
Chris@49 980 {
Chris@49 981 for(uword row=0; (row < col) && (row < U_n_rows); ++row)
Chris@49 982 {
Chris@49 983 L.at(row,col) = eT(0);
Chris@49 984 }
Chris@49 985
Chris@49 986 if( L.in_range(col,col) == true )
Chris@49 987 {
Chris@49 988 L.at(col,col) = eT(1);
Chris@49 989 }
Chris@49 990
Chris@49 991 for(uword row = (col+1); row < U_n_rows; ++row)
Chris@49 992 {
Chris@49 993 L.at(row,col) = U.at(row,col);
Chris@49 994 U.at(row,col) = eT(0);
Chris@49 995 }
Chris@49 996 }
Chris@49 997
Chris@49 998 return status;
Chris@49 999 }
Chris@49 1000 #else
Chris@49 1001 {
Chris@49 1002 arma_stop("lu(): use of ATLAS or LAPACK needs to be enabled");
Chris@49 1003
Chris@49 1004 return false;
Chris@49 1005 }
Chris@49 1006 #endif
Chris@49 1007 }
Chris@49 1008
Chris@49 1009
Chris@49 1010
Chris@49 1011 template<typename eT, typename T1>
Chris@49 1012 inline
Chris@49 1013 bool
Chris@49 1014 auxlib::lu(Mat<eT>& L, Mat<eT>& U, Mat<eT>& P, const Base<eT,T1>& X)
Chris@49 1015 {
Chris@49 1016 arma_extra_debug_sigprint();
Chris@49 1017
Chris@49 1018 podarray<blas_int> ipiv1;
Chris@49 1019 const bool status = auxlib::lu(L, U, ipiv1, X);
Chris@49 1020
Chris@49 1021 if(status == true)
Chris@49 1022 {
Chris@49 1023 if(U.is_empty())
Chris@49 1024 {
Chris@49 1025 // L and U have been already set to the correct empty matrices
Chris@49 1026 P.eye(L.n_rows, L.n_rows);
Chris@49 1027 return true;
Chris@49 1028 }
Chris@49 1029
Chris@49 1030 const uword n = ipiv1.n_elem;
Chris@49 1031 const uword P_rows = U.n_rows;
Chris@49 1032
Chris@49 1033 podarray<blas_int> ipiv2(P_rows);
Chris@49 1034
Chris@49 1035 const blas_int* ipiv1_mem = ipiv1.memptr();
Chris@49 1036 blas_int* ipiv2_mem = ipiv2.memptr();
Chris@49 1037
Chris@49 1038 for(uword i=0; i<P_rows; ++i)
Chris@49 1039 {
Chris@49 1040 ipiv2_mem[i] = blas_int(i);
Chris@49 1041 }
Chris@49 1042
Chris@49 1043 for(uword i=0; i<n; ++i)
Chris@49 1044 {
Chris@49 1045 const uword k = static_cast<uword>(ipiv1_mem[i]);
Chris@49 1046
Chris@49 1047 if( ipiv2_mem[i] != ipiv2_mem[k] )
Chris@49 1048 {
Chris@49 1049 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
Chris@49 1050 }
Chris@49 1051 }
Chris@49 1052
Chris@49 1053 P.zeros(P_rows, P_rows);
Chris@49 1054
Chris@49 1055 for(uword row=0; row<P_rows; ++row)
Chris@49 1056 {
Chris@49 1057 P.at(row, static_cast<uword>(ipiv2_mem[row])) = eT(1);
Chris@49 1058 }
Chris@49 1059
Chris@49 1060 if(L.n_cols > U.n_rows)
Chris@49 1061 {
Chris@49 1062 L.shed_cols(U.n_rows, L.n_cols-1);
Chris@49 1063 }
Chris@49 1064
Chris@49 1065 if(U.n_rows > L.n_cols)
Chris@49 1066 {
Chris@49 1067 U.shed_rows(L.n_cols, U.n_rows-1);
Chris@49 1068 }
Chris@49 1069 }
Chris@49 1070
Chris@49 1071 return status;
Chris@49 1072 }
Chris@49 1073
Chris@49 1074
Chris@49 1075
Chris@49 1076 template<typename eT, typename T1>
Chris@49 1077 inline
Chris@49 1078 bool
Chris@49 1079 auxlib::lu(Mat<eT>& L, Mat<eT>& U, const Base<eT,T1>& X)
Chris@49 1080 {
Chris@49 1081 arma_extra_debug_sigprint();
Chris@49 1082
Chris@49 1083 podarray<blas_int> ipiv1;
Chris@49 1084 const bool status = auxlib::lu(L, U, ipiv1, X);
Chris@49 1085
Chris@49 1086 if(status == true)
Chris@49 1087 {
Chris@49 1088 if(U.is_empty())
Chris@49 1089 {
Chris@49 1090 // L and U have been already set to the correct empty matrices
Chris@49 1091 return true;
Chris@49 1092 }
Chris@49 1093
Chris@49 1094 const uword n = ipiv1.n_elem;
Chris@49 1095 const uword P_rows = U.n_rows;
Chris@49 1096
Chris@49 1097 podarray<blas_int> ipiv2(P_rows);
Chris@49 1098
Chris@49 1099 const blas_int* ipiv1_mem = ipiv1.memptr();
Chris@49 1100 blas_int* ipiv2_mem = ipiv2.memptr();
Chris@49 1101
Chris@49 1102 for(uword i=0; i<P_rows; ++i)
Chris@49 1103 {
Chris@49 1104 ipiv2_mem[i] = blas_int(i);
Chris@49 1105 }
Chris@49 1106
Chris@49 1107 for(uword i=0; i<n; ++i)
Chris@49 1108 {
Chris@49 1109 const uword k = static_cast<uword>(ipiv1_mem[i]);
Chris@49 1110
Chris@49 1111 if( ipiv2_mem[i] != ipiv2_mem[k] )
Chris@49 1112 {
Chris@49 1113 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
Chris@49 1114 L.swap_rows( static_cast<uword>(ipiv2_mem[i]), static_cast<uword>(ipiv2_mem[k]) );
Chris@49 1115 }
Chris@49 1116 }
Chris@49 1117
Chris@49 1118 if(L.n_cols > U.n_rows)
Chris@49 1119 {
Chris@49 1120 L.shed_cols(U.n_rows, L.n_cols-1);
Chris@49 1121 }
Chris@49 1122
Chris@49 1123 if(U.n_rows > L.n_cols)
Chris@49 1124 {
Chris@49 1125 U.shed_rows(L.n_cols, U.n_rows-1);
Chris@49 1126 }
Chris@49 1127 }
Chris@49 1128
Chris@49 1129 return status;
Chris@49 1130 }
Chris@49 1131
Chris@49 1132
Chris@49 1133
Chris@49 1134 //! immediate eigenvalues of a symmetric real matrix using LAPACK
Chris@49 1135 template<typename eT, typename T1>
Chris@49 1136 inline
Chris@49 1137 bool
Chris@49 1138 auxlib::eig_sym(Col<eT>& eigval, const Base<eT,T1>& X)
Chris@49 1139 {
Chris@49 1140 arma_extra_debug_sigprint();
Chris@49 1141
Chris@49 1142 #if defined(ARMA_USE_LAPACK)
Chris@49 1143 {
Chris@49 1144 Mat<eT> A(X.get_ref());
Chris@49 1145
Chris@49 1146 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square");
Chris@49 1147
Chris@49 1148 if(A.is_empty())
Chris@49 1149 {
Chris@49 1150 eigval.reset();
Chris@49 1151 return true;
Chris@49 1152 }
Chris@49 1153
Chris@49 1154 eigval.set_size(A.n_rows);
Chris@49 1155
Chris@49 1156 char jobz = 'N';
Chris@49 1157 char uplo = 'U';
Chris@49 1158
Chris@49 1159 blas_int N = blas_int(A.n_rows);
Chris@49 1160 blas_int lwork = 3 * ( (std::max)(blas_int(1), 3*N-1) );
Chris@49 1161 blas_int info = 0;
Chris@49 1162
Chris@49 1163 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 1164
Chris@49 1165 lapack::syev(&jobz, &uplo, &N, A.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, &info);
Chris@49 1166
Chris@49 1167 return (info == 0);
Chris@49 1168 }
Chris@49 1169 #else
Chris@49 1170 {
Chris@49 1171 arma_ignore(eigval);
Chris@49 1172 arma_ignore(X);
Chris@49 1173 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
Chris@49 1174 return false;
Chris@49 1175 }
Chris@49 1176 #endif
Chris@49 1177 }
Chris@49 1178
Chris@49 1179
Chris@49 1180
Chris@49 1181 //! immediate eigenvalues of a hermitian complex matrix using LAPACK
Chris@49 1182 template<typename T, typename T1>
Chris@49 1183 inline
Chris@49 1184 bool
Chris@49 1185 auxlib::eig_sym(Col<T>& eigval, const Base<std::complex<T>,T1>& X)
Chris@49 1186 {
Chris@49 1187 arma_extra_debug_sigprint();
Chris@49 1188
Chris@49 1189 #if defined(ARMA_USE_LAPACK)
Chris@49 1190 {
Chris@49 1191 typedef typename std::complex<T> eT;
Chris@49 1192
Chris@49 1193 Mat<eT> A(X.get_ref());
Chris@49 1194
Chris@49 1195 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square");
Chris@49 1196
Chris@49 1197 if(A.is_empty())
Chris@49 1198 {
Chris@49 1199 eigval.reset();
Chris@49 1200 return true;
Chris@49 1201 }
Chris@49 1202
Chris@49 1203 eigval.set_size(A.n_rows);
Chris@49 1204
Chris@49 1205 char jobz = 'N';
Chris@49 1206 char uplo = 'U';
Chris@49 1207
Chris@49 1208 blas_int N = blas_int(A.n_rows);
Chris@49 1209 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*N-1) );
Chris@49 1210 blas_int info = 0;
Chris@49 1211
Chris@49 1212 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 1213 podarray<T> rwork( static_cast<uword>( (std::max)(blas_int(1), 3*N-2) ) );
Chris@49 1214
Chris@49 1215 arma_extra_debug_print("lapack::heev()");
Chris@49 1216 lapack::heev(&jobz, &uplo, &N, A.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
Chris@49 1217
Chris@49 1218 return (info == 0);
Chris@49 1219 }
Chris@49 1220 #else
Chris@49 1221 {
Chris@49 1222 arma_ignore(eigval);
Chris@49 1223 arma_ignore(X);
Chris@49 1224 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
Chris@49 1225 return false;
Chris@49 1226 }
Chris@49 1227 #endif
Chris@49 1228 }
Chris@49 1229
Chris@49 1230
Chris@49 1231
Chris@49 1232 //! immediate eigenvalues and eigenvectors of a symmetric real matrix using LAPACK
Chris@49 1233 template<typename eT, typename T1>
Chris@49 1234 inline
Chris@49 1235 bool
Chris@49 1236 auxlib::eig_sym(Col<eT>& eigval, Mat<eT>& eigvec, const Base<eT,T1>& X)
Chris@49 1237 {
Chris@49 1238 arma_extra_debug_sigprint();
Chris@49 1239
Chris@49 1240 #if defined(ARMA_USE_LAPACK)
Chris@49 1241 {
Chris@49 1242 eigvec = X.get_ref();
Chris@49 1243
Chris@49 1244 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
Chris@49 1245
Chris@49 1246 if(eigvec.is_empty())
Chris@49 1247 {
Chris@49 1248 eigval.reset();
Chris@49 1249 eigvec.reset();
Chris@49 1250 return true;
Chris@49 1251 }
Chris@49 1252
Chris@49 1253 eigval.set_size(eigvec.n_rows);
Chris@49 1254
Chris@49 1255 char jobz = 'V';
Chris@49 1256 char uplo = 'U';
Chris@49 1257
Chris@49 1258 blas_int N = blas_int(eigvec.n_rows);
Chris@49 1259 blas_int lwork = 3 * ( (std::max)(blas_int(1), 3*N-1) );
Chris@49 1260 blas_int info = 0;
Chris@49 1261
Chris@49 1262 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 1263
Chris@49 1264 lapack::syev(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, &info);
Chris@49 1265
Chris@49 1266 return (info == 0);
Chris@49 1267 }
Chris@49 1268 #else
Chris@49 1269 {
Chris@49 1270 arma_ignore(eigval);
Chris@49 1271 arma_ignore(eigvec);
Chris@49 1272 arma_ignore(X);
Chris@49 1273 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
Chris@49 1274 return false;
Chris@49 1275 }
Chris@49 1276 #endif
Chris@49 1277 }
Chris@49 1278
Chris@49 1279
Chris@49 1280
Chris@49 1281 //! immediate eigenvalues and eigenvectors of a hermitian complex matrix using LAPACK
Chris@49 1282 template<typename T, typename T1>
Chris@49 1283 inline
Chris@49 1284 bool
Chris@49 1285 auxlib::eig_sym(Col<T>& eigval, Mat< std::complex<T> >& eigvec, const Base<std::complex<T>,T1>& X)
Chris@49 1286 {
Chris@49 1287 arma_extra_debug_sigprint();
Chris@49 1288
Chris@49 1289 #if defined(ARMA_USE_LAPACK)
Chris@49 1290 {
Chris@49 1291 typedef typename std::complex<T> eT;
Chris@49 1292
Chris@49 1293 eigvec = X.get_ref();
Chris@49 1294
Chris@49 1295 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
Chris@49 1296
Chris@49 1297 if(eigvec.is_empty())
Chris@49 1298 {
Chris@49 1299 eigval.reset();
Chris@49 1300 eigvec.reset();
Chris@49 1301 return true;
Chris@49 1302 }
Chris@49 1303
Chris@49 1304 eigval.set_size(eigvec.n_rows);
Chris@49 1305
Chris@49 1306 char jobz = 'V';
Chris@49 1307 char uplo = 'U';
Chris@49 1308
Chris@49 1309 blas_int N = blas_int(eigvec.n_rows);
Chris@49 1310 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*N-1) );
Chris@49 1311 blas_int info = 0;
Chris@49 1312
Chris@49 1313 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 1314 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*N-2)) );
Chris@49 1315
Chris@49 1316 arma_extra_debug_print("lapack::heev()");
Chris@49 1317 lapack::heev(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
Chris@49 1318
Chris@49 1319 return (info == 0);
Chris@49 1320 }
Chris@49 1321 #else
Chris@49 1322 {
Chris@49 1323 arma_ignore(eigval);
Chris@49 1324 arma_ignore(eigvec);
Chris@49 1325 arma_ignore(X);
Chris@49 1326 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
Chris@49 1327 return false;
Chris@49 1328 }
Chris@49 1329 #endif
Chris@49 1330 }
Chris@49 1331
Chris@49 1332
Chris@49 1333
Chris@49 1334 //! immediate eigenvalues and eigenvectors of a symmetric real matrix using LAPACK (divide and conquer algorithm)
Chris@49 1335 template<typename eT, typename T1>
Chris@49 1336 inline
Chris@49 1337 bool
Chris@49 1338 auxlib::eig_sym_dc(Col<eT>& eigval, Mat<eT>& eigvec, const Base<eT,T1>& X)
Chris@49 1339 {
Chris@49 1340 arma_extra_debug_sigprint();
Chris@49 1341
Chris@49 1342 #if defined(ARMA_USE_LAPACK)
Chris@49 1343 {
Chris@49 1344 eigvec = X.get_ref();
Chris@49 1345
Chris@49 1346 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
Chris@49 1347
Chris@49 1348 if(eigvec.is_empty())
Chris@49 1349 {
Chris@49 1350 eigval.reset();
Chris@49 1351 eigvec.reset();
Chris@49 1352 return true;
Chris@49 1353 }
Chris@49 1354
Chris@49 1355 eigval.set_size(eigvec.n_rows);
Chris@49 1356
Chris@49 1357 char jobz = 'V';
Chris@49 1358 char uplo = 'U';
Chris@49 1359
Chris@49 1360 blas_int N = blas_int(eigvec.n_rows);
Chris@49 1361 blas_int lwork = 3 * (1 + 6*N + 2*(N*N));
Chris@49 1362 blas_int liwork = 3 * (3 + 5*N + 2);
Chris@49 1363 blas_int info = 0;
Chris@49 1364
Chris@49 1365 podarray<eT> work( static_cast<uword>( lwork) );
Chris@49 1366 podarray<blas_int> iwork( static_cast<uword>(liwork) );
Chris@49 1367
Chris@49 1368 arma_extra_debug_print("lapack::syevd()");
Chris@49 1369 lapack::syevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, iwork.memptr(), &liwork, &info);
Chris@49 1370
Chris@49 1371 return (info == 0);
Chris@49 1372 }
Chris@49 1373 #else
Chris@49 1374 {
Chris@49 1375 arma_ignore(eigval);
Chris@49 1376 arma_ignore(eigvec);
Chris@49 1377 arma_ignore(X);
Chris@49 1378 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
Chris@49 1379 return false;
Chris@49 1380 }
Chris@49 1381 #endif
Chris@49 1382 }
Chris@49 1383
Chris@49 1384
Chris@49 1385
Chris@49 1386 //! immediate eigenvalues and eigenvectors of a hermitian complex matrix using LAPACK (divide and conquer algorithm)
Chris@49 1387 template<typename T, typename T1>
Chris@49 1388 inline
Chris@49 1389 bool
Chris@49 1390 auxlib::eig_sym_dc(Col<T>& eigval, Mat< std::complex<T> >& eigvec, const Base<std::complex<T>,T1>& X)
Chris@49 1391 {
Chris@49 1392 arma_extra_debug_sigprint();
Chris@49 1393
Chris@49 1394 #if defined(ARMA_USE_LAPACK)
Chris@49 1395 {
Chris@49 1396 typedef typename std::complex<T> eT;
Chris@49 1397
Chris@49 1398 eigvec = X.get_ref();
Chris@49 1399
Chris@49 1400 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
Chris@49 1401
Chris@49 1402 if(eigvec.is_empty())
Chris@49 1403 {
Chris@49 1404 eigval.reset();
Chris@49 1405 eigvec.reset();
Chris@49 1406 return true;
Chris@49 1407 }
Chris@49 1408
Chris@49 1409 eigval.set_size(eigvec.n_rows);
Chris@49 1410
Chris@49 1411 char jobz = 'V';
Chris@49 1412 char uplo = 'U';
Chris@49 1413
Chris@49 1414 blas_int N = blas_int(eigvec.n_rows);
Chris@49 1415 blas_int lwork = 3 * (2*N + N*N);
Chris@49 1416 blas_int lrwork = 3 * (1 + 5*N + 2*(N*N));
Chris@49 1417 blas_int liwork = 3 * (3 + 5*N);
Chris@49 1418 blas_int info = 0;
Chris@49 1419
Chris@49 1420 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 1421 podarray<T> rwork( static_cast<uword>(lrwork) );
Chris@49 1422 podarray<blas_int> iwork( static_cast<uword>(liwork) );
Chris@49 1423
Chris@49 1424 arma_extra_debug_print("lapack::heevd()");
Chris@49 1425 lapack::heevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &lrwork, iwork.memptr(), &liwork, &info);
Chris@49 1426
Chris@49 1427 return (info == 0);
Chris@49 1428 }
Chris@49 1429 #else
Chris@49 1430 {
Chris@49 1431 arma_ignore(eigval);
Chris@49 1432 arma_ignore(eigvec);
Chris@49 1433 arma_ignore(X);
Chris@49 1434 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
Chris@49 1435 return false;
Chris@49 1436 }
Chris@49 1437 #endif
Chris@49 1438 }
Chris@49 1439
Chris@49 1440
Chris@49 1441
Chris@49 1442 //! Eigenvalues and eigenvectors of a general square real matrix using LAPACK.
Chris@49 1443 //! The argument 'side' specifies which eigenvectors should be calculated
Chris@49 1444 //! (see code for mode details).
Chris@49 1445 template<typename T, typename T1>
Chris@49 1446 inline
Chris@49 1447 bool
Chris@49 1448 auxlib::eig_gen
Chris@49 1449 (
Chris@49 1450 Col< std::complex<T> >& eigval,
Chris@49 1451 Mat<T>& l_eigvec,
Chris@49 1452 Mat<T>& r_eigvec,
Chris@49 1453 const Base<T,T1>& X,
Chris@49 1454 const char side
Chris@49 1455 )
Chris@49 1456 {
Chris@49 1457 arma_extra_debug_sigprint();
Chris@49 1458
Chris@49 1459 #if defined(ARMA_USE_LAPACK)
Chris@49 1460 {
Chris@49 1461 char jobvl;
Chris@49 1462 char jobvr;
Chris@49 1463
Chris@49 1464 switch(side)
Chris@49 1465 {
Chris@49 1466 case 'l': // left
Chris@49 1467 jobvl = 'V';
Chris@49 1468 jobvr = 'N';
Chris@49 1469 break;
Chris@49 1470
Chris@49 1471 case 'r': // right
Chris@49 1472 jobvl = 'N';
Chris@49 1473 jobvr = 'V';
Chris@49 1474 break;
Chris@49 1475
Chris@49 1476 case 'b': // both
Chris@49 1477 jobvl = 'V';
Chris@49 1478 jobvr = 'V';
Chris@49 1479 break;
Chris@49 1480
Chris@49 1481 case 'n': // neither
Chris@49 1482 jobvl = 'N';
Chris@49 1483 jobvr = 'N';
Chris@49 1484 break;
Chris@49 1485
Chris@49 1486 default:
Chris@49 1487 arma_stop("eig_gen(): parameter 'side' is invalid");
Chris@49 1488 return false;
Chris@49 1489 }
Chris@49 1490
Chris@49 1491 Mat<T> A(X.get_ref());
Chris@49 1492 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
Chris@49 1493
Chris@49 1494 if(A.is_empty())
Chris@49 1495 {
Chris@49 1496 eigval.reset();
Chris@49 1497 l_eigvec.reset();
Chris@49 1498 r_eigvec.reset();
Chris@49 1499 return true;
Chris@49 1500 }
Chris@49 1501
Chris@49 1502 const uword A_n_rows = A.n_rows;
Chris@49 1503
Chris@49 1504 eigval.set_size(A_n_rows);
Chris@49 1505
Chris@49 1506 l_eigvec.set_size(A_n_rows, A_n_rows);
Chris@49 1507 r_eigvec.set_size(A_n_rows, A_n_rows);
Chris@49 1508
Chris@49 1509 blas_int N = blas_int(A_n_rows);
Chris@49 1510 blas_int lwork = 3 * ( (std::max)(blas_int(1), 4*N) );
Chris@49 1511 blas_int info = 0;
Chris@49 1512
Chris@49 1513 podarray<T> work( static_cast<uword>(lwork) );
Chris@49 1514
Chris@49 1515 podarray<T> wr(A_n_rows);
Chris@49 1516 podarray<T> wi(A_n_rows);
Chris@49 1517
Chris@49 1518 arma_extra_debug_print("lapack::geev()");
Chris@49 1519 lapack::geev(&jobvl, &jobvr, &N, A.memptr(), &N, wr.memptr(), wi.memptr(), l_eigvec.memptr(), &N, r_eigvec.memptr(), &N, work.memptr(), &lwork, &info);
Chris@49 1520
Chris@49 1521 eigval.set_size(A_n_rows);
Chris@49 1522 for(uword i=0; i<A_n_rows; ++i)
Chris@49 1523 {
Chris@49 1524 eigval[i] = std::complex<T>(wr[i], wi[i]);
Chris@49 1525 }
Chris@49 1526
Chris@49 1527 return (info == 0);
Chris@49 1528 }
Chris@49 1529 #else
Chris@49 1530 {
Chris@49 1531 arma_ignore(eigval);
Chris@49 1532 arma_ignore(l_eigvec);
Chris@49 1533 arma_ignore(r_eigvec);
Chris@49 1534 arma_ignore(X);
Chris@49 1535 arma_ignore(side);
Chris@49 1536 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
Chris@49 1537 return false;
Chris@49 1538 }
Chris@49 1539 #endif
Chris@49 1540 }
Chris@49 1541
Chris@49 1542
Chris@49 1543
Chris@49 1544
Chris@49 1545
Chris@49 1546 //! Eigenvalues and eigenvectors of a general square complex matrix using LAPACK
Chris@49 1547 //! The argument 'side' specifies which eigenvectors should be calculated
Chris@49 1548 //! (see code for mode details).
Chris@49 1549 template<typename T, typename T1>
Chris@49 1550 inline
Chris@49 1551 bool
Chris@49 1552 auxlib::eig_gen
Chris@49 1553 (
Chris@49 1554 Col< std::complex<T> >& eigval,
Chris@49 1555 Mat< std::complex<T> >& l_eigvec,
Chris@49 1556 Mat< std::complex<T> >& r_eigvec,
Chris@49 1557 const Base< std::complex<T>, T1 >& X,
Chris@49 1558 const char side
Chris@49 1559 )
Chris@49 1560 {
Chris@49 1561 arma_extra_debug_sigprint();
Chris@49 1562
Chris@49 1563 #if defined(ARMA_USE_LAPACK)
Chris@49 1564 {
Chris@49 1565 typedef typename std::complex<T> eT;
Chris@49 1566
Chris@49 1567 char jobvl;
Chris@49 1568 char jobvr;
Chris@49 1569
Chris@49 1570 switch(side)
Chris@49 1571 {
Chris@49 1572 case 'l': // left
Chris@49 1573 jobvl = 'V';
Chris@49 1574 jobvr = 'N';
Chris@49 1575 break;
Chris@49 1576
Chris@49 1577 case 'r': // right
Chris@49 1578 jobvl = 'N';
Chris@49 1579 jobvr = 'V';
Chris@49 1580 break;
Chris@49 1581
Chris@49 1582 case 'b': // both
Chris@49 1583 jobvl = 'V';
Chris@49 1584 jobvr = 'V';
Chris@49 1585 break;
Chris@49 1586
Chris@49 1587 case 'n': // neither
Chris@49 1588 jobvl = 'N';
Chris@49 1589 jobvr = 'N';
Chris@49 1590 break;
Chris@49 1591
Chris@49 1592 default:
Chris@49 1593 arma_stop("eig_gen(): parameter 'side' is invalid");
Chris@49 1594 return false;
Chris@49 1595 }
Chris@49 1596
Chris@49 1597 Mat<eT> A(X.get_ref());
Chris@49 1598 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
Chris@49 1599
Chris@49 1600 if(A.is_empty())
Chris@49 1601 {
Chris@49 1602 eigval.reset();
Chris@49 1603 l_eigvec.reset();
Chris@49 1604 r_eigvec.reset();
Chris@49 1605 return true;
Chris@49 1606 }
Chris@49 1607
Chris@49 1608 const uword A_n_rows = A.n_rows;
Chris@49 1609
Chris@49 1610 eigval.set_size(A_n_rows);
Chris@49 1611
Chris@49 1612 l_eigvec.set_size(A_n_rows, A_n_rows);
Chris@49 1613 r_eigvec.set_size(A_n_rows, A_n_rows);
Chris@49 1614
Chris@49 1615 blas_int N = blas_int(A_n_rows);
Chris@49 1616 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*N) );
Chris@49 1617 blas_int info = 0;
Chris@49 1618
Chris@49 1619 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 1620 podarray<T> rwork( static_cast<uword>(2*N) );
Chris@49 1621
Chris@49 1622 arma_extra_debug_print("lapack::cx_geev()");
Chris@49 1623 lapack::cx_geev(&jobvl, &jobvr, &N, A.memptr(), &N, eigval.memptr(), l_eigvec.memptr(), &N, r_eigvec.memptr(), &N, work.memptr(), &lwork, rwork.memptr(), &info);
Chris@49 1624
Chris@49 1625 return (info == 0);
Chris@49 1626 }
Chris@49 1627 #else
Chris@49 1628 {
Chris@49 1629 arma_ignore(eigval);
Chris@49 1630 arma_ignore(l_eigvec);
Chris@49 1631 arma_ignore(r_eigvec);
Chris@49 1632 arma_ignore(X);
Chris@49 1633 arma_ignore(side);
Chris@49 1634 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
Chris@49 1635 return false;
Chris@49 1636 }
Chris@49 1637 #endif
Chris@49 1638 }
Chris@49 1639
Chris@49 1640
Chris@49 1641
Chris@49 1642 template<typename eT, typename T1>
Chris@49 1643 inline
Chris@49 1644 bool
Chris@49 1645 auxlib::chol(Mat<eT>& out, const Base<eT,T1>& X)
Chris@49 1646 {
Chris@49 1647 arma_extra_debug_sigprint();
Chris@49 1648
Chris@49 1649 #if defined(ARMA_USE_LAPACK)
Chris@49 1650 {
Chris@49 1651 out = X.get_ref();
Chris@49 1652
Chris@49 1653 arma_debug_check( (out.is_square() == false), "chol(): given matrix is not square" );
Chris@49 1654
Chris@49 1655 if(out.is_empty())
Chris@49 1656 {
Chris@49 1657 return true;
Chris@49 1658 }
Chris@49 1659
Chris@49 1660 const uword out_n_rows = out.n_rows;
Chris@49 1661
Chris@49 1662 char uplo = 'U';
Chris@49 1663 blas_int n = out_n_rows;
Chris@49 1664 blas_int info = 0;
Chris@49 1665
Chris@49 1666 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
Chris@49 1667
Chris@49 1668 for(uword col=0; col<out_n_rows; ++col)
Chris@49 1669 {
Chris@49 1670 eT* colptr = out.colptr(col);
Chris@49 1671
Chris@49 1672 for(uword row=(col+1); row < out_n_rows; ++row)
Chris@49 1673 {
Chris@49 1674 colptr[row] = eT(0);
Chris@49 1675 }
Chris@49 1676 }
Chris@49 1677
Chris@49 1678 return (info == 0);
Chris@49 1679 }
Chris@49 1680 #else
Chris@49 1681 {
Chris@49 1682 arma_ignore(out);
Chris@49 1683 arma_ignore(X);
Chris@49 1684
Chris@49 1685 arma_stop("chol(): use of LAPACK needs to be enabled");
Chris@49 1686 return false;
Chris@49 1687 }
Chris@49 1688 #endif
Chris@49 1689 }
Chris@49 1690
Chris@49 1691
Chris@49 1692
Chris@49 1693 template<typename eT, typename T1>
Chris@49 1694 inline
Chris@49 1695 bool
Chris@49 1696 auxlib::qr(Mat<eT>& Q, Mat<eT>& R, const Base<eT,T1>& X)
Chris@49 1697 {
Chris@49 1698 arma_extra_debug_sigprint();
Chris@49 1699
Chris@49 1700 #if defined(ARMA_USE_LAPACK)
Chris@49 1701 {
Chris@49 1702 R = X.get_ref();
Chris@49 1703
Chris@49 1704 const uword R_n_rows = R.n_rows;
Chris@49 1705 const uword R_n_cols = R.n_cols;
Chris@49 1706
Chris@49 1707 if(R.is_empty())
Chris@49 1708 {
Chris@49 1709 Q.eye(R_n_rows, R_n_rows);
Chris@49 1710 return true;
Chris@49 1711 }
Chris@49 1712
Chris@49 1713 blas_int m = static_cast<blas_int>(R_n_rows);
Chris@49 1714 blas_int n = static_cast<blas_int>(R_n_cols);
Chris@49 1715 blas_int lwork = 0;
Chris@49 1716 blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr()
Chris@49 1717 blas_int k = (std::min)(m,n);
Chris@49 1718 blas_int info = 0;
Chris@49 1719
Chris@49 1720 podarray<eT> tau( static_cast<uword>(k) );
Chris@49 1721
Chris@49 1722 eT work_query[2];
Chris@49 1723 blas_int lwork_query = -1;
Chris@49 1724
Chris@49 1725 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info);
Chris@49 1726
Chris@49 1727 if(info == 0)
Chris@49 1728 {
Chris@49 1729 const blas_int lwork_proposed = static_cast<blas_int>( access::tmp_real(work_query[0]) );
Chris@49 1730
Chris@49 1731 lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min;
Chris@49 1732 }
Chris@49 1733 else
Chris@49 1734 {
Chris@49 1735 return false;
Chris@49 1736 }
Chris@49 1737
Chris@49 1738 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 1739
Chris@49 1740 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
Chris@49 1741
Chris@49 1742 Q.set_size(R_n_rows, R_n_rows);
Chris@49 1743
Chris@49 1744 arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) );
Chris@49 1745
Chris@49 1746 //
Chris@49 1747 // construct R
Chris@49 1748
Chris@49 1749 for(uword col=0; col < R_n_cols; ++col)
Chris@49 1750 {
Chris@49 1751 for(uword row=(col+1); row < R_n_rows; ++row)
Chris@49 1752 {
Chris@49 1753 R.at(row,col) = eT(0);
Chris@49 1754 }
Chris@49 1755 }
Chris@49 1756
Chris@49 1757
Chris@49 1758 if( (is_float<eT>::value == true) || (is_double<eT>::value == true) )
Chris@49 1759 {
Chris@49 1760 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
Chris@49 1761 }
Chris@49 1762 else
Chris@49 1763 if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) )
Chris@49 1764 {
Chris@49 1765 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
Chris@49 1766 }
Chris@49 1767
Chris@49 1768 return (info == 0);
Chris@49 1769 }
Chris@49 1770 #else
Chris@49 1771 {
Chris@49 1772 arma_ignore(Q);
Chris@49 1773 arma_ignore(R);
Chris@49 1774 arma_ignore(X);
Chris@49 1775 arma_stop("qr(): use of LAPACK needs to be enabled");
Chris@49 1776 return false;
Chris@49 1777 }
Chris@49 1778 #endif
Chris@49 1779 }
Chris@49 1780
Chris@49 1781
Chris@49 1782
Chris@49 1783 template<typename eT, typename T1>
Chris@49 1784 inline
Chris@49 1785 bool
Chris@49 1786 auxlib::qr_econ(Mat<eT>& Q, Mat<eT>& R, const Base<eT,T1>& X)
Chris@49 1787 {
Chris@49 1788 arma_extra_debug_sigprint();
Chris@49 1789
Chris@49 1790 // This function implements a memory-efficient QR for a non-square X that has dimensions m x n.
Chris@49 1791 // This basically discards the basis for the null-space.
Chris@49 1792 //
Chris@49 1793 // if m <= n: (use standard routine)
Chris@49 1794 // Q[m,m]*R[m,n] = X[m,n]
Chris@49 1795 // geqrf Needs A[m,n]: Uses R
Chris@49 1796 // orgqr Needs A[m,m]: Uses Q
Chris@49 1797 // otherwise: (memory-efficient routine)
Chris@49 1798 // Q[m,n]*R[n,n] = X[m,n]
Chris@49 1799 // geqrf Needs A[m,n]: Uses Q
Chris@49 1800 // geqrf Needs A[m,n]: Uses Q
Chris@49 1801
Chris@49 1802 #if defined(ARMA_USE_LAPACK)
Chris@49 1803 {
Chris@49 1804 if(is_Mat<T1>::value == true)
Chris@49 1805 {
Chris@49 1806 const unwrap<T1> tmp(X.get_ref());
Chris@49 1807 const Mat<eT>& M = tmp.M;
Chris@49 1808
Chris@49 1809 if(M.n_rows < M.n_cols)
Chris@49 1810 {
Chris@49 1811 return auxlib::qr(Q, R, X);
Chris@49 1812 }
Chris@49 1813 }
Chris@49 1814
Chris@49 1815 Q = X.get_ref();
Chris@49 1816
Chris@49 1817 const uword Q_n_rows = Q.n_rows;
Chris@49 1818 const uword Q_n_cols = Q.n_cols;
Chris@49 1819
Chris@49 1820 if( Q_n_rows <= Q_n_cols )
Chris@49 1821 {
Chris@49 1822 return auxlib::qr(Q, R, Q);
Chris@49 1823 }
Chris@49 1824
Chris@49 1825 if(Q.is_empty())
Chris@49 1826 {
Chris@49 1827 Q.set_size(Q_n_rows, 0 );
Chris@49 1828 R.set_size(0, Q_n_cols);
Chris@49 1829 return true;
Chris@49 1830 }
Chris@49 1831
Chris@49 1832 blas_int m = static_cast<blas_int>(Q_n_rows);
Chris@49 1833 blas_int n = static_cast<blas_int>(Q_n_cols);
Chris@49 1834 blas_int lwork = 0;
Chris@49 1835 blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr()
Chris@49 1836 blas_int k = (std::min)(m,n);
Chris@49 1837 blas_int info = 0;
Chris@49 1838
Chris@49 1839 podarray<eT> tau( static_cast<uword>(k) );
Chris@49 1840
Chris@49 1841 eT work_query[2];
Chris@49 1842 blas_int lwork_query = -1;
Chris@49 1843
Chris@49 1844 lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info);
Chris@49 1845
Chris@49 1846 if(info == 0)
Chris@49 1847 {
Chris@49 1848 const blas_int lwork_proposed = static_cast<blas_int>( access::tmp_real(work_query[0]) );
Chris@49 1849
Chris@49 1850 lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min;
Chris@49 1851 }
Chris@49 1852 else
Chris@49 1853 {
Chris@49 1854 return false;
Chris@49 1855 }
Chris@49 1856
Chris@49 1857 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 1858
Chris@49 1859 lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
Chris@49 1860
Chris@49 1861 // Q now has the elements on and above the diagonal of the array
Chris@49 1862 // contain the min(M,N)-by-N upper trapezoidal matrix Q
Chris@49 1863 // (Q is upper triangular if m >= n);
Chris@49 1864 // the elements below the diagonal, with the array TAU,
Chris@49 1865 // represent the orthogonal matrix Q as a product of min(m,n) elementary reflectors.
Chris@49 1866
Chris@49 1867 R.set_size(Q_n_cols, Q_n_cols);
Chris@49 1868
Chris@49 1869 //
Chris@49 1870 // construct R
Chris@49 1871
Chris@49 1872 for(uword col=0; col < Q_n_cols; ++col)
Chris@49 1873 {
Chris@49 1874 for(uword row=0; row <= col; ++row)
Chris@49 1875 {
Chris@49 1876 R.at(row,col) = Q.at(row,col);
Chris@49 1877 }
Chris@49 1878
Chris@49 1879 for(uword row=(col+1); row < Q_n_cols; ++row)
Chris@49 1880 {
Chris@49 1881 R.at(row,col) = eT(0);
Chris@49 1882 }
Chris@49 1883 }
Chris@49 1884
Chris@49 1885 if( (is_float<eT>::value == true) || (is_double<eT>::value == true) )
Chris@49 1886 {
Chris@49 1887 lapack::orgqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
Chris@49 1888 }
Chris@49 1889 else
Chris@49 1890 if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) )
Chris@49 1891 {
Chris@49 1892 lapack::ungqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
Chris@49 1893 }
Chris@49 1894
Chris@49 1895 return (info == 0);
Chris@49 1896 }
Chris@49 1897 #else
Chris@49 1898 {
Chris@49 1899 arma_ignore(Q);
Chris@49 1900 arma_ignore(R);
Chris@49 1901 arma_ignore(X);
Chris@49 1902 arma_stop("qr_econ(): use of LAPACK needs to be enabled");
Chris@49 1903 return false;
Chris@49 1904 }
Chris@49 1905 #endif
Chris@49 1906 }
Chris@49 1907
Chris@49 1908
Chris@49 1909
Chris@49 1910 template<typename eT, typename T1>
Chris@49 1911 inline
Chris@49 1912 bool
Chris@49 1913 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X, uword& X_n_rows, uword& X_n_cols)
Chris@49 1914 {
Chris@49 1915 arma_extra_debug_sigprint();
Chris@49 1916
Chris@49 1917 #if defined(ARMA_USE_LAPACK)
Chris@49 1918 {
Chris@49 1919 Mat<eT> A(X.get_ref());
Chris@49 1920
Chris@49 1921 X_n_rows = A.n_rows;
Chris@49 1922 X_n_cols = A.n_cols;
Chris@49 1923
Chris@49 1924 if(A.is_empty())
Chris@49 1925 {
Chris@49 1926 S.reset();
Chris@49 1927 return true;
Chris@49 1928 }
Chris@49 1929
Chris@49 1930 Mat<eT> U(1, 1);
Chris@49 1931 Mat<eT> V(1, A.n_cols);
Chris@49 1932
Chris@49 1933 char jobu = 'N';
Chris@49 1934 char jobvt = 'N';
Chris@49 1935
Chris@49 1936 blas_int m = A.n_rows;
Chris@49 1937 blas_int n = A.n_cols;
Chris@49 1938 blas_int min_mn = (std::min)(m,n);
Chris@49 1939 blas_int lda = A.n_rows;
Chris@49 1940 blas_int ldu = U.n_rows;
Chris@49 1941 blas_int ldvt = V.n_rows;
Chris@49 1942 blas_int lwork = 0;
Chris@49 1943 blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) );
Chris@49 1944 blas_int info = 0;
Chris@49 1945
Chris@49 1946 S.set_size( static_cast<uword>(min_mn) );
Chris@49 1947
Chris@49 1948 eT work_query[2];
Chris@49 1949 blas_int lwork_query = -1;
Chris@49 1950
Chris@49 1951 lapack::gesvd<eT>
Chris@49 1952 (
Chris@49 1953 &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info
Chris@49 1954 );
Chris@49 1955
Chris@49 1956 if(info == 0)
Chris@49 1957 {
Chris@49 1958 const blas_int lwork_proposed = static_cast<blas_int>( work_query[0] );
Chris@49 1959
Chris@49 1960 lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min;
Chris@49 1961
Chris@49 1962 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 1963
Chris@49 1964 lapack::gesvd<eT>
Chris@49 1965 (
Chris@49 1966 &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, &info
Chris@49 1967 );
Chris@49 1968 }
Chris@49 1969
Chris@49 1970 return (info == 0);
Chris@49 1971 }
Chris@49 1972 #else
Chris@49 1973 {
Chris@49 1974 arma_ignore(S);
Chris@49 1975 arma_ignore(X);
Chris@49 1976 arma_ignore(X_n_rows);
Chris@49 1977 arma_ignore(X_n_cols);
Chris@49 1978 arma_stop("svd(): use of LAPACK needs to be enabled");
Chris@49 1979 return false;
Chris@49 1980 }
Chris@49 1981 #endif
Chris@49 1982 }
Chris@49 1983
Chris@49 1984
Chris@49 1985
Chris@49 1986 template<typename T, typename T1>
Chris@49 1987 inline
Chris@49 1988 bool
Chris@49 1989 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X, uword& X_n_rows, uword& X_n_cols)
Chris@49 1990 {
Chris@49 1991 arma_extra_debug_sigprint();
Chris@49 1992
Chris@49 1993 #if defined(ARMA_USE_LAPACK)
Chris@49 1994 {
Chris@49 1995 typedef std::complex<T> eT;
Chris@49 1996
Chris@49 1997 Mat<eT> A(X.get_ref());
Chris@49 1998
Chris@49 1999 X_n_rows = A.n_rows;
Chris@49 2000 X_n_cols = A.n_cols;
Chris@49 2001
Chris@49 2002 if(A.is_empty())
Chris@49 2003 {
Chris@49 2004 S.reset();
Chris@49 2005 return true;
Chris@49 2006 }
Chris@49 2007
Chris@49 2008 Mat<eT> U(1, 1);
Chris@49 2009 Mat<eT> V(1, A.n_cols);
Chris@49 2010
Chris@49 2011 char jobu = 'N';
Chris@49 2012 char jobvt = 'N';
Chris@49 2013
Chris@49 2014 blas_int m = A.n_rows;
Chris@49 2015 blas_int n = A.n_cols;
Chris@49 2016 blas_int min_mn = (std::min)(m,n);
Chris@49 2017 blas_int lda = A.n_rows;
Chris@49 2018 blas_int ldu = U.n_rows;
Chris@49 2019 blas_int ldvt = V.n_rows;
Chris@49 2020 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*min_mn+(std::max)(m,n) ) );
Chris@49 2021 blas_int info = 0;
Chris@49 2022
Chris@49 2023 S.set_size( static_cast<uword>(min_mn) );
Chris@49 2024
Chris@49 2025 podarray<eT> work( static_cast<uword>(lwork ) );
Chris@49 2026 podarray< T> rwork( static_cast<uword>(5*min_mn) );
Chris@49 2027
Chris@49 2028 // let gesvd_() calculate the optimum size of the workspace
Chris@49 2029 blas_int lwork_tmp = -1;
Chris@49 2030
Chris@49 2031 lapack::cx_gesvd<T>
Chris@49 2032 (
Chris@49 2033 &jobu, &jobvt,
Chris@49 2034 &m, &n,
Chris@49 2035 A.memptr(), &lda,
Chris@49 2036 S.memptr(),
Chris@49 2037 U.memptr(), &ldu,
Chris@49 2038 V.memptr(), &ldvt,
Chris@49 2039 work.memptr(), &lwork_tmp,
Chris@49 2040 rwork.memptr(),
Chris@49 2041 &info
Chris@49 2042 );
Chris@49 2043
Chris@49 2044 if(info == 0)
Chris@49 2045 {
Chris@49 2046 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
Chris@49 2047 if(proposed_lwork > lwork)
Chris@49 2048 {
Chris@49 2049 lwork = proposed_lwork;
Chris@49 2050 work.set_size( static_cast<uword>(lwork) );
Chris@49 2051 }
Chris@49 2052
Chris@49 2053 lapack::cx_gesvd<T>
Chris@49 2054 (
Chris@49 2055 &jobu, &jobvt,
Chris@49 2056 &m, &n,
Chris@49 2057 A.memptr(), &lda,
Chris@49 2058 S.memptr(),
Chris@49 2059 U.memptr(), &ldu,
Chris@49 2060 V.memptr(), &ldvt,
Chris@49 2061 work.memptr(), &lwork,
Chris@49 2062 rwork.memptr(),
Chris@49 2063 &info
Chris@49 2064 );
Chris@49 2065 }
Chris@49 2066
Chris@49 2067 return (info == 0);
Chris@49 2068 }
Chris@49 2069 #else
Chris@49 2070 {
Chris@49 2071 arma_ignore(S);
Chris@49 2072 arma_ignore(X);
Chris@49 2073 arma_ignore(X_n_rows);
Chris@49 2074 arma_ignore(X_n_cols);
Chris@49 2075
Chris@49 2076 arma_stop("svd(): use of LAPACK needs to be enabled");
Chris@49 2077 return false;
Chris@49 2078 }
Chris@49 2079 #endif
Chris@49 2080 }
Chris@49 2081
Chris@49 2082
Chris@49 2083
Chris@49 2084 template<typename eT, typename T1>
Chris@49 2085 inline
Chris@49 2086 bool
Chris@49 2087 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X)
Chris@49 2088 {
Chris@49 2089 arma_extra_debug_sigprint();
Chris@49 2090
Chris@49 2091 uword junk;
Chris@49 2092 return auxlib::svd(S, X, junk, junk);
Chris@49 2093 }
Chris@49 2094
Chris@49 2095
Chris@49 2096
Chris@49 2097 template<typename T, typename T1>
Chris@49 2098 inline
Chris@49 2099 bool
Chris@49 2100 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X)
Chris@49 2101 {
Chris@49 2102 arma_extra_debug_sigprint();
Chris@49 2103
Chris@49 2104 uword junk;
Chris@49 2105 return auxlib::svd(S, X, junk, junk);
Chris@49 2106 }
Chris@49 2107
Chris@49 2108
Chris@49 2109
Chris@49 2110 template<typename eT, typename T1>
Chris@49 2111 inline
Chris@49 2112 bool
Chris@49 2113 auxlib::svd(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X)
Chris@49 2114 {
Chris@49 2115 arma_extra_debug_sigprint();
Chris@49 2116
Chris@49 2117 #if defined(ARMA_USE_LAPACK)
Chris@49 2118 {
Chris@49 2119 Mat<eT> A(X.get_ref());
Chris@49 2120
Chris@49 2121 if(A.is_empty())
Chris@49 2122 {
Chris@49 2123 U.eye(A.n_rows, A.n_rows);
Chris@49 2124 S.reset();
Chris@49 2125 V.eye(A.n_cols, A.n_cols);
Chris@49 2126 return true;
Chris@49 2127 }
Chris@49 2128
Chris@49 2129 U.set_size(A.n_rows, A.n_rows);
Chris@49 2130 V.set_size(A.n_cols, A.n_cols);
Chris@49 2131
Chris@49 2132 char jobu = 'A';
Chris@49 2133 char jobvt = 'A';
Chris@49 2134
Chris@49 2135 blas_int m = blas_int(A.n_rows);
Chris@49 2136 blas_int n = blas_int(A.n_cols);
Chris@49 2137 blas_int min_mn = (std::min)(m,n);
Chris@49 2138 blas_int lda = blas_int(A.n_rows);
Chris@49 2139 blas_int ldu = blas_int(U.n_rows);
Chris@49 2140 blas_int ldvt = blas_int(V.n_rows);
Chris@49 2141 blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) );
Chris@49 2142 blas_int lwork = 0;
Chris@49 2143 blas_int info = 0;
Chris@49 2144
Chris@49 2145 S.set_size( static_cast<uword>(min_mn) );
Chris@49 2146
Chris@49 2147 // let gesvd_() calculate the optimum size of the workspace
Chris@49 2148 eT work_query[2];
Chris@49 2149 blas_int lwork_query = -1;
Chris@49 2150
Chris@49 2151 lapack::gesvd<eT>
Chris@49 2152 (
Chris@49 2153 &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info
Chris@49 2154 );
Chris@49 2155
Chris@49 2156 if(info == 0)
Chris@49 2157 {
Chris@49 2158 const blas_int lwork_proposed = static_cast<blas_int>( work_query[0] );
Chris@49 2159
Chris@49 2160 lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min;
Chris@49 2161
Chris@49 2162 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 2163
Chris@49 2164 lapack::gesvd<eT>
Chris@49 2165 (
Chris@49 2166 &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, &info
Chris@49 2167 );
Chris@49 2168
Chris@49 2169 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
Chris@49 2170 }
Chris@49 2171
Chris@49 2172 return (info == 0);
Chris@49 2173 }
Chris@49 2174 #else
Chris@49 2175 {
Chris@49 2176 arma_ignore(U);
Chris@49 2177 arma_ignore(S);
Chris@49 2178 arma_ignore(V);
Chris@49 2179 arma_ignore(X);
Chris@49 2180 arma_stop("svd(): use of LAPACK needs to be enabled");
Chris@49 2181 return false;
Chris@49 2182 }
Chris@49 2183 #endif
Chris@49 2184 }
Chris@49 2185
Chris@49 2186
Chris@49 2187
Chris@49 2188 template<typename T, typename T1>
Chris@49 2189 inline
Chris@49 2190 bool
Chris@49 2191 auxlib::svd(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X)
Chris@49 2192 {
Chris@49 2193 arma_extra_debug_sigprint();
Chris@49 2194
Chris@49 2195 #if defined(ARMA_USE_LAPACK)
Chris@49 2196 {
Chris@49 2197 typedef std::complex<T> eT;
Chris@49 2198
Chris@49 2199 Mat<eT> A(X.get_ref());
Chris@49 2200
Chris@49 2201 if(A.is_empty())
Chris@49 2202 {
Chris@49 2203 U.eye(A.n_rows, A.n_rows);
Chris@49 2204 S.reset();
Chris@49 2205 V.eye(A.n_cols, A.n_cols);
Chris@49 2206 return true;
Chris@49 2207 }
Chris@49 2208
Chris@49 2209 U.set_size(A.n_rows, A.n_rows);
Chris@49 2210 V.set_size(A.n_cols, A.n_cols);
Chris@49 2211
Chris@49 2212 char jobu = 'A';
Chris@49 2213 char jobvt = 'A';
Chris@49 2214
Chris@49 2215 blas_int m = blas_int(A.n_rows);
Chris@49 2216 blas_int n = blas_int(A.n_cols);
Chris@49 2217 blas_int min_mn = (std::min)(m,n);
Chris@49 2218 blas_int lda = blas_int(A.n_rows);
Chris@49 2219 blas_int ldu = blas_int(U.n_rows);
Chris@49 2220 blas_int ldvt = blas_int(V.n_rows);
Chris@49 2221 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*min_mn + (std::max)(m,n) ) );
Chris@49 2222 blas_int info = 0;
Chris@49 2223
Chris@49 2224 S.set_size( static_cast<uword>(min_mn) );
Chris@49 2225
Chris@49 2226 podarray<eT> work( static_cast<uword>(lwork ) );
Chris@49 2227 podarray<T> rwork( static_cast<uword>(5*min_mn) );
Chris@49 2228
Chris@49 2229 // let gesvd_() calculate the optimum size of the workspace
Chris@49 2230 blas_int lwork_tmp = -1;
Chris@49 2231 lapack::cx_gesvd<T>
Chris@49 2232 (
Chris@49 2233 &jobu, &jobvt,
Chris@49 2234 &m, &n,
Chris@49 2235 A.memptr(), &lda,
Chris@49 2236 S.memptr(),
Chris@49 2237 U.memptr(), &ldu,
Chris@49 2238 V.memptr(), &ldvt,
Chris@49 2239 work.memptr(), &lwork_tmp,
Chris@49 2240 rwork.memptr(),
Chris@49 2241 &info
Chris@49 2242 );
Chris@49 2243
Chris@49 2244 if(info == 0)
Chris@49 2245 {
Chris@49 2246 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
Chris@49 2247
Chris@49 2248 if(proposed_lwork > lwork)
Chris@49 2249 {
Chris@49 2250 lwork = proposed_lwork;
Chris@49 2251 work.set_size( static_cast<uword>(lwork) );
Chris@49 2252 }
Chris@49 2253
Chris@49 2254 lapack::cx_gesvd<T>
Chris@49 2255 (
Chris@49 2256 &jobu, &jobvt,
Chris@49 2257 &m, &n,
Chris@49 2258 A.memptr(), &lda,
Chris@49 2259 S.memptr(),
Chris@49 2260 U.memptr(), &ldu,
Chris@49 2261 V.memptr(), &ldvt,
Chris@49 2262 work.memptr(), &lwork,
Chris@49 2263 rwork.memptr(),
Chris@49 2264 &info
Chris@49 2265 );
Chris@49 2266
Chris@49 2267 op_htrans::apply(V,V); // op_htrans will work out that an in-place transpose can be done
Chris@49 2268 }
Chris@49 2269
Chris@49 2270 return (info == 0);
Chris@49 2271 }
Chris@49 2272 #else
Chris@49 2273 {
Chris@49 2274 arma_ignore(U);
Chris@49 2275 arma_ignore(S);
Chris@49 2276 arma_ignore(V);
Chris@49 2277 arma_ignore(X);
Chris@49 2278 arma_stop("svd(): use of LAPACK needs to be enabled");
Chris@49 2279 return false;
Chris@49 2280 }
Chris@49 2281 #endif
Chris@49 2282 }
Chris@49 2283
Chris@49 2284
Chris@49 2285
Chris@49 2286 template<typename eT, typename T1>
Chris@49 2287 inline
Chris@49 2288 bool
Chris@49 2289 auxlib::svd_econ(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X, const char mode)
Chris@49 2290 {
Chris@49 2291 arma_extra_debug_sigprint();
Chris@49 2292
Chris@49 2293 #if defined(ARMA_USE_LAPACK)
Chris@49 2294 {
Chris@49 2295 Mat<eT> A(X.get_ref());
Chris@49 2296
Chris@49 2297 blas_int m = blas_int(A.n_rows);
Chris@49 2298 blas_int n = blas_int(A.n_cols);
Chris@49 2299 blas_int min_mn = (std::min)(m,n);
Chris@49 2300 blas_int lda = blas_int(A.n_rows);
Chris@49 2301
Chris@49 2302 S.set_size( static_cast<uword>(min_mn) );
Chris@49 2303
Chris@49 2304 blas_int ldu = 0;
Chris@49 2305 blas_int ldvt = 0;
Chris@49 2306
Chris@49 2307 char jobu;
Chris@49 2308 char jobvt;
Chris@49 2309
Chris@49 2310 switch(mode)
Chris@49 2311 {
Chris@49 2312 case 'l':
Chris@49 2313 jobu = 'S';
Chris@49 2314 jobvt = 'N';
Chris@49 2315
Chris@49 2316 ldu = m;
Chris@49 2317 ldvt = 1;
Chris@49 2318
Chris@49 2319 U.set_size( static_cast<uword>(ldu), static_cast<uword>(min_mn) );
Chris@49 2320 V.reset();
Chris@49 2321
Chris@49 2322 break;
Chris@49 2323
Chris@49 2324
Chris@49 2325 case 'r':
Chris@49 2326 jobu = 'N';
Chris@49 2327 jobvt = 'S';
Chris@49 2328
Chris@49 2329 ldu = 1;
Chris@49 2330 ldvt = (std::min)(m,n);
Chris@49 2331
Chris@49 2332 U.reset();
Chris@49 2333 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
Chris@49 2334
Chris@49 2335 break;
Chris@49 2336
Chris@49 2337
Chris@49 2338 case 'b':
Chris@49 2339 jobu = 'S';
Chris@49 2340 jobvt = 'S';
Chris@49 2341
Chris@49 2342 ldu = m;
Chris@49 2343 ldvt = (std::min)(m,n);
Chris@49 2344
Chris@49 2345 U.set_size( static_cast<uword>(ldu), static_cast<uword>(min_mn) );
Chris@49 2346 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n ) );
Chris@49 2347
Chris@49 2348 break;
Chris@49 2349
Chris@49 2350
Chris@49 2351 default:
Chris@49 2352 U.reset();
Chris@49 2353 S.reset();
Chris@49 2354 V.reset();
Chris@49 2355 return false;
Chris@49 2356 }
Chris@49 2357
Chris@49 2358
Chris@49 2359 if(A.is_empty())
Chris@49 2360 {
Chris@49 2361 U.eye();
Chris@49 2362 S.reset();
Chris@49 2363 V.eye();
Chris@49 2364 return true;
Chris@49 2365 }
Chris@49 2366
Chris@49 2367
Chris@49 2368 blas_int lwork = 3 * ( (std::max)(blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ) );
Chris@49 2369 blas_int info = 0;
Chris@49 2370
Chris@49 2371
Chris@49 2372 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 2373
Chris@49 2374 // let gesvd_() calculate the optimum size of the workspace
Chris@49 2375 blas_int lwork_tmp = -1;
Chris@49 2376
Chris@49 2377 lapack::gesvd<eT>
Chris@49 2378 (
Chris@49 2379 &jobu, &jobvt,
Chris@49 2380 &m, &n,
Chris@49 2381 A.memptr(), &lda,
Chris@49 2382 S.memptr(),
Chris@49 2383 U.memptr(), &ldu,
Chris@49 2384 V.memptr(), &ldvt,
Chris@49 2385 work.memptr(), &lwork_tmp,
Chris@49 2386 &info
Chris@49 2387 );
Chris@49 2388
Chris@49 2389 if(info == 0)
Chris@49 2390 {
Chris@49 2391 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
Chris@49 2392 if(proposed_lwork > lwork)
Chris@49 2393 {
Chris@49 2394 lwork = proposed_lwork;
Chris@49 2395 work.set_size( static_cast<uword>(lwork) );
Chris@49 2396 }
Chris@49 2397
Chris@49 2398 lapack::gesvd<eT>
Chris@49 2399 (
Chris@49 2400 &jobu, &jobvt,
Chris@49 2401 &m, &n,
Chris@49 2402 A.memptr(), &lda,
Chris@49 2403 S.memptr(),
Chris@49 2404 U.memptr(), &ldu,
Chris@49 2405 V.memptr(), &ldvt,
Chris@49 2406 work.memptr(), &lwork,
Chris@49 2407 &info
Chris@49 2408 );
Chris@49 2409
Chris@49 2410 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
Chris@49 2411 }
Chris@49 2412
Chris@49 2413 return (info == 0);
Chris@49 2414 }
Chris@49 2415 #else
Chris@49 2416 {
Chris@49 2417 arma_ignore(U);
Chris@49 2418 arma_ignore(S);
Chris@49 2419 arma_ignore(V);
Chris@49 2420 arma_ignore(X);
Chris@49 2421 arma_ignore(mode);
Chris@49 2422 arma_stop("svd(): use of LAPACK needs to be enabled");
Chris@49 2423 return false;
Chris@49 2424 }
Chris@49 2425 #endif
Chris@49 2426 }
Chris@49 2427
Chris@49 2428
Chris@49 2429
Chris@49 2430 template<typename T, typename T1>
Chris@49 2431 inline
Chris@49 2432 bool
Chris@49 2433 auxlib::svd_econ(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X, const char mode)
Chris@49 2434 {
Chris@49 2435 arma_extra_debug_sigprint();
Chris@49 2436
Chris@49 2437 #if defined(ARMA_USE_LAPACK)
Chris@49 2438 {
Chris@49 2439 typedef std::complex<T> eT;
Chris@49 2440
Chris@49 2441 Mat<eT> A(X.get_ref());
Chris@49 2442
Chris@49 2443 blas_int m = blas_int(A.n_rows);
Chris@49 2444 blas_int n = blas_int(A.n_cols);
Chris@49 2445 blas_int min_mn = (std::min)(m,n);
Chris@49 2446 blas_int lda = blas_int(A.n_rows);
Chris@49 2447
Chris@49 2448 S.set_size( static_cast<uword>(min_mn) );
Chris@49 2449
Chris@49 2450 blas_int ldu = 0;
Chris@49 2451 blas_int ldvt = 0;
Chris@49 2452
Chris@49 2453 char jobu;
Chris@49 2454 char jobvt;
Chris@49 2455
Chris@49 2456 switch(mode)
Chris@49 2457 {
Chris@49 2458 case 'l':
Chris@49 2459 jobu = 'S';
Chris@49 2460 jobvt = 'N';
Chris@49 2461
Chris@49 2462 ldu = m;
Chris@49 2463 ldvt = 1;
Chris@49 2464
Chris@49 2465 U.set_size( static_cast<uword>(ldu), static_cast<uword>(min_mn) );
Chris@49 2466 V.reset();
Chris@49 2467
Chris@49 2468 break;
Chris@49 2469
Chris@49 2470
Chris@49 2471 case 'r':
Chris@49 2472 jobu = 'N';
Chris@49 2473 jobvt = 'S';
Chris@49 2474
Chris@49 2475 ldu = 1;
Chris@49 2476 ldvt = (std::min)(m,n);
Chris@49 2477
Chris@49 2478 U.reset();
Chris@49 2479 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
Chris@49 2480
Chris@49 2481 break;
Chris@49 2482
Chris@49 2483
Chris@49 2484 case 'b':
Chris@49 2485 jobu = 'S';
Chris@49 2486 jobvt = 'S';
Chris@49 2487
Chris@49 2488 ldu = m;
Chris@49 2489 ldvt = (std::min)(m,n);
Chris@49 2490
Chris@49 2491 U.set_size( static_cast<uword>(ldu), static_cast<uword>(min_mn) );
Chris@49 2492 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
Chris@49 2493
Chris@49 2494 break;
Chris@49 2495
Chris@49 2496
Chris@49 2497 default:
Chris@49 2498 U.reset();
Chris@49 2499 S.reset();
Chris@49 2500 V.reset();
Chris@49 2501 return false;
Chris@49 2502 }
Chris@49 2503
Chris@49 2504
Chris@49 2505 if(A.is_empty())
Chris@49 2506 {
Chris@49 2507 U.eye();
Chris@49 2508 S.reset();
Chris@49 2509 V.eye();
Chris@49 2510 return true;
Chris@49 2511 }
Chris@49 2512
Chris@49 2513
Chris@49 2514 blas_int lwork = 3 * ( (std::max)(blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ) );
Chris@49 2515 blas_int info = 0;
Chris@49 2516
Chris@49 2517
Chris@49 2518 podarray<eT> work( static_cast<uword>(lwork ) );
Chris@49 2519 podarray<T> rwork( static_cast<uword>(5*min_mn) );
Chris@49 2520
Chris@49 2521 // let gesvd_() calculate the optimum size of the workspace
Chris@49 2522 blas_int lwork_tmp = -1;
Chris@49 2523
Chris@49 2524 lapack::cx_gesvd<T>
Chris@49 2525 (
Chris@49 2526 &jobu, &jobvt,
Chris@49 2527 &m, &n,
Chris@49 2528 A.memptr(), &lda,
Chris@49 2529 S.memptr(),
Chris@49 2530 U.memptr(), &ldu,
Chris@49 2531 V.memptr(), &ldvt,
Chris@49 2532 work.memptr(), &lwork_tmp,
Chris@49 2533 rwork.memptr(),
Chris@49 2534 &info
Chris@49 2535 );
Chris@49 2536
Chris@49 2537 if(info == 0)
Chris@49 2538 {
Chris@49 2539 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
Chris@49 2540 if(proposed_lwork > lwork)
Chris@49 2541 {
Chris@49 2542 lwork = proposed_lwork;
Chris@49 2543 work.set_size( static_cast<uword>(lwork) );
Chris@49 2544 }
Chris@49 2545
Chris@49 2546 lapack::cx_gesvd<T>
Chris@49 2547 (
Chris@49 2548 &jobu, &jobvt,
Chris@49 2549 &m, &n,
Chris@49 2550 A.memptr(), &lda,
Chris@49 2551 S.memptr(),
Chris@49 2552 U.memptr(), &ldu,
Chris@49 2553 V.memptr(), &ldvt,
Chris@49 2554 work.memptr(), &lwork,
Chris@49 2555 rwork.memptr(),
Chris@49 2556 &info
Chris@49 2557 );
Chris@49 2558
Chris@49 2559 op_htrans::apply(V,V); // op_strans will work out that an in-place transpose can be done
Chris@49 2560 }
Chris@49 2561
Chris@49 2562 return (info == 0);
Chris@49 2563 }
Chris@49 2564 #else
Chris@49 2565 {
Chris@49 2566 arma_ignore(U);
Chris@49 2567 arma_ignore(S);
Chris@49 2568 arma_ignore(V);
Chris@49 2569 arma_ignore(X);
Chris@49 2570 arma_ignore(mode);
Chris@49 2571 arma_stop("svd(): use of LAPACK needs to be enabled");
Chris@49 2572 return false;
Chris@49 2573 }
Chris@49 2574 #endif
Chris@49 2575 }
Chris@49 2576
Chris@49 2577
Chris@49 2578
Chris@49 2579 template<typename eT, typename T1>
Chris@49 2580 inline
Chris@49 2581 bool
Chris@49 2582 auxlib::svd_dc(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X)
Chris@49 2583 {
Chris@49 2584 arma_extra_debug_sigprint();
Chris@49 2585
Chris@49 2586 #if defined(ARMA_USE_LAPACK)
Chris@49 2587 {
Chris@49 2588 Mat<eT> A(X.get_ref());
Chris@49 2589
Chris@49 2590 if(A.is_empty())
Chris@49 2591 {
Chris@49 2592 U.eye(A.n_rows, A.n_rows);
Chris@49 2593 S.reset();
Chris@49 2594 V.eye(A.n_cols, A.n_cols);
Chris@49 2595 return true;
Chris@49 2596 }
Chris@49 2597
Chris@49 2598 U.set_size(A.n_rows, A.n_rows);
Chris@49 2599 V.set_size(A.n_cols, A.n_cols);
Chris@49 2600
Chris@49 2601 char jobz = 'A';
Chris@49 2602
Chris@49 2603 blas_int m = blas_int(A.n_rows);
Chris@49 2604 blas_int n = blas_int(A.n_cols);
Chris@49 2605 blas_int min_mn = (std::min)(m,n);
Chris@49 2606 blas_int lda = blas_int(A.n_rows);
Chris@49 2607 blas_int ldu = blas_int(U.n_rows);
Chris@49 2608 blas_int ldvt = blas_int(V.n_rows);
Chris@49 2609 blas_int lwork = 3 * ( 3*min_mn*min_mn + (std::max)( (std::max)(m,n), 4*min_mn*min_mn + 4*min_mn ) );
Chris@49 2610 blas_int info = 0;
Chris@49 2611
Chris@49 2612 S.set_size( static_cast<uword>(min_mn) );
Chris@49 2613
Chris@49 2614 podarray<eT> work( static_cast<uword>(lwork ) );
Chris@49 2615 podarray<blas_int> iwork( static_cast<uword>(8*min_mn) );
Chris@49 2616
Chris@49 2617 lapack::gesdd<eT>
Chris@49 2618 (
Chris@49 2619 &jobz, &m, &n,
Chris@49 2620 A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt,
Chris@49 2621 work.memptr(), &lwork, iwork.memptr(), &info
Chris@49 2622 );
Chris@49 2623
Chris@49 2624 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
Chris@49 2625
Chris@49 2626 return (info == 0);
Chris@49 2627 }
Chris@49 2628 #else
Chris@49 2629 {
Chris@49 2630 arma_ignore(U);
Chris@49 2631 arma_ignore(S);
Chris@49 2632 arma_ignore(V);
Chris@49 2633 arma_ignore(X);
Chris@49 2634 arma_stop("svd(): use of LAPACK needs to be enabled");
Chris@49 2635 return false;
Chris@49 2636 }
Chris@49 2637 #endif
Chris@49 2638 }
Chris@49 2639
Chris@49 2640
Chris@49 2641
Chris@49 2642 template<typename T, typename T1>
Chris@49 2643 inline
Chris@49 2644 bool
Chris@49 2645 auxlib::svd_dc(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X)
Chris@49 2646 {
Chris@49 2647 arma_extra_debug_sigprint();
Chris@49 2648
Chris@49 2649 #if defined(ARMA_USE_LAPACK)
Chris@49 2650 {
Chris@49 2651 typedef std::complex<T> eT;
Chris@49 2652
Chris@49 2653 Mat<eT> A(X.get_ref());
Chris@49 2654
Chris@49 2655 if(A.is_empty())
Chris@49 2656 {
Chris@49 2657 U.eye(A.n_rows, A.n_rows);
Chris@49 2658 S.reset();
Chris@49 2659 V.eye(A.n_cols, A.n_cols);
Chris@49 2660 return true;
Chris@49 2661 }
Chris@49 2662
Chris@49 2663 U.set_size(A.n_rows, A.n_rows);
Chris@49 2664 V.set_size(A.n_cols, A.n_cols);
Chris@49 2665
Chris@49 2666 char jobz = 'A';
Chris@49 2667
Chris@49 2668 blas_int m = blas_int(A.n_rows);
Chris@49 2669 blas_int n = blas_int(A.n_cols);
Chris@49 2670 blas_int min_mn = (std::min)(m,n);
Chris@49 2671 blas_int lda = blas_int(A.n_rows);
Chris@49 2672 blas_int ldu = blas_int(U.n_rows);
Chris@49 2673 blas_int ldvt = blas_int(V.n_rows);
Chris@49 2674 blas_int lwork = 3 * (min_mn*min_mn + 2*min_mn + (std::max)(m,n));
Chris@49 2675 blas_int info = 0;
Chris@49 2676
Chris@49 2677 S.set_size( static_cast<uword>(min_mn) );
Chris@49 2678
Chris@49 2679 podarray<eT> work( static_cast<uword>(lwork ) );
Chris@49 2680 podarray<T> rwork( static_cast<uword>(5*min_mn*min_mn + 7*min_mn) );
Chris@49 2681 podarray<blas_int> iwork( static_cast<uword>(8*min_mn ) );
Chris@49 2682
Chris@49 2683 lapack::cx_gesdd<T>
Chris@49 2684 (
Chris@49 2685 &jobz, &m, &n,
Chris@49 2686 A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt,
Chris@49 2687 work.memptr(), &lwork, rwork.memptr(), iwork.memptr(), &info
Chris@49 2688 );
Chris@49 2689
Chris@49 2690 op_htrans::apply(V,V); // op_htrans will work out that an in-place transpose can be done
Chris@49 2691
Chris@49 2692 return (info == 0);
Chris@49 2693 }
Chris@49 2694 #else
Chris@49 2695 {
Chris@49 2696 arma_ignore(U);
Chris@49 2697 arma_ignore(S);
Chris@49 2698 arma_ignore(V);
Chris@49 2699 arma_ignore(X);
Chris@49 2700 arma_stop("svd(): use of LAPACK needs to be enabled");
Chris@49 2701 return false;
Chris@49 2702 }
Chris@49 2703 #endif
Chris@49 2704 }
Chris@49 2705
Chris@49 2706
Chris@49 2707
Chris@49 2708 //! Solve a system of linear equations.
Chris@49 2709 //! Assumes that A.n_rows = A.n_cols and B.n_rows = A.n_rows
Chris@49 2710 template<typename eT, typename T1>
Chris@49 2711 inline
Chris@49 2712 bool
Chris@49 2713 auxlib::solve(Mat<eT>& out, Mat<eT>& A, const Base<eT,T1>& X, const bool slow)
Chris@49 2714 {
Chris@49 2715 arma_extra_debug_sigprint();
Chris@49 2716
Chris@49 2717 bool status = false;
Chris@49 2718
Chris@49 2719 const uword A_n_rows = A.n_rows;
Chris@49 2720
Chris@49 2721 if( (A_n_rows <= 4) && (slow == false) )
Chris@49 2722 {
Chris@49 2723 Mat<eT> A_inv;
Chris@49 2724
Chris@49 2725 status = auxlib::inv_noalias_tinymat(A_inv, A, A_n_rows);
Chris@49 2726
Chris@49 2727 if(status == true)
Chris@49 2728 {
Chris@49 2729 const unwrap_check<T1> Y( X.get_ref(), out );
Chris@49 2730 const Mat<eT>& B = Y.M;
Chris@49 2731
Chris@49 2732 const uword B_n_rows = B.n_rows;
Chris@49 2733 const uword B_n_cols = B.n_cols;
Chris@49 2734
Chris@49 2735 arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" );
Chris@49 2736
Chris@49 2737 if(A.is_empty() || B.is_empty())
Chris@49 2738 {
Chris@49 2739 out.zeros(A.n_cols, B_n_cols);
Chris@49 2740 return true;
Chris@49 2741 }
Chris@49 2742
Chris@49 2743 out.set_size(A_n_rows, B_n_cols);
Chris@49 2744
Chris@49 2745 gemm_emul<false,false,false,false>::apply(out, A_inv, B);
Chris@49 2746
Chris@49 2747 return true;
Chris@49 2748 }
Chris@49 2749 }
Chris@49 2750
Chris@49 2751 if( (A_n_rows > 4) || (status == false) )
Chris@49 2752 {
Chris@49 2753 out = X.get_ref();
Chris@49 2754
Chris@49 2755 const uword B_n_rows = out.n_rows;
Chris@49 2756 const uword B_n_cols = out.n_cols;
Chris@49 2757
Chris@49 2758 arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" );
Chris@49 2759
Chris@49 2760 if(A.is_empty() || out.is_empty())
Chris@49 2761 {
Chris@49 2762 out.zeros(A.n_cols, B_n_cols);
Chris@49 2763 return true;
Chris@49 2764 }
Chris@49 2765
Chris@49 2766 #if defined(ARMA_USE_ATLAS)
Chris@49 2767 {
Chris@49 2768 podarray<int> ipiv(A_n_rows + 2); // +2 for paranoia: old versions of Atlas might be trashing memory
Chris@49 2769
Chris@49 2770 int info = atlas::clapack_gesv<eT>(atlas::CblasColMajor, A_n_rows, B_n_cols, A.memptr(), A_n_rows, ipiv.memptr(), out.memptr(), A_n_rows);
Chris@49 2771
Chris@49 2772 return (info == 0);
Chris@49 2773 }
Chris@49 2774 #elif defined(ARMA_USE_LAPACK)
Chris@49 2775 {
Chris@49 2776 blas_int n = blas_int(A_n_rows); // assuming A is square
Chris@49 2777 blas_int lda = blas_int(A_n_rows);
Chris@49 2778 blas_int ldb = blas_int(A_n_rows);
Chris@49 2779 blas_int nrhs = blas_int(B_n_cols);
Chris@49 2780 blas_int info = 0;
Chris@49 2781
Chris@49 2782 podarray<blas_int> ipiv(A_n_rows + 2); // +2 for paranoia: some versions of Lapack might be trashing memory
Chris@49 2783
Chris@49 2784 arma_extra_debug_print("lapack::gesv()");
Chris@49 2785 lapack::gesv<eT>(&n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info);
Chris@49 2786
Chris@49 2787 arma_extra_debug_print("lapack::gesv() -- finished");
Chris@49 2788
Chris@49 2789 return (info == 0);
Chris@49 2790 }
Chris@49 2791 #else
Chris@49 2792 {
Chris@49 2793 arma_stop("solve(): use of ATLAS or LAPACK needs to be enabled");
Chris@49 2794 return false;
Chris@49 2795 }
Chris@49 2796 #endif
Chris@49 2797 }
Chris@49 2798
Chris@49 2799 return true;
Chris@49 2800 }
Chris@49 2801
Chris@49 2802
Chris@49 2803
Chris@49 2804 //! Solve an over-determined system.
Chris@49 2805 //! Assumes that A.n_rows > A.n_cols and B.n_rows = A.n_rows
Chris@49 2806 template<typename eT, typename T1>
Chris@49 2807 inline
Chris@49 2808 bool
Chris@49 2809 auxlib::solve_od(Mat<eT>& out, Mat<eT>& A, const Base<eT,T1>& X)
Chris@49 2810 {
Chris@49 2811 arma_extra_debug_sigprint();
Chris@49 2812
Chris@49 2813 #if defined(ARMA_USE_LAPACK)
Chris@49 2814 {
Chris@49 2815 Mat<eT> tmp = X.get_ref();
Chris@49 2816
Chris@49 2817 const uword A_n_rows = A.n_rows;
Chris@49 2818 const uword A_n_cols = A.n_cols;
Chris@49 2819
Chris@49 2820 const uword B_n_rows = tmp.n_rows;
Chris@49 2821 const uword B_n_cols = tmp.n_cols;
Chris@49 2822
Chris@49 2823 arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" );
Chris@49 2824
Chris@49 2825 out.set_size(A_n_cols, B_n_cols);
Chris@49 2826
Chris@49 2827 if(A.is_empty() || tmp.is_empty())
Chris@49 2828 {
Chris@49 2829 out.zeros();
Chris@49 2830 return true;
Chris@49 2831 }
Chris@49 2832
Chris@49 2833 char trans = 'N';
Chris@49 2834
Chris@49 2835 blas_int m = blas_int(A_n_rows);
Chris@49 2836 blas_int n = blas_int(A_n_cols);
Chris@49 2837 blas_int lda = blas_int(A_n_rows);
Chris@49 2838 blas_int ldb = blas_int(A_n_rows);
Chris@49 2839 blas_int nrhs = blas_int(B_n_cols);
Chris@49 2840 blas_int lwork = 3 * ( (std::max)(blas_int(1), n + (std::max)(n, nrhs)) );
Chris@49 2841 blas_int info = 0;
Chris@49 2842
Chris@49 2843 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 2844
Chris@49 2845 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems
Chris@49 2846 arma_extra_debug_print("lapack::gels()");
Chris@49 2847 lapack::gels<eT>( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, work.memptr(), &lwork, &info );
Chris@49 2848
Chris@49 2849 arma_extra_debug_print("lapack::gels() -- finished");
Chris@49 2850
Chris@49 2851 for(uword col=0; col<B_n_cols; ++col)
Chris@49 2852 {
Chris@49 2853 arrayops::copy( out.colptr(col), tmp.colptr(col), A_n_cols );
Chris@49 2854 }
Chris@49 2855
Chris@49 2856 return (info == 0);
Chris@49 2857 }
Chris@49 2858 #else
Chris@49 2859 {
Chris@49 2860 arma_ignore(out);
Chris@49 2861 arma_ignore(A);
Chris@49 2862 arma_ignore(X);
Chris@49 2863 arma_stop("solve(): use of LAPACK needs to be enabled");
Chris@49 2864 return false;
Chris@49 2865 }
Chris@49 2866 #endif
Chris@49 2867 }
Chris@49 2868
Chris@49 2869
Chris@49 2870
Chris@49 2871 //! Solve an under-determined system.
Chris@49 2872 //! Assumes that A.n_rows < A.n_cols and B.n_rows = A.n_rows
Chris@49 2873 template<typename eT, typename T1>
Chris@49 2874 inline
Chris@49 2875 bool
Chris@49 2876 auxlib::solve_ud(Mat<eT>& out, Mat<eT>& A, const Base<eT,T1>& X)
Chris@49 2877 {
Chris@49 2878 arma_extra_debug_sigprint();
Chris@49 2879
Chris@49 2880 // TODO: this function provides the same results as Octave 3.4.2.
Chris@49 2881 // TODO: however, these results are different than Matlab 7.12.0.635.
Chris@49 2882 // TODO: figure out whether both Octave and Matlab are correct, or only one of them
Chris@49 2883
Chris@49 2884 #if defined(ARMA_USE_LAPACK)
Chris@49 2885 {
Chris@49 2886 const unwrap<T1> Y( X.get_ref() );
Chris@49 2887 const Mat<eT>& B = Y.M;
Chris@49 2888
Chris@49 2889 const uword A_n_rows = A.n_rows;
Chris@49 2890 const uword A_n_cols = A.n_cols;
Chris@49 2891
Chris@49 2892 const uword B_n_rows = B.n_rows;
Chris@49 2893 const uword B_n_cols = B.n_cols;
Chris@49 2894
Chris@49 2895 arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" );
Chris@49 2896
Chris@49 2897 // B could be an alias of "out", hence we need to check whether B is empty before setting the size of "out"
Chris@49 2898 if(A.is_empty() || B.is_empty())
Chris@49 2899 {
Chris@49 2900 out.zeros(A_n_cols, B_n_cols);
Chris@49 2901 return true;
Chris@49 2902 }
Chris@49 2903
Chris@49 2904 char trans = 'N';
Chris@49 2905
Chris@49 2906 blas_int m = blas_int(A_n_rows);
Chris@49 2907 blas_int n = blas_int(A_n_cols);
Chris@49 2908 blas_int lda = blas_int(A_n_rows);
Chris@49 2909 blas_int ldb = blas_int(A_n_cols);
Chris@49 2910 blas_int nrhs = blas_int(B_n_cols);
Chris@49 2911 blas_int lwork = 3 * ( (std::max)(blas_int(1), m + (std::max)(m,nrhs)) );
Chris@49 2912 blas_int info = 0;
Chris@49 2913
Chris@49 2914 Mat<eT> tmp(A_n_cols, B_n_cols);
Chris@49 2915 tmp.zeros();
Chris@49 2916
Chris@49 2917 for(uword col=0; col<B_n_cols; ++col)
Chris@49 2918 {
Chris@49 2919 eT* tmp_colmem = tmp.colptr(col);
Chris@49 2920
Chris@49 2921 arrayops::copy( tmp_colmem, B.colptr(col), B_n_rows );
Chris@49 2922
Chris@49 2923 for(uword row=B_n_rows; row<A_n_cols; ++row)
Chris@49 2924 {
Chris@49 2925 tmp_colmem[row] = eT(0);
Chris@49 2926 }
Chris@49 2927 }
Chris@49 2928
Chris@49 2929 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 2930
Chris@49 2931 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems
Chris@49 2932 arma_extra_debug_print("lapack::gels()");
Chris@49 2933 lapack::gels<eT>( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, work.memptr(), &lwork, &info );
Chris@49 2934
Chris@49 2935 arma_extra_debug_print("lapack::gels() -- finished");
Chris@49 2936
Chris@49 2937 out.set_size(A_n_cols, B_n_cols);
Chris@49 2938
Chris@49 2939 for(uword col=0; col<B_n_cols; ++col)
Chris@49 2940 {
Chris@49 2941 arrayops::copy( out.colptr(col), tmp.colptr(col), A_n_cols );
Chris@49 2942 }
Chris@49 2943
Chris@49 2944 return (info == 0);
Chris@49 2945 }
Chris@49 2946 #else
Chris@49 2947 {
Chris@49 2948 arma_ignore(out);
Chris@49 2949 arma_ignore(A);
Chris@49 2950 arma_ignore(X);
Chris@49 2951 arma_stop("solve(): use of LAPACK needs to be enabled");
Chris@49 2952 return false;
Chris@49 2953 }
Chris@49 2954 #endif
Chris@49 2955 }
Chris@49 2956
Chris@49 2957
Chris@49 2958
Chris@49 2959 //
Chris@49 2960 // solve_tr
Chris@49 2961
Chris@49 2962 template<typename eT>
Chris@49 2963 inline
Chris@49 2964 bool
Chris@49 2965 auxlib::solve_tr(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const uword layout)
Chris@49 2966 {
Chris@49 2967 arma_extra_debug_sigprint();
Chris@49 2968
Chris@49 2969 #if defined(ARMA_USE_LAPACK)
Chris@49 2970 {
Chris@49 2971 if(A.is_empty() || B.is_empty())
Chris@49 2972 {
Chris@49 2973 out.zeros(A.n_cols, B.n_cols);
Chris@49 2974 return true;
Chris@49 2975 }
Chris@49 2976
Chris@49 2977 out = B;
Chris@49 2978
Chris@49 2979 char uplo = (layout == 0) ? 'U' : 'L';
Chris@49 2980 char trans = 'N';
Chris@49 2981 char diag = 'N';
Chris@49 2982 blas_int n = blas_int(A.n_rows);
Chris@49 2983 blas_int nrhs = blas_int(B.n_cols);
Chris@49 2984 blas_int info = 0;
Chris@49 2985
Chris@49 2986 lapack::trtrs<eT>(&uplo, &trans, &diag, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info);
Chris@49 2987
Chris@49 2988 return (info == 0);
Chris@49 2989 }
Chris@49 2990 #else
Chris@49 2991 {
Chris@49 2992 arma_ignore(out);
Chris@49 2993 arma_ignore(A);
Chris@49 2994 arma_ignore(B);
Chris@49 2995 arma_ignore(layout);
Chris@49 2996 arma_stop("solve(): use of LAPACK needs to be enabled");
Chris@49 2997 return false;
Chris@49 2998 }
Chris@49 2999 #endif
Chris@49 3000 }
Chris@49 3001
Chris@49 3002
Chris@49 3003
Chris@49 3004 //
Chris@49 3005 // Schur decomposition
Chris@49 3006
Chris@49 3007 template<typename eT>
Chris@49 3008 inline
Chris@49 3009 bool
Chris@49 3010 auxlib::schur_dec(Mat<eT>& Z, Mat<eT>& T, const Mat<eT>& A)
Chris@49 3011 {
Chris@49 3012 arma_extra_debug_sigprint();
Chris@49 3013
Chris@49 3014 #if defined(ARMA_USE_LAPACK)
Chris@49 3015 {
Chris@49 3016 arma_debug_check( (A.is_square() == false), "schur_dec(): given matrix is not square" );
Chris@49 3017
Chris@49 3018 if(A.is_empty())
Chris@49 3019 {
Chris@49 3020 Z.reset();
Chris@49 3021 T.reset();
Chris@49 3022 return true;
Chris@49 3023 }
Chris@49 3024
Chris@49 3025 const uword A_n_rows = A.n_rows;
Chris@49 3026
Chris@49 3027 Z.set_size(A_n_rows, A_n_rows);
Chris@49 3028 T = A;
Chris@49 3029
Chris@49 3030 char jobvs = 'V'; // get Schur vectors (Z)
Chris@49 3031 char sort = 'N'; // do not sort eigenvalues/vectors
Chris@49 3032 blas_int* select = 0; // pointer to sorting function
Chris@49 3033 blas_int n = blas_int(A_n_rows);
Chris@49 3034 blas_int sdim = 0; // output for sorting
Chris@49 3035 blas_int lwork = 3 * ( (std::max)(blas_int(1), 3*n) );
Chris@49 3036 blas_int info = 0;
Chris@49 3037
Chris@49 3038 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 3039 podarray<blas_int> bwork(A_n_rows);
Chris@49 3040
Chris@49 3041 podarray<eT> wr(A_n_rows); // output for eigenvalues
Chris@49 3042 podarray<eT> wi(A_n_rows); // output for eigenvalues
Chris@49 3043
Chris@49 3044 lapack::gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, wr.memptr(), wi.memptr(), Z.memptr(), &n, work.memptr(), &lwork, bwork.memptr(), &info);
Chris@49 3045
Chris@49 3046 return (info == 0);
Chris@49 3047 }
Chris@49 3048 #else
Chris@49 3049 {
Chris@49 3050 arma_ignore(Z);
Chris@49 3051 arma_ignore(T);
Chris@49 3052 arma_ignore(A);
Chris@49 3053
Chris@49 3054 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
Chris@49 3055 return false;
Chris@49 3056 }
Chris@49 3057 #endif
Chris@49 3058 }
Chris@49 3059
Chris@49 3060
Chris@49 3061
Chris@49 3062 template<typename cT>
Chris@49 3063 inline
Chris@49 3064 bool
Chris@49 3065 auxlib::schur_dec(Mat<std::complex<cT> >& Z, Mat<std::complex<cT> >& T, const Mat<std::complex<cT> >& A)
Chris@49 3066 {
Chris@49 3067 arma_extra_debug_sigprint();
Chris@49 3068
Chris@49 3069 #if defined(ARMA_USE_LAPACK)
Chris@49 3070 {
Chris@49 3071 arma_debug_check( (A.is_square() == false), "schur_dec(): matrix A is not square" );
Chris@49 3072
Chris@49 3073 if(A.is_empty())
Chris@49 3074 {
Chris@49 3075 Z.reset();
Chris@49 3076 T.reset();
Chris@49 3077 return true;
Chris@49 3078 }
Chris@49 3079
Chris@49 3080 typedef std::complex<cT> eT;
Chris@49 3081
Chris@49 3082 const uword A_n_rows = A.n_rows;
Chris@49 3083
Chris@49 3084 Z.set_size(A_n_rows, A_n_rows);
Chris@49 3085 T = A;
Chris@49 3086
Chris@49 3087 char jobvs = 'V'; // get Schur vectors (Z)
Chris@49 3088 char sort = 'N'; // do not sort eigenvalues/vectors
Chris@49 3089 blas_int* select = 0; // pointer to sorting function
Chris@49 3090 blas_int n = blas_int(A_n_rows);
Chris@49 3091 blas_int sdim = 0; // output for sorting
Chris@49 3092 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*n) );
Chris@49 3093 blas_int info = 0;
Chris@49 3094
Chris@49 3095 podarray<eT> work( static_cast<uword>(lwork) );
Chris@49 3096 podarray<blas_int> bwork(A_n_rows);
Chris@49 3097
Chris@49 3098 podarray<eT> w(A_n_rows); // output for eigenvalues
Chris@49 3099 podarray<cT> rwork(A_n_rows);
Chris@49 3100
Chris@49 3101 lapack::cx_gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, w.memptr(), Z.memptr(), &n, work.memptr(), &lwork, rwork.memptr(), bwork.memptr(), &info);
Chris@49 3102
Chris@49 3103 return (info == 0);
Chris@49 3104 }
Chris@49 3105 #else
Chris@49 3106 {
Chris@49 3107 arma_ignore(Z);
Chris@49 3108 arma_ignore(T);
Chris@49 3109 arma_ignore(A);
Chris@49 3110
Chris@49 3111 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
Chris@49 3112 return false;
Chris@49 3113 }
Chris@49 3114 #endif
Chris@49 3115 }
Chris@49 3116
Chris@49 3117
Chris@49 3118
Chris@49 3119 //
Chris@49 3120 // syl (solution of the Sylvester equation AX + XB = C)
Chris@49 3121
Chris@49 3122 template<typename eT>
Chris@49 3123 inline
Chris@49 3124 bool
Chris@49 3125 auxlib::syl(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C)
Chris@49 3126 {
Chris@49 3127 arma_extra_debug_sigprint();
Chris@49 3128
Chris@49 3129 arma_debug_check
Chris@49 3130 (
Chris@49 3131 (A.is_square() == false) || (B.is_square() == false),
Chris@49 3132 "syl(): given matrix is not square"
Chris@49 3133 );
Chris@49 3134
Chris@49 3135 arma_debug_check
Chris@49 3136 (
Chris@49 3137 (C.n_rows != A.n_rows) || (C.n_cols != B.n_cols),
Chris@49 3138 "syl(): matrices are not conformant"
Chris@49 3139 );
Chris@49 3140
Chris@49 3141 if(A.is_empty() || B.is_empty() || C.is_empty())
Chris@49 3142 {
Chris@49 3143 X.reset();
Chris@49 3144 return true;
Chris@49 3145 }
Chris@49 3146
Chris@49 3147 #if defined(ARMA_USE_LAPACK)
Chris@49 3148 {
Chris@49 3149 Mat<eT> Z1, Z2, T1, T2;
Chris@49 3150
Chris@49 3151 const bool status_sd1 = auxlib::schur_dec(Z1, T1, A);
Chris@49 3152 const bool status_sd2 = auxlib::schur_dec(Z2, T2, B);
Chris@49 3153
Chris@49 3154 if( (status_sd1 == false) || (status_sd2 == false) )
Chris@49 3155 {
Chris@49 3156 return false;
Chris@49 3157 }
Chris@49 3158
Chris@49 3159 char trana = 'N';
Chris@49 3160 char tranb = 'N';
Chris@49 3161 blas_int isgn = +1;
Chris@49 3162 blas_int m = blas_int(T1.n_rows);
Chris@49 3163 blas_int n = blas_int(T2.n_cols);
Chris@49 3164
Chris@49 3165 eT scale = eT(0);
Chris@49 3166 blas_int info = 0;
Chris@49 3167
Chris@49 3168 Mat<eT> Y = trans(Z1) * C * Z2;
Chris@49 3169
Chris@49 3170 lapack::trsyl<eT>(&trana, &tranb, &isgn, &m, &n, T1.memptr(), &m, T2.memptr(), &n, Y.memptr(), &m, &scale, &info);
Chris@49 3171
Chris@49 3172 //Y /= scale;
Chris@49 3173 Y /= (-scale);
Chris@49 3174
Chris@49 3175 X = Z1 * Y * trans(Z2);
Chris@49 3176
Chris@49 3177 return (info >= 0);
Chris@49 3178 }
Chris@49 3179 #else
Chris@49 3180 {
Chris@49 3181 arma_stop("syl(): use of LAPACK needs to be enabled");
Chris@49 3182 return false;
Chris@49 3183 }
Chris@49 3184 #endif
Chris@49 3185 }
Chris@49 3186
Chris@49 3187
Chris@49 3188
Chris@49 3189 //
Chris@49 3190 // lyap (solution of the continuous Lyapunov equation AX + XA^H + Q = 0)
Chris@49 3191
Chris@49 3192 template<typename eT>
Chris@49 3193 inline
Chris@49 3194 bool
Chris@49 3195 auxlib::lyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
Chris@49 3196 {
Chris@49 3197 arma_extra_debug_sigprint();
Chris@49 3198
Chris@49 3199 arma_debug_check( (A.is_square() == false), "lyap(): matrix A is not square");
Chris@49 3200 arma_debug_check( (Q.is_square() == false), "lyap(): matrix Q is not square");
Chris@49 3201 arma_debug_check( (A.n_rows != Q.n_rows), "lyap(): matrices A and Q have different dimensions");
Chris@49 3202
Chris@49 3203 Mat<eT> htransA;
Chris@49 3204 op_htrans::apply_noalias(htransA, A);
Chris@49 3205
Chris@49 3206 const Mat<eT> mQ = -Q;
Chris@49 3207
Chris@49 3208 return auxlib::syl(X, A, htransA, mQ);
Chris@49 3209 }
Chris@49 3210
Chris@49 3211
Chris@49 3212
Chris@49 3213 //
Chris@49 3214 // dlyap (solution of the discrete Lyapunov equation AXA^H - X + Q = 0)
Chris@49 3215
Chris@49 3216 template<typename eT>
Chris@49 3217 inline
Chris@49 3218 bool
Chris@49 3219 auxlib::dlyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
Chris@49 3220 {
Chris@49 3221 arma_extra_debug_sigprint();
Chris@49 3222
Chris@49 3223 arma_debug_check( (A.is_square() == false), "dlyap(): matrix A is not square");
Chris@49 3224 arma_debug_check( (Q.is_square() == false), "dlyap(): matrix Q is not square");
Chris@49 3225 arma_debug_check( (A.n_rows != Q.n_rows), "dlyap(): matrices A and Q have different dimensions");
Chris@49 3226
Chris@49 3227 const Col<eT> vecQ = reshape(Q, Q.n_elem, 1);
Chris@49 3228
Chris@49 3229 const Mat<eT> M = eye< Mat<eT> >(Q.n_elem, Q.n_elem) - kron(conj(A), A);
Chris@49 3230
Chris@49 3231 Col<eT> vecX;
Chris@49 3232
Chris@49 3233 const bool status = solve(vecX, M, vecQ);
Chris@49 3234
Chris@49 3235 if(status == true)
Chris@49 3236 {
Chris@49 3237 X = reshape(vecX, Q.n_rows, Q.n_cols);
Chris@49 3238 return true;
Chris@49 3239 }
Chris@49 3240 else
Chris@49 3241 {
Chris@49 3242 X.reset();
Chris@49 3243 return false;
Chris@49 3244 }
Chris@49 3245 }
Chris@49 3246
Chris@49 3247
Chris@49 3248
Chris@49 3249 //! @}