Chris@49: // Copyright (C) 2008-2013 NICTA (www.nicta.com.au) Chris@49: // Copyright (C) 2008-2013 Conrad Sanderson Chris@49: // Chris@49: // This Source Code Form is subject to the terms of the Mozilla Public Chris@49: // License, v. 2.0. If a copy of the MPL was not distributed with this Chris@49: // file, You can obtain one at http://mozilla.org/MPL/2.0/. Chris@49: Chris@49: Chris@49: //! \addtogroup glue_times Chris@49: //! @{ Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times_redirect2_helper::apply(Mat& out, const Glue& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const partial_unwrap_check tmp1(X.A, out); Chris@49: const partial_unwrap_check tmp2(X.B, out); Chris@49: Chris@49: const typename partial_unwrap_check::stored_type& A = tmp1.M; Chris@49: const typename partial_unwrap_check::stored_type& B = tmp2.M; Chris@49: Chris@49: const bool use_alpha = partial_unwrap_check::do_times || partial_unwrap_check::do_times; Chris@49: const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0); Chris@49: Chris@49: glue_times::apply Chris@49: < Chris@49: eT, Chris@49: partial_unwrap_check::do_trans, Chris@49: partial_unwrap_check::do_trans, Chris@49: (partial_unwrap_check::do_times || partial_unwrap_check::do_times) Chris@49: > Chris@49: (out, A, B, alpha); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times_redirect2_helper::apply(Mat& out, const Glue& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: if(strip_inv::do_inv == false) Chris@49: { Chris@49: const partial_unwrap_check tmp1(X.A, out); Chris@49: const partial_unwrap_check tmp2(X.B, out); Chris@49: Chris@49: const typename partial_unwrap_check::stored_type& A = tmp1.M; Chris@49: const typename partial_unwrap_check::stored_type& B = tmp2.M; Chris@49: Chris@49: const bool use_alpha = partial_unwrap_check::do_times || partial_unwrap_check::do_times; Chris@49: const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0); Chris@49: Chris@49: glue_times::apply Chris@49: < Chris@49: eT, Chris@49: partial_unwrap_check::do_trans, Chris@49: partial_unwrap_check::do_trans, Chris@49: (partial_unwrap_check::do_times || partial_unwrap_check::do_times) Chris@49: > Chris@49: (out, A, B, alpha); Chris@49: } Chris@49: else Chris@49: { Chris@49: arma_extra_debug_print("glue_times_redirect<2>::apply(): detected inv(A)*B"); Chris@49: Chris@49: const strip_inv A_strip(X.A); Chris@49: Chris@49: Mat A = A_strip.M; Chris@49: Chris@49: arma_debug_check( (A.is_square() == false), "inv(): given matrix is not square" ); Chris@49: Chris@49: const unwrap_check B_tmp(X.B, out); Chris@49: const Mat& B = B_tmp.M; Chris@49: Chris@49: glue_solve::solve_direct( out, A, B, A_strip.slow ); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times_redirect::apply(Mat& out, const Glue& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const partial_unwrap_check tmp1(X.A, out); Chris@49: const partial_unwrap_check tmp2(X.B, out); Chris@49: Chris@49: const typename partial_unwrap_check::stored_type& A = tmp1.M; Chris@49: const typename partial_unwrap_check::stored_type& B = tmp2.M; Chris@49: Chris@49: const bool use_alpha = partial_unwrap_check::do_times || partial_unwrap_check::do_times; Chris@49: const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0); Chris@49: Chris@49: glue_times::apply Chris@49: < Chris@49: eT, Chris@49: partial_unwrap_check::do_trans, Chris@49: partial_unwrap_check::do_trans, Chris@49: (partial_unwrap_check::do_times || partial_unwrap_check::do_times) Chris@49: > Chris@49: (out, A, B, alpha); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times_redirect<2>::apply(Mat& out, const Glue& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: glue_times_redirect2_helper< is_supported_blas_type::value >::apply(out, X); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times_redirect<3>::apply(Mat& out, const Glue< Glue, T3, glue_times>& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: // TODO: investigate detecting inv(A)*B*C and replacing with solve(A,B)*C Chris@49: // TODO: investigate detecting A*inv(B)*C and replacing with A*solve(B,C) Chris@49: Chris@49: // there is exactly 3 objects Chris@49: // hence we can safely expand X as X.A.A, X.A.B and X.B Chris@49: Chris@49: const partial_unwrap_check tmp1(X.A.A, out); Chris@49: const partial_unwrap_check tmp2(X.A.B, out); Chris@49: const partial_unwrap_check tmp3(X.B, out); Chris@49: Chris@49: const typename partial_unwrap_check::stored_type& A = tmp1.M; Chris@49: const typename partial_unwrap_check::stored_type& B = tmp2.M; Chris@49: const typename partial_unwrap_check::stored_type& C = tmp3.M; Chris@49: Chris@49: const bool use_alpha = partial_unwrap_check::do_times || partial_unwrap_check::do_times || partial_unwrap_check::do_times; Chris@49: const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0); Chris@49: Chris@49: glue_times::apply Chris@49: < Chris@49: eT, Chris@49: partial_unwrap_check::do_trans, Chris@49: partial_unwrap_check::do_trans, Chris@49: partial_unwrap_check::do_trans, Chris@49: (partial_unwrap_check::do_times || partial_unwrap_check::do_times || partial_unwrap_check::do_times) Chris@49: > Chris@49: (out, A, B, C, alpha); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times_redirect<4>::apply(Mat& out, const Glue< Glue< Glue, T3, glue_times>, T4, glue_times>& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: // there is exactly 4 objects Chris@49: // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B Chris@49: Chris@49: const partial_unwrap_check tmp1(X.A.A.A, out); Chris@49: const partial_unwrap_check tmp2(X.A.A.B, out); Chris@49: const partial_unwrap_check tmp3(X.A.B, out); Chris@49: const partial_unwrap_check tmp4(X.B, out); Chris@49: Chris@49: const typename partial_unwrap_check::stored_type& A = tmp1.M; Chris@49: const typename partial_unwrap_check::stored_type& B = tmp2.M; Chris@49: const typename partial_unwrap_check::stored_type& C = tmp3.M; Chris@49: const typename partial_unwrap_check::stored_type& D = tmp4.M; Chris@49: Chris@49: const bool use_alpha = partial_unwrap_check::do_times || partial_unwrap_check::do_times || partial_unwrap_check::do_times || partial_unwrap_check::do_times; Chris@49: const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0); Chris@49: Chris@49: glue_times::apply Chris@49: < Chris@49: eT, Chris@49: partial_unwrap_check::do_trans, Chris@49: partial_unwrap_check::do_trans, Chris@49: partial_unwrap_check::do_trans, Chris@49: partial_unwrap_check::do_trans, Chris@49: (partial_unwrap_check::do_times || partial_unwrap_check::do_times || partial_unwrap_check::do_times || partial_unwrap_check::do_times) Chris@49: > Chris@49: (out, A, B, C, D, alpha); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times::apply(Mat& out, const Glue& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: const sword N_mat = 1 + depth_lhs< glue_times, Glue >::num; Chris@49: Chris@49: arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat); Chris@49: Chris@49: glue_times_redirect::apply(out, X); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times::apply_inplace(Mat& out, const T1& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const unwrap_check B_tmp(X, out); Chris@49: const Mat& B = B_tmp.M; Chris@49: Chris@49: arma_debug_assert_mul_size(out, B, "matrix multiplication"); Chris@49: Chris@49: const uword out_n_rows = out.n_rows; Chris@49: const uword out_n_cols = out.n_cols; Chris@49: Chris@49: if(out_n_cols == B.n_cols) Chris@49: { Chris@49: // size of resulting matrix is the same as 'out' Chris@49: Chris@49: podarray tmp(out_n_cols); Chris@49: Chris@49: eT* tmp_rowdata = tmp.memptr(); Chris@49: Chris@49: for(uword row=0; row < out_n_rows; ++row) Chris@49: { Chris@49: tmp.copy_row(out, row); Chris@49: Chris@49: for(uword col=0; col < out_n_cols; ++col) Chris@49: { Chris@49: out.at(row,col) = op_dot::direct_dot( out_n_cols, tmp_rowdata, B.colptr(col) ); Chris@49: } Chris@49: } Chris@49: Chris@49: } Chris@49: else Chris@49: { Chris@49: const Mat tmp(out); Chris@49: Chris@49: glue_times::apply(out, tmp, B, eT(1)); Chris@49: } Chris@49: Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times::apply_inplace_plus(Mat& out, const Glue& X, const sword sign) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const partial_unwrap_check tmp1(X.A, out); Chris@49: const partial_unwrap_check tmp2(X.B, out); Chris@49: Chris@49: typedef typename partial_unwrap_check::stored_type TA; Chris@49: typedef typename partial_unwrap_check::stored_type TB; Chris@49: Chris@49: const TA& A = tmp1.M; Chris@49: const TB& B = tmp2.M; Chris@49: Chris@49: const bool do_trans_A = partial_unwrap_check::do_trans; Chris@49: const bool do_trans_B = partial_unwrap_check::do_trans; Chris@49: Chris@49: const bool use_alpha = partial_unwrap_check::do_times || partial_unwrap_check::do_times || (sign < sword(0)); Chris@49: const eT alpha = use_alpha ? ( tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ) ) : eT(0); Chris@49: Chris@49: arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication"); Chris@49: Chris@49: const uword result_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols); Chris@49: const uword result_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows); Chris@49: Chris@49: arma_debug_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "addition"); Chris@49: Chris@49: if(out.n_elem > 0) Chris@49: { Chris@49: if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) Chris@49: { Chris@49: if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: if( (B.n_cols == 1) || (TB::is_col) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha, eT(1)); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) Chris@49: { Chris@49: if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: if( (B.n_cols == 1) || (TB::is_col) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha, eT(1)); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) Chris@49: { Chris@49: if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: if( (B.n_cols == 1) || (TB::is_col) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha, eT(1)); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) Chris@49: { Chris@49: if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: if( (B.n_cols == 1) || (TB::is_col) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha, eT(1)); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) Chris@49: { Chris@49: if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha, eT(1)); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) Chris@49: { Chris@49: if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha, eT(1)); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) Chris@49: { Chris@49: if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha, eT(1)); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) Chris@49: { Chris@49: if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha, eT(1)); Chris@49: } Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_inline Chris@49: uword Chris@49: glue_times::mul_storage_cost(const TA& A, const TB& B) Chris@49: { Chris@49: const uword final_A_n_rows = (do_trans_A == false) ? ( TA::is_row ? 1 : A.n_rows ) : ( TA::is_col ? 1 : A.n_cols ); Chris@49: const uword final_B_n_cols = (do_trans_B == false) ? ( TB::is_col ? 1 : B.n_cols ) : ( TB::is_row ? 1 : B.n_rows ); Chris@49: Chris@49: return final_A_n_rows * final_B_n_cols; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: < Chris@49: typename eT, Chris@49: const bool do_trans_A, Chris@49: const bool do_trans_B, Chris@49: const bool use_alpha, Chris@49: typename TA, Chris@49: typename TB Chris@49: > Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times::apply Chris@49: ( Chris@49: Mat& out, Chris@49: const TA& A, Chris@49: const TB& B, Chris@49: const eT alpha Chris@49: ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: //arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication"); Chris@49: arma_debug_assert_trans_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); Chris@49: Chris@49: const uword final_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols); Chris@49: const uword final_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows); Chris@49: Chris@49: out.set_size(final_n_rows, final_n_cols); Chris@49: Chris@49: if( (A.n_elem > 0) && (B.n_elem > 0) ) Chris@49: { Chris@49: if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) Chris@49: { Chris@49: if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr()); Chris@49: } Chris@49: else Chris@49: if( (B.n_cols == 1) || (TB::is_col) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr()); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) Chris@49: { Chris@49: if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha); Chris@49: } Chris@49: else Chris@49: if( (B.n_cols == 1) || (TB::is_col) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) Chris@49: { Chris@49: if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr()); Chris@49: } Chris@49: else Chris@49: if( (B.n_cols == 1) || (TB::is_col) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr()); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) Chris@49: { Chris@49: if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha); Chris@49: } Chris@49: else Chris@49: if( (B.n_cols == 1) || (TB::is_col) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) Chris@49: { Chris@49: if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr()); Chris@49: } Chris@49: else Chris@49: if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr()); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) Chris@49: { Chris@49: if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha); Chris@49: } Chris@49: else Chris@49: if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) Chris@49: { Chris@49: if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr()); Chris@49: } Chris@49: else Chris@49: if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr()); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B); Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) Chris@49: { Chris@49: if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), B, A.memptr(), alpha); Chris@49: } Chris@49: else Chris@49: if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex::value == false) ) Chris@49: { Chris@49: gemv::apply(out.memptr(), A, B.memptr(), alpha); Chris@49: } Chris@49: else Chris@49: { Chris@49: gemm::apply(out, A, B, alpha); Chris@49: } Chris@49: } Chris@49: } Chris@49: else Chris@49: { Chris@49: out.zeros(); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: < Chris@49: typename eT, Chris@49: const bool do_trans_A, Chris@49: const bool do_trans_B, Chris@49: const bool do_trans_C, Chris@49: const bool use_alpha, Chris@49: typename TA, Chris@49: typename TB, Chris@49: typename TC Chris@49: > Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times::apply Chris@49: ( Chris@49: Mat& out, Chris@49: const TA& A, Chris@49: const TB& B, Chris@49: const TC& C, Chris@49: const eT alpha Chris@49: ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: Mat tmp; Chris@49: Chris@49: const uword storage_cost_AB = glue_times::mul_storage_cost(A, B); Chris@49: const uword storage_cost_BC = glue_times::mul_storage_cost(B, C); Chris@49: Chris@49: if(storage_cost_AB <= storage_cost_BC) Chris@49: { Chris@49: // out = (A*B)*C Chris@49: Chris@49: glue_times::apply(tmp, A, B, alpha); Chris@49: glue_times::apply(out, tmp, C, eT(0)); Chris@49: } Chris@49: else Chris@49: { Chris@49: // out = A*(B*C) Chris@49: Chris@49: glue_times::apply(tmp, B, C, alpha); Chris@49: glue_times::apply(out, A, tmp, eT(0)); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: < Chris@49: typename eT, Chris@49: const bool do_trans_A, Chris@49: const bool do_trans_B, Chris@49: const bool do_trans_C, Chris@49: const bool do_trans_D, Chris@49: const bool use_alpha, Chris@49: typename TA, Chris@49: typename TB, Chris@49: typename TC, Chris@49: typename TD Chris@49: > Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times::apply Chris@49: ( Chris@49: Mat& out, Chris@49: const TA& A, Chris@49: const TB& B, Chris@49: const TC& C, Chris@49: const TD& D, Chris@49: const eT alpha Chris@49: ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: Mat tmp; Chris@49: Chris@49: const uword storage_cost_AC = glue_times::mul_storage_cost(A, C); Chris@49: const uword storage_cost_BD = glue_times::mul_storage_cost(B, D); Chris@49: Chris@49: if(storage_cost_AC <= storage_cost_BD) Chris@49: { Chris@49: // out = (A*B*C)*D Chris@49: Chris@49: glue_times::apply(tmp, A, B, C, alpha); Chris@49: Chris@49: glue_times::apply(out, tmp, D, eT(0)); Chris@49: } Chris@49: else Chris@49: { Chris@49: // out = A*(B*C*D) Chris@49: Chris@49: glue_times::apply(tmp, B, C, D, alpha); Chris@49: Chris@49: glue_times::apply(out, A, tmp, eT(0)); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: // Chris@49: // glue_times_diag Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: glue_times_diag::apply(Mat& out, const Glue& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const strip_diagmat S1(X.A); Chris@49: const strip_diagmat S2(X.B); Chris@49: Chris@49: typedef typename strip_diagmat::stored_type T1_stripped; Chris@49: typedef typename strip_diagmat::stored_type T2_stripped; Chris@49: Chris@49: if( (strip_diagmat::do_diagmat == true) && (strip_diagmat::do_diagmat == false) ) Chris@49: { Chris@49: const diagmat_proxy_check A(S1.M, out); Chris@49: Chris@49: const unwrap_check tmp(X.B, out); Chris@49: const Mat& B = tmp.M; Chris@49: Chris@49: const uword A_n_elem = A.n_elem; Chris@49: const uword B_n_rows = B.n_rows; Chris@49: const uword B_n_cols = B.n_cols; Chris@49: Chris@49: arma_debug_assert_mul_size(A_n_elem, A_n_elem, B_n_rows, B_n_cols, "matrix multiplication"); Chris@49: Chris@49: out.set_size(A_n_elem, B_n_cols); Chris@49: Chris@49: for(uword col=0; col < B_n_cols; ++col) Chris@49: { Chris@49: eT* out_coldata = out.colptr(col); Chris@49: const eT* B_coldata = B.colptr(col); Chris@49: Chris@49: uword i,j; Chris@49: for(i=0, j=1; j < B_n_rows; i+=2, j+=2) Chris@49: { Chris@49: eT tmp_i = A[i]; Chris@49: eT tmp_j = A[j]; Chris@49: Chris@49: tmp_i *= B_coldata[i]; Chris@49: tmp_j *= B_coldata[j]; Chris@49: Chris@49: out_coldata[i] = tmp_i; Chris@49: out_coldata[j] = tmp_j; Chris@49: } Chris@49: Chris@49: if(i < B_n_rows) Chris@49: { Chris@49: out_coldata[i] = A[i] * B_coldata[i]; Chris@49: } Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (strip_diagmat::do_diagmat == false) && (strip_diagmat::do_diagmat == true) ) Chris@49: { Chris@49: const unwrap_check tmp(X.A, out); Chris@49: const Mat& A = tmp.M; Chris@49: Chris@49: const diagmat_proxy_check B(S2.M, out); Chris@49: Chris@49: const uword A_n_rows = A.n_rows; Chris@49: const uword A_n_cols = A.n_cols; Chris@49: const uword B_n_elem = B.n_elem; Chris@49: Chris@49: arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_elem, B_n_elem, "matrix multiplication"); Chris@49: Chris@49: out.set_size(A_n_rows, B_n_elem); Chris@49: Chris@49: for(uword col=0; col < A_n_cols; ++col) Chris@49: { Chris@49: const eT val = B[col]; Chris@49: Chris@49: eT* out_coldata = out.colptr(col); Chris@49: const eT* A_coldata = A.colptr(col); Chris@49: Chris@49: uword i,j; Chris@49: for(i=0, j=1; j < A_n_rows; i+=2, j+=2) Chris@49: { Chris@49: const eT tmp_i = A_coldata[i] * val; Chris@49: const eT tmp_j = A_coldata[j] * val; Chris@49: Chris@49: out_coldata[i] = tmp_i; Chris@49: out_coldata[j] = tmp_j; Chris@49: } Chris@49: Chris@49: if(i < A_n_rows) Chris@49: { Chris@49: out_coldata[i] = A_coldata[i] * val; Chris@49: } Chris@49: } Chris@49: } Chris@49: else Chris@49: if( (strip_diagmat::do_diagmat == true) && (strip_diagmat::do_diagmat == true) ) Chris@49: { Chris@49: const diagmat_proxy_check A(S1.M, out); Chris@49: const diagmat_proxy_check B(S2.M, out); Chris@49: Chris@49: const uword A_n_elem = A.n_elem; Chris@49: const uword B_n_elem = B.n_elem; Chris@49: Chris@49: arma_debug_assert_mul_size(A_n_elem, A_n_elem, B_n_elem, B_n_elem, "matrix multiplication"); Chris@49: Chris@49: out.zeros(A_n_elem, A_n_elem); Chris@49: Chris@49: for(uword i=0; i < A_n_elem; ++i) Chris@49: { Chris@49: out.at(i,i) = A[i] * B[i]; Chris@49: } Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! @}