Chris@49: // Copyright (C) 2008-2012 NICTA (www.nicta.com.au) Chris@49: // Copyright (C) 2008-2012 Conrad Sanderson Chris@49: // Chris@49: // This Source Code Form is subject to the terms of the Mozilla Public Chris@49: // License, v. 2.0. If a copy of the MPL was not distributed with this Chris@49: // file, You can obtain one at http://mozilla.org/MPL/2.0/. Chris@49: Chris@49: Chris@49: Chris@49: #ifdef ARMA_USE_BLAS Chris@49: Chris@49: Chris@49: //! \namespace blas namespace for BLAS functions Chris@49: namespace blas Chris@49: { Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: void Chris@49: gemv(const char* transA, const blas_int* m, const blas_int* n, const eT* alpha, const eT* A, const blas_int* ldA, const eT* x, const blas_int* incx, const eT* beta, eT* y, const blas_int* incy) Chris@49: { Chris@49: arma_type_check((is_supported_blas_type::value == false)); Chris@49: Chris@49: if(is_float::value == true) Chris@49: { Chris@49: typedef float T; Chris@49: arma_fortran(arma_sgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); Chris@49: } Chris@49: else Chris@49: if(is_double::value == true) Chris@49: { Chris@49: typedef double T; Chris@49: arma_fortran(arma_dgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); Chris@49: } Chris@49: else Chris@49: if(is_supported_complex_float::value == true) Chris@49: { Chris@49: typedef std::complex T; Chris@49: arma_fortran(arma_cgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); Chris@49: } Chris@49: else Chris@49: if(is_supported_complex_double::value == true) Chris@49: { Chris@49: typedef std::complex T; Chris@49: arma_fortran(arma_zgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); Chris@49: } Chris@49: Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: void Chris@49: gemm(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const eT* alpha, const eT* A, const blas_int* ldA, const eT* B, const blas_int* ldB, const eT* beta, eT* C, const blas_int* ldC) Chris@49: { Chris@49: arma_type_check((is_supported_blas_type::value == false)); Chris@49: Chris@49: if(is_float::value == true) Chris@49: { Chris@49: typedef float T; Chris@49: arma_fortran(arma_sgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); Chris@49: } Chris@49: else Chris@49: if(is_double::value == true) Chris@49: { Chris@49: typedef double T; Chris@49: arma_fortran(arma_dgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); Chris@49: } Chris@49: else Chris@49: if(is_supported_complex_float::value == true) Chris@49: { Chris@49: typedef std::complex T; Chris@49: arma_fortran(arma_cgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); Chris@49: } Chris@49: else Chris@49: if(is_supported_complex_double::value == true) Chris@49: { Chris@49: typedef std::complex T; Chris@49: arma_fortran(arma_zgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); Chris@49: } Chris@49: Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: eT Chris@49: dot(const uword n_elem, const eT* x, const eT* y) Chris@49: { Chris@49: arma_type_check((is_supported_blas_type::value == false)); Chris@49: Chris@49: if(is_float::value == true) Chris@49: { Chris@49: #if defined(ARMA_BLAS_SDOT_BUG) Chris@49: { Chris@49: if(n_elem == 0) { return eT(0); } Chris@49: Chris@49: const char trans = 'T'; Chris@49: Chris@49: const blas_int m = blas_int(n_elem); Chris@49: const blas_int n = 1; Chris@49: //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1); Chris@49: const blas_int inc = 1; Chris@49: Chris@49: const eT alpha = eT(1); Chris@49: const eT beta = eT(0); Chris@49: Chris@49: eT result[2]; // paranoia: using two elements instead of one Chris@49: Chris@49: //blas::gemv(&trans, &m, &n, &alpha, x, &lda, y, &inc, &beta, &result[0], &inc); Chris@49: blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc); Chris@49: Chris@49: return result[0]; Chris@49: } Chris@49: #else Chris@49: { Chris@49: blas_int n = blas_int(n_elem); Chris@49: blas_int inc = 1; Chris@49: Chris@49: typedef float T; Chris@49: return arma_fortran(arma_sdot)(&n, (const T*)x, &inc, (const T*)y, &inc); Chris@49: } Chris@49: #endif Chris@49: } Chris@49: else Chris@49: if(is_double::value == true) Chris@49: { Chris@49: blas_int n = blas_int(n_elem); Chris@49: blas_int inc = 1; Chris@49: Chris@49: typedef double T; Chris@49: return arma_fortran(arma_ddot)(&n, (const T*)x, &inc, (const T*)y, &inc); Chris@49: } Chris@49: else Chris@49: if( (is_supported_complex_float::value == true) || (is_supported_complex_double::value == true) ) Chris@49: { Chris@49: if(n_elem == 0) { return eT(0); } Chris@49: Chris@49: // using gemv() workaround due to compatibility issues with cdotu() and zdotu() Chris@49: Chris@49: const char trans = 'T'; Chris@49: Chris@49: const blas_int m = blas_int(n_elem); Chris@49: const blas_int n = 1; Chris@49: //const blas_int lda = (n_elem > 0) ? blas_int(n_elem) : blas_int(1); Chris@49: const blas_int inc = 1; Chris@49: Chris@49: const eT alpha = eT(1); Chris@49: const eT beta = eT(0); Chris@49: Chris@49: eT result[2]; // paranoia: using two elements instead of one Chris@49: Chris@49: //blas::gemv(&trans, &m, &n, &alpha, x, &lda, y, &inc, &beta, &result[0], &inc); Chris@49: blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc); Chris@49: Chris@49: return result[0]; Chris@49: } Chris@49: else Chris@49: { Chris@49: return eT(0); Chris@49: } Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: #endif