Chris@49
|
1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2013 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9 //! \addtogroup op_sum
|
Chris@49
|
10 //! @{
|
Chris@49
|
11
|
Chris@49
|
12 //! \brief
|
Chris@49
|
13 //! Immediate sum of elements of a matrix along a specified dimension (either rows or columns).
|
Chris@49
|
14 //! The result is stored in a dense matrix that has either one column or one row.
|
Chris@49
|
15 //! See the sum() function for more details.
|
Chris@49
|
16 template<typename T1>
|
Chris@49
|
17 arma_hot
|
Chris@49
|
18 inline
|
Chris@49
|
19 void
|
Chris@49
|
20 op_sum::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_sum>& in)
|
Chris@49
|
21 {
|
Chris@49
|
22 arma_extra_debug_sigprint();
|
Chris@49
|
23
|
Chris@49
|
24 typedef typename T1::elem_type eT;
|
Chris@49
|
25
|
Chris@49
|
26 const uword dim = in.aux_uword_a;
|
Chris@49
|
27 arma_debug_check( (dim > 1), "sum(): incorrect usage. dim must be 0 or 1");
|
Chris@49
|
28
|
Chris@49
|
29 const Proxy<T1> P(in.m);
|
Chris@49
|
30
|
Chris@49
|
31 typedef typename Proxy<T1>::stored_type P_stored_type;
|
Chris@49
|
32
|
Chris@49
|
33 const bool is_alias = P.is_alias(out);
|
Chris@49
|
34
|
Chris@49
|
35 if( (is_Mat<P_stored_type>::value == true) || is_alias )
|
Chris@49
|
36 {
|
Chris@49
|
37 const unwrap_check<P_stored_type> tmp(P.Q, is_alias);
|
Chris@49
|
38
|
Chris@49
|
39 const typename unwrap_check<P_stored_type>::stored_type& X = tmp.M;
|
Chris@49
|
40
|
Chris@49
|
41 const uword X_n_rows = X.n_rows;
|
Chris@49
|
42 const uword X_n_cols = X.n_cols;
|
Chris@49
|
43
|
Chris@49
|
44 if(dim == 0) // traverse across rows (i.e. find the sum in each column)
|
Chris@49
|
45 {
|
Chris@49
|
46 out.set_size(1, X_n_cols);
|
Chris@49
|
47
|
Chris@49
|
48 eT* out_mem = out.memptr();
|
Chris@49
|
49
|
Chris@49
|
50 for(uword col=0; col < X_n_cols; ++col)
|
Chris@49
|
51 {
|
Chris@49
|
52 out_mem[col] = arrayops::accumulate( X.colptr(col), X_n_rows );
|
Chris@49
|
53 }
|
Chris@49
|
54 }
|
Chris@49
|
55 else // traverse across columns (i.e. find the sum in each row)
|
Chris@49
|
56 {
|
Chris@49
|
57 out.set_size(X_n_rows, 1);
|
Chris@49
|
58
|
Chris@49
|
59 eT* out_mem = out.memptr();
|
Chris@49
|
60
|
Chris@49
|
61 for(uword row=0; row < X_n_rows; ++row)
|
Chris@49
|
62 {
|
Chris@49
|
63 eT val = eT(0);
|
Chris@49
|
64
|
Chris@49
|
65 uword i,j;
|
Chris@49
|
66 for(i=0, j=1; j < X_n_cols; i+=2, j+=2)
|
Chris@49
|
67 {
|
Chris@49
|
68 val += X.at(row,i);
|
Chris@49
|
69 val += X.at(row,j);
|
Chris@49
|
70 }
|
Chris@49
|
71
|
Chris@49
|
72 if(i < X_n_cols)
|
Chris@49
|
73 {
|
Chris@49
|
74 val += X.at(row,i);
|
Chris@49
|
75 }
|
Chris@49
|
76
|
Chris@49
|
77 out_mem[row] = val;
|
Chris@49
|
78 }
|
Chris@49
|
79 }
|
Chris@49
|
80 }
|
Chris@49
|
81 else
|
Chris@49
|
82 {
|
Chris@49
|
83 const uword P_n_rows = P.get_n_rows();
|
Chris@49
|
84 const uword P_n_cols = P.get_n_cols();
|
Chris@49
|
85
|
Chris@49
|
86 if(dim == 0) // traverse across rows (i.e. find the sum in each column)
|
Chris@49
|
87 {
|
Chris@49
|
88 out.set_size(1, P_n_cols);
|
Chris@49
|
89
|
Chris@49
|
90 eT* out_mem = out.memptr();
|
Chris@49
|
91
|
Chris@49
|
92 for(uword col=0; col < P_n_cols; ++col)
|
Chris@49
|
93 {
|
Chris@49
|
94 eT val = eT(0);
|
Chris@49
|
95
|
Chris@49
|
96 uword i,j;
|
Chris@49
|
97 for(i=0, j=1; j < P_n_rows; i+=2, j+=2)
|
Chris@49
|
98 {
|
Chris@49
|
99 val += P.at(i,col);
|
Chris@49
|
100 val += P.at(j,col);
|
Chris@49
|
101 }
|
Chris@49
|
102
|
Chris@49
|
103 if(i < P_n_rows)
|
Chris@49
|
104 {
|
Chris@49
|
105 val += P.at(i,col);
|
Chris@49
|
106 }
|
Chris@49
|
107
|
Chris@49
|
108 out_mem[col] = val;
|
Chris@49
|
109 }
|
Chris@49
|
110 }
|
Chris@49
|
111 else // traverse across columns (i.e. find the sum in each row)
|
Chris@49
|
112 {
|
Chris@49
|
113 out.set_size(P_n_rows, 1);
|
Chris@49
|
114
|
Chris@49
|
115 eT* out_mem = out.memptr();
|
Chris@49
|
116
|
Chris@49
|
117 for(uword row=0; row < P_n_rows; ++row)
|
Chris@49
|
118 {
|
Chris@49
|
119 eT val = eT(0);
|
Chris@49
|
120
|
Chris@49
|
121 uword i,j;
|
Chris@49
|
122 for(i=0, j=1; j < P_n_cols; i+=2, j+=2)
|
Chris@49
|
123 {
|
Chris@49
|
124 val += P.at(row,i);
|
Chris@49
|
125 val += P.at(row,j);
|
Chris@49
|
126 }
|
Chris@49
|
127
|
Chris@49
|
128 if(i < P_n_cols)
|
Chris@49
|
129 {
|
Chris@49
|
130 val += P.at(row,i);
|
Chris@49
|
131 }
|
Chris@49
|
132
|
Chris@49
|
133 out_mem[row] = val;
|
Chris@49
|
134 }
|
Chris@49
|
135 }
|
Chris@49
|
136 }
|
Chris@49
|
137 }
|
Chris@49
|
138
|
Chris@49
|
139
|
Chris@49
|
140
|
Chris@49
|
141 //! @}
|