Chris@49
|
1 // Copyright (C) 2008-2010 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2010 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9 //! \addtogroup op_dotext
|
Chris@49
|
10 //! @{
|
Chris@49
|
11
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 template<typename eT>
|
Chris@49
|
15 inline
|
Chris@49
|
16 eT
|
Chris@49
|
17 op_dotext::direct_rowvec_mat_colvec
|
Chris@49
|
18 (
|
Chris@49
|
19 const eT* A_mem,
|
Chris@49
|
20 const Mat<eT>& B,
|
Chris@49
|
21 const eT* C_mem
|
Chris@49
|
22 )
|
Chris@49
|
23 {
|
Chris@49
|
24 arma_extra_debug_sigprint();
|
Chris@49
|
25
|
Chris@49
|
26 const uword cost_AB = B.n_cols;
|
Chris@49
|
27 const uword cost_BC = B.n_rows;
|
Chris@49
|
28
|
Chris@49
|
29 if(cost_AB <= cost_BC)
|
Chris@49
|
30 {
|
Chris@49
|
31 podarray<eT> tmp(B.n_cols);
|
Chris@49
|
32
|
Chris@49
|
33 for(uword col=0; col<B.n_cols; ++col)
|
Chris@49
|
34 {
|
Chris@49
|
35 const eT* B_coldata = B.colptr(col);
|
Chris@49
|
36
|
Chris@49
|
37 eT val = eT(0);
|
Chris@49
|
38 for(uword i=0; i<B.n_rows; ++i)
|
Chris@49
|
39 {
|
Chris@49
|
40 val += A_mem[i] * B_coldata[i];
|
Chris@49
|
41 }
|
Chris@49
|
42
|
Chris@49
|
43 tmp[col] = val;
|
Chris@49
|
44 }
|
Chris@49
|
45
|
Chris@49
|
46 return op_dot::direct_dot(B.n_cols, tmp.mem, C_mem);
|
Chris@49
|
47 }
|
Chris@49
|
48 else
|
Chris@49
|
49 {
|
Chris@49
|
50 podarray<eT> tmp(B.n_rows);
|
Chris@49
|
51
|
Chris@49
|
52 for(uword row=0; row<B.n_rows; ++row)
|
Chris@49
|
53 {
|
Chris@49
|
54 eT val = eT(0);
|
Chris@49
|
55 for(uword col=0; col<B.n_cols; ++col)
|
Chris@49
|
56 {
|
Chris@49
|
57 val += B.at(row,col) * C_mem[col];
|
Chris@49
|
58 }
|
Chris@49
|
59
|
Chris@49
|
60 tmp[row] = val;
|
Chris@49
|
61 }
|
Chris@49
|
62
|
Chris@49
|
63 return op_dot::direct_dot(B.n_rows, A_mem, tmp.mem);
|
Chris@49
|
64 }
|
Chris@49
|
65
|
Chris@49
|
66
|
Chris@49
|
67 }
|
Chris@49
|
68
|
Chris@49
|
69
|
Chris@49
|
70
|
Chris@49
|
71 template<typename eT>
|
Chris@49
|
72 inline
|
Chris@49
|
73 eT
|
Chris@49
|
74 op_dotext::direct_rowvec_transmat_colvec
|
Chris@49
|
75 (
|
Chris@49
|
76 const eT* A_mem,
|
Chris@49
|
77 const Mat<eT>& B,
|
Chris@49
|
78 const eT* C_mem
|
Chris@49
|
79 )
|
Chris@49
|
80 {
|
Chris@49
|
81 arma_extra_debug_sigprint();
|
Chris@49
|
82
|
Chris@49
|
83 const uword cost_AB = B.n_rows;
|
Chris@49
|
84 const uword cost_BC = B.n_cols;
|
Chris@49
|
85
|
Chris@49
|
86 if(cost_AB <= cost_BC)
|
Chris@49
|
87 {
|
Chris@49
|
88 podarray<eT> tmp(B.n_rows);
|
Chris@49
|
89
|
Chris@49
|
90 for(uword row=0; row<B.n_rows; ++row)
|
Chris@49
|
91 {
|
Chris@49
|
92 eT val = eT(0);
|
Chris@49
|
93
|
Chris@49
|
94 for(uword i=0; i<B.n_cols; ++i)
|
Chris@49
|
95 {
|
Chris@49
|
96 val += A_mem[i] * B.at(row,i);
|
Chris@49
|
97 }
|
Chris@49
|
98
|
Chris@49
|
99 tmp[row] = val;
|
Chris@49
|
100 }
|
Chris@49
|
101
|
Chris@49
|
102 return op_dot::direct_dot(B.n_rows, tmp.mem, C_mem);
|
Chris@49
|
103 }
|
Chris@49
|
104 else
|
Chris@49
|
105 {
|
Chris@49
|
106 podarray<eT> tmp(B.n_cols);
|
Chris@49
|
107
|
Chris@49
|
108 for(uword col=0; col<B.n_cols; ++col)
|
Chris@49
|
109 {
|
Chris@49
|
110 const eT* B_coldata = B.colptr(col);
|
Chris@49
|
111
|
Chris@49
|
112 eT val = eT(0);
|
Chris@49
|
113
|
Chris@49
|
114 for(uword i=0; i<B.n_rows; ++i)
|
Chris@49
|
115 {
|
Chris@49
|
116 val += B_coldata[i] * C_mem[i];
|
Chris@49
|
117 }
|
Chris@49
|
118
|
Chris@49
|
119 tmp[col] = val;
|
Chris@49
|
120 }
|
Chris@49
|
121
|
Chris@49
|
122 return op_dot::direct_dot(B.n_cols, A_mem, tmp.mem);
|
Chris@49
|
123 }
|
Chris@49
|
124
|
Chris@49
|
125
|
Chris@49
|
126 }
|
Chris@49
|
127
|
Chris@49
|
128
|
Chris@49
|
129
|
Chris@49
|
130 template<typename eT>
|
Chris@49
|
131 inline
|
Chris@49
|
132 eT
|
Chris@49
|
133 op_dotext::direct_rowvec_diagmat_colvec
|
Chris@49
|
134 (
|
Chris@49
|
135 const eT* A_mem,
|
Chris@49
|
136 const Mat<eT>& B,
|
Chris@49
|
137 const eT* C_mem
|
Chris@49
|
138 )
|
Chris@49
|
139 {
|
Chris@49
|
140 arma_extra_debug_sigprint();
|
Chris@49
|
141
|
Chris@49
|
142 eT val = eT(0);
|
Chris@49
|
143
|
Chris@49
|
144 for(uword i=0; i<B.n_rows; ++i)
|
Chris@49
|
145 {
|
Chris@49
|
146 val += A_mem[i] * B.at(i,i) * C_mem[i];
|
Chris@49
|
147 }
|
Chris@49
|
148
|
Chris@49
|
149 return val;
|
Chris@49
|
150 }
|
Chris@49
|
151
|
Chris@49
|
152
|
Chris@49
|
153
|
Chris@49
|
154 template<typename eT>
|
Chris@49
|
155 inline
|
Chris@49
|
156 eT
|
Chris@49
|
157 op_dotext::direct_rowvec_invdiagmat_colvec
|
Chris@49
|
158 (
|
Chris@49
|
159 const eT* A_mem,
|
Chris@49
|
160 const Mat<eT>& B,
|
Chris@49
|
161 const eT* C_mem
|
Chris@49
|
162 )
|
Chris@49
|
163 {
|
Chris@49
|
164 arma_extra_debug_sigprint();
|
Chris@49
|
165
|
Chris@49
|
166 eT val = eT(0);
|
Chris@49
|
167
|
Chris@49
|
168 for(uword i=0; i<B.n_rows; ++i)
|
Chris@49
|
169 {
|
Chris@49
|
170 val += (A_mem[i] * C_mem[i]) / B.at(i,i);
|
Chris@49
|
171 }
|
Chris@49
|
172
|
Chris@49
|
173 return val;
|
Chris@49
|
174 }
|
Chris@49
|
175
|
Chris@49
|
176
|
Chris@49
|
177
|
Chris@49
|
178 template<typename eT>
|
Chris@49
|
179 inline
|
Chris@49
|
180 eT
|
Chris@49
|
181 op_dotext::direct_rowvec_invdiagvec_colvec
|
Chris@49
|
182 (
|
Chris@49
|
183 const eT* A_mem,
|
Chris@49
|
184 const Mat<eT>& B,
|
Chris@49
|
185 const eT* C_mem
|
Chris@49
|
186 )
|
Chris@49
|
187 {
|
Chris@49
|
188 arma_extra_debug_sigprint();
|
Chris@49
|
189
|
Chris@49
|
190 const eT* B_mem = B.mem;
|
Chris@49
|
191
|
Chris@49
|
192 eT val = eT(0);
|
Chris@49
|
193
|
Chris@49
|
194 for(uword i=0; i<B.n_elem; ++i)
|
Chris@49
|
195 {
|
Chris@49
|
196 val += (A_mem[i] * C_mem[i]) / B_mem[i];
|
Chris@49
|
197 }
|
Chris@49
|
198
|
Chris@49
|
199 return val;
|
Chris@49
|
200 }
|
Chris@49
|
201
|
Chris@49
|
202
|
Chris@49
|
203
|
Chris@49
|
204 //! @}
|