Mercurial > hg > segmenter-vamp-plugin
diff armadillo-2.4.4/include/armadillo_bits/fn_as_scalar.hpp @ 0:8b6102e2a9b0
Armadillo Library
author | maxzanoni76 <max.zanoni@eecs.qmul.ac.uk> |
---|---|
date | Wed, 11 Apr 2012 09:27:06 +0100 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/armadillo-2.4.4/include/armadillo_bits/fn_as_scalar.hpp Wed Apr 11 09:27:06 2012 +0100 @@ -0,0 +1,357 @@ +// Copyright (C) 2010-2011 NICTA (www.nicta.com.au) +// Copyright (C) 2010-2011 Conrad Sanderson +// +// This file is part of the Armadillo C++ library. +// It is provided without any warranty of fitness +// for any purpose. You can redistribute this file +// and/or modify it under the terms of the GNU +// Lesser General Public License (LGPL) as published +// by the Free Software Foundation, either version 3 +// of the License or (at your option) any later version. +// (see http://www.opensource.org/licenses for more info) + + +//! \addtogroup fn_as_scalar +//! @{ + + + +template<uword N> +struct as_scalar_redirect + { + template<typename T1> + inline static typename T1::elem_type apply(const T1& X); + }; + + + +template<> +struct as_scalar_redirect<2> + { + template<typename T1, typename T2> + inline static typename T1::elem_type apply(const Glue<T1,T2,glue_times>& X); + }; + + +template<> +struct as_scalar_redirect<3> + { + template<typename T1, typename T2, typename T3> + inline static typename T1::elem_type apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times>& X); + }; + + + +template<uword N> +template<typename T1> +inline +typename T1::elem_type +as_scalar_redirect<N>::apply(const T1& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap<T1> tmp(X); + const Mat<eT>& A = tmp.M; + + arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); + + return A.mem[0]; + } + + + +template<typename T1, typename T2> +inline +typename T1::elem_type +as_scalar_redirect<2>::apply(const Glue<T1, T2, glue_times>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // T1 must result in a matrix with one row + // T2 must result in a matrix with one column + + const partial_unwrap<T1> tmp1(X.A); + const partial_unwrap<T2> tmp2(X.B); + + const Mat<eT>& A = tmp1.M; + const Mat<eT>& B = tmp2.M; + + const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; + const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; + + const uword B_n_rows = (tmp2.do_trans == false) ? B.n_rows : B.n_cols; + const uword B_n_cols = (tmp2.do_trans == false) ? B.n_cols : B.n_rows; + + const eT val = tmp1.get_val() * tmp2.get_val(); + + arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" ); + + return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem); + } + + + +template<typename T1, typename T2, typename T3> +inline +typename T1::elem_type +as_scalar_redirect<3>::apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times >& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // T1 * T2 must result in a matrix with one row + // T3 must result in a matrix with one column + + typedef typename strip_inv <T2 >::stored_type T2_stripped_1; + typedef typename strip_diagmat<T2_stripped_1>::stored_type T2_stripped_2; + + const strip_inv <T2> strip1(X.A.B); + const strip_diagmat<T2_stripped_1> strip2(strip1.M); + + const bool tmp2_do_inv = strip1.do_inv; + const bool tmp2_do_diagmat = strip2.do_diagmat; + + if(tmp2_do_diagmat == false) + { + const Mat<eT> tmp(X); + + arma_debug_check( (tmp.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); + + return tmp[0]; + } + else + { + const partial_unwrap<T1> tmp1(X.A.A); + const partial_unwrap<T2_stripped_2> tmp2(strip2.M); + const partial_unwrap<T3> tmp3(X.B); + + const Mat<eT>& A = tmp1.M; + const Mat<eT>& B = tmp2.M; + const Mat<eT>& C = tmp3.M; + + const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; + const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; + + const bool B_is_vec = B.is_vec(); + + const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols ); + const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows ); + + const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols; + const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows; + + const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val(); + + arma_debug_check + ( + (A_n_rows != 1) || + (C_n_cols != 1) || + (A_n_cols != B_n_rows) || + (B_n_cols != C_n_rows) + , + "as_scalar(): incompatible dimensions" + ); + + + if(B_is_vec == true) + { + if(tmp2_do_inv == true) + { + return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem); + } + else + { + return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem); + } + } + else + { + if(tmp2_do_inv == true) + { + return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem); + } + else + { + return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem); + } + } + } + } + + + +template<typename T1> +inline +typename T1::elem_type +as_scalar_diag(const Base<typename T1::elem_type,T1>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap<T1> tmp(X.get_ref()); + const Mat<eT>& A = tmp.M; + + arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); + + return A.mem[0]; + } + + + +template<typename T1, typename T2, typename T3> +inline +typename T1::elem_type +as_scalar_diag(const Glue< Glue<T1, T2, glue_times_diag>, T3, glue_times >& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + // T1 * T2 must result in a matrix with one row + // T3 must result in a matrix with one column + + typedef typename strip_diagmat<T2>::stored_type T2_stripped; + + const strip_diagmat<T2> strip(X.A.B); + + const partial_unwrap<T1> tmp1(X.A.A); + const partial_unwrap<T2_stripped> tmp2(strip.M); + const partial_unwrap<T3> tmp3(X.B); + + const Mat<eT>& A = tmp1.M; + const Mat<eT>& B = tmp2.M; + const Mat<eT>& C = tmp3.M; + + + const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols; + const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows; + + const bool B_is_vec = B.is_vec(); + + const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols ); + const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows ); + + const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols; + const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows; + + const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val(); + + arma_debug_check + ( + (A_n_rows != 1) || + (C_n_cols != 1) || + (A_n_cols != B_n_rows) || + (B_n_cols != C_n_rows) + , + "as_scalar(): incompatible dimensions" + ); + + + if(B_is_vec == true) + { + return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem); + } + else + { + return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem); + } + } + + + +template<typename T1, typename T2> +arma_inline +arma_warn_unused +typename T1::elem_type +as_scalar(const Glue<T1, T2, glue_times>& X, const typename arma_not_cx<typename T1::elem_type>::result* junk = 0) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(is_glue_times_diag<T1>::value == false) + { + const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num; + + arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat); + + return as_scalar_redirect<N_mat>::apply(X); + } + else + { + return as_scalar_diag(X); + } + } + + + +template<typename T1> +inline +arma_warn_unused +typename T1::elem_type +as_scalar(const Base<typename T1::elem_type,T1>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap<T1> tmp(X.get_ref()); + const Mat<eT>& A = tmp.M; + + arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); + + return A.mem[0]; + } + + + +template<typename T1> +arma_inline +arma_warn_unused +typename T1::elem_type +as_scalar(const eOp<T1, eop_neg>& X) + { + arma_extra_debug_sigprint(); + + return -(as_scalar(X.P.Q)); + } + + + +template<typename T1> +inline +arma_warn_unused +typename T1::elem_type +as_scalar(const BaseCube<typename T1::elem_type,T1>& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube<T1> tmp(X.get_ref()); + const Cube<eT>& A = tmp.M; + + arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" ); + + return A.mem[0]; + } + + + +template<typename T> +arma_inline +arma_warn_unused +const typename arma_scalar_only<T>::result & +as_scalar(const T& x) + { + return x; + } + + + +//! @}