max@0
|
1 // Copyright (C) 2010-2011 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2010-2011 Conrad Sanderson
|
max@0
|
3 //
|
max@0
|
4 // This file is part of the Armadillo C++ library.
|
max@0
|
5 // It is provided without any warranty of fitness
|
max@0
|
6 // for any purpose. You can redistribute this file
|
max@0
|
7 // and/or modify it under the terms of the GNU
|
max@0
|
8 // Lesser General Public License (LGPL) as published
|
max@0
|
9 // by the Free Software Foundation, either version 3
|
max@0
|
10 // of the License or (at your option) any later version.
|
max@0
|
11 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
12
|
max@0
|
13
|
max@0
|
14 //! \addtogroup fn_as_scalar
|
max@0
|
15 //! @{
|
max@0
|
16
|
max@0
|
17
|
max@0
|
18
|
max@0
|
19 template<uword N>
|
max@0
|
20 struct as_scalar_redirect
|
max@0
|
21 {
|
max@0
|
22 template<typename T1>
|
max@0
|
23 inline static typename T1::elem_type apply(const T1& X);
|
max@0
|
24 };
|
max@0
|
25
|
max@0
|
26
|
max@0
|
27
|
max@0
|
28 template<>
|
max@0
|
29 struct as_scalar_redirect<2>
|
max@0
|
30 {
|
max@0
|
31 template<typename T1, typename T2>
|
max@0
|
32 inline static typename T1::elem_type apply(const Glue<T1,T2,glue_times>& X);
|
max@0
|
33 };
|
max@0
|
34
|
max@0
|
35
|
max@0
|
36 template<>
|
max@0
|
37 struct as_scalar_redirect<3>
|
max@0
|
38 {
|
max@0
|
39 template<typename T1, typename T2, typename T3>
|
max@0
|
40 inline static typename T1::elem_type apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times>& X);
|
max@0
|
41 };
|
max@0
|
42
|
max@0
|
43
|
max@0
|
44
|
max@0
|
45 template<uword N>
|
max@0
|
46 template<typename T1>
|
max@0
|
47 inline
|
max@0
|
48 typename T1::elem_type
|
max@0
|
49 as_scalar_redirect<N>::apply(const T1& X)
|
max@0
|
50 {
|
max@0
|
51 arma_extra_debug_sigprint();
|
max@0
|
52
|
max@0
|
53 typedef typename T1::elem_type eT;
|
max@0
|
54
|
max@0
|
55 const unwrap<T1> tmp(X);
|
max@0
|
56 const Mat<eT>& A = tmp.M;
|
max@0
|
57
|
max@0
|
58 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
|
max@0
|
59
|
max@0
|
60 return A.mem[0];
|
max@0
|
61 }
|
max@0
|
62
|
max@0
|
63
|
max@0
|
64
|
max@0
|
65 template<typename T1, typename T2>
|
max@0
|
66 inline
|
max@0
|
67 typename T1::elem_type
|
max@0
|
68 as_scalar_redirect<2>::apply(const Glue<T1, T2, glue_times>& X)
|
max@0
|
69 {
|
max@0
|
70 arma_extra_debug_sigprint();
|
max@0
|
71
|
max@0
|
72 typedef typename T1::elem_type eT;
|
max@0
|
73
|
max@0
|
74 // T1 must result in a matrix with one row
|
max@0
|
75 // T2 must result in a matrix with one column
|
max@0
|
76
|
max@0
|
77 const partial_unwrap<T1> tmp1(X.A);
|
max@0
|
78 const partial_unwrap<T2> tmp2(X.B);
|
max@0
|
79
|
max@0
|
80 const Mat<eT>& A = tmp1.M;
|
max@0
|
81 const Mat<eT>& B = tmp2.M;
|
max@0
|
82
|
max@0
|
83 const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
|
max@0
|
84 const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
|
max@0
|
85
|
max@0
|
86 const uword B_n_rows = (tmp2.do_trans == false) ? B.n_rows : B.n_cols;
|
max@0
|
87 const uword B_n_cols = (tmp2.do_trans == false) ? B.n_cols : B.n_rows;
|
max@0
|
88
|
max@0
|
89 const eT val = tmp1.get_val() * tmp2.get_val();
|
max@0
|
90
|
max@0
|
91 arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" );
|
max@0
|
92
|
max@0
|
93 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem);
|
max@0
|
94 }
|
max@0
|
95
|
max@0
|
96
|
max@0
|
97
|
max@0
|
98 template<typename T1, typename T2, typename T3>
|
max@0
|
99 inline
|
max@0
|
100 typename T1::elem_type
|
max@0
|
101 as_scalar_redirect<3>::apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times >& X)
|
max@0
|
102 {
|
max@0
|
103 arma_extra_debug_sigprint();
|
max@0
|
104
|
max@0
|
105 typedef typename T1::elem_type eT;
|
max@0
|
106
|
max@0
|
107 // T1 * T2 must result in a matrix with one row
|
max@0
|
108 // T3 must result in a matrix with one column
|
max@0
|
109
|
max@0
|
110 typedef typename strip_inv <T2 >::stored_type T2_stripped_1;
|
max@0
|
111 typedef typename strip_diagmat<T2_stripped_1>::stored_type T2_stripped_2;
|
max@0
|
112
|
max@0
|
113 const strip_inv <T2> strip1(X.A.B);
|
max@0
|
114 const strip_diagmat<T2_stripped_1> strip2(strip1.M);
|
max@0
|
115
|
max@0
|
116 const bool tmp2_do_inv = strip1.do_inv;
|
max@0
|
117 const bool tmp2_do_diagmat = strip2.do_diagmat;
|
max@0
|
118
|
max@0
|
119 if(tmp2_do_diagmat == false)
|
max@0
|
120 {
|
max@0
|
121 const Mat<eT> tmp(X);
|
max@0
|
122
|
max@0
|
123 arma_debug_check( (tmp.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
|
max@0
|
124
|
max@0
|
125 return tmp[0];
|
max@0
|
126 }
|
max@0
|
127 else
|
max@0
|
128 {
|
max@0
|
129 const partial_unwrap<T1> tmp1(X.A.A);
|
max@0
|
130 const partial_unwrap<T2_stripped_2> tmp2(strip2.M);
|
max@0
|
131 const partial_unwrap<T3> tmp3(X.B);
|
max@0
|
132
|
max@0
|
133 const Mat<eT>& A = tmp1.M;
|
max@0
|
134 const Mat<eT>& B = tmp2.M;
|
max@0
|
135 const Mat<eT>& C = tmp3.M;
|
max@0
|
136
|
max@0
|
137 const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
|
max@0
|
138 const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
|
max@0
|
139
|
max@0
|
140 const bool B_is_vec = B.is_vec();
|
max@0
|
141
|
max@0
|
142 const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
|
max@0
|
143 const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
|
max@0
|
144
|
max@0
|
145 const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
|
max@0
|
146 const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
|
max@0
|
147
|
max@0
|
148 const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val();
|
max@0
|
149
|
max@0
|
150 arma_debug_check
|
max@0
|
151 (
|
max@0
|
152 (A_n_rows != 1) ||
|
max@0
|
153 (C_n_cols != 1) ||
|
max@0
|
154 (A_n_cols != B_n_rows) ||
|
max@0
|
155 (B_n_cols != C_n_rows)
|
max@0
|
156 ,
|
max@0
|
157 "as_scalar(): incompatible dimensions"
|
max@0
|
158 );
|
max@0
|
159
|
max@0
|
160
|
max@0
|
161 if(B_is_vec == true)
|
max@0
|
162 {
|
max@0
|
163 if(tmp2_do_inv == true)
|
max@0
|
164 {
|
max@0
|
165 return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem);
|
max@0
|
166 }
|
max@0
|
167 else
|
max@0
|
168 {
|
max@0
|
169 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
|
max@0
|
170 }
|
max@0
|
171 }
|
max@0
|
172 else
|
max@0
|
173 {
|
max@0
|
174 if(tmp2_do_inv == true)
|
max@0
|
175 {
|
max@0
|
176 return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem);
|
max@0
|
177 }
|
max@0
|
178 else
|
max@0
|
179 {
|
max@0
|
180 return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
|
max@0
|
181 }
|
max@0
|
182 }
|
max@0
|
183 }
|
max@0
|
184 }
|
max@0
|
185
|
max@0
|
186
|
max@0
|
187
|
max@0
|
188 template<typename T1>
|
max@0
|
189 inline
|
max@0
|
190 typename T1::elem_type
|
max@0
|
191 as_scalar_diag(const Base<typename T1::elem_type,T1>& X)
|
max@0
|
192 {
|
max@0
|
193 arma_extra_debug_sigprint();
|
max@0
|
194
|
max@0
|
195 typedef typename T1::elem_type eT;
|
max@0
|
196
|
max@0
|
197 const unwrap<T1> tmp(X.get_ref());
|
max@0
|
198 const Mat<eT>& A = tmp.M;
|
max@0
|
199
|
max@0
|
200 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
|
max@0
|
201
|
max@0
|
202 return A.mem[0];
|
max@0
|
203 }
|
max@0
|
204
|
max@0
|
205
|
max@0
|
206
|
max@0
|
207 template<typename T1, typename T2, typename T3>
|
max@0
|
208 inline
|
max@0
|
209 typename T1::elem_type
|
max@0
|
210 as_scalar_diag(const Glue< Glue<T1, T2, glue_times_diag>, T3, glue_times >& X)
|
max@0
|
211 {
|
max@0
|
212 arma_extra_debug_sigprint();
|
max@0
|
213
|
max@0
|
214 typedef typename T1::elem_type eT;
|
max@0
|
215
|
max@0
|
216 // T1 * T2 must result in a matrix with one row
|
max@0
|
217 // T3 must result in a matrix with one column
|
max@0
|
218
|
max@0
|
219 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
|
max@0
|
220
|
max@0
|
221 const strip_diagmat<T2> strip(X.A.B);
|
max@0
|
222
|
max@0
|
223 const partial_unwrap<T1> tmp1(X.A.A);
|
max@0
|
224 const partial_unwrap<T2_stripped> tmp2(strip.M);
|
max@0
|
225 const partial_unwrap<T3> tmp3(X.B);
|
max@0
|
226
|
max@0
|
227 const Mat<eT>& A = tmp1.M;
|
max@0
|
228 const Mat<eT>& B = tmp2.M;
|
max@0
|
229 const Mat<eT>& C = tmp3.M;
|
max@0
|
230
|
max@0
|
231
|
max@0
|
232 const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
|
max@0
|
233 const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
|
max@0
|
234
|
max@0
|
235 const bool B_is_vec = B.is_vec();
|
max@0
|
236
|
max@0
|
237 const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
|
max@0
|
238 const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
|
max@0
|
239
|
max@0
|
240 const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
|
max@0
|
241 const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
|
max@0
|
242
|
max@0
|
243 const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val();
|
max@0
|
244
|
max@0
|
245 arma_debug_check
|
max@0
|
246 (
|
max@0
|
247 (A_n_rows != 1) ||
|
max@0
|
248 (C_n_cols != 1) ||
|
max@0
|
249 (A_n_cols != B_n_rows) ||
|
max@0
|
250 (B_n_cols != C_n_rows)
|
max@0
|
251 ,
|
max@0
|
252 "as_scalar(): incompatible dimensions"
|
max@0
|
253 );
|
max@0
|
254
|
max@0
|
255
|
max@0
|
256 if(B_is_vec == true)
|
max@0
|
257 {
|
max@0
|
258 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
|
max@0
|
259 }
|
max@0
|
260 else
|
max@0
|
261 {
|
max@0
|
262 return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
|
max@0
|
263 }
|
max@0
|
264 }
|
max@0
|
265
|
max@0
|
266
|
max@0
|
267
|
max@0
|
268 template<typename T1, typename T2>
|
max@0
|
269 arma_inline
|
max@0
|
270 arma_warn_unused
|
max@0
|
271 typename T1::elem_type
|
max@0
|
272 as_scalar(const Glue<T1, T2, glue_times>& X, const typename arma_not_cx<typename T1::elem_type>::result* junk = 0)
|
max@0
|
273 {
|
max@0
|
274 arma_extra_debug_sigprint();
|
max@0
|
275 arma_ignore(junk);
|
max@0
|
276
|
max@0
|
277 if(is_glue_times_diag<T1>::value == false)
|
max@0
|
278 {
|
max@0
|
279 const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
|
max@0
|
280
|
max@0
|
281 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
|
max@0
|
282
|
max@0
|
283 return as_scalar_redirect<N_mat>::apply(X);
|
max@0
|
284 }
|
max@0
|
285 else
|
max@0
|
286 {
|
max@0
|
287 return as_scalar_diag(X);
|
max@0
|
288 }
|
max@0
|
289 }
|
max@0
|
290
|
max@0
|
291
|
max@0
|
292
|
max@0
|
293 template<typename T1>
|
max@0
|
294 inline
|
max@0
|
295 arma_warn_unused
|
max@0
|
296 typename T1::elem_type
|
max@0
|
297 as_scalar(const Base<typename T1::elem_type,T1>& X)
|
max@0
|
298 {
|
max@0
|
299 arma_extra_debug_sigprint();
|
max@0
|
300
|
max@0
|
301 typedef typename T1::elem_type eT;
|
max@0
|
302
|
max@0
|
303 const unwrap<T1> tmp(X.get_ref());
|
max@0
|
304 const Mat<eT>& A = tmp.M;
|
max@0
|
305
|
max@0
|
306 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
|
max@0
|
307
|
max@0
|
308 return A.mem[0];
|
max@0
|
309 }
|
max@0
|
310
|
max@0
|
311
|
max@0
|
312
|
max@0
|
313 template<typename T1>
|
max@0
|
314 arma_inline
|
max@0
|
315 arma_warn_unused
|
max@0
|
316 typename T1::elem_type
|
max@0
|
317 as_scalar(const eOp<T1, eop_neg>& X)
|
max@0
|
318 {
|
max@0
|
319 arma_extra_debug_sigprint();
|
max@0
|
320
|
max@0
|
321 return -(as_scalar(X.P.Q));
|
max@0
|
322 }
|
max@0
|
323
|
max@0
|
324
|
max@0
|
325
|
max@0
|
326 template<typename T1>
|
max@0
|
327 inline
|
max@0
|
328 arma_warn_unused
|
max@0
|
329 typename T1::elem_type
|
max@0
|
330 as_scalar(const BaseCube<typename T1::elem_type,T1>& X)
|
max@0
|
331 {
|
max@0
|
332 arma_extra_debug_sigprint();
|
max@0
|
333
|
max@0
|
334 typedef typename T1::elem_type eT;
|
max@0
|
335
|
max@0
|
336 const unwrap_cube<T1> tmp(X.get_ref());
|
max@0
|
337 const Cube<eT>& A = tmp.M;
|
max@0
|
338
|
max@0
|
339 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
|
max@0
|
340
|
max@0
|
341 return A.mem[0];
|
max@0
|
342 }
|
max@0
|
343
|
max@0
|
344
|
max@0
|
345
|
max@0
|
346 template<typename T>
|
max@0
|
347 arma_inline
|
max@0
|
348 arma_warn_unused
|
max@0
|
349 const typename arma_scalar_only<T>::result &
|
max@0
|
350 as_scalar(const T& x)
|
max@0
|
351 {
|
max@0
|
352 return x;
|
max@0
|
353 }
|
max@0
|
354
|
max@0
|
355
|
max@0
|
356
|
max@0
|
357 //! @}
|