Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-2.4.4/include/armadillo_bits/auxlib_meat.hpp @ 0:8b6102e2a9b0
Armadillo Library
author | maxzanoni76 <max.zanoni@eecs.qmul.ac.uk> |
---|---|
date | Wed, 11 Apr 2012 09:27:06 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:8b6102e2a9b0 |
---|---|
1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2008-2012 Conrad Sanderson | |
3 // Copyright (C) 2009 Edmund Highcock | |
4 // Copyright (C) 2011 James Sanders | |
5 // Copyright (C) 2011 Stanislav Funiak | |
6 // | |
7 // This file is part of the Armadillo C++ library. | |
8 // It is provided without any warranty of fitness | |
9 // for any purpose. You can redistribute this file | |
10 // and/or modify it under the terms of the GNU | |
11 // Lesser General Public License (LGPL) as published | |
12 // by the Free Software Foundation, either version 3 | |
13 // of the License or (at your option) any later version. | |
14 // (see http://www.opensource.org/licenses for more info) | |
15 | |
16 | |
17 //! \addtogroup auxlib | |
18 //! @{ | |
19 | |
20 | |
21 | |
22 //! immediate matrix inverse | |
23 template<typename eT, typename T1> | |
24 inline | |
25 bool | |
26 auxlib::inv(Mat<eT>& out, const Base<eT,T1>& X, const bool slow) | |
27 { | |
28 arma_extra_debug_sigprint(); | |
29 | |
30 out = X.get_ref(); | |
31 | |
32 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); | |
33 | |
34 bool status = false; | |
35 | |
36 const uword N = out.n_rows; | |
37 | |
38 if( (N <= 4) && (slow == false) ) | |
39 { | |
40 status = auxlib::inv_inplace_tinymat(out, N); | |
41 } | |
42 | |
43 if( (N > 4) || (status == false) ) | |
44 { | |
45 status = auxlib::inv_inplace_lapack(out); | |
46 } | |
47 | |
48 return status; | |
49 } | |
50 | |
51 | |
52 | |
53 template<typename eT> | |
54 inline | |
55 bool | |
56 auxlib::inv(Mat<eT>& out, const Mat<eT>& X, const bool slow) | |
57 { | |
58 arma_extra_debug_sigprint(); | |
59 | |
60 arma_debug_check( (X.is_square() == false), "inv(): given matrix is not square" ); | |
61 | |
62 bool status = false; | |
63 | |
64 const uword N = X.n_rows; | |
65 | |
66 if( (N <= 4) && (slow == false) ) | |
67 { | |
68 status = (&out != &X) ? auxlib::inv_noalias_tinymat(out, X, N) : auxlib::inv_inplace_tinymat(out, N); | |
69 } | |
70 | |
71 if( (N > 4) || (status == false) ) | |
72 { | |
73 out = X; | |
74 status = auxlib::inv_inplace_lapack(out); | |
75 } | |
76 | |
77 return status; | |
78 } | |
79 | |
80 | |
81 | |
82 template<typename eT> | |
83 inline | |
84 bool | |
85 auxlib::inv_noalias_tinymat(Mat<eT>& out, const Mat<eT>& X, const uword N) | |
86 { | |
87 arma_extra_debug_sigprint(); | |
88 | |
89 bool det_ok = true; | |
90 | |
91 out.set_size(N,N); | |
92 | |
93 switch(N) | |
94 { | |
95 case 1: | |
96 { | |
97 out[0] = eT(1) / X[0]; | |
98 }; | |
99 break; | |
100 | |
101 case 2: | |
102 { | |
103 const eT* Xm = X.memptr(); | |
104 | |
105 const eT a = Xm[pos<0,0>::n2]; | |
106 const eT b = Xm[pos<0,1>::n2]; | |
107 const eT c = Xm[pos<1,0>::n2]; | |
108 const eT d = Xm[pos<1,1>::n2]; | |
109 | |
110 const eT tmp_det = (a*d - b*c); | |
111 | |
112 if(tmp_det != eT(0)) | |
113 { | |
114 eT* outm = out.memptr(); | |
115 | |
116 outm[pos<0,0>::n2] = d / tmp_det; | |
117 outm[pos<0,1>::n2] = -b / tmp_det; | |
118 outm[pos<1,0>::n2] = -c / tmp_det; | |
119 outm[pos<1,1>::n2] = a / tmp_det; | |
120 } | |
121 else | |
122 { | |
123 det_ok = false; | |
124 } | |
125 }; | |
126 break; | |
127 | |
128 case 3: | |
129 { | |
130 const eT* X_col0 = X.colptr(0); | |
131 const eT a11 = X_col0[0]; | |
132 const eT a21 = X_col0[1]; | |
133 const eT a31 = X_col0[2]; | |
134 | |
135 const eT* X_col1 = X.colptr(1); | |
136 const eT a12 = X_col1[0]; | |
137 const eT a22 = X_col1[1]; | |
138 const eT a32 = X_col1[2]; | |
139 | |
140 const eT* X_col2 = X.colptr(2); | |
141 const eT a13 = X_col2[0]; | |
142 const eT a23 = X_col2[1]; | |
143 const eT a33 = X_col2[2]; | |
144 | |
145 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13); | |
146 | |
147 if(tmp_det != eT(0)) | |
148 { | |
149 eT* out_col0 = out.colptr(0); | |
150 out_col0[0] = (a33*a22 - a32*a23) / tmp_det; | |
151 out_col0[1] = -(a33*a21 - a31*a23) / tmp_det; | |
152 out_col0[2] = (a32*a21 - a31*a22) / tmp_det; | |
153 | |
154 eT* out_col1 = out.colptr(1); | |
155 out_col1[0] = -(a33*a12 - a32*a13) / tmp_det; | |
156 out_col1[1] = (a33*a11 - a31*a13) / tmp_det; | |
157 out_col1[2] = -(a32*a11 - a31*a12) / tmp_det; | |
158 | |
159 eT* out_col2 = out.colptr(2); | |
160 out_col2[0] = (a23*a12 - a22*a13) / tmp_det; | |
161 out_col2[1] = -(a23*a11 - a21*a13) / tmp_det; | |
162 out_col2[2] = (a22*a11 - a21*a12) / tmp_det; | |
163 } | |
164 else | |
165 { | |
166 det_ok = false; | |
167 } | |
168 }; | |
169 break; | |
170 | |
171 case 4: | |
172 { | |
173 const eT tmp_det = det(X); | |
174 | |
175 if(tmp_det != eT(0)) | |
176 { | |
177 const eT* Xm = X.memptr(); | |
178 eT* outm = out.memptr(); | |
179 | |
180 outm[pos<0,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; | |
181 outm[pos<1,0>::n4] = ( Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; | |
182 outm[pos<2,0>::n4] = ( Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; | |
183 outm[pos<3,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det; | |
184 | |
185 outm[pos<0,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; | |
186 outm[pos<1,1>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; | |
187 outm[pos<2,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; | |
188 outm[pos<3,1>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det; | |
189 | |
190 outm[pos<0,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; | |
191 outm[pos<1,2>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; | |
192 outm[pos<2,2>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det; | |
193 outm[pos<3,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det; | |
194 | |
195 outm[pos<0,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det; | |
196 outm[pos<1,3>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det; | |
197 outm[pos<2,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] ) / tmp_det; | |
198 outm[pos<3,3>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] ) / tmp_det; | |
199 } | |
200 else | |
201 { | |
202 det_ok = false; | |
203 } | |
204 }; | |
205 break; | |
206 | |
207 default: | |
208 ; | |
209 } | |
210 | |
211 return det_ok; | |
212 } | |
213 | |
214 | |
215 | |
216 template<typename eT> | |
217 inline | |
218 bool | |
219 auxlib::inv_inplace_tinymat(Mat<eT>& X, const uword N) | |
220 { | |
221 arma_extra_debug_sigprint(); | |
222 | |
223 bool det_ok = true; | |
224 | |
225 // for more info, see: | |
226 // http://www.dr-lex.34sp.com/random/matrix_inv.html | |
227 // http://www.cvl.iis.u-tokyo.ac.jp/~miyazaki/tech/teche23.html | |
228 // http://www.euclideanspace.com/maths/algebra/matrix/functions/inverse/fourD/index.htm | |
229 // http://www.geometrictools.com//LibFoundation/Mathematics/Wm4Matrix4.inl | |
230 | |
231 switch(N) | |
232 { | |
233 case 1: | |
234 { | |
235 X[0] = eT(1) / X[0]; | |
236 }; | |
237 break; | |
238 | |
239 case 2: | |
240 { | |
241 const eT a = X[pos<0,0>::n2]; | |
242 const eT b = X[pos<0,1>::n2]; | |
243 const eT c = X[pos<1,0>::n2]; | |
244 const eT d = X[pos<1,1>::n2]; | |
245 | |
246 const eT tmp_det = (a*d - b*c); | |
247 | |
248 if(tmp_det != eT(0)) | |
249 { | |
250 X[pos<0,0>::n2] = d / tmp_det; | |
251 X[pos<0,1>::n2] = -b / tmp_det; | |
252 X[pos<1,0>::n2] = -c / tmp_det; | |
253 X[pos<1,1>::n2] = a / tmp_det; | |
254 } | |
255 else | |
256 { | |
257 det_ok = false; | |
258 } | |
259 }; | |
260 break; | |
261 | |
262 case 3: | |
263 { | |
264 eT* X_col0 = X.colptr(0); | |
265 eT* X_col1 = X.colptr(1); | |
266 eT* X_col2 = X.colptr(2); | |
267 | |
268 const eT a11 = X_col0[0]; | |
269 const eT a21 = X_col0[1]; | |
270 const eT a31 = X_col0[2]; | |
271 | |
272 const eT a12 = X_col1[0]; | |
273 const eT a22 = X_col1[1]; | |
274 const eT a32 = X_col1[2]; | |
275 | |
276 const eT a13 = X_col2[0]; | |
277 const eT a23 = X_col2[1]; | |
278 const eT a33 = X_col2[2]; | |
279 | |
280 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13); | |
281 | |
282 if(tmp_det != eT(0)) | |
283 { | |
284 X_col0[0] = (a33*a22 - a32*a23) / tmp_det; | |
285 X_col0[1] = -(a33*a21 - a31*a23) / tmp_det; | |
286 X_col0[2] = (a32*a21 - a31*a22) / tmp_det; | |
287 | |
288 X_col1[0] = -(a33*a12 - a32*a13) / tmp_det; | |
289 X_col1[1] = (a33*a11 - a31*a13) / tmp_det; | |
290 X_col1[2] = -(a32*a11 - a31*a12) / tmp_det; | |
291 | |
292 X_col2[0] = (a23*a12 - a22*a13) / tmp_det; | |
293 X_col2[1] = -(a23*a11 - a21*a13) / tmp_det; | |
294 X_col2[2] = (a22*a11 - a21*a12) / tmp_det; | |
295 } | |
296 else | |
297 { | |
298 det_ok = false; | |
299 } | |
300 }; | |
301 break; | |
302 | |
303 case 4: | |
304 { | |
305 const eT tmp_det = det(X); | |
306 | |
307 if(tmp_det != eT(0)) | |
308 { | |
309 const Mat<eT> A(X); | |
310 | |
311 const eT* Am = A.memptr(); | |
312 eT* Xm = X.memptr(); | |
313 | |
314 Xm[pos<0,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] - Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] + Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; | |
315 Xm[pos<1,0>::n4] = ( Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; | |
316 Xm[pos<2,0>::n4] = ( Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det; | |
317 Xm[pos<3,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det; | |
318 | |
319 Xm[pos<0,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] + Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] - Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; | |
320 Xm[pos<1,1>::n4] = ( Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; | |
321 Xm[pos<2,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det; | |
322 Xm[pos<3,1>::n4] = ( Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det; | |
323 | |
324 Xm[pos<0,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] + Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; | |
325 Xm[pos<1,2>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det; | |
326 Xm[pos<2,2>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det; | |
327 Xm[pos<3,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det; | |
328 | |
329 Xm[pos<0,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] + Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det; | |
330 Xm[pos<1,3>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det; | |
331 Xm[pos<2,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] ) / tmp_det; | |
332 Xm[pos<3,3>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] ) / tmp_det; | |
333 } | |
334 else | |
335 { | |
336 det_ok = false; | |
337 } | |
338 }; | |
339 break; | |
340 | |
341 default: | |
342 ; | |
343 } | |
344 | |
345 return det_ok; | |
346 } | |
347 | |
348 | |
349 | |
350 template<typename eT> | |
351 inline | |
352 bool | |
353 auxlib::inv_inplace_lapack(Mat<eT>& out) | |
354 { | |
355 arma_extra_debug_sigprint(); | |
356 | |
357 if(out.is_empty()) | |
358 { | |
359 return true; | |
360 } | |
361 | |
362 #if defined(ARMA_USE_ATLAS) | |
363 { | |
364 podarray<int> ipiv(out.n_rows); | |
365 | |
366 int info = atlas::clapack_getrf(atlas::CblasColMajor, out.n_rows, out.n_cols, out.memptr(), out.n_rows, ipiv.memptr()); | |
367 | |
368 if(info == 0) | |
369 { | |
370 info = atlas::clapack_getri(atlas::CblasColMajor, out.n_rows, out.memptr(), out.n_rows, ipiv.memptr()); | |
371 } | |
372 | |
373 return (info == 0); | |
374 } | |
375 #elif defined(ARMA_USE_LAPACK) | |
376 { | |
377 blas_int n_rows = out.n_rows; | |
378 blas_int n_cols = out.n_cols; | |
379 blas_int info = 0; | |
380 | |
381 podarray<blas_int> ipiv(out.n_rows); | |
382 | |
383 // 84 was empirically found -- it is the maximum value suggested by LAPACK (as provided by ATLAS v3.6) | |
384 // based on tests with various matrix types on 32-bit and 64-bit machines | |
385 // | |
386 // the "work" array is deliberately long so that a secondary (time-consuming) | |
387 // memory allocation is avoided, if possible | |
388 | |
389 blas_int work_len = (std::max)(blas_int(1), n_rows*84); | |
390 podarray<eT> work( static_cast<uword>(work_len) ); | |
391 | |
392 lapack::getrf(&n_rows, &n_cols, out.memptr(), &n_rows, ipiv.memptr(), &info); | |
393 | |
394 if(info == 0) | |
395 { | |
396 // query for optimum size of work_len | |
397 | |
398 blas_int work_len_tmp = -1; | |
399 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len_tmp, &info); | |
400 | |
401 if(info == 0) | |
402 { | |
403 blas_int proposed_work_len = static_cast<blas_int>(access::tmp_real(work[0])); | |
404 | |
405 // if necessary, allocate more memory | |
406 if(work_len < proposed_work_len) | |
407 { | |
408 work_len = proposed_work_len; | |
409 work.set_size( static_cast<uword>(work_len) ); | |
410 } | |
411 } | |
412 | |
413 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len, &info); | |
414 } | |
415 | |
416 return (info == 0); | |
417 } | |
418 #else | |
419 { | |
420 arma_ignore(out); | |
421 arma_stop("inv(): use of ATLAS or LAPACK needs to be enabled"); | |
422 return false; | |
423 } | |
424 #endif | |
425 } | |
426 | |
427 | |
428 | |
429 template<typename eT, typename T1> | |
430 inline | |
431 bool | |
432 auxlib::inv_tr(Mat<eT>& out, const Base<eT,T1>& X, const uword layout) | |
433 { | |
434 arma_extra_debug_sigprint(); | |
435 | |
436 out = X.get_ref(); | |
437 | |
438 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); | |
439 | |
440 if(out.is_empty()) | |
441 { | |
442 return true; | |
443 } | |
444 | |
445 bool status; | |
446 | |
447 #if defined(ARMA_USE_LAPACK) | |
448 { | |
449 char uplo = (layout == 0) ? 'U' : 'L'; | |
450 char diag = 'N'; | |
451 blas_int n = blas_int(out.n_rows); | |
452 blas_int info = 0; | |
453 | |
454 lapack::trtri(&uplo, &diag, &n, out.memptr(), &n, &info); | |
455 | |
456 status = (info == 0); | |
457 } | |
458 #else | |
459 { | |
460 arma_ignore(layout); | |
461 arma_stop("inv(): use of LAPACK needs to be enabled"); | |
462 status = false; | |
463 } | |
464 #endif | |
465 | |
466 | |
467 if(status == true) | |
468 { | |
469 if(layout == 0) | |
470 { | |
471 // upper triangular | |
472 out = trimatu(out); | |
473 } | |
474 else | |
475 { | |
476 // lower triangular | |
477 out = trimatl(out); | |
478 } | |
479 } | |
480 | |
481 return status; | |
482 } | |
483 | |
484 | |
485 | |
486 template<typename eT, typename T1> | |
487 inline | |
488 bool | |
489 auxlib::inv_sym(Mat<eT>& out, const Base<eT,T1>& X, const uword layout) | |
490 { | |
491 arma_extra_debug_sigprint(); | |
492 | |
493 out = X.get_ref(); | |
494 | |
495 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); | |
496 | |
497 if(out.is_empty()) | |
498 { | |
499 return true; | |
500 } | |
501 | |
502 bool status; | |
503 | |
504 #if defined(ARMA_USE_LAPACK) | |
505 { | |
506 char uplo = (layout == 0) ? 'U' : 'L'; | |
507 blas_int n = blas_int(out.n_rows); | |
508 blas_int lwork = n*n; // TODO: use lwork = -1 to determine optimal size | |
509 blas_int info = 0; | |
510 | |
511 podarray<blas_int> ipiv; | |
512 ipiv.set_size(out.n_rows); | |
513 | |
514 podarray<eT> work; | |
515 work.set_size( uword(lwork) ); | |
516 | |
517 lapack::sytrf(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &lwork, &info); | |
518 | |
519 status = (info == 0); | |
520 | |
521 if(status == true) | |
522 { | |
523 lapack::sytri(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &info); | |
524 | |
525 out = (layout == 0) ? symmatu(out) : symmatl(out); | |
526 | |
527 status = (info == 0); | |
528 } | |
529 } | |
530 #else | |
531 { | |
532 arma_ignore(layout); | |
533 arma_stop("inv(): use of LAPACK needs to be enabled"); | |
534 status = false; | |
535 } | |
536 #endif | |
537 | |
538 return status; | |
539 } | |
540 | |
541 | |
542 | |
543 template<typename eT, typename T1> | |
544 inline | |
545 bool | |
546 auxlib::inv_sympd(Mat<eT>& out, const Base<eT,T1>& X, const uword layout) | |
547 { | |
548 arma_extra_debug_sigprint(); | |
549 | |
550 out = X.get_ref(); | |
551 | |
552 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" ); | |
553 | |
554 if(out.is_empty()) | |
555 { | |
556 return true; | |
557 } | |
558 | |
559 bool status; | |
560 | |
561 #if defined(ARMA_USE_LAPACK) | |
562 { | |
563 char uplo = (layout == 0) ? 'U' : 'L'; | |
564 blas_int n = blas_int(out.n_rows); | |
565 blas_int info = 0; | |
566 | |
567 lapack::potrf(&uplo, &n, out.memptr(), &n, &info); | |
568 | |
569 status = (info == 0); | |
570 | |
571 if(status == true) | |
572 { | |
573 lapack::potri(&uplo, &n, out.memptr(), &n, &info); | |
574 | |
575 out = (layout == 0) ? symmatu(out) : symmatl(out); | |
576 | |
577 status = (info == 0); | |
578 } | |
579 } | |
580 #else | |
581 { | |
582 arma_ignore(layout); | |
583 arma_stop("inv(): use of LAPACK needs to be enabled"); | |
584 status = false; | |
585 } | |
586 #endif | |
587 | |
588 return status; | |
589 } | |
590 | |
591 | |
592 | |
593 template<typename eT, typename T1> | |
594 inline | |
595 eT | |
596 auxlib::det(const Base<eT,T1>& X, const bool slow) | |
597 { | |
598 const unwrap<T1> tmp(X.get_ref()); | |
599 const Mat<eT>& A = tmp.M; | |
600 | |
601 arma_debug_check( (A.is_square() == false), "det(): matrix is not square" ); | |
602 | |
603 const bool make_copy = (is_Mat<T1>::value == true) ? true : false; | |
604 | |
605 if(slow == false) | |
606 { | |
607 const uword N = A.n_rows; | |
608 | |
609 switch(N) | |
610 { | |
611 case 0: | |
612 case 1: | |
613 case 2: | |
614 return auxlib::det_tinymat(A, N); | |
615 break; | |
616 | |
617 case 3: | |
618 case 4: | |
619 { | |
620 const eT tmp_det = auxlib::det_tinymat(A, N); | |
621 return (tmp_det != eT(0)) ? tmp_det : auxlib::det_lapack(A, make_copy); | |
622 } | |
623 break; | |
624 | |
625 default: | |
626 return auxlib::det_lapack(A, make_copy); | |
627 } | |
628 } | |
629 else | |
630 { | |
631 return auxlib::det_lapack(A, make_copy); | |
632 } | |
633 } | |
634 | |
635 | |
636 | |
637 template<typename eT> | |
638 inline | |
639 eT | |
640 auxlib::det_tinymat(const Mat<eT>& X, const uword N) | |
641 { | |
642 arma_extra_debug_sigprint(); | |
643 | |
644 switch(N) | |
645 { | |
646 case 0: | |
647 return eT(1); | |
648 break; | |
649 | |
650 case 1: | |
651 return X[0]; | |
652 break; | |
653 | |
654 case 2: | |
655 { | |
656 const eT* Xm = X.memptr(); | |
657 | |
658 return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] ); | |
659 } | |
660 break; | |
661 | |
662 case 3: | |
663 { | |
664 // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2); | |
665 // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0); | |
666 // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1); | |
667 // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2); | |
668 // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0); | |
669 // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1); | |
670 // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6); | |
671 | |
672 const eT* a_col0 = X.colptr(0); | |
673 const eT a11 = a_col0[0]; | |
674 const eT a21 = a_col0[1]; | |
675 const eT a31 = a_col0[2]; | |
676 | |
677 const eT* a_col1 = X.colptr(1); | |
678 const eT a12 = a_col1[0]; | |
679 const eT a22 = a_col1[1]; | |
680 const eT a32 = a_col1[2]; | |
681 | |
682 const eT* a_col2 = X.colptr(2); | |
683 const eT a13 = a_col2[0]; | |
684 const eT a23 = a_col2[1]; | |
685 const eT a33 = a_col2[2]; | |
686 | |
687 return ( a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13) ); | |
688 } | |
689 break; | |
690 | |
691 case 4: | |
692 { | |
693 const eT* Xm = X.memptr(); | |
694 | |
695 const eT val = \ | |
696 Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \ | |
697 - Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \ | |
698 - Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \ | |
699 + Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \ | |
700 + Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \ | |
701 - Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \ | |
702 - Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \ | |
703 + Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \ | |
704 + Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \ | |
705 - Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \ | |
706 - Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \ | |
707 + Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \ | |
708 + Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \ | |
709 - Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \ | |
710 - Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \ | |
711 + Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \ | |
712 + Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \ | |
713 - Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \ | |
714 - Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \ | |
715 + Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \ | |
716 + Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \ | |
717 - Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \ | |
718 - Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \ | |
719 + Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \ | |
720 ; | |
721 | |
722 return val; | |
723 } | |
724 break; | |
725 | |
726 default: | |
727 return eT(0); | |
728 ; | |
729 } | |
730 } | |
731 | |
732 | |
733 | |
734 //! immediate determinant of a matrix using ATLAS or LAPACK | |
735 template<typename eT> | |
736 inline | |
737 eT | |
738 auxlib::det_lapack(const Mat<eT>& X, const bool make_copy) | |
739 { | |
740 arma_extra_debug_sigprint(); | |
741 | |
742 Mat<eT> X_copy; | |
743 | |
744 if(make_copy == true) | |
745 { | |
746 X_copy = X; | |
747 } | |
748 | |
749 Mat<eT>& tmp = (make_copy == true) ? X_copy : const_cast< Mat<eT>& >(X); | |
750 | |
751 if(tmp.is_empty()) | |
752 { | |
753 return eT(1); | |
754 } | |
755 | |
756 | |
757 #if defined(ARMA_USE_ATLAS) | |
758 { | |
759 podarray<int> ipiv(tmp.n_rows); | |
760 | |
761 //const int info = | |
762 atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr()); | |
763 | |
764 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero | |
765 eT val = tmp.at(0,0); | |
766 for(uword i=1; i < tmp.n_rows; ++i) | |
767 { | |
768 val *= tmp.at(i,i); | |
769 } | |
770 | |
771 int sign = +1; | |
772 for(uword i=0; i < tmp.n_rows; ++i) | |
773 { | |
774 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0 | |
775 { | |
776 sign *= -1; | |
777 } | |
778 } | |
779 | |
780 return ( (sign < 0) ? -val : val ); | |
781 } | |
782 #elif defined(ARMA_USE_LAPACK) | |
783 { | |
784 podarray<blas_int> ipiv(tmp.n_rows); | |
785 | |
786 blas_int info = 0; | |
787 blas_int n_rows = blas_int(tmp.n_rows); | |
788 blas_int n_cols = blas_int(tmp.n_cols); | |
789 | |
790 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info); | |
791 | |
792 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero | |
793 eT val = tmp.at(0,0); | |
794 for(uword i=1; i < tmp.n_rows; ++i) | |
795 { | |
796 val *= tmp.at(i,i); | |
797 } | |
798 | |
799 blas_int sign = +1; | |
800 for(uword i=0; i < tmp.n_rows; ++i) | |
801 { | |
802 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1 | |
803 { | |
804 sign *= -1; | |
805 } | |
806 } | |
807 | |
808 return ( (sign < 0) ? -val : val ); | |
809 } | |
810 #else | |
811 { | |
812 arma_ignore(X); | |
813 arma_ignore(make_copy); | |
814 arma_ignore(tmp); | |
815 arma_stop("det(): use of ATLAS or LAPACK needs to be enabled"); | |
816 return eT(0); | |
817 } | |
818 #endif | |
819 } | |
820 | |
821 | |
822 | |
823 //! immediate log determinant of a matrix using ATLAS or LAPACK | |
824 template<typename eT, typename T1> | |
825 inline | |
826 bool | |
827 auxlib::log_det(eT& out_val, typename get_pod_type<eT>::result& out_sign, const Base<eT,T1>& X) | |
828 { | |
829 arma_extra_debug_sigprint(); | |
830 | |
831 typedef typename get_pod_type<eT>::result T; | |
832 | |
833 #if defined(ARMA_USE_ATLAS) | |
834 { | |
835 Mat<eT> tmp(X.get_ref()); | |
836 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" ); | |
837 | |
838 if(tmp.is_empty()) | |
839 { | |
840 out_val = eT(0); | |
841 out_sign = T(1); | |
842 return true; | |
843 } | |
844 | |
845 podarray<int> ipiv(tmp.n_rows); | |
846 | |
847 const int info = atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr()); | |
848 | |
849 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero | |
850 | |
851 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1; | |
852 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) ); | |
853 | |
854 for(uword i=1; i < tmp.n_rows; ++i) | |
855 { | |
856 const eT x = tmp.at(i,i); | |
857 | |
858 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1; | |
859 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); | |
860 } | |
861 | |
862 for(uword i=0; i < tmp.n_rows; ++i) | |
863 { | |
864 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0 | |
865 { | |
866 sign *= -1; | |
867 } | |
868 } | |
869 | |
870 out_val = val; | |
871 out_sign = T(sign); | |
872 | |
873 return (info == 0); | |
874 } | |
875 #elif defined(ARMA_USE_LAPACK) | |
876 { | |
877 Mat<eT> tmp(X.get_ref()); | |
878 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" ); | |
879 | |
880 if(tmp.is_empty()) | |
881 { | |
882 out_val = eT(0); | |
883 out_sign = T(1); | |
884 return true; | |
885 } | |
886 | |
887 podarray<blas_int> ipiv(tmp.n_rows); | |
888 | |
889 blas_int info = 0; | |
890 blas_int n_rows = blas_int(tmp.n_rows); | |
891 blas_int n_cols = blas_int(tmp.n_cols); | |
892 | |
893 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info); | |
894 | |
895 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero | |
896 | |
897 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1; | |
898 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) ); | |
899 | |
900 for(uword i=1; i < tmp.n_rows; ++i) | |
901 { | |
902 const eT x = tmp.at(i,i); | |
903 | |
904 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1; | |
905 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); | |
906 } | |
907 | |
908 for(uword i=0; i < tmp.n_rows; ++i) | |
909 { | |
910 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1 | |
911 { | |
912 sign *= -1; | |
913 } | |
914 } | |
915 | |
916 out_val = val; | |
917 out_sign = T(sign); | |
918 | |
919 return (info == 0); | |
920 } | |
921 #else | |
922 { | |
923 out_val = eT(0); | |
924 out_sign = T(0); | |
925 | |
926 arma_stop("log_det(): use of ATLAS or LAPACK needs to be enabled"); | |
927 | |
928 return false; | |
929 } | |
930 #endif | |
931 } | |
932 | |
933 | |
934 | |
935 //! immediate LU decomposition of a matrix using ATLAS or LAPACK | |
936 template<typename eT, typename T1> | |
937 inline | |
938 bool | |
939 auxlib::lu(Mat<eT>& L, Mat<eT>& U, podarray<blas_int>& ipiv, const Base<eT,T1>& X) | |
940 { | |
941 arma_extra_debug_sigprint(); | |
942 | |
943 U = X.get_ref(); | |
944 | |
945 const uword U_n_rows = U.n_rows; | |
946 const uword U_n_cols = U.n_cols; | |
947 | |
948 if(U.is_empty()) | |
949 { | |
950 L.set_size(U_n_rows, 0); | |
951 U.set_size(0, U_n_cols); | |
952 ipiv.reset(); | |
953 return true; | |
954 } | |
955 | |
956 #if defined(ARMA_USE_ATLAS) || defined(ARMA_USE_LAPACK) | |
957 { | |
958 bool status; | |
959 | |
960 #if defined(ARMA_USE_ATLAS) | |
961 { | |
962 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) ); | |
963 | |
964 int info = atlas::clapack_getrf(atlas::CblasColMajor, U_n_rows, U_n_cols, U.memptr(), U_n_rows, ipiv.memptr()); | |
965 | |
966 status = (info == 0); | |
967 } | |
968 #elif defined(ARMA_USE_LAPACK) | |
969 { | |
970 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) ); | |
971 | |
972 blas_int info = 0; | |
973 | |
974 blas_int n_rows = U_n_rows; | |
975 blas_int n_cols = U_n_cols; | |
976 | |
977 | |
978 lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info); | |
979 | |
980 // take into account that Fortran counts from 1 | |
981 arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem); | |
982 | |
983 status = (info == 0); | |
984 } | |
985 #endif | |
986 | |
987 L.copy_size(U); | |
988 | |
989 for(uword col=0; col < U_n_cols; ++col) | |
990 { | |
991 for(uword row=0; (row < col) && (row < U_n_rows); ++row) | |
992 { | |
993 L.at(row,col) = eT(0); | |
994 } | |
995 | |
996 if( L.in_range(col,col) == true ) | |
997 { | |
998 L.at(col,col) = eT(1); | |
999 } | |
1000 | |
1001 for(uword row = (col+1); row < U_n_rows; ++row) | |
1002 { | |
1003 L.at(row,col) = U.at(row,col); | |
1004 U.at(row,col) = eT(0); | |
1005 } | |
1006 } | |
1007 | |
1008 return status; | |
1009 } | |
1010 #else | |
1011 { | |
1012 arma_stop("lu(): use of ATLAS or LAPACK needs to be enabled"); | |
1013 | |
1014 return false; | |
1015 } | |
1016 #endif | |
1017 } | |
1018 | |
1019 | |
1020 | |
1021 template<typename eT, typename T1> | |
1022 inline | |
1023 bool | |
1024 auxlib::lu(Mat<eT>& L, Mat<eT>& U, Mat<eT>& P, const Base<eT,T1>& X) | |
1025 { | |
1026 arma_extra_debug_sigprint(); | |
1027 | |
1028 podarray<blas_int> ipiv1; | |
1029 const bool status = auxlib::lu(L, U, ipiv1, X); | |
1030 | |
1031 if(status == true) | |
1032 { | |
1033 if(U.is_empty()) | |
1034 { | |
1035 // L and U have been already set to the correct empty matrices | |
1036 P.eye(L.n_rows, L.n_rows); | |
1037 return true; | |
1038 } | |
1039 | |
1040 const uword n = ipiv1.n_elem; | |
1041 const uword P_rows = U.n_rows; | |
1042 | |
1043 podarray<blas_int> ipiv2(P_rows); | |
1044 | |
1045 const blas_int* ipiv1_mem = ipiv1.memptr(); | |
1046 blas_int* ipiv2_mem = ipiv2.memptr(); | |
1047 | |
1048 for(uword i=0; i<P_rows; ++i) | |
1049 { | |
1050 ipiv2_mem[i] = blas_int(i); | |
1051 } | |
1052 | |
1053 for(uword i=0; i<n; ++i) | |
1054 { | |
1055 const uword k = static_cast<uword>(ipiv1_mem[i]); | |
1056 | |
1057 if( ipiv2_mem[i] != ipiv2_mem[k] ) | |
1058 { | |
1059 std::swap( ipiv2_mem[i], ipiv2_mem[k] ); | |
1060 } | |
1061 } | |
1062 | |
1063 P.zeros(P_rows, P_rows); | |
1064 | |
1065 for(uword row=0; row<P_rows; ++row) | |
1066 { | |
1067 P.at(row, static_cast<uword>(ipiv2_mem[row])) = eT(1); | |
1068 } | |
1069 | |
1070 if(L.n_cols > U.n_rows) | |
1071 { | |
1072 L.shed_cols(U.n_rows, L.n_cols-1); | |
1073 } | |
1074 | |
1075 if(U.n_rows > L.n_cols) | |
1076 { | |
1077 U.shed_rows(L.n_cols, U.n_rows-1); | |
1078 } | |
1079 } | |
1080 | |
1081 return status; | |
1082 } | |
1083 | |
1084 | |
1085 | |
1086 template<typename eT, typename T1> | |
1087 inline | |
1088 bool | |
1089 auxlib::lu(Mat<eT>& L, Mat<eT>& U, const Base<eT,T1>& X) | |
1090 { | |
1091 arma_extra_debug_sigprint(); | |
1092 | |
1093 podarray<blas_int> ipiv1; | |
1094 const bool status = auxlib::lu(L, U, ipiv1, X); | |
1095 | |
1096 if(status == true) | |
1097 { | |
1098 if(U.is_empty()) | |
1099 { | |
1100 // L and U have been already set to the correct empty matrices | |
1101 return true; | |
1102 } | |
1103 | |
1104 const uword n = ipiv1.n_elem; | |
1105 const uword P_rows = U.n_rows; | |
1106 | |
1107 podarray<blas_int> ipiv2(P_rows); | |
1108 | |
1109 const blas_int* ipiv1_mem = ipiv1.memptr(); | |
1110 blas_int* ipiv2_mem = ipiv2.memptr(); | |
1111 | |
1112 for(uword i=0; i<P_rows; ++i) | |
1113 { | |
1114 ipiv2_mem[i] = blas_int(i); | |
1115 } | |
1116 | |
1117 for(uword i=0; i<n; ++i) | |
1118 { | |
1119 const uword k = static_cast<uword>(ipiv1_mem[i]); | |
1120 | |
1121 if( ipiv2_mem[i] != ipiv2_mem[k] ) | |
1122 { | |
1123 std::swap( ipiv2_mem[i], ipiv2_mem[k] ); | |
1124 L.swap_rows( static_cast<uword>(ipiv2_mem[i]), static_cast<uword>(ipiv2_mem[k]) ); | |
1125 } | |
1126 } | |
1127 | |
1128 if(L.n_cols > U.n_rows) | |
1129 { | |
1130 L.shed_cols(U.n_rows, L.n_cols-1); | |
1131 } | |
1132 | |
1133 if(U.n_rows > L.n_cols) | |
1134 { | |
1135 U.shed_rows(L.n_cols, U.n_rows-1); | |
1136 } | |
1137 } | |
1138 | |
1139 return status; | |
1140 } | |
1141 | |
1142 | |
1143 | |
1144 //! immediate eigenvalues of a symmetric real matrix using LAPACK | |
1145 template<typename eT, typename T1> | |
1146 inline | |
1147 bool | |
1148 auxlib::eig_sym(Col<eT>& eigval, const Base<eT,T1>& X) | |
1149 { | |
1150 arma_extra_debug_sigprint(); | |
1151 | |
1152 #if defined(ARMA_USE_LAPACK) | |
1153 { | |
1154 Mat<eT> A(X.get_ref()); | |
1155 | |
1156 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square"); | |
1157 | |
1158 if(A.is_empty()) | |
1159 { | |
1160 eigval.reset(); | |
1161 return true; | |
1162 } | |
1163 | |
1164 // rudimentary "better-than-nothing" test for symmetry | |
1165 //arma_debug_check( (A.at(A.n_rows-1, 0) != A.at(0, A.n_cols-1)), "auxlib::eig(): given matrix is not symmetric" ); | |
1166 | |
1167 char jobz = 'N'; | |
1168 char uplo = 'U'; | |
1169 | |
1170 blas_int n_rows = A.n_rows; | |
1171 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1); | |
1172 | |
1173 eigval.set_size( static_cast<uword>(n_rows) ); | |
1174 podarray<eT> work( static_cast<uword>(lwork) ); | |
1175 | |
1176 blas_int info; | |
1177 | |
1178 arma_extra_debug_print("lapack::syev()"); | |
1179 lapack::syev(&jobz, &uplo, &n_rows, A.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info); | |
1180 | |
1181 return (info == 0); | |
1182 } | |
1183 #else | |
1184 { | |
1185 arma_ignore(eigval); | |
1186 arma_ignore(X); | |
1187 arma_stop("eig_sym(): use of LAPACK needs to be enabled"); | |
1188 return false; | |
1189 } | |
1190 #endif | |
1191 } | |
1192 | |
1193 | |
1194 | |
1195 //! immediate eigenvalues of a hermitian complex matrix using LAPACK | |
1196 template<typename T, typename T1> | |
1197 inline | |
1198 bool | |
1199 auxlib::eig_sym(Col<T>& eigval, const Base<std::complex<T>,T1>& X) | |
1200 { | |
1201 arma_extra_debug_sigprint(); | |
1202 | |
1203 typedef typename std::complex<T> eT; | |
1204 | |
1205 #if defined(ARMA_USE_LAPACK) | |
1206 { | |
1207 Mat<eT> A(X.get_ref()); | |
1208 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not hermitian"); | |
1209 | |
1210 if(A.is_empty()) | |
1211 { | |
1212 eigval.reset(); | |
1213 return true; | |
1214 } | |
1215 | |
1216 char jobz = 'N'; | |
1217 char uplo = 'U'; | |
1218 | |
1219 blas_int n_rows = A.n_rows; | |
1220 blas_int lda = A.n_rows; | |
1221 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1); // TODO: automatically find best size of lwork | |
1222 | |
1223 eigval.set_size( static_cast<uword>(n_rows) ); | |
1224 | |
1225 podarray<eT> work( static_cast<uword>(lwork) ); | |
1226 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) ); | |
1227 | |
1228 blas_int info; | |
1229 | |
1230 arma_extra_debug_print("lapack::heev()"); | |
1231 lapack::heev(&jobz, &uplo, &n_rows, A.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); | |
1232 | |
1233 return (info == 0); | |
1234 } | |
1235 #else | |
1236 { | |
1237 arma_ignore(eigval); | |
1238 arma_ignore(X); | |
1239 arma_stop("eig_sym(): use of LAPACK needs to be enabled"); | |
1240 return false; | |
1241 } | |
1242 #endif | |
1243 } | |
1244 | |
1245 | |
1246 | |
1247 //! immediate eigenvalues and eigenvectors of a symmetric real matrix using LAPACK | |
1248 template<typename eT, typename T1> | |
1249 inline | |
1250 bool | |
1251 auxlib::eig_sym(Col<eT>& eigval, Mat<eT>& eigvec, const Base<eT,T1>& X) | |
1252 { | |
1253 arma_extra_debug_sigprint(); | |
1254 | |
1255 #if defined(ARMA_USE_LAPACK) | |
1256 { | |
1257 eigvec = X.get_ref(); | |
1258 | |
1259 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" ); | |
1260 | |
1261 if(eigvec.is_empty()) | |
1262 { | |
1263 eigval.reset(); | |
1264 eigvec.reset(); | |
1265 return true; | |
1266 } | |
1267 | |
1268 // rudimentary "better-than-nothing" test for symmetry | |
1269 //arma_debug_check( (A.at(A.n_rows-1, 0) != A.at(0, A.n_cols-1)), "auxlib::eig(): given matrix is not symmetric" ); | |
1270 | |
1271 char jobz = 'V'; | |
1272 char uplo = 'U'; | |
1273 | |
1274 blas_int n_rows = eigvec.n_rows; | |
1275 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1); | |
1276 | |
1277 eigval.set_size( static_cast<uword>(n_rows) ); | |
1278 podarray<eT> work( static_cast<uword>(lwork) ); | |
1279 | |
1280 blas_int info; | |
1281 | |
1282 arma_extra_debug_print("lapack::syev()"); | |
1283 lapack::syev(&jobz, &uplo, &n_rows, eigvec.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info); | |
1284 | |
1285 return (info == 0); | |
1286 } | |
1287 #else | |
1288 { | |
1289 arma_ignore(eigval); | |
1290 arma_ignore(eigvec); | |
1291 arma_stop("eig_sym(): use of LAPACK needs to be enabled"); | |
1292 | |
1293 return false; | |
1294 } | |
1295 #endif | |
1296 } | |
1297 | |
1298 | |
1299 | |
1300 //! immediate eigenvalues and eigenvectors of a hermitian complex matrix using LAPACK | |
1301 template<typename T, typename T1> | |
1302 inline | |
1303 bool | |
1304 auxlib::eig_sym(Col<T>& eigval, Mat< std::complex<T> >& eigvec, const Base<std::complex<T>,T1>& X) | |
1305 { | |
1306 arma_extra_debug_sigprint(); | |
1307 | |
1308 typedef typename std::complex<T> eT; | |
1309 | |
1310 #if defined(ARMA_USE_LAPACK) | |
1311 { | |
1312 eigvec = X.get_ref(); | |
1313 | |
1314 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not hermitian" ); | |
1315 | |
1316 if(eigvec.is_empty()) | |
1317 { | |
1318 eigval.reset(); | |
1319 eigvec.reset(); | |
1320 return true; | |
1321 } | |
1322 | |
1323 char jobz = 'V'; | |
1324 char uplo = 'U'; | |
1325 | |
1326 blas_int n_rows = eigvec.n_rows; | |
1327 blas_int lda = eigvec.n_rows; | |
1328 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1); // TODO: automatically find best size of lwork | |
1329 | |
1330 eigval.set_size( static_cast<uword>(n_rows) ); | |
1331 | |
1332 podarray<eT> work( static_cast<uword>(lwork) ); | |
1333 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) ); | |
1334 | |
1335 blas_int info; | |
1336 | |
1337 arma_extra_debug_print("lapack::heev()"); | |
1338 lapack::heev(&jobz, &uplo, &n_rows, eigvec.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); | |
1339 | |
1340 return (info == 0); | |
1341 } | |
1342 #else | |
1343 { | |
1344 arma_ignore(eigval); | |
1345 arma_ignore(eigvec); | |
1346 arma_ignore(X); | |
1347 arma_stop("eig_sym(): use of LAPACK needs to be enabled"); | |
1348 return false; | |
1349 } | |
1350 #endif | |
1351 } | |
1352 | |
1353 | |
1354 | |
1355 //! Eigenvalues and eigenvectors of a general square real matrix using LAPACK. | |
1356 //! The argument 'side' specifies which eigenvectors should be calculated | |
1357 //! (see code for mode details). | |
1358 template<typename T, typename T1> | |
1359 inline | |
1360 bool | |
1361 auxlib::eig_gen | |
1362 ( | |
1363 Col< std::complex<T> >& eigval, | |
1364 Mat<T>& l_eigvec, | |
1365 Mat<T>& r_eigvec, | |
1366 const Base<T,T1>& X, | |
1367 const char side | |
1368 ) | |
1369 { | |
1370 arma_extra_debug_sigprint(); | |
1371 | |
1372 #if defined(ARMA_USE_LAPACK) | |
1373 { | |
1374 char jobvl; | |
1375 char jobvr; | |
1376 | |
1377 switch(side) | |
1378 { | |
1379 case 'l': // left | |
1380 jobvl = 'V'; | |
1381 jobvr = 'N'; | |
1382 break; | |
1383 | |
1384 case 'r': // right | |
1385 jobvl = 'N'; | |
1386 jobvr = 'V'; | |
1387 break; | |
1388 | |
1389 case 'b': // both | |
1390 jobvl = 'V'; | |
1391 jobvr = 'V'; | |
1392 break; | |
1393 | |
1394 case 'n': // neither | |
1395 jobvl = 'N'; | |
1396 jobvr = 'N'; | |
1397 break; | |
1398 | |
1399 default: | |
1400 arma_stop("eig_gen(): parameter 'side' is invalid"); | |
1401 return false; | |
1402 } | |
1403 | |
1404 Mat<T> A(X.get_ref()); | |
1405 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" ); | |
1406 | |
1407 if(A.is_empty()) | |
1408 { | |
1409 eigval.reset(); | |
1410 l_eigvec.reset(); | |
1411 r_eigvec.reset(); | |
1412 return true; | |
1413 } | |
1414 | |
1415 uword A_n_rows = A.n_rows; | |
1416 | |
1417 blas_int n_rows = A_n_rows; | |
1418 blas_int lda = A_n_rows; | |
1419 blas_int lwork = (std::max)(blas_int(1), 4*n_rows); // TODO: automatically find best size of lwork | |
1420 | |
1421 eigval.set_size(A_n_rows); | |
1422 l_eigvec.set_size(A_n_rows, A_n_rows); | |
1423 r_eigvec.set_size(A_n_rows, A_n_rows); | |
1424 | |
1425 podarray<T> work( static_cast<uword>(lwork) ); | |
1426 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) ); | |
1427 | |
1428 podarray<T> wr(A_n_rows); | |
1429 podarray<T> wi(A_n_rows); | |
1430 | |
1431 Mat<T> A_copy = A; | |
1432 blas_int info; | |
1433 | |
1434 arma_extra_debug_print("lapack::geev()"); | |
1435 lapack::geev(&jobvl, &jobvr, &n_rows, A_copy.memptr(), &lda, wr.memptr(), wi.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, &info); | |
1436 | |
1437 | |
1438 eigval.set_size(A_n_rows); | |
1439 for(uword i=0; i<A_n_rows; ++i) | |
1440 { | |
1441 eigval[i] = std::complex<T>(wr[i], wi[i]); | |
1442 } | |
1443 | |
1444 return (info == 0); | |
1445 } | |
1446 #else | |
1447 { | |
1448 arma_ignore(eigval); | |
1449 arma_ignore(l_eigvec); | |
1450 arma_ignore(r_eigvec); | |
1451 arma_ignore(X); | |
1452 arma_ignore(side); | |
1453 arma_stop("eig_gen(): use of LAPACK needs to be enabled"); | |
1454 return false; | |
1455 } | |
1456 #endif | |
1457 } | |
1458 | |
1459 | |
1460 | |
1461 | |
1462 | |
1463 //! Eigenvalues and eigenvectors of a general square complex matrix using LAPACK | |
1464 //! The argument 'side' specifies which eigenvectors should be calculated | |
1465 //! (see code for mode details). | |
1466 template<typename T, typename T1> | |
1467 inline | |
1468 bool | |
1469 auxlib::eig_gen | |
1470 ( | |
1471 Col< std::complex<T> >& eigval, | |
1472 Mat< std::complex<T> >& l_eigvec, | |
1473 Mat< std::complex<T> >& r_eigvec, | |
1474 const Base< std::complex<T>, T1 >& X, | |
1475 const char side | |
1476 ) | |
1477 { | |
1478 arma_extra_debug_sigprint(); | |
1479 | |
1480 typedef typename std::complex<T> eT; | |
1481 | |
1482 #if defined(ARMA_USE_LAPACK) | |
1483 { | |
1484 char jobvl; | |
1485 char jobvr; | |
1486 | |
1487 switch(side) | |
1488 { | |
1489 case 'l': // left | |
1490 jobvl = 'V'; | |
1491 jobvr = 'N'; | |
1492 break; | |
1493 | |
1494 case 'r': // right | |
1495 jobvl = 'N'; | |
1496 jobvr = 'V'; | |
1497 break; | |
1498 | |
1499 case 'b': // both | |
1500 jobvl = 'V'; | |
1501 jobvr = 'V'; | |
1502 break; | |
1503 | |
1504 case 'n': // neither | |
1505 jobvl = 'N'; | |
1506 jobvr = 'N'; | |
1507 break; | |
1508 | |
1509 default: | |
1510 arma_stop("eig_gen(): parameter 'side' is invalid"); | |
1511 return false; | |
1512 } | |
1513 | |
1514 Mat<eT> A(X.get_ref()); | |
1515 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" ); | |
1516 | |
1517 if(A.is_empty()) | |
1518 { | |
1519 eigval.reset(); | |
1520 l_eigvec.reset(); | |
1521 r_eigvec.reset(); | |
1522 return true; | |
1523 } | |
1524 | |
1525 uword A_n_rows = A.n_rows; | |
1526 | |
1527 blas_int n_rows = A_n_rows; | |
1528 blas_int lda = A_n_rows; | |
1529 blas_int lwork = (std::max)(blas_int(1), 4*n_rows); // TODO: automatically find best size of lwork | |
1530 | |
1531 eigval.set_size(A_n_rows); | |
1532 l_eigvec.set_size(A_n_rows, A_n_rows); | |
1533 r_eigvec.set_size(A_n_rows, A_n_rows); | |
1534 | |
1535 podarray<eT> work( static_cast<uword>(lwork) ); | |
1536 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) ); // was 2,3 | |
1537 | |
1538 blas_int info; | |
1539 | |
1540 arma_extra_debug_print("lapack::cx_geev()"); | |
1541 lapack::cx_geev(&jobvl, &jobvr, &n_rows, A.memptr(), &lda, eigval.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, rwork.memptr(), &info); | |
1542 | |
1543 return (info == 0); | |
1544 } | |
1545 #else | |
1546 { | |
1547 arma_ignore(eigval); | |
1548 arma_ignore(l_eigvec); | |
1549 arma_ignore(r_eigvec); | |
1550 arma_ignore(X); | |
1551 arma_ignore(side); | |
1552 arma_stop("eig_gen(): use of LAPACK needs to be enabled"); | |
1553 return false; | |
1554 } | |
1555 #endif | |
1556 } | |
1557 | |
1558 | |
1559 | |
1560 template<typename eT, typename T1> | |
1561 inline | |
1562 bool | |
1563 auxlib::chol(Mat<eT>& out, const Base<eT,T1>& X) | |
1564 { | |
1565 arma_extra_debug_sigprint(); | |
1566 | |
1567 #if defined(ARMA_USE_LAPACK) | |
1568 { | |
1569 out = X.get_ref(); | |
1570 | |
1571 arma_debug_check( (out.is_square() == false), "chol(): given matrix is not square" ); | |
1572 | |
1573 if(out.is_empty()) | |
1574 { | |
1575 return true; | |
1576 } | |
1577 | |
1578 const uword out_n_rows = out.n_rows; | |
1579 | |
1580 char uplo = 'U'; | |
1581 blas_int n = out_n_rows; | |
1582 blas_int info; | |
1583 | |
1584 lapack::potrf(&uplo, &n, out.memptr(), &n, &info); | |
1585 | |
1586 for(uword col=0; col<out_n_rows; ++col) | |
1587 { | |
1588 eT* colptr = out.colptr(col); | |
1589 | |
1590 for(uword row=(col+1); row < out_n_rows; ++row) | |
1591 { | |
1592 colptr[row] = eT(0); | |
1593 } | |
1594 } | |
1595 | |
1596 return (info == 0); | |
1597 } | |
1598 #else | |
1599 { | |
1600 arma_ignore(out); | |
1601 arma_stop("chol(): use of LAPACK needs to be enabled"); | |
1602 return false; | |
1603 } | |
1604 #endif | |
1605 } | |
1606 | |
1607 | |
1608 | |
1609 template<typename eT, typename T1> | |
1610 inline | |
1611 bool | |
1612 auxlib::qr(Mat<eT>& Q, Mat<eT>& R, const Base<eT,T1>& X) | |
1613 { | |
1614 arma_extra_debug_sigprint(); | |
1615 | |
1616 #if defined(ARMA_USE_LAPACK) | |
1617 { | |
1618 R = X.get_ref(); | |
1619 | |
1620 const uword R_n_rows = R.n_rows; | |
1621 const uword R_n_cols = R.n_cols; | |
1622 | |
1623 if(R.is_empty()) | |
1624 { | |
1625 Q.eye(R_n_rows, R_n_rows); | |
1626 return true; | |
1627 } | |
1628 | |
1629 blas_int m = static_cast<blas_int>(R_n_rows); | |
1630 blas_int n = static_cast<blas_int>(R_n_cols); | |
1631 blas_int work_len = (std::max)(blas_int(1),n); | |
1632 blas_int work_len_tmp; | |
1633 blas_int k = (std::min)(m,n); | |
1634 blas_int info; | |
1635 | |
1636 podarray<eT> tau( static_cast<uword>(k) ); | |
1637 podarray<eT> work( static_cast<uword>(work_len) ); | |
1638 | |
1639 // query for the optimum value of work_len | |
1640 work_len_tmp = -1; | |
1641 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info); | |
1642 | |
1643 if(info == 0) | |
1644 { | |
1645 work_len = static_cast<blas_int>(access::tmp_real(work[0])); | |
1646 work.set_size( static_cast<uword>(work_len) ); | |
1647 } | |
1648 | |
1649 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info); | |
1650 | |
1651 Q.set_size(R_n_rows, R_n_rows); | |
1652 | |
1653 arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) ); | |
1654 | |
1655 // | |
1656 // construct R | |
1657 | |
1658 for(uword col=0; col < R_n_cols; ++col) | |
1659 { | |
1660 for(uword row=(col+1); row < R_n_rows; ++row) | |
1661 { | |
1662 R.at(row,col) = eT(0); | |
1663 } | |
1664 } | |
1665 | |
1666 | |
1667 if( (is_float<eT>::value == true) || (is_double<eT>::value == true) ) | |
1668 { | |
1669 // query for the optimum value of work_len | |
1670 work_len_tmp = -1; | |
1671 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info); | |
1672 | |
1673 if(info == 0) | |
1674 { | |
1675 work_len = static_cast<blas_int>(access::tmp_real(work[0])); | |
1676 work.set_size( static_cast<uword>(work_len) ); | |
1677 } | |
1678 | |
1679 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info); | |
1680 } | |
1681 else | |
1682 if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) ) | |
1683 { | |
1684 // query for the optimum value of work_len | |
1685 work_len_tmp = -1; | |
1686 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info); | |
1687 | |
1688 if(info == 0) | |
1689 { | |
1690 work_len = static_cast<blas_int>(access::tmp_real(work[0])); | |
1691 work.set_size( static_cast<uword>(work_len) ); | |
1692 } | |
1693 | |
1694 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info); | |
1695 } | |
1696 | |
1697 return (info == 0); | |
1698 } | |
1699 #else | |
1700 { | |
1701 arma_ignore(Q); | |
1702 arma_ignore(R); | |
1703 arma_ignore(X); | |
1704 arma_stop("qr(): use of LAPACK needs to be enabled"); | |
1705 return false; | |
1706 } | |
1707 #endif | |
1708 } | |
1709 | |
1710 | |
1711 | |
1712 template<typename eT, typename T1> | |
1713 inline | |
1714 bool | |
1715 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X, uword& X_n_rows, uword& X_n_cols) | |
1716 { | |
1717 arma_extra_debug_sigprint(); | |
1718 | |
1719 #if defined(ARMA_USE_LAPACK) | |
1720 { | |
1721 Mat<eT> A(X.get_ref()); | |
1722 | |
1723 X_n_rows = A.n_rows; | |
1724 X_n_cols = A.n_cols; | |
1725 | |
1726 if(A.is_empty()) | |
1727 { | |
1728 S.reset(); | |
1729 return true; | |
1730 } | |
1731 | |
1732 Mat<eT> U(1, 1); | |
1733 Mat<eT> V(1, A.n_cols); | |
1734 | |
1735 char jobu = 'N'; | |
1736 char jobvt = 'N'; | |
1737 | |
1738 blas_int m = A.n_rows; | |
1739 blas_int n = A.n_cols; | |
1740 blas_int lda = A.n_rows; | |
1741 blas_int ldu = U.n_rows; | |
1742 blas_int ldvt = V.n_rows; | |
1743 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) ); | |
1744 blas_int info; | |
1745 | |
1746 S.set_size( static_cast<uword>((std::min)(m, n)) ); | |
1747 | |
1748 podarray<eT> work( static_cast<uword>(lwork) ); | |
1749 | |
1750 | |
1751 // let gesvd_() calculate the optimum size of the workspace | |
1752 blas_int lwork_tmp = -1; | |
1753 | |
1754 lapack::gesvd<eT> | |
1755 ( | |
1756 &jobu, &jobvt, | |
1757 &m,&n, | |
1758 A.memptr(), &lda, | |
1759 S.memptr(), | |
1760 U.memptr(), &ldu, | |
1761 V.memptr(), &ldvt, | |
1762 work.memptr(), &lwork_tmp, | |
1763 &info | |
1764 ); | |
1765 | |
1766 if(info == 0) | |
1767 { | |
1768 blas_int proposed_lwork = static_cast<blas_int>(work[0]); | |
1769 | |
1770 if(proposed_lwork > lwork) | |
1771 { | |
1772 lwork = proposed_lwork; | |
1773 work.set_size( static_cast<uword>(lwork) ); | |
1774 } | |
1775 | |
1776 lapack::gesvd<eT> | |
1777 ( | |
1778 &jobu, &jobvt, | |
1779 &m, &n, | |
1780 A.memptr(), &lda, | |
1781 S.memptr(), | |
1782 U.memptr(), &ldu, | |
1783 V.memptr(), &ldvt, | |
1784 work.memptr(), &lwork, | |
1785 &info | |
1786 ); | |
1787 } | |
1788 | |
1789 return (info == 0); | |
1790 } | |
1791 #else | |
1792 { | |
1793 arma_ignore(S); | |
1794 arma_ignore(X); | |
1795 arma_ignore(X_n_rows); | |
1796 arma_ignore(X_n_cols); | |
1797 arma_stop("svd(): use of LAPACK needs to be enabled"); | |
1798 return false; | |
1799 } | |
1800 #endif | |
1801 } | |
1802 | |
1803 | |
1804 | |
1805 template<typename T, typename T1> | |
1806 inline | |
1807 bool | |
1808 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X, uword& X_n_rows, uword& X_n_cols) | |
1809 { | |
1810 arma_extra_debug_sigprint(); | |
1811 | |
1812 typedef std::complex<T> eT; | |
1813 | |
1814 #if defined(ARMA_USE_LAPACK) | |
1815 { | |
1816 Mat<eT> A(X.get_ref()); | |
1817 | |
1818 X_n_rows = A.n_rows; | |
1819 X_n_cols = A.n_cols; | |
1820 | |
1821 if(A.is_empty()) | |
1822 { | |
1823 S.reset(); | |
1824 return true; | |
1825 } | |
1826 | |
1827 Mat<eT> U(1, 1); | |
1828 Mat<eT> V(1, A.n_cols); | |
1829 | |
1830 char jobu = 'N'; | |
1831 char jobvt = 'N'; | |
1832 | |
1833 blas_int m = A.n_rows; | |
1834 blas_int n = A.n_cols; | |
1835 blas_int lda = A.n_rows; | |
1836 blas_int ldu = U.n_rows; | |
1837 blas_int ldvt = V.n_rows; | |
1838 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) ); | |
1839 blas_int info; | |
1840 | |
1841 S.set_size( static_cast<uword>((std::min)(m,n)) ); | |
1842 | |
1843 podarray<eT> work( static_cast<uword>(lwork) ); | |
1844 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) ); | |
1845 | |
1846 // let gesvd_() calculate the optimum size of the workspace | |
1847 blas_int lwork_tmp = -1; | |
1848 | |
1849 lapack::cx_gesvd<T> | |
1850 ( | |
1851 &jobu, &jobvt, | |
1852 &m, &n, | |
1853 A.memptr(), &lda, | |
1854 S.memptr(), | |
1855 U.memptr(), &ldu, | |
1856 V.memptr(), &ldvt, | |
1857 work.memptr(), &lwork_tmp, | |
1858 rwork.memptr(), | |
1859 &info | |
1860 ); | |
1861 | |
1862 if(info == 0) | |
1863 { | |
1864 blas_int proposed_lwork = static_cast<blas_int>(real(work[0])); | |
1865 if(proposed_lwork > lwork) | |
1866 { | |
1867 lwork = proposed_lwork; | |
1868 work.set_size( static_cast<uword>(lwork) ); | |
1869 } | |
1870 | |
1871 lapack::cx_gesvd<T> | |
1872 ( | |
1873 &jobu, &jobvt, | |
1874 &m, &n, | |
1875 A.memptr(), &lda, | |
1876 S.memptr(), | |
1877 U.memptr(), &ldu, | |
1878 V.memptr(), &ldvt, | |
1879 work.memptr(), &lwork, | |
1880 rwork.memptr(), | |
1881 &info | |
1882 ); | |
1883 } | |
1884 | |
1885 return (info == 0); | |
1886 } | |
1887 #else | |
1888 { | |
1889 arma_ignore(S); | |
1890 arma_ignore(X); | |
1891 arma_ignore(X_n_rows); | |
1892 arma_ignore(X_n_cols); | |
1893 | |
1894 arma_stop("svd(): use of LAPACK needs to be enabled"); | |
1895 return false; | |
1896 } | |
1897 #endif | |
1898 } | |
1899 | |
1900 | |
1901 | |
1902 template<typename eT, typename T1> | |
1903 inline | |
1904 bool | |
1905 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X) | |
1906 { | |
1907 arma_extra_debug_sigprint(); | |
1908 | |
1909 uword junk; | |
1910 return auxlib::svd(S, X, junk, junk); | |
1911 } | |
1912 | |
1913 | |
1914 | |
1915 template<typename T, typename T1> | |
1916 inline | |
1917 bool | |
1918 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X) | |
1919 { | |
1920 arma_extra_debug_sigprint(); | |
1921 | |
1922 uword junk; | |
1923 return auxlib::svd(S, X, junk, junk); | |
1924 } | |
1925 | |
1926 | |
1927 | |
1928 template<typename eT, typename T1> | |
1929 inline | |
1930 bool | |
1931 auxlib::svd(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X) | |
1932 { | |
1933 arma_extra_debug_sigprint(); | |
1934 | |
1935 #if defined(ARMA_USE_LAPACK) | |
1936 { | |
1937 Mat<eT> A(X.get_ref()); | |
1938 | |
1939 if(A.is_empty()) | |
1940 { | |
1941 U.eye(A.n_rows, A.n_rows); | |
1942 S.reset(); | |
1943 V.eye(A.n_cols, A.n_cols); | |
1944 return true; | |
1945 } | |
1946 | |
1947 U.set_size(A.n_rows, A.n_rows); | |
1948 V.set_size(A.n_cols, A.n_cols); | |
1949 | |
1950 char jobu = 'A'; | |
1951 char jobvt = 'A'; | |
1952 | |
1953 blas_int m = A.n_rows; | |
1954 blas_int n = A.n_cols; | |
1955 blas_int lda = A.n_rows; | |
1956 blas_int ldu = U.n_rows; | |
1957 blas_int ldvt = V.n_rows; | |
1958 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) ); | |
1959 blas_int info; | |
1960 | |
1961 | |
1962 S.set_size( static_cast<uword>((std::min)(m,n)) ); | |
1963 podarray<eT> work( static_cast<uword>(lwork) ); | |
1964 | |
1965 // let gesvd_() calculate the optimum size of the workspace | |
1966 blas_int lwork_tmp = -1; | |
1967 | |
1968 lapack::gesvd<eT> | |
1969 ( | |
1970 &jobu, &jobvt, | |
1971 &m, &n, | |
1972 A.memptr(), &lda, | |
1973 S.memptr(), | |
1974 U.memptr(), &ldu, | |
1975 V.memptr(), &ldvt, | |
1976 work.memptr(), &lwork_tmp, | |
1977 &info | |
1978 ); | |
1979 | |
1980 if(info == 0) | |
1981 { | |
1982 blas_int proposed_lwork = static_cast<blas_int>(work[0]); | |
1983 if(proposed_lwork > lwork) | |
1984 { | |
1985 lwork = proposed_lwork; | |
1986 work.set_size( static_cast<uword>(lwork) ); | |
1987 } | |
1988 | |
1989 lapack::gesvd<eT> | |
1990 ( | |
1991 &jobu, &jobvt, | |
1992 &m, &n, | |
1993 A.memptr(), &lda, | |
1994 S.memptr(), | |
1995 U.memptr(), &ldu, | |
1996 V.memptr(), &ldvt, | |
1997 work.memptr(), &lwork, | |
1998 &info | |
1999 ); | |
2000 | |
2001 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done | |
2002 } | |
2003 | |
2004 return (info == 0); | |
2005 } | |
2006 #else | |
2007 { | |
2008 arma_ignore(U); | |
2009 arma_ignore(S); | |
2010 arma_ignore(V); | |
2011 arma_ignore(X); | |
2012 arma_stop("svd(): use of LAPACK needs to be enabled"); | |
2013 return false; | |
2014 } | |
2015 #endif | |
2016 } | |
2017 | |
2018 | |
2019 | |
2020 template<typename T, typename T1> | |
2021 inline | |
2022 bool | |
2023 auxlib::svd(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X) | |
2024 { | |
2025 arma_extra_debug_sigprint(); | |
2026 | |
2027 typedef std::complex<T> eT; | |
2028 | |
2029 #if defined(ARMA_USE_LAPACK) | |
2030 { | |
2031 Mat<eT> A(X.get_ref()); | |
2032 | |
2033 if(A.is_empty()) | |
2034 { | |
2035 U.eye(A.n_rows, A.n_rows); | |
2036 S.reset(); | |
2037 V.eye(A.n_cols, A.n_cols); | |
2038 return true; | |
2039 } | |
2040 | |
2041 U.set_size(A.n_rows, A.n_rows); | |
2042 V.set_size(A.n_cols, A.n_cols); | |
2043 | |
2044 char jobu = 'A'; | |
2045 char jobvt = 'A'; | |
2046 | |
2047 blas_int m = A.n_rows; | |
2048 blas_int n = A.n_cols; | |
2049 blas_int lda = A.n_rows; | |
2050 blas_int ldu = U.n_rows; | |
2051 blas_int ldvt = V.n_rows; | |
2052 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) ); | |
2053 blas_int info; | |
2054 | |
2055 S.set_size( static_cast<uword>((std::min)(m,n)) ); | |
2056 | |
2057 podarray<eT> work( static_cast<uword>(lwork) ); | |
2058 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) ); | |
2059 | |
2060 // let gesvd_() calculate the optimum size of the workspace | |
2061 blas_int lwork_tmp = -1; | |
2062 lapack::cx_gesvd<T> | |
2063 ( | |
2064 &jobu, &jobvt, | |
2065 &m, &n, | |
2066 A.memptr(), &lda, | |
2067 S.memptr(), | |
2068 U.memptr(), &ldu, | |
2069 V.memptr(), &ldvt, | |
2070 work.memptr(), &lwork_tmp, | |
2071 rwork.memptr(), | |
2072 &info | |
2073 ); | |
2074 | |
2075 if(info == 0) | |
2076 { | |
2077 blas_int proposed_lwork = static_cast<blas_int>(real(work[0])); | |
2078 if(proposed_lwork > lwork) | |
2079 { | |
2080 lwork = proposed_lwork; | |
2081 work.set_size( static_cast<uword>(lwork) ); | |
2082 } | |
2083 | |
2084 lapack::cx_gesvd<T> | |
2085 ( | |
2086 &jobu, &jobvt, | |
2087 &m, &n, | |
2088 A.memptr(), &lda, | |
2089 S.memptr(), | |
2090 U.memptr(), &ldu, | |
2091 V.memptr(), &ldvt, | |
2092 work.memptr(), &lwork, | |
2093 rwork.memptr(), | |
2094 &info | |
2095 ); | |
2096 | |
2097 op_htrans::apply(V,V); // op_htrans will work out that an in-place transpose can be done | |
2098 } | |
2099 | |
2100 return (info == 0); | |
2101 } | |
2102 #else | |
2103 { | |
2104 arma_ignore(U); | |
2105 arma_ignore(S); | |
2106 arma_ignore(V); | |
2107 arma_ignore(X); | |
2108 arma_stop("svd(): use of LAPACK needs to be enabled"); | |
2109 return false; | |
2110 } | |
2111 #endif | |
2112 | |
2113 } | |
2114 | |
2115 | |
2116 | |
2117 template<typename eT, typename T1> | |
2118 inline | |
2119 bool | |
2120 auxlib::svd_econ(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X, const char mode) | |
2121 { | |
2122 arma_extra_debug_sigprint(); | |
2123 | |
2124 #if defined(ARMA_USE_LAPACK) | |
2125 { | |
2126 Mat<eT> A(X.get_ref()); | |
2127 | |
2128 blas_int m = A.n_rows; | |
2129 blas_int n = A.n_cols; | |
2130 blas_int lda = A.n_rows; | |
2131 | |
2132 S.set_size( static_cast<uword>((std::min)(m,n)) ); | |
2133 | |
2134 blas_int ldu = 0; | |
2135 blas_int ldvt = 0; | |
2136 | |
2137 char jobu; | |
2138 char jobvt; | |
2139 | |
2140 switch(mode) | |
2141 { | |
2142 case 'l': | |
2143 jobu = 'S'; | |
2144 jobvt = 'N'; | |
2145 | |
2146 ldu = m; | |
2147 ldvt = 1; | |
2148 | |
2149 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) ); | |
2150 V.reset(); | |
2151 | |
2152 break; | |
2153 | |
2154 | |
2155 case 'r': | |
2156 jobu = 'N'; | |
2157 jobvt = 'S'; | |
2158 | |
2159 ldu = 1; | |
2160 ldvt = (std::min)(m,n); | |
2161 | |
2162 U.reset(); | |
2163 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) ); | |
2164 | |
2165 break; | |
2166 | |
2167 | |
2168 case 'b': | |
2169 jobu = 'S'; | |
2170 jobvt = 'S'; | |
2171 | |
2172 ldu = m; | |
2173 ldvt = (std::min)(m,n); | |
2174 | |
2175 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) ); | |
2176 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) ); | |
2177 | |
2178 break; | |
2179 | |
2180 | |
2181 default: | |
2182 U.reset(); | |
2183 S.reset(); | |
2184 V.reset(); | |
2185 return false; | |
2186 } | |
2187 | |
2188 | |
2189 if(A.is_empty()) | |
2190 { | |
2191 U.eye(); | |
2192 S.reset(); | |
2193 V.eye(); | |
2194 return true; | |
2195 } | |
2196 | |
2197 | |
2198 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) ); | |
2199 blas_int info = 0; | |
2200 | |
2201 | |
2202 podarray<eT> work( static_cast<uword>(lwork) ); | |
2203 | |
2204 // let gesvd_() calculate the optimum size of the workspace | |
2205 blas_int lwork_tmp = -1; | |
2206 | |
2207 lapack::gesvd<eT> | |
2208 ( | |
2209 &jobu, &jobvt, | |
2210 &m, &n, | |
2211 A.memptr(), &lda, | |
2212 S.memptr(), | |
2213 U.memptr(), &ldu, | |
2214 V.memptr(), &ldvt, | |
2215 work.memptr(), &lwork_tmp, | |
2216 &info | |
2217 ); | |
2218 | |
2219 if(info == 0) | |
2220 { | |
2221 blas_int proposed_lwork = static_cast<blas_int>(work[0]); | |
2222 if(proposed_lwork > lwork) | |
2223 { | |
2224 lwork = proposed_lwork; | |
2225 work.set_size( static_cast<uword>(lwork) ); | |
2226 } | |
2227 | |
2228 lapack::gesvd<eT> | |
2229 ( | |
2230 &jobu, &jobvt, | |
2231 &m, &n, | |
2232 A.memptr(), &lda, | |
2233 S.memptr(), | |
2234 U.memptr(), &ldu, | |
2235 V.memptr(), &ldvt, | |
2236 work.memptr(), &lwork, | |
2237 &info | |
2238 ); | |
2239 | |
2240 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done | |
2241 } | |
2242 | |
2243 return (info == 0); | |
2244 } | |
2245 #else | |
2246 { | |
2247 arma_ignore(U); | |
2248 arma_ignore(S); | |
2249 arma_ignore(V); | |
2250 arma_ignore(X); | |
2251 arma_ignore(mode); | |
2252 arma_stop("svd(): use of LAPACK needs to be enabled"); | |
2253 return false; | |
2254 } | |
2255 #endif | |
2256 } | |
2257 | |
2258 | |
2259 | |
2260 template<typename T, typename T1> | |
2261 inline | |
2262 bool | |
2263 auxlib::svd_econ(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X, const char mode) | |
2264 { | |
2265 arma_extra_debug_sigprint(); | |
2266 | |
2267 typedef std::complex<T> eT; | |
2268 | |
2269 #if defined(ARMA_USE_LAPACK) | |
2270 { | |
2271 Mat<eT> A(X.get_ref()); | |
2272 | |
2273 blas_int m = A.n_rows; | |
2274 blas_int n = A.n_cols; | |
2275 blas_int lda = A.n_rows; | |
2276 | |
2277 S.set_size( static_cast<uword>((std::min)(m,n)) ); | |
2278 | |
2279 blas_int ldu = 0; | |
2280 blas_int ldvt = 0; | |
2281 | |
2282 char jobu; | |
2283 char jobvt; | |
2284 | |
2285 switch(mode) | |
2286 { | |
2287 case 'l': | |
2288 jobu = 'S'; | |
2289 jobvt = 'N'; | |
2290 | |
2291 ldu = m; | |
2292 ldvt = 1; | |
2293 | |
2294 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) ); | |
2295 V.reset(); | |
2296 | |
2297 break; | |
2298 | |
2299 | |
2300 case 'r': | |
2301 jobu = 'N'; | |
2302 jobvt = 'S'; | |
2303 | |
2304 ldu = 1; | |
2305 ldvt = (std::min)(m,n); | |
2306 | |
2307 U.reset(); | |
2308 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) ); | |
2309 | |
2310 break; | |
2311 | |
2312 | |
2313 case 'b': | |
2314 jobu = 'S'; | |
2315 jobvt = 'S'; | |
2316 | |
2317 ldu = m; | |
2318 ldvt = (std::min)(m,n); | |
2319 | |
2320 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) ); | |
2321 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) ); | |
2322 | |
2323 break; | |
2324 | |
2325 | |
2326 default: | |
2327 U.reset(); | |
2328 S.reset(); | |
2329 V.reset(); | |
2330 return false; | |
2331 } | |
2332 | |
2333 | |
2334 if(A.is_empty()) | |
2335 { | |
2336 U.eye(); | |
2337 S.reset(); | |
2338 V.eye(); | |
2339 return true; | |
2340 } | |
2341 | |
2342 | |
2343 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) ); | |
2344 blas_int info = 0; | |
2345 | |
2346 | |
2347 podarray<eT> work( static_cast<uword>(lwork) ); | |
2348 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) ); | |
2349 | |
2350 // let gesvd_() calculate the optimum size of the workspace | |
2351 blas_int lwork_tmp = -1; | |
2352 | |
2353 lapack::cx_gesvd<T> | |
2354 ( | |
2355 &jobu, &jobvt, | |
2356 &m, &n, | |
2357 A.memptr(), &lda, | |
2358 S.memptr(), | |
2359 U.memptr(), &ldu, | |
2360 V.memptr(), &ldvt, | |
2361 work.memptr(), &lwork_tmp, | |
2362 rwork.memptr(), | |
2363 &info | |
2364 ); | |
2365 | |
2366 if(info == 0) | |
2367 { | |
2368 blas_int proposed_lwork = static_cast<blas_int>(real(work[0])); | |
2369 if(proposed_lwork > lwork) | |
2370 { | |
2371 lwork = proposed_lwork; | |
2372 work.set_size( static_cast<uword>(lwork) ); | |
2373 } | |
2374 | |
2375 lapack::cx_gesvd<T> | |
2376 ( | |
2377 &jobu, &jobvt, | |
2378 &m, &n, | |
2379 A.memptr(), &lda, | |
2380 S.memptr(), | |
2381 U.memptr(), &ldu, | |
2382 V.memptr(), &ldvt, | |
2383 work.memptr(), &lwork, | |
2384 rwork.memptr(), | |
2385 &info | |
2386 ); | |
2387 | |
2388 op_htrans::apply(V,V); // op_strans will work out that an in-place transpose can be done | |
2389 } | |
2390 | |
2391 return (info == 0); | |
2392 } | |
2393 #else | |
2394 { | |
2395 arma_ignore(U); | |
2396 arma_ignore(S); | |
2397 arma_ignore(V); | |
2398 arma_ignore(X); | |
2399 arma_ignore(mode); | |
2400 arma_stop("svd(): use of LAPACK needs to be enabled"); | |
2401 return false; | |
2402 } | |
2403 #endif | |
2404 } | |
2405 | |
2406 | |
2407 | |
2408 //! Solve a system of linear equations. | |
2409 //! Assumes that A.n_rows = A.n_cols and B.n_rows = A.n_rows | |
2410 template<typename eT> | |
2411 inline | |
2412 bool | |
2413 auxlib::solve(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B, const bool slow) | |
2414 { | |
2415 arma_extra_debug_sigprint(); | |
2416 | |
2417 if(A.is_empty() || B.is_empty()) | |
2418 { | |
2419 out.zeros(A.n_cols, B.n_cols); | |
2420 return true; | |
2421 } | |
2422 else | |
2423 { | |
2424 const uword A_n_rows = A.n_rows; | |
2425 | |
2426 bool status = false; | |
2427 | |
2428 if( (A_n_rows <= 4) && (slow == false) ) | |
2429 { | |
2430 Mat<eT> A_inv; | |
2431 | |
2432 status = auxlib::inv_noalias_tinymat(A_inv, A, A_n_rows); | |
2433 | |
2434 if(status == true) | |
2435 { | |
2436 out.set_size(A_n_rows, B.n_cols); | |
2437 | |
2438 gemm_emul<false,false,false,false>::apply(out, A_inv, B); | |
2439 | |
2440 return true; | |
2441 } | |
2442 } | |
2443 | |
2444 if( (A_n_rows > 4) || (status == false) ) | |
2445 { | |
2446 #if defined(ARMA_USE_ATLAS) | |
2447 { | |
2448 podarray<int> ipiv(A_n_rows); | |
2449 | |
2450 out = B; | |
2451 | |
2452 int info = atlas::clapack_gesv<eT>(atlas::CblasColMajor, A_n_rows, B.n_cols, A.memptr(), A_n_rows, ipiv.memptr(), out.memptr(), A_n_rows); | |
2453 | |
2454 return (info == 0); | |
2455 } | |
2456 #elif defined(ARMA_USE_LAPACK) | |
2457 { | |
2458 blas_int n = A_n_rows; | |
2459 blas_int lda = A_n_rows; | |
2460 blas_int ldb = A_n_rows; | |
2461 blas_int nrhs = B.n_cols; | |
2462 blas_int info; | |
2463 | |
2464 podarray<blas_int> ipiv(A_n_rows); | |
2465 | |
2466 out = B; | |
2467 | |
2468 lapack::gesv<eT>(&n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); | |
2469 | |
2470 return (info == 0); | |
2471 } | |
2472 #else | |
2473 { | |
2474 arma_stop("solve(): use of ATLAS or LAPACK needs to be enabled"); | |
2475 return false; | |
2476 } | |
2477 #endif | |
2478 } | |
2479 } | |
2480 | |
2481 return true; | |
2482 } | |
2483 | |
2484 | |
2485 | |
2486 //! Solve an over-determined system. | |
2487 //! Assumes that A.n_rows > A.n_cols and B.n_rows = A.n_rows | |
2488 template<typename eT> | |
2489 inline | |
2490 bool | |
2491 auxlib::solve_od(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B) | |
2492 { | |
2493 arma_extra_debug_sigprint(); | |
2494 | |
2495 #if defined(ARMA_USE_LAPACK) | |
2496 { | |
2497 if(A.is_empty() || B.is_empty()) | |
2498 { | |
2499 out.zeros(A.n_cols, B.n_cols); | |
2500 return true; | |
2501 } | |
2502 | |
2503 char trans = 'N'; | |
2504 | |
2505 blas_int m = A.n_rows; | |
2506 blas_int n = A.n_cols; | |
2507 blas_int lda = A.n_rows; | |
2508 blas_int ldb = A.n_rows; | |
2509 blas_int nrhs = B.n_cols; | |
2510 blas_int lwork = n + (std::max)(n, nrhs); | |
2511 blas_int info; | |
2512 | |
2513 Mat<eT> tmp = B; | |
2514 | |
2515 podarray<eT> work( static_cast<uword>(lwork) ); | |
2516 | |
2517 arma_extra_debug_print("lapack::gels()"); | |
2518 | |
2519 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems | |
2520 | |
2521 lapack::gels<eT> | |
2522 ( | |
2523 &trans, &m, &n, &nrhs, | |
2524 A.memptr(), &lda, | |
2525 tmp.memptr(), &ldb, | |
2526 work.memptr(), &lwork, | |
2527 &info | |
2528 ); | |
2529 | |
2530 arma_extra_debug_print("lapack::gels() -- finished"); | |
2531 | |
2532 out.set_size(A.n_cols, B.n_cols); | |
2533 | |
2534 for(uword col=0; col<B.n_cols; ++col) | |
2535 { | |
2536 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols ); | |
2537 } | |
2538 | |
2539 return (info == 0); | |
2540 } | |
2541 #else | |
2542 { | |
2543 arma_ignore(out); | |
2544 arma_ignore(A); | |
2545 arma_ignore(B); | |
2546 arma_stop("solve(): use of LAPACK needs to be enabled"); | |
2547 return false; | |
2548 } | |
2549 #endif | |
2550 } | |
2551 | |
2552 | |
2553 | |
2554 //! Solve an under-determined system. | |
2555 //! Assumes that A.n_rows < A.n_cols and B.n_rows = A.n_rows | |
2556 template<typename eT> | |
2557 inline | |
2558 bool | |
2559 auxlib::solve_ud(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B) | |
2560 { | |
2561 arma_extra_debug_sigprint(); | |
2562 | |
2563 #if defined(ARMA_USE_LAPACK) | |
2564 { | |
2565 if(A.is_empty() || B.is_empty()) | |
2566 { | |
2567 out.zeros(A.n_cols, B.n_cols); | |
2568 return true; | |
2569 } | |
2570 | |
2571 char trans = 'N'; | |
2572 | |
2573 blas_int m = A.n_rows; | |
2574 blas_int n = A.n_cols; | |
2575 blas_int lda = A.n_rows; | |
2576 blas_int ldb = A.n_cols; | |
2577 blas_int nrhs = B.n_cols; | |
2578 blas_int lwork = m + (std::max)(m,nrhs); | |
2579 blas_int info; | |
2580 | |
2581 | |
2582 Mat<eT> tmp; | |
2583 tmp.zeros(A.n_cols, B.n_cols); | |
2584 | |
2585 for(uword col=0; col<B.n_cols; ++col) | |
2586 { | |
2587 eT* tmp_colmem = tmp.colptr(col); | |
2588 | |
2589 arrayops::copy( tmp_colmem, B.colptr(col), B.n_rows ); | |
2590 | |
2591 for(uword row=B.n_rows; row<A.n_cols; ++row) | |
2592 { | |
2593 tmp_colmem[row] = eT(0); | |
2594 } | |
2595 } | |
2596 | |
2597 podarray<eT> work( static_cast<uword>(lwork) ); | |
2598 | |
2599 arma_extra_debug_print("lapack::gels()"); | |
2600 | |
2601 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems | |
2602 | |
2603 lapack::gels<eT> | |
2604 ( | |
2605 &trans, &m, &n, &nrhs, | |
2606 A.memptr(), &lda, | |
2607 tmp.memptr(), &ldb, | |
2608 work.memptr(), &lwork, | |
2609 &info | |
2610 ); | |
2611 | |
2612 arma_extra_debug_print("lapack::gels() -- finished"); | |
2613 | |
2614 out.set_size(A.n_cols, B.n_cols); | |
2615 | |
2616 for(uword col=0; col<B.n_cols; ++col) | |
2617 { | |
2618 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols ); | |
2619 } | |
2620 | |
2621 return (info == 0); | |
2622 } | |
2623 #else | |
2624 { | |
2625 arma_ignore(out); | |
2626 arma_ignore(A); | |
2627 arma_ignore(B); | |
2628 arma_stop("solve(): use of LAPACK needs to be enabled"); | |
2629 return false; | |
2630 } | |
2631 #endif | |
2632 } | |
2633 | |
2634 | |
2635 | |
2636 // | |
2637 // solve_tr | |
2638 | |
2639 template<typename eT> | |
2640 inline | |
2641 bool | |
2642 auxlib::solve_tr(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const uword layout) | |
2643 { | |
2644 arma_extra_debug_sigprint(); | |
2645 | |
2646 #if defined(ARMA_USE_LAPACK) | |
2647 { | |
2648 if(A.is_empty() || B.is_empty()) | |
2649 { | |
2650 out.zeros(A.n_cols, B.n_cols); | |
2651 return true; | |
2652 } | |
2653 | |
2654 out = B; | |
2655 | |
2656 char uplo = (layout == 0) ? 'U' : 'L'; | |
2657 char trans = 'N'; | |
2658 char diag = 'N'; | |
2659 blas_int n = blas_int(A.n_rows); | |
2660 blas_int nrhs = blas_int(B.n_cols); | |
2661 blas_int info = 0; | |
2662 | |
2663 lapack::trtrs<eT>(&uplo, &trans, &diag, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info); | |
2664 | |
2665 return (info == 0); | |
2666 } | |
2667 #else | |
2668 { | |
2669 arma_ignore(out); | |
2670 arma_ignore(A); | |
2671 arma_ignore(B); | |
2672 arma_ignore(layout); | |
2673 arma_stop("solve(): use of LAPACK needs to be enabled"); | |
2674 return false; | |
2675 } | |
2676 #endif | |
2677 } | |
2678 | |
2679 | |
2680 | |
2681 // | |
2682 // Schur decomposition | |
2683 | |
2684 template<typename eT> | |
2685 inline | |
2686 bool | |
2687 auxlib::schur_dec(Mat<eT>& Z, Mat<eT>& T, const Mat<eT>& A) | |
2688 { | |
2689 arma_extra_debug_sigprint(); | |
2690 | |
2691 #if defined(ARMA_USE_LAPACK) | |
2692 { | |
2693 arma_debug_check( (A.is_square() == false), "schur_dec(): given matrix is not square" ); | |
2694 | |
2695 if(A.is_empty()) | |
2696 { | |
2697 Z.reset(); | |
2698 T.reset(); | |
2699 return true; | |
2700 } | |
2701 | |
2702 const uword A_n_rows = A.n_rows; | |
2703 | |
2704 char jobvs = 'V'; // get Schur vectors (Z) | |
2705 char sort = 'N'; // do not sort eigenvalues/vectors | |
2706 blas_int* select = 0; // pointer to sorting function | |
2707 blas_int n = blas_int(A_n_rows); | |
2708 blas_int sdim = 0; // output for sorting | |
2709 | |
2710 blas_int lwork = 3 * n; // workspace must be at least 3 * n (if set to -1, optimal size is output in work(0) and nothing else is done | |
2711 | |
2712 podarray<eT> work( static_cast<uword>(lwork) ); | |
2713 podarray<blas_int> bwork(A_n_rows); | |
2714 | |
2715 blas_int info = 0; | |
2716 | |
2717 Z.set_size(A_n_rows, A_n_rows); | |
2718 T = A; | |
2719 | |
2720 podarray<eT> wr(A_n_rows); // output for eigenvalues | |
2721 podarray<eT> wi(A_n_rows); // output for eigenvalues | |
2722 | |
2723 lapack::gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, wr.memptr(), wi.memptr(), Z.memptr(), &n, work.memptr(), &lwork, bwork.memptr(), &info); | |
2724 | |
2725 return (info == 0); | |
2726 } | |
2727 #else | |
2728 { | |
2729 arma_ignore(Z); | |
2730 arma_ignore(T); | |
2731 arma_stop("schur_dec(): use of LAPACK needs to be enabled"); | |
2732 return false; | |
2733 } | |
2734 #endif | |
2735 } | |
2736 | |
2737 | |
2738 | |
2739 template<typename cT> | |
2740 inline | |
2741 bool | |
2742 auxlib::schur_dec(Mat<std::complex<cT> >& Z, Mat<std::complex<cT> >& T, const Mat<std::complex<cT> >& A) | |
2743 { | |
2744 arma_extra_debug_sigprint(); | |
2745 | |
2746 #if defined(ARMA_USE_LAPACK) | |
2747 { | |
2748 arma_debug_check( (A.is_square() == false), "schur_dec(): matrix A is not square" ); | |
2749 | |
2750 if(A.is_empty()) | |
2751 { | |
2752 Z.reset(); | |
2753 T.reset(); | |
2754 return true; | |
2755 } | |
2756 | |
2757 typedef std::complex<cT> eT; | |
2758 | |
2759 const uword A_n_rows = A.n_rows; | |
2760 | |
2761 char jobvs = 'V'; // get Schur vectors (Z) | |
2762 char sort = 'N'; // do not sort eigenvalues/vectors | |
2763 blas_int* select = 0; // pointer to sorting function | |
2764 blas_int n = blas_int(A_n_rows); | |
2765 blas_int sdim = 0; // output for sorting | |
2766 | |
2767 blas_int lwork = 3 * n; // workspace must be at least 3 * n (if set to -1, optimal size is output in work(0) and nothing else is done | |
2768 | |
2769 podarray<eT> work( static_cast<uword>(lwork) ); | |
2770 podarray<blas_int> bwork(A_n_rows); | |
2771 | |
2772 blas_int info = 0; | |
2773 | |
2774 Z.set_size(A_n_rows, A_n_rows); | |
2775 T = A; | |
2776 | |
2777 podarray<eT> w(A_n_rows); // output for eigenvalues | |
2778 podarray<cT> rwork(A_n_rows); | |
2779 | |
2780 lapack::cx_gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, w.memptr(), Z.memptr(), &n, work.memptr(), &lwork, rwork.memptr(), bwork.memptr(), &info); | |
2781 | |
2782 return (info == 0); | |
2783 } | |
2784 #else | |
2785 { | |
2786 arma_ignore(Z); | |
2787 arma_ignore(T); | |
2788 arma_stop("schur_dec(): use of LAPACK needs to be enabled"); | |
2789 return false; | |
2790 } | |
2791 #endif | |
2792 } | |
2793 | |
2794 | |
2795 | |
2796 // | |
2797 // syl (solution of the Sylvester equation AX + XB = C) | |
2798 | |
2799 template<typename eT> | |
2800 inline | |
2801 bool | |
2802 auxlib::syl(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C) | |
2803 { | |
2804 arma_extra_debug_sigprint(); | |
2805 | |
2806 arma_debug_check | |
2807 ( | |
2808 (A.is_square() == false) || (B.is_square() == false), | |
2809 "syl(): given matrix is not square" | |
2810 ); | |
2811 | |
2812 arma_debug_check | |
2813 ( | |
2814 (C.n_rows != A.n_rows) || (C.n_cols != B.n_cols), | |
2815 "syl(): matrices are not conformant" | |
2816 ); | |
2817 | |
2818 if(A.is_empty() || B.is_empty() || C.is_empty()) | |
2819 { | |
2820 X.reset(); | |
2821 return true; | |
2822 } | |
2823 | |
2824 #if defined(ARMA_USE_LAPACK) | |
2825 { | |
2826 Mat<eT> Z1, Z2, T1, T2; | |
2827 | |
2828 const bool status_sd1 = auxlib::schur_dec(Z1, T1, A); | |
2829 const bool status_sd2 = auxlib::schur_dec(Z2, T2, B); | |
2830 | |
2831 if( (status_sd1 == false) || (status_sd2 == false) ) | |
2832 { | |
2833 return false; | |
2834 } | |
2835 | |
2836 char trana = 'N'; | |
2837 char tranb = 'N'; | |
2838 blas_int isgn = +1; | |
2839 blas_int m = blas_int(T1.n_rows); | |
2840 blas_int n = blas_int(T2.n_cols); | |
2841 | |
2842 eT scale = eT(0); | |
2843 blas_int info = 0; | |
2844 | |
2845 Mat<eT> Y = trans(Z1) * C * Z2; | |
2846 | |
2847 lapack::trsyl<eT>(&trana, &tranb, &isgn, &m, &n, T1.memptr(), &m, T2.memptr(), &n, Y.memptr(), &m, &scale, &info); | |
2848 | |
2849 //Y /= scale; | |
2850 Y /= (-scale); | |
2851 | |
2852 X = Z1 * Y * trans(Z2); | |
2853 | |
2854 return (info >= 0); | |
2855 } | |
2856 #else | |
2857 { | |
2858 arma_stop("syl(): use of LAPACK needs to be enabled"); | |
2859 return false; | |
2860 } | |
2861 #endif | |
2862 } | |
2863 | |
2864 | |
2865 | |
2866 // | |
2867 // lyap (solution of the continuous Lyapunov equation AX + XA^H + Q = 0) | |
2868 | |
2869 template<typename eT> | |
2870 inline | |
2871 bool | |
2872 auxlib::lyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q) | |
2873 { | |
2874 arma_extra_debug_sigprint(); | |
2875 | |
2876 arma_debug_check( (A.is_square() == false), "lyap(): matrix A is not square"); | |
2877 arma_debug_check( (Q.is_square() == false), "lyap(): matrix Q is not square"); | |
2878 arma_debug_check( (A.n_rows != Q.n_rows), "lyap(): matrices A and Q have different dimensions"); | |
2879 | |
2880 Mat<eT> htransA; | |
2881 op_htrans::apply_noalias(htransA, A); | |
2882 | |
2883 const Mat<eT> mQ = -Q; | |
2884 | |
2885 return auxlib::syl(X, A, htransA, mQ); | |
2886 } | |
2887 | |
2888 | |
2889 | |
2890 // | |
2891 // dlyap (solution of the discrete Lyapunov equation AXA^H - X + Q = 0) | |
2892 | |
2893 template<typename eT> | |
2894 inline | |
2895 bool | |
2896 auxlib::dlyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q) | |
2897 { | |
2898 arma_extra_debug_sigprint(); | |
2899 | |
2900 arma_debug_check( (A.is_square() == false), "dlyap(): matrix A is not square"); | |
2901 arma_debug_check( (Q.is_square() == false), "dlyap(): matrix Q is not square"); | |
2902 arma_debug_check( (A.n_rows != Q.n_rows), "dlyap(): matrices A and Q have different dimensions"); | |
2903 | |
2904 const Col<eT> vecQ = reshape(Q, Q.n_elem, 1); | |
2905 | |
2906 const Mat<eT> M = eye< Mat<eT> >(Q.n_elem, Q.n_elem) - kron(conj(A), A); | |
2907 | |
2908 Col<eT> vecX; | |
2909 | |
2910 const bool status = solve(vecX, M, vecQ); | |
2911 | |
2912 if(status == true) | |
2913 { | |
2914 X = reshape(vecX, Q.n_rows, Q.n_cols); | |
2915 return true; | |
2916 } | |
2917 else | |
2918 { | |
2919 X.reset(); | |
2920 return false; | |
2921 } | |
2922 } | |
2923 | |
2924 | |
2925 | |
2926 //! @} |