annotate armadillo-3.900.4/include/armadillo_bits/fn_as_scalar.hpp @ 84:55a047986812 tip

Update library URI so as not to be document-local
author Chris Cannam
date Wed, 22 Apr 2020 14:21:57 +0100
parents 1ec0e2823891
children
rev   line source
Chris@49 1 // Copyright (C) 2010-2013 NICTA (www.nicta.com.au)
Chris@49 2 // Copyright (C) 2010-2013 Conrad Sanderson
Chris@49 3 //
Chris@49 4 // This Source Code Form is subject to the terms of the Mozilla Public
Chris@49 5 // License, v. 2.0. If a copy of the MPL was not distributed with this
Chris@49 6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
Chris@49 7
Chris@49 8
Chris@49 9 //! \addtogroup fn_as_scalar
Chris@49 10 //! @{
Chris@49 11
Chris@49 12
Chris@49 13
Chris@49 14 template<uword N>
Chris@49 15 struct as_scalar_redirect
Chris@49 16 {
Chris@49 17 template<typename T1>
Chris@49 18 inline static typename T1::elem_type apply(const T1& X);
Chris@49 19 };
Chris@49 20
Chris@49 21
Chris@49 22
Chris@49 23 template<>
Chris@49 24 struct as_scalar_redirect<2>
Chris@49 25 {
Chris@49 26 template<typename T1, typename T2>
Chris@49 27 inline static typename T1::elem_type apply(const Glue<T1,T2,glue_times>& X);
Chris@49 28 };
Chris@49 29
Chris@49 30
Chris@49 31 template<>
Chris@49 32 struct as_scalar_redirect<3>
Chris@49 33 {
Chris@49 34 template<typename T1, typename T2, typename T3>
Chris@49 35 inline static typename T1::elem_type apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times>& X);
Chris@49 36 };
Chris@49 37
Chris@49 38
Chris@49 39
Chris@49 40 template<uword N>
Chris@49 41 template<typename T1>
Chris@49 42 inline
Chris@49 43 typename T1::elem_type
Chris@49 44 as_scalar_redirect<N>::apply(const T1& X)
Chris@49 45 {
Chris@49 46 arma_extra_debug_sigprint();
Chris@49 47
Chris@49 48 // typedef typename T1::elem_type eT;
Chris@49 49
Chris@49 50 const Proxy<T1> P(X);
Chris@49 51
Chris@49 52 arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
Chris@49 53
Chris@49 54 return (Proxy<T1>::prefer_at_accessor == true) ? P.at(0,0) : P[0];
Chris@49 55 }
Chris@49 56
Chris@49 57
Chris@49 58
Chris@49 59 template<typename T1, typename T2>
Chris@49 60 inline
Chris@49 61 typename T1::elem_type
Chris@49 62 as_scalar_redirect<2>::apply(const Glue<T1, T2, glue_times>& X)
Chris@49 63 {
Chris@49 64 arma_extra_debug_sigprint();
Chris@49 65
Chris@49 66 typedef typename T1::elem_type eT;
Chris@49 67
Chris@49 68 // T1 must result in a matrix with one row
Chris@49 69 // T2 must result in a matrix with one column
Chris@49 70
Chris@49 71 const bool has_all_mat = is_Mat<T1>::value && is_Mat<T2>::value;
Chris@49 72 const bool prefer_at_accessor = Proxy<T1>::prefer_at_accessor || Proxy<T2>::prefer_at_accessor;
Chris@49 73
Chris@49 74 const bool do_partial_unwrap = has_all_mat || prefer_at_accessor;
Chris@49 75
Chris@49 76 if(do_partial_unwrap == true)
Chris@49 77 {
Chris@49 78 const partial_unwrap<T1> tmp1(X.A);
Chris@49 79 const partial_unwrap<T2> tmp2(X.B);
Chris@49 80
Chris@49 81 typedef typename partial_unwrap<T1>::stored_type TA;
Chris@49 82 typedef typename partial_unwrap<T2>::stored_type TB;
Chris@49 83
Chris@49 84 const TA& A = tmp1.M;
Chris@49 85 const TB& B = tmp2.M;
Chris@49 86
Chris@49 87 const uword A_n_rows = (tmp1.do_trans == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols);
Chris@49 88 const uword A_n_cols = (tmp1.do_trans == false) ? (TA::is_col ? 1 : A.n_cols) : (TA::is_row ? 1 : A.n_rows);
Chris@49 89
Chris@49 90 const uword B_n_rows = (tmp2.do_trans == false) ? (TB::is_row ? 1 : B.n_rows) : (TB::is_col ? 1 : B.n_cols);
Chris@49 91 const uword B_n_cols = (tmp2.do_trans == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows);
Chris@49 92
Chris@49 93 arma_debug_check( (A_n_rows != 1) || (B_n_cols != 1) || (A_n_cols != B_n_rows), "as_scalar(): incompatible dimensions" );
Chris@49 94
Chris@49 95 const eT val = op_dot::direct_dot(A.n_elem, A.memptr(), B.memptr());
Chris@49 96
Chris@49 97 return (tmp1.do_times || tmp2.do_times) ? (val * tmp1.get_val() * tmp2.get_val()) : val;
Chris@49 98 }
Chris@49 99 else
Chris@49 100 {
Chris@49 101 const Proxy<T1> PA(X.A);
Chris@49 102 const Proxy<T2> PB(X.B);
Chris@49 103
Chris@49 104 arma_debug_check
Chris@49 105 (
Chris@49 106 (PA.get_n_rows() != 1) || (PB.get_n_cols() != 1) || (PA.get_n_cols() != PB.get_n_rows()),
Chris@49 107 "as_scalar(): incompatible dimensions"
Chris@49 108 );
Chris@49 109
Chris@49 110 return op_dot::apply_proxy(PA,PB);
Chris@49 111 }
Chris@49 112 }
Chris@49 113
Chris@49 114
Chris@49 115
Chris@49 116 template<typename T1, typename T2, typename T3>
Chris@49 117 inline
Chris@49 118 typename T1::elem_type
Chris@49 119 as_scalar_redirect<3>::apply(const Glue< Glue<T1, T2, glue_times>, T3, glue_times >& X)
Chris@49 120 {
Chris@49 121 arma_extra_debug_sigprint();
Chris@49 122
Chris@49 123 typedef typename T1::elem_type eT;
Chris@49 124
Chris@49 125 // T1 * T2 must result in a matrix with one row
Chris@49 126 // T3 must result in a matrix with one column
Chris@49 127
Chris@49 128 typedef typename strip_inv <T2 >::stored_type T2_stripped_1;
Chris@49 129 typedef typename strip_diagmat<T2_stripped_1>::stored_type T2_stripped_2;
Chris@49 130
Chris@49 131 const strip_inv <T2> strip1(X.A.B);
Chris@49 132 const strip_diagmat<T2_stripped_1> strip2(strip1.M);
Chris@49 133
Chris@49 134 const bool tmp2_do_inv = strip1.do_inv;
Chris@49 135 const bool tmp2_do_diagmat = strip2.do_diagmat;
Chris@49 136
Chris@49 137 if(tmp2_do_diagmat == false)
Chris@49 138 {
Chris@49 139 const Mat<eT> tmp(X);
Chris@49 140
Chris@49 141 arma_debug_check( (tmp.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
Chris@49 142
Chris@49 143 return tmp[0];
Chris@49 144 }
Chris@49 145 else
Chris@49 146 {
Chris@49 147 const partial_unwrap<T1> tmp1(X.A.A);
Chris@49 148 const partial_unwrap<T2_stripped_2> tmp2(strip2.M);
Chris@49 149 const partial_unwrap<T3> tmp3(X.B);
Chris@49 150
Chris@49 151 const Mat<eT>& A = tmp1.M;
Chris@49 152 const Mat<eT>& B = tmp2.M;
Chris@49 153 const Mat<eT>& C = tmp3.M;
Chris@49 154
Chris@49 155 const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
Chris@49 156 const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
Chris@49 157
Chris@49 158 const bool B_is_vec = B.is_vec();
Chris@49 159
Chris@49 160 const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
Chris@49 161 const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
Chris@49 162
Chris@49 163 const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
Chris@49 164 const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
Chris@49 165
Chris@49 166 const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val();
Chris@49 167
Chris@49 168 arma_debug_check
Chris@49 169 (
Chris@49 170 (A_n_rows != 1) ||
Chris@49 171 (C_n_cols != 1) ||
Chris@49 172 (A_n_cols != B_n_rows) ||
Chris@49 173 (B_n_cols != C_n_rows)
Chris@49 174 ,
Chris@49 175 "as_scalar(): incompatible dimensions"
Chris@49 176 );
Chris@49 177
Chris@49 178
Chris@49 179 if(B_is_vec == true)
Chris@49 180 {
Chris@49 181 if(tmp2_do_inv == true)
Chris@49 182 {
Chris@49 183 return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem);
Chris@49 184 }
Chris@49 185 else
Chris@49 186 {
Chris@49 187 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
Chris@49 188 }
Chris@49 189 }
Chris@49 190 else
Chris@49 191 {
Chris@49 192 if(tmp2_do_inv == true)
Chris@49 193 {
Chris@49 194 return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem);
Chris@49 195 }
Chris@49 196 else
Chris@49 197 {
Chris@49 198 return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
Chris@49 199 }
Chris@49 200 }
Chris@49 201 }
Chris@49 202 }
Chris@49 203
Chris@49 204
Chris@49 205
Chris@49 206 template<typename T1>
Chris@49 207 inline
Chris@49 208 typename T1::elem_type
Chris@49 209 as_scalar_diag(const Base<typename T1::elem_type,T1>& X)
Chris@49 210 {
Chris@49 211 arma_extra_debug_sigprint();
Chris@49 212
Chris@49 213 typedef typename T1::elem_type eT;
Chris@49 214
Chris@49 215 const unwrap<T1> tmp(X.get_ref());
Chris@49 216 const Mat<eT>& A = tmp.M;
Chris@49 217
Chris@49 218 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
Chris@49 219
Chris@49 220 return A.mem[0];
Chris@49 221 }
Chris@49 222
Chris@49 223
Chris@49 224
Chris@49 225 template<typename T1, typename T2, typename T3>
Chris@49 226 inline
Chris@49 227 typename T1::elem_type
Chris@49 228 as_scalar_diag(const Glue< Glue<T1, T2, glue_times_diag>, T3, glue_times >& X)
Chris@49 229 {
Chris@49 230 arma_extra_debug_sigprint();
Chris@49 231
Chris@49 232 typedef typename T1::elem_type eT;
Chris@49 233
Chris@49 234 // T1 * T2 must result in a matrix with one row
Chris@49 235 // T3 must result in a matrix with one column
Chris@49 236
Chris@49 237 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
Chris@49 238
Chris@49 239 const strip_diagmat<T2> strip(X.A.B);
Chris@49 240
Chris@49 241 const partial_unwrap<T1> tmp1(X.A.A);
Chris@49 242 const partial_unwrap<T2_stripped> tmp2(strip.M);
Chris@49 243 const partial_unwrap<T3> tmp3(X.B);
Chris@49 244
Chris@49 245 const Mat<eT>& A = tmp1.M;
Chris@49 246 const Mat<eT>& B = tmp2.M;
Chris@49 247 const Mat<eT>& C = tmp3.M;
Chris@49 248
Chris@49 249
Chris@49 250 const uword A_n_rows = (tmp1.do_trans == false) ? A.n_rows : A.n_cols;
Chris@49 251 const uword A_n_cols = (tmp1.do_trans == false) ? A.n_cols : A.n_rows;
Chris@49 252
Chris@49 253 const bool B_is_vec = B.is_vec();
Chris@49 254
Chris@49 255 const uword B_n_rows = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_rows : B.n_cols );
Chris@49 256 const uword B_n_cols = (B_is_vec == true) ? B.n_elem : ( (tmp2.do_trans == false) ? B.n_cols : B.n_rows );
Chris@49 257
Chris@49 258 const uword C_n_rows = (tmp3.do_trans == false) ? C.n_rows : C.n_cols;
Chris@49 259 const uword C_n_cols = (tmp3.do_trans == false) ? C.n_cols : C.n_rows;
Chris@49 260
Chris@49 261 const eT val = tmp1.get_val() * tmp2.get_val() * tmp3.get_val();
Chris@49 262
Chris@49 263 arma_debug_check
Chris@49 264 (
Chris@49 265 (A_n_rows != 1) ||
Chris@49 266 (C_n_cols != 1) ||
Chris@49 267 (A_n_cols != B_n_rows) ||
Chris@49 268 (B_n_cols != C_n_rows)
Chris@49 269 ,
Chris@49 270 "as_scalar(): incompatible dimensions"
Chris@49 271 );
Chris@49 272
Chris@49 273
Chris@49 274 if(B_is_vec == true)
Chris@49 275 {
Chris@49 276 return val * op_dot::direct_dot(A.n_elem, A.mem, B.mem, C.mem);
Chris@49 277 }
Chris@49 278 else
Chris@49 279 {
Chris@49 280 return val * op_dotext::direct_rowvec_diagmat_colvec(A.mem, B, C.mem);
Chris@49 281 }
Chris@49 282 }
Chris@49 283
Chris@49 284
Chris@49 285
Chris@49 286 template<typename T1, typename T2>
Chris@49 287 arma_inline
Chris@49 288 arma_warn_unused
Chris@49 289 typename T1::elem_type
Chris@49 290 as_scalar(const Glue<T1, T2, glue_times>& X, const typename arma_not_cx<typename T1::elem_type>::result* junk = 0)
Chris@49 291 {
Chris@49 292 arma_extra_debug_sigprint();
Chris@49 293 arma_ignore(junk);
Chris@49 294
Chris@49 295 if(is_glue_times_diag<T1>::value == false)
Chris@49 296 {
Chris@49 297 const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
Chris@49 298
Chris@49 299 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
Chris@49 300
Chris@49 301 return as_scalar_redirect<N_mat>::apply(X);
Chris@49 302 }
Chris@49 303 else
Chris@49 304 {
Chris@49 305 return as_scalar_diag(X);
Chris@49 306 }
Chris@49 307 }
Chris@49 308
Chris@49 309
Chris@49 310
Chris@49 311 template<typename T1>
Chris@49 312 inline
Chris@49 313 arma_warn_unused
Chris@49 314 typename T1::elem_type
Chris@49 315 as_scalar(const Base<typename T1::elem_type,T1>& X)
Chris@49 316 {
Chris@49 317 arma_extra_debug_sigprint();
Chris@49 318
Chris@49 319 // typedef typename T1::elem_type eT;
Chris@49 320
Chris@49 321 const Proxy<T1> P(X.get_ref());
Chris@49 322
Chris@49 323 arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
Chris@49 324
Chris@49 325 return (Proxy<T1>::prefer_at_accessor == true) ? P.at(0,0) : P[0];
Chris@49 326 }
Chris@49 327
Chris@49 328
Chris@49 329 // ensure the following two functions are aware of each other
Chris@49 330 template<typename T1, typename eop_type> inline arma_warn_unused typename T1::elem_type as_scalar(const eOp<T1, eop_type>& X);
Chris@49 331 template<typename T1, typename T2, typename eglue_type> inline arma_warn_unused typename T1::elem_type as_scalar(const eGlue<T1, T2, eglue_type>& X);
Chris@49 332
Chris@49 333
Chris@49 334
Chris@49 335 template<typename T1, typename eop_type>
Chris@49 336 inline
Chris@49 337 arma_warn_unused
Chris@49 338 typename T1::elem_type
Chris@49 339 as_scalar(const eOp<T1, eop_type>& X)
Chris@49 340 {
Chris@49 341 arma_extra_debug_sigprint();
Chris@49 342
Chris@49 343 typedef typename T1::elem_type eT;
Chris@49 344
Chris@49 345 const eT val = as_scalar(X.P.Q);
Chris@49 346
Chris@49 347 return eop_core<eop_type>::process(val, X.aux);
Chris@49 348 }
Chris@49 349
Chris@49 350
Chris@49 351
Chris@49 352 template<typename T1, typename T2, typename eglue_type>
Chris@49 353 inline
Chris@49 354 arma_warn_unused
Chris@49 355 typename T1::elem_type
Chris@49 356 as_scalar(const eGlue<T1, T2, eglue_type>& X)
Chris@49 357 {
Chris@49 358 arma_extra_debug_sigprint();
Chris@49 359
Chris@49 360 typedef typename T1::elem_type eT;
Chris@49 361
Chris@49 362 const eT a = as_scalar(X.P1.Q);
Chris@49 363 const eT b = as_scalar(X.P2.Q);
Chris@49 364
Chris@49 365 // the optimiser will keep only one return statement
Chris@49 366
Chris@49 367 if(is_same_type<eglue_type, eglue_plus >::value == true) { return a + b; }
Chris@49 368 else if(is_same_type<eglue_type, eglue_minus>::value == true) { return a - b; }
Chris@49 369 else if(is_same_type<eglue_type, eglue_div >::value == true) { return a / b; }
Chris@49 370 else if(is_same_type<eglue_type, eglue_schur>::value == true) { return a * b; }
Chris@49 371 }
Chris@49 372
Chris@49 373
Chris@49 374
Chris@49 375 template<typename T1>
Chris@49 376 inline
Chris@49 377 arma_warn_unused
Chris@49 378 typename T1::elem_type
Chris@49 379 as_scalar(const BaseCube<typename T1::elem_type,T1>& X)
Chris@49 380 {
Chris@49 381 arma_extra_debug_sigprint();
Chris@49 382
Chris@49 383 // typedef typename T1::elem_type eT;
Chris@49 384
Chris@49 385 const ProxyCube<T1> P(X.get_ref());
Chris@49 386
Chris@49 387 arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
Chris@49 388
Chris@49 389 return (ProxyCube<T1>::prefer_at_accessor == true) ? P.at(0,0,0) : P[0];
Chris@49 390 }
Chris@49 391
Chris@49 392
Chris@49 393
Chris@49 394 template<typename T>
Chris@49 395 arma_inline
Chris@49 396 arma_warn_unused
Chris@49 397 const typename arma_scalar_only<T>::result &
Chris@49 398 as_scalar(const T& x)
Chris@49 399 {
Chris@49 400 return x;
Chris@49 401 }
Chris@49 402
Chris@49 403
Chris@49 404
Chris@49 405 template<typename T1>
Chris@49 406 arma_inline
Chris@49 407 arma_warn_unused
Chris@49 408 typename T1::elem_type
Chris@49 409 as_scalar(const SpBase<typename T1::elem_type, T1>& X)
Chris@49 410 {
Chris@49 411 typedef typename T1::elem_type eT;
Chris@49 412
Chris@49 413 const unwrap_spmat<T1> tmp(X.get_ref());
Chris@49 414 const SpMat<eT>& A = tmp.M;
Chris@49 415
Chris@49 416 arma_debug_check( (A.n_elem != 1), "as_scalar(): expression doesn't evaluate to exactly one element" );
Chris@49 417
Chris@49 418 return A.at(0,0);
Chris@49 419 }
Chris@49 420
Chris@49 421
Chris@49 422
Chris@49 423 //! @}