Chris@49: // Copyright (C) 2010-2013 NICTA (www.nicta.com.au) Chris@49: // Copyright (C) 2010-2013 Conrad Sanderson Chris@49: // Chris@49: // This Source Code Form is subject to the terms of the Mozilla Public Chris@49: // License, v. 2.0. If a copy of the MPL was not distributed with this Chris@49: // file, You can obtain one at http://mozilla.org/MPL/2.0/. Chris@49: Chris@49: Chris@49: //! \addtogroup fn_as_scalar Chris@49: //! @{ Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: struct as_scalar_redirect Chris@49: { Chris@49: template Chris@49: inline static typename T1::elem_type apply(const T1& X); Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template<> Chris@49: struct as_scalar_redirect<2> Chris@49: { Chris@49: template Chris@49: inline static typename T1::elem_type apply(const Glue& X); Chris@49: }; Chris@49: Chris@49: Chris@49: template<> Chris@49: struct as_scalar_redirect<3> Chris@49: { Chris@49: template Chris@49: inline static typename T1::elem_type apply(const Glue< Glue, T3, glue_times>& X); Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: template Chris@49: inline Chris@49: typename T1::elem_type Chris@49: as_scalar_redirect::apply(const T1& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: // typedef typename T1::elem_type eT; Chris@49: Chris@49: const Proxy P(X); Chris@49: Chris@49: arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); Chris@49: Chris@49: return (Proxy::prefer_at_accessor == true) ? P.at(0,0) : P[0]; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: typename T1::elem_type Chris@49: as_scalar_redirect<2>::apply(const Glue& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: // T1 must result in a matrix with one row Chris@49: // T2 must result in a matrix with one column Chris@49: Chris@49: const bool has_all_mat = is_Mat::value && is_Mat::value; Chris@49: const bool prefer_at_accessor = Proxy::prefer_at_accessor || Proxy::prefer_at_accessor; Chris@49: Chris@49: const bool do_partial_unwrap = has_all_mat || prefer_at_accessor; Chris@49: Chris@49: if(do_partial_unwrap == true) Chris@49: { Chris@49: const partial_unwrap tmp1(X.A); Chris@49: const partial_unwrap tmp2(X.B); Chris@49: Chris@49: typedef typename partial_unwrap::stored_type TA; Chris@49: typedef typename partial_unwrap::stored_type TB; Chris@49: Chris@49: const TA& A = tmp1.M; Chris@49: const TB& B = tmp2.M; Chris@49: Chris@49: const uword A_n_rows = (tmp1.do_trans == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols); Chris@49: const uword A_n_cols = (tmp1.do_trans == false) ? (TA::is_col ? 1 : A.n_cols) : (TA::is_row ? 1 : A.n_rows); Chris@49: Chris@49: const uword B_n_rows = (tmp2.do_trans == false) ? (TB::is_row ? 1 : B.n_rows) : (TB::is_col ? 1 : B.n_cols); Chris@49: const uword B_n_cols = (tmp2.do_trans == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows); Chris@49: Chris@49: arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" ); Chris@49: Chris@49: const eT val = op_dot::direct_dot(A.n_elem, A.memptr(), B.memptr()); Chris@49: Chris@49: return (tmp1.do_times || tmp2.do_times) ? (val * tmp1.get_val() * tmp2.get_val()) : val; Chris@49: } Chris@49: else Chris@49: { Chris@49: const Proxy PA(X.A); Chris@49: const Proxy PB(X.B); Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (PA.get_n_rows() != 1) || (PB.get_n_cols() != 1) || (PA.get_n_cols() != PB.get_n_rows()), Chris@49: "as_scalar(): incompatible dimensions" Chris@49: ); Chris@49: Chris@49: return op_dot::apply_proxy(PA,PB); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: typename T1::elem_type Chris@49: as_scalar_redirect<3>::apply(const Glue< Glue, T3, glue_times >& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: // T1 * T2 must result in a matrix with one row Chris@49: // T3 must result in a matrix with one column Chris@49: Chris@49: typedef typename strip_inv ::stored_type T2_stripped_1; Chris@49: typedef typename strip_diagmat::stored_type T2_stripped_2; Chris@49: Chris@49: const strip_inv strip1(X.A.B); Chris@49: const strip_diagmat strip2(strip1.M); Chris@49: Chris@49: const bool tmp2_do_inv = strip1.do_inv; Chris@49: const bool tmp2_do_diagmat = strip2.do_diagmat; Chris@49: Chris@49: if(tmp2_do_diagmat == false) Chris@49: { Chris@49: const Mat tmp(X); Chris@49: Chris@49: arma_debug_check( (tmp.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); Chris@49: Chris@49: return tmp[0]; Chris@49: } Chris@49: else Chris@49: { Chris@49: const partial_unwrap tmp1(X.A.A); Chris@49: const partial_unwrap tmp2(strip2.M); Chris@49: const partial_unwrap tmp3(X.B); Chris@49: Chris@49: const Mat& A = tmp1.M; Chris@49: const Mat& B = tmp2.M; Chris@49: const Mat& C = tmp3.M; Chris@49: Chris@49: const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; Chris@49: const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; Chris@49: Chris@49: const bool B_is_vec = B.is_vec(); Chris@49: Chris@49: const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols ); Chris@49: const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows ); Chris@49: Chris@49: const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols; Chris@49: const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows; Chris@49: Chris@49: const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val(); Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (A_n_rows != 1) || Chris@49: (C_n_cols != 1) || Chris@49: (A_n_cols != B_n_rows) || Chris@49: (B_n_cols != C_n_rows) Chris@49: , Chris@49: "as_scalar(): incompatible dimensions" Chris@49: ); Chris@49: Chris@49: Chris@49: if(B_is_vec == true) Chris@49: { Chris@49: if(tmp2_do_inv == true) Chris@49: { Chris@49: return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem); Chris@49: } Chris@49: else Chris@49: { Chris@49: return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem); Chris@49: } Chris@49: } Chris@49: else Chris@49: { Chris@49: if(tmp2_do_inv == true) Chris@49: { Chris@49: return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem); Chris@49: } Chris@49: else Chris@49: { Chris@49: return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem); Chris@49: } Chris@49: } Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: typename T1::elem_type Chris@49: as_scalar_diag(const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const unwrap tmp(X.get_ref()); Chris@49: const Mat& A = tmp.M; Chris@49: Chris@49: arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); Chris@49: Chris@49: return A.mem[0]; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: typename T1::elem_type Chris@49: as_scalar_diag(const Glue< Glue, T3, glue_times >& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: // T1 * T2 must result in a matrix with one row Chris@49: // T3 must result in a matrix with one column Chris@49: Chris@49: typedef typename strip_diagmat::stored_type T2_stripped; Chris@49: Chris@49: const strip_diagmat strip(X.A.B); Chris@49: Chris@49: const partial_unwrap tmp1(X.A.A); Chris@49: const partial_unwrap tmp2(strip.M); Chris@49: const partial_unwrap tmp3(X.B); Chris@49: Chris@49: const Mat& A = tmp1.M; Chris@49: const Mat& B = tmp2.M; Chris@49: const Mat& C = tmp3.M; Chris@49: Chris@49: Chris@49: const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; Chris@49: const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; Chris@49: Chris@49: const bool B_is_vec = B.is_vec(); Chris@49: Chris@49: const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols ); Chris@49: const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows ); Chris@49: Chris@49: const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols; Chris@49: const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows; Chris@49: Chris@49: const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val(); Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (A_n_rows != 1) || Chris@49: (C_n_cols != 1) || Chris@49: (A_n_cols != B_n_rows) || Chris@49: (B_n_cols != C_n_rows) Chris@49: , Chris@49: "as_scalar(): incompatible dimensions" Chris@49: ); Chris@49: Chris@49: Chris@49: if(B_is_vec == true) Chris@49: { Chris@49: return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem); Chris@49: } Chris@49: else Chris@49: { Chris@49: return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_inline Chris@49: arma_warn_unused Chris@49: typename T1::elem_type Chris@49: as_scalar(const Glue& X, const typename arma_not_cx::result* junk = 0) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: arma_ignore(junk); Chris@49: Chris@49: if(is_glue_times_diag::value == false) Chris@49: { Chris@49: const sword N_mat = 1 + depth_lhs< glue_times, Glue >::num; Chris@49: Chris@49: arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat); Chris@49: Chris@49: return as_scalar_redirect::apply(X); Chris@49: } Chris@49: else Chris@49: { Chris@49: return as_scalar_diag(X); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: arma_warn_unused Chris@49: typename T1::elem_type Chris@49: as_scalar(const Base& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: // typedef typename T1::elem_type eT; Chris@49: Chris@49: const Proxy P(X.get_ref()); Chris@49: Chris@49: arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); Chris@49: Chris@49: return (Proxy::prefer_at_accessor == true) ? P.at(0,0) : P[0]; Chris@49: } Chris@49: Chris@49: Chris@49: // ensure the following two functions are aware of each other Chris@49: template inline arma_warn_unused typename T1::elem_type as_scalar(const eOp& X); Chris@49: template inline arma_warn_unused typename T1::elem_type as_scalar(const eGlue& X); Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: arma_warn_unused Chris@49: typename T1::elem_type Chris@49: as_scalar(const eOp& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const eT val = as_scalar(X.P.Q); Chris@49: Chris@49: return eop_core::process(val, X.aux); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: arma_warn_unused Chris@49: typename T1::elem_type Chris@49: as_scalar(const eGlue& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const eT a = as_scalar(X.P1.Q); Chris@49: const eT b = as_scalar(X.P2.Q); Chris@49: Chris@49: // the optimiser will keep only one return statement Chris@49: Chris@49: if(is_same_type::value == true) { return a + b; } Chris@49: else if(is_same_type::value == true) { return a - b; } Chris@49: else if(is_same_type::value == true) { return a / b; } Chris@49: else if(is_same_type::value == true) { return a * b; } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: arma_warn_unused Chris@49: typename T1::elem_type Chris@49: as_scalar(const BaseCube& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: // typedef typename T1::elem_type eT; Chris@49: Chris@49: const ProxyCube P(X.get_ref()); Chris@49: Chris@49: arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); Chris@49: Chris@49: return (ProxyCube::prefer_at_accessor == true) ? P.at(0,0,0) : P[0]; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_inline Chris@49: arma_warn_unused Chris@49: const typename arma_scalar_only::result & Chris@49: as_scalar(const T& x) Chris@49: { Chris@49: return x; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_inline Chris@49: arma_warn_unused Chris@49: typename T1::elem_type Chris@49: as_scalar(const SpBase& X) Chris@49: { Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const unwrap_spmat tmp(X.get_ref()); Chris@49: const SpMat& A = tmp.M; Chris@49: Chris@49: arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); Chris@49: Chris@49: return A.at(0,0); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! @}