Chris@49: // Copyright (C) 2008-2012 NICTA (www.nicta.com.au) Chris@49: // Copyright (C) 2008-2012 Conrad Sanderson Chris@49: // Chris@49: // This Source Code Form is subject to the terms of the Mozilla Public Chris@49: // License, v. 2.0. If a copy of the MPL was not distributed with this Chris@49: // file, You can obtain one at http://mozilla.org/MPL/2.0/. Chris@49: Chris@49: Chris@49: //! \addtogroup diagmat_proxy Chris@49: //! @{ Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy_default Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef typename T1::elem_type elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: inline Chris@49: diagmat_proxy_default(const T1& X) Chris@49: : P ( X ) Chris@49: , P_is_vec( (resolves_to_vector::value) || (P.get_n_rows() == 1) || (P.get_n_cols() == 1) ) Chris@49: , P_is_col( T1::is_col || (P.get_n_cols() == 1) ) Chris@49: , n_elem ( P_is_vec ? P.get_n_elem() : (std::min)(P.get_n_elem(), P.get_n_rows()) ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (P_is_vec == false) && (P.get_n_rows() != P.get_n_cols()), Chris@49: "diagmat(): only vectors and square matrices are accepted" Chris@49: ); Chris@49: } Chris@49: Chris@49: Chris@49: arma_inline Chris@49: elem_type Chris@49: operator[](const uword i) const Chris@49: { Chris@49: if(Proxy::prefer_at_accessor == false) Chris@49: { Chris@49: return P_is_vec ? P[i] : P.at(i,i); Chris@49: } Chris@49: else Chris@49: { Chris@49: if(P_is_vec) Chris@49: { Chris@49: return (P_is_col) ? P.at(i,0) : P.at(0,i); Chris@49: } Chris@49: else Chris@49: { Chris@49: return P.at(i,i); Chris@49: } Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: arma_inline Chris@49: elem_type Chris@49: at(const uword row, const uword col) const Chris@49: { Chris@49: if(row == col) Chris@49: { Chris@49: if(Proxy::prefer_at_accessor == false) Chris@49: { Chris@49: return (P_is_vec) ? P[row] : P.at(row,row); Chris@49: } Chris@49: else Chris@49: { Chris@49: if(P_is_vec) Chris@49: { Chris@49: return (P_is_col) ? P.at(row,0) : P.at(0,row); Chris@49: } Chris@49: else Chris@49: { Chris@49: return P.at(row,row); Chris@49: } Chris@49: } Chris@49: } Chris@49: else Chris@49: { Chris@49: return elem_type(0); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: const Proxy P; Chris@49: const bool P_is_vec; Chris@49: const bool P_is_col; Chris@49: const uword n_elem; Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy_fixed Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef typename T1::elem_type elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: inline Chris@49: diagmat_proxy_fixed(const T1& X) Chris@49: : P(X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (P_is_vec == false) && (T1::n_rows != T1::n_cols), Chris@49: "diagmat(): only vectors and square matrices are accepted" Chris@49: ); Chris@49: } Chris@49: Chris@49: Chris@49: arma_inline Chris@49: elem_type Chris@49: operator[](const uword i) const Chris@49: { Chris@49: return (P_is_vec) ? P[i] : P.at(i,i); Chris@49: } Chris@49: Chris@49: Chris@49: arma_inline Chris@49: elem_type Chris@49: at(const uword row, const uword col) const Chris@49: { Chris@49: if(row == col) Chris@49: { Chris@49: return (P_is_vec) ? P[row] : P.at(row,row); Chris@49: } Chris@49: else Chris@49: { Chris@49: return elem_type(0); Chris@49: } Chris@49: } Chris@49: Chris@49: const T1& P; Chris@49: Chris@49: static const bool P_is_vec = (T1::n_rows == 1) || (T1::n_cols == 1); Chris@49: static const uword n_elem = P_is_vec ? T1::n_elem : ( (T1::n_elem < T1::n_rows) ? T1::n_elem : T1::n_rows ); Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: struct diagmat_proxy_redirect {}; Chris@49: Chris@49: template Chris@49: struct diagmat_proxy_redirect { typedef diagmat_proxy_default result; }; Chris@49: Chris@49: template Chris@49: struct diagmat_proxy_redirect { typedef diagmat_proxy_fixed result; }; Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy : public diagmat_proxy_redirect::value >::result Chris@49: { Chris@49: public: Chris@49: inline diagmat_proxy(const T1& X) Chris@49: : diagmat_proxy_redirect< T1, is_Mat_fixed::value >::result(X) Chris@49: { Chris@49: } Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy< Mat > Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef eT elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: inline Chris@49: diagmat_proxy(const Mat& X) Chris@49: : P ( X ) Chris@49: , P_is_vec( (X.n_rows == 1) || (X.n_cols == 1) ) Chris@49: , n_elem ( P_is_vec ? X.n_elem : (std::min)(X.n_elem, X.n_rows) ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (P_is_vec == false) && (P.n_rows != P.n_cols), Chris@49: "diagmat(): only vectors and square matrices are accepted" Chris@49: ); Chris@49: } Chris@49: Chris@49: arma_inline elem_type operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); } Chris@49: arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); } Chris@49: Chris@49: const Mat& P; Chris@49: const bool P_is_vec; Chris@49: const uword n_elem; Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy< Row > Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef eT elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: Chris@49: inline Chris@49: diagmat_proxy(const Row& X) Chris@49: : P(X) Chris@49: , n_elem(X.n_elem) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: } Chris@49: Chris@49: arma_inline elem_type operator[] (const uword i) const { return P[i]; } Chris@49: arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } Chris@49: Chris@49: static const bool P_is_vec = true; Chris@49: Chris@49: const Row& P; Chris@49: const uword n_elem; Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy< Col > Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef eT elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: Chris@49: inline Chris@49: diagmat_proxy(const Col& X) Chris@49: : P(X) Chris@49: , n_elem(X.n_elem) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: } Chris@49: Chris@49: arma_inline elem_type operator[] (const uword i) const { return P[i]; } Chris@49: arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } Chris@49: Chris@49: static const bool P_is_vec = true; Chris@49: Chris@49: const Col& P; Chris@49: const uword n_elem; Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy< subview_row > Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef eT elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: Chris@49: inline Chris@49: diagmat_proxy(const subview_row& X) Chris@49: : P(X) Chris@49: , n_elem(X.n_elem) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: } Chris@49: Chris@49: arma_inline elem_type operator[] (const uword i) const { return P[i]; } Chris@49: arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } Chris@49: Chris@49: static const bool P_is_vec = true; Chris@49: Chris@49: const subview_row& P; Chris@49: const uword n_elem; Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy< subview_col > Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef eT elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: Chris@49: inline Chris@49: diagmat_proxy(const subview_col& X) Chris@49: : P(X) Chris@49: , n_elem(X.n_elem) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: } Chris@49: Chris@49: arma_inline elem_type operator[] (const uword i) const { return P[i]; } Chris@49: arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } Chris@49: Chris@49: static const bool P_is_vec = true; Chris@49: Chris@49: const subview_col& P; Chris@49: const uword n_elem; Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: // Chris@49: // Chris@49: // Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy_check_default Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef typename T1::elem_type elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: inline Chris@49: diagmat_proxy_check_default(const T1& X, const Mat&) Chris@49: : P(X) Chris@49: , P_is_vec( (resolves_to_vector::value) || (P.n_rows == 1) || (P.n_cols == 1) ) Chris@49: , n_elem( P_is_vec ? P.n_elem : (std::min)(P.n_elem, P.n_rows) ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (P_is_vec == false) && (P.n_rows != P.n_cols), Chris@49: "diagmat(): only vectors and square matrices are accepted" Chris@49: ); Chris@49: } Chris@49: Chris@49: arma_inline elem_type operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); } Chris@49: arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); } Chris@49: Chris@49: const Mat P; Chris@49: const bool P_is_vec; Chris@49: const uword n_elem; Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy_check_fixed Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: typedef typename T1::elem_type elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: inline Chris@49: diagmat_proxy_check_fixed(const T1& X, const Mat& out) Chris@49: : P( const_cast(X.memptr()), T1::n_rows, T1::n_cols, (&X == &out), false ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (P_is_vec == false) && (T1::n_rows != T1::n_cols), Chris@49: "diagmat(): only vectors and square matrices are accepted" Chris@49: ); Chris@49: } Chris@49: Chris@49: Chris@49: arma_inline eT operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); } Chris@49: arma_inline eT at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); } Chris@49: Chris@49: const Mat P; // TODO: why not just store X directly as T1& ? test with fixed size vectors and matrices Chris@49: Chris@49: static const bool P_is_vec = (T1::n_rows == 1) || (T1::n_cols == 1); Chris@49: static const uword n_elem = P_is_vec ? T1::n_elem : ( (T1::n_elem < T1::n_rows) ? T1::n_elem : T1::n_rows ); Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: struct diagmat_proxy_check_redirect {}; Chris@49: Chris@49: template Chris@49: struct diagmat_proxy_check_redirect { typedef diagmat_proxy_check_default result; }; Chris@49: Chris@49: template Chris@49: struct diagmat_proxy_check_redirect { typedef diagmat_proxy_check_fixed result; }; Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy_check : public diagmat_proxy_check_redirect::value >::result Chris@49: { Chris@49: public: Chris@49: inline diagmat_proxy_check(const T1& X, const Mat& out) Chris@49: : diagmat_proxy_check_redirect< T1, is_Mat_fixed::value >::result(X, out) Chris@49: { Chris@49: } Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy_check< Mat > Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef eT elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: Chris@49: inline Chris@49: diagmat_proxy_check(const Mat& X, const Mat& out) Chris@49: : P_local ( (&X == &out) ? new Mat(X) : 0 ) Chris@49: , P ( (&X == &out) ? (*P_local) : X ) Chris@49: , P_is_vec( (P.n_rows == 1) || (P.n_cols == 1) ) Chris@49: , n_elem ( P_is_vec ? P.n_elem : (std::min)(P.n_elem, P.n_rows) ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: arma_debug_check Chris@49: ( Chris@49: (P_is_vec == false) && (P.n_rows != P.n_cols), Chris@49: "diagmat(): only vectors and square matrices are accepted" Chris@49: ); Chris@49: } Chris@49: Chris@49: inline ~diagmat_proxy_check() Chris@49: { Chris@49: if(P_local) { delete P_local; } Chris@49: } Chris@49: Chris@49: arma_inline elem_type operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); } Chris@49: arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); } Chris@49: Chris@49: const Mat* P_local; Chris@49: const Mat& P; Chris@49: const bool P_is_vec; Chris@49: const uword n_elem; Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy_check< Row > Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef eT elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: inline Chris@49: diagmat_proxy_check(const Row& X, const Mat& out) Chris@49: : P_local ( (&X == reinterpret_cast*>(&out)) ? new Row(X) : 0 ) Chris@49: , P ( (&X == reinterpret_cast*>(&out)) ? (*P_local) : X ) Chris@49: , n_elem (X.n_elem) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: } Chris@49: Chris@49: inline ~diagmat_proxy_check() Chris@49: { Chris@49: if(P_local) { delete P_local; } Chris@49: } Chris@49: Chris@49: arma_inline elem_type operator[] (const uword i) const { return P[i]; } Chris@49: arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } Chris@49: Chris@49: static const bool P_is_vec = true; Chris@49: Chris@49: const Row* P_local; Chris@49: const Row& P; Chris@49: const uword n_elem; Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy_check< Col > Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef eT elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: inline Chris@49: diagmat_proxy_check(const Col& X, const Mat& out) Chris@49: : P_local ( (&X == reinterpret_cast*>(&out)) ? new Col(X) : 0 ) Chris@49: , P ( (&X == reinterpret_cast*>(&out)) ? (*P_local) : X ) Chris@49: , n_elem (X.n_elem) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: } Chris@49: Chris@49: inline ~diagmat_proxy_check() Chris@49: { Chris@49: if(P_local) { delete P_local; } Chris@49: } Chris@49: Chris@49: arma_inline elem_type operator[] (const uword i) const { return P[i]; } Chris@49: arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } Chris@49: Chris@49: static const bool P_is_vec = true; Chris@49: Chris@49: const Col* P_local; Chris@49: const Col& P; Chris@49: const uword n_elem; Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy_check< subview_row > Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef eT elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: inline Chris@49: diagmat_proxy_check(const subview_row& X, const Mat&) Chris@49: : P ( X ) Chris@49: , n_elem ( X.n_elem ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: } Chris@49: Chris@49: arma_inline elem_type operator[] (const uword i) const { return P[i]; } Chris@49: arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } Chris@49: Chris@49: static const bool P_is_vec = true; Chris@49: Chris@49: const Row P; Chris@49: const uword n_elem; Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: class diagmat_proxy_check< subview_col > Chris@49: { Chris@49: public: Chris@49: Chris@49: typedef eT elem_type; Chris@49: typedef typename get_pod_type::result pod_type; Chris@49: Chris@49: inline Chris@49: diagmat_proxy_check(const subview_col& X, const Mat& out) Chris@49: : P ( const_cast(X.colptr(0)), X.n_rows, (&(X.m) == &out), false ) Chris@49: , n_elem( X.n_elem ) Chris@49: //, X_ref ( X ) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: } Chris@49: Chris@49: arma_inline elem_type operator[] (const uword i) const { return P[i]; } Chris@49: arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } Chris@49: Chris@49: static const bool P_is_vec = true; Chris@49: Chris@49: const Col P; Chris@49: const uword n_elem; Chris@49: Chris@49: //const subview_col& X_ref; // prevents the compiler from potentially deleting X before we're done with it Chris@49: }; Chris@49: Chris@49: Chris@49: Chris@49: //! @}