Chris@49
|
1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2013 Conrad Sanderson
|
Chris@49
|
3 // Copyright (C) 2009 Edmund Highcock
|
Chris@49
|
4 // Copyright (C) 2011 James Sanders
|
Chris@49
|
5 // Copyright (C) 2011 Stanislav Funiak
|
Chris@49
|
6 // Copyright (C) 2012 Eric Jon Sundstrom
|
Chris@49
|
7 // Copyright (C) 2012 Michael McNeil Forbes
|
Chris@49
|
8 //
|
Chris@49
|
9 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
10 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
11 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 //! \addtogroup auxlib
|
Chris@49
|
15 //! @{
|
Chris@49
|
16
|
Chris@49
|
17
|
Chris@49
|
18
|
Chris@49
|
19 //! immediate matrix inverse
|
Chris@49
|
20 template<typename eT, typename T1>
|
Chris@49
|
21 inline
|
Chris@49
|
22 bool
|
Chris@49
|
23 auxlib::inv(Mat<eT>& out, const Base<eT,T1>& X, const bool slow)
|
Chris@49
|
24 {
|
Chris@49
|
25 arma_extra_debug_sigprint();
|
Chris@49
|
26
|
Chris@49
|
27 out = X.get_ref();
|
Chris@49
|
28
|
Chris@49
|
29 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
|
Chris@49
|
30
|
Chris@49
|
31 bool status = false;
|
Chris@49
|
32
|
Chris@49
|
33 const uword N = out.n_rows;
|
Chris@49
|
34
|
Chris@49
|
35 if( (N <= 4) && (slow == false) )
|
Chris@49
|
36 {
|
Chris@49
|
37 status = auxlib::inv_inplace_tinymat(out, N);
|
Chris@49
|
38 }
|
Chris@49
|
39
|
Chris@49
|
40 if( (N > 4) || (status == false) )
|
Chris@49
|
41 {
|
Chris@49
|
42 status = auxlib::inv_inplace_lapack(out);
|
Chris@49
|
43 }
|
Chris@49
|
44
|
Chris@49
|
45 return status;
|
Chris@49
|
46 }
|
Chris@49
|
47
|
Chris@49
|
48
|
Chris@49
|
49
|
Chris@49
|
50 template<typename eT>
|
Chris@49
|
51 inline
|
Chris@49
|
52 bool
|
Chris@49
|
53 auxlib::inv(Mat<eT>& out, const Mat<eT>& X, const bool slow)
|
Chris@49
|
54 {
|
Chris@49
|
55 arma_extra_debug_sigprint();
|
Chris@49
|
56
|
Chris@49
|
57 arma_debug_check( (X.is_square() == false), "inv(): given matrix is not square" );
|
Chris@49
|
58
|
Chris@49
|
59 bool status = false;
|
Chris@49
|
60
|
Chris@49
|
61 const uword N = X.n_rows;
|
Chris@49
|
62
|
Chris@49
|
63 if( (N <= 4) && (slow == false) )
|
Chris@49
|
64 {
|
Chris@49
|
65 status = (&out != &X) ? auxlib::inv_noalias_tinymat(out, X, N) : auxlib::inv_inplace_tinymat(out, N);
|
Chris@49
|
66 }
|
Chris@49
|
67
|
Chris@49
|
68 if( (N > 4) || (status == false) )
|
Chris@49
|
69 {
|
Chris@49
|
70 out = X;
|
Chris@49
|
71 status = auxlib::inv_inplace_lapack(out);
|
Chris@49
|
72 }
|
Chris@49
|
73
|
Chris@49
|
74 return status;
|
Chris@49
|
75 }
|
Chris@49
|
76
|
Chris@49
|
77
|
Chris@49
|
78
|
Chris@49
|
79 template<typename eT>
|
Chris@49
|
80 inline
|
Chris@49
|
81 bool
|
Chris@49
|
82 auxlib::inv_noalias_tinymat(Mat<eT>& out, const Mat<eT>& X, const uword N)
|
Chris@49
|
83 {
|
Chris@49
|
84 arma_extra_debug_sigprint();
|
Chris@49
|
85
|
Chris@49
|
86 bool det_ok = true;
|
Chris@49
|
87
|
Chris@49
|
88 out.set_size(N,N);
|
Chris@49
|
89
|
Chris@49
|
90 switch(N)
|
Chris@49
|
91 {
|
Chris@49
|
92 case 1:
|
Chris@49
|
93 {
|
Chris@49
|
94 out[0] = eT(1) / X[0];
|
Chris@49
|
95 };
|
Chris@49
|
96 break;
|
Chris@49
|
97
|
Chris@49
|
98 case 2:
|
Chris@49
|
99 {
|
Chris@49
|
100 const eT* Xm = X.memptr();
|
Chris@49
|
101
|
Chris@49
|
102 const eT a = Xm[pos<0,0>::n2];
|
Chris@49
|
103 const eT b = Xm[pos<0,1>::n2];
|
Chris@49
|
104 const eT c = Xm[pos<1,0>::n2];
|
Chris@49
|
105 const eT d = Xm[pos<1,1>::n2];
|
Chris@49
|
106
|
Chris@49
|
107 const eT tmp_det = (a*d - b*c);
|
Chris@49
|
108
|
Chris@49
|
109 if(tmp_det != eT(0))
|
Chris@49
|
110 {
|
Chris@49
|
111 eT* outm = out.memptr();
|
Chris@49
|
112
|
Chris@49
|
113 outm[pos<0,0>::n2] = d / tmp_det;
|
Chris@49
|
114 outm[pos<0,1>::n2] = -b / tmp_det;
|
Chris@49
|
115 outm[pos<1,0>::n2] = -c / tmp_det;
|
Chris@49
|
116 outm[pos<1,1>::n2] = a / tmp_det;
|
Chris@49
|
117 }
|
Chris@49
|
118 else
|
Chris@49
|
119 {
|
Chris@49
|
120 det_ok = false;
|
Chris@49
|
121 }
|
Chris@49
|
122 };
|
Chris@49
|
123 break;
|
Chris@49
|
124
|
Chris@49
|
125 case 3:
|
Chris@49
|
126 {
|
Chris@49
|
127 const eT* X_col0 = X.colptr(0);
|
Chris@49
|
128 const eT a11 = X_col0[0];
|
Chris@49
|
129 const eT a21 = X_col0[1];
|
Chris@49
|
130 const eT a31 = X_col0[2];
|
Chris@49
|
131
|
Chris@49
|
132 const eT* X_col1 = X.colptr(1);
|
Chris@49
|
133 const eT a12 = X_col1[0];
|
Chris@49
|
134 const eT a22 = X_col1[1];
|
Chris@49
|
135 const eT a32 = X_col1[2];
|
Chris@49
|
136
|
Chris@49
|
137 const eT* X_col2 = X.colptr(2);
|
Chris@49
|
138 const eT a13 = X_col2[0];
|
Chris@49
|
139 const eT a23 = X_col2[1];
|
Chris@49
|
140 const eT a33 = X_col2[2];
|
Chris@49
|
141
|
Chris@49
|
142 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
|
Chris@49
|
143
|
Chris@49
|
144 if(tmp_det != eT(0))
|
Chris@49
|
145 {
|
Chris@49
|
146 eT* out_col0 = out.colptr(0);
|
Chris@49
|
147 out_col0[0] = (a33*a22 - a32*a23) / tmp_det;
|
Chris@49
|
148 out_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
|
Chris@49
|
149 out_col0[2] = (a32*a21 - a31*a22) / tmp_det;
|
Chris@49
|
150
|
Chris@49
|
151 eT* out_col1 = out.colptr(1);
|
Chris@49
|
152 out_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
|
Chris@49
|
153 out_col1[1] = (a33*a11 - a31*a13) / tmp_det;
|
Chris@49
|
154 out_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
|
Chris@49
|
155
|
Chris@49
|
156 eT* out_col2 = out.colptr(2);
|
Chris@49
|
157 out_col2[0] = (a23*a12 - a22*a13) / tmp_det;
|
Chris@49
|
158 out_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
|
Chris@49
|
159 out_col2[2] = (a22*a11 - a21*a12) / tmp_det;
|
Chris@49
|
160 }
|
Chris@49
|
161 else
|
Chris@49
|
162 {
|
Chris@49
|
163 det_ok = false;
|
Chris@49
|
164 }
|
Chris@49
|
165 };
|
Chris@49
|
166 break;
|
Chris@49
|
167
|
Chris@49
|
168 case 4:
|
Chris@49
|
169 {
|
Chris@49
|
170 const eT tmp_det = det(X);
|
Chris@49
|
171
|
Chris@49
|
172 if(tmp_det != eT(0))
|
Chris@49
|
173 {
|
Chris@49
|
174 const eT* Xm = X.memptr();
|
Chris@49
|
175 eT* outm = out.memptr();
|
Chris@49
|
176
|
Chris@49
|
177 outm[pos<0,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
178 outm[pos<1,0>::n4] = ( Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
179 outm[pos<2,0>::n4] = ( Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
180 outm[pos<3,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
|
Chris@49
|
181
|
Chris@49
|
182 outm[pos<0,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
183 outm[pos<1,1>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
184 outm[pos<2,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
185 outm[pos<3,1>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
|
Chris@49
|
186
|
Chris@49
|
187 outm[pos<0,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
188 outm[pos<1,2>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
189 outm[pos<2,2>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
190 outm[pos<3,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
|
Chris@49
|
191
|
Chris@49
|
192 outm[pos<0,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
|
Chris@49
|
193 outm[pos<1,3>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
|
Chris@49
|
194 outm[pos<2,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
|
Chris@49
|
195 outm[pos<3,3>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] ) / tmp_det;
|
Chris@49
|
196 }
|
Chris@49
|
197 else
|
Chris@49
|
198 {
|
Chris@49
|
199 det_ok = false;
|
Chris@49
|
200 }
|
Chris@49
|
201 };
|
Chris@49
|
202 break;
|
Chris@49
|
203
|
Chris@49
|
204 default:
|
Chris@49
|
205 ;
|
Chris@49
|
206 }
|
Chris@49
|
207
|
Chris@49
|
208 return det_ok;
|
Chris@49
|
209 }
|
Chris@49
|
210
|
Chris@49
|
211
|
Chris@49
|
212
|
Chris@49
|
213 template<typename eT>
|
Chris@49
|
214 inline
|
Chris@49
|
215 bool
|
Chris@49
|
216 auxlib::inv_inplace_tinymat(Mat<eT>& X, const uword N)
|
Chris@49
|
217 {
|
Chris@49
|
218 arma_extra_debug_sigprint();
|
Chris@49
|
219
|
Chris@49
|
220 bool det_ok = true;
|
Chris@49
|
221
|
Chris@49
|
222 // for more info, see:
|
Chris@49
|
223 // http://www.dr-lex.34sp.com/random/matrix_inv.html
|
Chris@49
|
224 // http://www.cvl.iis.u-tokyo.ac.jp/~miyazaki/tech/teche23.html
|
Chris@49
|
225 // http://www.euclideanspace.com/maths/algebra/matrix/functions/inverse/fourD/index.htm
|
Chris@49
|
226 // http://www.geometrictools.com//LibFoundation/Mathematics/Wm4Matrix4.inl
|
Chris@49
|
227
|
Chris@49
|
228 switch(N)
|
Chris@49
|
229 {
|
Chris@49
|
230 case 1:
|
Chris@49
|
231 {
|
Chris@49
|
232 X[0] = eT(1) / X[0];
|
Chris@49
|
233 };
|
Chris@49
|
234 break;
|
Chris@49
|
235
|
Chris@49
|
236 case 2:
|
Chris@49
|
237 {
|
Chris@49
|
238 const eT a = X[pos<0,0>::n2];
|
Chris@49
|
239 const eT b = X[pos<0,1>::n2];
|
Chris@49
|
240 const eT c = X[pos<1,0>::n2];
|
Chris@49
|
241 const eT d = X[pos<1,1>::n2];
|
Chris@49
|
242
|
Chris@49
|
243 const eT tmp_det = (a*d - b*c);
|
Chris@49
|
244
|
Chris@49
|
245 if(tmp_det != eT(0))
|
Chris@49
|
246 {
|
Chris@49
|
247 X[pos<0,0>::n2] = d / tmp_det;
|
Chris@49
|
248 X[pos<0,1>::n2] = -b / tmp_det;
|
Chris@49
|
249 X[pos<1,0>::n2] = -c / tmp_det;
|
Chris@49
|
250 X[pos<1,1>::n2] = a / tmp_det;
|
Chris@49
|
251 }
|
Chris@49
|
252 else
|
Chris@49
|
253 {
|
Chris@49
|
254 det_ok = false;
|
Chris@49
|
255 }
|
Chris@49
|
256 };
|
Chris@49
|
257 break;
|
Chris@49
|
258
|
Chris@49
|
259 case 3:
|
Chris@49
|
260 {
|
Chris@49
|
261 eT* X_col0 = X.colptr(0);
|
Chris@49
|
262 eT* X_col1 = X.colptr(1);
|
Chris@49
|
263 eT* X_col2 = X.colptr(2);
|
Chris@49
|
264
|
Chris@49
|
265 const eT a11 = X_col0[0];
|
Chris@49
|
266 const eT a21 = X_col0[1];
|
Chris@49
|
267 const eT a31 = X_col0[2];
|
Chris@49
|
268
|
Chris@49
|
269 const eT a12 = X_col1[0];
|
Chris@49
|
270 const eT a22 = X_col1[1];
|
Chris@49
|
271 const eT a32 = X_col1[2];
|
Chris@49
|
272
|
Chris@49
|
273 const eT a13 = X_col2[0];
|
Chris@49
|
274 const eT a23 = X_col2[1];
|
Chris@49
|
275 const eT a33 = X_col2[2];
|
Chris@49
|
276
|
Chris@49
|
277 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
|
Chris@49
|
278
|
Chris@49
|
279 if(tmp_det != eT(0))
|
Chris@49
|
280 {
|
Chris@49
|
281 X_col0[0] = (a33*a22 - a32*a23) / tmp_det;
|
Chris@49
|
282 X_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
|
Chris@49
|
283 X_col0[2] = (a32*a21 - a31*a22) / tmp_det;
|
Chris@49
|
284
|
Chris@49
|
285 X_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
|
Chris@49
|
286 X_col1[1] = (a33*a11 - a31*a13) / tmp_det;
|
Chris@49
|
287 X_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
|
Chris@49
|
288
|
Chris@49
|
289 X_col2[0] = (a23*a12 - a22*a13) / tmp_det;
|
Chris@49
|
290 X_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
|
Chris@49
|
291 X_col2[2] = (a22*a11 - a21*a12) / tmp_det;
|
Chris@49
|
292 }
|
Chris@49
|
293 else
|
Chris@49
|
294 {
|
Chris@49
|
295 det_ok = false;
|
Chris@49
|
296 }
|
Chris@49
|
297 };
|
Chris@49
|
298 break;
|
Chris@49
|
299
|
Chris@49
|
300 case 4:
|
Chris@49
|
301 {
|
Chris@49
|
302 const eT tmp_det = det(X);
|
Chris@49
|
303
|
Chris@49
|
304 if(tmp_det != eT(0))
|
Chris@49
|
305 {
|
Chris@49
|
306 const Mat<eT> A(X);
|
Chris@49
|
307
|
Chris@49
|
308 const eT* Am = A.memptr();
|
Chris@49
|
309 eT* Xm = X.memptr();
|
Chris@49
|
310
|
Chris@49
|
311 Xm[pos<0,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] - Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] + Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
312 Xm[pos<1,0>::n4] = ( Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
313 Xm[pos<2,0>::n4] = ( Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
314 Xm[pos<3,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
|
Chris@49
|
315
|
Chris@49
|
316 Xm[pos<0,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] + Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] - Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
317 Xm[pos<1,1>::n4] = ( Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
318 Xm[pos<2,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
319 Xm[pos<3,1>::n4] = ( Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
|
Chris@49
|
320
|
Chris@49
|
321 Xm[pos<0,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] + Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
322 Xm[pos<1,2>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
323 Xm[pos<2,2>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
Chris@49
|
324 Xm[pos<3,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
|
Chris@49
|
325
|
Chris@49
|
326 Xm[pos<0,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] + Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
|
Chris@49
|
327 Xm[pos<1,3>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
|
Chris@49
|
328 Xm[pos<2,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
|
Chris@49
|
329 Xm[pos<3,3>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] ) / tmp_det;
|
Chris@49
|
330 }
|
Chris@49
|
331 else
|
Chris@49
|
332 {
|
Chris@49
|
333 det_ok = false;
|
Chris@49
|
334 }
|
Chris@49
|
335 };
|
Chris@49
|
336 break;
|
Chris@49
|
337
|
Chris@49
|
338 default:
|
Chris@49
|
339 ;
|
Chris@49
|
340 }
|
Chris@49
|
341
|
Chris@49
|
342 return det_ok;
|
Chris@49
|
343 }
|
Chris@49
|
344
|
Chris@49
|
345
|
Chris@49
|
346
|
Chris@49
|
347 template<typename eT>
|
Chris@49
|
348 inline
|
Chris@49
|
349 bool
|
Chris@49
|
350 auxlib::inv_inplace_lapack(Mat<eT>& out)
|
Chris@49
|
351 {
|
Chris@49
|
352 arma_extra_debug_sigprint();
|
Chris@49
|
353
|
Chris@49
|
354 if(out.is_empty())
|
Chris@49
|
355 {
|
Chris@49
|
356 return true;
|
Chris@49
|
357 }
|
Chris@49
|
358
|
Chris@49
|
359 #if defined(ARMA_USE_ATLAS)
|
Chris@49
|
360 {
|
Chris@49
|
361 podarray<int> ipiv(out.n_rows);
|
Chris@49
|
362
|
Chris@49
|
363 int info = atlas::clapack_getrf(atlas::CblasColMajor, out.n_rows, out.n_cols, out.memptr(), out.n_rows, ipiv.memptr());
|
Chris@49
|
364
|
Chris@49
|
365 if(info == 0)
|
Chris@49
|
366 {
|
Chris@49
|
367 info = atlas::clapack_getri(atlas::CblasColMajor, out.n_rows, out.memptr(), out.n_rows, ipiv.memptr());
|
Chris@49
|
368 }
|
Chris@49
|
369
|
Chris@49
|
370 return (info == 0);
|
Chris@49
|
371 }
|
Chris@49
|
372 #elif defined(ARMA_USE_LAPACK)
|
Chris@49
|
373 {
|
Chris@49
|
374 blas_int n_rows = out.n_rows;
|
Chris@49
|
375 blas_int n_cols = out.n_cols;
|
Chris@49
|
376 blas_int lwork = 0;
|
Chris@49
|
377 blas_int lwork_min = (std::max)(blas_int(1), n_rows);
|
Chris@49
|
378 blas_int info = 0;
|
Chris@49
|
379
|
Chris@49
|
380 podarray<blas_int> ipiv(out.n_rows);
|
Chris@49
|
381
|
Chris@49
|
382 eT work_query[2];
|
Chris@49
|
383 blas_int lwork_query = -1;
|
Chris@49
|
384
|
Chris@49
|
385 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), &work_query[0], &lwork_query, &info);
|
Chris@49
|
386
|
Chris@49
|
387 if(info == 0)
|
Chris@49
|
388 {
|
Chris@49
|
389 const blas_int lwork_proposed = static_cast<blas_int>( access::tmp_real(work_query[0]) );
|
Chris@49
|
390
|
Chris@49
|
391 lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min;
|
Chris@49
|
392 }
|
Chris@49
|
393 else
|
Chris@49
|
394 {
|
Chris@49
|
395 return false;
|
Chris@49
|
396 }
|
Chris@49
|
397
|
Chris@49
|
398 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
399
|
Chris@49
|
400 lapack::getrf(&n_rows, &n_cols, out.memptr(), &n_rows, ipiv.memptr(), &info);
|
Chris@49
|
401
|
Chris@49
|
402 if(info == 0)
|
Chris@49
|
403 {
|
Chris@49
|
404 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &lwork, &info);
|
Chris@49
|
405 }
|
Chris@49
|
406
|
Chris@49
|
407 return (info == 0);
|
Chris@49
|
408 }
|
Chris@49
|
409 #else
|
Chris@49
|
410 {
|
Chris@49
|
411 arma_stop("inv(): use of ATLAS or LAPACK needs to be enabled");
|
Chris@49
|
412 return false;
|
Chris@49
|
413 }
|
Chris@49
|
414 #endif
|
Chris@49
|
415 }
|
Chris@49
|
416
|
Chris@49
|
417
|
Chris@49
|
418
|
Chris@49
|
419 template<typename eT, typename T1>
|
Chris@49
|
420 inline
|
Chris@49
|
421 bool
|
Chris@49
|
422 auxlib::inv_tr(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
|
Chris@49
|
423 {
|
Chris@49
|
424 arma_extra_debug_sigprint();
|
Chris@49
|
425
|
Chris@49
|
426 out = X.get_ref();
|
Chris@49
|
427
|
Chris@49
|
428 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
|
Chris@49
|
429
|
Chris@49
|
430 if(out.is_empty())
|
Chris@49
|
431 {
|
Chris@49
|
432 return true;
|
Chris@49
|
433 }
|
Chris@49
|
434
|
Chris@49
|
435 bool status;
|
Chris@49
|
436
|
Chris@49
|
437 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
438 {
|
Chris@49
|
439 char uplo = (layout == 0) ? 'U' : 'L';
|
Chris@49
|
440 char diag = 'N';
|
Chris@49
|
441 blas_int n = blas_int(out.n_rows);
|
Chris@49
|
442 blas_int info = 0;
|
Chris@49
|
443
|
Chris@49
|
444 lapack::trtri(&uplo, &diag, &n, out.memptr(), &n, &info);
|
Chris@49
|
445
|
Chris@49
|
446 status = (info == 0);
|
Chris@49
|
447 }
|
Chris@49
|
448 #else
|
Chris@49
|
449 {
|
Chris@49
|
450 arma_ignore(layout);
|
Chris@49
|
451 arma_stop("inv(): use of LAPACK needs to be enabled");
|
Chris@49
|
452 status = false;
|
Chris@49
|
453 }
|
Chris@49
|
454 #endif
|
Chris@49
|
455
|
Chris@49
|
456
|
Chris@49
|
457 if(status == true)
|
Chris@49
|
458 {
|
Chris@49
|
459 if(layout == 0)
|
Chris@49
|
460 {
|
Chris@49
|
461 // upper triangular
|
Chris@49
|
462 out = trimatu(out);
|
Chris@49
|
463 }
|
Chris@49
|
464 else
|
Chris@49
|
465 {
|
Chris@49
|
466 // lower triangular
|
Chris@49
|
467 out = trimatl(out);
|
Chris@49
|
468 }
|
Chris@49
|
469 }
|
Chris@49
|
470
|
Chris@49
|
471 return status;
|
Chris@49
|
472 }
|
Chris@49
|
473
|
Chris@49
|
474
|
Chris@49
|
475
|
Chris@49
|
476 template<typename eT, typename T1>
|
Chris@49
|
477 inline
|
Chris@49
|
478 bool
|
Chris@49
|
479 auxlib::inv_sym(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
|
Chris@49
|
480 {
|
Chris@49
|
481 arma_extra_debug_sigprint();
|
Chris@49
|
482
|
Chris@49
|
483 out = X.get_ref();
|
Chris@49
|
484
|
Chris@49
|
485 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
|
Chris@49
|
486
|
Chris@49
|
487 if(out.is_empty())
|
Chris@49
|
488 {
|
Chris@49
|
489 return true;
|
Chris@49
|
490 }
|
Chris@49
|
491
|
Chris@49
|
492 bool status;
|
Chris@49
|
493
|
Chris@49
|
494 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
495 {
|
Chris@49
|
496 char uplo = (layout == 0) ? 'U' : 'L';
|
Chris@49
|
497 blas_int n = blas_int(out.n_rows);
|
Chris@49
|
498 blas_int lwork = 3 * (n*n); // TODO: use lwork = -1 to determine optimal size
|
Chris@49
|
499 blas_int info = 0;
|
Chris@49
|
500
|
Chris@49
|
501 podarray<blas_int> ipiv;
|
Chris@49
|
502 ipiv.set_size(out.n_rows);
|
Chris@49
|
503
|
Chris@49
|
504 podarray<eT> work;
|
Chris@49
|
505 work.set_size( uword(lwork) );
|
Chris@49
|
506
|
Chris@49
|
507 lapack::sytrf(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &lwork, &info);
|
Chris@49
|
508
|
Chris@49
|
509 status = (info == 0);
|
Chris@49
|
510
|
Chris@49
|
511 if(status == true)
|
Chris@49
|
512 {
|
Chris@49
|
513 lapack::sytri(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &info);
|
Chris@49
|
514
|
Chris@49
|
515 out = (layout == 0) ? symmatu(out) : symmatl(out);
|
Chris@49
|
516
|
Chris@49
|
517 status = (info == 0);
|
Chris@49
|
518 }
|
Chris@49
|
519 }
|
Chris@49
|
520 #else
|
Chris@49
|
521 {
|
Chris@49
|
522 arma_ignore(layout);
|
Chris@49
|
523 arma_stop("inv(): use of LAPACK needs to be enabled");
|
Chris@49
|
524 status = false;
|
Chris@49
|
525 }
|
Chris@49
|
526 #endif
|
Chris@49
|
527
|
Chris@49
|
528 return status;
|
Chris@49
|
529 }
|
Chris@49
|
530
|
Chris@49
|
531
|
Chris@49
|
532
|
Chris@49
|
533 template<typename eT, typename T1>
|
Chris@49
|
534 inline
|
Chris@49
|
535 bool
|
Chris@49
|
536 auxlib::inv_sympd(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
|
Chris@49
|
537 {
|
Chris@49
|
538 arma_extra_debug_sigprint();
|
Chris@49
|
539
|
Chris@49
|
540 out = X.get_ref();
|
Chris@49
|
541
|
Chris@49
|
542 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
|
Chris@49
|
543
|
Chris@49
|
544 if(out.is_empty())
|
Chris@49
|
545 {
|
Chris@49
|
546 return true;
|
Chris@49
|
547 }
|
Chris@49
|
548
|
Chris@49
|
549 bool status;
|
Chris@49
|
550
|
Chris@49
|
551 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
552 {
|
Chris@49
|
553 char uplo = (layout == 0) ? 'U' : 'L';
|
Chris@49
|
554 blas_int n = blas_int(out.n_rows);
|
Chris@49
|
555 blas_int info = 0;
|
Chris@49
|
556
|
Chris@49
|
557 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
|
Chris@49
|
558
|
Chris@49
|
559 status = (info == 0);
|
Chris@49
|
560
|
Chris@49
|
561 if(status == true)
|
Chris@49
|
562 {
|
Chris@49
|
563 lapack::potri(&uplo, &n, out.memptr(), &n, &info);
|
Chris@49
|
564
|
Chris@49
|
565 out = (layout == 0) ? symmatu(out) : symmatl(out);
|
Chris@49
|
566
|
Chris@49
|
567 status = (info == 0);
|
Chris@49
|
568 }
|
Chris@49
|
569 }
|
Chris@49
|
570 #else
|
Chris@49
|
571 {
|
Chris@49
|
572 arma_ignore(layout);
|
Chris@49
|
573 arma_stop("inv(): use of LAPACK needs to be enabled");
|
Chris@49
|
574 status = false;
|
Chris@49
|
575 }
|
Chris@49
|
576 #endif
|
Chris@49
|
577
|
Chris@49
|
578 return status;
|
Chris@49
|
579 }
|
Chris@49
|
580
|
Chris@49
|
581
|
Chris@49
|
582
|
Chris@49
|
583 template<typename eT, typename T1>
|
Chris@49
|
584 inline
|
Chris@49
|
585 eT
|
Chris@49
|
586 auxlib::det(const Base<eT,T1>& X, const bool slow)
|
Chris@49
|
587 {
|
Chris@49
|
588 const unwrap<T1> tmp(X.get_ref());
|
Chris@49
|
589 const Mat<eT>& A = tmp.M;
|
Chris@49
|
590
|
Chris@49
|
591 arma_debug_check( (A.is_square() == false), "det(): matrix is not square" );
|
Chris@49
|
592
|
Chris@49
|
593 const bool make_copy = (is_Mat<T1>::value == true) ? true : false;
|
Chris@49
|
594
|
Chris@49
|
595 if(slow == false)
|
Chris@49
|
596 {
|
Chris@49
|
597 const uword N = A.n_rows;
|
Chris@49
|
598
|
Chris@49
|
599 switch(N)
|
Chris@49
|
600 {
|
Chris@49
|
601 case 0:
|
Chris@49
|
602 case 1:
|
Chris@49
|
603 case 2:
|
Chris@49
|
604 return auxlib::det_tinymat(A, N);
|
Chris@49
|
605 break;
|
Chris@49
|
606
|
Chris@49
|
607 case 3:
|
Chris@49
|
608 case 4:
|
Chris@49
|
609 {
|
Chris@49
|
610 const eT tmp_det = auxlib::det_tinymat(A, N);
|
Chris@49
|
611 return (tmp_det != eT(0)) ? tmp_det : auxlib::det_lapack(A, make_copy);
|
Chris@49
|
612 }
|
Chris@49
|
613 break;
|
Chris@49
|
614
|
Chris@49
|
615 default:
|
Chris@49
|
616 return auxlib::det_lapack(A, make_copy);
|
Chris@49
|
617 }
|
Chris@49
|
618 }
|
Chris@49
|
619
|
Chris@49
|
620 return auxlib::det_lapack(A, make_copy);
|
Chris@49
|
621 }
|
Chris@49
|
622
|
Chris@49
|
623
|
Chris@49
|
624
|
Chris@49
|
625 template<typename eT>
|
Chris@49
|
626 inline
|
Chris@49
|
627 eT
|
Chris@49
|
628 auxlib::det_tinymat(const Mat<eT>& X, const uword N)
|
Chris@49
|
629 {
|
Chris@49
|
630 arma_extra_debug_sigprint();
|
Chris@49
|
631
|
Chris@49
|
632 switch(N)
|
Chris@49
|
633 {
|
Chris@49
|
634 case 0:
|
Chris@49
|
635 return eT(1);
|
Chris@49
|
636 break;
|
Chris@49
|
637
|
Chris@49
|
638 case 1:
|
Chris@49
|
639 return X[0];
|
Chris@49
|
640 break;
|
Chris@49
|
641
|
Chris@49
|
642 case 2:
|
Chris@49
|
643 {
|
Chris@49
|
644 const eT* Xm = X.memptr();
|
Chris@49
|
645
|
Chris@49
|
646 return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] );
|
Chris@49
|
647 }
|
Chris@49
|
648 break;
|
Chris@49
|
649
|
Chris@49
|
650 case 3:
|
Chris@49
|
651 {
|
Chris@49
|
652 // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2);
|
Chris@49
|
653 // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0);
|
Chris@49
|
654 // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1);
|
Chris@49
|
655 // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2);
|
Chris@49
|
656 // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0);
|
Chris@49
|
657 // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1);
|
Chris@49
|
658 // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6);
|
Chris@49
|
659
|
Chris@49
|
660 const eT* a_col0 = X.colptr(0);
|
Chris@49
|
661 const eT a11 = a_col0[0];
|
Chris@49
|
662 const eT a21 = a_col0[1];
|
Chris@49
|
663 const eT a31 = a_col0[2];
|
Chris@49
|
664
|
Chris@49
|
665 const eT* a_col1 = X.colptr(1);
|
Chris@49
|
666 const eT a12 = a_col1[0];
|
Chris@49
|
667 const eT a22 = a_col1[1];
|
Chris@49
|
668 const eT a32 = a_col1[2];
|
Chris@49
|
669
|
Chris@49
|
670 const eT* a_col2 = X.colptr(2);
|
Chris@49
|
671 const eT a13 = a_col2[0];
|
Chris@49
|
672 const eT a23 = a_col2[1];
|
Chris@49
|
673 const eT a33 = a_col2[2];
|
Chris@49
|
674
|
Chris@49
|
675 return ( a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13) );
|
Chris@49
|
676 }
|
Chris@49
|
677 break;
|
Chris@49
|
678
|
Chris@49
|
679 case 4:
|
Chris@49
|
680 {
|
Chris@49
|
681 const eT* Xm = X.memptr();
|
Chris@49
|
682
|
Chris@49
|
683 const eT val = \
|
Chris@49
|
684 Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
|
Chris@49
|
685 - Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
|
Chris@49
|
686 - Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
|
Chris@49
|
687 + Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
|
Chris@49
|
688 + Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
|
Chris@49
|
689 - Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
|
Chris@49
|
690 - Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
|
Chris@49
|
691 + Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
|
Chris@49
|
692 + Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
|
Chris@49
|
693 - Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
|
Chris@49
|
694 - Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
|
Chris@49
|
695 + Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
|
Chris@49
|
696 + Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
|
Chris@49
|
697 - Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
|
Chris@49
|
698 - Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
|
Chris@49
|
699 + Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
|
Chris@49
|
700 + Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
|
Chris@49
|
701 - Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
|
Chris@49
|
702 - Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
|
Chris@49
|
703 + Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
|
Chris@49
|
704 + Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
|
Chris@49
|
705 - Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
|
Chris@49
|
706 - Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
|
Chris@49
|
707 + Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
|
Chris@49
|
708 ;
|
Chris@49
|
709
|
Chris@49
|
710 return val;
|
Chris@49
|
711 }
|
Chris@49
|
712 break;
|
Chris@49
|
713
|
Chris@49
|
714 default:
|
Chris@49
|
715 return eT(0);
|
Chris@49
|
716 ;
|
Chris@49
|
717 }
|
Chris@49
|
718 }
|
Chris@49
|
719
|
Chris@49
|
720
|
Chris@49
|
721
|
Chris@49
|
722 //! immediate determinant of a matrix using ATLAS or LAPACK
|
Chris@49
|
723 template<typename eT>
|
Chris@49
|
724 inline
|
Chris@49
|
725 eT
|
Chris@49
|
726 auxlib::det_lapack(const Mat<eT>& X, const bool make_copy)
|
Chris@49
|
727 {
|
Chris@49
|
728 arma_extra_debug_sigprint();
|
Chris@49
|
729
|
Chris@49
|
730 Mat<eT> X_copy;
|
Chris@49
|
731
|
Chris@49
|
732 if(make_copy == true)
|
Chris@49
|
733 {
|
Chris@49
|
734 X_copy = X;
|
Chris@49
|
735 }
|
Chris@49
|
736
|
Chris@49
|
737 Mat<eT>& tmp = (make_copy == true) ? X_copy : const_cast< Mat<eT>& >(X);
|
Chris@49
|
738
|
Chris@49
|
739 if(tmp.is_empty())
|
Chris@49
|
740 {
|
Chris@49
|
741 return eT(1);
|
Chris@49
|
742 }
|
Chris@49
|
743
|
Chris@49
|
744
|
Chris@49
|
745 #if defined(ARMA_USE_ATLAS)
|
Chris@49
|
746 {
|
Chris@49
|
747 podarray<int> ipiv(tmp.n_rows);
|
Chris@49
|
748
|
Chris@49
|
749 //const int info =
|
Chris@49
|
750 atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
|
Chris@49
|
751
|
Chris@49
|
752 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
|
Chris@49
|
753 eT val = tmp.at(0,0);
|
Chris@49
|
754 for(uword i=1; i < tmp.n_rows; ++i)
|
Chris@49
|
755 {
|
Chris@49
|
756 val *= tmp.at(i,i);
|
Chris@49
|
757 }
|
Chris@49
|
758
|
Chris@49
|
759 int sign = +1;
|
Chris@49
|
760 for(uword i=0; i < tmp.n_rows; ++i)
|
Chris@49
|
761 {
|
Chris@49
|
762 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0
|
Chris@49
|
763 {
|
Chris@49
|
764 sign *= -1;
|
Chris@49
|
765 }
|
Chris@49
|
766 }
|
Chris@49
|
767
|
Chris@49
|
768 return ( (sign < 0) ? -val : val );
|
Chris@49
|
769 }
|
Chris@49
|
770 #elif defined(ARMA_USE_LAPACK)
|
Chris@49
|
771 {
|
Chris@49
|
772 podarray<blas_int> ipiv(tmp.n_rows);
|
Chris@49
|
773
|
Chris@49
|
774 blas_int info = 0;
|
Chris@49
|
775 blas_int n_rows = blas_int(tmp.n_rows);
|
Chris@49
|
776 blas_int n_cols = blas_int(tmp.n_cols);
|
Chris@49
|
777
|
Chris@49
|
778 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
|
Chris@49
|
779
|
Chris@49
|
780 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
|
Chris@49
|
781 eT val = tmp.at(0,0);
|
Chris@49
|
782 for(uword i=1; i < tmp.n_rows; ++i)
|
Chris@49
|
783 {
|
Chris@49
|
784 val *= tmp.at(i,i);
|
Chris@49
|
785 }
|
Chris@49
|
786
|
Chris@49
|
787 blas_int sign = +1;
|
Chris@49
|
788 for(uword i=0; i < tmp.n_rows; ++i)
|
Chris@49
|
789 {
|
Chris@49
|
790 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1
|
Chris@49
|
791 {
|
Chris@49
|
792 sign *= -1;
|
Chris@49
|
793 }
|
Chris@49
|
794 }
|
Chris@49
|
795
|
Chris@49
|
796 return ( (sign < 0) ? -val : val );
|
Chris@49
|
797 }
|
Chris@49
|
798 #else
|
Chris@49
|
799 {
|
Chris@49
|
800 arma_ignore(X);
|
Chris@49
|
801 arma_ignore(make_copy);
|
Chris@49
|
802 arma_ignore(tmp);
|
Chris@49
|
803 arma_stop("det(): use of ATLAS or LAPACK needs to be enabled");
|
Chris@49
|
804 return eT(0);
|
Chris@49
|
805 }
|
Chris@49
|
806 #endif
|
Chris@49
|
807 }
|
Chris@49
|
808
|
Chris@49
|
809
|
Chris@49
|
810
|
Chris@49
|
811 //! immediate log determinant of a matrix using ATLAS or LAPACK
|
Chris@49
|
812 template<typename eT, typename T1>
|
Chris@49
|
813 inline
|
Chris@49
|
814 bool
|
Chris@49
|
815 auxlib::log_det(eT& out_val, typename get_pod_type<eT>::result& out_sign, const Base<eT,T1>& X)
|
Chris@49
|
816 {
|
Chris@49
|
817 arma_extra_debug_sigprint();
|
Chris@49
|
818
|
Chris@49
|
819 typedef typename get_pod_type<eT>::result T;
|
Chris@49
|
820
|
Chris@49
|
821 #if defined(ARMA_USE_ATLAS)
|
Chris@49
|
822 {
|
Chris@49
|
823 Mat<eT> tmp(X.get_ref());
|
Chris@49
|
824 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
|
Chris@49
|
825
|
Chris@49
|
826 if(tmp.is_empty())
|
Chris@49
|
827 {
|
Chris@49
|
828 out_val = eT(0);
|
Chris@49
|
829 out_sign = T(1);
|
Chris@49
|
830 return true;
|
Chris@49
|
831 }
|
Chris@49
|
832
|
Chris@49
|
833 podarray<int> ipiv(tmp.n_rows);
|
Chris@49
|
834
|
Chris@49
|
835 const int info = atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
|
Chris@49
|
836
|
Chris@49
|
837 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
|
Chris@49
|
838
|
Chris@49
|
839 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
|
Chris@49
|
840 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
|
Chris@49
|
841
|
Chris@49
|
842 for(uword i=1; i < tmp.n_rows; ++i)
|
Chris@49
|
843 {
|
Chris@49
|
844 const eT x = tmp.at(i,i);
|
Chris@49
|
845
|
Chris@49
|
846 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
|
Chris@49
|
847 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
|
Chris@49
|
848 }
|
Chris@49
|
849
|
Chris@49
|
850 for(uword i=0; i < tmp.n_rows; ++i)
|
Chris@49
|
851 {
|
Chris@49
|
852 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0
|
Chris@49
|
853 {
|
Chris@49
|
854 sign *= -1;
|
Chris@49
|
855 }
|
Chris@49
|
856 }
|
Chris@49
|
857
|
Chris@49
|
858 out_val = val;
|
Chris@49
|
859 out_sign = T(sign);
|
Chris@49
|
860
|
Chris@49
|
861 return (info == 0);
|
Chris@49
|
862 }
|
Chris@49
|
863 #elif defined(ARMA_USE_LAPACK)
|
Chris@49
|
864 {
|
Chris@49
|
865 Mat<eT> tmp(X.get_ref());
|
Chris@49
|
866 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
|
Chris@49
|
867
|
Chris@49
|
868 if(tmp.is_empty())
|
Chris@49
|
869 {
|
Chris@49
|
870 out_val = eT(0);
|
Chris@49
|
871 out_sign = T(1);
|
Chris@49
|
872 return true;
|
Chris@49
|
873 }
|
Chris@49
|
874
|
Chris@49
|
875 podarray<blas_int> ipiv(tmp.n_rows);
|
Chris@49
|
876
|
Chris@49
|
877 blas_int info = 0;
|
Chris@49
|
878 blas_int n_rows = blas_int(tmp.n_rows);
|
Chris@49
|
879 blas_int n_cols = blas_int(tmp.n_cols);
|
Chris@49
|
880
|
Chris@49
|
881 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
|
Chris@49
|
882
|
Chris@49
|
883 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
|
Chris@49
|
884
|
Chris@49
|
885 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
|
Chris@49
|
886 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
|
Chris@49
|
887
|
Chris@49
|
888 for(uword i=1; i < tmp.n_rows; ++i)
|
Chris@49
|
889 {
|
Chris@49
|
890 const eT x = tmp.at(i,i);
|
Chris@49
|
891
|
Chris@49
|
892 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
|
Chris@49
|
893 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
|
Chris@49
|
894 }
|
Chris@49
|
895
|
Chris@49
|
896 for(uword i=0; i < tmp.n_rows; ++i)
|
Chris@49
|
897 {
|
Chris@49
|
898 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1
|
Chris@49
|
899 {
|
Chris@49
|
900 sign *= -1;
|
Chris@49
|
901 }
|
Chris@49
|
902 }
|
Chris@49
|
903
|
Chris@49
|
904 out_val = val;
|
Chris@49
|
905 out_sign = T(sign);
|
Chris@49
|
906
|
Chris@49
|
907 return (info == 0);
|
Chris@49
|
908 }
|
Chris@49
|
909 #else
|
Chris@49
|
910 {
|
Chris@49
|
911 arma_ignore(X);
|
Chris@49
|
912
|
Chris@49
|
913 out_val = eT(0);
|
Chris@49
|
914 out_sign = T(0);
|
Chris@49
|
915
|
Chris@49
|
916 arma_stop("log_det(): use of ATLAS or LAPACK needs to be enabled");
|
Chris@49
|
917
|
Chris@49
|
918 return false;
|
Chris@49
|
919 }
|
Chris@49
|
920 #endif
|
Chris@49
|
921 }
|
Chris@49
|
922
|
Chris@49
|
923
|
Chris@49
|
924
|
Chris@49
|
925 //! immediate LU decomposition of a matrix using ATLAS or LAPACK
|
Chris@49
|
926 template<typename eT, typename T1>
|
Chris@49
|
927 inline
|
Chris@49
|
928 bool
|
Chris@49
|
929 auxlib::lu(Mat<eT>& L, Mat<eT>& U, podarray<blas_int>& ipiv, const Base<eT,T1>& X)
|
Chris@49
|
930 {
|
Chris@49
|
931 arma_extra_debug_sigprint();
|
Chris@49
|
932
|
Chris@49
|
933 U = X.get_ref();
|
Chris@49
|
934
|
Chris@49
|
935 const uword U_n_rows = U.n_rows;
|
Chris@49
|
936 const uword U_n_cols = U.n_cols;
|
Chris@49
|
937
|
Chris@49
|
938 if(U.is_empty())
|
Chris@49
|
939 {
|
Chris@49
|
940 L.set_size(U_n_rows, 0);
|
Chris@49
|
941 U.set_size(0, U_n_cols);
|
Chris@49
|
942 ipiv.reset();
|
Chris@49
|
943 return true;
|
Chris@49
|
944 }
|
Chris@49
|
945
|
Chris@49
|
946 #if defined(ARMA_USE_ATLAS) || defined(ARMA_USE_LAPACK)
|
Chris@49
|
947 {
|
Chris@49
|
948 bool status;
|
Chris@49
|
949
|
Chris@49
|
950 #if defined(ARMA_USE_ATLAS)
|
Chris@49
|
951 {
|
Chris@49
|
952 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
|
Chris@49
|
953
|
Chris@49
|
954 int info = atlas::clapack_getrf(atlas::CblasColMajor, U_n_rows, U_n_cols, U.memptr(), U_n_rows, ipiv.memptr());
|
Chris@49
|
955
|
Chris@49
|
956 status = (info == 0);
|
Chris@49
|
957 }
|
Chris@49
|
958 #elif defined(ARMA_USE_LAPACK)
|
Chris@49
|
959 {
|
Chris@49
|
960 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
|
Chris@49
|
961
|
Chris@49
|
962 blas_int info = 0;
|
Chris@49
|
963
|
Chris@49
|
964 blas_int n_rows = U_n_rows;
|
Chris@49
|
965 blas_int n_cols = U_n_cols;
|
Chris@49
|
966
|
Chris@49
|
967
|
Chris@49
|
968 lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info);
|
Chris@49
|
969
|
Chris@49
|
970 // take into account that Fortran counts from 1
|
Chris@49
|
971 arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem);
|
Chris@49
|
972
|
Chris@49
|
973 status = (info == 0);
|
Chris@49
|
974 }
|
Chris@49
|
975 #endif
|
Chris@49
|
976
|
Chris@49
|
977 L.copy_size(U);
|
Chris@49
|
978
|
Chris@49
|
979 for(uword col=0; col < U_n_cols; ++col)
|
Chris@49
|
980 {
|
Chris@49
|
981 for(uword row=0; (row < col) && (row < U_n_rows); ++row)
|
Chris@49
|
982 {
|
Chris@49
|
983 L.at(row,col) = eT(0);
|
Chris@49
|
984 }
|
Chris@49
|
985
|
Chris@49
|
986 if( L.in_range(col,col) == true )
|
Chris@49
|
987 {
|
Chris@49
|
988 L.at(col,col) = eT(1);
|
Chris@49
|
989 }
|
Chris@49
|
990
|
Chris@49
|
991 for(uword row = (col+1); row < U_n_rows; ++row)
|
Chris@49
|
992 {
|
Chris@49
|
993 L.at(row,col) = U.at(row,col);
|
Chris@49
|
994 U.at(row,col) = eT(0);
|
Chris@49
|
995 }
|
Chris@49
|
996 }
|
Chris@49
|
997
|
Chris@49
|
998 return status;
|
Chris@49
|
999 }
|
Chris@49
|
1000 #else
|
Chris@49
|
1001 {
|
Chris@49
|
1002 arma_stop("lu(): use of ATLAS or LAPACK needs to be enabled");
|
Chris@49
|
1003
|
Chris@49
|
1004 return false;
|
Chris@49
|
1005 }
|
Chris@49
|
1006 #endif
|
Chris@49
|
1007 }
|
Chris@49
|
1008
|
Chris@49
|
1009
|
Chris@49
|
1010
|
Chris@49
|
1011 template<typename eT, typename T1>
|
Chris@49
|
1012 inline
|
Chris@49
|
1013 bool
|
Chris@49
|
1014 auxlib::lu(Mat<eT>& L, Mat<eT>& U, Mat<eT>& P, const Base<eT,T1>& X)
|
Chris@49
|
1015 {
|
Chris@49
|
1016 arma_extra_debug_sigprint();
|
Chris@49
|
1017
|
Chris@49
|
1018 podarray<blas_int> ipiv1;
|
Chris@49
|
1019 const bool status = auxlib::lu(L, U, ipiv1, X);
|
Chris@49
|
1020
|
Chris@49
|
1021 if(status == true)
|
Chris@49
|
1022 {
|
Chris@49
|
1023 if(U.is_empty())
|
Chris@49
|
1024 {
|
Chris@49
|
1025 // L and U have been already set to the correct empty matrices
|
Chris@49
|
1026 P.eye(L.n_rows, L.n_rows);
|
Chris@49
|
1027 return true;
|
Chris@49
|
1028 }
|
Chris@49
|
1029
|
Chris@49
|
1030 const uword n = ipiv1.n_elem;
|
Chris@49
|
1031 const uword P_rows = U.n_rows;
|
Chris@49
|
1032
|
Chris@49
|
1033 podarray<blas_int> ipiv2(P_rows);
|
Chris@49
|
1034
|
Chris@49
|
1035 const blas_int* ipiv1_mem = ipiv1.memptr();
|
Chris@49
|
1036 blas_int* ipiv2_mem = ipiv2.memptr();
|
Chris@49
|
1037
|
Chris@49
|
1038 for(uword i=0; i<P_rows; ++i)
|
Chris@49
|
1039 {
|
Chris@49
|
1040 ipiv2_mem[i] = blas_int(i);
|
Chris@49
|
1041 }
|
Chris@49
|
1042
|
Chris@49
|
1043 for(uword i=0; i<n; ++i)
|
Chris@49
|
1044 {
|
Chris@49
|
1045 const uword k = static_cast<uword>(ipiv1_mem[i]);
|
Chris@49
|
1046
|
Chris@49
|
1047 if( ipiv2_mem[i] != ipiv2_mem[k] )
|
Chris@49
|
1048 {
|
Chris@49
|
1049 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
|
Chris@49
|
1050 }
|
Chris@49
|
1051 }
|
Chris@49
|
1052
|
Chris@49
|
1053 P.zeros(P_rows, P_rows);
|
Chris@49
|
1054
|
Chris@49
|
1055 for(uword row=0; row<P_rows; ++row)
|
Chris@49
|
1056 {
|
Chris@49
|
1057 P.at(row, static_cast<uword>(ipiv2_mem[row])) = eT(1);
|
Chris@49
|
1058 }
|
Chris@49
|
1059
|
Chris@49
|
1060 if(L.n_cols > U.n_rows)
|
Chris@49
|
1061 {
|
Chris@49
|
1062 L.shed_cols(U.n_rows, L.n_cols-1);
|
Chris@49
|
1063 }
|
Chris@49
|
1064
|
Chris@49
|
1065 if(U.n_rows > L.n_cols)
|
Chris@49
|
1066 {
|
Chris@49
|
1067 U.shed_rows(L.n_cols, U.n_rows-1);
|
Chris@49
|
1068 }
|
Chris@49
|
1069 }
|
Chris@49
|
1070
|
Chris@49
|
1071 return status;
|
Chris@49
|
1072 }
|
Chris@49
|
1073
|
Chris@49
|
1074
|
Chris@49
|
1075
|
Chris@49
|
1076 template<typename eT, typename T1>
|
Chris@49
|
1077 inline
|
Chris@49
|
1078 bool
|
Chris@49
|
1079 auxlib::lu(Mat<eT>& L, Mat<eT>& U, const Base<eT,T1>& X)
|
Chris@49
|
1080 {
|
Chris@49
|
1081 arma_extra_debug_sigprint();
|
Chris@49
|
1082
|
Chris@49
|
1083 podarray<blas_int> ipiv1;
|
Chris@49
|
1084 const bool status = auxlib::lu(L, U, ipiv1, X);
|
Chris@49
|
1085
|
Chris@49
|
1086 if(status == true)
|
Chris@49
|
1087 {
|
Chris@49
|
1088 if(U.is_empty())
|
Chris@49
|
1089 {
|
Chris@49
|
1090 // L and U have been already set to the correct empty matrices
|
Chris@49
|
1091 return true;
|
Chris@49
|
1092 }
|
Chris@49
|
1093
|
Chris@49
|
1094 const uword n = ipiv1.n_elem;
|
Chris@49
|
1095 const uword P_rows = U.n_rows;
|
Chris@49
|
1096
|
Chris@49
|
1097 podarray<blas_int> ipiv2(P_rows);
|
Chris@49
|
1098
|
Chris@49
|
1099 const blas_int* ipiv1_mem = ipiv1.memptr();
|
Chris@49
|
1100 blas_int* ipiv2_mem = ipiv2.memptr();
|
Chris@49
|
1101
|
Chris@49
|
1102 for(uword i=0; i<P_rows; ++i)
|
Chris@49
|
1103 {
|
Chris@49
|
1104 ipiv2_mem[i] = blas_int(i);
|
Chris@49
|
1105 }
|
Chris@49
|
1106
|
Chris@49
|
1107 for(uword i=0; i<n; ++i)
|
Chris@49
|
1108 {
|
Chris@49
|
1109 const uword k = static_cast<uword>(ipiv1_mem[i]);
|
Chris@49
|
1110
|
Chris@49
|
1111 if( ipiv2_mem[i] != ipiv2_mem[k] )
|
Chris@49
|
1112 {
|
Chris@49
|
1113 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
|
Chris@49
|
1114 L.swap_rows( static_cast<uword>(ipiv2_mem[i]), static_cast<uword>(ipiv2_mem[k]) );
|
Chris@49
|
1115 }
|
Chris@49
|
1116 }
|
Chris@49
|
1117
|
Chris@49
|
1118 if(L.n_cols > U.n_rows)
|
Chris@49
|
1119 {
|
Chris@49
|
1120 L.shed_cols(U.n_rows, L.n_cols-1);
|
Chris@49
|
1121 }
|
Chris@49
|
1122
|
Chris@49
|
1123 if(U.n_rows > L.n_cols)
|
Chris@49
|
1124 {
|
Chris@49
|
1125 U.shed_rows(L.n_cols, U.n_rows-1);
|
Chris@49
|
1126 }
|
Chris@49
|
1127 }
|
Chris@49
|
1128
|
Chris@49
|
1129 return status;
|
Chris@49
|
1130 }
|
Chris@49
|
1131
|
Chris@49
|
1132
|
Chris@49
|
1133
|
Chris@49
|
1134 //! immediate eigenvalues of a symmetric real matrix using LAPACK
|
Chris@49
|
1135 template<typename eT, typename T1>
|
Chris@49
|
1136 inline
|
Chris@49
|
1137 bool
|
Chris@49
|
1138 auxlib::eig_sym(Col<eT>& eigval, const Base<eT,T1>& X)
|
Chris@49
|
1139 {
|
Chris@49
|
1140 arma_extra_debug_sigprint();
|
Chris@49
|
1141
|
Chris@49
|
1142 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1143 {
|
Chris@49
|
1144 Mat<eT> A(X.get_ref());
|
Chris@49
|
1145
|
Chris@49
|
1146 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square");
|
Chris@49
|
1147
|
Chris@49
|
1148 if(A.is_empty())
|
Chris@49
|
1149 {
|
Chris@49
|
1150 eigval.reset();
|
Chris@49
|
1151 return true;
|
Chris@49
|
1152 }
|
Chris@49
|
1153
|
Chris@49
|
1154 eigval.set_size(A.n_rows);
|
Chris@49
|
1155
|
Chris@49
|
1156 char jobz = 'N';
|
Chris@49
|
1157 char uplo = 'U';
|
Chris@49
|
1158
|
Chris@49
|
1159 blas_int N = blas_int(A.n_rows);
|
Chris@49
|
1160 blas_int lwork = 3 * ( (std::max)(blas_int(1), 3*N-1) );
|
Chris@49
|
1161 blas_int info = 0;
|
Chris@49
|
1162
|
Chris@49
|
1163 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
1164
|
Chris@49
|
1165 lapack::syev(&jobz, &uplo, &N, A.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, &info);
|
Chris@49
|
1166
|
Chris@49
|
1167 return (info == 0);
|
Chris@49
|
1168 }
|
Chris@49
|
1169 #else
|
Chris@49
|
1170 {
|
Chris@49
|
1171 arma_ignore(eigval);
|
Chris@49
|
1172 arma_ignore(X);
|
Chris@49
|
1173 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
|
Chris@49
|
1174 return false;
|
Chris@49
|
1175 }
|
Chris@49
|
1176 #endif
|
Chris@49
|
1177 }
|
Chris@49
|
1178
|
Chris@49
|
1179
|
Chris@49
|
1180
|
Chris@49
|
1181 //! immediate eigenvalues of a hermitian complex matrix using LAPACK
|
Chris@49
|
1182 template<typename T, typename T1>
|
Chris@49
|
1183 inline
|
Chris@49
|
1184 bool
|
Chris@49
|
1185 auxlib::eig_sym(Col<T>& eigval, const Base<std::complex<T>,T1>& X)
|
Chris@49
|
1186 {
|
Chris@49
|
1187 arma_extra_debug_sigprint();
|
Chris@49
|
1188
|
Chris@49
|
1189 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1190 {
|
Chris@49
|
1191 typedef typename std::complex<T> eT;
|
Chris@49
|
1192
|
Chris@49
|
1193 Mat<eT> A(X.get_ref());
|
Chris@49
|
1194
|
Chris@49
|
1195 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square");
|
Chris@49
|
1196
|
Chris@49
|
1197 if(A.is_empty())
|
Chris@49
|
1198 {
|
Chris@49
|
1199 eigval.reset();
|
Chris@49
|
1200 return true;
|
Chris@49
|
1201 }
|
Chris@49
|
1202
|
Chris@49
|
1203 eigval.set_size(A.n_rows);
|
Chris@49
|
1204
|
Chris@49
|
1205 char jobz = 'N';
|
Chris@49
|
1206 char uplo = 'U';
|
Chris@49
|
1207
|
Chris@49
|
1208 blas_int N = blas_int(A.n_rows);
|
Chris@49
|
1209 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*N-1) );
|
Chris@49
|
1210 blas_int info = 0;
|
Chris@49
|
1211
|
Chris@49
|
1212 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
1213 podarray<T> rwork( static_cast<uword>( (std::max)(blas_int(1), 3*N-2) ) );
|
Chris@49
|
1214
|
Chris@49
|
1215 arma_extra_debug_print("lapack::heev()");
|
Chris@49
|
1216 lapack::heev(&jobz, &uplo, &N, A.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
|
Chris@49
|
1217
|
Chris@49
|
1218 return (info == 0);
|
Chris@49
|
1219 }
|
Chris@49
|
1220 #else
|
Chris@49
|
1221 {
|
Chris@49
|
1222 arma_ignore(eigval);
|
Chris@49
|
1223 arma_ignore(X);
|
Chris@49
|
1224 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
|
Chris@49
|
1225 return false;
|
Chris@49
|
1226 }
|
Chris@49
|
1227 #endif
|
Chris@49
|
1228 }
|
Chris@49
|
1229
|
Chris@49
|
1230
|
Chris@49
|
1231
|
Chris@49
|
1232 //! immediate eigenvalues and eigenvectors of a symmetric real matrix using LAPACK
|
Chris@49
|
1233 template<typename eT, typename T1>
|
Chris@49
|
1234 inline
|
Chris@49
|
1235 bool
|
Chris@49
|
1236 auxlib::eig_sym(Col<eT>& eigval, Mat<eT>& eigvec, const Base<eT,T1>& X)
|
Chris@49
|
1237 {
|
Chris@49
|
1238 arma_extra_debug_sigprint();
|
Chris@49
|
1239
|
Chris@49
|
1240 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1241 {
|
Chris@49
|
1242 eigvec = X.get_ref();
|
Chris@49
|
1243
|
Chris@49
|
1244 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
|
Chris@49
|
1245
|
Chris@49
|
1246 if(eigvec.is_empty())
|
Chris@49
|
1247 {
|
Chris@49
|
1248 eigval.reset();
|
Chris@49
|
1249 eigvec.reset();
|
Chris@49
|
1250 return true;
|
Chris@49
|
1251 }
|
Chris@49
|
1252
|
Chris@49
|
1253 eigval.set_size(eigvec.n_rows);
|
Chris@49
|
1254
|
Chris@49
|
1255 char jobz = 'V';
|
Chris@49
|
1256 char uplo = 'U';
|
Chris@49
|
1257
|
Chris@49
|
1258 blas_int N = blas_int(eigvec.n_rows);
|
Chris@49
|
1259 blas_int lwork = 3 * ( (std::max)(blas_int(1), 3*N-1) );
|
Chris@49
|
1260 blas_int info = 0;
|
Chris@49
|
1261
|
Chris@49
|
1262 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
1263
|
Chris@49
|
1264 lapack::syev(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, &info);
|
Chris@49
|
1265
|
Chris@49
|
1266 return (info == 0);
|
Chris@49
|
1267 }
|
Chris@49
|
1268 #else
|
Chris@49
|
1269 {
|
Chris@49
|
1270 arma_ignore(eigval);
|
Chris@49
|
1271 arma_ignore(eigvec);
|
Chris@49
|
1272 arma_ignore(X);
|
Chris@49
|
1273 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
|
Chris@49
|
1274 return false;
|
Chris@49
|
1275 }
|
Chris@49
|
1276 #endif
|
Chris@49
|
1277 }
|
Chris@49
|
1278
|
Chris@49
|
1279
|
Chris@49
|
1280
|
Chris@49
|
1281 //! immediate eigenvalues and eigenvectors of a hermitian complex matrix using LAPACK
|
Chris@49
|
1282 template<typename T, typename T1>
|
Chris@49
|
1283 inline
|
Chris@49
|
1284 bool
|
Chris@49
|
1285 auxlib::eig_sym(Col<T>& eigval, Mat< std::complex<T> >& eigvec, const Base<std::complex<T>,T1>& X)
|
Chris@49
|
1286 {
|
Chris@49
|
1287 arma_extra_debug_sigprint();
|
Chris@49
|
1288
|
Chris@49
|
1289 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1290 {
|
Chris@49
|
1291 typedef typename std::complex<T> eT;
|
Chris@49
|
1292
|
Chris@49
|
1293 eigvec = X.get_ref();
|
Chris@49
|
1294
|
Chris@49
|
1295 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
|
Chris@49
|
1296
|
Chris@49
|
1297 if(eigvec.is_empty())
|
Chris@49
|
1298 {
|
Chris@49
|
1299 eigval.reset();
|
Chris@49
|
1300 eigvec.reset();
|
Chris@49
|
1301 return true;
|
Chris@49
|
1302 }
|
Chris@49
|
1303
|
Chris@49
|
1304 eigval.set_size(eigvec.n_rows);
|
Chris@49
|
1305
|
Chris@49
|
1306 char jobz = 'V';
|
Chris@49
|
1307 char uplo = 'U';
|
Chris@49
|
1308
|
Chris@49
|
1309 blas_int N = blas_int(eigvec.n_rows);
|
Chris@49
|
1310 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*N-1) );
|
Chris@49
|
1311 blas_int info = 0;
|
Chris@49
|
1312
|
Chris@49
|
1313 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
1314 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*N-2)) );
|
Chris@49
|
1315
|
Chris@49
|
1316 arma_extra_debug_print("lapack::heev()");
|
Chris@49
|
1317 lapack::heev(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
|
Chris@49
|
1318
|
Chris@49
|
1319 return (info == 0);
|
Chris@49
|
1320 }
|
Chris@49
|
1321 #else
|
Chris@49
|
1322 {
|
Chris@49
|
1323 arma_ignore(eigval);
|
Chris@49
|
1324 arma_ignore(eigvec);
|
Chris@49
|
1325 arma_ignore(X);
|
Chris@49
|
1326 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
|
Chris@49
|
1327 return false;
|
Chris@49
|
1328 }
|
Chris@49
|
1329 #endif
|
Chris@49
|
1330 }
|
Chris@49
|
1331
|
Chris@49
|
1332
|
Chris@49
|
1333
|
Chris@49
|
1334 //! immediate eigenvalues and eigenvectors of a symmetric real matrix using LAPACK (divide and conquer algorithm)
|
Chris@49
|
1335 template<typename eT, typename T1>
|
Chris@49
|
1336 inline
|
Chris@49
|
1337 bool
|
Chris@49
|
1338 auxlib::eig_sym_dc(Col<eT>& eigval, Mat<eT>& eigvec, const Base<eT,T1>& X)
|
Chris@49
|
1339 {
|
Chris@49
|
1340 arma_extra_debug_sigprint();
|
Chris@49
|
1341
|
Chris@49
|
1342 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1343 {
|
Chris@49
|
1344 eigvec = X.get_ref();
|
Chris@49
|
1345
|
Chris@49
|
1346 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
|
Chris@49
|
1347
|
Chris@49
|
1348 if(eigvec.is_empty())
|
Chris@49
|
1349 {
|
Chris@49
|
1350 eigval.reset();
|
Chris@49
|
1351 eigvec.reset();
|
Chris@49
|
1352 return true;
|
Chris@49
|
1353 }
|
Chris@49
|
1354
|
Chris@49
|
1355 eigval.set_size(eigvec.n_rows);
|
Chris@49
|
1356
|
Chris@49
|
1357 char jobz = 'V';
|
Chris@49
|
1358 char uplo = 'U';
|
Chris@49
|
1359
|
Chris@49
|
1360 blas_int N = blas_int(eigvec.n_rows);
|
Chris@49
|
1361 blas_int lwork = 3 * (1 + 6*N + 2*(N*N));
|
Chris@49
|
1362 blas_int liwork = 3 * (3 + 5*N + 2);
|
Chris@49
|
1363 blas_int info = 0;
|
Chris@49
|
1364
|
Chris@49
|
1365 podarray<eT> work( static_cast<uword>( lwork) );
|
Chris@49
|
1366 podarray<blas_int> iwork( static_cast<uword>(liwork) );
|
Chris@49
|
1367
|
Chris@49
|
1368 arma_extra_debug_print("lapack::syevd()");
|
Chris@49
|
1369 lapack::syevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, iwork.memptr(), &liwork, &info);
|
Chris@49
|
1370
|
Chris@49
|
1371 return (info == 0);
|
Chris@49
|
1372 }
|
Chris@49
|
1373 #else
|
Chris@49
|
1374 {
|
Chris@49
|
1375 arma_ignore(eigval);
|
Chris@49
|
1376 arma_ignore(eigvec);
|
Chris@49
|
1377 arma_ignore(X);
|
Chris@49
|
1378 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
|
Chris@49
|
1379 return false;
|
Chris@49
|
1380 }
|
Chris@49
|
1381 #endif
|
Chris@49
|
1382 }
|
Chris@49
|
1383
|
Chris@49
|
1384
|
Chris@49
|
1385
|
Chris@49
|
1386 //! immediate eigenvalues and eigenvectors of a hermitian complex matrix using LAPACK (divide and conquer algorithm)
|
Chris@49
|
1387 template<typename T, typename T1>
|
Chris@49
|
1388 inline
|
Chris@49
|
1389 bool
|
Chris@49
|
1390 auxlib::eig_sym_dc(Col<T>& eigval, Mat< std::complex<T> >& eigvec, const Base<std::complex<T>,T1>& X)
|
Chris@49
|
1391 {
|
Chris@49
|
1392 arma_extra_debug_sigprint();
|
Chris@49
|
1393
|
Chris@49
|
1394 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1395 {
|
Chris@49
|
1396 typedef typename std::complex<T> eT;
|
Chris@49
|
1397
|
Chris@49
|
1398 eigvec = X.get_ref();
|
Chris@49
|
1399
|
Chris@49
|
1400 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
|
Chris@49
|
1401
|
Chris@49
|
1402 if(eigvec.is_empty())
|
Chris@49
|
1403 {
|
Chris@49
|
1404 eigval.reset();
|
Chris@49
|
1405 eigvec.reset();
|
Chris@49
|
1406 return true;
|
Chris@49
|
1407 }
|
Chris@49
|
1408
|
Chris@49
|
1409 eigval.set_size(eigvec.n_rows);
|
Chris@49
|
1410
|
Chris@49
|
1411 char jobz = 'V';
|
Chris@49
|
1412 char uplo = 'U';
|
Chris@49
|
1413
|
Chris@49
|
1414 blas_int N = blas_int(eigvec.n_rows);
|
Chris@49
|
1415 blas_int lwork = 3 * (2*N + N*N);
|
Chris@49
|
1416 blas_int lrwork = 3 * (1 + 5*N + 2*(N*N));
|
Chris@49
|
1417 blas_int liwork = 3 * (3 + 5*N);
|
Chris@49
|
1418 blas_int info = 0;
|
Chris@49
|
1419
|
Chris@49
|
1420 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
1421 podarray<T> rwork( static_cast<uword>(lrwork) );
|
Chris@49
|
1422 podarray<blas_int> iwork( static_cast<uword>(liwork) );
|
Chris@49
|
1423
|
Chris@49
|
1424 arma_extra_debug_print("lapack::heevd()");
|
Chris@49
|
1425 lapack::heevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &lrwork, iwork.memptr(), &liwork, &info);
|
Chris@49
|
1426
|
Chris@49
|
1427 return (info == 0);
|
Chris@49
|
1428 }
|
Chris@49
|
1429 #else
|
Chris@49
|
1430 {
|
Chris@49
|
1431 arma_ignore(eigval);
|
Chris@49
|
1432 arma_ignore(eigvec);
|
Chris@49
|
1433 arma_ignore(X);
|
Chris@49
|
1434 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
|
Chris@49
|
1435 return false;
|
Chris@49
|
1436 }
|
Chris@49
|
1437 #endif
|
Chris@49
|
1438 }
|
Chris@49
|
1439
|
Chris@49
|
1440
|
Chris@49
|
1441
|
Chris@49
|
1442 //! Eigenvalues and eigenvectors of a general square real matrix using LAPACK.
|
Chris@49
|
1443 //! The argument 'side' specifies which eigenvectors should be calculated
|
Chris@49
|
1444 //! (see code for mode details).
|
Chris@49
|
1445 template<typename T, typename T1>
|
Chris@49
|
1446 inline
|
Chris@49
|
1447 bool
|
Chris@49
|
1448 auxlib::eig_gen
|
Chris@49
|
1449 (
|
Chris@49
|
1450 Col< std::complex<T> >& eigval,
|
Chris@49
|
1451 Mat<T>& l_eigvec,
|
Chris@49
|
1452 Mat<T>& r_eigvec,
|
Chris@49
|
1453 const Base<T,T1>& X,
|
Chris@49
|
1454 const char side
|
Chris@49
|
1455 )
|
Chris@49
|
1456 {
|
Chris@49
|
1457 arma_extra_debug_sigprint();
|
Chris@49
|
1458
|
Chris@49
|
1459 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1460 {
|
Chris@49
|
1461 char jobvl;
|
Chris@49
|
1462 char jobvr;
|
Chris@49
|
1463
|
Chris@49
|
1464 switch(side)
|
Chris@49
|
1465 {
|
Chris@49
|
1466 case 'l': // left
|
Chris@49
|
1467 jobvl = 'V';
|
Chris@49
|
1468 jobvr = 'N';
|
Chris@49
|
1469 break;
|
Chris@49
|
1470
|
Chris@49
|
1471 case 'r': // right
|
Chris@49
|
1472 jobvl = 'N';
|
Chris@49
|
1473 jobvr = 'V';
|
Chris@49
|
1474 break;
|
Chris@49
|
1475
|
Chris@49
|
1476 case 'b': // both
|
Chris@49
|
1477 jobvl = 'V';
|
Chris@49
|
1478 jobvr = 'V';
|
Chris@49
|
1479 break;
|
Chris@49
|
1480
|
Chris@49
|
1481 case 'n': // neither
|
Chris@49
|
1482 jobvl = 'N';
|
Chris@49
|
1483 jobvr = 'N';
|
Chris@49
|
1484 break;
|
Chris@49
|
1485
|
Chris@49
|
1486 default:
|
Chris@49
|
1487 arma_stop("eig_gen(): parameter 'side' is invalid");
|
Chris@49
|
1488 return false;
|
Chris@49
|
1489 }
|
Chris@49
|
1490
|
Chris@49
|
1491 Mat<T> A(X.get_ref());
|
Chris@49
|
1492 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
|
Chris@49
|
1493
|
Chris@49
|
1494 if(A.is_empty())
|
Chris@49
|
1495 {
|
Chris@49
|
1496 eigval.reset();
|
Chris@49
|
1497 l_eigvec.reset();
|
Chris@49
|
1498 r_eigvec.reset();
|
Chris@49
|
1499 return true;
|
Chris@49
|
1500 }
|
Chris@49
|
1501
|
Chris@49
|
1502 const uword A_n_rows = A.n_rows;
|
Chris@49
|
1503
|
Chris@49
|
1504 eigval.set_size(A_n_rows);
|
Chris@49
|
1505
|
Chris@49
|
1506 l_eigvec.set_size(A_n_rows, A_n_rows);
|
Chris@49
|
1507 r_eigvec.set_size(A_n_rows, A_n_rows);
|
Chris@49
|
1508
|
Chris@49
|
1509 blas_int N = blas_int(A_n_rows);
|
Chris@49
|
1510 blas_int lwork = 3 * ( (std::max)(blas_int(1), 4*N) );
|
Chris@49
|
1511 blas_int info = 0;
|
Chris@49
|
1512
|
Chris@49
|
1513 podarray<T> work( static_cast<uword>(lwork) );
|
Chris@49
|
1514
|
Chris@49
|
1515 podarray<T> wr(A_n_rows);
|
Chris@49
|
1516 podarray<T> wi(A_n_rows);
|
Chris@49
|
1517
|
Chris@49
|
1518 arma_extra_debug_print("lapack::geev()");
|
Chris@49
|
1519 lapack::geev(&jobvl, &jobvr, &N, A.memptr(), &N, wr.memptr(), wi.memptr(), l_eigvec.memptr(), &N, r_eigvec.memptr(), &N, work.memptr(), &lwork, &info);
|
Chris@49
|
1520
|
Chris@49
|
1521 eigval.set_size(A_n_rows);
|
Chris@49
|
1522 for(uword i=0; i<A_n_rows; ++i)
|
Chris@49
|
1523 {
|
Chris@49
|
1524 eigval[i] = std::complex<T>(wr[i], wi[i]);
|
Chris@49
|
1525 }
|
Chris@49
|
1526
|
Chris@49
|
1527 return (info == 0);
|
Chris@49
|
1528 }
|
Chris@49
|
1529 #else
|
Chris@49
|
1530 {
|
Chris@49
|
1531 arma_ignore(eigval);
|
Chris@49
|
1532 arma_ignore(l_eigvec);
|
Chris@49
|
1533 arma_ignore(r_eigvec);
|
Chris@49
|
1534 arma_ignore(X);
|
Chris@49
|
1535 arma_ignore(side);
|
Chris@49
|
1536 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
|
Chris@49
|
1537 return false;
|
Chris@49
|
1538 }
|
Chris@49
|
1539 #endif
|
Chris@49
|
1540 }
|
Chris@49
|
1541
|
Chris@49
|
1542
|
Chris@49
|
1543
|
Chris@49
|
1544
|
Chris@49
|
1545
|
Chris@49
|
1546 //! Eigenvalues and eigenvectors of a general square complex matrix using LAPACK
|
Chris@49
|
1547 //! The argument 'side' specifies which eigenvectors should be calculated
|
Chris@49
|
1548 //! (see code for mode details).
|
Chris@49
|
1549 template<typename T, typename T1>
|
Chris@49
|
1550 inline
|
Chris@49
|
1551 bool
|
Chris@49
|
1552 auxlib::eig_gen
|
Chris@49
|
1553 (
|
Chris@49
|
1554 Col< std::complex<T> >& eigval,
|
Chris@49
|
1555 Mat< std::complex<T> >& l_eigvec,
|
Chris@49
|
1556 Mat< std::complex<T> >& r_eigvec,
|
Chris@49
|
1557 const Base< std::complex<T>, T1 >& X,
|
Chris@49
|
1558 const char side
|
Chris@49
|
1559 )
|
Chris@49
|
1560 {
|
Chris@49
|
1561 arma_extra_debug_sigprint();
|
Chris@49
|
1562
|
Chris@49
|
1563 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1564 {
|
Chris@49
|
1565 typedef typename std::complex<T> eT;
|
Chris@49
|
1566
|
Chris@49
|
1567 char jobvl;
|
Chris@49
|
1568 char jobvr;
|
Chris@49
|
1569
|
Chris@49
|
1570 switch(side)
|
Chris@49
|
1571 {
|
Chris@49
|
1572 case 'l': // left
|
Chris@49
|
1573 jobvl = 'V';
|
Chris@49
|
1574 jobvr = 'N';
|
Chris@49
|
1575 break;
|
Chris@49
|
1576
|
Chris@49
|
1577 case 'r': // right
|
Chris@49
|
1578 jobvl = 'N';
|
Chris@49
|
1579 jobvr = 'V';
|
Chris@49
|
1580 break;
|
Chris@49
|
1581
|
Chris@49
|
1582 case 'b': // both
|
Chris@49
|
1583 jobvl = 'V';
|
Chris@49
|
1584 jobvr = 'V';
|
Chris@49
|
1585 break;
|
Chris@49
|
1586
|
Chris@49
|
1587 case 'n': // neither
|
Chris@49
|
1588 jobvl = 'N';
|
Chris@49
|
1589 jobvr = 'N';
|
Chris@49
|
1590 break;
|
Chris@49
|
1591
|
Chris@49
|
1592 default:
|
Chris@49
|
1593 arma_stop("eig_gen(): parameter 'side' is invalid");
|
Chris@49
|
1594 return false;
|
Chris@49
|
1595 }
|
Chris@49
|
1596
|
Chris@49
|
1597 Mat<eT> A(X.get_ref());
|
Chris@49
|
1598 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
|
Chris@49
|
1599
|
Chris@49
|
1600 if(A.is_empty())
|
Chris@49
|
1601 {
|
Chris@49
|
1602 eigval.reset();
|
Chris@49
|
1603 l_eigvec.reset();
|
Chris@49
|
1604 r_eigvec.reset();
|
Chris@49
|
1605 return true;
|
Chris@49
|
1606 }
|
Chris@49
|
1607
|
Chris@49
|
1608 const uword A_n_rows = A.n_rows;
|
Chris@49
|
1609
|
Chris@49
|
1610 eigval.set_size(A_n_rows);
|
Chris@49
|
1611
|
Chris@49
|
1612 l_eigvec.set_size(A_n_rows, A_n_rows);
|
Chris@49
|
1613 r_eigvec.set_size(A_n_rows, A_n_rows);
|
Chris@49
|
1614
|
Chris@49
|
1615 blas_int N = blas_int(A_n_rows);
|
Chris@49
|
1616 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*N) );
|
Chris@49
|
1617 blas_int info = 0;
|
Chris@49
|
1618
|
Chris@49
|
1619 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
1620 podarray<T> rwork( static_cast<uword>(2*N) );
|
Chris@49
|
1621
|
Chris@49
|
1622 arma_extra_debug_print("lapack::cx_geev()");
|
Chris@49
|
1623 lapack::cx_geev(&jobvl, &jobvr, &N, A.memptr(), &N, eigval.memptr(), l_eigvec.memptr(), &N, r_eigvec.memptr(), &N, work.memptr(), &lwork, rwork.memptr(), &info);
|
Chris@49
|
1624
|
Chris@49
|
1625 return (info == 0);
|
Chris@49
|
1626 }
|
Chris@49
|
1627 #else
|
Chris@49
|
1628 {
|
Chris@49
|
1629 arma_ignore(eigval);
|
Chris@49
|
1630 arma_ignore(l_eigvec);
|
Chris@49
|
1631 arma_ignore(r_eigvec);
|
Chris@49
|
1632 arma_ignore(X);
|
Chris@49
|
1633 arma_ignore(side);
|
Chris@49
|
1634 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
|
Chris@49
|
1635 return false;
|
Chris@49
|
1636 }
|
Chris@49
|
1637 #endif
|
Chris@49
|
1638 }
|
Chris@49
|
1639
|
Chris@49
|
1640
|
Chris@49
|
1641
|
Chris@49
|
1642 template<typename eT, typename T1>
|
Chris@49
|
1643 inline
|
Chris@49
|
1644 bool
|
Chris@49
|
1645 auxlib::chol(Mat<eT>& out, const Base<eT,T1>& X)
|
Chris@49
|
1646 {
|
Chris@49
|
1647 arma_extra_debug_sigprint();
|
Chris@49
|
1648
|
Chris@49
|
1649 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1650 {
|
Chris@49
|
1651 out = X.get_ref();
|
Chris@49
|
1652
|
Chris@49
|
1653 arma_debug_check( (out.is_square() == false), "chol(): given matrix is not square" );
|
Chris@49
|
1654
|
Chris@49
|
1655 if(out.is_empty())
|
Chris@49
|
1656 {
|
Chris@49
|
1657 return true;
|
Chris@49
|
1658 }
|
Chris@49
|
1659
|
Chris@49
|
1660 const uword out_n_rows = out.n_rows;
|
Chris@49
|
1661
|
Chris@49
|
1662 char uplo = 'U';
|
Chris@49
|
1663 blas_int n = out_n_rows;
|
Chris@49
|
1664 blas_int info = 0;
|
Chris@49
|
1665
|
Chris@49
|
1666 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
|
Chris@49
|
1667
|
Chris@49
|
1668 for(uword col=0; col<out_n_rows; ++col)
|
Chris@49
|
1669 {
|
Chris@49
|
1670 eT* colptr = out.colptr(col);
|
Chris@49
|
1671
|
Chris@49
|
1672 for(uword row=(col+1); row < out_n_rows; ++row)
|
Chris@49
|
1673 {
|
Chris@49
|
1674 colptr[row] = eT(0);
|
Chris@49
|
1675 }
|
Chris@49
|
1676 }
|
Chris@49
|
1677
|
Chris@49
|
1678 return (info == 0);
|
Chris@49
|
1679 }
|
Chris@49
|
1680 #else
|
Chris@49
|
1681 {
|
Chris@49
|
1682 arma_ignore(out);
|
Chris@49
|
1683 arma_ignore(X);
|
Chris@49
|
1684
|
Chris@49
|
1685 arma_stop("chol(): use of LAPACK needs to be enabled");
|
Chris@49
|
1686 return false;
|
Chris@49
|
1687 }
|
Chris@49
|
1688 #endif
|
Chris@49
|
1689 }
|
Chris@49
|
1690
|
Chris@49
|
1691
|
Chris@49
|
1692
|
Chris@49
|
1693 template<typename eT, typename T1>
|
Chris@49
|
1694 inline
|
Chris@49
|
1695 bool
|
Chris@49
|
1696 auxlib::qr(Mat<eT>& Q, Mat<eT>& R, const Base<eT,T1>& X)
|
Chris@49
|
1697 {
|
Chris@49
|
1698 arma_extra_debug_sigprint();
|
Chris@49
|
1699
|
Chris@49
|
1700 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1701 {
|
Chris@49
|
1702 R = X.get_ref();
|
Chris@49
|
1703
|
Chris@49
|
1704 const uword R_n_rows = R.n_rows;
|
Chris@49
|
1705 const uword R_n_cols = R.n_cols;
|
Chris@49
|
1706
|
Chris@49
|
1707 if(R.is_empty())
|
Chris@49
|
1708 {
|
Chris@49
|
1709 Q.eye(R_n_rows, R_n_rows);
|
Chris@49
|
1710 return true;
|
Chris@49
|
1711 }
|
Chris@49
|
1712
|
Chris@49
|
1713 blas_int m = static_cast<blas_int>(R_n_rows);
|
Chris@49
|
1714 blas_int n = static_cast<blas_int>(R_n_cols);
|
Chris@49
|
1715 blas_int lwork = 0;
|
Chris@49
|
1716 blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr()
|
Chris@49
|
1717 blas_int k = (std::min)(m,n);
|
Chris@49
|
1718 blas_int info = 0;
|
Chris@49
|
1719
|
Chris@49
|
1720 podarray<eT> tau( static_cast<uword>(k) );
|
Chris@49
|
1721
|
Chris@49
|
1722 eT work_query[2];
|
Chris@49
|
1723 blas_int lwork_query = -1;
|
Chris@49
|
1724
|
Chris@49
|
1725 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info);
|
Chris@49
|
1726
|
Chris@49
|
1727 if(info == 0)
|
Chris@49
|
1728 {
|
Chris@49
|
1729 const blas_int lwork_proposed = static_cast<blas_int>( access::tmp_real(work_query[0]) );
|
Chris@49
|
1730
|
Chris@49
|
1731 lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min;
|
Chris@49
|
1732 }
|
Chris@49
|
1733 else
|
Chris@49
|
1734 {
|
Chris@49
|
1735 return false;
|
Chris@49
|
1736 }
|
Chris@49
|
1737
|
Chris@49
|
1738 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
1739
|
Chris@49
|
1740 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
|
Chris@49
|
1741
|
Chris@49
|
1742 Q.set_size(R_n_rows, R_n_rows);
|
Chris@49
|
1743
|
Chris@49
|
1744 arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) );
|
Chris@49
|
1745
|
Chris@49
|
1746 //
|
Chris@49
|
1747 // construct R
|
Chris@49
|
1748
|
Chris@49
|
1749 for(uword col=0; col < R_n_cols; ++col)
|
Chris@49
|
1750 {
|
Chris@49
|
1751 for(uword row=(col+1); row < R_n_rows; ++row)
|
Chris@49
|
1752 {
|
Chris@49
|
1753 R.at(row,col) = eT(0);
|
Chris@49
|
1754 }
|
Chris@49
|
1755 }
|
Chris@49
|
1756
|
Chris@49
|
1757
|
Chris@49
|
1758 if( (is_float<eT>::value == true) || (is_double<eT>::value == true) )
|
Chris@49
|
1759 {
|
Chris@49
|
1760 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
|
Chris@49
|
1761 }
|
Chris@49
|
1762 else
|
Chris@49
|
1763 if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) )
|
Chris@49
|
1764 {
|
Chris@49
|
1765 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
|
Chris@49
|
1766 }
|
Chris@49
|
1767
|
Chris@49
|
1768 return (info == 0);
|
Chris@49
|
1769 }
|
Chris@49
|
1770 #else
|
Chris@49
|
1771 {
|
Chris@49
|
1772 arma_ignore(Q);
|
Chris@49
|
1773 arma_ignore(R);
|
Chris@49
|
1774 arma_ignore(X);
|
Chris@49
|
1775 arma_stop("qr(): use of LAPACK needs to be enabled");
|
Chris@49
|
1776 return false;
|
Chris@49
|
1777 }
|
Chris@49
|
1778 #endif
|
Chris@49
|
1779 }
|
Chris@49
|
1780
|
Chris@49
|
1781
|
Chris@49
|
1782
|
Chris@49
|
1783 template<typename eT, typename T1>
|
Chris@49
|
1784 inline
|
Chris@49
|
1785 bool
|
Chris@49
|
1786 auxlib::qr_econ(Mat<eT>& Q, Mat<eT>& R, const Base<eT,T1>& X)
|
Chris@49
|
1787 {
|
Chris@49
|
1788 arma_extra_debug_sigprint();
|
Chris@49
|
1789
|
Chris@49
|
1790 // This function implements a memory-efficient QR for a non-square X that has dimensions m x n.
|
Chris@49
|
1791 // This basically discards the basis for the null-space.
|
Chris@49
|
1792 //
|
Chris@49
|
1793 // if m <= n: (use standard routine)
|
Chris@49
|
1794 // Q[m,m]*R[m,n] = X[m,n]
|
Chris@49
|
1795 // geqrf Needs A[m,n]: Uses R
|
Chris@49
|
1796 // orgqr Needs A[m,m]: Uses Q
|
Chris@49
|
1797 // otherwise: (memory-efficient routine)
|
Chris@49
|
1798 // Q[m,n]*R[n,n] = X[m,n]
|
Chris@49
|
1799 // geqrf Needs A[m,n]: Uses Q
|
Chris@49
|
1800 // geqrf Needs A[m,n]: Uses Q
|
Chris@49
|
1801
|
Chris@49
|
1802 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1803 {
|
Chris@49
|
1804 if(is_Mat<T1>::value == true)
|
Chris@49
|
1805 {
|
Chris@49
|
1806 const unwrap<T1> tmp(X.get_ref());
|
Chris@49
|
1807 const Mat<eT>& M = tmp.M;
|
Chris@49
|
1808
|
Chris@49
|
1809 if(M.n_rows < M.n_cols)
|
Chris@49
|
1810 {
|
Chris@49
|
1811 return auxlib::qr(Q, R, X);
|
Chris@49
|
1812 }
|
Chris@49
|
1813 }
|
Chris@49
|
1814
|
Chris@49
|
1815 Q = X.get_ref();
|
Chris@49
|
1816
|
Chris@49
|
1817 const uword Q_n_rows = Q.n_rows;
|
Chris@49
|
1818 const uword Q_n_cols = Q.n_cols;
|
Chris@49
|
1819
|
Chris@49
|
1820 if( Q_n_rows <= Q_n_cols )
|
Chris@49
|
1821 {
|
Chris@49
|
1822 return auxlib::qr(Q, R, Q);
|
Chris@49
|
1823 }
|
Chris@49
|
1824
|
Chris@49
|
1825 if(Q.is_empty())
|
Chris@49
|
1826 {
|
Chris@49
|
1827 Q.set_size(Q_n_rows, 0 );
|
Chris@49
|
1828 R.set_size(0, Q_n_cols);
|
Chris@49
|
1829 return true;
|
Chris@49
|
1830 }
|
Chris@49
|
1831
|
Chris@49
|
1832 blas_int m = static_cast<blas_int>(Q_n_rows);
|
Chris@49
|
1833 blas_int n = static_cast<blas_int>(Q_n_cols);
|
Chris@49
|
1834 blas_int lwork = 0;
|
Chris@49
|
1835 blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr()
|
Chris@49
|
1836 blas_int k = (std::min)(m,n);
|
Chris@49
|
1837 blas_int info = 0;
|
Chris@49
|
1838
|
Chris@49
|
1839 podarray<eT> tau( static_cast<uword>(k) );
|
Chris@49
|
1840
|
Chris@49
|
1841 eT work_query[2];
|
Chris@49
|
1842 blas_int lwork_query = -1;
|
Chris@49
|
1843
|
Chris@49
|
1844 lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info);
|
Chris@49
|
1845
|
Chris@49
|
1846 if(info == 0)
|
Chris@49
|
1847 {
|
Chris@49
|
1848 const blas_int lwork_proposed = static_cast<blas_int>( access::tmp_real(work_query[0]) );
|
Chris@49
|
1849
|
Chris@49
|
1850 lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min;
|
Chris@49
|
1851 }
|
Chris@49
|
1852 else
|
Chris@49
|
1853 {
|
Chris@49
|
1854 return false;
|
Chris@49
|
1855 }
|
Chris@49
|
1856
|
Chris@49
|
1857 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
1858
|
Chris@49
|
1859 lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
|
Chris@49
|
1860
|
Chris@49
|
1861 // Q now has the elements on and above the diagonal of the array
|
Chris@49
|
1862 // contain the min(M,N)-by-N upper trapezoidal matrix Q
|
Chris@49
|
1863 // (Q is upper triangular if m >= n);
|
Chris@49
|
1864 // the elements below the diagonal, with the array TAU,
|
Chris@49
|
1865 // represent the orthogonal matrix Q as a product of min(m,n) elementary reflectors.
|
Chris@49
|
1866
|
Chris@49
|
1867 R.set_size(Q_n_cols, Q_n_cols);
|
Chris@49
|
1868
|
Chris@49
|
1869 //
|
Chris@49
|
1870 // construct R
|
Chris@49
|
1871
|
Chris@49
|
1872 for(uword col=0; col < Q_n_cols; ++col)
|
Chris@49
|
1873 {
|
Chris@49
|
1874 for(uword row=0; row <= col; ++row)
|
Chris@49
|
1875 {
|
Chris@49
|
1876 R.at(row,col) = Q.at(row,col);
|
Chris@49
|
1877 }
|
Chris@49
|
1878
|
Chris@49
|
1879 for(uword row=(col+1); row < Q_n_cols; ++row)
|
Chris@49
|
1880 {
|
Chris@49
|
1881 R.at(row,col) = eT(0);
|
Chris@49
|
1882 }
|
Chris@49
|
1883 }
|
Chris@49
|
1884
|
Chris@49
|
1885 if( (is_float<eT>::value == true) || (is_double<eT>::value == true) )
|
Chris@49
|
1886 {
|
Chris@49
|
1887 lapack::orgqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
|
Chris@49
|
1888 }
|
Chris@49
|
1889 else
|
Chris@49
|
1890 if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) )
|
Chris@49
|
1891 {
|
Chris@49
|
1892 lapack::ungqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info);
|
Chris@49
|
1893 }
|
Chris@49
|
1894
|
Chris@49
|
1895 return (info == 0);
|
Chris@49
|
1896 }
|
Chris@49
|
1897 #else
|
Chris@49
|
1898 {
|
Chris@49
|
1899 arma_ignore(Q);
|
Chris@49
|
1900 arma_ignore(R);
|
Chris@49
|
1901 arma_ignore(X);
|
Chris@49
|
1902 arma_stop("qr_econ(): use of LAPACK needs to be enabled");
|
Chris@49
|
1903 return false;
|
Chris@49
|
1904 }
|
Chris@49
|
1905 #endif
|
Chris@49
|
1906 }
|
Chris@49
|
1907
|
Chris@49
|
1908
|
Chris@49
|
1909
|
Chris@49
|
1910 template<typename eT, typename T1>
|
Chris@49
|
1911 inline
|
Chris@49
|
1912 bool
|
Chris@49
|
1913 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X, uword& X_n_rows, uword& X_n_cols)
|
Chris@49
|
1914 {
|
Chris@49
|
1915 arma_extra_debug_sigprint();
|
Chris@49
|
1916
|
Chris@49
|
1917 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1918 {
|
Chris@49
|
1919 Mat<eT> A(X.get_ref());
|
Chris@49
|
1920
|
Chris@49
|
1921 X_n_rows = A.n_rows;
|
Chris@49
|
1922 X_n_cols = A.n_cols;
|
Chris@49
|
1923
|
Chris@49
|
1924 if(A.is_empty())
|
Chris@49
|
1925 {
|
Chris@49
|
1926 S.reset();
|
Chris@49
|
1927 return true;
|
Chris@49
|
1928 }
|
Chris@49
|
1929
|
Chris@49
|
1930 Mat<eT> U(1, 1);
|
Chris@49
|
1931 Mat<eT> V(1, A.n_cols);
|
Chris@49
|
1932
|
Chris@49
|
1933 char jobu = 'N';
|
Chris@49
|
1934 char jobvt = 'N';
|
Chris@49
|
1935
|
Chris@49
|
1936 blas_int m = A.n_rows;
|
Chris@49
|
1937 blas_int n = A.n_cols;
|
Chris@49
|
1938 blas_int min_mn = (std::min)(m,n);
|
Chris@49
|
1939 blas_int lda = A.n_rows;
|
Chris@49
|
1940 blas_int ldu = U.n_rows;
|
Chris@49
|
1941 blas_int ldvt = V.n_rows;
|
Chris@49
|
1942 blas_int lwork = 0;
|
Chris@49
|
1943 blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) );
|
Chris@49
|
1944 blas_int info = 0;
|
Chris@49
|
1945
|
Chris@49
|
1946 S.set_size( static_cast<uword>(min_mn) );
|
Chris@49
|
1947
|
Chris@49
|
1948 eT work_query[2];
|
Chris@49
|
1949 blas_int lwork_query = -1;
|
Chris@49
|
1950
|
Chris@49
|
1951 lapack::gesvd<eT>
|
Chris@49
|
1952 (
|
Chris@49
|
1953 &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info
|
Chris@49
|
1954 );
|
Chris@49
|
1955
|
Chris@49
|
1956 if(info == 0)
|
Chris@49
|
1957 {
|
Chris@49
|
1958 const blas_int lwork_proposed = static_cast<blas_int>( work_query[0] );
|
Chris@49
|
1959
|
Chris@49
|
1960 lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min;
|
Chris@49
|
1961
|
Chris@49
|
1962 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
1963
|
Chris@49
|
1964 lapack::gesvd<eT>
|
Chris@49
|
1965 (
|
Chris@49
|
1966 &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, &info
|
Chris@49
|
1967 );
|
Chris@49
|
1968 }
|
Chris@49
|
1969
|
Chris@49
|
1970 return (info == 0);
|
Chris@49
|
1971 }
|
Chris@49
|
1972 #else
|
Chris@49
|
1973 {
|
Chris@49
|
1974 arma_ignore(S);
|
Chris@49
|
1975 arma_ignore(X);
|
Chris@49
|
1976 arma_ignore(X_n_rows);
|
Chris@49
|
1977 arma_ignore(X_n_cols);
|
Chris@49
|
1978 arma_stop("svd(): use of LAPACK needs to be enabled");
|
Chris@49
|
1979 return false;
|
Chris@49
|
1980 }
|
Chris@49
|
1981 #endif
|
Chris@49
|
1982 }
|
Chris@49
|
1983
|
Chris@49
|
1984
|
Chris@49
|
1985
|
Chris@49
|
1986 template<typename T, typename T1>
|
Chris@49
|
1987 inline
|
Chris@49
|
1988 bool
|
Chris@49
|
1989 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X, uword& X_n_rows, uword& X_n_cols)
|
Chris@49
|
1990 {
|
Chris@49
|
1991 arma_extra_debug_sigprint();
|
Chris@49
|
1992
|
Chris@49
|
1993 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
1994 {
|
Chris@49
|
1995 typedef std::complex<T> eT;
|
Chris@49
|
1996
|
Chris@49
|
1997 Mat<eT> A(X.get_ref());
|
Chris@49
|
1998
|
Chris@49
|
1999 X_n_rows = A.n_rows;
|
Chris@49
|
2000 X_n_cols = A.n_cols;
|
Chris@49
|
2001
|
Chris@49
|
2002 if(A.is_empty())
|
Chris@49
|
2003 {
|
Chris@49
|
2004 S.reset();
|
Chris@49
|
2005 return true;
|
Chris@49
|
2006 }
|
Chris@49
|
2007
|
Chris@49
|
2008 Mat<eT> U(1, 1);
|
Chris@49
|
2009 Mat<eT> V(1, A.n_cols);
|
Chris@49
|
2010
|
Chris@49
|
2011 char jobu = 'N';
|
Chris@49
|
2012 char jobvt = 'N';
|
Chris@49
|
2013
|
Chris@49
|
2014 blas_int m = A.n_rows;
|
Chris@49
|
2015 blas_int n = A.n_cols;
|
Chris@49
|
2016 blas_int min_mn = (std::min)(m,n);
|
Chris@49
|
2017 blas_int lda = A.n_rows;
|
Chris@49
|
2018 blas_int ldu = U.n_rows;
|
Chris@49
|
2019 blas_int ldvt = V.n_rows;
|
Chris@49
|
2020 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*min_mn+(std::max)(m,n) ) );
|
Chris@49
|
2021 blas_int info = 0;
|
Chris@49
|
2022
|
Chris@49
|
2023 S.set_size( static_cast<uword>(min_mn) );
|
Chris@49
|
2024
|
Chris@49
|
2025 podarray<eT> work( static_cast<uword>(lwork ) );
|
Chris@49
|
2026 podarray< T> rwork( static_cast<uword>(5*min_mn) );
|
Chris@49
|
2027
|
Chris@49
|
2028 // let gesvd_() calculate the optimum size of the workspace
|
Chris@49
|
2029 blas_int lwork_tmp = -1;
|
Chris@49
|
2030
|
Chris@49
|
2031 lapack::cx_gesvd<T>
|
Chris@49
|
2032 (
|
Chris@49
|
2033 &jobu, &jobvt,
|
Chris@49
|
2034 &m, &n,
|
Chris@49
|
2035 A.memptr(), &lda,
|
Chris@49
|
2036 S.memptr(),
|
Chris@49
|
2037 U.memptr(), &ldu,
|
Chris@49
|
2038 V.memptr(), &ldvt,
|
Chris@49
|
2039 work.memptr(), &lwork_tmp,
|
Chris@49
|
2040 rwork.memptr(),
|
Chris@49
|
2041 &info
|
Chris@49
|
2042 );
|
Chris@49
|
2043
|
Chris@49
|
2044 if(info == 0)
|
Chris@49
|
2045 {
|
Chris@49
|
2046 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
|
Chris@49
|
2047 if(proposed_lwork > lwork)
|
Chris@49
|
2048 {
|
Chris@49
|
2049 lwork = proposed_lwork;
|
Chris@49
|
2050 work.set_size( static_cast<uword>(lwork) );
|
Chris@49
|
2051 }
|
Chris@49
|
2052
|
Chris@49
|
2053 lapack::cx_gesvd<T>
|
Chris@49
|
2054 (
|
Chris@49
|
2055 &jobu, &jobvt,
|
Chris@49
|
2056 &m, &n,
|
Chris@49
|
2057 A.memptr(), &lda,
|
Chris@49
|
2058 S.memptr(),
|
Chris@49
|
2059 U.memptr(), &ldu,
|
Chris@49
|
2060 V.memptr(), &ldvt,
|
Chris@49
|
2061 work.memptr(), &lwork,
|
Chris@49
|
2062 rwork.memptr(),
|
Chris@49
|
2063 &info
|
Chris@49
|
2064 );
|
Chris@49
|
2065 }
|
Chris@49
|
2066
|
Chris@49
|
2067 return (info == 0);
|
Chris@49
|
2068 }
|
Chris@49
|
2069 #else
|
Chris@49
|
2070 {
|
Chris@49
|
2071 arma_ignore(S);
|
Chris@49
|
2072 arma_ignore(X);
|
Chris@49
|
2073 arma_ignore(X_n_rows);
|
Chris@49
|
2074 arma_ignore(X_n_cols);
|
Chris@49
|
2075
|
Chris@49
|
2076 arma_stop("svd(): use of LAPACK needs to be enabled");
|
Chris@49
|
2077 return false;
|
Chris@49
|
2078 }
|
Chris@49
|
2079 #endif
|
Chris@49
|
2080 }
|
Chris@49
|
2081
|
Chris@49
|
2082
|
Chris@49
|
2083
|
Chris@49
|
2084 template<typename eT, typename T1>
|
Chris@49
|
2085 inline
|
Chris@49
|
2086 bool
|
Chris@49
|
2087 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X)
|
Chris@49
|
2088 {
|
Chris@49
|
2089 arma_extra_debug_sigprint();
|
Chris@49
|
2090
|
Chris@49
|
2091 uword junk;
|
Chris@49
|
2092 return auxlib::svd(S, X, junk, junk);
|
Chris@49
|
2093 }
|
Chris@49
|
2094
|
Chris@49
|
2095
|
Chris@49
|
2096
|
Chris@49
|
2097 template<typename T, typename T1>
|
Chris@49
|
2098 inline
|
Chris@49
|
2099 bool
|
Chris@49
|
2100 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X)
|
Chris@49
|
2101 {
|
Chris@49
|
2102 arma_extra_debug_sigprint();
|
Chris@49
|
2103
|
Chris@49
|
2104 uword junk;
|
Chris@49
|
2105 return auxlib::svd(S, X, junk, junk);
|
Chris@49
|
2106 }
|
Chris@49
|
2107
|
Chris@49
|
2108
|
Chris@49
|
2109
|
Chris@49
|
2110 template<typename eT, typename T1>
|
Chris@49
|
2111 inline
|
Chris@49
|
2112 bool
|
Chris@49
|
2113 auxlib::svd(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X)
|
Chris@49
|
2114 {
|
Chris@49
|
2115 arma_extra_debug_sigprint();
|
Chris@49
|
2116
|
Chris@49
|
2117 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
2118 {
|
Chris@49
|
2119 Mat<eT> A(X.get_ref());
|
Chris@49
|
2120
|
Chris@49
|
2121 if(A.is_empty())
|
Chris@49
|
2122 {
|
Chris@49
|
2123 U.eye(A.n_rows, A.n_rows);
|
Chris@49
|
2124 S.reset();
|
Chris@49
|
2125 V.eye(A.n_cols, A.n_cols);
|
Chris@49
|
2126 return true;
|
Chris@49
|
2127 }
|
Chris@49
|
2128
|
Chris@49
|
2129 U.set_size(A.n_rows, A.n_rows);
|
Chris@49
|
2130 V.set_size(A.n_cols, A.n_cols);
|
Chris@49
|
2131
|
Chris@49
|
2132 char jobu = 'A';
|
Chris@49
|
2133 char jobvt = 'A';
|
Chris@49
|
2134
|
Chris@49
|
2135 blas_int m = blas_int(A.n_rows);
|
Chris@49
|
2136 blas_int n = blas_int(A.n_cols);
|
Chris@49
|
2137 blas_int min_mn = (std::min)(m,n);
|
Chris@49
|
2138 blas_int lda = blas_int(A.n_rows);
|
Chris@49
|
2139 blas_int ldu = blas_int(U.n_rows);
|
Chris@49
|
2140 blas_int ldvt = blas_int(V.n_rows);
|
Chris@49
|
2141 blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) );
|
Chris@49
|
2142 blas_int lwork = 0;
|
Chris@49
|
2143 blas_int info = 0;
|
Chris@49
|
2144
|
Chris@49
|
2145 S.set_size( static_cast<uword>(min_mn) );
|
Chris@49
|
2146
|
Chris@49
|
2147 // let gesvd_() calculate the optimum size of the workspace
|
Chris@49
|
2148 eT work_query[2];
|
Chris@49
|
2149 blas_int lwork_query = -1;
|
Chris@49
|
2150
|
Chris@49
|
2151 lapack::gesvd<eT>
|
Chris@49
|
2152 (
|
Chris@49
|
2153 &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info
|
Chris@49
|
2154 );
|
Chris@49
|
2155
|
Chris@49
|
2156 if(info == 0)
|
Chris@49
|
2157 {
|
Chris@49
|
2158 const blas_int lwork_proposed = static_cast<blas_int>( work_query[0] );
|
Chris@49
|
2159
|
Chris@49
|
2160 lwork = (lwork_proposed > lwork_min) ? lwork_proposed : lwork_min;
|
Chris@49
|
2161
|
Chris@49
|
2162 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
2163
|
Chris@49
|
2164 lapack::gesvd<eT>
|
Chris@49
|
2165 (
|
Chris@49
|
2166 &jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, &info
|
Chris@49
|
2167 );
|
Chris@49
|
2168
|
Chris@49
|
2169 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
|
Chris@49
|
2170 }
|
Chris@49
|
2171
|
Chris@49
|
2172 return (info == 0);
|
Chris@49
|
2173 }
|
Chris@49
|
2174 #else
|
Chris@49
|
2175 {
|
Chris@49
|
2176 arma_ignore(U);
|
Chris@49
|
2177 arma_ignore(S);
|
Chris@49
|
2178 arma_ignore(V);
|
Chris@49
|
2179 arma_ignore(X);
|
Chris@49
|
2180 arma_stop("svd(): use of LAPACK needs to be enabled");
|
Chris@49
|
2181 return false;
|
Chris@49
|
2182 }
|
Chris@49
|
2183 #endif
|
Chris@49
|
2184 }
|
Chris@49
|
2185
|
Chris@49
|
2186
|
Chris@49
|
2187
|
Chris@49
|
2188 template<typename T, typename T1>
|
Chris@49
|
2189 inline
|
Chris@49
|
2190 bool
|
Chris@49
|
2191 auxlib::svd(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X)
|
Chris@49
|
2192 {
|
Chris@49
|
2193 arma_extra_debug_sigprint();
|
Chris@49
|
2194
|
Chris@49
|
2195 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
2196 {
|
Chris@49
|
2197 typedef std::complex<T> eT;
|
Chris@49
|
2198
|
Chris@49
|
2199 Mat<eT> A(X.get_ref());
|
Chris@49
|
2200
|
Chris@49
|
2201 if(A.is_empty())
|
Chris@49
|
2202 {
|
Chris@49
|
2203 U.eye(A.n_rows, A.n_rows);
|
Chris@49
|
2204 S.reset();
|
Chris@49
|
2205 V.eye(A.n_cols, A.n_cols);
|
Chris@49
|
2206 return true;
|
Chris@49
|
2207 }
|
Chris@49
|
2208
|
Chris@49
|
2209 U.set_size(A.n_rows, A.n_rows);
|
Chris@49
|
2210 V.set_size(A.n_cols, A.n_cols);
|
Chris@49
|
2211
|
Chris@49
|
2212 char jobu = 'A';
|
Chris@49
|
2213 char jobvt = 'A';
|
Chris@49
|
2214
|
Chris@49
|
2215 blas_int m = blas_int(A.n_rows);
|
Chris@49
|
2216 blas_int n = blas_int(A.n_cols);
|
Chris@49
|
2217 blas_int min_mn = (std::min)(m,n);
|
Chris@49
|
2218 blas_int lda = blas_int(A.n_rows);
|
Chris@49
|
2219 blas_int ldu = blas_int(U.n_rows);
|
Chris@49
|
2220 blas_int ldvt = blas_int(V.n_rows);
|
Chris@49
|
2221 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*min_mn + (std::max)(m,n) ) );
|
Chris@49
|
2222 blas_int info = 0;
|
Chris@49
|
2223
|
Chris@49
|
2224 S.set_size( static_cast<uword>(min_mn) );
|
Chris@49
|
2225
|
Chris@49
|
2226 podarray<eT> work( static_cast<uword>(lwork ) );
|
Chris@49
|
2227 podarray<T> rwork( static_cast<uword>(5*min_mn) );
|
Chris@49
|
2228
|
Chris@49
|
2229 // let gesvd_() calculate the optimum size of the workspace
|
Chris@49
|
2230 blas_int lwork_tmp = -1;
|
Chris@49
|
2231 lapack::cx_gesvd<T>
|
Chris@49
|
2232 (
|
Chris@49
|
2233 &jobu, &jobvt,
|
Chris@49
|
2234 &m, &n,
|
Chris@49
|
2235 A.memptr(), &lda,
|
Chris@49
|
2236 S.memptr(),
|
Chris@49
|
2237 U.memptr(), &ldu,
|
Chris@49
|
2238 V.memptr(), &ldvt,
|
Chris@49
|
2239 work.memptr(), &lwork_tmp,
|
Chris@49
|
2240 rwork.memptr(),
|
Chris@49
|
2241 &info
|
Chris@49
|
2242 );
|
Chris@49
|
2243
|
Chris@49
|
2244 if(info == 0)
|
Chris@49
|
2245 {
|
Chris@49
|
2246 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
|
Chris@49
|
2247
|
Chris@49
|
2248 if(proposed_lwork > lwork)
|
Chris@49
|
2249 {
|
Chris@49
|
2250 lwork = proposed_lwork;
|
Chris@49
|
2251 work.set_size( static_cast<uword>(lwork) );
|
Chris@49
|
2252 }
|
Chris@49
|
2253
|
Chris@49
|
2254 lapack::cx_gesvd<T>
|
Chris@49
|
2255 (
|
Chris@49
|
2256 &jobu, &jobvt,
|
Chris@49
|
2257 &m, &n,
|
Chris@49
|
2258 A.memptr(), &lda,
|
Chris@49
|
2259 S.memptr(),
|
Chris@49
|
2260 U.memptr(), &ldu,
|
Chris@49
|
2261 V.memptr(), &ldvt,
|
Chris@49
|
2262 work.memptr(), &lwork,
|
Chris@49
|
2263 rwork.memptr(),
|
Chris@49
|
2264 &info
|
Chris@49
|
2265 );
|
Chris@49
|
2266
|
Chris@49
|
2267 op_htrans::apply(V,V); // op_htrans will work out that an in-place transpose can be done
|
Chris@49
|
2268 }
|
Chris@49
|
2269
|
Chris@49
|
2270 return (info == 0);
|
Chris@49
|
2271 }
|
Chris@49
|
2272 #else
|
Chris@49
|
2273 {
|
Chris@49
|
2274 arma_ignore(U);
|
Chris@49
|
2275 arma_ignore(S);
|
Chris@49
|
2276 arma_ignore(V);
|
Chris@49
|
2277 arma_ignore(X);
|
Chris@49
|
2278 arma_stop("svd(): use of LAPACK needs to be enabled");
|
Chris@49
|
2279 return false;
|
Chris@49
|
2280 }
|
Chris@49
|
2281 #endif
|
Chris@49
|
2282 }
|
Chris@49
|
2283
|
Chris@49
|
2284
|
Chris@49
|
2285
|
Chris@49
|
2286 template<typename eT, typename T1>
|
Chris@49
|
2287 inline
|
Chris@49
|
2288 bool
|
Chris@49
|
2289 auxlib::svd_econ(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X, const char mode)
|
Chris@49
|
2290 {
|
Chris@49
|
2291 arma_extra_debug_sigprint();
|
Chris@49
|
2292
|
Chris@49
|
2293 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
2294 {
|
Chris@49
|
2295 Mat<eT> A(X.get_ref());
|
Chris@49
|
2296
|
Chris@49
|
2297 blas_int m = blas_int(A.n_rows);
|
Chris@49
|
2298 blas_int n = blas_int(A.n_cols);
|
Chris@49
|
2299 blas_int min_mn = (std::min)(m,n);
|
Chris@49
|
2300 blas_int lda = blas_int(A.n_rows);
|
Chris@49
|
2301
|
Chris@49
|
2302 S.set_size( static_cast<uword>(min_mn) );
|
Chris@49
|
2303
|
Chris@49
|
2304 blas_int ldu = 0;
|
Chris@49
|
2305 blas_int ldvt = 0;
|
Chris@49
|
2306
|
Chris@49
|
2307 char jobu;
|
Chris@49
|
2308 char jobvt;
|
Chris@49
|
2309
|
Chris@49
|
2310 switch(mode)
|
Chris@49
|
2311 {
|
Chris@49
|
2312 case 'l':
|
Chris@49
|
2313 jobu = 'S';
|
Chris@49
|
2314 jobvt = 'N';
|
Chris@49
|
2315
|
Chris@49
|
2316 ldu = m;
|
Chris@49
|
2317 ldvt = 1;
|
Chris@49
|
2318
|
Chris@49
|
2319 U.set_size( static_cast<uword>(ldu), static_cast<uword>(min_mn) );
|
Chris@49
|
2320 V.reset();
|
Chris@49
|
2321
|
Chris@49
|
2322 break;
|
Chris@49
|
2323
|
Chris@49
|
2324
|
Chris@49
|
2325 case 'r':
|
Chris@49
|
2326 jobu = 'N';
|
Chris@49
|
2327 jobvt = 'S';
|
Chris@49
|
2328
|
Chris@49
|
2329 ldu = 1;
|
Chris@49
|
2330 ldvt = (std::min)(m,n);
|
Chris@49
|
2331
|
Chris@49
|
2332 U.reset();
|
Chris@49
|
2333 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
|
Chris@49
|
2334
|
Chris@49
|
2335 break;
|
Chris@49
|
2336
|
Chris@49
|
2337
|
Chris@49
|
2338 case 'b':
|
Chris@49
|
2339 jobu = 'S';
|
Chris@49
|
2340 jobvt = 'S';
|
Chris@49
|
2341
|
Chris@49
|
2342 ldu = m;
|
Chris@49
|
2343 ldvt = (std::min)(m,n);
|
Chris@49
|
2344
|
Chris@49
|
2345 U.set_size( static_cast<uword>(ldu), static_cast<uword>(min_mn) );
|
Chris@49
|
2346 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n ) );
|
Chris@49
|
2347
|
Chris@49
|
2348 break;
|
Chris@49
|
2349
|
Chris@49
|
2350
|
Chris@49
|
2351 default:
|
Chris@49
|
2352 U.reset();
|
Chris@49
|
2353 S.reset();
|
Chris@49
|
2354 V.reset();
|
Chris@49
|
2355 return false;
|
Chris@49
|
2356 }
|
Chris@49
|
2357
|
Chris@49
|
2358
|
Chris@49
|
2359 if(A.is_empty())
|
Chris@49
|
2360 {
|
Chris@49
|
2361 U.eye();
|
Chris@49
|
2362 S.reset();
|
Chris@49
|
2363 V.eye();
|
Chris@49
|
2364 return true;
|
Chris@49
|
2365 }
|
Chris@49
|
2366
|
Chris@49
|
2367
|
Chris@49
|
2368 blas_int lwork = 3 * ( (std::max)(blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ) );
|
Chris@49
|
2369 blas_int info = 0;
|
Chris@49
|
2370
|
Chris@49
|
2371
|
Chris@49
|
2372 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
2373
|
Chris@49
|
2374 // let gesvd_() calculate the optimum size of the workspace
|
Chris@49
|
2375 blas_int lwork_tmp = -1;
|
Chris@49
|
2376
|
Chris@49
|
2377 lapack::gesvd<eT>
|
Chris@49
|
2378 (
|
Chris@49
|
2379 &jobu, &jobvt,
|
Chris@49
|
2380 &m, &n,
|
Chris@49
|
2381 A.memptr(), &lda,
|
Chris@49
|
2382 S.memptr(),
|
Chris@49
|
2383 U.memptr(), &ldu,
|
Chris@49
|
2384 V.memptr(), &ldvt,
|
Chris@49
|
2385 work.memptr(), &lwork_tmp,
|
Chris@49
|
2386 &info
|
Chris@49
|
2387 );
|
Chris@49
|
2388
|
Chris@49
|
2389 if(info == 0)
|
Chris@49
|
2390 {
|
Chris@49
|
2391 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
|
Chris@49
|
2392 if(proposed_lwork > lwork)
|
Chris@49
|
2393 {
|
Chris@49
|
2394 lwork = proposed_lwork;
|
Chris@49
|
2395 work.set_size( static_cast<uword>(lwork) );
|
Chris@49
|
2396 }
|
Chris@49
|
2397
|
Chris@49
|
2398 lapack::gesvd<eT>
|
Chris@49
|
2399 (
|
Chris@49
|
2400 &jobu, &jobvt,
|
Chris@49
|
2401 &m, &n,
|
Chris@49
|
2402 A.memptr(), &lda,
|
Chris@49
|
2403 S.memptr(),
|
Chris@49
|
2404 U.memptr(), &ldu,
|
Chris@49
|
2405 V.memptr(), &ldvt,
|
Chris@49
|
2406 work.memptr(), &lwork,
|
Chris@49
|
2407 &info
|
Chris@49
|
2408 );
|
Chris@49
|
2409
|
Chris@49
|
2410 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
|
Chris@49
|
2411 }
|
Chris@49
|
2412
|
Chris@49
|
2413 return (info == 0);
|
Chris@49
|
2414 }
|
Chris@49
|
2415 #else
|
Chris@49
|
2416 {
|
Chris@49
|
2417 arma_ignore(U);
|
Chris@49
|
2418 arma_ignore(S);
|
Chris@49
|
2419 arma_ignore(V);
|
Chris@49
|
2420 arma_ignore(X);
|
Chris@49
|
2421 arma_ignore(mode);
|
Chris@49
|
2422 arma_stop("svd(): use of LAPACK needs to be enabled");
|
Chris@49
|
2423 return false;
|
Chris@49
|
2424 }
|
Chris@49
|
2425 #endif
|
Chris@49
|
2426 }
|
Chris@49
|
2427
|
Chris@49
|
2428
|
Chris@49
|
2429
|
Chris@49
|
2430 template<typename T, typename T1>
|
Chris@49
|
2431 inline
|
Chris@49
|
2432 bool
|
Chris@49
|
2433 auxlib::svd_econ(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X, const char mode)
|
Chris@49
|
2434 {
|
Chris@49
|
2435 arma_extra_debug_sigprint();
|
Chris@49
|
2436
|
Chris@49
|
2437 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
2438 {
|
Chris@49
|
2439 typedef std::complex<T> eT;
|
Chris@49
|
2440
|
Chris@49
|
2441 Mat<eT> A(X.get_ref());
|
Chris@49
|
2442
|
Chris@49
|
2443 blas_int m = blas_int(A.n_rows);
|
Chris@49
|
2444 blas_int n = blas_int(A.n_cols);
|
Chris@49
|
2445 blas_int min_mn = (std::min)(m,n);
|
Chris@49
|
2446 blas_int lda = blas_int(A.n_rows);
|
Chris@49
|
2447
|
Chris@49
|
2448 S.set_size( static_cast<uword>(min_mn) );
|
Chris@49
|
2449
|
Chris@49
|
2450 blas_int ldu = 0;
|
Chris@49
|
2451 blas_int ldvt = 0;
|
Chris@49
|
2452
|
Chris@49
|
2453 char jobu;
|
Chris@49
|
2454 char jobvt;
|
Chris@49
|
2455
|
Chris@49
|
2456 switch(mode)
|
Chris@49
|
2457 {
|
Chris@49
|
2458 case 'l':
|
Chris@49
|
2459 jobu = 'S';
|
Chris@49
|
2460 jobvt = 'N';
|
Chris@49
|
2461
|
Chris@49
|
2462 ldu = m;
|
Chris@49
|
2463 ldvt = 1;
|
Chris@49
|
2464
|
Chris@49
|
2465 U.set_size( static_cast<uword>(ldu), static_cast<uword>(min_mn) );
|
Chris@49
|
2466 V.reset();
|
Chris@49
|
2467
|
Chris@49
|
2468 break;
|
Chris@49
|
2469
|
Chris@49
|
2470
|
Chris@49
|
2471 case 'r':
|
Chris@49
|
2472 jobu = 'N';
|
Chris@49
|
2473 jobvt = 'S';
|
Chris@49
|
2474
|
Chris@49
|
2475 ldu = 1;
|
Chris@49
|
2476 ldvt = (std::min)(m,n);
|
Chris@49
|
2477
|
Chris@49
|
2478 U.reset();
|
Chris@49
|
2479 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
|
Chris@49
|
2480
|
Chris@49
|
2481 break;
|
Chris@49
|
2482
|
Chris@49
|
2483
|
Chris@49
|
2484 case 'b':
|
Chris@49
|
2485 jobu = 'S';
|
Chris@49
|
2486 jobvt = 'S';
|
Chris@49
|
2487
|
Chris@49
|
2488 ldu = m;
|
Chris@49
|
2489 ldvt = (std::min)(m,n);
|
Chris@49
|
2490
|
Chris@49
|
2491 U.set_size( static_cast<uword>(ldu), static_cast<uword>(min_mn) );
|
Chris@49
|
2492 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
|
Chris@49
|
2493
|
Chris@49
|
2494 break;
|
Chris@49
|
2495
|
Chris@49
|
2496
|
Chris@49
|
2497 default:
|
Chris@49
|
2498 U.reset();
|
Chris@49
|
2499 S.reset();
|
Chris@49
|
2500 V.reset();
|
Chris@49
|
2501 return false;
|
Chris@49
|
2502 }
|
Chris@49
|
2503
|
Chris@49
|
2504
|
Chris@49
|
2505 if(A.is_empty())
|
Chris@49
|
2506 {
|
Chris@49
|
2507 U.eye();
|
Chris@49
|
2508 S.reset();
|
Chris@49
|
2509 V.eye();
|
Chris@49
|
2510 return true;
|
Chris@49
|
2511 }
|
Chris@49
|
2512
|
Chris@49
|
2513
|
Chris@49
|
2514 blas_int lwork = 3 * ( (std::max)(blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ) );
|
Chris@49
|
2515 blas_int info = 0;
|
Chris@49
|
2516
|
Chris@49
|
2517
|
Chris@49
|
2518 podarray<eT> work( static_cast<uword>(lwork ) );
|
Chris@49
|
2519 podarray<T> rwork( static_cast<uword>(5*min_mn) );
|
Chris@49
|
2520
|
Chris@49
|
2521 // let gesvd_() calculate the optimum size of the workspace
|
Chris@49
|
2522 blas_int lwork_tmp = -1;
|
Chris@49
|
2523
|
Chris@49
|
2524 lapack::cx_gesvd<T>
|
Chris@49
|
2525 (
|
Chris@49
|
2526 &jobu, &jobvt,
|
Chris@49
|
2527 &m, &n,
|
Chris@49
|
2528 A.memptr(), &lda,
|
Chris@49
|
2529 S.memptr(),
|
Chris@49
|
2530 U.memptr(), &ldu,
|
Chris@49
|
2531 V.memptr(), &ldvt,
|
Chris@49
|
2532 work.memptr(), &lwork_tmp,
|
Chris@49
|
2533 rwork.memptr(),
|
Chris@49
|
2534 &info
|
Chris@49
|
2535 );
|
Chris@49
|
2536
|
Chris@49
|
2537 if(info == 0)
|
Chris@49
|
2538 {
|
Chris@49
|
2539 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
|
Chris@49
|
2540 if(proposed_lwork > lwork)
|
Chris@49
|
2541 {
|
Chris@49
|
2542 lwork = proposed_lwork;
|
Chris@49
|
2543 work.set_size( static_cast<uword>(lwork) );
|
Chris@49
|
2544 }
|
Chris@49
|
2545
|
Chris@49
|
2546 lapack::cx_gesvd<T>
|
Chris@49
|
2547 (
|
Chris@49
|
2548 &jobu, &jobvt,
|
Chris@49
|
2549 &m, &n,
|
Chris@49
|
2550 A.memptr(), &lda,
|
Chris@49
|
2551 S.memptr(),
|
Chris@49
|
2552 U.memptr(), &ldu,
|
Chris@49
|
2553 V.memptr(), &ldvt,
|
Chris@49
|
2554 work.memptr(), &lwork,
|
Chris@49
|
2555 rwork.memptr(),
|
Chris@49
|
2556 &info
|
Chris@49
|
2557 );
|
Chris@49
|
2558
|
Chris@49
|
2559 op_htrans::apply(V,V); // op_strans will work out that an in-place transpose can be done
|
Chris@49
|
2560 }
|
Chris@49
|
2561
|
Chris@49
|
2562 return (info == 0);
|
Chris@49
|
2563 }
|
Chris@49
|
2564 #else
|
Chris@49
|
2565 {
|
Chris@49
|
2566 arma_ignore(U);
|
Chris@49
|
2567 arma_ignore(S);
|
Chris@49
|
2568 arma_ignore(V);
|
Chris@49
|
2569 arma_ignore(X);
|
Chris@49
|
2570 arma_ignore(mode);
|
Chris@49
|
2571 arma_stop("svd(): use of LAPACK needs to be enabled");
|
Chris@49
|
2572 return false;
|
Chris@49
|
2573 }
|
Chris@49
|
2574 #endif
|
Chris@49
|
2575 }
|
Chris@49
|
2576
|
Chris@49
|
2577
|
Chris@49
|
2578
|
Chris@49
|
2579 template<typename eT, typename T1>
|
Chris@49
|
2580 inline
|
Chris@49
|
2581 bool
|
Chris@49
|
2582 auxlib::svd_dc(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X)
|
Chris@49
|
2583 {
|
Chris@49
|
2584 arma_extra_debug_sigprint();
|
Chris@49
|
2585
|
Chris@49
|
2586 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
2587 {
|
Chris@49
|
2588 Mat<eT> A(X.get_ref());
|
Chris@49
|
2589
|
Chris@49
|
2590 if(A.is_empty())
|
Chris@49
|
2591 {
|
Chris@49
|
2592 U.eye(A.n_rows, A.n_rows);
|
Chris@49
|
2593 S.reset();
|
Chris@49
|
2594 V.eye(A.n_cols, A.n_cols);
|
Chris@49
|
2595 return true;
|
Chris@49
|
2596 }
|
Chris@49
|
2597
|
Chris@49
|
2598 U.set_size(A.n_rows, A.n_rows);
|
Chris@49
|
2599 V.set_size(A.n_cols, A.n_cols);
|
Chris@49
|
2600
|
Chris@49
|
2601 char jobz = 'A';
|
Chris@49
|
2602
|
Chris@49
|
2603 blas_int m = blas_int(A.n_rows);
|
Chris@49
|
2604 blas_int n = blas_int(A.n_cols);
|
Chris@49
|
2605 blas_int min_mn = (std::min)(m,n);
|
Chris@49
|
2606 blas_int lda = blas_int(A.n_rows);
|
Chris@49
|
2607 blas_int ldu = blas_int(U.n_rows);
|
Chris@49
|
2608 blas_int ldvt = blas_int(V.n_rows);
|
Chris@49
|
2609 blas_int lwork = 3 * ( 3*min_mn*min_mn + (std::max)( (std::max)(m,n), 4*min_mn*min_mn + 4*min_mn ) );
|
Chris@49
|
2610 blas_int info = 0;
|
Chris@49
|
2611
|
Chris@49
|
2612 S.set_size( static_cast<uword>(min_mn) );
|
Chris@49
|
2613
|
Chris@49
|
2614 podarray<eT> work( static_cast<uword>(lwork ) );
|
Chris@49
|
2615 podarray<blas_int> iwork( static_cast<uword>(8*min_mn) );
|
Chris@49
|
2616
|
Chris@49
|
2617 lapack::gesdd<eT>
|
Chris@49
|
2618 (
|
Chris@49
|
2619 &jobz, &m, &n,
|
Chris@49
|
2620 A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt,
|
Chris@49
|
2621 work.memptr(), &lwork, iwork.memptr(), &info
|
Chris@49
|
2622 );
|
Chris@49
|
2623
|
Chris@49
|
2624 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
|
Chris@49
|
2625
|
Chris@49
|
2626 return (info == 0);
|
Chris@49
|
2627 }
|
Chris@49
|
2628 #else
|
Chris@49
|
2629 {
|
Chris@49
|
2630 arma_ignore(U);
|
Chris@49
|
2631 arma_ignore(S);
|
Chris@49
|
2632 arma_ignore(V);
|
Chris@49
|
2633 arma_ignore(X);
|
Chris@49
|
2634 arma_stop("svd(): use of LAPACK needs to be enabled");
|
Chris@49
|
2635 return false;
|
Chris@49
|
2636 }
|
Chris@49
|
2637 #endif
|
Chris@49
|
2638 }
|
Chris@49
|
2639
|
Chris@49
|
2640
|
Chris@49
|
2641
|
Chris@49
|
2642 template<typename T, typename T1>
|
Chris@49
|
2643 inline
|
Chris@49
|
2644 bool
|
Chris@49
|
2645 auxlib::svd_dc(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X)
|
Chris@49
|
2646 {
|
Chris@49
|
2647 arma_extra_debug_sigprint();
|
Chris@49
|
2648
|
Chris@49
|
2649 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
2650 {
|
Chris@49
|
2651 typedef std::complex<T> eT;
|
Chris@49
|
2652
|
Chris@49
|
2653 Mat<eT> A(X.get_ref());
|
Chris@49
|
2654
|
Chris@49
|
2655 if(A.is_empty())
|
Chris@49
|
2656 {
|
Chris@49
|
2657 U.eye(A.n_rows, A.n_rows);
|
Chris@49
|
2658 S.reset();
|
Chris@49
|
2659 V.eye(A.n_cols, A.n_cols);
|
Chris@49
|
2660 return true;
|
Chris@49
|
2661 }
|
Chris@49
|
2662
|
Chris@49
|
2663 U.set_size(A.n_rows, A.n_rows);
|
Chris@49
|
2664 V.set_size(A.n_cols, A.n_cols);
|
Chris@49
|
2665
|
Chris@49
|
2666 char jobz = 'A';
|
Chris@49
|
2667
|
Chris@49
|
2668 blas_int m = blas_int(A.n_rows);
|
Chris@49
|
2669 blas_int n = blas_int(A.n_cols);
|
Chris@49
|
2670 blas_int min_mn = (std::min)(m,n);
|
Chris@49
|
2671 blas_int lda = blas_int(A.n_rows);
|
Chris@49
|
2672 blas_int ldu = blas_int(U.n_rows);
|
Chris@49
|
2673 blas_int ldvt = blas_int(V.n_rows);
|
Chris@49
|
2674 blas_int lwork = 3 * (min_mn*min_mn + 2*min_mn + (std::max)(m,n));
|
Chris@49
|
2675 blas_int info = 0;
|
Chris@49
|
2676
|
Chris@49
|
2677 S.set_size( static_cast<uword>(min_mn) );
|
Chris@49
|
2678
|
Chris@49
|
2679 podarray<eT> work( static_cast<uword>(lwork ) );
|
Chris@49
|
2680 podarray<T> rwork( static_cast<uword>(5*min_mn*min_mn + 7*min_mn) );
|
Chris@49
|
2681 podarray<blas_int> iwork( static_cast<uword>(8*min_mn ) );
|
Chris@49
|
2682
|
Chris@49
|
2683 lapack::cx_gesdd<T>
|
Chris@49
|
2684 (
|
Chris@49
|
2685 &jobz, &m, &n,
|
Chris@49
|
2686 A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt,
|
Chris@49
|
2687 work.memptr(), &lwork, rwork.memptr(), iwork.memptr(), &info
|
Chris@49
|
2688 );
|
Chris@49
|
2689
|
Chris@49
|
2690 op_htrans::apply(V,V); // op_htrans will work out that an in-place transpose can be done
|
Chris@49
|
2691
|
Chris@49
|
2692 return (info == 0);
|
Chris@49
|
2693 }
|
Chris@49
|
2694 #else
|
Chris@49
|
2695 {
|
Chris@49
|
2696 arma_ignore(U);
|
Chris@49
|
2697 arma_ignore(S);
|
Chris@49
|
2698 arma_ignore(V);
|
Chris@49
|
2699 arma_ignore(X);
|
Chris@49
|
2700 arma_stop("svd(): use of LAPACK needs to be enabled");
|
Chris@49
|
2701 return false;
|
Chris@49
|
2702 }
|
Chris@49
|
2703 #endif
|
Chris@49
|
2704 }
|
Chris@49
|
2705
|
Chris@49
|
2706
|
Chris@49
|
2707
|
Chris@49
|
2708 //! Solve a system of linear equations.
|
Chris@49
|
2709 //! Assumes that A.n_rows = A.n_cols and B.n_rows = A.n_rows
|
Chris@49
|
2710 template<typename eT, typename T1>
|
Chris@49
|
2711 inline
|
Chris@49
|
2712 bool
|
Chris@49
|
2713 auxlib::solve(Mat<eT>& out, Mat<eT>& A, const Base<eT,T1>& X, const bool slow)
|
Chris@49
|
2714 {
|
Chris@49
|
2715 arma_extra_debug_sigprint();
|
Chris@49
|
2716
|
Chris@49
|
2717 bool status = false;
|
Chris@49
|
2718
|
Chris@49
|
2719 const uword A_n_rows = A.n_rows;
|
Chris@49
|
2720
|
Chris@49
|
2721 if( (A_n_rows <= 4) && (slow == false) )
|
Chris@49
|
2722 {
|
Chris@49
|
2723 Mat<eT> A_inv;
|
Chris@49
|
2724
|
Chris@49
|
2725 status = auxlib::inv_noalias_tinymat(A_inv, A, A_n_rows);
|
Chris@49
|
2726
|
Chris@49
|
2727 if(status == true)
|
Chris@49
|
2728 {
|
Chris@49
|
2729 const unwrap_check<T1> Y( X.get_ref(), out );
|
Chris@49
|
2730 const Mat<eT>& B = Y.M;
|
Chris@49
|
2731
|
Chris@49
|
2732 const uword B_n_rows = B.n_rows;
|
Chris@49
|
2733 const uword B_n_cols = B.n_cols;
|
Chris@49
|
2734
|
Chris@49
|
2735 arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" );
|
Chris@49
|
2736
|
Chris@49
|
2737 if(A.is_empty() || B.is_empty())
|
Chris@49
|
2738 {
|
Chris@49
|
2739 out.zeros(A.n_cols, B_n_cols);
|
Chris@49
|
2740 return true;
|
Chris@49
|
2741 }
|
Chris@49
|
2742
|
Chris@49
|
2743 out.set_size(A_n_rows, B_n_cols);
|
Chris@49
|
2744
|
Chris@49
|
2745 gemm_emul<false,false,false,false>::apply(out, A_inv, B);
|
Chris@49
|
2746
|
Chris@49
|
2747 return true;
|
Chris@49
|
2748 }
|
Chris@49
|
2749 }
|
Chris@49
|
2750
|
Chris@49
|
2751 if( (A_n_rows > 4) || (status == false) )
|
Chris@49
|
2752 {
|
Chris@49
|
2753 out = X.get_ref();
|
Chris@49
|
2754
|
Chris@49
|
2755 const uword B_n_rows = out.n_rows;
|
Chris@49
|
2756 const uword B_n_cols = out.n_cols;
|
Chris@49
|
2757
|
Chris@49
|
2758 arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" );
|
Chris@49
|
2759
|
Chris@49
|
2760 if(A.is_empty() || out.is_empty())
|
Chris@49
|
2761 {
|
Chris@49
|
2762 out.zeros(A.n_cols, B_n_cols);
|
Chris@49
|
2763 return true;
|
Chris@49
|
2764 }
|
Chris@49
|
2765
|
Chris@49
|
2766 #if defined(ARMA_USE_ATLAS)
|
Chris@49
|
2767 {
|
Chris@49
|
2768 podarray<int> ipiv(A_n_rows + 2); // +2 for paranoia: old versions of Atlas might be trashing memory
|
Chris@49
|
2769
|
Chris@49
|
2770 int info = atlas::clapack_gesv<eT>(atlas::CblasColMajor, A_n_rows, B_n_cols, A.memptr(), A_n_rows, ipiv.memptr(), out.memptr(), A_n_rows);
|
Chris@49
|
2771
|
Chris@49
|
2772 return (info == 0);
|
Chris@49
|
2773 }
|
Chris@49
|
2774 #elif defined(ARMA_USE_LAPACK)
|
Chris@49
|
2775 {
|
Chris@49
|
2776 blas_int n = blas_int(A_n_rows); // assuming A is square
|
Chris@49
|
2777 blas_int lda = blas_int(A_n_rows);
|
Chris@49
|
2778 blas_int ldb = blas_int(A_n_rows);
|
Chris@49
|
2779 blas_int nrhs = blas_int(B_n_cols);
|
Chris@49
|
2780 blas_int info = 0;
|
Chris@49
|
2781
|
Chris@49
|
2782 podarray<blas_int> ipiv(A_n_rows + 2); // +2 for paranoia: some versions of Lapack might be trashing memory
|
Chris@49
|
2783
|
Chris@49
|
2784 arma_extra_debug_print("lapack::gesv()");
|
Chris@49
|
2785 lapack::gesv<eT>(&n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info);
|
Chris@49
|
2786
|
Chris@49
|
2787 arma_extra_debug_print("lapack::gesv() -- finished");
|
Chris@49
|
2788
|
Chris@49
|
2789 return (info == 0);
|
Chris@49
|
2790 }
|
Chris@49
|
2791 #else
|
Chris@49
|
2792 {
|
Chris@49
|
2793 arma_stop("solve(): use of ATLAS or LAPACK needs to be enabled");
|
Chris@49
|
2794 return false;
|
Chris@49
|
2795 }
|
Chris@49
|
2796 #endif
|
Chris@49
|
2797 }
|
Chris@49
|
2798
|
Chris@49
|
2799 return true;
|
Chris@49
|
2800 }
|
Chris@49
|
2801
|
Chris@49
|
2802
|
Chris@49
|
2803
|
Chris@49
|
2804 //! Solve an over-determined system.
|
Chris@49
|
2805 //! Assumes that A.n_rows > A.n_cols and B.n_rows = A.n_rows
|
Chris@49
|
2806 template<typename eT, typename T1>
|
Chris@49
|
2807 inline
|
Chris@49
|
2808 bool
|
Chris@49
|
2809 auxlib::solve_od(Mat<eT>& out, Mat<eT>& A, const Base<eT,T1>& X)
|
Chris@49
|
2810 {
|
Chris@49
|
2811 arma_extra_debug_sigprint();
|
Chris@49
|
2812
|
Chris@49
|
2813 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
2814 {
|
Chris@49
|
2815 Mat<eT> tmp = X.get_ref();
|
Chris@49
|
2816
|
Chris@49
|
2817 const uword A_n_rows = A.n_rows;
|
Chris@49
|
2818 const uword A_n_cols = A.n_cols;
|
Chris@49
|
2819
|
Chris@49
|
2820 const uword B_n_rows = tmp.n_rows;
|
Chris@49
|
2821 const uword B_n_cols = tmp.n_cols;
|
Chris@49
|
2822
|
Chris@49
|
2823 arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" );
|
Chris@49
|
2824
|
Chris@49
|
2825 out.set_size(A_n_cols, B_n_cols);
|
Chris@49
|
2826
|
Chris@49
|
2827 if(A.is_empty() || tmp.is_empty())
|
Chris@49
|
2828 {
|
Chris@49
|
2829 out.zeros();
|
Chris@49
|
2830 return true;
|
Chris@49
|
2831 }
|
Chris@49
|
2832
|
Chris@49
|
2833 char trans = 'N';
|
Chris@49
|
2834
|
Chris@49
|
2835 blas_int m = blas_int(A_n_rows);
|
Chris@49
|
2836 blas_int n = blas_int(A_n_cols);
|
Chris@49
|
2837 blas_int lda = blas_int(A_n_rows);
|
Chris@49
|
2838 blas_int ldb = blas_int(A_n_rows);
|
Chris@49
|
2839 blas_int nrhs = blas_int(B_n_cols);
|
Chris@49
|
2840 blas_int lwork = 3 * ( (std::max)(blas_int(1), n + (std::max)(n, nrhs)) );
|
Chris@49
|
2841 blas_int info = 0;
|
Chris@49
|
2842
|
Chris@49
|
2843 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
2844
|
Chris@49
|
2845 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems
|
Chris@49
|
2846 arma_extra_debug_print("lapack::gels()");
|
Chris@49
|
2847 lapack::gels<eT>( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, work.memptr(), &lwork, &info );
|
Chris@49
|
2848
|
Chris@49
|
2849 arma_extra_debug_print("lapack::gels() -- finished");
|
Chris@49
|
2850
|
Chris@49
|
2851 for(uword col=0; col<B_n_cols; ++col)
|
Chris@49
|
2852 {
|
Chris@49
|
2853 arrayops::copy( out.colptr(col), tmp.colptr(col), A_n_cols );
|
Chris@49
|
2854 }
|
Chris@49
|
2855
|
Chris@49
|
2856 return (info == 0);
|
Chris@49
|
2857 }
|
Chris@49
|
2858 #else
|
Chris@49
|
2859 {
|
Chris@49
|
2860 arma_ignore(out);
|
Chris@49
|
2861 arma_ignore(A);
|
Chris@49
|
2862 arma_ignore(X);
|
Chris@49
|
2863 arma_stop("solve(): use of LAPACK needs to be enabled");
|
Chris@49
|
2864 return false;
|
Chris@49
|
2865 }
|
Chris@49
|
2866 #endif
|
Chris@49
|
2867 }
|
Chris@49
|
2868
|
Chris@49
|
2869
|
Chris@49
|
2870
|
Chris@49
|
2871 //! Solve an under-determined system.
|
Chris@49
|
2872 //! Assumes that A.n_rows < A.n_cols and B.n_rows = A.n_rows
|
Chris@49
|
2873 template<typename eT, typename T1>
|
Chris@49
|
2874 inline
|
Chris@49
|
2875 bool
|
Chris@49
|
2876 auxlib::solve_ud(Mat<eT>& out, Mat<eT>& A, const Base<eT,T1>& X)
|
Chris@49
|
2877 {
|
Chris@49
|
2878 arma_extra_debug_sigprint();
|
Chris@49
|
2879
|
Chris@49
|
2880 // TODO: this function provides the same results as Octave 3.4.2.
|
Chris@49
|
2881 // TODO: however, these results are different than Matlab 7.12.0.635.
|
Chris@49
|
2882 // TODO: figure out whether both Octave and Matlab are correct, or only one of them
|
Chris@49
|
2883
|
Chris@49
|
2884 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
2885 {
|
Chris@49
|
2886 const unwrap<T1> Y( X.get_ref() );
|
Chris@49
|
2887 const Mat<eT>& B = Y.M;
|
Chris@49
|
2888
|
Chris@49
|
2889 const uword A_n_rows = A.n_rows;
|
Chris@49
|
2890 const uword A_n_cols = A.n_cols;
|
Chris@49
|
2891
|
Chris@49
|
2892 const uword B_n_rows = B.n_rows;
|
Chris@49
|
2893 const uword B_n_cols = B.n_cols;
|
Chris@49
|
2894
|
Chris@49
|
2895 arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given objects must be the same" );
|
Chris@49
|
2896
|
Chris@49
|
2897 // B could be an alias of "out", hence we need to check whether B is empty before setting the size of "out"
|
Chris@49
|
2898 if(A.is_empty() || B.is_empty())
|
Chris@49
|
2899 {
|
Chris@49
|
2900 out.zeros(A_n_cols, B_n_cols);
|
Chris@49
|
2901 return true;
|
Chris@49
|
2902 }
|
Chris@49
|
2903
|
Chris@49
|
2904 char trans = 'N';
|
Chris@49
|
2905
|
Chris@49
|
2906 blas_int m = blas_int(A_n_rows);
|
Chris@49
|
2907 blas_int n = blas_int(A_n_cols);
|
Chris@49
|
2908 blas_int lda = blas_int(A_n_rows);
|
Chris@49
|
2909 blas_int ldb = blas_int(A_n_cols);
|
Chris@49
|
2910 blas_int nrhs = blas_int(B_n_cols);
|
Chris@49
|
2911 blas_int lwork = 3 * ( (std::max)(blas_int(1), m + (std::max)(m,nrhs)) );
|
Chris@49
|
2912 blas_int info = 0;
|
Chris@49
|
2913
|
Chris@49
|
2914 Mat<eT> tmp(A_n_cols, B_n_cols);
|
Chris@49
|
2915 tmp.zeros();
|
Chris@49
|
2916
|
Chris@49
|
2917 for(uword col=0; col<B_n_cols; ++col)
|
Chris@49
|
2918 {
|
Chris@49
|
2919 eT* tmp_colmem = tmp.colptr(col);
|
Chris@49
|
2920
|
Chris@49
|
2921 arrayops::copy( tmp_colmem, B.colptr(col), B_n_rows );
|
Chris@49
|
2922
|
Chris@49
|
2923 for(uword row=B_n_rows; row<A_n_cols; ++row)
|
Chris@49
|
2924 {
|
Chris@49
|
2925 tmp_colmem[row] = eT(0);
|
Chris@49
|
2926 }
|
Chris@49
|
2927 }
|
Chris@49
|
2928
|
Chris@49
|
2929 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
2930
|
Chris@49
|
2931 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems
|
Chris@49
|
2932 arma_extra_debug_print("lapack::gels()");
|
Chris@49
|
2933 lapack::gels<eT>( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, work.memptr(), &lwork, &info );
|
Chris@49
|
2934
|
Chris@49
|
2935 arma_extra_debug_print("lapack::gels() -- finished");
|
Chris@49
|
2936
|
Chris@49
|
2937 out.set_size(A_n_cols, B_n_cols);
|
Chris@49
|
2938
|
Chris@49
|
2939 for(uword col=0; col<B_n_cols; ++col)
|
Chris@49
|
2940 {
|
Chris@49
|
2941 arrayops::copy( out.colptr(col), tmp.colptr(col), A_n_cols );
|
Chris@49
|
2942 }
|
Chris@49
|
2943
|
Chris@49
|
2944 return (info == 0);
|
Chris@49
|
2945 }
|
Chris@49
|
2946 #else
|
Chris@49
|
2947 {
|
Chris@49
|
2948 arma_ignore(out);
|
Chris@49
|
2949 arma_ignore(A);
|
Chris@49
|
2950 arma_ignore(X);
|
Chris@49
|
2951 arma_stop("solve(): use of LAPACK needs to be enabled");
|
Chris@49
|
2952 return false;
|
Chris@49
|
2953 }
|
Chris@49
|
2954 #endif
|
Chris@49
|
2955 }
|
Chris@49
|
2956
|
Chris@49
|
2957
|
Chris@49
|
2958
|
Chris@49
|
2959 //
|
Chris@49
|
2960 // solve_tr
|
Chris@49
|
2961
|
Chris@49
|
2962 template<typename eT>
|
Chris@49
|
2963 inline
|
Chris@49
|
2964 bool
|
Chris@49
|
2965 auxlib::solve_tr(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const uword layout)
|
Chris@49
|
2966 {
|
Chris@49
|
2967 arma_extra_debug_sigprint();
|
Chris@49
|
2968
|
Chris@49
|
2969 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
2970 {
|
Chris@49
|
2971 if(A.is_empty() || B.is_empty())
|
Chris@49
|
2972 {
|
Chris@49
|
2973 out.zeros(A.n_cols, B.n_cols);
|
Chris@49
|
2974 return true;
|
Chris@49
|
2975 }
|
Chris@49
|
2976
|
Chris@49
|
2977 out = B;
|
Chris@49
|
2978
|
Chris@49
|
2979 char uplo = (layout == 0) ? 'U' : 'L';
|
Chris@49
|
2980 char trans = 'N';
|
Chris@49
|
2981 char diag = 'N';
|
Chris@49
|
2982 blas_int n = blas_int(A.n_rows);
|
Chris@49
|
2983 blas_int nrhs = blas_int(B.n_cols);
|
Chris@49
|
2984 blas_int info = 0;
|
Chris@49
|
2985
|
Chris@49
|
2986 lapack::trtrs<eT>(&uplo, &trans, &diag, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info);
|
Chris@49
|
2987
|
Chris@49
|
2988 return (info == 0);
|
Chris@49
|
2989 }
|
Chris@49
|
2990 #else
|
Chris@49
|
2991 {
|
Chris@49
|
2992 arma_ignore(out);
|
Chris@49
|
2993 arma_ignore(A);
|
Chris@49
|
2994 arma_ignore(B);
|
Chris@49
|
2995 arma_ignore(layout);
|
Chris@49
|
2996 arma_stop("solve(): use of LAPACK needs to be enabled");
|
Chris@49
|
2997 return false;
|
Chris@49
|
2998 }
|
Chris@49
|
2999 #endif
|
Chris@49
|
3000 }
|
Chris@49
|
3001
|
Chris@49
|
3002
|
Chris@49
|
3003
|
Chris@49
|
3004 //
|
Chris@49
|
3005 // Schur decomposition
|
Chris@49
|
3006
|
Chris@49
|
3007 template<typename eT>
|
Chris@49
|
3008 inline
|
Chris@49
|
3009 bool
|
Chris@49
|
3010 auxlib::schur_dec(Mat<eT>& Z, Mat<eT>& T, const Mat<eT>& A)
|
Chris@49
|
3011 {
|
Chris@49
|
3012 arma_extra_debug_sigprint();
|
Chris@49
|
3013
|
Chris@49
|
3014 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
3015 {
|
Chris@49
|
3016 arma_debug_check( (A.is_square() == false), "schur_dec(): given matrix is not square" );
|
Chris@49
|
3017
|
Chris@49
|
3018 if(A.is_empty())
|
Chris@49
|
3019 {
|
Chris@49
|
3020 Z.reset();
|
Chris@49
|
3021 T.reset();
|
Chris@49
|
3022 return true;
|
Chris@49
|
3023 }
|
Chris@49
|
3024
|
Chris@49
|
3025 const uword A_n_rows = A.n_rows;
|
Chris@49
|
3026
|
Chris@49
|
3027 Z.set_size(A_n_rows, A_n_rows);
|
Chris@49
|
3028 T = A;
|
Chris@49
|
3029
|
Chris@49
|
3030 char jobvs = 'V'; // get Schur vectors (Z)
|
Chris@49
|
3031 char sort = 'N'; // do not sort eigenvalues/vectors
|
Chris@49
|
3032 blas_int* select = 0; // pointer to sorting function
|
Chris@49
|
3033 blas_int n = blas_int(A_n_rows);
|
Chris@49
|
3034 blas_int sdim = 0; // output for sorting
|
Chris@49
|
3035 blas_int lwork = 3 * ( (std::max)(blas_int(1), 3*n) );
|
Chris@49
|
3036 blas_int info = 0;
|
Chris@49
|
3037
|
Chris@49
|
3038 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
3039 podarray<blas_int> bwork(A_n_rows);
|
Chris@49
|
3040
|
Chris@49
|
3041 podarray<eT> wr(A_n_rows); // output for eigenvalues
|
Chris@49
|
3042 podarray<eT> wi(A_n_rows); // output for eigenvalues
|
Chris@49
|
3043
|
Chris@49
|
3044 lapack::gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, wr.memptr(), wi.memptr(), Z.memptr(), &n, work.memptr(), &lwork, bwork.memptr(), &info);
|
Chris@49
|
3045
|
Chris@49
|
3046 return (info == 0);
|
Chris@49
|
3047 }
|
Chris@49
|
3048 #else
|
Chris@49
|
3049 {
|
Chris@49
|
3050 arma_ignore(Z);
|
Chris@49
|
3051 arma_ignore(T);
|
Chris@49
|
3052 arma_ignore(A);
|
Chris@49
|
3053
|
Chris@49
|
3054 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
|
Chris@49
|
3055 return false;
|
Chris@49
|
3056 }
|
Chris@49
|
3057 #endif
|
Chris@49
|
3058 }
|
Chris@49
|
3059
|
Chris@49
|
3060
|
Chris@49
|
3061
|
Chris@49
|
3062 template<typename cT>
|
Chris@49
|
3063 inline
|
Chris@49
|
3064 bool
|
Chris@49
|
3065 auxlib::schur_dec(Mat<std::complex<cT> >& Z, Mat<std::complex<cT> >& T, const Mat<std::complex<cT> >& A)
|
Chris@49
|
3066 {
|
Chris@49
|
3067 arma_extra_debug_sigprint();
|
Chris@49
|
3068
|
Chris@49
|
3069 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
3070 {
|
Chris@49
|
3071 arma_debug_check( (A.is_square() == false), "schur_dec(): matrix A is not square" );
|
Chris@49
|
3072
|
Chris@49
|
3073 if(A.is_empty())
|
Chris@49
|
3074 {
|
Chris@49
|
3075 Z.reset();
|
Chris@49
|
3076 T.reset();
|
Chris@49
|
3077 return true;
|
Chris@49
|
3078 }
|
Chris@49
|
3079
|
Chris@49
|
3080 typedef std::complex<cT> eT;
|
Chris@49
|
3081
|
Chris@49
|
3082 const uword A_n_rows = A.n_rows;
|
Chris@49
|
3083
|
Chris@49
|
3084 Z.set_size(A_n_rows, A_n_rows);
|
Chris@49
|
3085 T = A;
|
Chris@49
|
3086
|
Chris@49
|
3087 char jobvs = 'V'; // get Schur vectors (Z)
|
Chris@49
|
3088 char sort = 'N'; // do not sort eigenvalues/vectors
|
Chris@49
|
3089 blas_int* select = 0; // pointer to sorting function
|
Chris@49
|
3090 blas_int n = blas_int(A_n_rows);
|
Chris@49
|
3091 blas_int sdim = 0; // output for sorting
|
Chris@49
|
3092 blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*n) );
|
Chris@49
|
3093 blas_int info = 0;
|
Chris@49
|
3094
|
Chris@49
|
3095 podarray<eT> work( static_cast<uword>(lwork) );
|
Chris@49
|
3096 podarray<blas_int> bwork(A_n_rows);
|
Chris@49
|
3097
|
Chris@49
|
3098 podarray<eT> w(A_n_rows); // output for eigenvalues
|
Chris@49
|
3099 podarray<cT> rwork(A_n_rows);
|
Chris@49
|
3100
|
Chris@49
|
3101 lapack::cx_gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, w.memptr(), Z.memptr(), &n, work.memptr(), &lwork, rwork.memptr(), bwork.memptr(), &info);
|
Chris@49
|
3102
|
Chris@49
|
3103 return (info == 0);
|
Chris@49
|
3104 }
|
Chris@49
|
3105 #else
|
Chris@49
|
3106 {
|
Chris@49
|
3107 arma_ignore(Z);
|
Chris@49
|
3108 arma_ignore(T);
|
Chris@49
|
3109 arma_ignore(A);
|
Chris@49
|
3110
|
Chris@49
|
3111 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
|
Chris@49
|
3112 return false;
|
Chris@49
|
3113 }
|
Chris@49
|
3114 #endif
|
Chris@49
|
3115 }
|
Chris@49
|
3116
|
Chris@49
|
3117
|
Chris@49
|
3118
|
Chris@49
|
3119 //
|
Chris@49
|
3120 // syl (solution of the Sylvester equation AX + XB = C)
|
Chris@49
|
3121
|
Chris@49
|
3122 template<typename eT>
|
Chris@49
|
3123 inline
|
Chris@49
|
3124 bool
|
Chris@49
|
3125 auxlib::syl(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C)
|
Chris@49
|
3126 {
|
Chris@49
|
3127 arma_extra_debug_sigprint();
|
Chris@49
|
3128
|
Chris@49
|
3129 arma_debug_check
|
Chris@49
|
3130 (
|
Chris@49
|
3131 (A.is_square() == false) || (B.is_square() == false),
|
Chris@49
|
3132 "syl(): given matrix is not square"
|
Chris@49
|
3133 );
|
Chris@49
|
3134
|
Chris@49
|
3135 arma_debug_check
|
Chris@49
|
3136 (
|
Chris@49
|
3137 (C.n_rows != A.n_rows) || (C.n_cols != B.n_cols),
|
Chris@49
|
3138 "syl(): matrices are not conformant"
|
Chris@49
|
3139 );
|
Chris@49
|
3140
|
Chris@49
|
3141 if(A.is_empty() || B.is_empty() || C.is_empty())
|
Chris@49
|
3142 {
|
Chris@49
|
3143 X.reset();
|
Chris@49
|
3144 return true;
|
Chris@49
|
3145 }
|
Chris@49
|
3146
|
Chris@49
|
3147 #if defined(ARMA_USE_LAPACK)
|
Chris@49
|
3148 {
|
Chris@49
|
3149 Mat<eT> Z1, Z2, T1, T2;
|
Chris@49
|
3150
|
Chris@49
|
3151 const bool status_sd1 = auxlib::schur_dec(Z1, T1, A);
|
Chris@49
|
3152 const bool status_sd2 = auxlib::schur_dec(Z2, T2, B);
|
Chris@49
|
3153
|
Chris@49
|
3154 if( (status_sd1 == false) || (status_sd2 == false) )
|
Chris@49
|
3155 {
|
Chris@49
|
3156 return false;
|
Chris@49
|
3157 }
|
Chris@49
|
3158
|
Chris@49
|
3159 char trana = 'N';
|
Chris@49
|
3160 char tranb = 'N';
|
Chris@49
|
3161 blas_int isgn = +1;
|
Chris@49
|
3162 blas_int m = blas_int(T1.n_rows);
|
Chris@49
|
3163 blas_int n = blas_int(T2.n_cols);
|
Chris@49
|
3164
|
Chris@49
|
3165 eT scale = eT(0);
|
Chris@49
|
3166 blas_int info = 0;
|
Chris@49
|
3167
|
Chris@49
|
3168 Mat<eT> Y = trans(Z1) * C * Z2;
|
Chris@49
|
3169
|
Chris@49
|
3170 lapack::trsyl<eT>(&trana, &tranb, &isgn, &m, &n, T1.memptr(), &m, T2.memptr(), &n, Y.memptr(), &m, &scale, &info);
|
Chris@49
|
3171
|
Chris@49
|
3172 //Y /= scale;
|
Chris@49
|
3173 Y /= (-scale);
|
Chris@49
|
3174
|
Chris@49
|
3175 X = Z1 * Y * trans(Z2);
|
Chris@49
|
3176
|
Chris@49
|
3177 return (info >= 0);
|
Chris@49
|
3178 }
|
Chris@49
|
3179 #else
|
Chris@49
|
3180 {
|
Chris@49
|
3181 arma_stop("syl(): use of LAPACK needs to be enabled");
|
Chris@49
|
3182 return false;
|
Chris@49
|
3183 }
|
Chris@49
|
3184 #endif
|
Chris@49
|
3185 }
|
Chris@49
|
3186
|
Chris@49
|
3187
|
Chris@49
|
3188
|
Chris@49
|
3189 //
|
Chris@49
|
3190 // lyap (solution of the continuous Lyapunov equation AX + XA^H + Q = 0)
|
Chris@49
|
3191
|
Chris@49
|
3192 template<typename eT>
|
Chris@49
|
3193 inline
|
Chris@49
|
3194 bool
|
Chris@49
|
3195 auxlib::lyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
|
Chris@49
|
3196 {
|
Chris@49
|
3197 arma_extra_debug_sigprint();
|
Chris@49
|
3198
|
Chris@49
|
3199 arma_debug_check( (A.is_square() == false), "lyap(): matrix A is not square");
|
Chris@49
|
3200 arma_debug_check( (Q.is_square() == false), "lyap(): matrix Q is not square");
|
Chris@49
|
3201 arma_debug_check( (A.n_rows != Q.n_rows), "lyap(): matrices A and Q have different dimensions");
|
Chris@49
|
3202
|
Chris@49
|
3203 Mat<eT> htransA;
|
Chris@49
|
3204 op_htrans::apply_noalias(htransA, A);
|
Chris@49
|
3205
|
Chris@49
|
3206 const Mat<eT> mQ = -Q;
|
Chris@49
|
3207
|
Chris@49
|
3208 return auxlib::syl(X, A, htransA, mQ);
|
Chris@49
|
3209 }
|
Chris@49
|
3210
|
Chris@49
|
3211
|
Chris@49
|
3212
|
Chris@49
|
3213 //
|
Chris@49
|
3214 // dlyap (solution of the discrete Lyapunov equation AXA^H - X + Q = 0)
|
Chris@49
|
3215
|
Chris@49
|
3216 template<typename eT>
|
Chris@49
|
3217 inline
|
Chris@49
|
3218 bool
|
Chris@49
|
3219 auxlib::dlyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
|
Chris@49
|
3220 {
|
Chris@49
|
3221 arma_extra_debug_sigprint();
|
Chris@49
|
3222
|
Chris@49
|
3223 arma_debug_check( (A.is_square() == false), "dlyap(): matrix A is not square");
|
Chris@49
|
3224 arma_debug_check( (Q.is_square() == false), "dlyap(): matrix Q is not square");
|
Chris@49
|
3225 arma_debug_check( (A.n_rows != Q.n_rows), "dlyap(): matrices A and Q have different dimensions");
|
Chris@49
|
3226
|
Chris@49
|
3227 const Col<eT> vecQ = reshape(Q, Q.n_elem, 1);
|
Chris@49
|
3228
|
Chris@49
|
3229 const Mat<eT> M = eye< Mat<eT> >(Q.n_elem, Q.n_elem) - kron(conj(A), A);
|
Chris@49
|
3230
|
Chris@49
|
3231 Col<eT> vecX;
|
Chris@49
|
3232
|
Chris@49
|
3233 const bool status = solve(vecX, M, vecQ);
|
Chris@49
|
3234
|
Chris@49
|
3235 if(status == true)
|
Chris@49
|
3236 {
|
Chris@49
|
3237 X = reshape(vecX, Q.n_rows, Q.n_cols);
|
Chris@49
|
3238 return true;
|
Chris@49
|
3239 }
|
Chris@49
|
3240 else
|
Chris@49
|
3241 {
|
Chris@49
|
3242 X.reset();
|
Chris@49
|
3243 return false;
|
Chris@49
|
3244 }
|
Chris@49
|
3245 }
|
Chris@49
|
3246
|
Chris@49
|
3247
|
Chris@49
|
3248
|
Chris@49
|
3249 //! @}
|