Mercurial > hg > segmenter-vamp-plugin
diff armadillo-2.4.4/include/armadillo_bits/glue_times_meat.hpp @ 0:8b6102e2a9b0
Armadillo Library
author | maxzanoni76 <max.zanoni@eecs.qmul.ac.uk> |
---|---|
date | Wed, 11 Apr 2012 09:27:06 +0100 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/armadillo-2.4.4/include/armadillo_bits/glue_times_meat.hpp Wed Apr 11 09:27:06 2012 +0100 @@ -0,0 +1,694 @@ +// Copyright (C) 2008-2011 NICTA (www.nicta.com.au) +// Copyright (C) 2008-2011 Conrad Sanderson +// +// This file is part of the Armadillo C++ library. +// It is provided without any warranty of fitness +// for any purpose. You can redistribute this file +// and/or modify it under the terms of the GNU +// Lesser General Public License (LGPL) as published +// by the Free Software Foundation, either version 3 +// of the License or (at your option) any later version. +// (see http://www.opensource.org/licenses for more info) + + +//! \addtogroup glue_times +//! @{ + + + +template<uword N> +template<typename T1, typename T2> +inline +void +glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const partial_unwrap_check<T1> tmp1(X.A, out); + const partial_unwrap_check<T2> tmp2(X.B, out); + + const Mat<eT>& A = tmp1.M; + const Mat<eT>& B = tmp2.M; + + const bool do_trans_A = tmp1.do_trans; + const bool do_trans_B = tmp2.do_trans; + + const bool use_alpha = tmp1.do_times || tmp2.do_times; + const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0); + + glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha); + } + + + +template<typename T1, typename T2, typename T3> +inline +void +glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // there is exactly 3 objects + // hence we can safely expand X as X.A.A, X.A.B and X.B + + const partial_unwrap_check<T1> tmp1(X.A.A, out); + const partial_unwrap_check<T2> tmp2(X.A.B, out); + const partial_unwrap_check<T3> tmp3(X.B, out); + + const Mat<eT>& A = tmp1.M; + const Mat<eT>& B = tmp2.M; + const Mat<eT>& C = tmp3.M; + + const bool do_trans_A = tmp1.do_trans; + const bool do_trans_B = tmp2.do_trans; + const bool do_trans_C = tmp3.do_trans; + + const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times; + const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0); + + glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); + } + + + +template<typename T1, typename T2, typename T3, typename T4> +inline +void +glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // there is exactly 4 objects + // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B + + const partial_unwrap_check<T1> tmp1(X.A.A.A, out); + const partial_unwrap_check<T2> tmp2(X.A.A.B, out); + const partial_unwrap_check<T3> tmp3(X.A.B, out); + const partial_unwrap_check<T4> tmp4(X.B, out); + + const Mat<eT>& A = tmp1.M; + const Mat<eT>& B = tmp2.M; + const Mat<eT>& C = tmp3.M; + const Mat<eT>& D = tmp4.M; + + const bool do_trans_A = tmp1.do_trans; + const bool do_trans_B = tmp2.do_trans; + const bool do_trans_C = tmp3.do_trans; + const bool do_trans_D = tmp4.do_trans; + + const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times || tmp4.do_times; + const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0); + + glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha); + } + + + +template<typename T1, typename T2> +inline +void +glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num; + + arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat); + + glue_times_redirect<N_mat>::apply(out, X); + } + + + +template<typename T1> +inline +void +glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_check<T1> tmp(X, out); + const Mat<eT>& B = tmp.M; + + arma_debug_assert_mul_size(out, B, "matrix multiplication"); + + const uword out_n_rows = out.n_rows; + const uword out_n_cols = out.n_cols; + + if(out_n_cols == B.n_cols) + { + // size of resulting matrix is the same as 'out' + + podarray<eT> tmp(out_n_cols); + + eT* tmp_rowdata = tmp.memptr(); + + for(uword row=0; row < out_n_rows; ++row) + { + tmp.copy_row(out, row); + + for(uword col=0; col < out_n_cols; ++col) + { + out.at(row,col) = op_dot::direct_dot( out_n_cols, tmp_rowdata, B.colptr(col) ); + } + } + + } + else + { + const Mat<eT> tmp(out); + glue_times::apply(out, tmp, B, eT(1), false, false, false); + } + + } + + + +template<typename T1, typename T2> +arma_hot +inline +void +glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const partial_unwrap_check<T1> tmp1(X.A, out); + const partial_unwrap_check<T2> tmp2(X.B, out); + + const Mat<eT>& A = tmp1.M; + const Mat<eT>& B = tmp2.M; + const eT alpha = tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ); + + const bool do_trans_A = tmp1.do_trans; + const bool do_trans_B = tmp2.do_trans; + const bool use_alpha = tmp1.do_times || tmp2.do_times || (sign < sword(0)); + + arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication"); + + const uword result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; + const uword result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; + + arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "addition"); + + if(out.n_elem > 0) + { + if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) + { + if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); + } + else + if(B.n_cols == 1) + { + gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); + } + else + { + gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1)); + } + } + else + if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) + { + if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); + } + else + if(B.n_cols == 1) + { + gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); + } + else + { + gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1)); + } + } + else + if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) + { + if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); + } + else + if(B.n_cols == 1) + { + gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); + } + else + { + gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1)); + } + } + else + if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) + { + if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); + } + else + if(B.n_cols == 1) + { + gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); + } + else + { + gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1)); + } + } + else + if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) + { + if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); + } + else + if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); + } + else + { + gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1)); + } + } + else + if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) + { + if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); + } + else + if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); + } + else + { + gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1)); + } + } + else + if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) + { + if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); + } + else + if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); + } + else + { + gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1)); + } + } + else + if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) + { + if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); + } + else + if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); + } + else + { + gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1)); + } + } + } + + + } + + + +template<typename eT> +arma_inline +uword +glue_times::mul_storage_cost(const Mat<eT>& A, const Mat<eT>& B, const bool do_trans_A, const bool do_trans_B) + { + const uword final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; + const uword final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; + + return final_A_n_rows * final_B_n_cols; + } + + + +template<typename eT> +arma_hot +inline +void +glue_times::apply + ( + Mat<eT>& out, + const Mat<eT>& A, + const Mat<eT>& B, + const eT alpha, + const bool do_trans_A, + const bool do_trans_B, + const bool use_alpha + ) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication"); + + const uword final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; + const uword final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; + + out.set_size(final_n_rows, final_n_cols); + + if( (A.n_elem > 0) && (B.n_elem > 0) ) + { + if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) + { + if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); + } + else + if(B.n_cols == 1) + { + gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); + } + else + { + gemm<false, false, false, false>::apply(out, A, B); + } + } + else + if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) + { + if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); + } + else + if(B.n_cols == 1) + { + gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); + } + else + { + gemm<false, false, true, false>::apply(out, A, B, alpha); + } + } + else + if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) + { + if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); + } + else + if(B.n_cols == 1) + { + gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); + } + else + { + gemm<true, false, false, false>::apply(out, A, B); + } + } + else + if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) + { + if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); + } + else + if(B.n_cols == 1) + { + gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); + } + else + { + gemm<true, false, true, false>::apply(out, A, B, alpha); + } + } + else + if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) + { + if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); + } + else + if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); + } + else + { + gemm<false, true, false, false>::apply(out, A, B); + } + } + else + if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) + { + if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); + } + else + if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); + } + else + { + gemm<false, true, true, false>::apply(out, A, B, alpha); + } + } + else + if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) + { + if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); + } + else + if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); + } + else + { + gemm<true, true, false, false>::apply(out, A, B); + } + } + else + if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) + { + if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) + { + gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); + } + else + if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) + { + gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); + } + else + { + gemm<true, true, true, false>::apply(out, A, B, alpha); + } + } + } + else + { + out.zeros(); + } + } + + + +template<typename eT> +inline +void +glue_times::apply + ( + Mat<eT>& out, + const Mat<eT>& A, + const Mat<eT>& B, + const Mat<eT>& C, + const eT alpha, + const bool do_trans_A, + const bool do_trans_B, + const bool do_trans_C, + const bool use_alpha + ) + { + arma_extra_debug_sigprint(); + + Mat<eT> tmp; + + if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) ) + { + // out = (A*B)*C + glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha); + glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false ); + } + else + { + // out = A*(B*C) + glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha); + glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false ); + } + } + + + +template<typename eT> +inline +void +glue_times::apply + ( + Mat<eT>& out, + const Mat<eT>& A, + const Mat<eT>& B, + const Mat<eT>& C, + const Mat<eT>& D, + const eT alpha, + const bool do_trans_A, + const bool do_trans_B, + const bool do_trans_C, + const bool do_trans_D, + const bool use_alpha + ) + { + arma_extra_debug_sigprint(); + + Mat<eT> tmp; + + if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) ) + { + // out = (A*B*C)*D + glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); + + glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false); + } + else + { + // out = A*(B*C*D) + glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha); + + glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false); + } + } + + + +// +// glue_times_diag + + +template<typename T1, typename T2> +arma_hot +inline +void +glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const strip_diagmat<T1> S1(X.A); + const strip_diagmat<T2> S2(X.B); + + typedef typename strip_diagmat<T1>::stored_type T1_stripped; + typedef typename strip_diagmat<T2>::stored_type T2_stripped; + + if( (S1.do_diagmat == true) && (S2.do_diagmat == false) ) + { + const diagmat_proxy_check<T1_stripped> A(S1.M, out); + + const unwrap_check<T2> tmp(X.B, out); + const Mat<eT>& B = tmp.M; + + arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiplication"); + + out.set_size(A.n_elem, B.n_cols); + + for(uword col=0; col<B.n_cols; ++col) + { + eT* out_coldata = out.colptr(col); + const eT* B_coldata = B.colptr(col); + + for(uword row=0; row<B.n_rows; ++row) + { + out_coldata[row] = A[row] * B_coldata[row]; + } + } + } + else + if( (S1.do_diagmat == false) && (S2.do_diagmat == true) ) + { + const unwrap_check<T1> tmp(X.A, out); + const Mat<eT>& A = tmp.M; + + const diagmat_proxy_check<T2_stripped> B(S2.M, out); + + arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiplication"); + + out.set_size(A.n_rows, B.n_elem); + + for(uword col=0; col<A.n_cols; ++col) + { + const eT val = B[col]; + + eT* out_coldata = out.colptr(col); + const eT* A_coldata = A.colptr(col); + + for(uword row=0; row<A.n_rows; ++row) + { + out_coldata[row] = A_coldata[row] * val; + } + } + } + else + if( (S1.do_diagmat == true) && (S2.do_diagmat == true) ) + { + const diagmat_proxy_check<T1_stripped> A(S1.M, out); + const diagmat_proxy_check<T2_stripped> B(S2.M, out); + + arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiplication"); + + out.zeros(A.n_elem, A.n_elem); + + for(uword i=0; i<A.n_elem; ++i) + { + out.at(i,i) = A[i] * B[i]; + } + } + } + + + +//! @}