max@0
|
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2008-2011 Conrad Sanderson
|
max@0
|
3 //
|
max@0
|
4 // This file is part of the Armadillo C++ library.
|
max@0
|
5 // It is provided without any warranty of fitness
|
max@0
|
6 // for any purpose. You can redistribute this file
|
max@0
|
7 // and/or modify it under the terms of the GNU
|
max@0
|
8 // Lesser General Public License (LGPL) as published
|
max@0
|
9 // by the Free Software Foundation, either version 3
|
max@0
|
10 // of the License or (at your option) any later version.
|
max@0
|
11 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
12
|
max@0
|
13
|
max@0
|
14 //! \addtogroup glue_times
|
max@0
|
15 //! @{
|
max@0
|
16
|
max@0
|
17
|
max@0
|
18
|
max@0
|
19 template<uword N>
|
max@0
|
20 template<typename T1, typename T2>
|
max@0
|
21 inline
|
max@0
|
22 void
|
max@0
|
23 glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
|
max@0
|
24 {
|
max@0
|
25 arma_extra_debug_sigprint();
|
max@0
|
26
|
max@0
|
27 typedef typename T1::elem_type eT;
|
max@0
|
28
|
max@0
|
29 const partial_unwrap_check<T1> tmp1(X.A, out);
|
max@0
|
30 const partial_unwrap_check<T2> tmp2(X.B, out);
|
max@0
|
31
|
max@0
|
32 const Mat<eT>& A = tmp1.M;
|
max@0
|
33 const Mat<eT>& B = tmp2.M;
|
max@0
|
34
|
max@0
|
35 const bool do_trans_A = tmp1.do_trans;
|
max@0
|
36 const bool do_trans_B = tmp2.do_trans;
|
max@0
|
37
|
max@0
|
38 const bool use_alpha = tmp1.do_times || tmp2.do_times;
|
max@0
|
39 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
|
max@0
|
40
|
max@0
|
41 glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
|
max@0
|
42 }
|
max@0
|
43
|
max@0
|
44
|
max@0
|
45
|
max@0
|
46 template<typename T1, typename T2, typename T3>
|
max@0
|
47 inline
|
max@0
|
48 void
|
max@0
|
49 glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
|
max@0
|
50 {
|
max@0
|
51 arma_extra_debug_sigprint();
|
max@0
|
52
|
max@0
|
53 typedef typename T1::elem_type eT;
|
max@0
|
54
|
max@0
|
55 // there is exactly 3 objects
|
max@0
|
56 // hence we can safely expand X as X.A.A, X.A.B and X.B
|
max@0
|
57
|
max@0
|
58 const partial_unwrap_check<T1> tmp1(X.A.A, out);
|
max@0
|
59 const partial_unwrap_check<T2> tmp2(X.A.B, out);
|
max@0
|
60 const partial_unwrap_check<T3> tmp3(X.B, out);
|
max@0
|
61
|
max@0
|
62 const Mat<eT>& A = tmp1.M;
|
max@0
|
63 const Mat<eT>& B = tmp2.M;
|
max@0
|
64 const Mat<eT>& C = tmp3.M;
|
max@0
|
65
|
max@0
|
66 const bool do_trans_A = tmp1.do_trans;
|
max@0
|
67 const bool do_trans_B = tmp2.do_trans;
|
max@0
|
68 const bool do_trans_C = tmp3.do_trans;
|
max@0
|
69
|
max@0
|
70 const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times;
|
max@0
|
71 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0);
|
max@0
|
72
|
max@0
|
73 glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
|
max@0
|
74 }
|
max@0
|
75
|
max@0
|
76
|
max@0
|
77
|
max@0
|
78 template<typename T1, typename T2, typename T3, typename T4>
|
max@0
|
79 inline
|
max@0
|
80 void
|
max@0
|
81 glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X)
|
max@0
|
82 {
|
max@0
|
83 arma_extra_debug_sigprint();
|
max@0
|
84
|
max@0
|
85 typedef typename T1::elem_type eT;
|
max@0
|
86
|
max@0
|
87 // there is exactly 4 objects
|
max@0
|
88 // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B
|
max@0
|
89
|
max@0
|
90 const partial_unwrap_check<T1> tmp1(X.A.A.A, out);
|
max@0
|
91 const partial_unwrap_check<T2> tmp2(X.A.A.B, out);
|
max@0
|
92 const partial_unwrap_check<T3> tmp3(X.A.B, out);
|
max@0
|
93 const partial_unwrap_check<T4> tmp4(X.B, out);
|
max@0
|
94
|
max@0
|
95 const Mat<eT>& A = tmp1.M;
|
max@0
|
96 const Mat<eT>& B = tmp2.M;
|
max@0
|
97 const Mat<eT>& C = tmp3.M;
|
max@0
|
98 const Mat<eT>& D = tmp4.M;
|
max@0
|
99
|
max@0
|
100 const bool do_trans_A = tmp1.do_trans;
|
max@0
|
101 const bool do_trans_B = tmp2.do_trans;
|
max@0
|
102 const bool do_trans_C = tmp3.do_trans;
|
max@0
|
103 const bool do_trans_D = tmp4.do_trans;
|
max@0
|
104
|
max@0
|
105 const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times || tmp4.do_times;
|
max@0
|
106 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0);
|
max@0
|
107
|
max@0
|
108 glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha);
|
max@0
|
109 }
|
max@0
|
110
|
max@0
|
111
|
max@0
|
112
|
max@0
|
113 template<typename T1, typename T2>
|
max@0
|
114 inline
|
max@0
|
115 void
|
max@0
|
116 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
|
max@0
|
117 {
|
max@0
|
118 arma_extra_debug_sigprint();
|
max@0
|
119
|
max@0
|
120 typedef typename T1::elem_type eT;
|
max@0
|
121
|
max@0
|
122 const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
|
max@0
|
123
|
max@0
|
124 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
|
max@0
|
125
|
max@0
|
126 glue_times_redirect<N_mat>::apply(out, X);
|
max@0
|
127 }
|
max@0
|
128
|
max@0
|
129
|
max@0
|
130
|
max@0
|
131 template<typename T1>
|
max@0
|
132 inline
|
max@0
|
133 void
|
max@0
|
134 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
|
max@0
|
135 {
|
max@0
|
136 arma_extra_debug_sigprint();
|
max@0
|
137
|
max@0
|
138 typedef typename T1::elem_type eT;
|
max@0
|
139
|
max@0
|
140 const unwrap_check<T1> tmp(X, out);
|
max@0
|
141 const Mat<eT>& B = tmp.M;
|
max@0
|
142
|
max@0
|
143 arma_debug_assert_mul_size(out, B, "matrix multiplication");
|
max@0
|
144
|
max@0
|
145 const uword out_n_rows = out.n_rows;
|
max@0
|
146 const uword out_n_cols = out.n_cols;
|
max@0
|
147
|
max@0
|
148 if(out_n_cols == B.n_cols)
|
max@0
|
149 {
|
max@0
|
150 // size of resulting matrix is the same as 'out'
|
max@0
|
151
|
max@0
|
152 podarray<eT> tmp(out_n_cols);
|
max@0
|
153
|
max@0
|
154 eT* tmp_rowdata = tmp.memptr();
|
max@0
|
155
|
max@0
|
156 for(uword row=0; row < out_n_rows; ++row)
|
max@0
|
157 {
|
max@0
|
158 tmp.copy_row(out, row);
|
max@0
|
159
|
max@0
|
160 for(uword col=0; col < out_n_cols; ++col)
|
max@0
|
161 {
|
max@0
|
162 out.at(row,col) = op_dot::direct_dot( out_n_cols, tmp_rowdata, B.colptr(col) );
|
max@0
|
163 }
|
max@0
|
164 }
|
max@0
|
165
|
max@0
|
166 }
|
max@0
|
167 else
|
max@0
|
168 {
|
max@0
|
169 const Mat<eT> tmp(out);
|
max@0
|
170 glue_times::apply(out, tmp, B, eT(1), false, false, false);
|
max@0
|
171 }
|
max@0
|
172
|
max@0
|
173 }
|
max@0
|
174
|
max@0
|
175
|
max@0
|
176
|
max@0
|
177 template<typename T1, typename T2>
|
max@0
|
178 arma_hot
|
max@0
|
179 inline
|
max@0
|
180 void
|
max@0
|
181 glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign)
|
max@0
|
182 {
|
max@0
|
183 arma_extra_debug_sigprint();
|
max@0
|
184
|
max@0
|
185 typedef typename T1::elem_type eT;
|
max@0
|
186
|
max@0
|
187 const partial_unwrap_check<T1> tmp1(X.A, out);
|
max@0
|
188 const partial_unwrap_check<T2> tmp2(X.B, out);
|
max@0
|
189
|
max@0
|
190 const Mat<eT>& A = tmp1.M;
|
max@0
|
191 const Mat<eT>& B = tmp2.M;
|
max@0
|
192 const eT alpha = tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) );
|
max@0
|
193
|
max@0
|
194 const bool do_trans_A = tmp1.do_trans;
|
max@0
|
195 const bool do_trans_B = tmp2.do_trans;
|
max@0
|
196 const bool use_alpha = tmp1.do_times || tmp2.do_times || (sign < sword(0));
|
max@0
|
197
|
max@0
|
198 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
|
max@0
|
199
|
max@0
|
200 const uword result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
|
max@0
|
201 const uword result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
|
max@0
|
202
|
max@0
|
203 arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "addition");
|
max@0
|
204
|
max@0
|
205 if(out.n_elem > 0)
|
max@0
|
206 {
|
max@0
|
207 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
|
max@0
|
208 {
|
max@0
|
209 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
210 {
|
max@0
|
211 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
max@0
|
212 }
|
max@0
|
213 else
|
max@0
|
214 if(B.n_cols == 1)
|
max@0
|
215 {
|
max@0
|
216 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
max@0
|
217 }
|
max@0
|
218 else
|
max@0
|
219 {
|
max@0
|
220 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1));
|
max@0
|
221 }
|
max@0
|
222 }
|
max@0
|
223 else
|
max@0
|
224 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
|
max@0
|
225 {
|
max@0
|
226 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
227 {
|
max@0
|
228 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
max@0
|
229 }
|
max@0
|
230 else
|
max@0
|
231 if(B.n_cols == 1)
|
max@0
|
232 {
|
max@0
|
233 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
max@0
|
234 }
|
max@0
|
235 else
|
max@0
|
236 {
|
max@0
|
237 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1));
|
max@0
|
238 }
|
max@0
|
239 }
|
max@0
|
240 else
|
max@0
|
241 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
|
max@0
|
242 {
|
max@0
|
243 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
|
max@0
|
244 {
|
max@0
|
245 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
max@0
|
246 }
|
max@0
|
247 else
|
max@0
|
248 if(B.n_cols == 1)
|
max@0
|
249 {
|
max@0
|
250 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
max@0
|
251 }
|
max@0
|
252 else
|
max@0
|
253 {
|
max@0
|
254 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1));
|
max@0
|
255 }
|
max@0
|
256 }
|
max@0
|
257 else
|
max@0
|
258 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
|
max@0
|
259 {
|
max@0
|
260 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
|
max@0
|
261 {
|
max@0
|
262 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
max@0
|
263 }
|
max@0
|
264 else
|
max@0
|
265 if(B.n_cols == 1)
|
max@0
|
266 {
|
max@0
|
267 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
max@0
|
268 }
|
max@0
|
269 else
|
max@0
|
270 {
|
max@0
|
271 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1));
|
max@0
|
272 }
|
max@0
|
273 }
|
max@0
|
274 else
|
max@0
|
275 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
|
max@0
|
276 {
|
max@0
|
277 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
278 {
|
max@0
|
279 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
max@0
|
280 }
|
max@0
|
281 else
|
max@0
|
282 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
283 {
|
max@0
|
284 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
max@0
|
285 }
|
max@0
|
286 else
|
max@0
|
287 {
|
max@0
|
288 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1));
|
max@0
|
289 }
|
max@0
|
290 }
|
max@0
|
291 else
|
max@0
|
292 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
|
max@0
|
293 {
|
max@0
|
294 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
295 {
|
max@0
|
296 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
max@0
|
297 }
|
max@0
|
298 else
|
max@0
|
299 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
300 {
|
max@0
|
301 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
max@0
|
302 }
|
max@0
|
303 else
|
max@0
|
304 {
|
max@0
|
305 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1));
|
max@0
|
306 }
|
max@0
|
307 }
|
max@0
|
308 else
|
max@0
|
309 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
|
max@0
|
310 {
|
max@0
|
311 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
|
max@0
|
312 {
|
max@0
|
313 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
max@0
|
314 }
|
max@0
|
315 else
|
max@0
|
316 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
317 {
|
max@0
|
318 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
max@0
|
319 }
|
max@0
|
320 else
|
max@0
|
321 {
|
max@0
|
322 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1));
|
max@0
|
323 }
|
max@0
|
324 }
|
max@0
|
325 else
|
max@0
|
326 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
|
max@0
|
327 {
|
max@0
|
328 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
|
max@0
|
329 {
|
max@0
|
330 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
max@0
|
331 }
|
max@0
|
332 else
|
max@0
|
333 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
334 {
|
max@0
|
335 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
max@0
|
336 }
|
max@0
|
337 else
|
max@0
|
338 {
|
max@0
|
339 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1));
|
max@0
|
340 }
|
max@0
|
341 }
|
max@0
|
342 }
|
max@0
|
343
|
max@0
|
344
|
max@0
|
345 }
|
max@0
|
346
|
max@0
|
347
|
max@0
|
348
|
max@0
|
349 template<typename eT>
|
max@0
|
350 arma_inline
|
max@0
|
351 uword
|
max@0
|
352 glue_times::mul_storage_cost(const Mat<eT>& A, const Mat<eT>& B, const bool do_trans_A, const bool do_trans_B)
|
max@0
|
353 {
|
max@0
|
354 const uword final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
|
max@0
|
355 const uword final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
|
max@0
|
356
|
max@0
|
357 return final_A_n_rows * final_B_n_cols;
|
max@0
|
358 }
|
max@0
|
359
|
max@0
|
360
|
max@0
|
361
|
max@0
|
362 template<typename eT>
|
max@0
|
363 arma_hot
|
max@0
|
364 inline
|
max@0
|
365 void
|
max@0
|
366 glue_times::apply
|
max@0
|
367 (
|
max@0
|
368 Mat<eT>& out,
|
max@0
|
369 const Mat<eT>& A,
|
max@0
|
370 const Mat<eT>& B,
|
max@0
|
371 const eT alpha,
|
max@0
|
372 const bool do_trans_A,
|
max@0
|
373 const bool do_trans_B,
|
max@0
|
374 const bool use_alpha
|
max@0
|
375 )
|
max@0
|
376 {
|
max@0
|
377 arma_extra_debug_sigprint();
|
max@0
|
378
|
max@0
|
379 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
|
max@0
|
380
|
max@0
|
381 const uword final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
|
max@0
|
382 const uword final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
|
max@0
|
383
|
max@0
|
384 out.set_size(final_n_rows, final_n_cols);
|
max@0
|
385
|
max@0
|
386 if( (A.n_elem > 0) && (B.n_elem > 0) )
|
max@0
|
387 {
|
max@0
|
388 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
|
max@0
|
389 {
|
max@0
|
390 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
391 {
|
max@0
|
392 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
|
max@0
|
393 }
|
max@0
|
394 else
|
max@0
|
395 if(B.n_cols == 1)
|
max@0
|
396 {
|
max@0
|
397 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
|
max@0
|
398 }
|
max@0
|
399 else
|
max@0
|
400 {
|
max@0
|
401 gemm<false, false, false, false>::apply(out, A, B);
|
max@0
|
402 }
|
max@0
|
403 }
|
max@0
|
404 else
|
max@0
|
405 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
|
max@0
|
406 {
|
max@0
|
407 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
408 {
|
max@0
|
409 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
|
max@0
|
410 }
|
max@0
|
411 else
|
max@0
|
412 if(B.n_cols == 1)
|
max@0
|
413 {
|
max@0
|
414 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
|
max@0
|
415 }
|
max@0
|
416 else
|
max@0
|
417 {
|
max@0
|
418 gemm<false, false, true, false>::apply(out, A, B, alpha);
|
max@0
|
419 }
|
max@0
|
420 }
|
max@0
|
421 else
|
max@0
|
422 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
|
max@0
|
423 {
|
max@0
|
424 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
|
max@0
|
425 {
|
max@0
|
426 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
|
max@0
|
427 }
|
max@0
|
428 else
|
max@0
|
429 if(B.n_cols == 1)
|
max@0
|
430 {
|
max@0
|
431 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
|
max@0
|
432 }
|
max@0
|
433 else
|
max@0
|
434 {
|
max@0
|
435 gemm<true, false, false, false>::apply(out, A, B);
|
max@0
|
436 }
|
max@0
|
437 }
|
max@0
|
438 else
|
max@0
|
439 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
|
max@0
|
440 {
|
max@0
|
441 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
|
max@0
|
442 {
|
max@0
|
443 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
|
max@0
|
444 }
|
max@0
|
445 else
|
max@0
|
446 if(B.n_cols == 1)
|
max@0
|
447 {
|
max@0
|
448 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
|
max@0
|
449 }
|
max@0
|
450 else
|
max@0
|
451 {
|
max@0
|
452 gemm<true, false, true, false>::apply(out, A, B, alpha);
|
max@0
|
453 }
|
max@0
|
454 }
|
max@0
|
455 else
|
max@0
|
456 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
|
max@0
|
457 {
|
max@0
|
458 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
459 {
|
max@0
|
460 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
|
max@0
|
461 }
|
max@0
|
462 else
|
max@0
|
463 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
464 {
|
max@0
|
465 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
|
max@0
|
466 }
|
max@0
|
467 else
|
max@0
|
468 {
|
max@0
|
469 gemm<false, true, false, false>::apply(out, A, B);
|
max@0
|
470 }
|
max@0
|
471 }
|
max@0
|
472 else
|
max@0
|
473 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
|
max@0
|
474 {
|
max@0
|
475 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
476 {
|
max@0
|
477 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
|
max@0
|
478 }
|
max@0
|
479 else
|
max@0
|
480 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
481 {
|
max@0
|
482 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
|
max@0
|
483 }
|
max@0
|
484 else
|
max@0
|
485 {
|
max@0
|
486 gemm<false, true, true, false>::apply(out, A, B, alpha);
|
max@0
|
487 }
|
max@0
|
488 }
|
max@0
|
489 else
|
max@0
|
490 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
|
max@0
|
491 {
|
max@0
|
492 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
|
max@0
|
493 {
|
max@0
|
494 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
|
max@0
|
495 }
|
max@0
|
496 else
|
max@0
|
497 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
498 {
|
max@0
|
499 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
|
max@0
|
500 }
|
max@0
|
501 else
|
max@0
|
502 {
|
max@0
|
503 gemm<true, true, false, false>::apply(out, A, B);
|
max@0
|
504 }
|
max@0
|
505 }
|
max@0
|
506 else
|
max@0
|
507 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
|
max@0
|
508 {
|
max@0
|
509 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
|
max@0
|
510 {
|
max@0
|
511 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
|
max@0
|
512 }
|
max@0
|
513 else
|
max@0
|
514 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
|
max@0
|
515 {
|
max@0
|
516 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
|
max@0
|
517 }
|
max@0
|
518 else
|
max@0
|
519 {
|
max@0
|
520 gemm<true, true, true, false>::apply(out, A, B, alpha);
|
max@0
|
521 }
|
max@0
|
522 }
|
max@0
|
523 }
|
max@0
|
524 else
|
max@0
|
525 {
|
max@0
|
526 out.zeros();
|
max@0
|
527 }
|
max@0
|
528 }
|
max@0
|
529
|
max@0
|
530
|
max@0
|
531
|
max@0
|
532 template<typename eT>
|
max@0
|
533 inline
|
max@0
|
534 void
|
max@0
|
535 glue_times::apply
|
max@0
|
536 (
|
max@0
|
537 Mat<eT>& out,
|
max@0
|
538 const Mat<eT>& A,
|
max@0
|
539 const Mat<eT>& B,
|
max@0
|
540 const Mat<eT>& C,
|
max@0
|
541 const eT alpha,
|
max@0
|
542 const bool do_trans_A,
|
max@0
|
543 const bool do_trans_B,
|
max@0
|
544 const bool do_trans_C,
|
max@0
|
545 const bool use_alpha
|
max@0
|
546 )
|
max@0
|
547 {
|
max@0
|
548 arma_extra_debug_sigprint();
|
max@0
|
549
|
max@0
|
550 Mat<eT> tmp;
|
max@0
|
551
|
max@0
|
552 if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) )
|
max@0
|
553 {
|
max@0
|
554 // out = (A*B)*C
|
max@0
|
555 glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
|
max@0
|
556 glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false );
|
max@0
|
557 }
|
max@0
|
558 else
|
max@0
|
559 {
|
max@0
|
560 // out = A*(B*C)
|
max@0
|
561 glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha);
|
max@0
|
562 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false );
|
max@0
|
563 }
|
max@0
|
564 }
|
max@0
|
565
|
max@0
|
566
|
max@0
|
567
|
max@0
|
568 template<typename eT>
|
max@0
|
569 inline
|
max@0
|
570 void
|
max@0
|
571 glue_times::apply
|
max@0
|
572 (
|
max@0
|
573 Mat<eT>& out,
|
max@0
|
574 const Mat<eT>& A,
|
max@0
|
575 const Mat<eT>& B,
|
max@0
|
576 const Mat<eT>& C,
|
max@0
|
577 const Mat<eT>& D,
|
max@0
|
578 const eT alpha,
|
max@0
|
579 const bool do_trans_A,
|
max@0
|
580 const bool do_trans_B,
|
max@0
|
581 const bool do_trans_C,
|
max@0
|
582 const bool do_trans_D,
|
max@0
|
583 const bool use_alpha
|
max@0
|
584 )
|
max@0
|
585 {
|
max@0
|
586 arma_extra_debug_sigprint();
|
max@0
|
587
|
max@0
|
588 Mat<eT> tmp;
|
max@0
|
589
|
max@0
|
590 if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) )
|
max@0
|
591 {
|
max@0
|
592 // out = (A*B*C)*D
|
max@0
|
593 glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
|
max@0
|
594
|
max@0
|
595 glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false);
|
max@0
|
596 }
|
max@0
|
597 else
|
max@0
|
598 {
|
max@0
|
599 // out = A*(B*C*D)
|
max@0
|
600 glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha);
|
max@0
|
601
|
max@0
|
602 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false);
|
max@0
|
603 }
|
max@0
|
604 }
|
max@0
|
605
|
max@0
|
606
|
max@0
|
607
|
max@0
|
608 //
|
max@0
|
609 // glue_times_diag
|
max@0
|
610
|
max@0
|
611
|
max@0
|
612 template<typename T1, typename T2>
|
max@0
|
613 arma_hot
|
max@0
|
614 inline
|
max@0
|
615 void
|
max@0
|
616 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X)
|
max@0
|
617 {
|
max@0
|
618 arma_extra_debug_sigprint();
|
max@0
|
619
|
max@0
|
620 typedef typename T1::elem_type eT;
|
max@0
|
621
|
max@0
|
622 const strip_diagmat<T1> S1(X.A);
|
max@0
|
623 const strip_diagmat<T2> S2(X.B);
|
max@0
|
624
|
max@0
|
625 typedef typename strip_diagmat<T1>::stored_type T1_stripped;
|
max@0
|
626 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
|
max@0
|
627
|
max@0
|
628 if( (S1.do_diagmat == true) && (S2.do_diagmat == false) )
|
max@0
|
629 {
|
max@0
|
630 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
|
max@0
|
631
|
max@0
|
632 const unwrap_check<T2> tmp(X.B, out);
|
max@0
|
633 const Mat<eT>& B = tmp.M;
|
max@0
|
634
|
max@0
|
635 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiplication");
|
max@0
|
636
|
max@0
|
637 out.set_size(A.n_elem, B.n_cols);
|
max@0
|
638
|
max@0
|
639 for(uword col=0; col<B.n_cols; ++col)
|
max@0
|
640 {
|
max@0
|
641 eT* out_coldata = out.colptr(col);
|
max@0
|
642 const eT* B_coldata = B.colptr(col);
|
max@0
|
643
|
max@0
|
644 for(uword row=0; row<B.n_rows; ++row)
|
max@0
|
645 {
|
max@0
|
646 out_coldata[row] = A[row] * B_coldata[row];
|
max@0
|
647 }
|
max@0
|
648 }
|
max@0
|
649 }
|
max@0
|
650 else
|
max@0
|
651 if( (S1.do_diagmat == false) && (S2.do_diagmat == true) )
|
max@0
|
652 {
|
max@0
|
653 const unwrap_check<T1> tmp(X.A, out);
|
max@0
|
654 const Mat<eT>& A = tmp.M;
|
max@0
|
655
|
max@0
|
656 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
|
max@0
|
657
|
max@0
|
658 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiplication");
|
max@0
|
659
|
max@0
|
660 out.set_size(A.n_rows, B.n_elem);
|
max@0
|
661
|
max@0
|
662 for(uword col=0; col<A.n_cols; ++col)
|
max@0
|
663 {
|
max@0
|
664 const eT val = B[col];
|
max@0
|
665
|
max@0
|
666 eT* out_coldata = out.colptr(col);
|
max@0
|
667 const eT* A_coldata = A.colptr(col);
|
max@0
|
668
|
max@0
|
669 for(uword row=0; row<A.n_rows; ++row)
|
max@0
|
670 {
|
max@0
|
671 out_coldata[row] = A_coldata[row] * val;
|
max@0
|
672 }
|
max@0
|
673 }
|
max@0
|
674 }
|
max@0
|
675 else
|
max@0
|
676 if( (S1.do_diagmat == true) && (S2.do_diagmat == true) )
|
max@0
|
677 {
|
max@0
|
678 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
|
max@0
|
679 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
|
max@0
|
680
|
max@0
|
681 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiplication");
|
max@0
|
682
|
max@0
|
683 out.zeros(A.n_elem, A.n_elem);
|
max@0
|
684
|
max@0
|
685 for(uword i=0; i<A.n_elem; ++i)
|
max@0
|
686 {
|
max@0
|
687 out.at(i,i) = A[i] * B[i];
|
max@0
|
688 }
|
max@0
|
689 }
|
max@0
|
690 }
|
max@0
|
691
|
max@0
|
692
|
max@0
|
693
|
max@0
|
694 //! @}
|