annotate armadillo-2.4.4/include/armadillo_bits/fn_as_scalar.hpp @ 18:8d046a9d36aa slimline

Back out rev 13:ac07c60aa798. Like an idiot, I committed a whole pile of unrelated changes in the guise of a single typo fix. Will re-commit in stages
author Chris Cannam
date Thu, 10 May 2012 10:45:44 +0100
parents 8b6102e2a9b0
children
rev   line source
max@0 1 // Copyright (C) 2010-2011 NICTA (www.nicta.com.au)
max@0 2 // Copyright (C) 2010-2011 Conrad Sanderson
max@0 3 //
max@0 4 // This file is part of the Armadillo C++ library.
max@0 5 // It is provided without any warranty of fitness
max@0 6 // for any purpose. You can redistribute this file
max@0 7 // and/or modify it under the terms of the GNU
max@0 8 // Lesser General Public License (LGPL) as published
max@0 9 // by the Free Software Foundation, either version 3
max@0 10 // of the License or (at your option) any later version.
max@0 11 // (see http://www.opensource.org/licenses for more info)
max@0 12
max@0 13
max@0 14 //! \addtogroup fn_as_scalar
max@0 15 //! @{
max@0 16
max@0 17
max@0 18
max@0 19 template<uword N>
max@0 20 struct as_scalar_redirect
max@0 21 {
max@0 22 template<typename T1>
max@0 23 inline static typename T1::elem_type apply(const T1& X);
max@0 24 };
max@0 25
max@0 26
max@0 27
max@0 28 template<>
max@0 29 struct as_scalar_redirect<2>
max@0 30 {
max@0 31 template<typename T1, typename T2>
max@0 32 inline static typename T1::elem_type apply(const Glue<T1,T2,glue_times>& X);
max@0 33 };
max@0 34
max@0 35
max@0 36 template<>
max@0 37 struct as_scalar_redirect<3>
max@0 38 {
max@0 39 template<typename T1, typename T2, typename T3>
max@0 40 inline static typename T1::elem_type apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times>& X);
max@0 41 };
max@0 42
max@0 43
max@0 44
max@0 45 template<uword N>
max@0 46 template<typename T1>
max@0 47 inline
max@0 48 typename T1::elem_type
max@0 49 as_scalar_redirect<N>::apply(const T1& X)
max@0 50 {
max@0 51 arma_extra_debug_sigprint();
max@0 52
max@0 53 typedef typename T1::elem_type eT;
max@0 54
max@0 55 const unwrap<T1> tmp(X);
max@0 56 const Mat<eT>& A = tmp.M;
max@0 57
max@0 58 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
max@0 59
max@0 60 return A.mem[0];
max@0 61 }
max@0 62
max@0 63
max@0 64
max@0 65 template<typename T1, typename T2>
max@0 66 inline
max@0 67 typename T1::elem_type
max@0 68 as_scalar_redirect<2>::apply(const Glue<T1, T2, glue_times>& X)
max@0 69 {
max@0 70 arma_extra_debug_sigprint();
max@0 71
max@0 72 typedef typename T1::elem_type eT;
max@0 73
max@0 74 // T1 must result in a matrix with one row
max@0 75 // T2 must result in a matrix with one column
max@0 76
max@0 77 const partial_unwrap<T1> tmp1(X.A);
max@0 78 const partial_unwrap<T2> tmp2(X.B);
max@0 79
max@0 80 const Mat<eT>& A = tmp1.M;
max@0 81 const Mat<eT>& B = tmp2.M;
max@0 82
max@0 83 const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
max@0 84 const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
max@0 85
max@0 86 const uword B_n_rows = (tmp2.do_trans == false) ? B.n_rows : B.n_cols;
max@0 87 const uword B_n_cols = (tmp2.do_trans == false) ? B.n_cols : B.n_rows;
max@0 88
max@0 89 const eT val = tmp1.get_val() * tmp2.get_val();
max@0 90
max@0 91 arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" );
max@0 92
max@0 93 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem);
max@0 94 }
max@0 95
max@0 96
max@0 97
max@0 98 template<typename T1, typename T2, typename T3>
max@0 99 inline
max@0 100 typename T1::elem_type
max@0 101 as_scalar_redirect<3>::apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times >& X)
max@0 102 {
max@0 103 arma_extra_debug_sigprint();
max@0 104
max@0 105 typedef typename T1::elem_type eT;
max@0 106
max@0 107 // T1 * T2 must result in a matrix with one row
max@0 108 // T3 must result in a matrix with one column
max@0 109
max@0 110 typedef typename strip_inv <T2 >::stored_type T2_stripped_1;
max@0 111 typedef typename strip_diagmat<T2_stripped_1>::stored_type T2_stripped_2;
max@0 112
max@0 113 const strip_inv <T2> strip1(X.A.B);
max@0 114 const strip_diagmat<T2_stripped_1> strip2(strip1.M);
max@0 115
max@0 116 const bool tmp2_do_inv = strip1.do_inv;
max@0 117 const bool tmp2_do_diagmat = strip2.do_diagmat;
max@0 118
max@0 119 if(tmp2_do_diagmat == false)
max@0 120 {
max@0 121 const Mat<eT> tmp(X);
max@0 122
max@0 123 arma_debug_check( (tmp.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
max@0 124
max@0 125 return tmp[0];
max@0 126 }
max@0 127 else
max@0 128 {
max@0 129 const partial_unwrap<T1> tmp1(X.A.A);
max@0 130 const partial_unwrap<T2_stripped_2> tmp2(strip2.M);
max@0 131 const partial_unwrap<T3> tmp3(X.B);
max@0 132
max@0 133 const Mat<eT>& A = tmp1.M;
max@0 134 const Mat<eT>& B = tmp2.M;
max@0 135 const Mat<eT>& C = tmp3.M;
max@0 136
max@0 137 const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
max@0 138 const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
max@0 139
max@0 140 const bool B_is_vec = B.is_vec();
max@0 141
max@0 142 const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
max@0 143 const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
max@0 144
max@0 145 const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
max@0 146 const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
max@0 147
max@0 148 const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val();
max@0 149
max@0 150 arma_debug_check
max@0 151 (
max@0 152 (A_n_rows != 1) ||
max@0 153 (C_n_cols != 1) ||
max@0 154 (A_n_cols != B_n_rows) ||
max@0 155 (B_n_cols != C_n_rows)
max@0 156 ,
max@0 157 "as_scalar(): incompatible dimensions"
max@0 158 );
max@0 159
max@0 160
max@0 161 if(B_is_vec == true)
max@0 162 {
max@0 163 if(tmp2_do_inv == true)
max@0 164 {
max@0 165 return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem);
max@0 166 }
max@0 167 else
max@0 168 {
max@0 169 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
max@0 170 }
max@0 171 }
max@0 172 else
max@0 173 {
max@0 174 if(tmp2_do_inv == true)
max@0 175 {
max@0 176 return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem);
max@0 177 }
max@0 178 else
max@0 179 {
max@0 180 return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
max@0 181 }
max@0 182 }
max@0 183 }
max@0 184 }
max@0 185
max@0 186
max@0 187
max@0 188 template<typename T1>
max@0 189 inline
max@0 190 typename T1::elem_type
max@0 191 as_scalar_diag(const Base<typename T1::elem_type,T1>& X)
max@0 192 {
max@0 193 arma_extra_debug_sigprint();
max@0 194
max@0 195 typedef typename T1::elem_type eT;
max@0 196
max@0 197 const unwrap<T1> tmp(X.get_ref());
max@0 198 const Mat<eT>& A = tmp.M;
max@0 199
max@0 200 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
max@0 201
max@0 202 return A.mem[0];
max@0 203 }
max@0 204
max@0 205
max@0 206
max@0 207 template<typename T1, typename T2, typename T3>
max@0 208 inline
max@0 209 typename T1::elem_type
max@0 210 as_scalar_diag(const Glue< Glue<T1, T2, glue_times_diag>, T3, glue_times >& X)
max@0 211 {
max@0 212 arma_extra_debug_sigprint();
max@0 213
max@0 214 typedef typename T1::elem_type eT;
max@0 215
max@0 216 // T1 * T2 must result in a matrix with one row
max@0 217 // T3 must result in a matrix with one column
max@0 218
max@0 219 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
max@0 220
max@0 221 const strip_diagmat<T2> strip(X.A.B);
max@0 222
max@0 223 const partial_unwrap<T1> tmp1(X.A.A);
max@0 224 const partial_unwrap<T2_stripped> tmp2(strip.M);
max@0 225 const partial_unwrap<T3> tmp3(X.B);
max@0 226
max@0 227 const Mat<eT>& A = tmp1.M;
max@0 228 const Mat<eT>& B = tmp2.M;
max@0 229 const Mat<eT>& C = tmp3.M;
max@0 230
max@0 231
max@0 232 const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
max@0 233 const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
max@0 234
max@0 235 const bool B_is_vec = B.is_vec();
max@0 236
max@0 237 const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
max@0 238 const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
max@0 239
max@0 240 const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
max@0 241 const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
max@0 242
max@0 243 const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val();
max@0 244
max@0 245 arma_debug_check
max@0 246 (
max@0 247 (A_n_rows != 1) ||
max@0 248 (C_n_cols != 1) ||
max@0 249 (A_n_cols != B_n_rows) ||
max@0 250 (B_n_cols != C_n_rows)
max@0 251 ,
max@0 252 "as_scalar(): incompatible dimensions"
max@0 253 );
max@0 254
max@0 255
max@0 256 if(B_is_vec == true)
max@0 257 {
max@0 258 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
max@0 259 }
max@0 260 else
max@0 261 {
max@0 262 return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
max@0 263 }
max@0 264 }
max@0 265
max@0 266
max@0 267
max@0 268 template<typename T1, typename T2>
max@0 269 arma_inline
max@0 270 arma_warn_unused
max@0 271 typename T1::elem_type
max@0 272 as_scalar(const Glue<T1, T2, glue_times>& X, const typename arma_not_cx<typename T1::elem_type>::result* junk = 0)
max@0 273 {
max@0 274 arma_extra_debug_sigprint();
max@0 275 arma_ignore(junk);
max@0 276
max@0 277 if(is_glue_times_diag<T1>::value == false)
max@0 278 {
max@0 279 const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
max@0 280
max@0 281 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
max@0 282
max@0 283 return as_scalar_redirect<N_mat>::apply(X);
max@0 284 }
max@0 285 else
max@0 286 {
max@0 287 return as_scalar_diag(X);
max@0 288 }
max@0 289 }
max@0 290
max@0 291
max@0 292
max@0 293 template<typename T1>
max@0 294 inline
max@0 295 arma_warn_unused
max@0 296 typename T1::elem_type
max@0 297 as_scalar(const Base<typename T1::elem_type,T1>& X)
max@0 298 {
max@0 299 arma_extra_debug_sigprint();
max@0 300
max@0 301 typedef typename T1::elem_type eT;
max@0 302
max@0 303 const unwrap<T1> tmp(X.get_ref());
max@0 304 const Mat<eT>& A = tmp.M;
max@0 305
max@0 306 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
max@0 307
max@0 308 return A.mem[0];
max@0 309 }
max@0 310
max@0 311
max@0 312
max@0 313 template<typename T1>
max@0 314 arma_inline
max@0 315 arma_warn_unused
max@0 316 typename T1::elem_type
max@0 317 as_scalar(const eOp<T1, eop_neg>& X)
max@0 318 {
max@0 319 arma_extra_debug_sigprint();
max@0 320
max@0 321 return -(as_scalar(X.P.Q));
max@0 322 }
max@0 323
max@0 324
max@0 325
max@0 326 template<typename T1>
max@0 327 inline
max@0 328 arma_warn_unused
max@0 329 typename T1::elem_type
max@0 330 as_scalar(const BaseCube<typename T1::elem_type,T1>& X)
max@0 331 {
max@0 332 arma_extra_debug_sigprint();
max@0 333
max@0 334 typedef typename T1::elem_type eT;
max@0 335
max@0 336 const unwrap_cube<T1> tmp(X.get_ref());
max@0 337 const Cube<eT>& A = tmp.M;
max@0 338
max@0 339 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
max@0 340
max@0 341 return A.mem[0];
max@0 342 }
max@0 343
max@0 344
max@0 345
max@0 346 template<typename T>
max@0 347 arma_inline
max@0 348 arma_warn_unused
max@0 349 const typename arma_scalar_only<T>::result &
max@0 350 as_scalar(const T& x)
max@0 351 {
max@0 352 return x;
max@0 353 }
max@0 354
max@0 355
max@0 356
max@0 357 //! @}