max@0
|
1 // Copyright (C) 2008-2010 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2008-2010 Conrad Sanderson
|
max@0
|
3 //
|
max@0
|
4 // This file is part of the Armadillo C++ library.
|
max@0
|
5 // It is provided without any warranty of fitness
|
max@0
|
6 // for any purpose. You can redistribute this file
|
max@0
|
7 // and/or modify it under the terms of the GNU
|
max@0
|
8 // Lesser General Public License (LGPL) as published
|
max@0
|
9 // by the Free Software Foundation, either version 3
|
max@0
|
10 // of the License or (at your option) any later version.
|
max@0
|
11 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
12
|
max@0
|
13
|
max@0
|
14 //! \addtogroup fn_trace
|
max@0
|
15 //! @{
|
max@0
|
16
|
max@0
|
17
|
max@0
|
18 //! Immediate trace (sum of diagonal elements) of a square dense matrix
|
max@0
|
19 template<typename T1>
|
max@0
|
20 inline
|
max@0
|
21 arma_warn_unused
|
max@0
|
22 typename T1::elem_type
|
max@0
|
23 trace(const Base<typename T1::elem_type,T1>& X)
|
max@0
|
24 {
|
max@0
|
25 arma_extra_debug_sigprint();
|
max@0
|
26
|
max@0
|
27 typedef typename T1::elem_type eT;
|
max@0
|
28
|
max@0
|
29 const Proxy<T1> A(X.get_ref());
|
max@0
|
30
|
max@0
|
31 arma_debug_check( (A.get_n_rows() != A.get_n_cols()), "trace(): matrix must be square sized" );
|
max@0
|
32
|
max@0
|
33 const uword N = A.get_n_rows();
|
max@0
|
34 eT val = eT(0);
|
max@0
|
35
|
max@0
|
36 for(uword i=0; i<N; ++i)
|
max@0
|
37 {
|
max@0
|
38 val += A.at(i,i);
|
max@0
|
39 }
|
max@0
|
40
|
max@0
|
41 return val;
|
max@0
|
42 }
|
max@0
|
43
|
max@0
|
44
|
max@0
|
45
|
max@0
|
46 template<typename T1>
|
max@0
|
47 inline
|
max@0
|
48 arma_warn_unused
|
max@0
|
49 typename T1::elem_type
|
max@0
|
50 trace(const Op<T1, op_diagmat>& X)
|
max@0
|
51 {
|
max@0
|
52 arma_extra_debug_sigprint();
|
max@0
|
53
|
max@0
|
54 typedef typename T1::elem_type eT;
|
max@0
|
55
|
max@0
|
56 const diagmat_proxy<T1> A(X.m);
|
max@0
|
57
|
max@0
|
58 const uword N = A.n_elem;
|
max@0
|
59
|
max@0
|
60 eT val = eT(0);
|
max@0
|
61
|
max@0
|
62 for(uword i=0; i<N; ++i)
|
max@0
|
63 {
|
max@0
|
64 val += A[i];
|
max@0
|
65 }
|
max@0
|
66
|
max@0
|
67 return val;
|
max@0
|
68 }
|
max@0
|
69
|
max@0
|
70
|
max@0
|
71 //! speedup for trace(A*B), where the result of A*B is a square sized matrix
|
max@0
|
72 template<typename T1, typename T2>
|
max@0
|
73 inline
|
max@0
|
74 arma_warn_unused
|
max@0
|
75 typename T1::elem_type
|
max@0
|
76 trace(const Glue<T1, T2, glue_times>& X)
|
max@0
|
77 {
|
max@0
|
78 arma_extra_debug_sigprint();
|
max@0
|
79
|
max@0
|
80 typedef typename T1::elem_type eT;
|
max@0
|
81
|
max@0
|
82 const unwrap<T1> tmp1(X.A);
|
max@0
|
83 const unwrap<T2> tmp2(X.B);
|
max@0
|
84
|
max@0
|
85 const Mat<eT>& A = tmp1.M;
|
max@0
|
86 const Mat<eT>& B = tmp2.M;
|
max@0
|
87
|
max@0
|
88 arma_debug_assert_mul_size(A, B, "matrix multiply");
|
max@0
|
89
|
max@0
|
90 arma_debug_check( (A.n_rows != B.n_cols), "trace(): matrix must be square sized" );
|
max@0
|
91
|
max@0
|
92 const uword N1 = A.n_rows;
|
max@0
|
93 const uword N2 = A.n_cols;
|
max@0
|
94 eT val = eT(0);
|
max@0
|
95
|
max@0
|
96 for(uword i=0; i<N1; ++i)
|
max@0
|
97 {
|
max@0
|
98 const eT* B_colmem = B.colptr(i);
|
max@0
|
99 eT acc = eT(0);
|
max@0
|
100
|
max@0
|
101 for(uword j=0; j<N2; ++j)
|
max@0
|
102 {
|
max@0
|
103 acc += A.at(i,j) * B_colmem[j];
|
max@0
|
104 }
|
max@0
|
105
|
max@0
|
106 val += acc;
|
max@0
|
107 }
|
max@0
|
108
|
max@0
|
109 return val;
|
max@0
|
110 }
|
max@0
|
111
|
max@0
|
112
|
max@0
|
113
|
max@0
|
114 //! @}
|