max@0
|
1 // Copyright (C) 2008-2010 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2008-2010 Conrad Sanderson
|
max@0
|
3 //
|
max@0
|
4 // This file is part of the Armadillo C++ library.
|
max@0
|
5 // It is provided without any warranty of fitness
|
max@0
|
6 // for any purpose. You can redistribute this file
|
max@0
|
7 // and/or modify it under the terms of the GNU
|
max@0
|
8 // Lesser General Public License (LGPL) as published
|
max@0
|
9 // by the Free Software Foundation, either version 3
|
max@0
|
10 // of the License or (at your option) any later version.
|
max@0
|
11 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
12
|
max@0
|
13
|
max@0
|
14 //! \addtogroup op_dotext
|
max@0
|
15 //! @{
|
max@0
|
16
|
max@0
|
17
|
max@0
|
18
|
max@0
|
19 template<typename eT>
|
max@0
|
20 inline
|
max@0
|
21 eT
|
max@0
|
22 op_dotext::direct_rowvec_mat_colvec
|
max@0
|
23 (
|
max@0
|
24 const eT* A_mem,
|
max@0
|
25 const Mat<eT>& B,
|
max@0
|
26 const eT* C_mem
|
max@0
|
27 )
|
max@0
|
28 {
|
max@0
|
29 arma_extra_debug_sigprint();
|
max@0
|
30
|
max@0
|
31 const uword cost_AB = B.n_cols;
|
max@0
|
32 const uword cost_BC = B.n_rows;
|
max@0
|
33
|
max@0
|
34 if(cost_AB <= cost_BC)
|
max@0
|
35 {
|
max@0
|
36 podarray<eT> tmp(B.n_cols);
|
max@0
|
37
|
max@0
|
38 for(uword col=0; col<B.n_cols; ++col)
|
max@0
|
39 {
|
max@0
|
40 const eT* B_coldata = B.colptr(col);
|
max@0
|
41
|
max@0
|
42 eT val = eT(0);
|
max@0
|
43 for(uword i=0; i<B.n_rows; ++i)
|
max@0
|
44 {
|
max@0
|
45 val += A_mem[i] * B_coldata[i];
|
max@0
|
46 }
|
max@0
|
47
|
max@0
|
48 tmp[col] = val;
|
max@0
|
49 }
|
max@0
|
50
|
max@0
|
51 return op_dot::direct_dot(B.n_cols, tmp.mem, C_mem);
|
max@0
|
52 }
|
max@0
|
53 else
|
max@0
|
54 {
|
max@0
|
55 podarray<eT> tmp(B.n_rows);
|
max@0
|
56
|
max@0
|
57 for(uword row=0; row<B.n_rows; ++row)
|
max@0
|
58 {
|
max@0
|
59 eT val = eT(0);
|
max@0
|
60 for(uword col=0; col<B.n_cols; ++col)
|
max@0
|
61 {
|
max@0
|
62 val += B.at(row,col) * C_mem[col];
|
max@0
|
63 }
|
max@0
|
64
|
max@0
|
65 tmp[row] = val;
|
max@0
|
66 }
|
max@0
|
67
|
max@0
|
68 return op_dot::direct_dot(B.n_rows, A_mem, tmp.mem);
|
max@0
|
69 }
|
max@0
|
70
|
max@0
|
71
|
max@0
|
72 }
|
max@0
|
73
|
max@0
|
74
|
max@0
|
75
|
max@0
|
76 template<typename eT>
|
max@0
|
77 inline
|
max@0
|
78 eT
|
max@0
|
79 op_dotext::direct_rowvec_transmat_colvec
|
max@0
|
80 (
|
max@0
|
81 const eT* A_mem,
|
max@0
|
82 const Mat<eT>& B,
|
max@0
|
83 const eT* C_mem
|
max@0
|
84 )
|
max@0
|
85 {
|
max@0
|
86 arma_extra_debug_sigprint();
|
max@0
|
87
|
max@0
|
88 const uword cost_AB = B.n_rows;
|
max@0
|
89 const uword cost_BC = B.n_cols;
|
max@0
|
90
|
max@0
|
91 if(cost_AB <= cost_BC)
|
max@0
|
92 {
|
max@0
|
93 podarray<eT> tmp(B.n_rows);
|
max@0
|
94
|
max@0
|
95 for(uword row=0; row<B.n_rows; ++row)
|
max@0
|
96 {
|
max@0
|
97 eT val = eT(0);
|
max@0
|
98
|
max@0
|
99 for(uword i=0; i<B.n_cols; ++i)
|
max@0
|
100 {
|
max@0
|
101 val += A_mem[i] * B.at(row,i);
|
max@0
|
102 }
|
max@0
|
103
|
max@0
|
104 tmp[row] = val;
|
max@0
|
105 }
|
max@0
|
106
|
max@0
|
107 return op_dot::direct_dot(B.n_rows, tmp.mem, C_mem);
|
max@0
|
108 }
|
max@0
|
109 else
|
max@0
|
110 {
|
max@0
|
111 podarray<eT> tmp(B.n_cols);
|
max@0
|
112
|
max@0
|
113 for(uword col=0; col<B.n_cols; ++col)
|
max@0
|
114 {
|
max@0
|
115 const eT* B_coldata = B.colptr(col);
|
max@0
|
116
|
max@0
|
117 eT val = eT(0);
|
max@0
|
118
|
max@0
|
119 for(uword i=0; i<B.n_rows; ++i)
|
max@0
|
120 {
|
max@0
|
121 val += B_coldata[i] * C_mem[i];
|
max@0
|
122 }
|
max@0
|
123
|
max@0
|
124 tmp[col] = val;
|
max@0
|
125 }
|
max@0
|
126
|
max@0
|
127 return op_dot::direct_dot(B.n_cols, A_mem, tmp.mem);
|
max@0
|
128 }
|
max@0
|
129
|
max@0
|
130
|
max@0
|
131 }
|
max@0
|
132
|
max@0
|
133
|
max@0
|
134
|
max@0
|
135 template<typename eT>
|
max@0
|
136 inline
|
max@0
|
137 eT
|
max@0
|
138 op_dotext::direct_rowvec_diagmat_colvec
|
max@0
|
139 (
|
max@0
|
140 const eT* A_mem,
|
max@0
|
141 const Mat<eT>& B,
|
max@0
|
142 const eT* C_mem
|
max@0
|
143 )
|
max@0
|
144 {
|
max@0
|
145 arma_extra_debug_sigprint();
|
max@0
|
146
|
max@0
|
147 eT val = eT(0);
|
max@0
|
148
|
max@0
|
149 for(uword i=0; i<B.n_rows; ++i)
|
max@0
|
150 {
|
max@0
|
151 val += A_mem[i] * B.at(i,i) * C_mem[i];
|
max@0
|
152 }
|
max@0
|
153
|
max@0
|
154 return val;
|
max@0
|
155 }
|
max@0
|
156
|
max@0
|
157
|
max@0
|
158
|
max@0
|
159 template<typename eT>
|
max@0
|
160 inline
|
max@0
|
161 eT
|
max@0
|
162 op_dotext::direct_rowvec_invdiagmat_colvec
|
max@0
|
163 (
|
max@0
|
164 const eT* A_mem,
|
max@0
|
165 const Mat<eT>& B,
|
max@0
|
166 const eT* C_mem
|
max@0
|
167 )
|
max@0
|
168 {
|
max@0
|
169 arma_extra_debug_sigprint();
|
max@0
|
170
|
max@0
|
171 eT val = eT(0);
|
max@0
|
172
|
max@0
|
173 for(uword i=0; i<B.n_rows; ++i)
|
max@0
|
174 {
|
max@0
|
175 val += (A_mem[i] * C_mem[i]) / B.at(i,i);
|
max@0
|
176 }
|
max@0
|
177
|
max@0
|
178 return val;
|
max@0
|
179 }
|
max@0
|
180
|
max@0
|
181
|
max@0
|
182
|
max@0
|
183 template<typename eT>
|
max@0
|
184 inline
|
max@0
|
185 eT
|
max@0
|
186 op_dotext::direct_rowvec_invdiagvec_colvec
|
max@0
|
187 (
|
max@0
|
188 const eT* A_mem,
|
max@0
|
189 const Mat<eT>& B,
|
max@0
|
190 const eT* C_mem
|
max@0
|
191 )
|
max@0
|
192 {
|
max@0
|
193 arma_extra_debug_sigprint();
|
max@0
|
194
|
max@0
|
195 const eT* B_mem = B.mem;
|
max@0
|
196
|
max@0
|
197 eT val = eT(0);
|
max@0
|
198
|
max@0
|
199 for(uword i=0; i<B.n_elem; ++i)
|
max@0
|
200 {
|
max@0
|
201 val += (A_mem[i] * C_mem[i]) / B_mem[i];
|
max@0
|
202 }
|
max@0
|
203
|
max@0
|
204 return val;
|
max@0
|
205 }
|
max@0
|
206
|
max@0
|
207
|
max@0
|
208
|
max@0
|
209 //! @}
|