max@0
|
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2008-2011 Conrad Sanderson
|
max@0
|
3 //
|
max@0
|
4 // This file is part of the Armadillo C++ library.
|
max@0
|
5 // It is provided without any warranty of fitness
|
max@0
|
6 // for any purpose. You can redistribute this file
|
max@0
|
7 // and/or modify it under the terms of the GNU
|
max@0
|
8 // Lesser General Public License (LGPL) as published
|
max@0
|
9 // by the Free Software Foundation, either version 3
|
max@0
|
10 // of the License or (at your option) any later version.
|
max@0
|
11 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
12
|
max@0
|
13
|
max@0
|
14 //! \addtogroup gemv
|
max@0
|
15 //! @{
|
max@0
|
16
|
max@0
|
17
|
max@0
|
18
|
max@0
|
19 //! for tiny square matrices, size <= 4x4
|
max@0
|
20 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
|
max@0
|
21 class gemv_emul_tinysq
|
max@0
|
22 {
|
max@0
|
23 public:
|
max@0
|
24
|
max@0
|
25
|
max@0
|
26 template<const uword row, const uword col>
|
max@0
|
27 struct pos
|
max@0
|
28 {
|
max@0
|
29 static const uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2);
|
max@0
|
30 static const uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3);
|
max@0
|
31 static const uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4);
|
max@0
|
32 };
|
max@0
|
33
|
max@0
|
34
|
max@0
|
35
|
max@0
|
36 template<typename eT, const uword i>
|
max@0
|
37 arma_hot
|
max@0
|
38 arma_inline
|
max@0
|
39 static
|
max@0
|
40 void
|
max@0
|
41 assign(eT* y, const eT acc, const eT alpha, const eT beta)
|
max@0
|
42 {
|
max@0
|
43 if(use_beta == false)
|
max@0
|
44 {
|
max@0
|
45 y[i] = (use_alpha == false) ? acc : alpha*acc;
|
max@0
|
46 }
|
max@0
|
47 else
|
max@0
|
48 {
|
max@0
|
49 const eT tmp = y[i];
|
max@0
|
50
|
max@0
|
51 y[i] = beta*tmp + ( (use_alpha == false) ? acc : alpha*acc );
|
max@0
|
52 }
|
max@0
|
53 }
|
max@0
|
54
|
max@0
|
55
|
max@0
|
56
|
max@0
|
57 template<typename eT>
|
max@0
|
58 arma_hot
|
max@0
|
59 inline
|
max@0
|
60 static
|
max@0
|
61 void
|
max@0
|
62 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
|
max@0
|
63 {
|
max@0
|
64 arma_extra_debug_sigprint();
|
max@0
|
65
|
max@0
|
66 const eT* Am = A.memptr();
|
max@0
|
67
|
max@0
|
68 switch(A.n_rows)
|
max@0
|
69 {
|
max@0
|
70 case 1:
|
max@0
|
71 {
|
max@0
|
72 const eT acc = Am[0] * x[0];
|
max@0
|
73
|
max@0
|
74 assign<eT, 0>(y, acc, alpha, beta);
|
max@0
|
75 }
|
max@0
|
76 break;
|
max@0
|
77
|
max@0
|
78
|
max@0
|
79 case 2:
|
max@0
|
80 {
|
max@0
|
81 const eT x0 = x[0];
|
max@0
|
82 const eT x1 = x[1];
|
max@0
|
83
|
max@0
|
84 const eT acc0 = Am[pos<0,0>::n2]*x0 + Am[pos<0,1>::n2]*x1;
|
max@0
|
85 const eT acc1 = Am[pos<1,0>::n2]*x0 + Am[pos<1,1>::n2]*x1;
|
max@0
|
86
|
max@0
|
87 assign<eT, 0>(y, acc0, alpha, beta);
|
max@0
|
88 assign<eT, 1>(y, acc1, alpha, beta);
|
max@0
|
89 }
|
max@0
|
90 break;
|
max@0
|
91
|
max@0
|
92
|
max@0
|
93 case 3:
|
max@0
|
94 {
|
max@0
|
95 const eT x0 = x[0];
|
max@0
|
96 const eT x1 = x[1];
|
max@0
|
97 const eT x2 = x[2];
|
max@0
|
98
|
max@0
|
99 const eT acc0 = Am[pos<0,0>::n3]*x0 + Am[pos<0,1>::n3]*x1 + Am[pos<0,2>::n3]*x2;
|
max@0
|
100 const eT acc1 = Am[pos<1,0>::n3]*x0 + Am[pos<1,1>::n3]*x1 + Am[pos<1,2>::n3]*x2;
|
max@0
|
101 const eT acc2 = Am[pos<2,0>::n3]*x0 + Am[pos<2,1>::n3]*x1 + Am[pos<2,2>::n3]*x2;
|
max@0
|
102
|
max@0
|
103 assign<eT, 0>(y, acc0, alpha, beta);
|
max@0
|
104 assign<eT, 1>(y, acc1, alpha, beta);
|
max@0
|
105 assign<eT, 2>(y, acc2, alpha, beta);
|
max@0
|
106 }
|
max@0
|
107 break;
|
max@0
|
108
|
max@0
|
109
|
max@0
|
110 case 4:
|
max@0
|
111 {
|
max@0
|
112 const eT x0 = x[0];
|
max@0
|
113 const eT x1 = x[1];
|
max@0
|
114 const eT x2 = x[2];
|
max@0
|
115 const eT x3 = x[3];
|
max@0
|
116
|
max@0
|
117 const eT acc0 = Am[pos<0,0>::n4]*x0 + Am[pos<0,1>::n4]*x1 + Am[pos<0,2>::n4]*x2 + Am[pos<0,3>::n4]*x3;
|
max@0
|
118 const eT acc1 = Am[pos<1,0>::n4]*x0 + Am[pos<1,1>::n4]*x1 + Am[pos<1,2>::n4]*x2 + Am[pos<1,3>::n4]*x3;
|
max@0
|
119 const eT acc2 = Am[pos<2,0>::n4]*x0 + Am[pos<2,1>::n4]*x1 + Am[pos<2,2>::n4]*x2 + Am[pos<2,3>::n4]*x3;
|
max@0
|
120 const eT acc3 = Am[pos<3,0>::n4]*x0 + Am[pos<3,1>::n4]*x1 + Am[pos<3,2>::n4]*x2 + Am[pos<3,3>::n4]*x3;
|
max@0
|
121
|
max@0
|
122 assign<eT, 0>(y, acc0, alpha, beta);
|
max@0
|
123 assign<eT, 1>(y, acc1, alpha, beta);
|
max@0
|
124 assign<eT, 2>(y, acc2, alpha, beta);
|
max@0
|
125 assign<eT, 3>(y, acc3, alpha, beta);
|
max@0
|
126 }
|
max@0
|
127 break;
|
max@0
|
128
|
max@0
|
129
|
max@0
|
130 default:
|
max@0
|
131 ;
|
max@0
|
132 }
|
max@0
|
133 }
|
max@0
|
134
|
max@0
|
135 };
|
max@0
|
136
|
max@0
|
137
|
max@0
|
138
|
max@0
|
139 //! \brief
|
max@0
|
140 //! Partial emulation of ATLAS/BLAS gemv().
|
max@0
|
141 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
|
max@0
|
142
|
max@0
|
143 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
|
max@0
|
144 class gemv_emul_large
|
max@0
|
145 {
|
max@0
|
146 public:
|
max@0
|
147
|
max@0
|
148 template<typename eT>
|
max@0
|
149 arma_hot
|
max@0
|
150 inline
|
max@0
|
151 static
|
max@0
|
152 void
|
max@0
|
153 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
|
max@0
|
154 {
|
max@0
|
155 arma_extra_debug_sigprint();
|
max@0
|
156
|
max@0
|
157 const uword A_n_rows = A.n_rows;
|
max@0
|
158 const uword A_n_cols = A.n_cols;
|
max@0
|
159
|
max@0
|
160 if(do_trans_A == false)
|
max@0
|
161 {
|
max@0
|
162 if(A_n_rows == 1)
|
max@0
|
163 {
|
max@0
|
164 const eT acc = op_dot::direct_dot_arma(A_n_cols, A.mem, x);
|
max@0
|
165
|
max@0
|
166 if( (use_alpha == false) && (use_beta == false) )
|
max@0
|
167 {
|
max@0
|
168 y[0] = acc;
|
max@0
|
169 }
|
max@0
|
170 else
|
max@0
|
171 if( (use_alpha == true) && (use_beta == false) )
|
max@0
|
172 {
|
max@0
|
173 y[0] = alpha * acc;
|
max@0
|
174 }
|
max@0
|
175 else
|
max@0
|
176 if( (use_alpha == false) && (use_beta == true) )
|
max@0
|
177 {
|
max@0
|
178 y[0] = acc + beta*y[0];
|
max@0
|
179 }
|
max@0
|
180 else
|
max@0
|
181 if( (use_alpha == true) && (use_beta == true) )
|
max@0
|
182 {
|
max@0
|
183 y[0] = alpha*acc + beta*y[0];
|
max@0
|
184 }
|
max@0
|
185 }
|
max@0
|
186 else
|
max@0
|
187 for(uword row=0; row < A_n_rows; ++row)
|
max@0
|
188 {
|
max@0
|
189 eT acc = eT(0);
|
max@0
|
190
|
max@0
|
191 for(uword i=0; i < A_n_cols; ++i)
|
max@0
|
192 {
|
max@0
|
193 acc += A.at(row,i) * x[i];
|
max@0
|
194 }
|
max@0
|
195
|
max@0
|
196 if( (use_alpha == false) && (use_beta == false) )
|
max@0
|
197 {
|
max@0
|
198 y[row] = acc;
|
max@0
|
199 }
|
max@0
|
200 else
|
max@0
|
201 if( (use_alpha == true) && (use_beta == false) )
|
max@0
|
202 {
|
max@0
|
203 y[row] = alpha * acc;
|
max@0
|
204 }
|
max@0
|
205 else
|
max@0
|
206 if( (use_alpha == false) && (use_beta == true) )
|
max@0
|
207 {
|
max@0
|
208 y[row] = acc + beta*y[row];
|
max@0
|
209 }
|
max@0
|
210 else
|
max@0
|
211 if( (use_alpha == true) && (use_beta == true) )
|
max@0
|
212 {
|
max@0
|
213 y[row] = alpha*acc + beta*y[row];
|
max@0
|
214 }
|
max@0
|
215 }
|
max@0
|
216 }
|
max@0
|
217 else
|
max@0
|
218 if(do_trans_A == true)
|
max@0
|
219 {
|
max@0
|
220 for(uword col=0; col < A_n_cols; ++col)
|
max@0
|
221 {
|
max@0
|
222 // col is interpreted as row when storing the results in 'y'
|
max@0
|
223
|
max@0
|
224
|
max@0
|
225 // const eT* A_coldata = A.colptr(col);
|
max@0
|
226 //
|
max@0
|
227 // eT acc = eT(0);
|
max@0
|
228 // for(uword row=0; row < A_n_rows; ++row)
|
max@0
|
229 // {
|
max@0
|
230 // acc += A_coldata[row] * x[row];
|
max@0
|
231 // }
|
max@0
|
232
|
max@0
|
233 const eT acc = op_dot::direct_dot_arma(A_n_rows, A.colptr(col), x);
|
max@0
|
234
|
max@0
|
235 if( (use_alpha == false) && (use_beta == false) )
|
max@0
|
236 {
|
max@0
|
237 y[col] = acc;
|
max@0
|
238 }
|
max@0
|
239 else
|
max@0
|
240 if( (use_alpha == true) && (use_beta == false) )
|
max@0
|
241 {
|
max@0
|
242 y[col] = alpha * acc;
|
max@0
|
243 }
|
max@0
|
244 else
|
max@0
|
245 if( (use_alpha == false) && (use_beta == true) )
|
max@0
|
246 {
|
max@0
|
247 y[col] = acc + beta*y[col];
|
max@0
|
248 }
|
max@0
|
249 else
|
max@0
|
250 if( (use_alpha == true) && (use_beta == true) )
|
max@0
|
251 {
|
max@0
|
252 y[col] = alpha*acc + beta*y[col];
|
max@0
|
253 }
|
max@0
|
254
|
max@0
|
255 }
|
max@0
|
256 }
|
max@0
|
257 }
|
max@0
|
258
|
max@0
|
259 };
|
max@0
|
260
|
max@0
|
261
|
max@0
|
262
|
max@0
|
263 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
|
max@0
|
264 class gemv_emul
|
max@0
|
265 {
|
max@0
|
266 public:
|
max@0
|
267
|
max@0
|
268 template<typename eT>
|
max@0
|
269 arma_hot
|
max@0
|
270 inline
|
max@0
|
271 static
|
max@0
|
272 void
|
max@0
|
273 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx<eT>::result* junk = 0 )
|
max@0
|
274 {
|
max@0
|
275 arma_extra_debug_sigprint();
|
max@0
|
276 arma_ignore(junk);
|
max@0
|
277
|
max@0
|
278 const uword A_n_rows = A.n_rows;
|
max@0
|
279 const uword A_n_cols = A.n_cols;
|
max@0
|
280
|
max@0
|
281 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) )
|
max@0
|
282 {
|
max@0
|
283 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
|
max@0
|
284 }
|
max@0
|
285 else
|
max@0
|
286 {
|
max@0
|
287 gemv_emul_large<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
|
max@0
|
288 }
|
max@0
|
289 }
|
max@0
|
290
|
max@0
|
291
|
max@0
|
292
|
max@0
|
293 template<typename eT>
|
max@0
|
294 arma_hot
|
max@0
|
295 inline
|
max@0
|
296 static
|
max@0
|
297 void
|
max@0
|
298 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_cx_only<eT>::result* junk = 0 )
|
max@0
|
299 {
|
max@0
|
300 arma_extra_debug_sigprint();
|
max@0
|
301
|
max@0
|
302 Mat<eT> tmp_A;
|
max@0
|
303
|
max@0
|
304 if(do_trans_A)
|
max@0
|
305 {
|
max@0
|
306 op_htrans::apply_noalias(tmp_A, A);
|
max@0
|
307 }
|
max@0
|
308
|
max@0
|
309 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
|
max@0
|
310
|
max@0
|
311 const uword AA_n_rows = AA.n_rows;
|
max@0
|
312 const uword AA_n_cols = AA.n_cols;
|
max@0
|
313
|
max@0
|
314 if( (AA_n_rows <= 4) && (AA_n_rows == AA_n_cols) )
|
max@0
|
315 {
|
max@0
|
316 gemv_emul_tinysq<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
|
max@0
|
317 }
|
max@0
|
318 else
|
max@0
|
319 {
|
max@0
|
320 gemv_emul_large<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
|
max@0
|
321 }
|
max@0
|
322 }
|
max@0
|
323 };
|
max@0
|
324
|
max@0
|
325
|
max@0
|
326
|
max@0
|
327 //! \brief
|
max@0
|
328 //! Wrapper for ATLAS/BLAS gemv function, using template arguments to control the arguments passed to gemv.
|
max@0
|
329 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
|
max@0
|
330
|
max@0
|
331 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
|
max@0
|
332 class gemv
|
max@0
|
333 {
|
max@0
|
334 public:
|
max@0
|
335
|
max@0
|
336 template<typename eT>
|
max@0
|
337 inline
|
max@0
|
338 static
|
max@0
|
339 void
|
max@0
|
340 apply_blas_type( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
|
max@0
|
341 {
|
max@0
|
342 arma_extra_debug_sigprint();
|
max@0
|
343
|
max@0
|
344 if(A.n_elem <= 64u)
|
max@0
|
345 {
|
max@0
|
346 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
|
max@0
|
347 }
|
max@0
|
348 else
|
max@0
|
349 {
|
max@0
|
350 #if defined(ARMA_USE_ATLAS)
|
max@0
|
351 {
|
max@0
|
352 arma_extra_debug_print("atlas::cblas_gemv()");
|
max@0
|
353
|
max@0
|
354 atlas::cblas_gemv<eT>
|
max@0
|
355 (
|
max@0
|
356 atlas::CblasColMajor,
|
max@0
|
357 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
|
max@0
|
358 A.n_rows,
|
max@0
|
359 A.n_cols,
|
max@0
|
360 (use_alpha) ? alpha : eT(1),
|
max@0
|
361 A.mem,
|
max@0
|
362 A.n_rows,
|
max@0
|
363 x,
|
max@0
|
364 1,
|
max@0
|
365 (use_beta) ? beta : eT(0),
|
max@0
|
366 y,
|
max@0
|
367 1
|
max@0
|
368 );
|
max@0
|
369 }
|
max@0
|
370 #elif defined(ARMA_USE_BLAS)
|
max@0
|
371 {
|
max@0
|
372 arma_extra_debug_print("blas::gemv()");
|
max@0
|
373
|
max@0
|
374 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
|
max@0
|
375 const blas_int m = A.n_rows;
|
max@0
|
376 const blas_int n = A.n_cols;
|
max@0
|
377 const eT local_alpha = (use_alpha) ? alpha : eT(1);
|
max@0
|
378 //const blas_int lda = A.n_rows;
|
max@0
|
379 const blas_int inc = 1;
|
max@0
|
380 const eT local_beta = (use_beta) ? beta : eT(0);
|
max@0
|
381
|
max@0
|
382 arma_extra_debug_print( arma_boost::format("blas::gemv(): trans_A = %c") % trans_A );
|
max@0
|
383
|
max@0
|
384 blas::gemv<eT>
|
max@0
|
385 (
|
max@0
|
386 &trans_A,
|
max@0
|
387 &m,
|
max@0
|
388 &n,
|
max@0
|
389 &local_alpha,
|
max@0
|
390 A.mem,
|
max@0
|
391 &m, // lda
|
max@0
|
392 x,
|
max@0
|
393 &inc,
|
max@0
|
394 &local_beta,
|
max@0
|
395 y,
|
max@0
|
396 &inc
|
max@0
|
397 );
|
max@0
|
398 }
|
max@0
|
399 #else
|
max@0
|
400 {
|
max@0
|
401 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
|
max@0
|
402 }
|
max@0
|
403 #endif
|
max@0
|
404 }
|
max@0
|
405
|
max@0
|
406 }
|
max@0
|
407
|
max@0
|
408
|
max@0
|
409
|
max@0
|
410 template<typename eT>
|
max@0
|
411 arma_inline
|
max@0
|
412 static
|
max@0
|
413 void
|
max@0
|
414 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
|
max@0
|
415 {
|
max@0
|
416 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
|
max@0
|
417 }
|
max@0
|
418
|
max@0
|
419
|
max@0
|
420
|
max@0
|
421 arma_inline
|
max@0
|
422 static
|
max@0
|
423 void
|
max@0
|
424 apply
|
max@0
|
425 (
|
max@0
|
426 float* y,
|
max@0
|
427 const Mat<float>& A,
|
max@0
|
428 const float* x,
|
max@0
|
429 const float alpha = float(1),
|
max@0
|
430 const float beta = float(0)
|
max@0
|
431 )
|
max@0
|
432 {
|
max@0
|
433 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
|
max@0
|
434 }
|
max@0
|
435
|
max@0
|
436
|
max@0
|
437
|
max@0
|
438 arma_inline
|
max@0
|
439 static
|
max@0
|
440 void
|
max@0
|
441 apply
|
max@0
|
442 (
|
max@0
|
443 double* y,
|
max@0
|
444 const Mat<double>& A,
|
max@0
|
445 const double* x,
|
max@0
|
446 const double alpha = double(1),
|
max@0
|
447 const double beta = double(0)
|
max@0
|
448 )
|
max@0
|
449 {
|
max@0
|
450 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
|
max@0
|
451 }
|
max@0
|
452
|
max@0
|
453
|
max@0
|
454
|
max@0
|
455 arma_inline
|
max@0
|
456 static
|
max@0
|
457 void
|
max@0
|
458 apply
|
max@0
|
459 (
|
max@0
|
460 std::complex<float>* y,
|
max@0
|
461 const Mat< std::complex<float > >& A,
|
max@0
|
462 const std::complex<float>* x,
|
max@0
|
463 const std::complex<float> alpha = std::complex<float>(1),
|
max@0
|
464 const std::complex<float> beta = std::complex<float>(0)
|
max@0
|
465 )
|
max@0
|
466 {
|
max@0
|
467 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
|
max@0
|
468 }
|
max@0
|
469
|
max@0
|
470
|
max@0
|
471
|
max@0
|
472 arma_inline
|
max@0
|
473 static
|
max@0
|
474 void
|
max@0
|
475 apply
|
max@0
|
476 (
|
max@0
|
477 std::complex<double>* y,
|
max@0
|
478 const Mat< std::complex<double> >& A,
|
max@0
|
479 const std::complex<double>* x,
|
max@0
|
480 const std::complex<double> alpha = std::complex<double>(1),
|
max@0
|
481 const std::complex<double> beta = std::complex<double>(0)
|
max@0
|
482 )
|
max@0
|
483 {
|
max@0
|
484 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
|
max@0
|
485 }
|
max@0
|
486
|
max@0
|
487
|
max@0
|
488
|
max@0
|
489 };
|
max@0
|
490
|
max@0
|
491
|
max@0
|
492 //! @}
|