Chris@49
|
1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2012 Conrad Sanderson
|
Chris@49
|
3 // Copyright (C) 2012 Ryan Curtin
|
Chris@49
|
4 //
|
Chris@49
|
5 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
8
|
Chris@49
|
9
|
Chris@49
|
10 //! \addtogroup fn_dot
|
Chris@49
|
11 //! @{
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 template<typename T1, typename T2>
|
Chris@49
|
15 arma_inline
|
Chris@49
|
16 arma_warn_unused
|
Chris@49
|
17 typename
|
Chris@49
|
18 enable_if2
|
Chris@49
|
19 <
|
Chris@49
|
20 is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value,
|
Chris@49
|
21 typename T1::elem_type
|
Chris@49
|
22 >::result
|
Chris@49
|
23 dot
|
Chris@49
|
24 (
|
Chris@49
|
25 const T1& A,
|
Chris@49
|
26 const T2& B
|
Chris@49
|
27 )
|
Chris@49
|
28 {
|
Chris@49
|
29 arma_extra_debug_sigprint();
|
Chris@49
|
30
|
Chris@49
|
31 return op_dot::apply(A,B);
|
Chris@49
|
32 }
|
Chris@49
|
33
|
Chris@49
|
34
|
Chris@49
|
35
|
Chris@49
|
36 template<typename T1, typename T2>
|
Chris@49
|
37 arma_inline
|
Chris@49
|
38 arma_warn_unused
|
Chris@49
|
39 typename
|
Chris@49
|
40 enable_if2
|
Chris@49
|
41 <
|
Chris@49
|
42 is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value,
|
Chris@49
|
43 typename T1::elem_type
|
Chris@49
|
44 >::result
|
Chris@49
|
45 norm_dot
|
Chris@49
|
46 (
|
Chris@49
|
47 const T1& A,
|
Chris@49
|
48 const T2& B
|
Chris@49
|
49 )
|
Chris@49
|
50 {
|
Chris@49
|
51 arma_extra_debug_sigprint();
|
Chris@49
|
52
|
Chris@49
|
53 return op_norm_dot::apply(A,B);
|
Chris@49
|
54 }
|
Chris@49
|
55
|
Chris@49
|
56
|
Chris@49
|
57
|
Chris@49
|
58 //
|
Chris@49
|
59 // cdot
|
Chris@49
|
60
|
Chris@49
|
61
|
Chris@49
|
62
|
Chris@49
|
63 template<typename T1, typename T2>
|
Chris@49
|
64 arma_inline
|
Chris@49
|
65 arma_warn_unused
|
Chris@49
|
66 typename
|
Chris@49
|
67 enable_if2
|
Chris@49
|
68 <
|
Chris@49
|
69 is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value && is_not_complex<typename T1::elem_type>::value,
|
Chris@49
|
70 typename T1::elem_type
|
Chris@49
|
71 >::result
|
Chris@49
|
72 cdot
|
Chris@49
|
73 (
|
Chris@49
|
74 const T1& A,
|
Chris@49
|
75 const T2& B
|
Chris@49
|
76 )
|
Chris@49
|
77 {
|
Chris@49
|
78 arma_extra_debug_sigprint();
|
Chris@49
|
79
|
Chris@49
|
80 return op_dot::apply(A,B);
|
Chris@49
|
81 }
|
Chris@49
|
82
|
Chris@49
|
83
|
Chris@49
|
84
|
Chris@49
|
85
|
Chris@49
|
86 template<typename T1, typename T2>
|
Chris@49
|
87 arma_inline
|
Chris@49
|
88 arma_warn_unused
|
Chris@49
|
89 typename
|
Chris@49
|
90 enable_if2
|
Chris@49
|
91 <
|
Chris@49
|
92 is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value && is_complex<typename T1::elem_type>::value,
|
Chris@49
|
93 typename T1::elem_type
|
Chris@49
|
94 >::result
|
Chris@49
|
95 cdot
|
Chris@49
|
96 (
|
Chris@49
|
97 const T1& A,
|
Chris@49
|
98 const T2& B
|
Chris@49
|
99 )
|
Chris@49
|
100 {
|
Chris@49
|
101 arma_extra_debug_sigprint();
|
Chris@49
|
102
|
Chris@49
|
103 return op_cdot::apply(A,B);
|
Chris@49
|
104 }
|
Chris@49
|
105
|
Chris@49
|
106
|
Chris@49
|
107
|
Chris@49
|
108 // convert dot(htrans(x), y) to cdot(x,y)
|
Chris@49
|
109
|
Chris@49
|
110 template<typename T1, typename T2>
|
Chris@49
|
111 arma_inline
|
Chris@49
|
112 arma_warn_unused
|
Chris@49
|
113 typename
|
Chris@49
|
114 enable_if2
|
Chris@49
|
115 <
|
Chris@49
|
116 is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value && is_complex<typename T1::elem_type>::value,
|
Chris@49
|
117 typename T1::elem_type
|
Chris@49
|
118 >::result
|
Chris@49
|
119 dot
|
Chris@49
|
120 (
|
Chris@49
|
121 const Op<T1, op_htrans>& A,
|
Chris@49
|
122 const T2& B
|
Chris@49
|
123 )
|
Chris@49
|
124 {
|
Chris@49
|
125 arma_extra_debug_sigprint();
|
Chris@49
|
126
|
Chris@49
|
127 return cdot(A.m, B);
|
Chris@49
|
128 }
|
Chris@49
|
129
|
Chris@49
|
130
|
Chris@49
|
131
|
Chris@49
|
132 //
|
Chris@49
|
133 // for sparse matrices
|
Chris@49
|
134 //
|
Chris@49
|
135
|
Chris@49
|
136
|
Chris@49
|
137
|
Chris@49
|
138 namespace priv
|
Chris@49
|
139 {
|
Chris@49
|
140
|
Chris@49
|
141 template<typename T1, typename T2>
|
Chris@49
|
142 arma_hot
|
Chris@49
|
143 inline
|
Chris@49
|
144 typename T1::elem_type
|
Chris@49
|
145 dot_helper(const SpProxy<T1>& pa, const SpProxy<T2>& pb)
|
Chris@49
|
146 {
|
Chris@49
|
147 typedef typename T1::elem_type eT;
|
Chris@49
|
148
|
Chris@49
|
149 // Iterate over both objects and see when they are the same
|
Chris@49
|
150 eT result = eT(0);
|
Chris@49
|
151
|
Chris@49
|
152 typename SpProxy<T1>::const_iterator_type a_it = pa.begin();
|
Chris@49
|
153 typename SpProxy<T1>::const_iterator_type a_end = pa.end();
|
Chris@49
|
154
|
Chris@49
|
155 typename SpProxy<T2>::const_iterator_type b_it = pb.begin();
|
Chris@49
|
156 typename SpProxy<T2>::const_iterator_type b_end = pb.end();
|
Chris@49
|
157
|
Chris@49
|
158 while((a_it != a_end) && (b_it != b_end))
|
Chris@49
|
159 {
|
Chris@49
|
160 if(a_it == b_it)
|
Chris@49
|
161 {
|
Chris@49
|
162 result += (*a_it) * (*b_it);
|
Chris@49
|
163
|
Chris@49
|
164 ++a_it;
|
Chris@49
|
165 ++b_it;
|
Chris@49
|
166 }
|
Chris@49
|
167 else if((a_it.col() < b_it.col()) || ((a_it.col() == b_it.col()) && (a_it.row() < b_it.row())))
|
Chris@49
|
168 {
|
Chris@49
|
169 // a_it is "behind"
|
Chris@49
|
170 ++a_it;
|
Chris@49
|
171 }
|
Chris@49
|
172 else
|
Chris@49
|
173 {
|
Chris@49
|
174 // b_it is "behind"
|
Chris@49
|
175 ++b_it;
|
Chris@49
|
176 }
|
Chris@49
|
177 }
|
Chris@49
|
178
|
Chris@49
|
179 return result;
|
Chris@49
|
180 }
|
Chris@49
|
181
|
Chris@49
|
182 }
|
Chris@49
|
183
|
Chris@49
|
184
|
Chris@49
|
185
|
Chris@49
|
186 //! dot product of two sparse objects
|
Chris@49
|
187 template<typename T1, typename T2>
|
Chris@49
|
188 arma_warn_unused
|
Chris@49
|
189 arma_hot
|
Chris@49
|
190 inline
|
Chris@49
|
191 typename
|
Chris@49
|
192 enable_if2
|
Chris@49
|
193 <(is_arma_sparse_type<T1>::value) && (is_arma_sparse_type<T2>::value) && (is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
|
Chris@49
|
194 typename T1::elem_type
|
Chris@49
|
195 >::result
|
Chris@49
|
196 dot
|
Chris@49
|
197 (
|
Chris@49
|
198 const T1& x,
|
Chris@49
|
199 const T2& y
|
Chris@49
|
200 )
|
Chris@49
|
201 {
|
Chris@49
|
202 arma_extra_debug_sigprint();
|
Chris@49
|
203
|
Chris@49
|
204 const SpProxy<T1> pa(x);
|
Chris@49
|
205 const SpProxy<T2> pb(y);
|
Chris@49
|
206
|
Chris@49
|
207 arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "dot()");
|
Chris@49
|
208
|
Chris@49
|
209 typedef typename T1::elem_type eT;
|
Chris@49
|
210
|
Chris@49
|
211 typedef typename SpProxy<T1>::stored_type pa_Q_type;
|
Chris@49
|
212 typedef typename SpProxy<T2>::stored_type pb_Q_type;
|
Chris@49
|
213
|
Chris@49
|
214 if(
|
Chris@49
|
215 ( (SpProxy<T1>::must_use_iterator == false) && (SpProxy<T2>::must_use_iterator == false) )
|
Chris@49
|
216 && ( (is_SpMat<pa_Q_type>::value == true ) && (is_SpMat<pb_Q_type>::value == true ) )
|
Chris@49
|
217 )
|
Chris@49
|
218 {
|
Chris@49
|
219 const unwrap_spmat<pa_Q_type> tmp_a(pa.Q);
|
Chris@49
|
220 const unwrap_spmat<pb_Q_type> tmp_b(pb.Q);
|
Chris@49
|
221
|
Chris@49
|
222 const SpMat<eT>& A = tmp_a.M;
|
Chris@49
|
223 const SpMat<eT>& B = tmp_b.M;
|
Chris@49
|
224
|
Chris@49
|
225 if( &A == &B )
|
Chris@49
|
226 {
|
Chris@49
|
227 // We can do it directly!
|
Chris@49
|
228 return op_dot::direct_dot_arma(A.n_nonzero, A.values, A.values);
|
Chris@49
|
229 }
|
Chris@49
|
230 else
|
Chris@49
|
231 {
|
Chris@49
|
232 return priv::dot_helper(pa,pb);
|
Chris@49
|
233 }
|
Chris@49
|
234 }
|
Chris@49
|
235 else
|
Chris@49
|
236 {
|
Chris@49
|
237 return priv::dot_helper(pa,pb);
|
Chris@49
|
238 }
|
Chris@49
|
239 }
|
Chris@49
|
240
|
Chris@49
|
241
|
Chris@49
|
242
|
Chris@49
|
243 //! dot product of one dense and one sparse object
|
Chris@49
|
244 template<typename T1, typename T2>
|
Chris@49
|
245 arma_warn_unused
|
Chris@49
|
246 arma_hot
|
Chris@49
|
247 inline
|
Chris@49
|
248 typename
|
Chris@49
|
249 enable_if2
|
Chris@49
|
250 <(is_arma_type<T1>::value) && (is_arma_sparse_type<T2>::value) && (is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
|
Chris@49
|
251 typename T1::elem_type
|
Chris@49
|
252 >::result
|
Chris@49
|
253 dot
|
Chris@49
|
254 (
|
Chris@49
|
255 const T1& x,
|
Chris@49
|
256 const T2& y
|
Chris@49
|
257 )
|
Chris@49
|
258 {
|
Chris@49
|
259 arma_extra_debug_sigprint();
|
Chris@49
|
260
|
Chris@49
|
261 const Proxy<T1> pa(x);
|
Chris@49
|
262 const SpProxy<T2> pb(y);
|
Chris@49
|
263
|
Chris@49
|
264 arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "dot()");
|
Chris@49
|
265
|
Chris@49
|
266 typedef typename T1::elem_type eT;
|
Chris@49
|
267
|
Chris@49
|
268 eT result = eT(0);
|
Chris@49
|
269
|
Chris@49
|
270 typename SpProxy<T2>::const_iterator_type it = pb.begin();
|
Chris@49
|
271 typename SpProxy<T2>::const_iterator_type it_end = pb.end();
|
Chris@49
|
272
|
Chris@49
|
273 // prefer_at_accessor won't save us operations
|
Chris@49
|
274 while(it != it_end)
|
Chris@49
|
275 {
|
Chris@49
|
276 result += (*it) * pa.at(it.row(), it.col());
|
Chris@49
|
277 ++it;
|
Chris@49
|
278 }
|
Chris@49
|
279
|
Chris@49
|
280 return result;
|
Chris@49
|
281 }
|
Chris@49
|
282
|
Chris@49
|
283
|
Chris@49
|
284
|
Chris@49
|
285 //! dot product of one sparse and one dense object
|
Chris@49
|
286 template<typename T1, typename T2>
|
Chris@49
|
287 arma_warn_unused
|
Chris@49
|
288 arma_hot
|
Chris@49
|
289 inline
|
Chris@49
|
290 typename
|
Chris@49
|
291 enable_if2
|
Chris@49
|
292 <(is_arma_sparse_type<T1>::value) && (is_arma_type<T2>::value) && (is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
|
Chris@49
|
293 typename T1::elem_type
|
Chris@49
|
294 >::result
|
Chris@49
|
295 dot
|
Chris@49
|
296 (
|
Chris@49
|
297 const T1& x,
|
Chris@49
|
298 const T2& y
|
Chris@49
|
299 )
|
Chris@49
|
300 {
|
Chris@49
|
301 arma_extra_debug_sigprint();
|
Chris@49
|
302
|
Chris@49
|
303 // this is commutative
|
Chris@49
|
304 return dot(y, x);
|
Chris@49
|
305 }
|
Chris@49
|
306
|
Chris@49
|
307
|
Chris@49
|
308
|
Chris@49
|
309 //! @}
|