max@0: // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) max@0: // Copyright (C) 2008-2011 Conrad Sanderson max@0: // max@0: // This file is part of the Armadillo C++ library. max@0: // It is provided without any warranty of fitness max@0: // for any purpose. You can redistribute this file max@0: // and/or modify it under the terms of the GNU max@0: // Lesser General Public License (LGPL) as published max@0: // by the Free Software Foundation, either version 3 max@0: // of the License or (at your option) any later version. max@0: // (see http://www.opensource.org/licenses for more info) max@0: max@0: max@0: //! \addtogroup glue_times max@0: //! @{ max@0: max@0: max@0: max@0: template max@0: template max@0: inline max@0: void max@0: glue_times_redirect::apply(Mat& out, const Glue& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: const partial_unwrap_check tmp1(X.A, out); max@0: const partial_unwrap_check tmp2(X.B, out); max@0: max@0: const Mat& A = tmp1.M; max@0: const Mat& B = tmp2.M; max@0: max@0: const bool do_trans_A = tmp1.do_trans; max@0: const bool do_trans_B = tmp2.do_trans; max@0: max@0: const bool use_alpha = tmp1.do_times || tmp2.do_times; max@0: const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0); max@0: max@0: glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha); max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: void max@0: glue_times_redirect<3>::apply(Mat& out, const Glue< Glue, T3, glue_times>& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: // there is exactly 3 objects max@0: // hence we can safely expand X as X.A.A, X.A.B and X.B max@0: max@0: const partial_unwrap_check tmp1(X.A.A, out); max@0: const partial_unwrap_check tmp2(X.A.B, out); max@0: const partial_unwrap_check tmp3(X.B, out); max@0: max@0: const Mat& A = tmp1.M; max@0: const Mat& B = tmp2.M; max@0: const Mat& C = tmp3.M; max@0: max@0: const bool do_trans_A = tmp1.do_trans; max@0: const bool do_trans_B = tmp2.do_trans; max@0: const bool do_trans_C = tmp3.do_trans; max@0: max@0: const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times; max@0: const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0); max@0: max@0: glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: void max@0: glue_times_redirect<4>::apply(Mat& out, const Glue< Glue< Glue, T3, glue_times>, T4, glue_times>& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: // there is exactly 4 objects max@0: // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B max@0: max@0: const partial_unwrap_check tmp1(X.A.A.A, out); max@0: const partial_unwrap_check tmp2(X.A.A.B, out); max@0: const partial_unwrap_check tmp3(X.A.B, out); max@0: const partial_unwrap_check tmp4(X.B, out); max@0: max@0: const Mat& A = tmp1.M; max@0: const Mat& B = tmp2.M; max@0: const Mat& C = tmp3.M; max@0: const Mat& D = tmp4.M; max@0: max@0: const bool do_trans_A = tmp1.do_trans; max@0: const bool do_trans_B = tmp2.do_trans; max@0: const bool do_trans_C = tmp3.do_trans; max@0: const bool do_trans_D = tmp4.do_trans; max@0: max@0: const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times || tmp4.do_times; max@0: const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0); max@0: max@0: glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha); max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: void max@0: glue_times::apply(Mat& out, const Glue& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: const sword N_mat = 1 + depth_lhs< glue_times, Glue >::num; max@0: max@0: arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat); max@0: max@0: glue_times_redirect::apply(out, X); max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: void max@0: glue_times::apply_inplace(Mat& out, const T1& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: const unwrap_check tmp(X, out); max@0: const Mat& B = tmp.M; max@0: max@0: arma_debug_assert_mul_size(out, B, "matrix multiplication"); max@0: max@0: const uword out_n_rows = out.n_rows; max@0: const uword out_n_cols = out.n_cols; max@0: max@0: if(out_n_cols == B.n_cols) max@0: { max@0: // size of resulting matrix is the same as 'out' max@0: max@0: podarray tmp(out_n_cols); max@0: max@0: eT* tmp_rowdata = tmp.memptr(); max@0: max@0: for(uword row=0; row < out_n_rows; ++row) max@0: { max@0: tmp.copy_row(out, row); max@0: max@0: for(uword col=0; col < out_n_cols; ++col) max@0: { max@0: out.at(row,col) = op_dot::direct_dot( out_n_cols, tmp_rowdata, B.colptr(col) ); max@0: } max@0: } max@0: max@0: } max@0: else max@0: { max@0: const Mat tmp(out); max@0: glue_times::apply(out, tmp, B, eT(1), false, false, false); max@0: } max@0: max@0: } max@0: max@0: max@0: max@0: template max@0: arma_hot max@0: inline max@0: void max@0: glue_times::apply_inplace_plus(Mat& out, const Glue& X, const sword sign) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: const partial_unwrap_check tmp1(X.A, out); max@0: const partial_unwrap_check tmp2(X.B, out); max@0: max@0: const Mat& A = tmp1.M; max@0: const Mat& B = tmp2.M; max@0: const eT alpha = tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ); max@0: max@0: const bool do_trans_A = tmp1.do_trans; max@0: const bool do_trans_B = tmp2.do_trans; max@0: const bool use_alpha = tmp1.do_times || tmp2.do_times || (sign < sword(0)); max@0: max@0: arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication"); max@0: max@0: const uword result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; max@0: const uword result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; max@0: max@0: arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "addition"); max@0: max@0: if(out.n_elem > 0) max@0: { max@0: if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) max@0: { max@0: if( (A.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: if(B.n_cols == 1) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha, eT(1)); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) max@0: { max@0: if( (A.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: if(B.n_cols == 1) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha, eT(1)); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) max@0: { max@0: if( (A.n_cols == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: if(B.n_cols == 1) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha, eT(1)); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) max@0: { max@0: if( (A.n_cols == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: if(B.n_cols == 1) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha, eT(1)); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) max@0: { max@0: if( (A.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: if( (B.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha, eT(1)); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) max@0: { max@0: if( (A.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: if( (B.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha, eT(1)); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) max@0: { max@0: if( (A.n_cols == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: if( (B.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha, eT(1)); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) max@0: { max@0: if( (A.n_cols == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: if( (B.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha, eT(1)); max@0: } max@0: } max@0: } max@0: max@0: max@0: } max@0: max@0: max@0: max@0: template max@0: arma_inline max@0: uword max@0: glue_times::mul_storage_cost(const Mat& A, const Mat& B, const bool do_trans_A, const bool do_trans_B) max@0: { max@0: const uword final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; max@0: const uword final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; max@0: max@0: return final_A_n_rows * final_B_n_cols; max@0: } max@0: max@0: max@0: max@0: template max@0: arma_hot max@0: inline max@0: void max@0: glue_times::apply max@0: ( max@0: Mat& out, max@0: const Mat& A, max@0: const Mat& B, max@0: const eT alpha, max@0: const bool do_trans_A, max@0: const bool do_trans_B, max@0: const bool use_alpha max@0: ) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication"); max@0: max@0: const uword final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; max@0: const uword final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; max@0: max@0: out.set_size(final_n_rows, final_n_cols); max@0: max@0: if( (A.n_elem > 0) && (B.n_elem > 0) ) max@0: { max@0: if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) max@0: { max@0: if( (A.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr()); max@0: } max@0: else max@0: if(B.n_cols == 1) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr()); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) max@0: { max@0: if( (A.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha); max@0: } max@0: else max@0: if(B.n_cols == 1) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) max@0: { max@0: if( (A.n_cols == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr()); max@0: } max@0: else max@0: if(B.n_cols == 1) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr()); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) max@0: { max@0: if( (A.n_cols == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha); max@0: } max@0: else max@0: if(B.n_cols == 1) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) max@0: { max@0: if( (A.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr()); max@0: } max@0: else max@0: if( (B.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr()); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) max@0: { max@0: if( (A.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha); max@0: } max@0: else max@0: if( (B.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) max@0: { max@0: if( (A.n_cols == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr()); max@0: } max@0: else max@0: if( (B.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr()); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B); max@0: } max@0: } max@0: else max@0: if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) max@0: { max@0: if( (A.n_cols == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), B, A.memptr(), alpha); max@0: } max@0: else max@0: if( (B.n_rows == 1) && (is_complex::value == false) ) max@0: { max@0: gemv::apply(out.memptr(), A, B.memptr(), alpha); max@0: } max@0: else max@0: { max@0: gemm::apply(out, A, B, alpha); max@0: } max@0: } max@0: } max@0: else max@0: { max@0: out.zeros(); max@0: } max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: void max@0: glue_times::apply max@0: ( max@0: Mat& out, max@0: const Mat& A, max@0: const Mat& B, max@0: const Mat& C, max@0: const eT alpha, max@0: const bool do_trans_A, max@0: const bool do_trans_B, max@0: const bool do_trans_C, max@0: const bool use_alpha max@0: ) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: Mat tmp; max@0: max@0: if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) ) max@0: { max@0: // out = (A*B)*C max@0: glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha); max@0: glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false ); max@0: } max@0: else max@0: { max@0: // out = A*(B*C) max@0: glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha); max@0: glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false ); max@0: } max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: void max@0: glue_times::apply max@0: ( max@0: Mat& out, max@0: const Mat& A, max@0: const Mat& B, max@0: const Mat& C, max@0: const Mat& D, max@0: const eT alpha, max@0: const bool do_trans_A, max@0: const bool do_trans_B, max@0: const bool do_trans_C, max@0: const bool do_trans_D, max@0: const bool use_alpha max@0: ) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: Mat tmp; max@0: max@0: if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) ) max@0: { max@0: // out = (A*B*C)*D max@0: glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); max@0: max@0: glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false); max@0: } max@0: else max@0: { max@0: // out = A*(B*C*D) max@0: glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha); max@0: max@0: glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false); max@0: } max@0: } max@0: max@0: max@0: max@0: // max@0: // glue_times_diag max@0: max@0: max@0: template max@0: arma_hot max@0: inline max@0: void max@0: glue_times_diag::apply(Mat& out, const Glue& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: const strip_diagmat S1(X.A); max@0: const strip_diagmat S2(X.B); max@0: max@0: typedef typename strip_diagmat::stored_type T1_stripped; max@0: typedef typename strip_diagmat::stored_type T2_stripped; max@0: max@0: if( (S1.do_diagmat == true) && (S2.do_diagmat == false) ) max@0: { max@0: const diagmat_proxy_check A(S1.M, out); max@0: max@0: const unwrap_check tmp(X.B, out); max@0: const Mat& B = tmp.M; max@0: max@0: arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiplication"); max@0: max@0: out.set_size(A.n_elem, B.n_cols); max@0: max@0: for(uword col=0; col tmp(X.A, out); max@0: const Mat& A = tmp.M; max@0: max@0: const diagmat_proxy_check B(S2.M, out); max@0: max@0: arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiplication"); max@0: max@0: out.set_size(A.n_rows, B.n_elem); max@0: max@0: for(uword col=0; col A(S1.M, out); max@0: const diagmat_proxy_check B(S2.M, out); max@0: max@0: arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiplication"); max@0: max@0: out.zeros(A.n_elem, A.n_elem); max@0: max@0: for(uword i=0; i