max@0: // Copyright (C) 2010-2011 NICTA (www.nicta.com.au) max@0: // Copyright (C) 2010-2011 Conrad Sanderson max@0: // max@0: // This file is part of the Armadillo C++ library. max@0: // It is provided without any warranty of fitness max@0: // for any purpose. You can redistribute this file max@0: // and/or modify it under the terms of the GNU max@0: // Lesser General Public License (LGPL) as published max@0: // by the Free Software Foundation, either version 3 max@0: // of the License or (at your option) any later version. max@0: // (see http://www.opensource.org/licenses for more info) max@0: max@0: max@0: //! \addtogroup fn_as_scalar max@0: //! @{ max@0: max@0: max@0: max@0: template max@0: struct as_scalar_redirect max@0: { max@0: template max@0: inline static typename T1::elem_type apply(const T1& X); max@0: }; max@0: max@0: max@0: max@0: template<> max@0: struct as_scalar_redirect<2> max@0: { max@0: template max@0: inline static typename T1::elem_type apply(const Glue& X); max@0: }; max@0: max@0: max@0: template<> max@0: struct as_scalar_redirect<3> max@0: { max@0: template max@0: inline static typename T1::elem_type apply(const Glue< Glue, T3, glue_times>& X); max@0: }; max@0: max@0: max@0: max@0: template max@0: template max@0: inline max@0: typename T1::elem_type max@0: as_scalar_redirect::apply(const T1& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: const unwrap tmp(X); max@0: const Mat& A = tmp.M; max@0: max@0: arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); max@0: max@0: return A.mem[0]; max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: typename T1::elem_type max@0: as_scalar_redirect<2>::apply(const Glue& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: // T1 must result in a matrix with one row max@0: // T2 must result in a matrix with one column max@0: max@0: const partial_unwrap tmp1(X.A); max@0: const partial_unwrap tmp2(X.B); max@0: max@0: const Mat& A = tmp1.M; max@0: const Mat& B = tmp2.M; max@0: max@0: const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; max@0: const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; max@0: max@0: const uword B_n_rows = (tmp2.do_trans == false) ? B.n_rows : B.n_cols; max@0: const uword B_n_cols = (tmp2.do_trans == false) ? B.n_cols : B.n_rows; max@0: max@0: const eT val = tmp1.get_val() * tmp2.get_val(); max@0: max@0: arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" ); max@0: max@0: return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem); max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: typename T1::elem_type max@0: as_scalar_redirect<3>::apply(const Glue< Glue, T3, glue_times >& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: // T1 * T2 must result in a matrix with one row max@0: // T3 must result in a matrix with one column max@0: max@0: typedef typename strip_inv ::stored_type T2_stripped_1; max@0: typedef typename strip_diagmat::stored_type T2_stripped_2; max@0: max@0: const strip_inv strip1(X.A.B); max@0: const strip_diagmat strip2(strip1.M); max@0: max@0: const bool tmp2_do_inv = strip1.do_inv; max@0: const bool tmp2_do_diagmat = strip2.do_diagmat; max@0: max@0: if(tmp2_do_diagmat == false) max@0: { max@0: const Mat tmp(X); max@0: max@0: arma_debug_check( (tmp.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); max@0: max@0: return tmp[0]; max@0: } max@0: else max@0: { max@0: const partial_unwrap tmp1(X.A.A); max@0: const partial_unwrap tmp2(strip2.M); max@0: const partial_unwrap tmp3(X.B); max@0: max@0: const Mat& A = tmp1.M; max@0: const Mat& B = tmp2.M; max@0: const Mat& C = tmp3.M; max@0: max@0: const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; max@0: const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; max@0: max@0: const bool B_is_vec = B.is_vec(); max@0: max@0: const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols ); max@0: const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows ); max@0: max@0: const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols; max@0: const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows; max@0: max@0: const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val(); max@0: max@0: arma_debug_check max@0: ( max@0: (A_n_rows != 1) || max@0: (C_n_cols != 1) || max@0: (A_n_cols != B_n_rows) || max@0: (B_n_cols != C_n_rows) max@0: , max@0: "as_scalar(): incompatible dimensions" max@0: ); max@0: max@0: max@0: if(B_is_vec == true) max@0: { max@0: if(tmp2_do_inv == true) max@0: { max@0: return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem); max@0: } max@0: else max@0: { max@0: return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem); max@0: } max@0: } max@0: else max@0: { max@0: if(tmp2_do_inv == true) max@0: { max@0: return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem); max@0: } max@0: else max@0: { max@0: return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem); max@0: } max@0: } max@0: } max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: typename T1::elem_type max@0: as_scalar_diag(const Base& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: const unwrap tmp(X.get_ref()); max@0: const Mat& A = tmp.M; max@0: max@0: arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); max@0: max@0: return A.mem[0]; max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: typename T1::elem_type max@0: as_scalar_diag(const Glue< Glue, T3, glue_times >& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: // T1 * T2 must result in a matrix with one row max@0: // T3 must result in a matrix with one column max@0: max@0: typedef typename strip_diagmat::stored_type T2_stripped; max@0: max@0: const strip_diagmat strip(X.A.B); max@0: max@0: const partial_unwrap tmp1(X.A.A); max@0: const partial_unwrap tmp2(strip.M); max@0: const partial_unwrap tmp3(X.B); max@0: max@0: const Mat& A = tmp1.M; max@0: const Mat& B = tmp2.M; max@0: const Mat& C = tmp3.M; max@0: max@0: max@0: const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; max@0: const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; max@0: max@0: const bool B_is_vec = B.is_vec(); max@0: max@0: const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols ); max@0: const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows ); max@0: max@0: const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols; max@0: const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows; max@0: max@0: const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val(); max@0: max@0: arma_debug_check max@0: ( max@0: (A_n_rows != 1) || max@0: (C_n_cols != 1) || max@0: (A_n_cols != B_n_rows) || max@0: (B_n_cols != C_n_rows) max@0: , max@0: "as_scalar(): incompatible dimensions" max@0: ); max@0: max@0: max@0: if(B_is_vec == true) max@0: { max@0: return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem); max@0: } max@0: else max@0: { max@0: return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem); max@0: } max@0: } max@0: max@0: max@0: max@0: template max@0: arma_inline max@0: arma_warn_unused max@0: typename T1::elem_type max@0: as_scalar(const Glue& X, const typename arma_not_cx::result* junk = 0) max@0: { max@0: arma_extra_debug_sigprint(); max@0: arma_ignore(junk); max@0: max@0: if(is_glue_times_diag::value == false) max@0: { max@0: const sword N_mat = 1 + depth_lhs< glue_times, Glue >::num; max@0: max@0: arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat); max@0: max@0: return as_scalar_redirect::apply(X); max@0: } max@0: else max@0: { max@0: return as_scalar_diag(X); max@0: } max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: arma_warn_unused max@0: typename T1::elem_type max@0: as_scalar(const Base& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: const unwrap tmp(X.get_ref()); max@0: const Mat& A = tmp.M; max@0: max@0: arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); max@0: max@0: return A.mem[0]; max@0: } max@0: max@0: max@0: max@0: template max@0: arma_inline max@0: arma_warn_unused max@0: typename T1::elem_type max@0: as_scalar(const eOp& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: return -(as_scalar(X.P.Q)); max@0: } max@0: max@0: max@0: max@0: template max@0: inline max@0: arma_warn_unused max@0: typename T1::elem_type max@0: as_scalar(const BaseCube& X) max@0: { max@0: arma_extra_debug_sigprint(); max@0: max@0: typedef typename T1::elem_type eT; max@0: max@0: const unwrap_cube tmp(X.get_ref()); max@0: const Cube& A = tmp.M; max@0: max@0: arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); max@0: max@0: return A.mem[0]; max@0: } max@0: max@0: max@0: max@0: template max@0: arma_inline max@0: arma_warn_unused max@0: const typename arma_scalar_only::result & max@0: as_scalar(const T& x) max@0: { max@0: return x; max@0: } max@0: max@0: max@0: max@0: //! @}