max@0
|
1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2008-2012 Conrad Sanderson
|
max@0
|
3 // Copyright (C) 2009 Edmund Highcock
|
max@0
|
4 // Copyright (C) 2011 James Sanders
|
max@0
|
5 // Copyright (C) 2011 Stanislav Funiak
|
max@0
|
6 //
|
max@0
|
7 // This file is part of the Armadillo C++ library.
|
max@0
|
8 // It is provided without any warranty of fitness
|
max@0
|
9 // for any purpose. You can redistribute this file
|
max@0
|
10 // and/or modify it under the terms of the GNU
|
max@0
|
11 // Lesser General Public License (LGPL) as published
|
max@0
|
12 // by the Free Software Foundation, either version 3
|
max@0
|
13 // of the License or (at your option) any later version.
|
max@0
|
14 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
15
|
max@0
|
16
|
max@0
|
17 //! \addtogroup auxlib
|
max@0
|
18 //! @{
|
max@0
|
19
|
max@0
|
20
|
max@0
|
21
|
max@0
|
22 //! immediate matrix inverse
|
max@0
|
23 template<typename eT, typename T1>
|
max@0
|
24 inline
|
max@0
|
25 bool
|
max@0
|
26 auxlib::inv(Mat<eT>& out, const Base<eT,T1>& X, const bool slow)
|
max@0
|
27 {
|
max@0
|
28 arma_extra_debug_sigprint();
|
max@0
|
29
|
max@0
|
30 out = X.get_ref();
|
max@0
|
31
|
max@0
|
32 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
|
max@0
|
33
|
max@0
|
34 bool status = false;
|
max@0
|
35
|
max@0
|
36 const uword N = out.n_rows;
|
max@0
|
37
|
max@0
|
38 if( (N <= 4) && (slow == false) )
|
max@0
|
39 {
|
max@0
|
40 status = auxlib::inv_inplace_tinymat(out, N);
|
max@0
|
41 }
|
max@0
|
42
|
max@0
|
43 if( (N > 4) || (status == false) )
|
max@0
|
44 {
|
max@0
|
45 status = auxlib::inv_inplace_lapack(out);
|
max@0
|
46 }
|
max@0
|
47
|
max@0
|
48 return status;
|
max@0
|
49 }
|
max@0
|
50
|
max@0
|
51
|
max@0
|
52
|
max@0
|
53 template<typename eT>
|
max@0
|
54 inline
|
max@0
|
55 bool
|
max@0
|
56 auxlib::inv(Mat<eT>& out, const Mat<eT>& X, const bool slow)
|
max@0
|
57 {
|
max@0
|
58 arma_extra_debug_sigprint();
|
max@0
|
59
|
max@0
|
60 arma_debug_check( (X.is_square() == false), "inv(): given matrix is not square" );
|
max@0
|
61
|
max@0
|
62 bool status = false;
|
max@0
|
63
|
max@0
|
64 const uword N = X.n_rows;
|
max@0
|
65
|
max@0
|
66 if( (N <= 4) && (slow == false) )
|
max@0
|
67 {
|
max@0
|
68 status = (&out != &X) ? auxlib::inv_noalias_tinymat(out, X, N) : auxlib::inv_inplace_tinymat(out, N);
|
max@0
|
69 }
|
max@0
|
70
|
max@0
|
71 if( (N > 4) || (status == false) )
|
max@0
|
72 {
|
max@0
|
73 out = X;
|
max@0
|
74 status = auxlib::inv_inplace_lapack(out);
|
max@0
|
75 }
|
max@0
|
76
|
max@0
|
77 return status;
|
max@0
|
78 }
|
max@0
|
79
|
max@0
|
80
|
max@0
|
81
|
max@0
|
82 template<typename eT>
|
max@0
|
83 inline
|
max@0
|
84 bool
|
max@0
|
85 auxlib::inv_noalias_tinymat(Mat<eT>& out, const Mat<eT>& X, const uword N)
|
max@0
|
86 {
|
max@0
|
87 arma_extra_debug_sigprint();
|
max@0
|
88
|
max@0
|
89 bool det_ok = true;
|
max@0
|
90
|
max@0
|
91 out.set_size(N,N);
|
max@0
|
92
|
max@0
|
93 switch(N)
|
max@0
|
94 {
|
max@0
|
95 case 1:
|
max@0
|
96 {
|
max@0
|
97 out[0] = eT(1) / X[0];
|
max@0
|
98 };
|
max@0
|
99 break;
|
max@0
|
100
|
max@0
|
101 case 2:
|
max@0
|
102 {
|
max@0
|
103 const eT* Xm = X.memptr();
|
max@0
|
104
|
max@0
|
105 const eT a = Xm[pos<0,0>::n2];
|
max@0
|
106 const eT b = Xm[pos<0,1>::n2];
|
max@0
|
107 const eT c = Xm[pos<1,0>::n2];
|
max@0
|
108 const eT d = Xm[pos<1,1>::n2];
|
max@0
|
109
|
max@0
|
110 const eT tmp_det = (a*d - b*c);
|
max@0
|
111
|
max@0
|
112 if(tmp_det != eT(0))
|
max@0
|
113 {
|
max@0
|
114 eT* outm = out.memptr();
|
max@0
|
115
|
max@0
|
116 outm[pos<0,0>::n2] = d / tmp_det;
|
max@0
|
117 outm[pos<0,1>::n2] = -b / tmp_det;
|
max@0
|
118 outm[pos<1,0>::n2] = -c / tmp_det;
|
max@0
|
119 outm[pos<1,1>::n2] = a / tmp_det;
|
max@0
|
120 }
|
max@0
|
121 else
|
max@0
|
122 {
|
max@0
|
123 det_ok = false;
|
max@0
|
124 }
|
max@0
|
125 };
|
max@0
|
126 break;
|
max@0
|
127
|
max@0
|
128 case 3:
|
max@0
|
129 {
|
max@0
|
130 const eT* X_col0 = X.colptr(0);
|
max@0
|
131 const eT a11 = X_col0[0];
|
max@0
|
132 const eT a21 = X_col0[1];
|
max@0
|
133 const eT a31 = X_col0[2];
|
max@0
|
134
|
max@0
|
135 const eT* X_col1 = X.colptr(1);
|
max@0
|
136 const eT a12 = X_col1[0];
|
max@0
|
137 const eT a22 = X_col1[1];
|
max@0
|
138 const eT a32 = X_col1[2];
|
max@0
|
139
|
max@0
|
140 const eT* X_col2 = X.colptr(2);
|
max@0
|
141 const eT a13 = X_col2[0];
|
max@0
|
142 const eT a23 = X_col2[1];
|
max@0
|
143 const eT a33 = X_col2[2];
|
max@0
|
144
|
max@0
|
145 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
|
max@0
|
146
|
max@0
|
147 if(tmp_det != eT(0))
|
max@0
|
148 {
|
max@0
|
149 eT* out_col0 = out.colptr(0);
|
max@0
|
150 out_col0[0] = (a33*a22 - a32*a23) / tmp_det;
|
max@0
|
151 out_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
|
max@0
|
152 out_col0[2] = (a32*a21 - a31*a22) / tmp_det;
|
max@0
|
153
|
max@0
|
154 eT* out_col1 = out.colptr(1);
|
max@0
|
155 out_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
|
max@0
|
156 out_col1[1] = (a33*a11 - a31*a13) / tmp_det;
|
max@0
|
157 out_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
|
max@0
|
158
|
max@0
|
159 eT* out_col2 = out.colptr(2);
|
max@0
|
160 out_col2[0] = (a23*a12 - a22*a13) / tmp_det;
|
max@0
|
161 out_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
|
max@0
|
162 out_col2[2] = (a22*a11 - a21*a12) / tmp_det;
|
max@0
|
163 }
|
max@0
|
164 else
|
max@0
|
165 {
|
max@0
|
166 det_ok = false;
|
max@0
|
167 }
|
max@0
|
168 };
|
max@0
|
169 break;
|
max@0
|
170
|
max@0
|
171 case 4:
|
max@0
|
172 {
|
max@0
|
173 const eT tmp_det = det(X);
|
max@0
|
174
|
max@0
|
175 if(tmp_det != eT(0))
|
max@0
|
176 {
|
max@0
|
177 const eT* Xm = X.memptr();
|
max@0
|
178 eT* outm = out.memptr();
|
max@0
|
179
|
max@0
|
180 outm[pos<0,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
181 outm[pos<1,0>::n4] = ( Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
182 outm[pos<2,0>::n4] = ( Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
183 outm[pos<3,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
|
max@0
|
184
|
max@0
|
185 outm[pos<0,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
186 outm[pos<1,1>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
187 outm[pos<2,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
188 outm[pos<3,1>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
|
max@0
|
189
|
max@0
|
190 outm[pos<0,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
191 outm[pos<1,2>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
192 outm[pos<2,2>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
193 outm[pos<3,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] ) / tmp_det;
|
max@0
|
194
|
max@0
|
195 outm[pos<0,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
|
max@0
|
196 outm[pos<1,3>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
|
max@0
|
197 outm[pos<2,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] ) / tmp_det;
|
max@0
|
198 outm[pos<3,3>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] ) / tmp_det;
|
max@0
|
199 }
|
max@0
|
200 else
|
max@0
|
201 {
|
max@0
|
202 det_ok = false;
|
max@0
|
203 }
|
max@0
|
204 };
|
max@0
|
205 break;
|
max@0
|
206
|
max@0
|
207 default:
|
max@0
|
208 ;
|
max@0
|
209 }
|
max@0
|
210
|
max@0
|
211 return det_ok;
|
max@0
|
212 }
|
max@0
|
213
|
max@0
|
214
|
max@0
|
215
|
max@0
|
216 template<typename eT>
|
max@0
|
217 inline
|
max@0
|
218 bool
|
max@0
|
219 auxlib::inv_inplace_tinymat(Mat<eT>& X, const uword N)
|
max@0
|
220 {
|
max@0
|
221 arma_extra_debug_sigprint();
|
max@0
|
222
|
max@0
|
223 bool det_ok = true;
|
max@0
|
224
|
max@0
|
225 // for more info, see:
|
max@0
|
226 // http://www.dr-lex.34sp.com/random/matrix_inv.html
|
max@0
|
227 // http://www.cvl.iis.u-tokyo.ac.jp/~miyazaki/tech/teche23.html
|
max@0
|
228 // http://www.euclideanspace.com/maths/algebra/matrix/functions/inverse/fourD/index.htm
|
max@0
|
229 // http://www.geometrictools.com//LibFoundation/Mathematics/Wm4Matrix4.inl
|
max@0
|
230
|
max@0
|
231 switch(N)
|
max@0
|
232 {
|
max@0
|
233 case 1:
|
max@0
|
234 {
|
max@0
|
235 X[0] = eT(1) / X[0];
|
max@0
|
236 };
|
max@0
|
237 break;
|
max@0
|
238
|
max@0
|
239 case 2:
|
max@0
|
240 {
|
max@0
|
241 const eT a = X[pos<0,0>::n2];
|
max@0
|
242 const eT b = X[pos<0,1>::n2];
|
max@0
|
243 const eT c = X[pos<1,0>::n2];
|
max@0
|
244 const eT d = X[pos<1,1>::n2];
|
max@0
|
245
|
max@0
|
246 const eT tmp_det = (a*d - b*c);
|
max@0
|
247
|
max@0
|
248 if(tmp_det != eT(0))
|
max@0
|
249 {
|
max@0
|
250 X[pos<0,0>::n2] = d / tmp_det;
|
max@0
|
251 X[pos<0,1>::n2] = -b / tmp_det;
|
max@0
|
252 X[pos<1,0>::n2] = -c / tmp_det;
|
max@0
|
253 X[pos<1,1>::n2] = a / tmp_det;
|
max@0
|
254 }
|
max@0
|
255 else
|
max@0
|
256 {
|
max@0
|
257 det_ok = false;
|
max@0
|
258 }
|
max@0
|
259 };
|
max@0
|
260 break;
|
max@0
|
261
|
max@0
|
262 case 3:
|
max@0
|
263 {
|
max@0
|
264 eT* X_col0 = X.colptr(0);
|
max@0
|
265 eT* X_col1 = X.colptr(1);
|
max@0
|
266 eT* X_col2 = X.colptr(2);
|
max@0
|
267
|
max@0
|
268 const eT a11 = X_col0[0];
|
max@0
|
269 const eT a21 = X_col0[1];
|
max@0
|
270 const eT a31 = X_col0[2];
|
max@0
|
271
|
max@0
|
272 const eT a12 = X_col1[0];
|
max@0
|
273 const eT a22 = X_col1[1];
|
max@0
|
274 const eT a32 = X_col1[2];
|
max@0
|
275
|
max@0
|
276 const eT a13 = X_col2[0];
|
max@0
|
277 const eT a23 = X_col2[1];
|
max@0
|
278 const eT a33 = X_col2[2];
|
max@0
|
279
|
max@0
|
280 const eT tmp_det = a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13);
|
max@0
|
281
|
max@0
|
282 if(tmp_det != eT(0))
|
max@0
|
283 {
|
max@0
|
284 X_col0[0] = (a33*a22 - a32*a23) / tmp_det;
|
max@0
|
285 X_col0[1] = -(a33*a21 - a31*a23) / tmp_det;
|
max@0
|
286 X_col0[2] = (a32*a21 - a31*a22) / tmp_det;
|
max@0
|
287
|
max@0
|
288 X_col1[0] = -(a33*a12 - a32*a13) / tmp_det;
|
max@0
|
289 X_col1[1] = (a33*a11 - a31*a13) / tmp_det;
|
max@0
|
290 X_col1[2] = -(a32*a11 - a31*a12) / tmp_det;
|
max@0
|
291
|
max@0
|
292 X_col2[0] = (a23*a12 - a22*a13) / tmp_det;
|
max@0
|
293 X_col2[1] = -(a23*a11 - a21*a13) / tmp_det;
|
max@0
|
294 X_col2[2] = (a22*a11 - a21*a12) / tmp_det;
|
max@0
|
295 }
|
max@0
|
296 else
|
max@0
|
297 {
|
max@0
|
298 det_ok = false;
|
max@0
|
299 }
|
max@0
|
300 };
|
max@0
|
301 break;
|
max@0
|
302
|
max@0
|
303 case 4:
|
max@0
|
304 {
|
max@0
|
305 const eT tmp_det = det(X);
|
max@0
|
306
|
max@0
|
307 if(tmp_det != eT(0))
|
max@0
|
308 {
|
max@0
|
309 const Mat<eT> A(X);
|
max@0
|
310
|
max@0
|
311 const eT* Am = A.memptr();
|
max@0
|
312 eT* Xm = X.memptr();
|
max@0
|
313
|
max@0
|
314 Xm[pos<0,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] - Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] + Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
315 Xm[pos<1,0>::n4] = ( Am[pos<1,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
316 Xm[pos<2,0>::n4] = ( Am[pos<1,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<1,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<1,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<1,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
317 Xm[pos<3,0>::n4] = ( Am[pos<1,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<1,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<1,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<1,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] + Am[pos<1,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<1,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
|
max@0
|
318
|
max@0
|
319 Xm[pos<0,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] + Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] - Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
320 Xm[pos<1,1>::n4] = ( Am[pos<0,2>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
321 Xm[pos<2,1>::n4] = ( Am[pos<0,3>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<2,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<2,3>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
322 Xm[pos<3,1>::n4] = ( Am[pos<0,1>::n4]*Am[pos<2,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<2,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,2>::n4]*Am[pos<2,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<2,2>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<2,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<2,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
|
max@0
|
323
|
max@0
|
324 Xm[pos<0,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] + Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
325 Xm[pos<1,2>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
326 Xm[pos<2,2>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<3,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<3,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,3>::n4] ) / tmp_det;
|
max@0
|
327 Xm[pos<3,2>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<3,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<3,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<3,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<3,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<3,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<3,2>::n4] ) / tmp_det;
|
max@0
|
328
|
max@0
|
329 Xm[pos<0,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] - Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] + Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] + Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] - Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
|
max@0
|
330 Xm[pos<1,3>::n4] = ( Am[pos<0,2>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] + Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] - Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,2>::n4] - Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] + Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
|
max@0
|
331 Xm[pos<2,3>::n4] = ( Am[pos<0,3>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] - Am[pos<0,1>::n4]*Am[pos<1,3>::n4]*Am[pos<2,0>::n4] - Am[pos<0,3>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] + Am[pos<0,0>::n4]*Am[pos<1,3>::n4]*Am[pos<2,1>::n4] + Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,3>::n4] - Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,3>::n4] ) / tmp_det;
|
max@0
|
332 Xm[pos<3,3>::n4] = ( Am[pos<0,1>::n4]*Am[pos<1,2>::n4]*Am[pos<2,0>::n4] - Am[pos<0,2>::n4]*Am[pos<1,1>::n4]*Am[pos<2,0>::n4] + Am[pos<0,2>::n4]*Am[pos<1,0>::n4]*Am[pos<2,1>::n4] - Am[pos<0,0>::n4]*Am[pos<1,2>::n4]*Am[pos<2,1>::n4] - Am[pos<0,1>::n4]*Am[pos<1,0>::n4]*Am[pos<2,2>::n4] + Am[pos<0,0>::n4]*Am[pos<1,1>::n4]*Am[pos<2,2>::n4] ) / tmp_det;
|
max@0
|
333 }
|
max@0
|
334 else
|
max@0
|
335 {
|
max@0
|
336 det_ok = false;
|
max@0
|
337 }
|
max@0
|
338 };
|
max@0
|
339 break;
|
max@0
|
340
|
max@0
|
341 default:
|
max@0
|
342 ;
|
max@0
|
343 }
|
max@0
|
344
|
max@0
|
345 return det_ok;
|
max@0
|
346 }
|
max@0
|
347
|
max@0
|
348
|
max@0
|
349
|
max@0
|
350 template<typename eT>
|
max@0
|
351 inline
|
max@0
|
352 bool
|
max@0
|
353 auxlib::inv_inplace_lapack(Mat<eT>& out)
|
max@0
|
354 {
|
max@0
|
355 arma_extra_debug_sigprint();
|
max@0
|
356
|
max@0
|
357 if(out.is_empty())
|
max@0
|
358 {
|
max@0
|
359 return true;
|
max@0
|
360 }
|
max@0
|
361
|
max@0
|
362 #if defined(ARMA_USE_ATLAS)
|
max@0
|
363 {
|
max@0
|
364 podarray<int> ipiv(out.n_rows);
|
max@0
|
365
|
max@0
|
366 int info = atlas::clapack_getrf(atlas::CblasColMajor, out.n_rows, out.n_cols, out.memptr(), out.n_rows, ipiv.memptr());
|
max@0
|
367
|
max@0
|
368 if(info == 0)
|
max@0
|
369 {
|
max@0
|
370 info = atlas::clapack_getri(atlas::CblasColMajor, out.n_rows, out.memptr(), out.n_rows, ipiv.memptr());
|
max@0
|
371 }
|
max@0
|
372
|
max@0
|
373 return (info == 0);
|
max@0
|
374 }
|
max@0
|
375 #elif defined(ARMA_USE_LAPACK)
|
max@0
|
376 {
|
max@0
|
377 blas_int n_rows = out.n_rows;
|
max@0
|
378 blas_int n_cols = out.n_cols;
|
max@0
|
379 blas_int info = 0;
|
max@0
|
380
|
max@0
|
381 podarray<blas_int> ipiv(out.n_rows);
|
max@0
|
382
|
max@0
|
383 // 84 was empirically found -- it is the maximum value suggested by LAPACK (as provided by ATLAS v3.6)
|
max@0
|
384 // based on tests with various matrix types on 32-bit and 64-bit machines
|
max@0
|
385 //
|
max@0
|
386 // the "work" array is deliberately long so that a secondary (time-consuming)
|
max@0
|
387 // memory allocation is avoided, if possible
|
max@0
|
388
|
max@0
|
389 blas_int work_len = (std::max)(blas_int(1), n_rows*84);
|
max@0
|
390 podarray<eT> work( static_cast<uword>(work_len) );
|
max@0
|
391
|
max@0
|
392 lapack::getrf(&n_rows, &n_cols, out.memptr(), &n_rows, ipiv.memptr(), &info);
|
max@0
|
393
|
max@0
|
394 if(info == 0)
|
max@0
|
395 {
|
max@0
|
396 // query for optimum size of work_len
|
max@0
|
397
|
max@0
|
398 blas_int work_len_tmp = -1;
|
max@0
|
399 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len_tmp, &info);
|
max@0
|
400
|
max@0
|
401 if(info == 0)
|
max@0
|
402 {
|
max@0
|
403 blas_int proposed_work_len = static_cast<blas_int>(access::tmp_real(work[0]));
|
max@0
|
404
|
max@0
|
405 // if necessary, allocate more memory
|
max@0
|
406 if(work_len < proposed_work_len)
|
max@0
|
407 {
|
max@0
|
408 work_len = proposed_work_len;
|
max@0
|
409 work.set_size( static_cast<uword>(work_len) );
|
max@0
|
410 }
|
max@0
|
411 }
|
max@0
|
412
|
max@0
|
413 lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &work_len, &info);
|
max@0
|
414 }
|
max@0
|
415
|
max@0
|
416 return (info == 0);
|
max@0
|
417 }
|
max@0
|
418 #else
|
max@0
|
419 {
|
max@0
|
420 arma_ignore(out);
|
max@0
|
421 arma_stop("inv(): use of ATLAS or LAPACK needs to be enabled");
|
max@0
|
422 return false;
|
max@0
|
423 }
|
max@0
|
424 #endif
|
max@0
|
425 }
|
max@0
|
426
|
max@0
|
427
|
max@0
|
428
|
max@0
|
429 template<typename eT, typename T1>
|
max@0
|
430 inline
|
max@0
|
431 bool
|
max@0
|
432 auxlib::inv_tr(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
|
max@0
|
433 {
|
max@0
|
434 arma_extra_debug_sigprint();
|
max@0
|
435
|
max@0
|
436 out = X.get_ref();
|
max@0
|
437
|
max@0
|
438 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
|
max@0
|
439
|
max@0
|
440 if(out.is_empty())
|
max@0
|
441 {
|
max@0
|
442 return true;
|
max@0
|
443 }
|
max@0
|
444
|
max@0
|
445 bool status;
|
max@0
|
446
|
max@0
|
447 #if defined(ARMA_USE_LAPACK)
|
max@0
|
448 {
|
max@0
|
449 char uplo = (layout == 0) ? 'U' : 'L';
|
max@0
|
450 char diag = 'N';
|
max@0
|
451 blas_int n = blas_int(out.n_rows);
|
max@0
|
452 blas_int info = 0;
|
max@0
|
453
|
max@0
|
454 lapack::trtri(&uplo, &diag, &n, out.memptr(), &n, &info);
|
max@0
|
455
|
max@0
|
456 status = (info == 0);
|
max@0
|
457 }
|
max@0
|
458 #else
|
max@0
|
459 {
|
max@0
|
460 arma_ignore(layout);
|
max@0
|
461 arma_stop("inv(): use of LAPACK needs to be enabled");
|
max@0
|
462 status = false;
|
max@0
|
463 }
|
max@0
|
464 #endif
|
max@0
|
465
|
max@0
|
466
|
max@0
|
467 if(status == true)
|
max@0
|
468 {
|
max@0
|
469 if(layout == 0)
|
max@0
|
470 {
|
max@0
|
471 // upper triangular
|
max@0
|
472 out = trimatu(out);
|
max@0
|
473 }
|
max@0
|
474 else
|
max@0
|
475 {
|
max@0
|
476 // lower triangular
|
max@0
|
477 out = trimatl(out);
|
max@0
|
478 }
|
max@0
|
479 }
|
max@0
|
480
|
max@0
|
481 return status;
|
max@0
|
482 }
|
max@0
|
483
|
max@0
|
484
|
max@0
|
485
|
max@0
|
486 template<typename eT, typename T1>
|
max@0
|
487 inline
|
max@0
|
488 bool
|
max@0
|
489 auxlib::inv_sym(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
|
max@0
|
490 {
|
max@0
|
491 arma_extra_debug_sigprint();
|
max@0
|
492
|
max@0
|
493 out = X.get_ref();
|
max@0
|
494
|
max@0
|
495 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
|
max@0
|
496
|
max@0
|
497 if(out.is_empty())
|
max@0
|
498 {
|
max@0
|
499 return true;
|
max@0
|
500 }
|
max@0
|
501
|
max@0
|
502 bool status;
|
max@0
|
503
|
max@0
|
504 #if defined(ARMA_USE_LAPACK)
|
max@0
|
505 {
|
max@0
|
506 char uplo = (layout == 0) ? 'U' : 'L';
|
max@0
|
507 blas_int n = blas_int(out.n_rows);
|
max@0
|
508 blas_int lwork = n*n; // TODO: use lwork = -1 to determine optimal size
|
max@0
|
509 blas_int info = 0;
|
max@0
|
510
|
max@0
|
511 podarray<blas_int> ipiv;
|
max@0
|
512 ipiv.set_size(out.n_rows);
|
max@0
|
513
|
max@0
|
514 podarray<eT> work;
|
max@0
|
515 work.set_size( uword(lwork) );
|
max@0
|
516
|
max@0
|
517 lapack::sytrf(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &lwork, &info);
|
max@0
|
518
|
max@0
|
519 status = (info == 0);
|
max@0
|
520
|
max@0
|
521 if(status == true)
|
max@0
|
522 {
|
max@0
|
523 lapack::sytri(&uplo, &n, out.memptr(), &n, ipiv.memptr(), work.memptr(), &info);
|
max@0
|
524
|
max@0
|
525 out = (layout == 0) ? symmatu(out) : symmatl(out);
|
max@0
|
526
|
max@0
|
527 status = (info == 0);
|
max@0
|
528 }
|
max@0
|
529 }
|
max@0
|
530 #else
|
max@0
|
531 {
|
max@0
|
532 arma_ignore(layout);
|
max@0
|
533 arma_stop("inv(): use of LAPACK needs to be enabled");
|
max@0
|
534 status = false;
|
max@0
|
535 }
|
max@0
|
536 #endif
|
max@0
|
537
|
max@0
|
538 return status;
|
max@0
|
539 }
|
max@0
|
540
|
max@0
|
541
|
max@0
|
542
|
max@0
|
543 template<typename eT, typename T1>
|
max@0
|
544 inline
|
max@0
|
545 bool
|
max@0
|
546 auxlib::inv_sympd(Mat<eT>& out, const Base<eT,T1>& X, const uword layout)
|
max@0
|
547 {
|
max@0
|
548 arma_extra_debug_sigprint();
|
max@0
|
549
|
max@0
|
550 out = X.get_ref();
|
max@0
|
551
|
max@0
|
552 arma_debug_check( (out.is_square() == false), "inv(): given matrix is not square" );
|
max@0
|
553
|
max@0
|
554 if(out.is_empty())
|
max@0
|
555 {
|
max@0
|
556 return true;
|
max@0
|
557 }
|
max@0
|
558
|
max@0
|
559 bool status;
|
max@0
|
560
|
max@0
|
561 #if defined(ARMA_USE_LAPACK)
|
max@0
|
562 {
|
max@0
|
563 char uplo = (layout == 0) ? 'U' : 'L';
|
max@0
|
564 blas_int n = blas_int(out.n_rows);
|
max@0
|
565 blas_int info = 0;
|
max@0
|
566
|
max@0
|
567 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
|
max@0
|
568
|
max@0
|
569 status = (info == 0);
|
max@0
|
570
|
max@0
|
571 if(status == true)
|
max@0
|
572 {
|
max@0
|
573 lapack::potri(&uplo, &n, out.memptr(), &n, &info);
|
max@0
|
574
|
max@0
|
575 out = (layout == 0) ? symmatu(out) : symmatl(out);
|
max@0
|
576
|
max@0
|
577 status = (info == 0);
|
max@0
|
578 }
|
max@0
|
579 }
|
max@0
|
580 #else
|
max@0
|
581 {
|
max@0
|
582 arma_ignore(layout);
|
max@0
|
583 arma_stop("inv(): use of LAPACK needs to be enabled");
|
max@0
|
584 status = false;
|
max@0
|
585 }
|
max@0
|
586 #endif
|
max@0
|
587
|
max@0
|
588 return status;
|
max@0
|
589 }
|
max@0
|
590
|
max@0
|
591
|
max@0
|
592
|
max@0
|
593 template<typename eT, typename T1>
|
max@0
|
594 inline
|
max@0
|
595 eT
|
max@0
|
596 auxlib::det(const Base<eT,T1>& X, const bool slow)
|
max@0
|
597 {
|
max@0
|
598 const unwrap<T1> tmp(X.get_ref());
|
max@0
|
599 const Mat<eT>& A = tmp.M;
|
max@0
|
600
|
max@0
|
601 arma_debug_check( (A.is_square() == false), "det(): matrix is not square" );
|
max@0
|
602
|
max@0
|
603 const bool make_copy = (is_Mat<T1>::value == true) ? true : false;
|
max@0
|
604
|
max@0
|
605 if(slow == false)
|
max@0
|
606 {
|
max@0
|
607 const uword N = A.n_rows;
|
max@0
|
608
|
max@0
|
609 switch(N)
|
max@0
|
610 {
|
max@0
|
611 case 0:
|
max@0
|
612 case 1:
|
max@0
|
613 case 2:
|
max@0
|
614 return auxlib::det_tinymat(A, N);
|
max@0
|
615 break;
|
max@0
|
616
|
max@0
|
617 case 3:
|
max@0
|
618 case 4:
|
max@0
|
619 {
|
max@0
|
620 const eT tmp_det = auxlib::det_tinymat(A, N);
|
max@0
|
621 return (tmp_det != eT(0)) ? tmp_det : auxlib::det_lapack(A, make_copy);
|
max@0
|
622 }
|
max@0
|
623 break;
|
max@0
|
624
|
max@0
|
625 default:
|
max@0
|
626 return auxlib::det_lapack(A, make_copy);
|
max@0
|
627 }
|
max@0
|
628 }
|
max@0
|
629 else
|
max@0
|
630 {
|
max@0
|
631 return auxlib::det_lapack(A, make_copy);
|
max@0
|
632 }
|
max@0
|
633 }
|
max@0
|
634
|
max@0
|
635
|
max@0
|
636
|
max@0
|
637 template<typename eT>
|
max@0
|
638 inline
|
max@0
|
639 eT
|
max@0
|
640 auxlib::det_tinymat(const Mat<eT>& X, const uword N)
|
max@0
|
641 {
|
max@0
|
642 arma_extra_debug_sigprint();
|
max@0
|
643
|
max@0
|
644 switch(N)
|
max@0
|
645 {
|
max@0
|
646 case 0:
|
max@0
|
647 return eT(1);
|
max@0
|
648 break;
|
max@0
|
649
|
max@0
|
650 case 1:
|
max@0
|
651 return X[0];
|
max@0
|
652 break;
|
max@0
|
653
|
max@0
|
654 case 2:
|
max@0
|
655 {
|
max@0
|
656 const eT* Xm = X.memptr();
|
max@0
|
657
|
max@0
|
658 return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] );
|
max@0
|
659 }
|
max@0
|
660 break;
|
max@0
|
661
|
max@0
|
662 case 3:
|
max@0
|
663 {
|
max@0
|
664 // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2);
|
max@0
|
665 // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0);
|
max@0
|
666 // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1);
|
max@0
|
667 // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2);
|
max@0
|
668 // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0);
|
max@0
|
669 // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1);
|
max@0
|
670 // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6);
|
max@0
|
671
|
max@0
|
672 const eT* a_col0 = X.colptr(0);
|
max@0
|
673 const eT a11 = a_col0[0];
|
max@0
|
674 const eT a21 = a_col0[1];
|
max@0
|
675 const eT a31 = a_col0[2];
|
max@0
|
676
|
max@0
|
677 const eT* a_col1 = X.colptr(1);
|
max@0
|
678 const eT a12 = a_col1[0];
|
max@0
|
679 const eT a22 = a_col1[1];
|
max@0
|
680 const eT a32 = a_col1[2];
|
max@0
|
681
|
max@0
|
682 const eT* a_col2 = X.colptr(2);
|
max@0
|
683 const eT a13 = a_col2[0];
|
max@0
|
684 const eT a23 = a_col2[1];
|
max@0
|
685 const eT a33 = a_col2[2];
|
max@0
|
686
|
max@0
|
687 return ( a11*(a33*a22 - a32*a23) - a21*(a33*a12-a32*a13) + a31*(a23*a12 - a22*a13) );
|
max@0
|
688 }
|
max@0
|
689 break;
|
max@0
|
690
|
max@0
|
691 case 4:
|
max@0
|
692 {
|
max@0
|
693 const eT* Xm = X.memptr();
|
max@0
|
694
|
max@0
|
695 const eT val = \
|
max@0
|
696 Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
|
max@0
|
697 - Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \
|
max@0
|
698 - Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
|
max@0
|
699 + Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \
|
max@0
|
700 + Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
|
max@0
|
701 - Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \
|
max@0
|
702 - Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
|
max@0
|
703 + Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \
|
max@0
|
704 + Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
|
max@0
|
705 - Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \
|
max@0
|
706 - Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
|
max@0
|
707 + Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \
|
max@0
|
708 + Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
|
max@0
|
709 - Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \
|
max@0
|
710 - Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
|
max@0
|
711 + Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \
|
max@0
|
712 + Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
|
max@0
|
713 - Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \
|
max@0
|
714 - Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
|
max@0
|
715 + Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \
|
max@0
|
716 + Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
|
max@0
|
717 - Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \
|
max@0
|
718 - Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
|
max@0
|
719 + Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \
|
max@0
|
720 ;
|
max@0
|
721
|
max@0
|
722 return val;
|
max@0
|
723 }
|
max@0
|
724 break;
|
max@0
|
725
|
max@0
|
726 default:
|
max@0
|
727 return eT(0);
|
max@0
|
728 ;
|
max@0
|
729 }
|
max@0
|
730 }
|
max@0
|
731
|
max@0
|
732
|
max@0
|
733
|
max@0
|
734 //! immediate determinant of a matrix using ATLAS or LAPACK
|
max@0
|
735 template<typename eT>
|
max@0
|
736 inline
|
max@0
|
737 eT
|
max@0
|
738 auxlib::det_lapack(const Mat<eT>& X, const bool make_copy)
|
max@0
|
739 {
|
max@0
|
740 arma_extra_debug_sigprint();
|
max@0
|
741
|
max@0
|
742 Mat<eT> X_copy;
|
max@0
|
743
|
max@0
|
744 if(make_copy == true)
|
max@0
|
745 {
|
max@0
|
746 X_copy = X;
|
max@0
|
747 }
|
max@0
|
748
|
max@0
|
749 Mat<eT>& tmp = (make_copy == true) ? X_copy : const_cast< Mat<eT>& >(X);
|
max@0
|
750
|
max@0
|
751 if(tmp.is_empty())
|
max@0
|
752 {
|
max@0
|
753 return eT(1);
|
max@0
|
754 }
|
max@0
|
755
|
max@0
|
756
|
max@0
|
757 #if defined(ARMA_USE_ATLAS)
|
max@0
|
758 {
|
max@0
|
759 podarray<int> ipiv(tmp.n_rows);
|
max@0
|
760
|
max@0
|
761 //const int info =
|
max@0
|
762 atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
|
max@0
|
763
|
max@0
|
764 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
|
max@0
|
765 eT val = tmp.at(0,0);
|
max@0
|
766 for(uword i=1; i < tmp.n_rows; ++i)
|
max@0
|
767 {
|
max@0
|
768 val *= tmp.at(i,i);
|
max@0
|
769 }
|
max@0
|
770
|
max@0
|
771 int sign = +1;
|
max@0
|
772 for(uword i=0; i < tmp.n_rows; ++i)
|
max@0
|
773 {
|
max@0
|
774 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0
|
max@0
|
775 {
|
max@0
|
776 sign *= -1;
|
max@0
|
777 }
|
max@0
|
778 }
|
max@0
|
779
|
max@0
|
780 return ( (sign < 0) ? -val : val );
|
max@0
|
781 }
|
max@0
|
782 #elif defined(ARMA_USE_LAPACK)
|
max@0
|
783 {
|
max@0
|
784 podarray<blas_int> ipiv(tmp.n_rows);
|
max@0
|
785
|
max@0
|
786 blas_int info = 0;
|
max@0
|
787 blas_int n_rows = blas_int(tmp.n_rows);
|
max@0
|
788 blas_int n_cols = blas_int(tmp.n_cols);
|
max@0
|
789
|
max@0
|
790 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
|
max@0
|
791
|
max@0
|
792 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
|
max@0
|
793 eT val = tmp.at(0,0);
|
max@0
|
794 for(uword i=1; i < tmp.n_rows; ++i)
|
max@0
|
795 {
|
max@0
|
796 val *= tmp.at(i,i);
|
max@0
|
797 }
|
max@0
|
798
|
max@0
|
799 blas_int sign = +1;
|
max@0
|
800 for(uword i=0; i < tmp.n_rows; ++i)
|
max@0
|
801 {
|
max@0
|
802 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1
|
max@0
|
803 {
|
max@0
|
804 sign *= -1;
|
max@0
|
805 }
|
max@0
|
806 }
|
max@0
|
807
|
max@0
|
808 return ( (sign < 0) ? -val : val );
|
max@0
|
809 }
|
max@0
|
810 #else
|
max@0
|
811 {
|
max@0
|
812 arma_ignore(X);
|
max@0
|
813 arma_ignore(make_copy);
|
max@0
|
814 arma_ignore(tmp);
|
max@0
|
815 arma_stop("det(): use of ATLAS or LAPACK needs to be enabled");
|
max@0
|
816 return eT(0);
|
max@0
|
817 }
|
max@0
|
818 #endif
|
max@0
|
819 }
|
max@0
|
820
|
max@0
|
821
|
max@0
|
822
|
max@0
|
823 //! immediate log determinant of a matrix using ATLAS or LAPACK
|
max@0
|
824 template<typename eT, typename T1>
|
max@0
|
825 inline
|
max@0
|
826 bool
|
max@0
|
827 auxlib::log_det(eT& out_val, typename get_pod_type<eT>::result& out_sign, const Base<eT,T1>& X)
|
max@0
|
828 {
|
max@0
|
829 arma_extra_debug_sigprint();
|
max@0
|
830
|
max@0
|
831 typedef typename get_pod_type<eT>::result T;
|
max@0
|
832
|
max@0
|
833 #if defined(ARMA_USE_ATLAS)
|
max@0
|
834 {
|
max@0
|
835 Mat<eT> tmp(X.get_ref());
|
max@0
|
836 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
|
max@0
|
837
|
max@0
|
838 if(tmp.is_empty())
|
max@0
|
839 {
|
max@0
|
840 out_val = eT(0);
|
max@0
|
841 out_sign = T(1);
|
max@0
|
842 return true;
|
max@0
|
843 }
|
max@0
|
844
|
max@0
|
845 podarray<int> ipiv(tmp.n_rows);
|
max@0
|
846
|
max@0
|
847 const int info = atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr());
|
max@0
|
848
|
max@0
|
849 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
|
max@0
|
850
|
max@0
|
851 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
|
max@0
|
852 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
|
max@0
|
853
|
max@0
|
854 for(uword i=1; i < tmp.n_rows; ++i)
|
max@0
|
855 {
|
max@0
|
856 const eT x = tmp.at(i,i);
|
max@0
|
857
|
max@0
|
858 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
|
max@0
|
859 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
|
max@0
|
860 }
|
max@0
|
861
|
max@0
|
862 for(uword i=0; i < tmp.n_rows; ++i)
|
max@0
|
863 {
|
max@0
|
864 if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0
|
max@0
|
865 {
|
max@0
|
866 sign *= -1;
|
max@0
|
867 }
|
max@0
|
868 }
|
max@0
|
869
|
max@0
|
870 out_val = val;
|
max@0
|
871 out_sign = T(sign);
|
max@0
|
872
|
max@0
|
873 return (info == 0);
|
max@0
|
874 }
|
max@0
|
875 #elif defined(ARMA_USE_LAPACK)
|
max@0
|
876 {
|
max@0
|
877 Mat<eT> tmp(X.get_ref());
|
max@0
|
878 arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix is not square" );
|
max@0
|
879
|
max@0
|
880 if(tmp.is_empty())
|
max@0
|
881 {
|
max@0
|
882 out_val = eT(0);
|
max@0
|
883 out_sign = T(1);
|
max@0
|
884 return true;
|
max@0
|
885 }
|
max@0
|
886
|
max@0
|
887 podarray<blas_int> ipiv(tmp.n_rows);
|
max@0
|
888
|
max@0
|
889 blas_int info = 0;
|
max@0
|
890 blas_int n_rows = blas_int(tmp.n_rows);
|
max@0
|
891 blas_int n_cols = blas_int(tmp.n_cols);
|
max@0
|
892
|
max@0
|
893 lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info);
|
max@0
|
894
|
max@0
|
895 // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero
|
max@0
|
896
|
max@0
|
897 sword sign = (is_complex<eT>::value == false) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1;
|
max@0
|
898 eT val = (is_complex<eT>::value == false) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) );
|
max@0
|
899
|
max@0
|
900 for(uword i=1; i < tmp.n_rows; ++i)
|
max@0
|
901 {
|
max@0
|
902 const eT x = tmp.at(i,i);
|
max@0
|
903
|
max@0
|
904 sign *= (is_complex<eT>::value == false) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1;
|
max@0
|
905 val += (is_complex<eT>::value == false) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x);
|
max@0
|
906 }
|
max@0
|
907
|
max@0
|
908 for(uword i=0; i < tmp.n_rows; ++i)
|
max@0
|
909 {
|
max@0
|
910 if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1
|
max@0
|
911 {
|
max@0
|
912 sign *= -1;
|
max@0
|
913 }
|
max@0
|
914 }
|
max@0
|
915
|
max@0
|
916 out_val = val;
|
max@0
|
917 out_sign = T(sign);
|
max@0
|
918
|
max@0
|
919 return (info == 0);
|
max@0
|
920 }
|
max@0
|
921 #else
|
max@0
|
922 {
|
max@0
|
923 out_val = eT(0);
|
max@0
|
924 out_sign = T(0);
|
max@0
|
925
|
max@0
|
926 arma_stop("log_det(): use of ATLAS or LAPACK needs to be enabled");
|
max@0
|
927
|
max@0
|
928 return false;
|
max@0
|
929 }
|
max@0
|
930 #endif
|
max@0
|
931 }
|
max@0
|
932
|
max@0
|
933
|
max@0
|
934
|
max@0
|
935 //! immediate LU decomposition of a matrix using ATLAS or LAPACK
|
max@0
|
936 template<typename eT, typename T1>
|
max@0
|
937 inline
|
max@0
|
938 bool
|
max@0
|
939 auxlib::lu(Mat<eT>& L, Mat<eT>& U, podarray<blas_int>& ipiv, const Base<eT,T1>& X)
|
max@0
|
940 {
|
max@0
|
941 arma_extra_debug_sigprint();
|
max@0
|
942
|
max@0
|
943 U = X.get_ref();
|
max@0
|
944
|
max@0
|
945 const uword U_n_rows = U.n_rows;
|
max@0
|
946 const uword U_n_cols = U.n_cols;
|
max@0
|
947
|
max@0
|
948 if(U.is_empty())
|
max@0
|
949 {
|
max@0
|
950 L.set_size(U_n_rows, 0);
|
max@0
|
951 U.set_size(0, U_n_cols);
|
max@0
|
952 ipiv.reset();
|
max@0
|
953 return true;
|
max@0
|
954 }
|
max@0
|
955
|
max@0
|
956 #if defined(ARMA_USE_ATLAS) || defined(ARMA_USE_LAPACK)
|
max@0
|
957 {
|
max@0
|
958 bool status;
|
max@0
|
959
|
max@0
|
960 #if defined(ARMA_USE_ATLAS)
|
max@0
|
961 {
|
max@0
|
962 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
|
max@0
|
963
|
max@0
|
964 int info = atlas::clapack_getrf(atlas::CblasColMajor, U_n_rows, U_n_cols, U.memptr(), U_n_rows, ipiv.memptr());
|
max@0
|
965
|
max@0
|
966 status = (info == 0);
|
max@0
|
967 }
|
max@0
|
968 #elif defined(ARMA_USE_LAPACK)
|
max@0
|
969 {
|
max@0
|
970 ipiv.set_size( (std::min)(U_n_rows, U_n_cols) );
|
max@0
|
971
|
max@0
|
972 blas_int info = 0;
|
max@0
|
973
|
max@0
|
974 blas_int n_rows = U_n_rows;
|
max@0
|
975 blas_int n_cols = U_n_cols;
|
max@0
|
976
|
max@0
|
977
|
max@0
|
978 lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info);
|
max@0
|
979
|
max@0
|
980 // take into account that Fortran counts from 1
|
max@0
|
981 arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem);
|
max@0
|
982
|
max@0
|
983 status = (info == 0);
|
max@0
|
984 }
|
max@0
|
985 #endif
|
max@0
|
986
|
max@0
|
987 L.copy_size(U);
|
max@0
|
988
|
max@0
|
989 for(uword col=0; col < U_n_cols; ++col)
|
max@0
|
990 {
|
max@0
|
991 for(uword row=0; (row < col) && (row < U_n_rows); ++row)
|
max@0
|
992 {
|
max@0
|
993 L.at(row,col) = eT(0);
|
max@0
|
994 }
|
max@0
|
995
|
max@0
|
996 if( L.in_range(col,col) == true )
|
max@0
|
997 {
|
max@0
|
998 L.at(col,col) = eT(1);
|
max@0
|
999 }
|
max@0
|
1000
|
max@0
|
1001 for(uword row = (col+1); row < U_n_rows; ++row)
|
max@0
|
1002 {
|
max@0
|
1003 L.at(row,col) = U.at(row,col);
|
max@0
|
1004 U.at(row,col) = eT(0);
|
max@0
|
1005 }
|
max@0
|
1006 }
|
max@0
|
1007
|
max@0
|
1008 return status;
|
max@0
|
1009 }
|
max@0
|
1010 #else
|
max@0
|
1011 {
|
max@0
|
1012 arma_stop("lu(): use of ATLAS or LAPACK needs to be enabled");
|
max@0
|
1013
|
max@0
|
1014 return false;
|
max@0
|
1015 }
|
max@0
|
1016 #endif
|
max@0
|
1017 }
|
max@0
|
1018
|
max@0
|
1019
|
max@0
|
1020
|
max@0
|
1021 template<typename eT, typename T1>
|
max@0
|
1022 inline
|
max@0
|
1023 bool
|
max@0
|
1024 auxlib::lu(Mat<eT>& L, Mat<eT>& U, Mat<eT>& P, const Base<eT,T1>& X)
|
max@0
|
1025 {
|
max@0
|
1026 arma_extra_debug_sigprint();
|
max@0
|
1027
|
max@0
|
1028 podarray<blas_int> ipiv1;
|
max@0
|
1029 const bool status = auxlib::lu(L, U, ipiv1, X);
|
max@0
|
1030
|
max@0
|
1031 if(status == true)
|
max@0
|
1032 {
|
max@0
|
1033 if(U.is_empty())
|
max@0
|
1034 {
|
max@0
|
1035 // L and U have been already set to the correct empty matrices
|
max@0
|
1036 P.eye(L.n_rows, L.n_rows);
|
max@0
|
1037 return true;
|
max@0
|
1038 }
|
max@0
|
1039
|
max@0
|
1040 const uword n = ipiv1.n_elem;
|
max@0
|
1041 const uword P_rows = U.n_rows;
|
max@0
|
1042
|
max@0
|
1043 podarray<blas_int> ipiv2(P_rows);
|
max@0
|
1044
|
max@0
|
1045 const blas_int* ipiv1_mem = ipiv1.memptr();
|
max@0
|
1046 blas_int* ipiv2_mem = ipiv2.memptr();
|
max@0
|
1047
|
max@0
|
1048 for(uword i=0; i<P_rows; ++i)
|
max@0
|
1049 {
|
max@0
|
1050 ipiv2_mem[i] = blas_int(i);
|
max@0
|
1051 }
|
max@0
|
1052
|
max@0
|
1053 for(uword i=0; i<n; ++i)
|
max@0
|
1054 {
|
max@0
|
1055 const uword k = static_cast<uword>(ipiv1_mem[i]);
|
max@0
|
1056
|
max@0
|
1057 if( ipiv2_mem[i] != ipiv2_mem[k] )
|
max@0
|
1058 {
|
max@0
|
1059 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
|
max@0
|
1060 }
|
max@0
|
1061 }
|
max@0
|
1062
|
max@0
|
1063 P.zeros(P_rows, P_rows);
|
max@0
|
1064
|
max@0
|
1065 for(uword row=0; row<P_rows; ++row)
|
max@0
|
1066 {
|
max@0
|
1067 P.at(row, static_cast<uword>(ipiv2_mem[row])) = eT(1);
|
max@0
|
1068 }
|
max@0
|
1069
|
max@0
|
1070 if(L.n_cols > U.n_rows)
|
max@0
|
1071 {
|
max@0
|
1072 L.shed_cols(U.n_rows, L.n_cols-1);
|
max@0
|
1073 }
|
max@0
|
1074
|
max@0
|
1075 if(U.n_rows > L.n_cols)
|
max@0
|
1076 {
|
max@0
|
1077 U.shed_rows(L.n_cols, U.n_rows-1);
|
max@0
|
1078 }
|
max@0
|
1079 }
|
max@0
|
1080
|
max@0
|
1081 return status;
|
max@0
|
1082 }
|
max@0
|
1083
|
max@0
|
1084
|
max@0
|
1085
|
max@0
|
1086 template<typename eT, typename T1>
|
max@0
|
1087 inline
|
max@0
|
1088 bool
|
max@0
|
1089 auxlib::lu(Mat<eT>& L, Mat<eT>& U, const Base<eT,T1>& X)
|
max@0
|
1090 {
|
max@0
|
1091 arma_extra_debug_sigprint();
|
max@0
|
1092
|
max@0
|
1093 podarray<blas_int> ipiv1;
|
max@0
|
1094 const bool status = auxlib::lu(L, U, ipiv1, X);
|
max@0
|
1095
|
max@0
|
1096 if(status == true)
|
max@0
|
1097 {
|
max@0
|
1098 if(U.is_empty())
|
max@0
|
1099 {
|
max@0
|
1100 // L and U have been already set to the correct empty matrices
|
max@0
|
1101 return true;
|
max@0
|
1102 }
|
max@0
|
1103
|
max@0
|
1104 const uword n = ipiv1.n_elem;
|
max@0
|
1105 const uword P_rows = U.n_rows;
|
max@0
|
1106
|
max@0
|
1107 podarray<blas_int> ipiv2(P_rows);
|
max@0
|
1108
|
max@0
|
1109 const blas_int* ipiv1_mem = ipiv1.memptr();
|
max@0
|
1110 blas_int* ipiv2_mem = ipiv2.memptr();
|
max@0
|
1111
|
max@0
|
1112 for(uword i=0; i<P_rows; ++i)
|
max@0
|
1113 {
|
max@0
|
1114 ipiv2_mem[i] = blas_int(i);
|
max@0
|
1115 }
|
max@0
|
1116
|
max@0
|
1117 for(uword i=0; i<n; ++i)
|
max@0
|
1118 {
|
max@0
|
1119 const uword k = static_cast<uword>(ipiv1_mem[i]);
|
max@0
|
1120
|
max@0
|
1121 if( ipiv2_mem[i] != ipiv2_mem[k] )
|
max@0
|
1122 {
|
max@0
|
1123 std::swap( ipiv2_mem[i], ipiv2_mem[k] );
|
max@0
|
1124 L.swap_rows( static_cast<uword>(ipiv2_mem[i]), static_cast<uword>(ipiv2_mem[k]) );
|
max@0
|
1125 }
|
max@0
|
1126 }
|
max@0
|
1127
|
max@0
|
1128 if(L.n_cols > U.n_rows)
|
max@0
|
1129 {
|
max@0
|
1130 L.shed_cols(U.n_rows, L.n_cols-1);
|
max@0
|
1131 }
|
max@0
|
1132
|
max@0
|
1133 if(U.n_rows > L.n_cols)
|
max@0
|
1134 {
|
max@0
|
1135 U.shed_rows(L.n_cols, U.n_rows-1);
|
max@0
|
1136 }
|
max@0
|
1137 }
|
max@0
|
1138
|
max@0
|
1139 return status;
|
max@0
|
1140 }
|
max@0
|
1141
|
max@0
|
1142
|
max@0
|
1143
|
max@0
|
1144 //! immediate eigenvalues of a symmetric real matrix using LAPACK
|
max@0
|
1145 template<typename eT, typename T1>
|
max@0
|
1146 inline
|
max@0
|
1147 bool
|
max@0
|
1148 auxlib::eig_sym(Col<eT>& eigval, const Base<eT,T1>& X)
|
max@0
|
1149 {
|
max@0
|
1150 arma_extra_debug_sigprint();
|
max@0
|
1151
|
max@0
|
1152 #if defined(ARMA_USE_LAPACK)
|
max@0
|
1153 {
|
max@0
|
1154 Mat<eT> A(X.get_ref());
|
max@0
|
1155
|
max@0
|
1156 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not square");
|
max@0
|
1157
|
max@0
|
1158 if(A.is_empty())
|
max@0
|
1159 {
|
max@0
|
1160 eigval.reset();
|
max@0
|
1161 return true;
|
max@0
|
1162 }
|
max@0
|
1163
|
max@0
|
1164 // rudimentary "better-than-nothing" test for symmetry
|
max@0
|
1165 //arma_debug_check( (A.at(A.n_rows-1, 0) != A.at(0, A.n_cols-1)), "auxlib::eig(): given matrix is not symmetric" );
|
max@0
|
1166
|
max@0
|
1167 char jobz = 'N';
|
max@0
|
1168 char uplo = 'U';
|
max@0
|
1169
|
max@0
|
1170 blas_int n_rows = A.n_rows;
|
max@0
|
1171 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1);
|
max@0
|
1172
|
max@0
|
1173 eigval.set_size( static_cast<uword>(n_rows) );
|
max@0
|
1174 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
1175
|
max@0
|
1176 blas_int info;
|
max@0
|
1177
|
max@0
|
1178 arma_extra_debug_print("lapack::syev()");
|
max@0
|
1179 lapack::syev(&jobz, &uplo, &n_rows, A.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info);
|
max@0
|
1180
|
max@0
|
1181 return (info == 0);
|
max@0
|
1182 }
|
max@0
|
1183 #else
|
max@0
|
1184 {
|
max@0
|
1185 arma_ignore(eigval);
|
max@0
|
1186 arma_ignore(X);
|
max@0
|
1187 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
|
max@0
|
1188 return false;
|
max@0
|
1189 }
|
max@0
|
1190 #endif
|
max@0
|
1191 }
|
max@0
|
1192
|
max@0
|
1193
|
max@0
|
1194
|
max@0
|
1195 //! immediate eigenvalues of a hermitian complex matrix using LAPACK
|
max@0
|
1196 template<typename T, typename T1>
|
max@0
|
1197 inline
|
max@0
|
1198 bool
|
max@0
|
1199 auxlib::eig_sym(Col<T>& eigval, const Base<std::complex<T>,T1>& X)
|
max@0
|
1200 {
|
max@0
|
1201 arma_extra_debug_sigprint();
|
max@0
|
1202
|
max@0
|
1203 typedef typename std::complex<T> eT;
|
max@0
|
1204
|
max@0
|
1205 #if defined(ARMA_USE_LAPACK)
|
max@0
|
1206 {
|
max@0
|
1207 Mat<eT> A(X.get_ref());
|
max@0
|
1208 arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix is not hermitian");
|
max@0
|
1209
|
max@0
|
1210 if(A.is_empty())
|
max@0
|
1211 {
|
max@0
|
1212 eigval.reset();
|
max@0
|
1213 return true;
|
max@0
|
1214 }
|
max@0
|
1215
|
max@0
|
1216 char jobz = 'N';
|
max@0
|
1217 char uplo = 'U';
|
max@0
|
1218
|
max@0
|
1219 blas_int n_rows = A.n_rows;
|
max@0
|
1220 blas_int lda = A.n_rows;
|
max@0
|
1221 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1); // TODO: automatically find best size of lwork
|
max@0
|
1222
|
max@0
|
1223 eigval.set_size( static_cast<uword>(n_rows) );
|
max@0
|
1224
|
max@0
|
1225 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
1226 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) );
|
max@0
|
1227
|
max@0
|
1228 blas_int info;
|
max@0
|
1229
|
max@0
|
1230 arma_extra_debug_print("lapack::heev()");
|
max@0
|
1231 lapack::heev(&jobz, &uplo, &n_rows, A.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
|
max@0
|
1232
|
max@0
|
1233 return (info == 0);
|
max@0
|
1234 }
|
max@0
|
1235 #else
|
max@0
|
1236 {
|
max@0
|
1237 arma_ignore(eigval);
|
max@0
|
1238 arma_ignore(X);
|
max@0
|
1239 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
|
max@0
|
1240 return false;
|
max@0
|
1241 }
|
max@0
|
1242 #endif
|
max@0
|
1243 }
|
max@0
|
1244
|
max@0
|
1245
|
max@0
|
1246
|
max@0
|
1247 //! immediate eigenvalues and eigenvectors of a symmetric real matrix using LAPACK
|
max@0
|
1248 template<typename eT, typename T1>
|
max@0
|
1249 inline
|
max@0
|
1250 bool
|
max@0
|
1251 auxlib::eig_sym(Col<eT>& eigval, Mat<eT>& eigvec, const Base<eT,T1>& X)
|
max@0
|
1252 {
|
max@0
|
1253 arma_extra_debug_sigprint();
|
max@0
|
1254
|
max@0
|
1255 #if defined(ARMA_USE_LAPACK)
|
max@0
|
1256 {
|
max@0
|
1257 eigvec = X.get_ref();
|
max@0
|
1258
|
max@0
|
1259 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not square" );
|
max@0
|
1260
|
max@0
|
1261 if(eigvec.is_empty())
|
max@0
|
1262 {
|
max@0
|
1263 eigval.reset();
|
max@0
|
1264 eigvec.reset();
|
max@0
|
1265 return true;
|
max@0
|
1266 }
|
max@0
|
1267
|
max@0
|
1268 // rudimentary "better-than-nothing" test for symmetry
|
max@0
|
1269 //arma_debug_check( (A.at(A.n_rows-1, 0) != A.at(0, A.n_cols-1)), "auxlib::eig(): given matrix is not symmetric" );
|
max@0
|
1270
|
max@0
|
1271 char jobz = 'V';
|
max@0
|
1272 char uplo = 'U';
|
max@0
|
1273
|
max@0
|
1274 blas_int n_rows = eigvec.n_rows;
|
max@0
|
1275 blas_int lwork = (std::max)(blas_int(1), 3*n_rows-1);
|
max@0
|
1276
|
max@0
|
1277 eigval.set_size( static_cast<uword>(n_rows) );
|
max@0
|
1278 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
1279
|
max@0
|
1280 blas_int info;
|
max@0
|
1281
|
max@0
|
1282 arma_extra_debug_print("lapack::syev()");
|
max@0
|
1283 lapack::syev(&jobz, &uplo, &n_rows, eigvec.memptr(), &n_rows, eigval.memptr(), work.memptr(), &lwork, &info);
|
max@0
|
1284
|
max@0
|
1285 return (info == 0);
|
max@0
|
1286 }
|
max@0
|
1287 #else
|
max@0
|
1288 {
|
max@0
|
1289 arma_ignore(eigval);
|
max@0
|
1290 arma_ignore(eigvec);
|
max@0
|
1291 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
|
max@0
|
1292
|
max@0
|
1293 return false;
|
max@0
|
1294 }
|
max@0
|
1295 #endif
|
max@0
|
1296 }
|
max@0
|
1297
|
max@0
|
1298
|
max@0
|
1299
|
max@0
|
1300 //! immediate eigenvalues and eigenvectors of a hermitian complex matrix using LAPACK
|
max@0
|
1301 template<typename T, typename T1>
|
max@0
|
1302 inline
|
max@0
|
1303 bool
|
max@0
|
1304 auxlib::eig_sym(Col<T>& eigval, Mat< std::complex<T> >& eigvec, const Base<std::complex<T>,T1>& X)
|
max@0
|
1305 {
|
max@0
|
1306 arma_extra_debug_sigprint();
|
max@0
|
1307
|
max@0
|
1308 typedef typename std::complex<T> eT;
|
max@0
|
1309
|
max@0
|
1310 #if defined(ARMA_USE_LAPACK)
|
max@0
|
1311 {
|
max@0
|
1312 eigvec = X.get_ref();
|
max@0
|
1313
|
max@0
|
1314 arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix is not hermitian" );
|
max@0
|
1315
|
max@0
|
1316 if(eigvec.is_empty())
|
max@0
|
1317 {
|
max@0
|
1318 eigval.reset();
|
max@0
|
1319 eigvec.reset();
|
max@0
|
1320 return true;
|
max@0
|
1321 }
|
max@0
|
1322
|
max@0
|
1323 char jobz = 'V';
|
max@0
|
1324 char uplo = 'U';
|
max@0
|
1325
|
max@0
|
1326 blas_int n_rows = eigvec.n_rows;
|
max@0
|
1327 blas_int lda = eigvec.n_rows;
|
max@0
|
1328 blas_int lwork = (std::max)(blas_int(1), 2*n_rows - 1); // TODO: automatically find best size of lwork
|
max@0
|
1329
|
max@0
|
1330 eigval.set_size( static_cast<uword>(n_rows) );
|
max@0
|
1331
|
max@0
|
1332 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
1333 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows - 2)) );
|
max@0
|
1334
|
max@0
|
1335 blas_int info;
|
max@0
|
1336
|
max@0
|
1337 arma_extra_debug_print("lapack::heev()");
|
max@0
|
1338 lapack::heev(&jobz, &uplo, &n_rows, eigvec.memptr(), &lda, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info);
|
max@0
|
1339
|
max@0
|
1340 return (info == 0);
|
max@0
|
1341 }
|
max@0
|
1342 #else
|
max@0
|
1343 {
|
max@0
|
1344 arma_ignore(eigval);
|
max@0
|
1345 arma_ignore(eigvec);
|
max@0
|
1346 arma_ignore(X);
|
max@0
|
1347 arma_stop("eig_sym(): use of LAPACK needs to be enabled");
|
max@0
|
1348 return false;
|
max@0
|
1349 }
|
max@0
|
1350 #endif
|
max@0
|
1351 }
|
max@0
|
1352
|
max@0
|
1353
|
max@0
|
1354
|
max@0
|
1355 //! Eigenvalues and eigenvectors of a general square real matrix using LAPACK.
|
max@0
|
1356 //! The argument 'side' specifies which eigenvectors should be calculated
|
max@0
|
1357 //! (see code for mode details).
|
max@0
|
1358 template<typename T, typename T1>
|
max@0
|
1359 inline
|
max@0
|
1360 bool
|
max@0
|
1361 auxlib::eig_gen
|
max@0
|
1362 (
|
max@0
|
1363 Col< std::complex<T> >& eigval,
|
max@0
|
1364 Mat<T>& l_eigvec,
|
max@0
|
1365 Mat<T>& r_eigvec,
|
max@0
|
1366 const Base<T,T1>& X,
|
max@0
|
1367 const char side
|
max@0
|
1368 )
|
max@0
|
1369 {
|
max@0
|
1370 arma_extra_debug_sigprint();
|
max@0
|
1371
|
max@0
|
1372 #if defined(ARMA_USE_LAPACK)
|
max@0
|
1373 {
|
max@0
|
1374 char jobvl;
|
max@0
|
1375 char jobvr;
|
max@0
|
1376
|
max@0
|
1377 switch(side)
|
max@0
|
1378 {
|
max@0
|
1379 case 'l': // left
|
max@0
|
1380 jobvl = 'V';
|
max@0
|
1381 jobvr = 'N';
|
max@0
|
1382 break;
|
max@0
|
1383
|
max@0
|
1384 case 'r': // right
|
max@0
|
1385 jobvl = 'N';
|
max@0
|
1386 jobvr = 'V';
|
max@0
|
1387 break;
|
max@0
|
1388
|
max@0
|
1389 case 'b': // both
|
max@0
|
1390 jobvl = 'V';
|
max@0
|
1391 jobvr = 'V';
|
max@0
|
1392 break;
|
max@0
|
1393
|
max@0
|
1394 case 'n': // neither
|
max@0
|
1395 jobvl = 'N';
|
max@0
|
1396 jobvr = 'N';
|
max@0
|
1397 break;
|
max@0
|
1398
|
max@0
|
1399 default:
|
max@0
|
1400 arma_stop("eig_gen(): parameter 'side' is invalid");
|
max@0
|
1401 return false;
|
max@0
|
1402 }
|
max@0
|
1403
|
max@0
|
1404 Mat<T> A(X.get_ref());
|
max@0
|
1405 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
|
max@0
|
1406
|
max@0
|
1407 if(A.is_empty())
|
max@0
|
1408 {
|
max@0
|
1409 eigval.reset();
|
max@0
|
1410 l_eigvec.reset();
|
max@0
|
1411 r_eigvec.reset();
|
max@0
|
1412 return true;
|
max@0
|
1413 }
|
max@0
|
1414
|
max@0
|
1415 uword A_n_rows = A.n_rows;
|
max@0
|
1416
|
max@0
|
1417 blas_int n_rows = A_n_rows;
|
max@0
|
1418 blas_int lda = A_n_rows;
|
max@0
|
1419 blas_int lwork = (std::max)(blas_int(1), 4*n_rows); // TODO: automatically find best size of lwork
|
max@0
|
1420
|
max@0
|
1421 eigval.set_size(A_n_rows);
|
max@0
|
1422 l_eigvec.set_size(A_n_rows, A_n_rows);
|
max@0
|
1423 r_eigvec.set_size(A_n_rows, A_n_rows);
|
max@0
|
1424
|
max@0
|
1425 podarray<T> work( static_cast<uword>(lwork) );
|
max@0
|
1426 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) );
|
max@0
|
1427
|
max@0
|
1428 podarray<T> wr(A_n_rows);
|
max@0
|
1429 podarray<T> wi(A_n_rows);
|
max@0
|
1430
|
max@0
|
1431 Mat<T> A_copy = A;
|
max@0
|
1432 blas_int info;
|
max@0
|
1433
|
max@0
|
1434 arma_extra_debug_print("lapack::geev()");
|
max@0
|
1435 lapack::geev(&jobvl, &jobvr, &n_rows, A_copy.memptr(), &lda, wr.memptr(), wi.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, &info);
|
max@0
|
1436
|
max@0
|
1437
|
max@0
|
1438 eigval.set_size(A_n_rows);
|
max@0
|
1439 for(uword i=0; i<A_n_rows; ++i)
|
max@0
|
1440 {
|
max@0
|
1441 eigval[i] = std::complex<T>(wr[i], wi[i]);
|
max@0
|
1442 }
|
max@0
|
1443
|
max@0
|
1444 return (info == 0);
|
max@0
|
1445 }
|
max@0
|
1446 #else
|
max@0
|
1447 {
|
max@0
|
1448 arma_ignore(eigval);
|
max@0
|
1449 arma_ignore(l_eigvec);
|
max@0
|
1450 arma_ignore(r_eigvec);
|
max@0
|
1451 arma_ignore(X);
|
max@0
|
1452 arma_ignore(side);
|
max@0
|
1453 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
|
max@0
|
1454 return false;
|
max@0
|
1455 }
|
max@0
|
1456 #endif
|
max@0
|
1457 }
|
max@0
|
1458
|
max@0
|
1459
|
max@0
|
1460
|
max@0
|
1461
|
max@0
|
1462
|
max@0
|
1463 //! Eigenvalues and eigenvectors of a general square complex matrix using LAPACK
|
max@0
|
1464 //! The argument 'side' specifies which eigenvectors should be calculated
|
max@0
|
1465 //! (see code for mode details).
|
max@0
|
1466 template<typename T, typename T1>
|
max@0
|
1467 inline
|
max@0
|
1468 bool
|
max@0
|
1469 auxlib::eig_gen
|
max@0
|
1470 (
|
max@0
|
1471 Col< std::complex<T> >& eigval,
|
max@0
|
1472 Mat< std::complex<T> >& l_eigvec,
|
max@0
|
1473 Mat< std::complex<T> >& r_eigvec,
|
max@0
|
1474 const Base< std::complex<T>, T1 >& X,
|
max@0
|
1475 const char side
|
max@0
|
1476 )
|
max@0
|
1477 {
|
max@0
|
1478 arma_extra_debug_sigprint();
|
max@0
|
1479
|
max@0
|
1480 typedef typename std::complex<T> eT;
|
max@0
|
1481
|
max@0
|
1482 #if defined(ARMA_USE_LAPACK)
|
max@0
|
1483 {
|
max@0
|
1484 char jobvl;
|
max@0
|
1485 char jobvr;
|
max@0
|
1486
|
max@0
|
1487 switch(side)
|
max@0
|
1488 {
|
max@0
|
1489 case 'l': // left
|
max@0
|
1490 jobvl = 'V';
|
max@0
|
1491 jobvr = 'N';
|
max@0
|
1492 break;
|
max@0
|
1493
|
max@0
|
1494 case 'r': // right
|
max@0
|
1495 jobvl = 'N';
|
max@0
|
1496 jobvr = 'V';
|
max@0
|
1497 break;
|
max@0
|
1498
|
max@0
|
1499 case 'b': // both
|
max@0
|
1500 jobvl = 'V';
|
max@0
|
1501 jobvr = 'V';
|
max@0
|
1502 break;
|
max@0
|
1503
|
max@0
|
1504 case 'n': // neither
|
max@0
|
1505 jobvl = 'N';
|
max@0
|
1506 jobvr = 'N';
|
max@0
|
1507 break;
|
max@0
|
1508
|
max@0
|
1509 default:
|
max@0
|
1510 arma_stop("eig_gen(): parameter 'side' is invalid");
|
max@0
|
1511 return false;
|
max@0
|
1512 }
|
max@0
|
1513
|
max@0
|
1514 Mat<eT> A(X.get_ref());
|
max@0
|
1515 arma_debug_check( (A.is_square() == false), "eig_gen(): given matrix is not square" );
|
max@0
|
1516
|
max@0
|
1517 if(A.is_empty())
|
max@0
|
1518 {
|
max@0
|
1519 eigval.reset();
|
max@0
|
1520 l_eigvec.reset();
|
max@0
|
1521 r_eigvec.reset();
|
max@0
|
1522 return true;
|
max@0
|
1523 }
|
max@0
|
1524
|
max@0
|
1525 uword A_n_rows = A.n_rows;
|
max@0
|
1526
|
max@0
|
1527 blas_int n_rows = A_n_rows;
|
max@0
|
1528 blas_int lda = A_n_rows;
|
max@0
|
1529 blas_int lwork = (std::max)(blas_int(1), 4*n_rows); // TODO: automatically find best size of lwork
|
max@0
|
1530
|
max@0
|
1531 eigval.set_size(A_n_rows);
|
max@0
|
1532 l_eigvec.set_size(A_n_rows, A_n_rows);
|
max@0
|
1533 r_eigvec.set_size(A_n_rows, A_n_rows);
|
max@0
|
1534
|
max@0
|
1535 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
1536 podarray<T> rwork( static_cast<uword>((std::max)(blas_int(1), 3*n_rows)) ); // was 2,3
|
max@0
|
1537
|
max@0
|
1538 blas_int info;
|
max@0
|
1539
|
max@0
|
1540 arma_extra_debug_print("lapack::cx_geev()");
|
max@0
|
1541 lapack::cx_geev(&jobvl, &jobvr, &n_rows, A.memptr(), &lda, eigval.memptr(), l_eigvec.memptr(), &n_rows, r_eigvec.memptr(), &n_rows, work.memptr(), &lwork, rwork.memptr(), &info);
|
max@0
|
1542
|
max@0
|
1543 return (info == 0);
|
max@0
|
1544 }
|
max@0
|
1545 #else
|
max@0
|
1546 {
|
max@0
|
1547 arma_ignore(eigval);
|
max@0
|
1548 arma_ignore(l_eigvec);
|
max@0
|
1549 arma_ignore(r_eigvec);
|
max@0
|
1550 arma_ignore(X);
|
max@0
|
1551 arma_ignore(side);
|
max@0
|
1552 arma_stop("eig_gen(): use of LAPACK needs to be enabled");
|
max@0
|
1553 return false;
|
max@0
|
1554 }
|
max@0
|
1555 #endif
|
max@0
|
1556 }
|
max@0
|
1557
|
max@0
|
1558
|
max@0
|
1559
|
max@0
|
1560 template<typename eT, typename T1>
|
max@0
|
1561 inline
|
max@0
|
1562 bool
|
max@0
|
1563 auxlib::chol(Mat<eT>& out, const Base<eT,T1>& X)
|
max@0
|
1564 {
|
max@0
|
1565 arma_extra_debug_sigprint();
|
max@0
|
1566
|
max@0
|
1567 #if defined(ARMA_USE_LAPACK)
|
max@0
|
1568 {
|
max@0
|
1569 out = X.get_ref();
|
max@0
|
1570
|
max@0
|
1571 arma_debug_check( (out.is_square() == false), "chol(): given matrix is not square" );
|
max@0
|
1572
|
max@0
|
1573 if(out.is_empty())
|
max@0
|
1574 {
|
max@0
|
1575 return true;
|
max@0
|
1576 }
|
max@0
|
1577
|
max@0
|
1578 const uword out_n_rows = out.n_rows;
|
max@0
|
1579
|
max@0
|
1580 char uplo = 'U';
|
max@0
|
1581 blas_int n = out_n_rows;
|
max@0
|
1582 blas_int info;
|
max@0
|
1583
|
max@0
|
1584 lapack::potrf(&uplo, &n, out.memptr(), &n, &info);
|
max@0
|
1585
|
max@0
|
1586 for(uword col=0; col<out_n_rows; ++col)
|
max@0
|
1587 {
|
max@0
|
1588 eT* colptr = out.colptr(col);
|
max@0
|
1589
|
max@0
|
1590 for(uword row=(col+1); row < out_n_rows; ++row)
|
max@0
|
1591 {
|
max@0
|
1592 colptr[row] = eT(0);
|
max@0
|
1593 }
|
max@0
|
1594 }
|
max@0
|
1595
|
max@0
|
1596 return (info == 0);
|
max@0
|
1597 }
|
max@0
|
1598 #else
|
max@0
|
1599 {
|
max@0
|
1600 arma_ignore(out);
|
max@0
|
1601 arma_stop("chol(): use of LAPACK needs to be enabled");
|
max@0
|
1602 return false;
|
max@0
|
1603 }
|
max@0
|
1604 #endif
|
max@0
|
1605 }
|
max@0
|
1606
|
max@0
|
1607
|
max@0
|
1608
|
max@0
|
1609 template<typename eT, typename T1>
|
max@0
|
1610 inline
|
max@0
|
1611 bool
|
max@0
|
1612 auxlib::qr(Mat<eT>& Q, Mat<eT>& R, const Base<eT,T1>& X)
|
max@0
|
1613 {
|
max@0
|
1614 arma_extra_debug_sigprint();
|
max@0
|
1615
|
max@0
|
1616 #if defined(ARMA_USE_LAPACK)
|
max@0
|
1617 {
|
max@0
|
1618 R = X.get_ref();
|
max@0
|
1619
|
max@0
|
1620 const uword R_n_rows = R.n_rows;
|
max@0
|
1621 const uword R_n_cols = R.n_cols;
|
max@0
|
1622
|
max@0
|
1623 if(R.is_empty())
|
max@0
|
1624 {
|
max@0
|
1625 Q.eye(R_n_rows, R_n_rows);
|
max@0
|
1626 return true;
|
max@0
|
1627 }
|
max@0
|
1628
|
max@0
|
1629 blas_int m = static_cast<blas_int>(R_n_rows);
|
max@0
|
1630 blas_int n = static_cast<blas_int>(R_n_cols);
|
max@0
|
1631 blas_int work_len = (std::max)(blas_int(1),n);
|
max@0
|
1632 blas_int work_len_tmp;
|
max@0
|
1633 blas_int k = (std::min)(m,n);
|
max@0
|
1634 blas_int info;
|
max@0
|
1635
|
max@0
|
1636 podarray<eT> tau( static_cast<uword>(k) );
|
max@0
|
1637 podarray<eT> work( static_cast<uword>(work_len) );
|
max@0
|
1638
|
max@0
|
1639 // query for the optimum value of work_len
|
max@0
|
1640 work_len_tmp = -1;
|
max@0
|
1641 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
|
max@0
|
1642
|
max@0
|
1643 if(info == 0)
|
max@0
|
1644 {
|
max@0
|
1645 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
|
max@0
|
1646 work.set_size( static_cast<uword>(work_len) );
|
max@0
|
1647 }
|
max@0
|
1648
|
max@0
|
1649 lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
|
max@0
|
1650
|
max@0
|
1651 Q.set_size(R_n_rows, R_n_rows);
|
max@0
|
1652
|
max@0
|
1653 arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) );
|
max@0
|
1654
|
max@0
|
1655 //
|
max@0
|
1656 // construct R
|
max@0
|
1657
|
max@0
|
1658 for(uword col=0; col < R_n_cols; ++col)
|
max@0
|
1659 {
|
max@0
|
1660 for(uword row=(col+1); row < R_n_rows; ++row)
|
max@0
|
1661 {
|
max@0
|
1662 R.at(row,col) = eT(0);
|
max@0
|
1663 }
|
max@0
|
1664 }
|
max@0
|
1665
|
max@0
|
1666
|
max@0
|
1667 if( (is_float<eT>::value == true) || (is_double<eT>::value == true) )
|
max@0
|
1668 {
|
max@0
|
1669 // query for the optimum value of work_len
|
max@0
|
1670 work_len_tmp = -1;
|
max@0
|
1671 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
|
max@0
|
1672
|
max@0
|
1673 if(info == 0)
|
max@0
|
1674 {
|
max@0
|
1675 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
|
max@0
|
1676 work.set_size( static_cast<uword>(work_len) );
|
max@0
|
1677 }
|
max@0
|
1678
|
max@0
|
1679 lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
|
max@0
|
1680 }
|
max@0
|
1681 else
|
max@0
|
1682 if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) )
|
max@0
|
1683 {
|
max@0
|
1684 // query for the optimum value of work_len
|
max@0
|
1685 work_len_tmp = -1;
|
max@0
|
1686 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len_tmp, &info);
|
max@0
|
1687
|
max@0
|
1688 if(info == 0)
|
max@0
|
1689 {
|
max@0
|
1690 work_len = static_cast<blas_int>(access::tmp_real(work[0]));
|
max@0
|
1691 work.set_size( static_cast<uword>(work_len) );
|
max@0
|
1692 }
|
max@0
|
1693
|
max@0
|
1694 lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &work_len, &info);
|
max@0
|
1695 }
|
max@0
|
1696
|
max@0
|
1697 return (info == 0);
|
max@0
|
1698 }
|
max@0
|
1699 #else
|
max@0
|
1700 {
|
max@0
|
1701 arma_ignore(Q);
|
max@0
|
1702 arma_ignore(R);
|
max@0
|
1703 arma_ignore(X);
|
max@0
|
1704 arma_stop("qr(): use of LAPACK needs to be enabled");
|
max@0
|
1705 return false;
|
max@0
|
1706 }
|
max@0
|
1707 #endif
|
max@0
|
1708 }
|
max@0
|
1709
|
max@0
|
1710
|
max@0
|
1711
|
max@0
|
1712 template<typename eT, typename T1>
|
max@0
|
1713 inline
|
max@0
|
1714 bool
|
max@0
|
1715 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X, uword& X_n_rows, uword& X_n_cols)
|
max@0
|
1716 {
|
max@0
|
1717 arma_extra_debug_sigprint();
|
max@0
|
1718
|
max@0
|
1719 #if defined(ARMA_USE_LAPACK)
|
max@0
|
1720 {
|
max@0
|
1721 Mat<eT> A(X.get_ref());
|
max@0
|
1722
|
max@0
|
1723 X_n_rows = A.n_rows;
|
max@0
|
1724 X_n_cols = A.n_cols;
|
max@0
|
1725
|
max@0
|
1726 if(A.is_empty())
|
max@0
|
1727 {
|
max@0
|
1728 S.reset();
|
max@0
|
1729 return true;
|
max@0
|
1730 }
|
max@0
|
1731
|
max@0
|
1732 Mat<eT> U(1, 1);
|
max@0
|
1733 Mat<eT> V(1, A.n_cols);
|
max@0
|
1734
|
max@0
|
1735 char jobu = 'N';
|
max@0
|
1736 char jobvt = 'N';
|
max@0
|
1737
|
max@0
|
1738 blas_int m = A.n_rows;
|
max@0
|
1739 blas_int n = A.n_cols;
|
max@0
|
1740 blas_int lda = A.n_rows;
|
max@0
|
1741 blas_int ldu = U.n_rows;
|
max@0
|
1742 blas_int ldvt = V.n_rows;
|
max@0
|
1743 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
|
max@0
|
1744 blas_int info;
|
max@0
|
1745
|
max@0
|
1746 S.set_size( static_cast<uword>((std::min)(m, n)) );
|
max@0
|
1747
|
max@0
|
1748 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
1749
|
max@0
|
1750
|
max@0
|
1751 // let gesvd_() calculate the optimum size of the workspace
|
max@0
|
1752 blas_int lwork_tmp = -1;
|
max@0
|
1753
|
max@0
|
1754 lapack::gesvd<eT>
|
max@0
|
1755 (
|
max@0
|
1756 &jobu, &jobvt,
|
max@0
|
1757 &m,&n,
|
max@0
|
1758 A.memptr(), &lda,
|
max@0
|
1759 S.memptr(),
|
max@0
|
1760 U.memptr(), &ldu,
|
max@0
|
1761 V.memptr(), &ldvt,
|
max@0
|
1762 work.memptr(), &lwork_tmp,
|
max@0
|
1763 &info
|
max@0
|
1764 );
|
max@0
|
1765
|
max@0
|
1766 if(info == 0)
|
max@0
|
1767 {
|
max@0
|
1768 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
|
max@0
|
1769
|
max@0
|
1770 if(proposed_lwork > lwork)
|
max@0
|
1771 {
|
max@0
|
1772 lwork = proposed_lwork;
|
max@0
|
1773 work.set_size( static_cast<uword>(lwork) );
|
max@0
|
1774 }
|
max@0
|
1775
|
max@0
|
1776 lapack::gesvd<eT>
|
max@0
|
1777 (
|
max@0
|
1778 &jobu, &jobvt,
|
max@0
|
1779 &m, &n,
|
max@0
|
1780 A.memptr(), &lda,
|
max@0
|
1781 S.memptr(),
|
max@0
|
1782 U.memptr(), &ldu,
|
max@0
|
1783 V.memptr(), &ldvt,
|
max@0
|
1784 work.memptr(), &lwork,
|
max@0
|
1785 &info
|
max@0
|
1786 );
|
max@0
|
1787 }
|
max@0
|
1788
|
max@0
|
1789 return (info == 0);
|
max@0
|
1790 }
|
max@0
|
1791 #else
|
max@0
|
1792 {
|
max@0
|
1793 arma_ignore(S);
|
max@0
|
1794 arma_ignore(X);
|
max@0
|
1795 arma_ignore(X_n_rows);
|
max@0
|
1796 arma_ignore(X_n_cols);
|
max@0
|
1797 arma_stop("svd(): use of LAPACK needs to be enabled");
|
max@0
|
1798 return false;
|
max@0
|
1799 }
|
max@0
|
1800 #endif
|
max@0
|
1801 }
|
max@0
|
1802
|
max@0
|
1803
|
max@0
|
1804
|
max@0
|
1805 template<typename T, typename T1>
|
max@0
|
1806 inline
|
max@0
|
1807 bool
|
max@0
|
1808 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X, uword& X_n_rows, uword& X_n_cols)
|
max@0
|
1809 {
|
max@0
|
1810 arma_extra_debug_sigprint();
|
max@0
|
1811
|
max@0
|
1812 typedef std::complex<T> eT;
|
max@0
|
1813
|
max@0
|
1814 #if defined(ARMA_USE_LAPACK)
|
max@0
|
1815 {
|
max@0
|
1816 Mat<eT> A(X.get_ref());
|
max@0
|
1817
|
max@0
|
1818 X_n_rows = A.n_rows;
|
max@0
|
1819 X_n_cols = A.n_cols;
|
max@0
|
1820
|
max@0
|
1821 if(A.is_empty())
|
max@0
|
1822 {
|
max@0
|
1823 S.reset();
|
max@0
|
1824 return true;
|
max@0
|
1825 }
|
max@0
|
1826
|
max@0
|
1827 Mat<eT> U(1, 1);
|
max@0
|
1828 Mat<eT> V(1, A.n_cols);
|
max@0
|
1829
|
max@0
|
1830 char jobu = 'N';
|
max@0
|
1831 char jobvt = 'N';
|
max@0
|
1832
|
max@0
|
1833 blas_int m = A.n_rows;
|
max@0
|
1834 blas_int n = A.n_cols;
|
max@0
|
1835 blas_int lda = A.n_rows;
|
max@0
|
1836 blas_int ldu = U.n_rows;
|
max@0
|
1837 blas_int ldvt = V.n_rows;
|
max@0
|
1838 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) );
|
max@0
|
1839 blas_int info;
|
max@0
|
1840
|
max@0
|
1841 S.set_size( static_cast<uword>((std::min)(m,n)) );
|
max@0
|
1842
|
max@0
|
1843 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
1844 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
|
max@0
|
1845
|
max@0
|
1846 // let gesvd_() calculate the optimum size of the workspace
|
max@0
|
1847 blas_int lwork_tmp = -1;
|
max@0
|
1848
|
max@0
|
1849 lapack::cx_gesvd<T>
|
max@0
|
1850 (
|
max@0
|
1851 &jobu, &jobvt,
|
max@0
|
1852 &m, &n,
|
max@0
|
1853 A.memptr(), &lda,
|
max@0
|
1854 S.memptr(),
|
max@0
|
1855 U.memptr(), &ldu,
|
max@0
|
1856 V.memptr(), &ldvt,
|
max@0
|
1857 work.memptr(), &lwork_tmp,
|
max@0
|
1858 rwork.memptr(),
|
max@0
|
1859 &info
|
max@0
|
1860 );
|
max@0
|
1861
|
max@0
|
1862 if(info == 0)
|
max@0
|
1863 {
|
max@0
|
1864 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
|
max@0
|
1865 if(proposed_lwork > lwork)
|
max@0
|
1866 {
|
max@0
|
1867 lwork = proposed_lwork;
|
max@0
|
1868 work.set_size( static_cast<uword>(lwork) );
|
max@0
|
1869 }
|
max@0
|
1870
|
max@0
|
1871 lapack::cx_gesvd<T>
|
max@0
|
1872 (
|
max@0
|
1873 &jobu, &jobvt,
|
max@0
|
1874 &m, &n,
|
max@0
|
1875 A.memptr(), &lda,
|
max@0
|
1876 S.memptr(),
|
max@0
|
1877 U.memptr(), &ldu,
|
max@0
|
1878 V.memptr(), &ldvt,
|
max@0
|
1879 work.memptr(), &lwork,
|
max@0
|
1880 rwork.memptr(),
|
max@0
|
1881 &info
|
max@0
|
1882 );
|
max@0
|
1883 }
|
max@0
|
1884
|
max@0
|
1885 return (info == 0);
|
max@0
|
1886 }
|
max@0
|
1887 #else
|
max@0
|
1888 {
|
max@0
|
1889 arma_ignore(S);
|
max@0
|
1890 arma_ignore(X);
|
max@0
|
1891 arma_ignore(X_n_rows);
|
max@0
|
1892 arma_ignore(X_n_cols);
|
max@0
|
1893
|
max@0
|
1894 arma_stop("svd(): use of LAPACK needs to be enabled");
|
max@0
|
1895 return false;
|
max@0
|
1896 }
|
max@0
|
1897 #endif
|
max@0
|
1898 }
|
max@0
|
1899
|
max@0
|
1900
|
max@0
|
1901
|
max@0
|
1902 template<typename eT, typename T1>
|
max@0
|
1903 inline
|
max@0
|
1904 bool
|
max@0
|
1905 auxlib::svd(Col<eT>& S, const Base<eT,T1>& X)
|
max@0
|
1906 {
|
max@0
|
1907 arma_extra_debug_sigprint();
|
max@0
|
1908
|
max@0
|
1909 uword junk;
|
max@0
|
1910 return auxlib::svd(S, X, junk, junk);
|
max@0
|
1911 }
|
max@0
|
1912
|
max@0
|
1913
|
max@0
|
1914
|
max@0
|
1915 template<typename T, typename T1>
|
max@0
|
1916 inline
|
max@0
|
1917 bool
|
max@0
|
1918 auxlib::svd(Col<T>& S, const Base<std::complex<T>, T1>& X)
|
max@0
|
1919 {
|
max@0
|
1920 arma_extra_debug_sigprint();
|
max@0
|
1921
|
max@0
|
1922 uword junk;
|
max@0
|
1923 return auxlib::svd(S, X, junk, junk);
|
max@0
|
1924 }
|
max@0
|
1925
|
max@0
|
1926
|
max@0
|
1927
|
max@0
|
1928 template<typename eT, typename T1>
|
max@0
|
1929 inline
|
max@0
|
1930 bool
|
max@0
|
1931 auxlib::svd(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X)
|
max@0
|
1932 {
|
max@0
|
1933 arma_extra_debug_sigprint();
|
max@0
|
1934
|
max@0
|
1935 #if defined(ARMA_USE_LAPACK)
|
max@0
|
1936 {
|
max@0
|
1937 Mat<eT> A(X.get_ref());
|
max@0
|
1938
|
max@0
|
1939 if(A.is_empty())
|
max@0
|
1940 {
|
max@0
|
1941 U.eye(A.n_rows, A.n_rows);
|
max@0
|
1942 S.reset();
|
max@0
|
1943 V.eye(A.n_cols, A.n_cols);
|
max@0
|
1944 return true;
|
max@0
|
1945 }
|
max@0
|
1946
|
max@0
|
1947 U.set_size(A.n_rows, A.n_rows);
|
max@0
|
1948 V.set_size(A.n_cols, A.n_cols);
|
max@0
|
1949
|
max@0
|
1950 char jobu = 'A';
|
max@0
|
1951 char jobvt = 'A';
|
max@0
|
1952
|
max@0
|
1953 blas_int m = A.n_rows;
|
max@0
|
1954 blas_int n = A.n_cols;
|
max@0
|
1955 blas_int lda = A.n_rows;
|
max@0
|
1956 blas_int ldu = U.n_rows;
|
max@0
|
1957 blas_int ldvt = V.n_rows;
|
max@0
|
1958 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
|
max@0
|
1959 blas_int info;
|
max@0
|
1960
|
max@0
|
1961
|
max@0
|
1962 S.set_size( static_cast<uword>((std::min)(m,n)) );
|
max@0
|
1963 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
1964
|
max@0
|
1965 // let gesvd_() calculate the optimum size of the workspace
|
max@0
|
1966 blas_int lwork_tmp = -1;
|
max@0
|
1967
|
max@0
|
1968 lapack::gesvd<eT>
|
max@0
|
1969 (
|
max@0
|
1970 &jobu, &jobvt,
|
max@0
|
1971 &m, &n,
|
max@0
|
1972 A.memptr(), &lda,
|
max@0
|
1973 S.memptr(),
|
max@0
|
1974 U.memptr(), &ldu,
|
max@0
|
1975 V.memptr(), &ldvt,
|
max@0
|
1976 work.memptr(), &lwork_tmp,
|
max@0
|
1977 &info
|
max@0
|
1978 );
|
max@0
|
1979
|
max@0
|
1980 if(info == 0)
|
max@0
|
1981 {
|
max@0
|
1982 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
|
max@0
|
1983 if(proposed_lwork > lwork)
|
max@0
|
1984 {
|
max@0
|
1985 lwork = proposed_lwork;
|
max@0
|
1986 work.set_size( static_cast<uword>(lwork) );
|
max@0
|
1987 }
|
max@0
|
1988
|
max@0
|
1989 lapack::gesvd<eT>
|
max@0
|
1990 (
|
max@0
|
1991 &jobu, &jobvt,
|
max@0
|
1992 &m, &n,
|
max@0
|
1993 A.memptr(), &lda,
|
max@0
|
1994 S.memptr(),
|
max@0
|
1995 U.memptr(), &ldu,
|
max@0
|
1996 V.memptr(), &ldvt,
|
max@0
|
1997 work.memptr(), &lwork,
|
max@0
|
1998 &info
|
max@0
|
1999 );
|
max@0
|
2000
|
max@0
|
2001 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
|
max@0
|
2002 }
|
max@0
|
2003
|
max@0
|
2004 return (info == 0);
|
max@0
|
2005 }
|
max@0
|
2006 #else
|
max@0
|
2007 {
|
max@0
|
2008 arma_ignore(U);
|
max@0
|
2009 arma_ignore(S);
|
max@0
|
2010 arma_ignore(V);
|
max@0
|
2011 arma_ignore(X);
|
max@0
|
2012 arma_stop("svd(): use of LAPACK needs to be enabled");
|
max@0
|
2013 return false;
|
max@0
|
2014 }
|
max@0
|
2015 #endif
|
max@0
|
2016 }
|
max@0
|
2017
|
max@0
|
2018
|
max@0
|
2019
|
max@0
|
2020 template<typename T, typename T1>
|
max@0
|
2021 inline
|
max@0
|
2022 bool
|
max@0
|
2023 auxlib::svd(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X)
|
max@0
|
2024 {
|
max@0
|
2025 arma_extra_debug_sigprint();
|
max@0
|
2026
|
max@0
|
2027 typedef std::complex<T> eT;
|
max@0
|
2028
|
max@0
|
2029 #if defined(ARMA_USE_LAPACK)
|
max@0
|
2030 {
|
max@0
|
2031 Mat<eT> A(X.get_ref());
|
max@0
|
2032
|
max@0
|
2033 if(A.is_empty())
|
max@0
|
2034 {
|
max@0
|
2035 U.eye(A.n_rows, A.n_rows);
|
max@0
|
2036 S.reset();
|
max@0
|
2037 V.eye(A.n_cols, A.n_cols);
|
max@0
|
2038 return true;
|
max@0
|
2039 }
|
max@0
|
2040
|
max@0
|
2041 U.set_size(A.n_rows, A.n_rows);
|
max@0
|
2042 V.set_size(A.n_cols, A.n_cols);
|
max@0
|
2043
|
max@0
|
2044 char jobu = 'A';
|
max@0
|
2045 char jobvt = 'A';
|
max@0
|
2046
|
max@0
|
2047 blas_int m = A.n_rows;
|
max@0
|
2048 blas_int n = A.n_cols;
|
max@0
|
2049 blas_int lda = A.n_rows;
|
max@0
|
2050 blas_int ldu = U.n_rows;
|
max@0
|
2051 blas_int ldvt = V.n_rows;
|
max@0
|
2052 blas_int lwork = 2 * (std::max)(blas_int(1), 2*(std::min)(m,n)+(std::max)(m,n) );
|
max@0
|
2053 blas_int info;
|
max@0
|
2054
|
max@0
|
2055 S.set_size( static_cast<uword>((std::min)(m,n)) );
|
max@0
|
2056
|
max@0
|
2057 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
2058 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
|
max@0
|
2059
|
max@0
|
2060 // let gesvd_() calculate the optimum size of the workspace
|
max@0
|
2061 blas_int lwork_tmp = -1;
|
max@0
|
2062 lapack::cx_gesvd<T>
|
max@0
|
2063 (
|
max@0
|
2064 &jobu, &jobvt,
|
max@0
|
2065 &m, &n,
|
max@0
|
2066 A.memptr(), &lda,
|
max@0
|
2067 S.memptr(),
|
max@0
|
2068 U.memptr(), &ldu,
|
max@0
|
2069 V.memptr(), &ldvt,
|
max@0
|
2070 work.memptr(), &lwork_tmp,
|
max@0
|
2071 rwork.memptr(),
|
max@0
|
2072 &info
|
max@0
|
2073 );
|
max@0
|
2074
|
max@0
|
2075 if(info == 0)
|
max@0
|
2076 {
|
max@0
|
2077 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
|
max@0
|
2078 if(proposed_lwork > lwork)
|
max@0
|
2079 {
|
max@0
|
2080 lwork = proposed_lwork;
|
max@0
|
2081 work.set_size( static_cast<uword>(lwork) );
|
max@0
|
2082 }
|
max@0
|
2083
|
max@0
|
2084 lapack::cx_gesvd<T>
|
max@0
|
2085 (
|
max@0
|
2086 &jobu, &jobvt,
|
max@0
|
2087 &m, &n,
|
max@0
|
2088 A.memptr(), &lda,
|
max@0
|
2089 S.memptr(),
|
max@0
|
2090 U.memptr(), &ldu,
|
max@0
|
2091 V.memptr(), &ldvt,
|
max@0
|
2092 work.memptr(), &lwork,
|
max@0
|
2093 rwork.memptr(),
|
max@0
|
2094 &info
|
max@0
|
2095 );
|
max@0
|
2096
|
max@0
|
2097 op_htrans::apply(V,V); // op_htrans will work out that an in-place transpose can be done
|
max@0
|
2098 }
|
max@0
|
2099
|
max@0
|
2100 return (info == 0);
|
max@0
|
2101 }
|
max@0
|
2102 #else
|
max@0
|
2103 {
|
max@0
|
2104 arma_ignore(U);
|
max@0
|
2105 arma_ignore(S);
|
max@0
|
2106 arma_ignore(V);
|
max@0
|
2107 arma_ignore(X);
|
max@0
|
2108 arma_stop("svd(): use of LAPACK needs to be enabled");
|
max@0
|
2109 return false;
|
max@0
|
2110 }
|
max@0
|
2111 #endif
|
max@0
|
2112
|
max@0
|
2113 }
|
max@0
|
2114
|
max@0
|
2115
|
max@0
|
2116
|
max@0
|
2117 template<typename eT, typename T1>
|
max@0
|
2118 inline
|
max@0
|
2119 bool
|
max@0
|
2120 auxlib::svd_econ(Mat<eT>& U, Col<eT>& S, Mat<eT>& V, const Base<eT,T1>& X, const char mode)
|
max@0
|
2121 {
|
max@0
|
2122 arma_extra_debug_sigprint();
|
max@0
|
2123
|
max@0
|
2124 #if defined(ARMA_USE_LAPACK)
|
max@0
|
2125 {
|
max@0
|
2126 Mat<eT> A(X.get_ref());
|
max@0
|
2127
|
max@0
|
2128 blas_int m = A.n_rows;
|
max@0
|
2129 blas_int n = A.n_cols;
|
max@0
|
2130 blas_int lda = A.n_rows;
|
max@0
|
2131
|
max@0
|
2132 S.set_size( static_cast<uword>((std::min)(m,n)) );
|
max@0
|
2133
|
max@0
|
2134 blas_int ldu = 0;
|
max@0
|
2135 blas_int ldvt = 0;
|
max@0
|
2136
|
max@0
|
2137 char jobu;
|
max@0
|
2138 char jobvt;
|
max@0
|
2139
|
max@0
|
2140 switch(mode)
|
max@0
|
2141 {
|
max@0
|
2142 case 'l':
|
max@0
|
2143 jobu = 'S';
|
max@0
|
2144 jobvt = 'N';
|
max@0
|
2145
|
max@0
|
2146 ldu = m;
|
max@0
|
2147 ldvt = 1;
|
max@0
|
2148
|
max@0
|
2149 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
|
max@0
|
2150 V.reset();
|
max@0
|
2151
|
max@0
|
2152 break;
|
max@0
|
2153
|
max@0
|
2154
|
max@0
|
2155 case 'r':
|
max@0
|
2156 jobu = 'N';
|
max@0
|
2157 jobvt = 'S';
|
max@0
|
2158
|
max@0
|
2159 ldu = 1;
|
max@0
|
2160 ldvt = (std::min)(m,n);
|
max@0
|
2161
|
max@0
|
2162 U.reset();
|
max@0
|
2163 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
|
max@0
|
2164
|
max@0
|
2165 break;
|
max@0
|
2166
|
max@0
|
2167
|
max@0
|
2168 case 'b':
|
max@0
|
2169 jobu = 'S';
|
max@0
|
2170 jobvt = 'S';
|
max@0
|
2171
|
max@0
|
2172 ldu = m;
|
max@0
|
2173 ldvt = (std::min)(m,n);
|
max@0
|
2174
|
max@0
|
2175 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
|
max@0
|
2176 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
|
max@0
|
2177
|
max@0
|
2178 break;
|
max@0
|
2179
|
max@0
|
2180
|
max@0
|
2181 default:
|
max@0
|
2182 U.reset();
|
max@0
|
2183 S.reset();
|
max@0
|
2184 V.reset();
|
max@0
|
2185 return false;
|
max@0
|
2186 }
|
max@0
|
2187
|
max@0
|
2188
|
max@0
|
2189 if(A.is_empty())
|
max@0
|
2190 {
|
max@0
|
2191 U.eye();
|
max@0
|
2192 S.reset();
|
max@0
|
2193 V.eye();
|
max@0
|
2194 return true;
|
max@0
|
2195 }
|
max@0
|
2196
|
max@0
|
2197
|
max@0
|
2198 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
|
max@0
|
2199 blas_int info = 0;
|
max@0
|
2200
|
max@0
|
2201
|
max@0
|
2202 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
2203
|
max@0
|
2204 // let gesvd_() calculate the optimum size of the workspace
|
max@0
|
2205 blas_int lwork_tmp = -1;
|
max@0
|
2206
|
max@0
|
2207 lapack::gesvd<eT>
|
max@0
|
2208 (
|
max@0
|
2209 &jobu, &jobvt,
|
max@0
|
2210 &m, &n,
|
max@0
|
2211 A.memptr(), &lda,
|
max@0
|
2212 S.memptr(),
|
max@0
|
2213 U.memptr(), &ldu,
|
max@0
|
2214 V.memptr(), &ldvt,
|
max@0
|
2215 work.memptr(), &lwork_tmp,
|
max@0
|
2216 &info
|
max@0
|
2217 );
|
max@0
|
2218
|
max@0
|
2219 if(info == 0)
|
max@0
|
2220 {
|
max@0
|
2221 blas_int proposed_lwork = static_cast<blas_int>(work[0]);
|
max@0
|
2222 if(proposed_lwork > lwork)
|
max@0
|
2223 {
|
max@0
|
2224 lwork = proposed_lwork;
|
max@0
|
2225 work.set_size( static_cast<uword>(lwork) );
|
max@0
|
2226 }
|
max@0
|
2227
|
max@0
|
2228 lapack::gesvd<eT>
|
max@0
|
2229 (
|
max@0
|
2230 &jobu, &jobvt,
|
max@0
|
2231 &m, &n,
|
max@0
|
2232 A.memptr(), &lda,
|
max@0
|
2233 S.memptr(),
|
max@0
|
2234 U.memptr(), &ldu,
|
max@0
|
2235 V.memptr(), &ldvt,
|
max@0
|
2236 work.memptr(), &lwork,
|
max@0
|
2237 &info
|
max@0
|
2238 );
|
max@0
|
2239
|
max@0
|
2240 op_strans::apply(V,V); // op_strans will work out that an in-place transpose can be done
|
max@0
|
2241 }
|
max@0
|
2242
|
max@0
|
2243 return (info == 0);
|
max@0
|
2244 }
|
max@0
|
2245 #else
|
max@0
|
2246 {
|
max@0
|
2247 arma_ignore(U);
|
max@0
|
2248 arma_ignore(S);
|
max@0
|
2249 arma_ignore(V);
|
max@0
|
2250 arma_ignore(X);
|
max@0
|
2251 arma_ignore(mode);
|
max@0
|
2252 arma_stop("svd(): use of LAPACK needs to be enabled");
|
max@0
|
2253 return false;
|
max@0
|
2254 }
|
max@0
|
2255 #endif
|
max@0
|
2256 }
|
max@0
|
2257
|
max@0
|
2258
|
max@0
|
2259
|
max@0
|
2260 template<typename T, typename T1>
|
max@0
|
2261 inline
|
max@0
|
2262 bool
|
max@0
|
2263 auxlib::svd_econ(Mat< std::complex<T> >& U, Col<T>& S, Mat< std::complex<T> >& V, const Base< std::complex<T>, T1>& X, const char mode)
|
max@0
|
2264 {
|
max@0
|
2265 arma_extra_debug_sigprint();
|
max@0
|
2266
|
max@0
|
2267 typedef std::complex<T> eT;
|
max@0
|
2268
|
max@0
|
2269 #if defined(ARMA_USE_LAPACK)
|
max@0
|
2270 {
|
max@0
|
2271 Mat<eT> A(X.get_ref());
|
max@0
|
2272
|
max@0
|
2273 blas_int m = A.n_rows;
|
max@0
|
2274 blas_int n = A.n_cols;
|
max@0
|
2275 blas_int lda = A.n_rows;
|
max@0
|
2276
|
max@0
|
2277 S.set_size( static_cast<uword>((std::min)(m,n)) );
|
max@0
|
2278
|
max@0
|
2279 blas_int ldu = 0;
|
max@0
|
2280 blas_int ldvt = 0;
|
max@0
|
2281
|
max@0
|
2282 char jobu;
|
max@0
|
2283 char jobvt;
|
max@0
|
2284
|
max@0
|
2285 switch(mode)
|
max@0
|
2286 {
|
max@0
|
2287 case 'l':
|
max@0
|
2288 jobu = 'S';
|
max@0
|
2289 jobvt = 'N';
|
max@0
|
2290
|
max@0
|
2291 ldu = m;
|
max@0
|
2292 ldvt = 1;
|
max@0
|
2293
|
max@0
|
2294 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
|
max@0
|
2295 V.reset();
|
max@0
|
2296
|
max@0
|
2297 break;
|
max@0
|
2298
|
max@0
|
2299
|
max@0
|
2300 case 'r':
|
max@0
|
2301 jobu = 'N';
|
max@0
|
2302 jobvt = 'S';
|
max@0
|
2303
|
max@0
|
2304 ldu = 1;
|
max@0
|
2305 ldvt = (std::min)(m,n);
|
max@0
|
2306
|
max@0
|
2307 U.reset();
|
max@0
|
2308 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
|
max@0
|
2309
|
max@0
|
2310 break;
|
max@0
|
2311
|
max@0
|
2312
|
max@0
|
2313 case 'b':
|
max@0
|
2314 jobu = 'S';
|
max@0
|
2315 jobvt = 'S';
|
max@0
|
2316
|
max@0
|
2317 ldu = m;
|
max@0
|
2318 ldvt = (std::min)(m,n);
|
max@0
|
2319
|
max@0
|
2320 U.set_size( static_cast<uword>(ldu), static_cast<uword>((std::min)(m,n)) );
|
max@0
|
2321 V.set_size( static_cast<uword>(ldvt), static_cast<uword>(n) );
|
max@0
|
2322
|
max@0
|
2323 break;
|
max@0
|
2324
|
max@0
|
2325
|
max@0
|
2326 default:
|
max@0
|
2327 U.reset();
|
max@0
|
2328 S.reset();
|
max@0
|
2329 V.reset();
|
max@0
|
2330 return false;
|
max@0
|
2331 }
|
max@0
|
2332
|
max@0
|
2333
|
max@0
|
2334 if(A.is_empty())
|
max@0
|
2335 {
|
max@0
|
2336 U.eye();
|
max@0
|
2337 S.reset();
|
max@0
|
2338 V.eye();
|
max@0
|
2339 return true;
|
max@0
|
2340 }
|
max@0
|
2341
|
max@0
|
2342
|
max@0
|
2343 blas_int lwork = 2 * (std::max)(blas_int(1), (std::max)( (3*(std::min)(m,n) + (std::max)(m,n)), 5*(std::min)(m,n) ) );
|
max@0
|
2344 blas_int info = 0;
|
max@0
|
2345
|
max@0
|
2346
|
max@0
|
2347 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
2348 podarray<T> rwork( static_cast<uword>(5*(std::min)(m,n)) );
|
max@0
|
2349
|
max@0
|
2350 // let gesvd_() calculate the optimum size of the workspace
|
max@0
|
2351 blas_int lwork_tmp = -1;
|
max@0
|
2352
|
max@0
|
2353 lapack::cx_gesvd<T>
|
max@0
|
2354 (
|
max@0
|
2355 &jobu, &jobvt,
|
max@0
|
2356 &m, &n,
|
max@0
|
2357 A.memptr(), &lda,
|
max@0
|
2358 S.memptr(),
|
max@0
|
2359 U.memptr(), &ldu,
|
max@0
|
2360 V.memptr(), &ldvt,
|
max@0
|
2361 work.memptr(), &lwork_tmp,
|
max@0
|
2362 rwork.memptr(),
|
max@0
|
2363 &info
|
max@0
|
2364 );
|
max@0
|
2365
|
max@0
|
2366 if(info == 0)
|
max@0
|
2367 {
|
max@0
|
2368 blas_int proposed_lwork = static_cast<blas_int>(real(work[0]));
|
max@0
|
2369 if(proposed_lwork > lwork)
|
max@0
|
2370 {
|
max@0
|
2371 lwork = proposed_lwork;
|
max@0
|
2372 work.set_size( static_cast<uword>(lwork) );
|
max@0
|
2373 }
|
max@0
|
2374
|
max@0
|
2375 lapack::cx_gesvd<T>
|
max@0
|
2376 (
|
max@0
|
2377 &jobu, &jobvt,
|
max@0
|
2378 &m, &n,
|
max@0
|
2379 A.memptr(), &lda,
|
max@0
|
2380 S.memptr(),
|
max@0
|
2381 U.memptr(), &ldu,
|
max@0
|
2382 V.memptr(), &ldvt,
|
max@0
|
2383 work.memptr(), &lwork,
|
max@0
|
2384 rwork.memptr(),
|
max@0
|
2385 &info
|
max@0
|
2386 );
|
max@0
|
2387
|
max@0
|
2388 op_htrans::apply(V,V); // op_strans will work out that an in-place transpose can be done
|
max@0
|
2389 }
|
max@0
|
2390
|
max@0
|
2391 return (info == 0);
|
max@0
|
2392 }
|
max@0
|
2393 #else
|
max@0
|
2394 {
|
max@0
|
2395 arma_ignore(U);
|
max@0
|
2396 arma_ignore(S);
|
max@0
|
2397 arma_ignore(V);
|
max@0
|
2398 arma_ignore(X);
|
max@0
|
2399 arma_ignore(mode);
|
max@0
|
2400 arma_stop("svd(): use of LAPACK needs to be enabled");
|
max@0
|
2401 return false;
|
max@0
|
2402 }
|
max@0
|
2403 #endif
|
max@0
|
2404 }
|
max@0
|
2405
|
max@0
|
2406
|
max@0
|
2407
|
max@0
|
2408 //! Solve a system of linear equations.
|
max@0
|
2409 //! Assumes that A.n_rows = A.n_cols and B.n_rows = A.n_rows
|
max@0
|
2410 template<typename eT>
|
max@0
|
2411 inline
|
max@0
|
2412 bool
|
max@0
|
2413 auxlib::solve(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B, const bool slow)
|
max@0
|
2414 {
|
max@0
|
2415 arma_extra_debug_sigprint();
|
max@0
|
2416
|
max@0
|
2417 if(A.is_empty() || B.is_empty())
|
max@0
|
2418 {
|
max@0
|
2419 out.zeros(A.n_cols, B.n_cols);
|
max@0
|
2420 return true;
|
max@0
|
2421 }
|
max@0
|
2422 else
|
max@0
|
2423 {
|
max@0
|
2424 const uword A_n_rows = A.n_rows;
|
max@0
|
2425
|
max@0
|
2426 bool status = false;
|
max@0
|
2427
|
max@0
|
2428 if( (A_n_rows <= 4) && (slow == false) )
|
max@0
|
2429 {
|
max@0
|
2430 Mat<eT> A_inv;
|
max@0
|
2431
|
max@0
|
2432 status = auxlib::inv_noalias_tinymat(A_inv, A, A_n_rows);
|
max@0
|
2433
|
max@0
|
2434 if(status == true)
|
max@0
|
2435 {
|
max@0
|
2436 out.set_size(A_n_rows, B.n_cols);
|
max@0
|
2437
|
max@0
|
2438 gemm_emul<false,false,false,false>::apply(out, A_inv, B);
|
max@0
|
2439
|
max@0
|
2440 return true;
|
max@0
|
2441 }
|
max@0
|
2442 }
|
max@0
|
2443
|
max@0
|
2444 if( (A_n_rows > 4) || (status == false) )
|
max@0
|
2445 {
|
max@0
|
2446 #if defined(ARMA_USE_ATLAS)
|
max@0
|
2447 {
|
max@0
|
2448 podarray<int> ipiv(A_n_rows);
|
max@0
|
2449
|
max@0
|
2450 out = B;
|
max@0
|
2451
|
max@0
|
2452 int info = atlas::clapack_gesv<eT>(atlas::CblasColMajor, A_n_rows, B.n_cols, A.memptr(), A_n_rows, ipiv.memptr(), out.memptr(), A_n_rows);
|
max@0
|
2453
|
max@0
|
2454 return (info == 0);
|
max@0
|
2455 }
|
max@0
|
2456 #elif defined(ARMA_USE_LAPACK)
|
max@0
|
2457 {
|
max@0
|
2458 blas_int n = A_n_rows;
|
max@0
|
2459 blas_int lda = A_n_rows;
|
max@0
|
2460 blas_int ldb = A_n_rows;
|
max@0
|
2461 blas_int nrhs = B.n_cols;
|
max@0
|
2462 blas_int info;
|
max@0
|
2463
|
max@0
|
2464 podarray<blas_int> ipiv(A_n_rows);
|
max@0
|
2465
|
max@0
|
2466 out = B;
|
max@0
|
2467
|
max@0
|
2468 lapack::gesv<eT>(&n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info);
|
max@0
|
2469
|
max@0
|
2470 return (info == 0);
|
max@0
|
2471 }
|
max@0
|
2472 #else
|
max@0
|
2473 {
|
max@0
|
2474 arma_stop("solve(): use of ATLAS or LAPACK needs to be enabled");
|
max@0
|
2475 return false;
|
max@0
|
2476 }
|
max@0
|
2477 #endif
|
max@0
|
2478 }
|
max@0
|
2479 }
|
max@0
|
2480
|
max@0
|
2481 return true;
|
max@0
|
2482 }
|
max@0
|
2483
|
max@0
|
2484
|
max@0
|
2485
|
max@0
|
2486 //! Solve an over-determined system.
|
max@0
|
2487 //! Assumes that A.n_rows > A.n_cols and B.n_rows = A.n_rows
|
max@0
|
2488 template<typename eT>
|
max@0
|
2489 inline
|
max@0
|
2490 bool
|
max@0
|
2491 auxlib::solve_od(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B)
|
max@0
|
2492 {
|
max@0
|
2493 arma_extra_debug_sigprint();
|
max@0
|
2494
|
max@0
|
2495 #if defined(ARMA_USE_LAPACK)
|
max@0
|
2496 {
|
max@0
|
2497 if(A.is_empty() || B.is_empty())
|
max@0
|
2498 {
|
max@0
|
2499 out.zeros(A.n_cols, B.n_cols);
|
max@0
|
2500 return true;
|
max@0
|
2501 }
|
max@0
|
2502
|
max@0
|
2503 char trans = 'N';
|
max@0
|
2504
|
max@0
|
2505 blas_int m = A.n_rows;
|
max@0
|
2506 blas_int n = A.n_cols;
|
max@0
|
2507 blas_int lda = A.n_rows;
|
max@0
|
2508 blas_int ldb = A.n_rows;
|
max@0
|
2509 blas_int nrhs = B.n_cols;
|
max@0
|
2510 blas_int lwork = n + (std::max)(n, nrhs);
|
max@0
|
2511 blas_int info;
|
max@0
|
2512
|
max@0
|
2513 Mat<eT> tmp = B;
|
max@0
|
2514
|
max@0
|
2515 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
2516
|
max@0
|
2517 arma_extra_debug_print("lapack::gels()");
|
max@0
|
2518
|
max@0
|
2519 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems
|
max@0
|
2520
|
max@0
|
2521 lapack::gels<eT>
|
max@0
|
2522 (
|
max@0
|
2523 &trans, &m, &n, &nrhs,
|
max@0
|
2524 A.memptr(), &lda,
|
max@0
|
2525 tmp.memptr(), &ldb,
|
max@0
|
2526 work.memptr(), &lwork,
|
max@0
|
2527 &info
|
max@0
|
2528 );
|
max@0
|
2529
|
max@0
|
2530 arma_extra_debug_print("lapack::gels() -- finished");
|
max@0
|
2531
|
max@0
|
2532 out.set_size(A.n_cols, B.n_cols);
|
max@0
|
2533
|
max@0
|
2534 for(uword col=0; col<B.n_cols; ++col)
|
max@0
|
2535 {
|
max@0
|
2536 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols );
|
max@0
|
2537 }
|
max@0
|
2538
|
max@0
|
2539 return (info == 0);
|
max@0
|
2540 }
|
max@0
|
2541 #else
|
max@0
|
2542 {
|
max@0
|
2543 arma_ignore(out);
|
max@0
|
2544 arma_ignore(A);
|
max@0
|
2545 arma_ignore(B);
|
max@0
|
2546 arma_stop("solve(): use of LAPACK needs to be enabled");
|
max@0
|
2547 return false;
|
max@0
|
2548 }
|
max@0
|
2549 #endif
|
max@0
|
2550 }
|
max@0
|
2551
|
max@0
|
2552
|
max@0
|
2553
|
max@0
|
2554 //! Solve an under-determined system.
|
max@0
|
2555 //! Assumes that A.n_rows < A.n_cols and B.n_rows = A.n_rows
|
max@0
|
2556 template<typename eT>
|
max@0
|
2557 inline
|
max@0
|
2558 bool
|
max@0
|
2559 auxlib::solve_ud(Mat<eT>& out, Mat<eT>& A, const Mat<eT>& B)
|
max@0
|
2560 {
|
max@0
|
2561 arma_extra_debug_sigprint();
|
max@0
|
2562
|
max@0
|
2563 #if defined(ARMA_USE_LAPACK)
|
max@0
|
2564 {
|
max@0
|
2565 if(A.is_empty() || B.is_empty())
|
max@0
|
2566 {
|
max@0
|
2567 out.zeros(A.n_cols, B.n_cols);
|
max@0
|
2568 return true;
|
max@0
|
2569 }
|
max@0
|
2570
|
max@0
|
2571 char trans = 'N';
|
max@0
|
2572
|
max@0
|
2573 blas_int m = A.n_rows;
|
max@0
|
2574 blas_int n = A.n_cols;
|
max@0
|
2575 blas_int lda = A.n_rows;
|
max@0
|
2576 blas_int ldb = A.n_cols;
|
max@0
|
2577 blas_int nrhs = B.n_cols;
|
max@0
|
2578 blas_int lwork = m + (std::max)(m,nrhs);
|
max@0
|
2579 blas_int info;
|
max@0
|
2580
|
max@0
|
2581
|
max@0
|
2582 Mat<eT> tmp;
|
max@0
|
2583 tmp.zeros(A.n_cols, B.n_cols);
|
max@0
|
2584
|
max@0
|
2585 for(uword col=0; col<B.n_cols; ++col)
|
max@0
|
2586 {
|
max@0
|
2587 eT* tmp_colmem = tmp.colptr(col);
|
max@0
|
2588
|
max@0
|
2589 arrayops::copy( tmp_colmem, B.colptr(col), B.n_rows );
|
max@0
|
2590
|
max@0
|
2591 for(uword row=B.n_rows; row<A.n_cols; ++row)
|
max@0
|
2592 {
|
max@0
|
2593 tmp_colmem[row] = eT(0);
|
max@0
|
2594 }
|
max@0
|
2595 }
|
max@0
|
2596
|
max@0
|
2597 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
2598
|
max@0
|
2599 arma_extra_debug_print("lapack::gels()");
|
max@0
|
2600
|
max@0
|
2601 // NOTE: the dgels() function in the lapack library supplied by ATLAS 3.6 seems to have problems
|
max@0
|
2602
|
max@0
|
2603 lapack::gels<eT>
|
max@0
|
2604 (
|
max@0
|
2605 &trans, &m, &n, &nrhs,
|
max@0
|
2606 A.memptr(), &lda,
|
max@0
|
2607 tmp.memptr(), &ldb,
|
max@0
|
2608 work.memptr(), &lwork,
|
max@0
|
2609 &info
|
max@0
|
2610 );
|
max@0
|
2611
|
max@0
|
2612 arma_extra_debug_print("lapack::gels() -- finished");
|
max@0
|
2613
|
max@0
|
2614 out.set_size(A.n_cols, B.n_cols);
|
max@0
|
2615
|
max@0
|
2616 for(uword col=0; col<B.n_cols; ++col)
|
max@0
|
2617 {
|
max@0
|
2618 arrayops::copy( out.colptr(col), tmp.colptr(col), A.n_cols );
|
max@0
|
2619 }
|
max@0
|
2620
|
max@0
|
2621 return (info == 0);
|
max@0
|
2622 }
|
max@0
|
2623 #else
|
max@0
|
2624 {
|
max@0
|
2625 arma_ignore(out);
|
max@0
|
2626 arma_ignore(A);
|
max@0
|
2627 arma_ignore(B);
|
max@0
|
2628 arma_stop("solve(): use of LAPACK needs to be enabled");
|
max@0
|
2629 return false;
|
max@0
|
2630 }
|
max@0
|
2631 #endif
|
max@0
|
2632 }
|
max@0
|
2633
|
max@0
|
2634
|
max@0
|
2635
|
max@0
|
2636 //
|
max@0
|
2637 // solve_tr
|
max@0
|
2638
|
max@0
|
2639 template<typename eT>
|
max@0
|
2640 inline
|
max@0
|
2641 bool
|
max@0
|
2642 auxlib::solve_tr(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const uword layout)
|
max@0
|
2643 {
|
max@0
|
2644 arma_extra_debug_sigprint();
|
max@0
|
2645
|
max@0
|
2646 #if defined(ARMA_USE_LAPACK)
|
max@0
|
2647 {
|
max@0
|
2648 if(A.is_empty() || B.is_empty())
|
max@0
|
2649 {
|
max@0
|
2650 out.zeros(A.n_cols, B.n_cols);
|
max@0
|
2651 return true;
|
max@0
|
2652 }
|
max@0
|
2653
|
max@0
|
2654 out = B;
|
max@0
|
2655
|
max@0
|
2656 char uplo = (layout == 0) ? 'U' : 'L';
|
max@0
|
2657 char trans = 'N';
|
max@0
|
2658 char diag = 'N';
|
max@0
|
2659 blas_int n = blas_int(A.n_rows);
|
max@0
|
2660 blas_int nrhs = blas_int(B.n_cols);
|
max@0
|
2661 blas_int info = 0;
|
max@0
|
2662
|
max@0
|
2663 lapack::trtrs<eT>(&uplo, &trans, &diag, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info);
|
max@0
|
2664
|
max@0
|
2665 return (info == 0);
|
max@0
|
2666 }
|
max@0
|
2667 #else
|
max@0
|
2668 {
|
max@0
|
2669 arma_ignore(out);
|
max@0
|
2670 arma_ignore(A);
|
max@0
|
2671 arma_ignore(B);
|
max@0
|
2672 arma_ignore(layout);
|
max@0
|
2673 arma_stop("solve(): use of LAPACK needs to be enabled");
|
max@0
|
2674 return false;
|
max@0
|
2675 }
|
max@0
|
2676 #endif
|
max@0
|
2677 }
|
max@0
|
2678
|
max@0
|
2679
|
max@0
|
2680
|
max@0
|
2681 //
|
max@0
|
2682 // Schur decomposition
|
max@0
|
2683
|
max@0
|
2684 template<typename eT>
|
max@0
|
2685 inline
|
max@0
|
2686 bool
|
max@0
|
2687 auxlib::schur_dec(Mat<eT>& Z, Mat<eT>& T, const Mat<eT>& A)
|
max@0
|
2688 {
|
max@0
|
2689 arma_extra_debug_sigprint();
|
max@0
|
2690
|
max@0
|
2691 #if defined(ARMA_USE_LAPACK)
|
max@0
|
2692 {
|
max@0
|
2693 arma_debug_check( (A.is_square() == false), "schur_dec(): given matrix is not square" );
|
max@0
|
2694
|
max@0
|
2695 if(A.is_empty())
|
max@0
|
2696 {
|
max@0
|
2697 Z.reset();
|
max@0
|
2698 T.reset();
|
max@0
|
2699 return true;
|
max@0
|
2700 }
|
max@0
|
2701
|
max@0
|
2702 const uword A_n_rows = A.n_rows;
|
max@0
|
2703
|
max@0
|
2704 char jobvs = 'V'; // get Schur vectors (Z)
|
max@0
|
2705 char sort = 'N'; // do not sort eigenvalues/vectors
|
max@0
|
2706 blas_int* select = 0; // pointer to sorting function
|
max@0
|
2707 blas_int n = blas_int(A_n_rows);
|
max@0
|
2708 blas_int sdim = 0; // output for sorting
|
max@0
|
2709
|
max@0
|
2710 blas_int lwork = 3 * n; // workspace must be at least 3 * n (if set to -1, optimal size is output in work(0) and nothing else is done
|
max@0
|
2711
|
max@0
|
2712 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
2713 podarray<blas_int> bwork(A_n_rows);
|
max@0
|
2714
|
max@0
|
2715 blas_int info = 0;
|
max@0
|
2716
|
max@0
|
2717 Z.set_size(A_n_rows, A_n_rows);
|
max@0
|
2718 T = A;
|
max@0
|
2719
|
max@0
|
2720 podarray<eT> wr(A_n_rows); // output for eigenvalues
|
max@0
|
2721 podarray<eT> wi(A_n_rows); // output for eigenvalues
|
max@0
|
2722
|
max@0
|
2723 lapack::gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, wr.memptr(), wi.memptr(), Z.memptr(), &n, work.memptr(), &lwork, bwork.memptr(), &info);
|
max@0
|
2724
|
max@0
|
2725 return (info == 0);
|
max@0
|
2726 }
|
max@0
|
2727 #else
|
max@0
|
2728 {
|
max@0
|
2729 arma_ignore(Z);
|
max@0
|
2730 arma_ignore(T);
|
max@0
|
2731 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
|
max@0
|
2732 return false;
|
max@0
|
2733 }
|
max@0
|
2734 #endif
|
max@0
|
2735 }
|
max@0
|
2736
|
max@0
|
2737
|
max@0
|
2738
|
max@0
|
2739 template<typename cT>
|
max@0
|
2740 inline
|
max@0
|
2741 bool
|
max@0
|
2742 auxlib::schur_dec(Mat<std::complex<cT> >& Z, Mat<std::complex<cT> >& T, const Mat<std::complex<cT> >& A)
|
max@0
|
2743 {
|
max@0
|
2744 arma_extra_debug_sigprint();
|
max@0
|
2745
|
max@0
|
2746 #if defined(ARMA_USE_LAPACK)
|
max@0
|
2747 {
|
max@0
|
2748 arma_debug_check( (A.is_square() == false), "schur_dec(): matrix A is not square" );
|
max@0
|
2749
|
max@0
|
2750 if(A.is_empty())
|
max@0
|
2751 {
|
max@0
|
2752 Z.reset();
|
max@0
|
2753 T.reset();
|
max@0
|
2754 return true;
|
max@0
|
2755 }
|
max@0
|
2756
|
max@0
|
2757 typedef std::complex<cT> eT;
|
max@0
|
2758
|
max@0
|
2759 const uword A_n_rows = A.n_rows;
|
max@0
|
2760
|
max@0
|
2761 char jobvs = 'V'; // get Schur vectors (Z)
|
max@0
|
2762 char sort = 'N'; // do not sort eigenvalues/vectors
|
max@0
|
2763 blas_int* select = 0; // pointer to sorting function
|
max@0
|
2764 blas_int n = blas_int(A_n_rows);
|
max@0
|
2765 blas_int sdim = 0; // output for sorting
|
max@0
|
2766
|
max@0
|
2767 blas_int lwork = 3 * n; // workspace must be at least 3 * n (if set to -1, optimal size is output in work(0) and nothing else is done
|
max@0
|
2768
|
max@0
|
2769 podarray<eT> work( static_cast<uword>(lwork) );
|
max@0
|
2770 podarray<blas_int> bwork(A_n_rows);
|
max@0
|
2771
|
max@0
|
2772 blas_int info = 0;
|
max@0
|
2773
|
max@0
|
2774 Z.set_size(A_n_rows, A_n_rows);
|
max@0
|
2775 T = A;
|
max@0
|
2776
|
max@0
|
2777 podarray<eT> w(A_n_rows); // output for eigenvalues
|
max@0
|
2778 podarray<cT> rwork(A_n_rows);
|
max@0
|
2779
|
max@0
|
2780 lapack::cx_gees(&jobvs, &sort, select, &n, T.memptr(), &n, &sdim, w.memptr(), Z.memptr(), &n, work.memptr(), &lwork, rwork.memptr(), bwork.memptr(), &info);
|
max@0
|
2781
|
max@0
|
2782 return (info == 0);
|
max@0
|
2783 }
|
max@0
|
2784 #else
|
max@0
|
2785 {
|
max@0
|
2786 arma_ignore(Z);
|
max@0
|
2787 arma_ignore(T);
|
max@0
|
2788 arma_stop("schur_dec(): use of LAPACK needs to be enabled");
|
max@0
|
2789 return false;
|
max@0
|
2790 }
|
max@0
|
2791 #endif
|
max@0
|
2792 }
|
max@0
|
2793
|
max@0
|
2794
|
max@0
|
2795
|
max@0
|
2796 //
|
max@0
|
2797 // syl (solution of the Sylvester equation AX + XB = C)
|
max@0
|
2798
|
max@0
|
2799 template<typename eT>
|
max@0
|
2800 inline
|
max@0
|
2801 bool
|
max@0
|
2802 auxlib::syl(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C)
|
max@0
|
2803 {
|
max@0
|
2804 arma_extra_debug_sigprint();
|
max@0
|
2805
|
max@0
|
2806 arma_debug_check
|
max@0
|
2807 (
|
max@0
|
2808 (A.is_square() == false) || (B.is_square() == false),
|
max@0
|
2809 "syl(): given matrix is not square"
|
max@0
|
2810 );
|
max@0
|
2811
|
max@0
|
2812 arma_debug_check
|
max@0
|
2813 (
|
max@0
|
2814 (C.n_rows != A.n_rows) || (C.n_cols != B.n_cols),
|
max@0
|
2815 "syl(): matrices are not conformant"
|
max@0
|
2816 );
|
max@0
|
2817
|
max@0
|
2818 if(A.is_empty() || B.is_empty() || C.is_empty())
|
max@0
|
2819 {
|
max@0
|
2820 X.reset();
|
max@0
|
2821 return true;
|
max@0
|
2822 }
|
max@0
|
2823
|
max@0
|
2824 #if defined(ARMA_USE_LAPACK)
|
max@0
|
2825 {
|
max@0
|
2826 Mat<eT> Z1, Z2, T1, T2;
|
max@0
|
2827
|
max@0
|
2828 const bool status_sd1 = auxlib::schur_dec(Z1, T1, A);
|
max@0
|
2829 const bool status_sd2 = auxlib::schur_dec(Z2, T2, B);
|
max@0
|
2830
|
max@0
|
2831 if( (status_sd1 == false) || (status_sd2 == false) )
|
max@0
|
2832 {
|
max@0
|
2833 return false;
|
max@0
|
2834 }
|
max@0
|
2835
|
max@0
|
2836 char trana = 'N';
|
max@0
|
2837 char tranb = 'N';
|
max@0
|
2838 blas_int isgn = +1;
|
max@0
|
2839 blas_int m = blas_int(T1.n_rows);
|
max@0
|
2840 blas_int n = blas_int(T2.n_cols);
|
max@0
|
2841
|
max@0
|
2842 eT scale = eT(0);
|
max@0
|
2843 blas_int info = 0;
|
max@0
|
2844
|
max@0
|
2845 Mat<eT> Y = trans(Z1) * C * Z2;
|
max@0
|
2846
|
max@0
|
2847 lapack::trsyl<eT>(&trana, &tranb, &isgn, &m, &n, T1.memptr(), &m, T2.memptr(), &n, Y.memptr(), &m, &scale, &info);
|
max@0
|
2848
|
max@0
|
2849 //Y /= scale;
|
max@0
|
2850 Y /= (-scale);
|
max@0
|
2851
|
max@0
|
2852 X = Z1 * Y * trans(Z2);
|
max@0
|
2853
|
max@0
|
2854 return (info >= 0);
|
max@0
|
2855 }
|
max@0
|
2856 #else
|
max@0
|
2857 {
|
max@0
|
2858 arma_stop("syl(): use of LAPACK needs to be enabled");
|
max@0
|
2859 return false;
|
max@0
|
2860 }
|
max@0
|
2861 #endif
|
max@0
|
2862 }
|
max@0
|
2863
|
max@0
|
2864
|
max@0
|
2865
|
max@0
|
2866 //
|
max@0
|
2867 // lyap (solution of the continuous Lyapunov equation AX + XA^H + Q = 0)
|
max@0
|
2868
|
max@0
|
2869 template<typename eT>
|
max@0
|
2870 inline
|
max@0
|
2871 bool
|
max@0
|
2872 auxlib::lyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
|
max@0
|
2873 {
|
max@0
|
2874 arma_extra_debug_sigprint();
|
max@0
|
2875
|
max@0
|
2876 arma_debug_check( (A.is_square() == false), "lyap(): matrix A is not square");
|
max@0
|
2877 arma_debug_check( (Q.is_square() == false), "lyap(): matrix Q is not square");
|
max@0
|
2878 arma_debug_check( (A.n_rows != Q.n_rows), "lyap(): matrices A and Q have different dimensions");
|
max@0
|
2879
|
max@0
|
2880 Mat<eT> htransA;
|
max@0
|
2881 op_htrans::apply_noalias(htransA, A);
|
max@0
|
2882
|
max@0
|
2883 const Mat<eT> mQ = -Q;
|
max@0
|
2884
|
max@0
|
2885 return auxlib::syl(X, A, htransA, mQ);
|
max@0
|
2886 }
|
max@0
|
2887
|
max@0
|
2888
|
max@0
|
2889
|
max@0
|
2890 //
|
max@0
|
2891 // dlyap (solution of the discrete Lyapunov equation AXA^H - X + Q = 0)
|
max@0
|
2892
|
max@0
|
2893 template<typename eT>
|
max@0
|
2894 inline
|
max@0
|
2895 bool
|
max@0
|
2896 auxlib::dlyap(Mat<eT>& X, const Mat<eT>& A, const Mat<eT>& Q)
|
max@0
|
2897 {
|
max@0
|
2898 arma_extra_debug_sigprint();
|
max@0
|
2899
|
max@0
|
2900 arma_debug_check( (A.is_square() == false), "dlyap(): matrix A is not square");
|
max@0
|
2901 arma_debug_check( (Q.is_square() == false), "dlyap(): matrix Q is not square");
|
max@0
|
2902 arma_debug_check( (A.n_rows != Q.n_rows), "dlyap(): matrices A and Q have different dimensions");
|
max@0
|
2903
|
max@0
|
2904 const Col<eT> vecQ = reshape(Q, Q.n_elem, 1);
|
max@0
|
2905
|
max@0
|
2906 const Mat<eT> M = eye< Mat<eT> >(Q.n_elem, Q.n_elem) - kron(conj(A), A);
|
max@0
|
2907
|
max@0
|
2908 Col<eT> vecX;
|
max@0
|
2909
|
max@0
|
2910 const bool status = solve(vecX, M, vecQ);
|
max@0
|
2911
|
max@0
|
2912 if(status == true)
|
max@0
|
2913 {
|
max@0
|
2914 X = reshape(vecX, Q.n_rows, Q.n_cols);
|
max@0
|
2915 return true;
|
max@0
|
2916 }
|
max@0
|
2917 else
|
max@0
|
2918 {
|
max@0
|
2919 X.reset();
|
max@0
|
2920 return false;
|
max@0
|
2921 }
|
max@0
|
2922 }
|
max@0
|
2923
|
max@0
|
2924
|
max@0
|
2925
|
max@0
|
2926 //! @}
|