comparison armadillo-2.4.4/include/armadillo_bits/auxlib_meat.hpp @ 0:8b6102e2a9b0

Armadillo Library
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 09:27:06 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:8b6102e2a9b0
1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2012 Conrad Sanderson
3 // Copyright (C) 2009 Edmund Highcock
4 // Copyright (C) 2011 James Sanders
5 // Copyright (C) 2011 Stanislav Funiak
6 //
7 // This file is part of the Armadillo C++ library.
8 // It is provided without any warranty of fitness
9 // for any purpose. You can redistribute this file
10 // and/or modify it under the terms of the GNU
11 // Lesser General Public License (LGPL) as published
12 // by the Free Software Foundation, either version 3
13 // of the License or (at your option) any later version.
14 // (see http://www.opensource.org/licenses for more info)
15
16
17 //! \addtogroup auxlib
18 //! @{
19
20
21
22 //! immediate matrix inverse
23 template<typename eT, typename T1>
24 inline
25 bool
26 auxlib::inv(Mat<eT>& out, const Base<eT,T1>& X, const bool slow)
27 {
28 arma_extra_debug_sigprint();
29
30 out = X.get_ref();
31
32 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
33
34 bool status = false;
35
36 const uword N = out.n_rows;
37
38 if( (N <= 4) && (slow == false) )
39 {
40 status = auxlib::inv_inplace_tinymat(out, N);
41 }
42
43 if( (N > 4) || (status == false) )
44 {
45 status = auxlib::inv_inplace_lapack(out);
46 }
47
48 return status;
49 }
50
51
52
53 template<typename eT>
54 inline
55 bool
56 auxlib::inv(Mat<eT>& out, const Mat<eT>& X, const bool slow)
57 {
58 arma_extra_debug_sigprint();
59
60 arma_debug_check( (X.is_square() == false), "inv(): given matrix is not square" );
61
62 bool status = false;
63
64 const uword N = X.n_rows;
65
66 if( (N <= 4) && (slow == false) )
67 {
68 status = (&out != &X) ? auxlib::inv_noalias_tinymat(out, X, N) : auxlib::inv_inplace_tinymat(out, N);
69 }
70
71 if( (N > 4) || (status == false) )
72 {
73 out = X;
74 status = auxlib::inv_inplace_lapack(out);
75 }
76
77 return status;
78 }
79
80
81
82 template<typename eT>
83 inline
84 bool
85 auxlib::inv_noalias_tinymat(Mat<eT>& out, const Mat<eT>& X, const uword N)
86 {
87 arma_extra_debug_sigprint();
88
89 bool det_ok = true;
90
91 out.set_size(N,N);
92
93 switch(N)
94 {
95 case 1:
96 {
97 out[0] = eT(1) / X[0];
98 };
99 break;
100
101 case 2:
102 {
103 const eT* Xm = X.memptr();
104
105 const eT a = Xm[pos<0,0>::n2];
106 const eT b = Xm[pos<0,1>::n2];
107 const eT c = Xm[pos<1,0>::n2];
108 const eT d = Xm[pos<1,1>::n2];
109
110 const eT tmp_det = (a*d - b*c);
111
112 if(tmp_det != eT(0))
113 {
114 eT* outm = out.memptr();
115
116 outm[pos<0,0>::n2] = d / tmp_det;
117 outm[pos<0,1>::n2] = -b / tmp_det;
118 outm[pos<1,0>::n2] = -c / tmp_det;
119 outm[pos<1,1>::n2] = a / tmp_det;
120 }
121 else
122 {
123 det_ok = false;
124 }
125 };
126 break;
127
128 case 3:
129 {
130 const eT* X_col0 = X.colptr(0);
131 const eT a11 = X_col0[0];
132 const eT a21 = X_col0[1];
133 const eT a31 = X_col0[2];
134
135 const eT* X_col1 = X.colptr(1);
136 const eT a12 = X_col1[0];
137 const eT a22 = X_col1[1];
138 const eT a32 = X_col1[2];
139
140 const eT* X_col2 = X.colptr(2);
141 const eT a13 = X_col2[0];
142 const eT a23 = X_col2[1];
143 const eT a33 = X_col2[2];
144
145 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
146
147 if(tmp_det != eT(0))
148 {
149 eT* out_col0 = out.colptr(0);
150 out_col0[0] = (a33*a22 - a32*a23) / tmp_det;
151 out_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
152 out_col0[2] = (a32*a21 - a31*a22) / tmp_det;
153
154 eT* out_col1 = out.colptr(1);
155 out_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
156 out_col1[1] = (a33*a11 - a31*a13) / tmp_det;
157 out_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
158
159 eT* out_col2 = out.colptr(2);
160 out_col2[0] = (a23*a12 - a22*a13) / tmp_det;
161 out_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
162 out_col2[2] = (a22*a11 - a21*a12) / tmp_det;
163 }
164 else
165 {
166 det_ok = false;
167 }
168 };
169 break;
170
171 case 4:
172 {
173 const eT tmp_det = det(X);
174
175 if(tmp_det != eT(0))
176 {
177 const eT* Xm = X.memptr();
178 eT* outm = out.memptr();
179
180 outm[pos<0,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
181 outm[pos<1,0>::n4] = ( Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
182 outm[pos<2,0>::n4] = ( Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
183 outm[pos<3,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
184
185 outm[pos<0,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
186 outm[pos<1,1>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
187 outm[pos<2,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
188 outm[pos<3,1>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
189
190 outm[pos<0,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
191 outm[pos<1,2>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
192 outm[pos<2,2>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
193 outm[pos<3,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
194
195 outm[pos<0,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
196 outm[pos<1,3>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
197 outm[pos<2,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
198 outm[pos<3,3>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] ) / tmp_det;
199 }
200 else
201 {
202 det_ok = false;
203 }
204 };
205 break;
206
207 default:
208 ;
209 }
210
211 return det_ok;
212 }
213
214
215
216 template<typename eT>
217 inline
218 bool
219 auxlib::inv_inplace_tinymat(Mat<eT>& X, const uword N)
220 {
221 arma_extra_debug_sigprint();
222
223 bool det_ok = true;
224
225 // for more info, see:
226 // http://www.dr-lex.34sp.com/random/matrix_inv.html
227 // http://www.cvl.iis.u-tokyo.ac.jp/~miyazaki/tech/teche23.html
228 // http://www.euclideanspace.com/maths/algebra/matrix/functions/inverse/fourD/index.htm
229 // http://www.geometrictools.com//LibFoundation/Mathematics/Wm4Matrix4.inl
230
231 switch(N)
232 {
233 case 1:
234 {
235 X[0] = eT(1) / X[0];
236 };
237 break;
238
239 case 2:
240 {
241 const eT a = X[pos<0,0>::n2];
242 const eT b = X[pos<0,1>::n2];
243 const eT c = X[pos<1,0>::n2];
244 const eT d = X[pos<1,1>::n2];
245
246 const eT tmp_det = (a*d - b*c);
247
248 if(tmp_det != eT(0))
249 {
250 X[pos<0,0>::n2] = d / tmp_det;
251 X[pos<0,1>::n2] = -b / tmp_det;
252 X[pos<1,0>::n2] = -c / tmp_det;
253 X[pos<1,1>::n2] = a / tmp_det;
254 }
255 else
256 {
257 det_ok = false;
258 }
259 };
260 break;
261
262 case 3:
263 {
264 eT* X_col0 = X.colptr(0);
265 eT* X_col1 = X.colptr(1);
266 eT* X_col2 = X.colptr(2);
267
268 const eT a11 = X_col0[0];
269 const eT a21 = X_col0[1];
270 const eT a31 = X_col0[2];
271
272 const eT a12 = X_col1[0];
273 const eT a22 = X_col1[1];
274 const eT a32 = X_col1[2];
275
276 const eT a13 = X_col2[0];
277 const eT a23 = X_col2[1];
278 const eT a33 = X_col2[2];
279
280 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
281
282 if(tmp_det != eT(0))
283 {
284 X_col0[0] = (a33*a22 - a32*a23) / tmp_det;
285 X_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
286 X_col0[2] = (a32*a21 - a31*a22) / tmp_det;
287
288 X_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
289 X_col1[1] = (a33*a11 - a31*a13) / tmp_det;
290 X_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
291
292 X_col2[0] = (a23*a12 - a22*a13) / tmp_det;
293 X_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
294 X_col2[2] = (a22*a11 - a21*a12) / tmp_det;
295 }
296 else
297 {
298 det_ok = false;
299 }
300 };
301 break;
302
303 case 4:
304 {
305 const eT tmp_det = det(X);
306
307 if(tmp_det != eT(0))
308 {
309 const Mat<eT> A(X);
310
311 const eT* Am = A.memptr();
312 eT* Xm = X.memptr();
313
314 Xm[pos<0,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] - Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] + Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
315 Xm[pos<1,0>::n4] = ( Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
316 Xm[pos<2,0>::n4] = ( Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
317 Xm[pos<3,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
318
319 Xm[pos<0,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] + Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] - Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
320 Xm[pos<1,1>::n4] = ( Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
321 Xm[pos<2,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
322 Xm[pos<3,1>::n4] = ( Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
323
324 Xm[pos<0,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] + Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
325 Xm[pos<1,2>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
326 Xm[pos<2,2>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
327 Xm[pos<3,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
328
329 Xm[pos<0,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] + Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
330 Xm[pos<1,3>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
331 Xm[pos<2,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
332 Xm[pos<3,3>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] ) / tmp_det;
333 }
334 else
335 {
336 det_ok = false;
337 }
338 };
339 break;
340
341 default:
342 ;
343 }
344
345 return det_ok;
346 }
347
348
349
350 template<typename eT>
351 inline
352 bool
353 auxlib::inv_inplace_lapack(Mat<eT>& out)
354 {
355 arma_extra_debug_sigprint();
356
357 if(out.is_empty())
358 {
359 return true;
360 }
361
362 #if defined(ARMA_USE_ATLAS)
363 {
364 podarray<int> ipiv(out.n_rows);
365
366 int info = atlas::clapack_getrf(atlas::CblasColMajor, out.n_rows, out.n_cols, out.memptr(), out.n_rows, ipiv.memptr());
367
368 if(info == 0)
369 {
370 info = atlas::clapack_getri(atlas::CblasColMajor, out.n_rows, out.memptr(), out.n_rows, ipiv.memptr());
371 }
372
373 return (info == 0);
374 }
375 #elif defined(ARMA_USE_LAPACK)
376 {
377 blas_int n_rows = out.n_rows;
378 blas_int n_cols = out.n_cols;
379 blas_int info = 0;
380
381 podarray<blas_int> ipiv(out.n_rows);
382
383 // 84 was empirically found -- it is the maximum value suggested by LAPACK (as provided by ATLAS v3.6)
384 // based on tests with various matrix types on 32-bit and 64-bit machines
385 //
386 // the "work" array is deliberately long so that a secondary (time-consuming)
387 // memory allocation is avoided, if possible
388
389 blas_int work_len = (std::max)(blas_int(1), n_rows*84);
390 podarray<eT> work( static_cast<uword>(work_len) );
391
392 lapack::getrf(&n_rows, &n_cols, out.memptr(), &n_rows, ipiv.memptr(), &info);
393
394 if(info == 0)
395 {
396 // query for optimum size of work_len
397
398 blas_int work_len_tmp = -1;
399 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len_tmp, &info);
400
401 if(info == 0)
402 {
403 blas_int proposed_work_len = static_cast<blas_int>(access::tmp_real(work[0]));
404
405 // if necessary, allocate more memory
406 if(work_len < proposed_work_len)
407 {
408 work_len = proposed_work_len;
409 work.set_size( static_cast<uword>(work_len) );
410 }
411 }
412
413 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len, &info);
414 }
415
416 return (info == 0);
417 }
418 #else
419 {
420 arma_ignore(out);
421 arma_stop("inv(): use of ATLAS or LAPACK needs to be enabled");
422 return false;
423 }
424 #endif
425 }
426
427
428
429 template<typename eT, typename T1>
430 inline
431 bool
432 auxlib::inv_tr(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
433 {
434 arma_extra_debug_sigprint();
435
436 out = X.get_ref();
437
438 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
439
440 if(out.is_empty())
441 {
442 return true;
443 }
444
445 bool status;
446
447 #if defined(ARMA_USE_LAPACK)
448 {
449 char uplo = (layout == 0) ? 'U' : 'L';
450 char diag = 'N';
451 blas_int n = blas_int(out.n_rows);
452 blas_int info = 0;
453
454 lapack::trtri(&uplo, &diag, &n, out.memptr(), &n, &info);
455
456 status = (info == 0);
457 }
458 #else
459 {
460 arma_ignore(layout);
461 arma_stop("inv(): use of LAPACK needs to be enabled");
462 status = false;
463 }
464 #endif
465
466
467 if(status == true)
468 {
469 if(layout == 0)
470 {
471 // upper triangular
472 out = trimatu(out);
473 }
474 else
475 {
476 // lower triangular
477 out = trimatl(out);
478 }
479 }
480
481 return status;
482 }
483
484
485
486 template<typename eT, typename T1>
487 inline
488 bool
489 auxlib::inv_sym(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
490 {
491 arma_extra_debug_sigprint();
492
493 out = X.get_ref();
494
495 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
496
497 if(out.is_empty())
498 {
499 return true;
500 }
501
502 bool status;
503
504 #if defined(ARMA_USE_LAPACK)
505 {
506 char uplo = (layout == 0) ? 'U' : 'L';
507 blas_int n = blas_int(out.n_rows);
508 blas_int lwork = n*n; // TODO: use lwork = -1 to determine optimal size
509 blas_int info = 0;
510
511 podarray<blas_int> ipiv;
512 ipiv.set_size(out.n_rows);
513
514 podarray<eT> work;
515 work.set_size( uword(lwork) );
516
517 lapack::sytrf(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &lwork, &info);
518
519 status = (info == 0);
520
521 if(status == true)
522 {
523 lapack::sytri(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &info);
524
525 out = (layout == 0) ? symmatu(out) : symmatl(out);
526
527 status = (info == 0);
528 }
529 }
530 #else
531 {
532 arma_ignore(layout);
533 arma_stop("inv(): use of LAPACK needs to be enabled");
534 status = false;
535 }
536 #endif
537
538 return status;
539 }
540
541
542
543 template<typename eT, typename T1>
544 inline
545 bool
546 auxlib::inv_sympd(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
547 {
548 arma_extra_debug_sigprint();
549
550 out = X.get_ref();
551
552 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
553
554 if(out.is_empty())
555 {
556 return true;
557 }
558
559 bool status;
560
561 #if defined(ARMA_USE_LAPACK)
562 {
563 char uplo = (layout == 0) ? 'U' : 'L';
564 blas_int n = blas_int(out.n_rows);
565 blas_int info = 0;
566
567 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
568
569 status = (info == 0);
570
571 if(status == true)
572 {
573 lapack::potri(&uplo, &n, out.memptr(), &n, &info);
574
575 out = (layout == 0) ? symmatu(out) : symmatl(out);
576
577 status = (info == 0);
578 }
579 }
580 #else
581 {
582 arma_ignore(layout);
583 arma_stop("inv(): use of LAPACK needs to be enabled");
584 status = false;
585 }
586 #endif
587
588 return status;
589 }
590
591
592
593 template<typename eT, typename T1>
594 inline
595 eT
596 auxlib::det(const Base<eT,T1>& X, const bool slow)
597 {
598 const unwrap<T1> tmp(X.get_ref());
599 const Mat<eT>& A = tmp.M;
600
601 arma_debug_check( (A.is_square() == false), "det(): matrix is not square" );
602
603 const bool make_copy = (is_Mat<T1>::value == true) ? true : false;
604
605 if(slow == false)
606 {
607 const uword N = A.n_rows;
608
609 switch(N)
610 {
611 case 0:
612 case 1:
613 case 2:
614 return auxlib::det_tinymat(A, N);
615 break;
616
617 case 3:
618 case 4:
619 {
620 const eT tmp_det = auxlib::det_tinymat(A, N);
621 return (tmp_det != eT(0)) ? tmp_det : auxlib::det_lapack(A, make_copy);
622 }
623 break;
624
625 default:
626 return auxlib::det_lapack(A, make_copy);
627 }
628 }
629 else
630 {
631 return auxlib::det_lapack(A, make_copy);
632 }
633 }
634
635
636
637 template<typename eT>
638 inline
639 eT
640 auxlib::det_tinymat(const Mat<eT>& X, const uword N)
641 {
642 arma_extra_debug_sigprint();
643
644 switch(N)
645 {
646 case 0:
647 return eT(1);
648 break;
649
650 case 1:
651 return X[0];
652 break;
653
654 case 2:
655 {
656 const eT* Xm = X.memptr();
657
658 return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] );
659 }
660 break;
661
662 case 3:
663 {
664 // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2);
665 // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0);
666 // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1);
667 // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2);
668 // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0);
669 // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1);
670 // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6);
671
672 const eT* a_col0 = X.colptr(0);
673 const eT a11 = a_col0[0];
674 const eT a21 = a_col0[1];
675 const eT a31 = a_col0[2];
676
677 const eT* a_col1 = X.colptr(1);
678 const eT a12 = a_col1[0];
679 const eT a22 = a_col1[1];
680 const eT a32 = a_col1[2];
681
682 const eT* a_col2 = X.colptr(2);
683 const eT a13 = a_col2[0];
684 const eT a23 = a_col2[1];
685 const eT a33 = a_col2[2];
686
687 return ( a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13) );
688 }
689 break;
690
691 case 4:
692 {
693 const eT* Xm = X.memptr();
694
695 const eT val = \
696 Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
697 - Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
698 - Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
699 + Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
700 + Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
701 - Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
702 - Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
703 + Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
704 + Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
705 - Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
706 - Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
707 + Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
708 + Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
709 - Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
710 - Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
711 + Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
712 + Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
713 - Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
714 - Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
715 + Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
716 + Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
717 - Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
718 - Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
719 + Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
720 ;
721
722 return val;
723 }
724 break;
725
726 default:
727 return eT(0);
728 ;
729 }
730 }
731
732
733
734 //! immediate determinant of a matrix using ATLAS or LAPACK
735 template<typename eT>
736 inline
737 eT
738 auxlib::det_lapack(const Mat<eT>& X, const bool make_copy)
739 {
740 arma_extra_debug_sigprint();
741
742 Mat<eT> X_copy;
743
744 if(make_copy == true)
745 {
746 X_copy = X;
747 }
748
749 Mat<eT>& tmp = (make_copy == true) ? X_copy : const_cast< Mat<eT>& >(X);
750
751 if(tmp.is_empty())
752 {
753 return eT(1);
754 }
755
756
757 #if defined(ARMA_USE_ATLAS)
758 {
759 podarray<int> ipiv(tmp.n_rows);
760
761 //const int info =
762 atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
763
764 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
765 eT val = tmp.at(0,0);
766 for(uword i=1; i < tmp.n_rows; ++i)
767 {
768 val *= tmp.at(i,i);
769 }
770
771 int sign = +1;
772 for(uword i=0; i < tmp.n_rows; ++i)
773 {
774 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0
775 {
776 sign *= -1;
777 }
778 }
779
780 return ( (sign < 0) ? -val : val );
781 }
782 #elif defined(ARMA_USE_LAPACK)
783 {
784 podarray<blas_int> ipiv(tmp.n_rows);
785
786 blas_int info = 0;
787 blas_int n_rows = blas_int(tmp.n_rows);
788 blas_int n_cols = blas_int(tmp.n_cols);
789
790 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
791
792 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
793 eT val = tmp.at(0,0);
794 for(uword i=1; i < tmp.n_rows; ++i)
795 {
796 val *= tmp.at(i,i);
797 }
798
799 blas_int sign = +1;
800 for(uword i=0; i < tmp.n_rows; ++i)
801 {
802 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1
803 {
804 sign *= -1;
805 }
806 }
807
808 return ( (sign < 0) ? -val : val );
809 }
810 #else
811 {
812 arma_ignore(X);
813 arma_ignore(make_copy);
814 arma_ignore(tmp);
815 arma_stop("det(): use of ATLAS or LAPACK needs to be enabled");
816 return eT(0);
817 }
818 #endif
819 }
820
821
822
823 //! immediate log determinant of a matrix using ATLAS or LAPACK
824 template<typename eT, typename T1>
825 inline
826 bool
827 auxlib::log_det(eT& out_val, typename get_pod_type<eT>::result& out_sign, const Base<eT,T1>& X)
828 {
829 arma_extra_debug_sigprint();
830
831 typedef typename get_pod_type<eT>::result T;
832
833 #if defined(ARMA_USE_ATLAS)
834 {
835 Mat<eT> tmp(X.get_ref());
836 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
837
838 if(tmp.is_empty())
839 {
840 out_val = eT(0);
841 out_sign = T(1);
842 return true;
843 }
844
845 podarray<int> ipiv(tmp.n_rows);
846
847 const int info = atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
848
849 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
850
851 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
852 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
853
854 for(uword i=1; i < tmp.n_rows; ++i)
855 {
856 const eT x = tmp.at(i,i);
857
858 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
859 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
860 }
861
862 for(uword i=0; i < tmp.n_rows; ++i)
863 {
864 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0
865 {
866 sign *= -1;
867 }
868 }
869
870 out_val = val;
871 out_sign = T(sign);
872
873 return (info == 0);
874 }
875 #elif defined(ARMA_USE_LAPACK)
876 {
877 Mat<eT> tmp(X.get_ref());
878 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
879
880 if(tmp.is_empty())
881 {
882 out_val = eT(0);
883 out_sign = T(1);
884 return true;
885 }
886
887 podarray<blas_int> ipiv(tmp.n_rows);
888
889 blas_int info = 0;
890 blas_int n_rows = blas_int(tmp.n_rows);
891 blas_int n_cols = blas_int(tmp.n_cols);
892
893 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
894
895 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
896
897 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
898 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
899
900 for(uword i=1; i < tmp.n_rows; ++i)
901 {
902 const eT x = tmp.at(i,i);
903
904 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
905 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
906 }
907
908 for(uword i=0; i < tmp.n_rows; ++i)
909 {
910 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1
911 {
912 sign *= -1;
913 }
914 }
915
916 out_val = val;
917 out_sign = T(sign);
918
919 return (info == 0);
920 }
921 #else
922 {
923 out_val = eT(0);
924 out_sign = T(0);
925
926 arma_stop("log_det(): use of ATLAS or LAPACK needs to be enabled");
927
928 return false;
929 }
930 #endif
931 }
932
933
934
935 //! immediate LU decomposition of a matrix using ATLAS or LAPACK
936 template<typename eT, typename T1>
937 inline
938 bool
939 auxlib::lu(Mat<eT>& L, Mat<eT>& U, podarray<blas_int>& ipiv, const Base<eT,T1>& X)
940 {
941 arma_extra_debug_sigprint();
942
943 U = X.get_ref();
944
945 const uword U_n_rows = U.n_rows;
946 const uword U_n_cols = U.n_cols;
947
948 if(U.is_empty())
949 {
950 L.set_size(U_n_rows, 0);
951 U.set_size(0, U_n_cols);
952 ipiv.reset();
953 return true;
954 }
955
956 #if defined(ARMA_USE_ATLAS) || defined(ARMA_USE_LAPACK)
957 {
958 bool status;
959
960 #if defined(ARMA_USE_ATLAS)
961 {
962 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
963
964 int info = atlas::clapack_getrf(atlas::CblasColMajor, U_n_rows, U_n_cols, U.memptr(), U_n_rows, ipiv.memptr());
965
966 status = (info == 0);
967 }
968 #elif defined(ARMA_USE_LAPACK)
969 {
970 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
971
972 blas_int info = 0;
973
974 blas_int n_rows = U_n_rows;
975 blas_int n_cols = U_n_cols;
976
977
978 lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info);
979
980 // take into account that Fortran counts from 1
981 arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem);
982
983 status = (info == 0);
984 }
985 #endif
986
987 L.copy_size(U);
988
989 for(uword col=0; col < U_n_cols; ++col)
990 {
991 for(uword row=0; (row < col) && (row < U_n_rows); ++row)
992 {
993 L.at(row,col) = eT(0);
994 }
995
996 if( L.in_range(col,col) == true )
997 {
998 L.at(col,col) = eT(1);
999 }
1000
1001 for(uword row = (col+1); row < U_n_rows; ++row)
1002 {
1003 L.at(row,col) = U.at(row,col);
1004 U.at(row,col) = eT(0);
1005 }
1006 }
1007
1008 return status;
1009 }
1010 #else
1011 {
1012 arma_stop("lu(): use of ATLAS or LAPACK needs to be enabled");
1013
1014 return false;
1015 }
1016 #endif
1017 }
1018
1019
1020
1021 template<typename eT, typename T1>
1022 inline
1023 bool
1024 auxlib::lu(Mat<eT>& L, Mat<eT>& U, Mat<eT>& P, const Base<eT,T1>& X)
1025 {
1026 arma_extra_debug_sigprint();
1027
1028 podarray<blas_int> ipiv1;
1029 const bool status = auxlib::lu(L, U, ipiv1, X);
1030
1031 if(status == true)
1032 {
1033 if(U.is_empty())
1034 {
1035 // L and U have been already set to the correct empty matrices
1036 P.eye(L.n_rows, L.n_rows);
1037 return true;
1038 }
1039
1040 const uword n = ipiv1.n_elem;
1041 const uword P_rows = U.n_rows;
1042
1043 podarray<blas_int> ipiv2(P_rows);
1044
1045 const blas_int* ipiv1_mem = ipiv1.memptr();
1046 blas_int* ipiv2_mem = ipiv2.memptr();
1047
1048 for(uword i=0; i<P_rows; ++i)
1049 {
1050 ipiv2_mem[i] = blas_int(i);
1051 }
1052
1053 for(uword i=0; i<n; ++i)
1054 {
1055 const uword k = static_cast<uword>(ipiv1_mem[i]);
1056
1057 if( ipiv2_mem[i] != ipiv2_mem[k] )
1058 {
1059 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
1060 }
1061 }
1062
1063 P.zeros(P_rows, P_rows);
1064
1065 for(uword row=0; row<P_rows; ++row)
1066 {
1067 P.at(row, static_cast<uword>(ipiv2_mem[row])) = eT(1);
1068 }
1069
1070 if(L.n_cols > U.n_rows)
1071 {
1072 L.shed_cols(U.n_rows, L.n_cols-1);
1073 }
1074
1075 if(U.n_rows > L.n_cols)
1076 {
1077 U.shed_rows(L.n_cols, U.n_rows-1);
1078 }
1079 }
1080
1081 return status;
1082 }
1083
1084
1085
1086 template<typename eT, typename T1>
1087 inline
1088 bool
1089 auxlib::lu(Mat<eT>& L, Mat<eT>& U, const Base<eT,T1>& X)
1090 {
1091 arma_extra_debug_sigprint();
1092
1093 podarray<blas_int> ipiv1;
1094 const bool status = auxlib::lu(L, U, ipiv1, X);
1095
1096 if(status == true)
1097 {
1098 if(U.is_empty())
1099 {
1100 // L and U have been already set to the correct empty matrices
1101 return true;
1102 }
1103
1104 const uword n = ipiv1.n_elem;
1105 const uword P_rows = U.n_rows;
1106
1107 podarray<blas_int> ipiv2(P_rows);
1108
1109 const blas_int* ipiv1_mem = ipiv1.memptr();
1110 blas_int* ipiv2_mem = ipiv2.memptr();
1111
1112 for(uword i=0; i<P_rows; ++i)
1113 {
1114 ipiv2_mem[i] = blas_int(i);
1115 }
1116
1117 for(uword i=0; i<n; ++i)
1118 {
1119 const uword k = static_cast<uword>(ipiv1_mem[i]);
1120
1121 if( ipiv2_mem[i] != ipiv2_mem[k] )
1122 {
1123 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
1124 L.swap_rows( static_cast<uword>(ipiv2_mem[i]), static_cast<uword>(ipiv2_mem[k]) );
1125 }
1126 }
1127
1128 if(L.n_cols > U.n_rows)
1129 {
1130 L.shed_cols(U.n_rows, L.n_cols-1);
1131 }
1132
1133 if(U.n_rows > L.n_cols)
1134 {
1135 U.shed_rows(L.n_cols, U.n_rows-1);
1136 }
1137 }
1138
1139 return status;
1140 }
1141
1142
1143
1144 //! immediate eigenvalues of a symmetric real matrix using LAPACK
1145 template<typename eT, typename T1>
1146 inline
1147 bool
1148 auxlib::eig_sym(Col<eT>& eigval, const Base<eT,T1>& X)
1149 {
1150 arma_extra_debug_sigprint();
1151
1152 #if defined(ARMA_USE_LAPACK)
1153 {
1154 Mat<eT> A(X.get_ref());
1155
1156 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square");
1157
1158 if(A.is_empty())
1159 {
1160 eigval.reset();
1161 return true;
1162 }
1163
1164 // rudimentary "better-than-nothing" test for symmetry
1165 //arma_debug_check( (A.at(A.n_rows-1, 0) != A.at(0, A.n_cols-1)), "auxlib::eig(): given matrix is not symmetric" );
1166
1167 char jobz = 'N';
1168 char uplo = 'U';
1169
1170 blas_int n_rows = A.n_rows;
1171 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1);
1172
1173 eigval.set_size( static_cast<uword>(n_rows) );
1174 podarray<eT> work( static_cast<uword>(lwork) );
1175
1176 blas_int info;
1177
1178 arma_extra_debug_print("lapack::syev()");
1179 lapack::syev(&jobz, &uplo, &n_rows, A.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info);
1180
1181 return (info == 0);
1182 }
1183 #else
1184 {
1185 arma_ignore(eigval);
1186 arma_ignore(X);
1187 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
1188 return false;
1189 }
1190 #endif
1191 }
1192
1193
1194
1195 //! immediate eigenvalues of a hermitian complex matrix using LAPACK
1196 template<typename T, typename T1>
1197 inline
1198 bool
1199 auxlib::eig_sym(Col<T>& eigval, const Base<std::complex<T>,T1>& X)
1200 {
1201 arma_extra_debug_sigprint();
1202
1203 typedef typename std::complex<T> eT;
1204
1205 #if defined(ARMA_USE_LAPACK)
1206 {
1207 Mat<eT> A(X.get_ref());
1208 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not hermitian");
1209
1210 if(A.is_empty())
1211 {
1212 eigval.reset();
1213 return true;
1214 }
1215
1216 char jobz = 'N';
1217 char uplo = 'U';
1218
1219 blas_int n_rows = A.n_rows;
1220 blas_int lda = A.n_rows;
1221 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1); // TODO: automatically find best size of lwork
1222
1223 eigval.set_size( static_cast<uword>(n_rows) );
1224
1225 podarray<eT> work( static_cast<uword>(lwork) );
1226 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) );
1227
1228 blas_int info;
1229
1230 arma_extra_debug_print("lapack::heev()");
1231 lapack::heev(&jobz, &uplo, &n_rows, A.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
1232
1233 return (info == 0);
1234 }
1235 #else
1236 {
1237 arma_ignore(eigval);
1238 arma_ignore(X);
1239 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
1240 return false;
1241 }
1242 #endif
1243 }
1244
1245
1246
1247 //! immediate eigenvalues and eigenvectors of a symmetric real matrix using LAPACK
1248 template<typename eT, typename T1>
1249 inline
1250 bool
1251 auxlib::eig_sym(Col<eT>& eigval, Mat<eT>& eigvec, const Base<eT,T1>& X)
1252 {
1253 arma_extra_debug_sigprint();
1254
1255 #if defined(ARMA_USE_LAPACK)
1256 {
1257 eigvec = X.get_ref();
1258
1259 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
1260
1261 if(eigvec.is_empty())
1262 {
1263 eigval.reset();
1264 eigvec.reset();
1265 return true;
1266 }
1267
1268 // rudimentary "better-than-nothing" test for symmetry
1269 //arma_debug_check( (A.at(A.n_rows-1, 0) != A.at(0, A.n_cols-1)), "auxlib::eig(): given matrix is not symmetric" );
1270
1271 char jobz = 'V';
1272 char uplo = 'U';
1273
1274 blas_int n_rows = eigvec.n_rows;
1275 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1);
1276
1277 eigval.set_size( static_cast<uword>(n_rows) );
1278 podarray<eT> work( static_cast<uword>(lwork) );
1279
1280 blas_int info;
1281
1282 arma_extra_debug_print("lapack::syev()");
1283 lapack::syev(&jobz, &uplo, &n_rows, eigvec.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info);
1284
1285 return (info == 0);
1286 }
1287 #else
1288 {
1289 arma_ignore(eigval);
1290 arma_ignore(eigvec);
1291 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
1292
1293 return false;
1294 }
1295 #endif
1296 }
1297
1298
1299
1300 //! immediate eigenvalues and eigenvectors of a hermitian complex matrix using LAPACK
1301 template<typename T, typename T1>
1302 inline
1303 bool
1304 auxlib::eig_sym(Col<T>& eigval, Mat< std::complex<T> >& eigvec, const Base<std::complex<T>,T1>& X)
1305 {
1306 arma_extra_debug_sigprint();
1307
1308 typedef typename std::complex<T> eT;
1309
1310 #if defined(ARMA_USE_LAPACK)
1311 {
1312 eigvec = X.get_ref();
1313
1314 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not hermitian" );
1315
1316 if(eigvec.is_empty())
1317 {
1318 eigval.reset();
1319 eigvec.reset();
1320 return true;
1321 }
1322
1323 char jobz = 'V';
1324 char uplo = 'U';
1325
1326 blas_int n_rows = eigvec.n_rows;
1327 blas_int lda = eigvec.n_rows;
1328 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1); // TODO: automatically find best size of lwork
1329
1330 eigval.set_size( static_cast<uword>(n_rows) );
1331
1332 podarray<eT> work( static_cast<uword>(lwork) );
1333 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) );
1334
1335 blas_int info;
1336
1337 arma_extra_debug_print("lapack::heev()");
1338 lapack::heev(&jobz, &uplo, &n_rows, eigvec.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
1339
1340 return (info == 0);
1341 }
1342 #else
1343 {
1344 arma_ignore(eigval);
1345 arma_ignore(eigvec);
1346 arma_ignore(X);
1347 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
1348 return false;
1349 }
1350 #endif
1351 }
1352
1353
1354
1355 //! Eigenvalues and eigenvectors of a general square real matrix using LAPACK.
1356 //! The argument 'side' specifies which eigenvectors should be calculated
1357 //! (see code for mode details).
1358 template<typename T, typename T1>
1359 inline
1360 bool
1361 auxlib::eig_gen
1362 (
1363 Col< std::complex<T> >& eigval,
1364 Mat<T>& l_eigvec,
1365 Mat<T>& r_eigvec,
1366 const Base<T,T1>& X,
1367 const char side
1368 )
1369 {
1370 arma_extra_debug_sigprint();
1371
1372 #if defined(ARMA_USE_LAPACK)
1373 {
1374 char jobvl;
1375 char jobvr;
1376
1377 switch(side)
1378 {
1379 case 'l': // left
1380 jobvl = 'V';
1381 jobvr = 'N';
1382 break;
1383
1384 case 'r': // right
1385 jobvl = 'N';
1386 jobvr = 'V';
1387 break;
1388
1389 case 'b': // both
1390 jobvl = 'V';
1391 jobvr = 'V';
1392 break;
1393
1394 case 'n': // neither
1395 jobvl = 'N';
1396 jobvr = 'N';
1397 break;
1398
1399 default:
1400 arma_stop("eig_gen(): parameter 'side' is invalid");
1401 return false;
1402 }
1403
1404 Mat<T> A(X.get_ref());
1405 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
1406
1407 if(A.is_empty())
1408 {
1409 eigval.reset();
1410 l_eigvec.reset();
1411 r_eigvec.reset();
1412 return true;
1413 }
1414
1415 uword A_n_rows = A.n_rows;
1416
1417 blas_int n_rows = A_n_rows;
1418 blas_int lda = A_n_rows;
1419 blas_int lwork = (std::max)(blas_int(1), 4*n_rows); // TODO: automatically find best size of lwork
1420
1421 eigval.set_size(A_n_rows);
1422 l_eigvec.set_size(A_n_rows, A_n_rows);
1423 r_eigvec.set_size(A_n_rows, A_n_rows);
1424
1425 podarray<T> work( static_cast<uword>(lwork) );
1426 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) );
1427
1428 podarray<T> wr(A_n_rows);
1429 podarray<T> wi(A_n_rows);
1430
1431 Mat<T> A_copy = A;
1432 blas_int info;
1433
1434 arma_extra_debug_print("lapack::geev()");
1435 lapack::geev(&jobvl, &jobvr, &n_rows, A_copy.memptr(), &lda, wr.memptr(), wi.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, &info);
1436
1437
1438 eigval.set_size(A_n_rows);
1439 for(uword i=0; i<A_n_rows; ++i)
1440 {
1441 eigval[i] = std::complex<T>(wr[i], wi[i]);
1442 }
1443
1444 return (info == 0);
1445 }
1446 #else
1447 {
1448 arma_ignore(eigval);
1449 arma_ignore(l_eigvec);
1450 arma_ignore(r_eigvec);
1451 arma_ignore(X);
1452 arma_ignore(side);
1453 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
1454 return false;
1455 }
1456 #endif
1457 }
1458
1459
1460
1461
1462
1463 //! Eigenvalues and eigenvectors of a general square complex matrix using LAPACK
1464 //! The argument 'side' specifies which eigenvectors should be calculated
1465 //! (see code for mode details).
1466 template<typename T, typename T1>
1467 inline
1468 bool
1469 auxlib::eig_gen
1470 (
1471 Col< std::complex<T> >& eigval,
1472 Mat< std::complex<T> >& l_eigvec,
1473 Mat< std::complex<T> >& r_eigvec,
1474 const Base< std::complex<T>, T1 >& X,
1475 const char side
1476 )
1477 {
1478 arma_extra_debug_sigprint();
1479
1480 typedef typename std::complex<T> eT;
1481
1482 #if defined(ARMA_USE_LAPACK)
1483 {
1484 char jobvl;
1485 char jobvr;
1486
1487 switch(side)
1488 {
1489 case 'l': // left
1490 jobvl = 'V';
1491 jobvr = 'N';
1492 break;
1493
1494 case 'r': // right
1495 jobvl = 'N';
1496 jobvr = 'V';
1497 break;
1498
1499 case 'b': // both
1500 jobvl = 'V';
1501 jobvr = 'V';
1502 break;
1503
1504 case 'n': // neither
1505 jobvl = 'N';
1506 jobvr = 'N';
1507 break;
1508
1509 default:
1510 arma_stop("eig_gen(): parameter 'side' is invalid");
1511 return false;
1512 }
1513
1514 Mat<eT> A(X.get_ref());
1515 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
1516
1517 if(A.is_empty())
1518 {
1519 eigval.reset();
1520 l_eigvec.reset();
1521 r_eigvec.reset();
1522 return true;
1523 }
1524
1525 uword A_n_rows = A.n_rows;
1526
1527 blas_int n_rows = A_n_rows;
1528 blas_int lda = A_n_rows;
1529 blas_int lwork = (std::max)(blas_int(1), 4*n_rows); // TODO: automatically find best size of lwork
1530
1531 eigval.set_size(A_n_rows);
1532 l_eigvec.set_size(A_n_rows, A_n_rows);
1533 r_eigvec.set_size(A_n_rows, A_n_rows);
1534
1535 podarray<eT> work( static_cast<uword>(lwork) );
1536 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) ); // was 2,3
1537
1538 blas_int info;
1539
1540 arma_extra_debug_print("lapack::cx_geev()");
1541 lapack::cx_geev(&jobvl, &jobvr, &n_rows, A.memptr(), &lda, eigval.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, rwork.memptr(), &info);
1542
1543 return (info == 0);
1544 }
1545 #else
1546 {
1547 arma_ignore(eigval);
1548 arma_ignore(l_eigvec);
1549 arma_ignore(r_eigvec);
1550 arma_ignore(X);
1551 arma_ignore(side);
1552 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
1553 return false;
1554 }
1555 #endif
1556 }
1557
1558
1559
1560 template<typename eT, typename T1>
1561 inline
1562 bool
1563 auxlib::chol(Mat<eT>& out, const Base<eT,T1>& X)
1564 {
1565 arma_extra_debug_sigprint();
1566
1567 #if defined(ARMA_USE_LAPACK)
1568 {
1569 out = X.get_ref();
1570
1571 arma_debug_check( (out.is_square() == false), "chol(): given matrix is not square" );
1572
1573 if(out.is_empty())
1574 {
1575 return true;
1576 }
1577
1578 const uword out_n_rows = out.n_rows;
1579
1580 char uplo = 'U';
1581 blas_int n = out_n_rows;
1582 blas_int info;
1583
1584 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
1585
1586 for(uword col=0; col<out_n_rows; ++col)
1587 {
1588 eT* colptr = out.colptr(col);
1589
1590 for(uword row=(col+1); row < out_n_rows; ++row)
1591 {
1592 colptr[row] = eT(0);
1593 }
1594 }
1595
1596 return (info == 0);
1597 }
1598 #else
1599 {
1600 arma_ignore(out);
1601 arma_stop("chol(): use of LAPACK needs to be enabled");
1602 return false;
1603 }
1604 #endif
1605 }
1606
1607
1608
1609 template<typename eT, typename T1>
1610 inline
1611 bool
1612 auxlib::qr(Mat<eT>& Q, Mat<eT>& R, const Base<eT,T1>& X)
1613 {
1614 arma_extra_debug_sigprint();
1615
1616 #if defined(ARMA_USE_LAPACK)
1617 {
1618 R = X.get_ref();
1619
1620 const uword R_n_rows = R.n_rows;
1621 const uword R_n_cols = R.n_cols;
1622
1623 if(R.is_empty())
1624 {
1625 Q.eye(R_n_rows, R_n_rows);
1626 return true;
1627 }
1628
1629 blas_int m = static_cast<blas_int>(R_n_rows);
1630 blas_int n = static_cast<blas_int>(R_n_cols);
1631 blas_int work_len = (std::max)(blas_int(1),n);
1632 blas_int work_len_tmp;
1633 blas_int k = (std::min)(m,n);
1634 blas_int info;
1635
1636 podarray<eT> tau( static_cast<uword>(k) );
1637 podarray<eT> work( static_cast<uword>(work_len) );
1638
1639 // query for the optimum value of work_len
1640 work_len_tmp = -1;
1641 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
1642
1643 if(info == 0)
1644 {
1645 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
1646 work.set_size( static_cast<uword>(work_len) );
1647 }
1648
1649 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
1650
1651 Q.set_size(R_n_rows, R_n_rows);
1652
1653 arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) );
1654
1655 //
1656 // construct R
1657
1658 for(uword col=0; col < R_n_cols; ++col)
1659 {
1660 for(uword row=(col+1); row < R_n_rows; ++row)
1661 {
1662 R.at(row,col) = eT(0);
1663 }
1664 }
1665
1666
1667 if( (is_float<eT>::value == true) || (is_double<eT>::value == true) )
1668 {
1669 // query for the optimum value of work_len
1670 work_len_tmp = -1;
1671 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
1672
1673 if(info == 0)
1674 {
1675 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
1676 work.set_size( static_cast<uword>(work_len) );
1677 }
1678
1679 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
1680 }
1681 else
1682 if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) )
1683 {
1684 // query for the optimum value of work_len
1685 work_len_tmp = -1;
1686 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
1687
1688 if(info == 0)
1689 {
1690 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
1691 work.set_size( static_cast<uword>(work_len) );
1692 }
1693
1694 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
1695 }
1696
1697 return (info == 0);
1698 }
1699 #else
1700 {
1701 arma_ignore(Q);
1702 arma_ignore(R);
1703 arma_ignore(X);
1704 arma_stop("qr(): use of LAPACK needs to be enabled");
1705 return false;
1706 }
1707 #endif
1708 }
1709
1710
1711
1712 template<typename eT, typename T1>
1713 inline
1714 bool
1715 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X, uword& X_n_rows, uword& X_n_cols)
1716 {
1717 arma_extra_debug_sigprint();
1718
1719 #if defined(ARMA_USE_LAPACK)
1720 {
1721 Mat<eT> A(X.get_ref());
1722
1723 X_n_rows = A.n_rows;
1724 X_n_cols = A.n_cols;
1725
1726 if(A.is_empty())
1727 {
1728 S.reset();
1729 return true;
1730 }
1731
1732 Mat<eT> U(1, 1);
1733 Mat<eT> V(1, A.n_cols);
1734
1735 char jobu = 'N';
1736 char jobvt = 'N';
1737
1738 blas_int m = A.n_rows;
1739 blas_int n = A.n_cols;
1740 blas_int lda = A.n_rows;
1741 blas_int ldu = U.n_rows;
1742 blas_int ldvt = V.n_rows;
1743 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
1744 blas_int info;
1745
1746 S.set_size( static_cast<uword>((std::min)(m, n)) );
1747
1748 podarray<eT> work( static_cast<uword>(lwork) );
1749
1750
1751 // let gesvd_() calculate the optimum size of the workspace
1752 blas_int lwork_tmp = -1;
1753
1754 lapack::gesvd<eT>
1755 (
1756 &jobu, &jobvt,
1757 &m,&n,
1758 A.memptr(), &lda,
1759 S.memptr(),
1760 U.memptr(), &ldu,
1761 V.memptr(), &ldvt,
1762 work.memptr(), &lwork_tmp,
1763 &info
1764 );
1765
1766 if(info == 0)
1767 {
1768 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
1769
1770 if(proposed_lwork > lwork)
1771 {
1772 lwork = proposed_lwork;
1773 work.set_size( static_cast<uword>(lwork) );
1774 }
1775
1776 lapack::gesvd<eT>
1777 (
1778 &jobu, &jobvt,
1779 &m, &n,
1780 A.memptr(), &lda,
1781 S.memptr(),
1782 U.memptr(), &ldu,
1783 V.memptr(), &ldvt,
1784 work.memptr(), &lwork,
1785 &info
1786 );
1787 }
1788
1789 return (info == 0);
1790 }
1791 #else
1792 {
1793 arma_ignore(S);
1794 arma_ignore(X);
1795 arma_ignore(X_n_rows);
1796 arma_ignore(X_n_cols);
1797 arma_stop("svd(): use of LAPACK needs to be enabled");
1798 return false;
1799 }
1800 #endif
1801 }
1802
1803
1804
1805 template<typename T, typename T1>
1806 inline
1807 bool
1808 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X, uword& X_n_rows, uword& X_n_cols)
1809 {
1810 arma_extra_debug_sigprint();
1811
1812 typedef std::complex<T> eT;
1813
1814 #if defined(ARMA_USE_LAPACK)
1815 {
1816 Mat<eT> A(X.get_ref());
1817
1818 X_n_rows = A.n_rows;
1819 X_n_cols = A.n_cols;
1820
1821 if(A.is_empty())
1822 {
1823 S.reset();
1824 return true;
1825 }
1826
1827 Mat<eT> U(1, 1);
1828 Mat<eT> V(1, A.n_cols);
1829
1830 char jobu = 'N';
1831 char jobvt = 'N';
1832
1833 blas_int m = A.n_rows;
1834 blas_int n = A.n_cols;
1835 blas_int lda = A.n_rows;
1836 blas_int ldu = U.n_rows;
1837 blas_int ldvt = V.n_rows;
1838 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) );
1839 blas_int info;
1840
1841 S.set_size( static_cast<uword>((std::min)(m,n)) );
1842
1843 podarray<eT> work( static_cast<uword>(lwork) );
1844 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
1845
1846 // let gesvd_() calculate the optimum size of the workspace
1847 blas_int lwork_tmp = -1;
1848
1849 lapack::cx_gesvd<T>
1850 (
1851 &jobu, &jobvt,
1852 &m, &n,
1853 A.memptr(), &lda,
1854 S.memptr(),
1855 U.memptr(), &ldu,
1856 V.memptr(), &ldvt,
1857 work.memptr(), &lwork_tmp,
1858 rwork.memptr(),
1859 &info
1860 );
1861
1862 if(info == 0)
1863 {
1864 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
1865 if(proposed_lwork > lwork)
1866 {
1867 lwork = proposed_lwork;
1868 work.set_size( static_cast<uword>(lwork) );
1869 }
1870
1871 lapack::cx_gesvd<T>
1872 (
1873 &jobu, &jobvt,
1874 &m, &n,
1875 A.memptr(), &lda,
1876 S.memptr(),
1877 U.memptr(), &ldu,
1878 V.memptr(), &ldvt,
1879 work.memptr(), &lwork,
1880 rwork.memptr(),
1881 &info
1882 );
1883 }
1884
1885 return (info == 0);
1886 }
1887 #else
1888 {
1889 arma_ignore(S);
1890 arma_ignore(X);
1891 arma_ignore(X_n_rows);
1892 arma_ignore(X_n_cols);
1893
1894 arma_stop("svd(): use of LAPACK needs to be enabled");
1895 return false;
1896 }
1897 #endif
1898 }
1899
1900
1901
1902 template<typename eT, typename T1>
1903 inline
1904 bool
1905 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X)
1906 {
1907 arma_extra_debug_sigprint();
1908
1909 uword junk;
1910 return auxlib::svd(S, X, junk, junk);
1911 }
1912
1913
1914
1915 template<typename T, typename T1>
1916 inline
1917 bool
1918 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X)
1919 {
1920 arma_extra_debug_sigprint();
1921
1922 uword junk;
1923 return auxlib::svd(S, X, junk, junk);
1924 }
1925
1926
1927
1928 template<typename eT, typename T1>
1929 inline
1930 bool
1931 auxlib::svd(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X)
1932 {
1933 arma_extra_debug_sigprint();
1934
1935 #if defined(ARMA_USE_LAPACK)
1936 {
1937 Mat<eT> A(X.get_ref());
1938
1939 if(A.is_empty())
1940 {
1941 U.eye(A.n_rows, A.n_rows);
1942 S.reset();
1943 V.eye(A.n_cols, A.n_cols);
1944 return true;
1945 }
1946
1947 U.set_size(A.n_rows, A.n_rows);
1948 V.set_size(A.n_cols, A.n_cols);
1949
1950 char jobu = 'A';
1951 char jobvt = 'A';
1952
1953 blas_int m = A.n_rows;
1954 blas_int n = A.n_cols;
1955 blas_int lda = A.n_rows;
1956 blas_int ldu = U.n_rows;
1957 blas_int ldvt = V.n_rows;
1958 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
1959 blas_int info;
1960
1961
1962 S.set_size( static_cast<uword>((std::min)(m,n)) );
1963 podarray<eT> work( static_cast<uword>(lwork) );
1964
1965 // let gesvd_() calculate the optimum size of the workspace
1966 blas_int lwork_tmp = -1;
1967
1968 lapack::gesvd<eT>
1969 (
1970 &jobu, &jobvt,
1971 &m, &n,
1972 A.memptr(), &lda,
1973 S.memptr(),
1974 U.memptr(), &ldu,
1975 V.memptr(), &ldvt,
1976 work.memptr(), &lwork_tmp,
1977 &info
1978 );
1979
1980 if(info == 0)
1981 {
1982 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
1983 if(proposed_lwork > lwork)
1984 {
1985 lwork = proposed_lwork;
1986 work.set_size( static_cast<uword>(lwork) );
1987 }
1988
1989 lapack::gesvd<eT>
1990 (
1991 &jobu, &jobvt,
1992 &m, &n,
1993 A.memptr(), &lda,
1994 S.memptr(),
1995 U.memptr(), &ldu,
1996 V.memptr(), &ldvt,
1997 work.memptr(), &lwork,
1998 &info
1999 );
2000
2001 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
2002 }
2003
2004 return (info == 0);
2005 }
2006 #else
2007 {
2008 arma_ignore(U);
2009 arma_ignore(S);
2010 arma_ignore(V);
2011 arma_ignore(X);
2012 arma_stop("svd(): use of LAPACK needs to be enabled");
2013 return false;
2014 }
2015 #endif
2016 }
2017
2018
2019
2020 template<typename T, typename T1>
2021 inline
2022 bool
2023 auxlib::svd(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X)
2024 {
2025 arma_extra_debug_sigprint();
2026
2027 typedef std::complex<T> eT;
2028
2029 #if defined(ARMA_USE_LAPACK)
2030 {
2031 Mat<eT> A(X.get_ref());
2032
2033 if(A.is_empty())
2034 {
2035 U.eye(A.n_rows, A.n_rows);
2036 S.reset();
2037 V.eye(A.n_cols, A.n_cols);
2038 return true;
2039 }
2040
2041 U.set_size(A.n_rows, A.n_rows);
2042 V.set_size(A.n_cols, A.n_cols);
2043
2044 char jobu = 'A';
2045 char jobvt = 'A';
2046
2047 blas_int m = A.n_rows;
2048 blas_int n = A.n_cols;
2049 blas_int lda = A.n_rows;
2050 blas_int ldu = U.n_rows;
2051 blas_int ldvt = V.n_rows;
2052 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) );
2053 blas_int info;
2054
2055 S.set_size( static_cast<uword>((std::min)(m,n)) );
2056
2057 podarray<eT> work( static_cast<uword>(lwork) );
2058 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
2059
2060 // let gesvd_() calculate the optimum size of the workspace
2061 blas_int lwork_tmp = -1;
2062 lapack::cx_gesvd<T>
2063 (
2064 &jobu, &jobvt,
2065 &m, &n,
2066 A.memptr(), &lda,
2067 S.memptr(),
2068 U.memptr(), &ldu,
2069 V.memptr(), &ldvt,
2070 work.memptr(), &lwork_tmp,
2071 rwork.memptr(),
2072 &info
2073 );
2074
2075 if(info == 0)
2076 {
2077 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
2078 if(proposed_lwork > lwork)
2079 {
2080 lwork = proposed_lwork;
2081 work.set_size( static_cast<uword>(lwork) );
2082 }
2083
2084 lapack::cx_gesvd<T>
2085 (
2086 &jobu, &jobvt,
2087 &m, &n,
2088 A.memptr(), &lda,
2089 S.memptr(),
2090 U.memptr(), &ldu,
2091 V.memptr(), &ldvt,
2092 work.memptr(), &lwork,
2093 rwork.memptr(),
2094 &info
2095 );
2096
2097 op_htrans::apply(V,V); // op_htrans will work out that an in-place transpose can be done
2098 }
2099
2100 return (info == 0);
2101 }
2102 #else
2103 {
2104 arma_ignore(U);
2105 arma_ignore(S);
2106 arma_ignore(V);
2107 arma_ignore(X);
2108 arma_stop("svd(): use of LAPACK needs to be enabled");
2109 return false;
2110 }
2111 #endif
2112
2113 }
2114
2115
2116
2117 template<typename eT, typename T1>
2118 inline
2119 bool
2120 auxlib::svd_econ(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X, const char mode)
2121 {
2122 arma_extra_debug_sigprint();
2123
2124 #if defined(ARMA_USE_LAPACK)
2125 {
2126 Mat<eT> A(X.get_ref());
2127
2128 blas_int m = A.n_rows;
2129 blas_int n = A.n_cols;
2130 blas_int lda = A.n_rows;
2131
2132 S.set_size( static_cast<uword>((std::min)(m,n)) );
2133
2134 blas_int ldu = 0;
2135 blas_int ldvt = 0;
2136
2137 char jobu;
2138 char jobvt;
2139
2140 switch(mode)
2141 {
2142 case 'l':
2143 jobu = 'S';
2144 jobvt = 'N';
2145
2146 ldu = m;
2147 ldvt = 1;
2148
2149 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
2150 V.reset();
2151
2152 break;
2153
2154
2155 case 'r':
2156 jobu = 'N';
2157 jobvt = 'S';
2158
2159 ldu = 1;
2160 ldvt = (std::min)(m,n);
2161
2162 U.reset();
2163 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
2164
2165 break;
2166
2167
2168 case 'b':
2169 jobu = 'S';
2170 jobvt = 'S';
2171
2172 ldu = m;
2173 ldvt = (std::min)(m,n);
2174
2175 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
2176 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
2177
2178 break;
2179
2180
2181 default:
2182 U.reset();
2183 S.reset();
2184 V.reset();
2185 return false;
2186 }
2187
2188
2189 if(A.is_empty())
2190 {
2191 U.eye();
2192 S.reset();
2193 V.eye();
2194 return true;
2195 }
2196
2197
2198 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
2199 blas_int info = 0;
2200
2201
2202 podarray<eT> work( static_cast<uword>(lwork) );
2203
2204 // let gesvd_() calculate the optimum size of the workspace
2205 blas_int lwork_tmp = -1;
2206
2207 lapack::gesvd<eT>
2208 (
2209 &jobu, &jobvt,
2210 &m, &n,
2211 A.memptr(), &lda,
2212 S.memptr(),
2213 U.memptr(), &ldu,
2214 V.memptr(), &ldvt,
2215 work.memptr(), &lwork_tmp,
2216 &info
2217 );
2218
2219 if(info == 0)
2220 {
2221 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
2222 if(proposed_lwork > lwork)
2223 {
2224 lwork = proposed_lwork;
2225 work.set_size( static_cast<uword>(lwork) );
2226 }
2227
2228 lapack::gesvd<eT>
2229 (
2230 &jobu, &jobvt,
2231 &m, &n,
2232 A.memptr(), &lda,
2233 S.memptr(),
2234 U.memptr(), &ldu,
2235 V.memptr(), &ldvt,
2236 work.memptr(), &lwork,
2237 &info
2238 );
2239
2240 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
2241 }
2242
2243 return (info == 0);
2244 }
2245 #else
2246 {
2247 arma_ignore(U);
2248 arma_ignore(S);
2249 arma_ignore(V);
2250 arma_ignore(X);
2251 arma_ignore(mode);
2252 arma_stop("svd(): use of LAPACK needs to be enabled");
2253 return false;
2254 }
2255 #endif
2256 }
2257
2258
2259
2260 template<typename T, typename T1>
2261 inline
2262 bool
2263 auxlib::svd_econ(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X, const char mode)
2264 {
2265 arma_extra_debug_sigprint();
2266
2267 typedef std::complex<T> eT;
2268
2269 #if defined(ARMA_USE_LAPACK)
2270 {
2271 Mat<eT> A(X.get_ref());
2272
2273 blas_int m = A.n_rows;
2274 blas_int n = A.n_cols;
2275 blas_int lda = A.n_rows;
2276
2277 S.set_size( static_cast<uword>((std::min)(m,n)) );
2278
2279 blas_int ldu = 0;
2280 blas_int ldvt = 0;
2281
2282 char jobu;
2283 char jobvt;
2284
2285 switch(mode)
2286 {
2287 case 'l':
2288 jobu = 'S';
2289 jobvt = 'N';
2290
2291 ldu = m;
2292 ldvt = 1;
2293
2294 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
2295 V.reset();
2296
2297 break;
2298
2299
2300 case 'r':
2301 jobu = 'N';
2302 jobvt = 'S';
2303
2304 ldu = 1;
2305 ldvt = (std::min)(m,n);
2306
2307 U.reset();
2308 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
2309
2310 break;
2311
2312
2313 case 'b':
2314 jobu = 'S';
2315 jobvt = 'S';
2316
2317 ldu = m;
2318 ldvt = (std::min)(m,n);
2319
2320 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
2321 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
2322
2323 break;
2324
2325
2326 default:
2327 U.reset();
2328 S.reset();
2329 V.reset();
2330 return false;
2331 }
2332
2333
2334 if(A.is_empty())
2335 {
2336 U.eye();
2337 S.reset();
2338 V.eye();
2339 return true;
2340 }
2341
2342
2343 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
2344 blas_int info = 0;
2345
2346
2347 podarray<eT> work( static_cast<uword>(lwork) );
2348 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
2349
2350 // let gesvd_() calculate the optimum size of the workspace
2351 blas_int lwork_tmp = -1;
2352
2353 lapack::cx_gesvd<T>
2354 (
2355 &jobu, &jobvt,
2356 &m, &n,
2357 A.memptr(), &lda,
2358 S.memptr(),
2359 U.memptr(), &ldu,
2360 V.memptr(), &ldvt,
2361 work.memptr(), &lwork_tmp,
2362 rwork.memptr(),
2363 &info
2364 );
2365
2366 if(info == 0)
2367 {
2368 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
2369 if(proposed_lwork > lwork)
2370 {
2371 lwork = proposed_lwork;
2372 work.set_size( static_cast<uword>(lwork) );
2373 }
2374
2375 lapack::cx_gesvd<T>
2376 (
2377 &jobu, &jobvt,
2378 &m, &n,
2379 A.memptr(), &lda,
2380 S.memptr(),
2381 U.memptr(), &ldu,
2382 V.memptr(), &ldvt,
2383 work.memptr(), &lwork,
2384 rwork.memptr(),
2385 &info
2386 );
2387
2388 op_htrans::apply(V,V); // op_strans will work out that an in-place transpose can be done
2389 }
2390
2391 return (info == 0);
2392 }
2393 #else
2394 {
2395 arma_ignore(U);
2396 arma_ignore(S);
2397 arma_ignore(V);
2398 arma_ignore(X);
2399 arma_ignore(mode);
2400 arma_stop("svd(): use of LAPACK needs to be enabled");
2401 return false;
2402 }
2403 #endif
2404 }
2405
2406
2407
2408 //! Solve a system of linear equations.
2409 //! Assumes that A.n_rows = A.n_cols and B.n_rows = A.n_rows
2410 template<typename eT>
2411 inline
2412 bool
2413 auxlib::solve(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B, const bool slow)
2414 {
2415 arma_extra_debug_sigprint();
2416
2417 if(A.is_empty() || B.is_empty())
2418 {
2419 out.zeros(A.n_cols, B.n_cols);
2420 return true;
2421 }
2422 else
2423 {
2424 const uword A_n_rows = A.n_rows;
2425
2426 bool status = false;
2427
2428 if( (A_n_rows <= 4) && (slow == false) )
2429 {
2430 Mat<eT> A_inv;
2431
2432 status = auxlib::inv_noalias_tinymat(A_inv, A, A_n_rows);
2433
2434 if(status == true)
2435 {
2436 out.set_size(A_n_rows, B.n_cols);
2437
2438 gemm_emul<false,false,false,false>::apply(out, A_inv, B);
2439
2440 return true;
2441 }
2442 }
2443
2444 if( (A_n_rows > 4) || (status == false) )
2445 {
2446 #if defined(ARMA_USE_ATLAS)
2447 {
2448 podarray<int> ipiv(A_n_rows);
2449
2450 out = B;
2451
2452 int info = atlas::clapack_gesv<eT>(atlas::CblasColMajor, A_n_rows, B.n_cols, A.memptr(), A_n_rows, ipiv.memptr(), out.memptr(), A_n_rows);
2453
2454 return (info == 0);
2455 }
2456 #elif defined(ARMA_USE_LAPACK)
2457 {
2458 blas_int n = A_n_rows;
2459 blas_int lda = A_n_rows;
2460 blas_int ldb = A_n_rows;
2461 blas_int nrhs = B.n_cols;
2462 blas_int info;
2463
2464 podarray<blas_int> ipiv(A_n_rows);
2465
2466 out = B;
2467
2468 lapack::gesv<eT>(&n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info);
2469
2470 return (info == 0);
2471 }
2472 #else
2473 {
2474 arma_stop("solve(): use of ATLAS or LAPACK needs to be enabled");
2475 return false;
2476 }
2477 #endif
2478 }
2479 }
2480
2481 return true;
2482 }
2483
2484
2485
2486 //! Solve an over-determined system.
2487 //! Assumes that A.n_rows > A.n_cols and B.n_rows = A.n_rows
2488 template<typename eT>
2489 inline
2490 bool
2491 auxlib::solve_od(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B)
2492 {
2493 arma_extra_debug_sigprint();
2494
2495 #if defined(ARMA_USE_LAPACK)
2496 {
2497 if(A.is_empty() || B.is_empty())
2498 {
2499 out.zeros(A.n_cols, B.n_cols);
2500 return true;
2501 }
2502
2503 char trans = 'N';
2504
2505 blas_int m = A.n_rows;
2506 blas_int n = A.n_cols;
2507 blas_int lda = A.n_rows;
2508 blas_int ldb = A.n_rows;
2509 blas_int nrhs = B.n_cols;
2510 blas_int lwork = n + (std::max)(n, nrhs);
2511 blas_int info;
2512
2513 Mat<eT> tmp = B;
2514
2515 podarray<eT> work( static_cast<uword>(lwork) );
2516
2517 arma_extra_debug_print("lapack::gels()");
2518
2519 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems
2520
2521 lapack::gels<eT>
2522 (
2523 &trans, &m, &n, &nrhs,
2524 A.memptr(), &lda,
2525 tmp.memptr(), &ldb,
2526 work.memptr(), &lwork,
2527 &info
2528 );
2529
2530 arma_extra_debug_print("lapack::gels() -- finished");
2531
2532 out.set_size(A.n_cols, B.n_cols);
2533
2534 for(uword col=0; col<B.n_cols; ++col)
2535 {
2536 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols );
2537 }
2538
2539 return (info == 0);
2540 }
2541 #else
2542 {
2543 arma_ignore(out);
2544 arma_ignore(A);
2545 arma_ignore(B);
2546 arma_stop("solve(): use of LAPACK needs to be enabled");
2547 return false;
2548 }
2549 #endif
2550 }
2551
2552
2553
2554 //! Solve an under-determined system.
2555 //! Assumes that A.n_rows < A.n_cols and B.n_rows = A.n_rows
2556 template<typename eT>
2557 inline
2558 bool
2559 auxlib::solve_ud(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B)
2560 {
2561 arma_extra_debug_sigprint();
2562
2563 #if defined(ARMA_USE_LAPACK)
2564 {
2565 if(A.is_empty() || B.is_empty())
2566 {
2567 out.zeros(A.n_cols, B.n_cols);
2568 return true;
2569 }
2570
2571 char trans = 'N';
2572
2573 blas_int m = A.n_rows;
2574 blas_int n = A.n_cols;
2575 blas_int lda = A.n_rows;
2576 blas_int ldb = A.n_cols;
2577 blas_int nrhs = B.n_cols;
2578 blas_int lwork = m + (std::max)(m,nrhs);
2579 blas_int info;
2580
2581
2582 Mat<eT> tmp;
2583 tmp.zeros(A.n_cols, B.n_cols);
2584
2585 for(uword col=0; col<B.n_cols; ++col)
2586 {
2587 eT* tmp_colmem = tmp.colptr(col);
2588
2589 arrayops::copy( tmp_colmem, B.colptr(col), B.n_rows );
2590
2591 for(uword row=B.n_rows; row<A.n_cols; ++row)
2592 {
2593 tmp_colmem[row] = eT(0);
2594 }
2595 }
2596
2597 podarray<eT> work( static_cast<uword>(lwork) );
2598
2599 arma_extra_debug_print("lapack::gels()");
2600
2601 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems
2602
2603 lapack::gels<eT>
2604 (
2605 &trans, &m, &n, &nrhs,
2606 A.memptr(), &lda,
2607 tmp.memptr(), &ldb,
2608 work.memptr(), &lwork,
2609 &info
2610 );
2611
2612 arma_extra_debug_print("lapack::gels() -- finished");
2613
2614 out.set_size(A.n_cols, B.n_cols);
2615
2616 for(uword col=0; col<B.n_cols; ++col)
2617 {
2618 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols );
2619 }
2620
2621 return (info == 0);
2622 }
2623 #else
2624 {
2625 arma_ignore(out);
2626 arma_ignore(A);
2627 arma_ignore(B);
2628 arma_stop("solve(): use of LAPACK needs to be enabled");
2629 return false;
2630 }
2631 #endif
2632 }
2633
2634
2635
2636 //
2637 // solve_tr
2638
2639 template<typename eT>
2640 inline
2641 bool
2642 auxlib::solve_tr(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const uword layout)
2643 {
2644 arma_extra_debug_sigprint();
2645
2646 #if defined(ARMA_USE_LAPACK)
2647 {
2648 if(A.is_empty() || B.is_empty())
2649 {
2650 out.zeros(A.n_cols, B.n_cols);
2651 return true;
2652 }
2653
2654 out = B;
2655
2656 char uplo = (layout == 0) ? 'U' : 'L';
2657 char trans = 'N';
2658 char diag = 'N';
2659 blas_int n = blas_int(A.n_rows);
2660 blas_int nrhs = blas_int(B.n_cols);
2661 blas_int info = 0;
2662
2663 lapack::trtrs<eT>(&uplo, &trans, &diag, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info);
2664
2665 return (info == 0);
2666 }
2667 #else
2668 {
2669 arma_ignore(out);
2670 arma_ignore(A);
2671 arma_ignore(B);
2672 arma_ignore(layout);
2673 arma_stop("solve(): use of LAPACK needs to be enabled");
2674 return false;
2675 }
2676 #endif
2677 }
2678
2679
2680
2681 //
2682 // Schur decomposition
2683
2684 template<typename eT>
2685 inline
2686 bool
2687 auxlib::schur_dec(Mat<eT>& Z, Mat<eT>& T, const Mat<eT>& A)
2688 {
2689 arma_extra_debug_sigprint();
2690
2691 #if defined(ARMA_USE_LAPACK)
2692 {
2693 arma_debug_check( (A.is_square() == false), "schur_dec(): given matrix is not square" );
2694
2695 if(A.is_empty())
2696 {
2697 Z.reset();
2698 T.reset();
2699 return true;
2700 }
2701
2702 const uword A_n_rows = A.n_rows;
2703
2704 char jobvs = 'V'; // get Schur vectors (Z)
2705 char sort = 'N'; // do not sort eigenvalues/vectors
2706 blas_int* select = 0; // pointer to sorting function
2707 blas_int n = blas_int(A_n_rows);
2708 blas_int sdim = 0; // output for sorting
2709
2710 blas_int lwork = 3 * n; // workspace must be at least 3 * n (if set to -1, optimal size is output in work(0) and nothing else is done
2711
2712 podarray<eT> work( static_cast<uword>(lwork) );
2713 podarray<blas_int> bwork(A_n_rows);
2714
2715 blas_int info = 0;
2716
2717 Z.set_size(A_n_rows, A_n_rows);
2718 T = A;
2719
2720 podarray<eT> wr(A_n_rows); // output for eigenvalues
2721 podarray<eT> wi(A_n_rows); // output for eigenvalues
2722
2723 lapack::gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, wr.memptr(), wi.memptr(), Z.memptr(), &n, work.memptr(), &lwork, bwork.memptr(), &info);
2724
2725 return (info == 0);
2726 }
2727 #else
2728 {
2729 arma_ignore(Z);
2730 arma_ignore(T);
2731 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
2732 return false;
2733 }
2734 #endif
2735 }
2736
2737
2738
2739 template<typename cT>
2740 inline
2741 bool
2742 auxlib::schur_dec(Mat<std::complex<cT> >& Z, Mat<std::complex<cT> >& T, const Mat<std::complex<cT> >& A)
2743 {
2744 arma_extra_debug_sigprint();
2745
2746 #if defined(ARMA_USE_LAPACK)
2747 {
2748 arma_debug_check( (A.is_square() == false), "schur_dec(): matrix A is not square" );
2749
2750 if(A.is_empty())
2751 {
2752 Z.reset();
2753 T.reset();
2754 return true;
2755 }
2756
2757 typedef std::complex<cT> eT;
2758
2759 const uword A_n_rows = A.n_rows;
2760
2761 char jobvs = 'V'; // get Schur vectors (Z)
2762 char sort = 'N'; // do not sort eigenvalues/vectors
2763 blas_int* select = 0; // pointer to sorting function
2764 blas_int n = blas_int(A_n_rows);
2765 blas_int sdim = 0; // output for sorting
2766
2767 blas_int lwork = 3 * n; // workspace must be at least 3 * n (if set to -1, optimal size is output in work(0) and nothing else is done
2768
2769 podarray<eT> work( static_cast<uword>(lwork) );
2770 podarray<blas_int> bwork(A_n_rows);
2771
2772 blas_int info = 0;
2773
2774 Z.set_size(A_n_rows, A_n_rows);
2775 T = A;
2776
2777 podarray<eT> w(A_n_rows); // output for eigenvalues
2778 podarray<cT> rwork(A_n_rows);
2779
2780 lapack::cx_gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, w.memptr(), Z.memptr(), &n, work.memptr(), &lwork, rwork.memptr(), bwork.memptr(), &info);
2781
2782 return (info == 0);
2783 }
2784 #else
2785 {
2786 arma_ignore(Z);
2787 arma_ignore(T);
2788 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
2789 return false;
2790 }
2791 #endif
2792 }
2793
2794
2795
2796 //
2797 // syl (solution of the Sylvester equation AX + XB = C)
2798
2799 template<typename eT>
2800 inline
2801 bool
2802 auxlib::syl(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C)
2803 {
2804 arma_extra_debug_sigprint();
2805
2806 arma_debug_check
2807 (
2808 (A.is_square() == false) || (B.is_square() == false),
2809 "syl(): given matrix is not square"
2810 );
2811
2812 arma_debug_check
2813 (
2814 (C.n_rows != A.n_rows) || (C.n_cols != B.n_cols),
2815 "syl(): matrices are not conformant"
2816 );
2817
2818 if(A.is_empty() || B.is_empty() || C.is_empty())
2819 {
2820 X.reset();
2821 return true;
2822 }
2823
2824 #if defined(ARMA_USE_LAPACK)
2825 {
2826 Mat<eT> Z1, Z2, T1, T2;
2827
2828 const bool status_sd1 = auxlib::schur_dec(Z1, T1, A);
2829 const bool status_sd2 = auxlib::schur_dec(Z2, T2, B);
2830
2831 if( (status_sd1 == false) || (status_sd2 == false) )
2832 {
2833 return false;
2834 }
2835
2836 char trana = 'N';
2837 char tranb = 'N';
2838 blas_int isgn = +1;
2839 blas_int m = blas_int(T1.n_rows);
2840 blas_int n = blas_int(T2.n_cols);
2841
2842 eT scale = eT(0);
2843 blas_int info = 0;
2844
2845 Mat<eT> Y = trans(Z1) * C * Z2;
2846
2847 lapack::trsyl<eT>(&trana, &tranb, &isgn, &m, &n, T1.memptr(), &m, T2.memptr(), &n, Y.memptr(), &m, &scale, &info);
2848
2849 //Y /= scale;
2850 Y /= (-scale);
2851
2852 X = Z1 * Y * trans(Z2);
2853
2854 return (info >= 0);
2855 }
2856 #else
2857 {
2858 arma_stop("syl(): use of LAPACK needs to be enabled");
2859 return false;
2860 }
2861 #endif
2862 }
2863
2864
2865
2866 //
2867 // lyap (solution of the continuous Lyapunov equation AX + XA^H + Q = 0)
2868
2869 template<typename eT>
2870 inline
2871 bool
2872 auxlib::lyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
2873 {
2874 arma_extra_debug_sigprint();
2875
2876 arma_debug_check( (A.is_square() == false), "lyap(): matrix A is not square");
2877 arma_debug_check( (Q.is_square() == false), "lyap(): matrix Q is not square");
2878 arma_debug_check( (A.n_rows != Q.n_rows), "lyap(): matrices A and Q have different dimensions");
2879
2880 Mat<eT> htransA;
2881 op_htrans::apply_noalias(htransA, A);
2882
2883 const Mat<eT> mQ = -Q;
2884
2885 return auxlib::syl(X, A, htransA, mQ);
2886 }
2887
2888
2889
2890 //
2891 // dlyap (solution of the discrete Lyapunov equation AXA^H - X + Q = 0)
2892
2893 template<typename eT>
2894 inline
2895 bool
2896 auxlib::dlyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
2897 {
2898 arma_extra_debug_sigprint();
2899
2900 arma_debug_check( (A.is_square() == false), "dlyap(): matrix A is not square");
2901 arma_debug_check( (Q.is_square() == false), "dlyap(): matrix Q is not square");
2902 arma_debug_check( (A.n_rows != Q.n_rows), "dlyap(): matrices A and Q have different dimensions");
2903
2904 const Col<eT> vecQ = reshape(Q, Q.n_elem, 1);
2905
2906 const Mat<eT> M = eye< Mat<eT> >(Q.n_elem, Q.n_elem) - kron(conj(A), A);
2907
2908 Col<eT> vecX;
2909
2910 const bool status = solve(vecX, M, vecQ);
2911
2912 if(status == true)
2913 {
2914 X = reshape(vecX, Q.n_rows, Q.n_cols);
2915 return true;
2916 }
2917 else
2918 {
2919 X.reset();
2920 return false;
2921 }
2922 }
2923
2924
2925
2926 //! @}