Chris@49
|
1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2012 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9
|
Chris@49
|
10 #ifdef ARMA_USE_BLAS
|
Chris@49
|
11
|
Chris@49
|
12
|
Chris@49
|
13 //! \namespace blas namespace for BLAS functions
|
Chris@49
|
14 namespace blas
|
Chris@49
|
15 {
|
Chris@49
|
16
|
Chris@49
|
17
|
Chris@49
|
18 template<typename eT>
|
Chris@49
|
19 inline
|
Chris@49
|
20 void
|
Chris@49
|
21 gemv(const char* transA, const blas_int* m, const blas_int* n, const eT* alpha, const eT* A, const blas_int* ldA, const eT* x, const blas_int* incx, const eT* beta, eT* y, const blas_int* incy)
|
Chris@49
|
22 {
|
Chris@49
|
23 arma_type_check((is_supported_blas_type<eT>::value == false));
|
Chris@49
|
24
|
Chris@49
|
25 if(is_float<eT>::value == true)
|
Chris@49
|
26 {
|
Chris@49
|
27 typedef float T;
|
Chris@49
|
28 arma_fortran(arma_sgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
|
Chris@49
|
29 }
|
Chris@49
|
30 else
|
Chris@49
|
31 if(is_double<eT>::value == true)
|
Chris@49
|
32 {
|
Chris@49
|
33 typedef double T;
|
Chris@49
|
34 arma_fortran(arma_dgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
|
Chris@49
|
35 }
|
Chris@49
|
36 else
|
Chris@49
|
37 if(is_supported_complex_float<eT>::value == true)
|
Chris@49
|
38 {
|
Chris@49
|
39 typedef std::complex<float> T;
|
Chris@49
|
40 arma_fortran(arma_cgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
|
Chris@49
|
41 }
|
Chris@49
|
42 else
|
Chris@49
|
43 if(is_supported_complex_double<eT>::value == true)
|
Chris@49
|
44 {
|
Chris@49
|
45 typedef std::complex<double> T;
|
Chris@49
|
46 arma_fortran(arma_zgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
|
Chris@49
|
47 }
|
Chris@49
|
48
|
Chris@49
|
49 }
|
Chris@49
|
50
|
Chris@49
|
51
|
Chris@49
|
52
|
Chris@49
|
53 template<typename eT>
|
Chris@49
|
54 inline
|
Chris@49
|
55 void
|
Chris@49
|
56 gemm(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const eT* alpha, const eT* A, const blas_int* ldA, const eT* B, const blas_int* ldB, const eT* beta, eT* C, const blas_int* ldC)
|
Chris@49
|
57 {
|
Chris@49
|
58 arma_type_check((is_supported_blas_type<eT>::value == false));
|
Chris@49
|
59
|
Chris@49
|
60 if(is_float<eT>::value == true)
|
Chris@49
|
61 {
|
Chris@49
|
62 typedef float T;
|
Chris@49
|
63 arma_fortran(arma_sgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
|
Chris@49
|
64 }
|
Chris@49
|
65 else
|
Chris@49
|
66 if(is_double<eT>::value == true)
|
Chris@49
|
67 {
|
Chris@49
|
68 typedef double T;
|
Chris@49
|
69 arma_fortran(arma_dgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
|
Chris@49
|
70 }
|
Chris@49
|
71 else
|
Chris@49
|
72 if(is_supported_complex_float<eT>::value == true)
|
Chris@49
|
73 {
|
Chris@49
|
74 typedef std::complex<float> T;
|
Chris@49
|
75 arma_fortran(arma_cgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
|
Chris@49
|
76 }
|
Chris@49
|
77 else
|
Chris@49
|
78 if(is_supported_complex_double<eT>::value == true)
|
Chris@49
|
79 {
|
Chris@49
|
80 typedef std::complex<double> T;
|
Chris@49
|
81 arma_fortran(arma_zgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
|
Chris@49
|
82 }
|
Chris@49
|
83
|
Chris@49
|
84 }
|
Chris@49
|
85
|
Chris@49
|
86
|
Chris@49
|
87
|
Chris@49
|
88 template<typename eT>
|
Chris@49
|
89 inline
|
Chris@49
|
90 eT
|
Chris@49
|
91 dot(const uword n_elem, const eT* x, const eT* y)
|
Chris@49
|
92 {
|
Chris@49
|
93 arma_type_check((is_supported_blas_type<eT>::value == false));
|
Chris@49
|
94
|
Chris@49
|
95 if(is_float<eT>::value == true)
|
Chris@49
|
96 {
|
Chris@49
|
97 #if defined(ARMA_BLAS_SDOT_BUG)
|
Chris@49
|
98 {
|
Chris@49
|
99 if(n_elem == 0) { return eT(0); }
|
Chris@49
|
100
|
Chris@49
|
101 const char trans = 'T';
|
Chris@49
|
102
|
Chris@49
|
103 const blas_int m = blas_int(n_elem);
|
Chris@49
|
104 const blas_int n = 1;
|
Chris@49
|
105 //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1);
|
Chris@49
|
106 const blas_int inc = 1;
|
Chris@49
|
107
|
Chris@49
|
108 const eT alpha = eT(1);
|
Chris@49
|
109 const eT beta = eT(0);
|
Chris@49
|
110
|
Chris@49
|
111 eT result[2]; // paranoia: using two elements instead of one
|
Chris@49
|
112
|
Chris@49
|
113 //blas::gemv(&trans, &m, &n, &alpha, x, &lda, y, &inc, &beta, &result[0], &inc);
|
Chris@49
|
114 blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc);
|
Chris@49
|
115
|
Chris@49
|
116 return result[0];
|
Chris@49
|
117 }
|
Chris@49
|
118 #else
|
Chris@49
|
119 {
|
Chris@49
|
120 blas_int n = blas_int(n_elem);
|
Chris@49
|
121 blas_int inc = 1;
|
Chris@49
|
122
|
Chris@49
|
123 typedef float T;
|
Chris@49
|
124 return arma_fortran(arma_sdot)(&n, (const T*)x, &inc, (const T*)y, &inc);
|
Chris@49
|
125 }
|
Chris@49
|
126 #endif
|
Chris@49
|
127 }
|
Chris@49
|
128 else
|
Chris@49
|
129 if(is_double<eT>::value == true)
|
Chris@49
|
130 {
|
Chris@49
|
131 blas_int n = blas_int(n_elem);
|
Chris@49
|
132 blas_int inc = 1;
|
Chris@49
|
133
|
Chris@49
|
134 typedef double T;
|
Chris@49
|
135 return arma_fortran(arma_ddot)(&n, (const T*)x, &inc, (const T*)y, &inc);
|
Chris@49
|
136 }
|
Chris@49
|
137 else
|
Chris@49
|
138 if( (is_supported_complex_float<eT>::value == true) || (is_supported_complex_double<eT>::value == true) )
|
Chris@49
|
139 {
|
Chris@49
|
140 if(n_elem == 0) { return eT(0); }
|
Chris@49
|
141
|
Chris@49
|
142 // using gemv() workaround due to compatibility issues with cdotu() and zdotu()
|
Chris@49
|
143
|
Chris@49
|
144 const char trans = 'T';
|
Chris@49
|
145
|
Chris@49
|
146 const blas_int m = blas_int(n_elem);
|
Chris@49
|
147 const blas_int n = 1;
|
Chris@49
|
148 //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1);
|
Chris@49
|
149 const blas_int inc = 1;
|
Chris@49
|
150
|
Chris@49
|
151 const eT alpha = eT(1);
|
Chris@49
|
152 const eT beta = eT(0);
|
Chris@49
|
153
|
Chris@49
|
154 eT result[2]; // paranoia: using two elements instead of one
|
Chris@49
|
155
|
Chris@49
|
156 //blas::gemv(&trans, &m, &n, &alpha, x, &lda, y, &inc, &beta, &result[0], &inc);
|
Chris@49
|
157 blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc);
|
Chris@49
|
158
|
Chris@49
|
159 return result[0];
|
Chris@49
|
160 }
|
Chris@49
|
161 else
|
Chris@49
|
162 {
|
Chris@49
|
163 return eT(0);
|
Chris@49
|
164 }
|
Chris@49
|
165 }
|
Chris@49
|
166 }
|
Chris@49
|
167
|
Chris@49
|
168
|
Chris@49
|
169 #endif
|