max@0
|
1 // Copyright (C) 2011 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2011 Conrad Sanderson
|
max@0
|
3 //
|
max@0
|
4 // This file is part of the Armadillo C++ library.
|
max@0
|
5 // It is provided without any warranty of fitness
|
max@0
|
6 // for any purpose. You can redistribute this file
|
max@0
|
7 // and/or modify it under the terms of the GNU
|
max@0
|
8 // Lesser General Public License (LGPL) as published
|
max@0
|
9 // by the Free Software Foundation, either version 3
|
max@0
|
10 // of the License or (at your option) any later version.
|
max@0
|
11 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
12
|
max@0
|
13
|
max@0
|
14 //! \addtogroup Gen
|
max@0
|
15 //! @{
|
max@0
|
16
|
max@0
|
17
|
max@0
|
18
|
max@0
|
19 template<typename eT, typename gen_type>
|
max@0
|
20 arma_inline
|
max@0
|
21 Gen<eT, gen_type>::Gen(const uword in_n_rows, const uword in_n_cols)
|
max@0
|
22 : n_rows(in_n_rows)
|
max@0
|
23 , n_cols(in_n_cols)
|
max@0
|
24 {
|
max@0
|
25 arma_extra_debug_sigprint();
|
max@0
|
26 }
|
max@0
|
27
|
max@0
|
28
|
max@0
|
29
|
max@0
|
30 template<typename eT, typename gen_type>
|
max@0
|
31 arma_inline
|
max@0
|
32 Gen<eT, gen_type>::~Gen()
|
max@0
|
33 {
|
max@0
|
34 arma_extra_debug_sigprint();
|
max@0
|
35 }
|
max@0
|
36
|
max@0
|
37
|
max@0
|
38
|
max@0
|
39 template<typename eT, typename gen_type>
|
max@0
|
40 arma_inline
|
max@0
|
41 eT
|
max@0
|
42 Gen<eT, gen_type>::generate()
|
max@0
|
43 {
|
max@0
|
44 if(is_same_type<gen_type, gen_ones_full>::value == true) { return eT(1); }
|
max@0
|
45 else if(is_same_type<gen_type, gen_zeros >::value == true) { return eT(0); }
|
max@0
|
46 else if(is_same_type<gen_type, gen_randu >::value == true) { return eT(eop_aux_randu<eT>()); }
|
max@0
|
47 else if(is_same_type<gen_type, gen_randn >::value == true) { return eT(eop_aux_randn<eT>()); }
|
max@0
|
48 else { return eT(); }
|
max@0
|
49 }
|
max@0
|
50
|
max@0
|
51
|
max@0
|
52
|
max@0
|
53 template<typename eT, typename gen_type>
|
max@0
|
54 arma_inline
|
max@0
|
55 eT
|
max@0
|
56 Gen<eT, gen_type>::operator[](const uword i) const
|
max@0
|
57 {
|
max@0
|
58 if(is_same_type<gen_type, gen_ones_diag>::value == true)
|
max@0
|
59 {
|
max@0
|
60 return ((i % n_rows) == (i / n_rows)) ? eT(1) : eT(0);
|
max@0
|
61 }
|
max@0
|
62 else
|
max@0
|
63 {
|
max@0
|
64 return Gen<eT, gen_type>::generate();
|
max@0
|
65 }
|
max@0
|
66 }
|
max@0
|
67
|
max@0
|
68
|
max@0
|
69
|
max@0
|
70 template<typename eT, typename gen_type>
|
max@0
|
71 arma_inline
|
max@0
|
72 eT
|
max@0
|
73 Gen<eT, gen_type>::at(const uword row, const uword col) const
|
max@0
|
74 {
|
max@0
|
75 if(is_same_type<gen_type, gen_ones_diag>::value == true)
|
max@0
|
76 {
|
max@0
|
77 return (row == col) ? eT(1) : eT(0);
|
max@0
|
78 }
|
max@0
|
79 else
|
max@0
|
80 {
|
max@0
|
81 return Gen<eT, gen_type>::generate();
|
max@0
|
82 }
|
max@0
|
83 }
|
max@0
|
84
|
max@0
|
85
|
max@0
|
86
|
max@0
|
87 template<typename eT, typename gen_type>
|
max@0
|
88 inline
|
max@0
|
89 void
|
max@0
|
90 Gen<eT, gen_type>::apply(Mat<eT>& out) const
|
max@0
|
91 {
|
max@0
|
92 arma_extra_debug_sigprint();
|
max@0
|
93
|
max@0
|
94 // NOTE: we're assuming that the matrix has already been set to the correct size;
|
max@0
|
95 // this is done by either the Mat contructor or operator=()
|
max@0
|
96
|
max@0
|
97 if(is_same_type<gen_type, gen_ones_diag>::value == true) { out.eye(); }
|
max@0
|
98 else if(is_same_type<gen_type, gen_ones_full>::value == true) { out.ones(); }
|
max@0
|
99 else if(is_same_type<gen_type, gen_zeros >::value == true) { out.zeros(); }
|
max@0
|
100 else if(is_same_type<gen_type, gen_randu >::value == true) { out.randu(); }
|
max@0
|
101 else if(is_same_type<gen_type, gen_randn >::value == true) { out.randn(); }
|
max@0
|
102 }
|
max@0
|
103
|
max@0
|
104
|
max@0
|
105
|
max@0
|
106 template<typename eT, typename gen_type>
|
max@0
|
107 inline
|
max@0
|
108 void
|
max@0
|
109 Gen<eT, gen_type>::apply_inplace_plus(Mat<eT>& out) const
|
max@0
|
110 {
|
max@0
|
111 arma_extra_debug_sigprint();
|
max@0
|
112
|
max@0
|
113 arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "addition");
|
max@0
|
114
|
max@0
|
115
|
max@0
|
116 if(is_same_type<gen_type, gen_ones_diag>::value == true)
|
max@0
|
117 {
|
max@0
|
118 const uword N = (std::min)(n_rows, n_cols);
|
max@0
|
119
|
max@0
|
120 for(uword i=0; i<N; ++i)
|
max@0
|
121 {
|
max@0
|
122 out.at(i,i) += eT(1);
|
max@0
|
123 }
|
max@0
|
124 }
|
max@0
|
125 else
|
max@0
|
126 {
|
max@0
|
127 eT* out_mem = out.memptr();
|
max@0
|
128 const uword n_elem = out.n_elem;
|
max@0
|
129
|
max@0
|
130 uword i,j;
|
max@0
|
131
|
max@0
|
132 for(i=0, j=1; j<n_elem; i+=2, j+=2)
|
max@0
|
133 {
|
max@0
|
134 const eT tmp_i = Gen<eT, gen_type>::generate();
|
max@0
|
135 const eT tmp_j = Gen<eT, gen_type>::generate();
|
max@0
|
136
|
max@0
|
137 out_mem[i] += tmp_i;
|
max@0
|
138 out_mem[j] += tmp_j;
|
max@0
|
139 }
|
max@0
|
140
|
max@0
|
141 if(i < n_elem)
|
max@0
|
142 {
|
max@0
|
143 out_mem[i] += Gen<eT, gen_type>::generate();
|
max@0
|
144 }
|
max@0
|
145 }
|
max@0
|
146
|
max@0
|
147 }
|
max@0
|
148
|
max@0
|
149
|
max@0
|
150
|
max@0
|
151
|
max@0
|
152 template<typename eT, typename gen_type>
|
max@0
|
153 inline
|
max@0
|
154 void
|
max@0
|
155 Gen<eT, gen_type>::apply_inplace_minus(Mat<eT>& out) const
|
max@0
|
156 {
|
max@0
|
157 arma_extra_debug_sigprint();
|
max@0
|
158
|
max@0
|
159 arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "subtraction");
|
max@0
|
160
|
max@0
|
161
|
max@0
|
162 if(is_same_type<gen_type, gen_ones_diag>::value == true)
|
max@0
|
163 {
|
max@0
|
164 const uword N = (std::min)(n_rows, n_cols);
|
max@0
|
165
|
max@0
|
166 for(uword i=0; i<N; ++i)
|
max@0
|
167 {
|
max@0
|
168 out.at(i,i) -= eT(1);
|
max@0
|
169 }
|
max@0
|
170 }
|
max@0
|
171 else
|
max@0
|
172 {
|
max@0
|
173 eT* out_mem = out.memptr();
|
max@0
|
174 const uword n_elem = out.n_elem;
|
max@0
|
175
|
max@0
|
176 uword i,j;
|
max@0
|
177
|
max@0
|
178 for(i=0, j=1; j<n_elem; i+=2, j+=2)
|
max@0
|
179 {
|
max@0
|
180 const eT tmp_i = Gen<eT, gen_type>::generate();
|
max@0
|
181 const eT tmp_j = Gen<eT, gen_type>::generate();
|
max@0
|
182
|
max@0
|
183 out_mem[i] -= tmp_i;
|
max@0
|
184 out_mem[j] -= tmp_j;
|
max@0
|
185 }
|
max@0
|
186
|
max@0
|
187 if(i < n_elem)
|
max@0
|
188 {
|
max@0
|
189 out_mem[i] -= Gen<eT, gen_type>::generate();
|
max@0
|
190 }
|
max@0
|
191 }
|
max@0
|
192
|
max@0
|
193 }
|
max@0
|
194
|
max@0
|
195
|
max@0
|
196
|
max@0
|
197
|
max@0
|
198 template<typename eT, typename gen_type>
|
max@0
|
199 inline
|
max@0
|
200 void
|
max@0
|
201 Gen<eT, gen_type>::apply_inplace_schur(Mat<eT>& out) const
|
max@0
|
202 {
|
max@0
|
203 arma_extra_debug_sigprint();
|
max@0
|
204
|
max@0
|
205 arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "element-wise multiplication");
|
max@0
|
206
|
max@0
|
207
|
max@0
|
208 if(is_same_type<gen_type, gen_ones_diag>::value == true)
|
max@0
|
209 {
|
max@0
|
210 const uword N = (std::min)(n_rows, n_cols);
|
max@0
|
211
|
max@0
|
212 for(uword i=0; i<N; ++i)
|
max@0
|
213 {
|
max@0
|
214 for(uword row=0; row<i; ++row) { out.at(row,i) = eT(0); }
|
max@0
|
215 for(uword row=i+1; row<n_rows; ++row) { out.at(row,i) = eT(0); }
|
max@0
|
216 }
|
max@0
|
217 }
|
max@0
|
218 else
|
max@0
|
219 {
|
max@0
|
220 eT* out_mem = out.memptr();
|
max@0
|
221 const uword n_elem = out.n_elem;
|
max@0
|
222
|
max@0
|
223 uword i,j;
|
max@0
|
224
|
max@0
|
225 for(i=0, j=1; j<n_elem; i+=2, j+=2)
|
max@0
|
226 {
|
max@0
|
227 const eT tmp_i = Gen<eT, gen_type>::generate();
|
max@0
|
228 const eT tmp_j = Gen<eT, gen_type>::generate();
|
max@0
|
229
|
max@0
|
230 out_mem[i] *= tmp_i;
|
max@0
|
231 out_mem[j] *= tmp_j;
|
max@0
|
232 }
|
max@0
|
233
|
max@0
|
234 if(i < n_elem)
|
max@0
|
235 {
|
max@0
|
236 out_mem[i] *= Gen<eT, gen_type>::generate();
|
max@0
|
237 }
|
max@0
|
238 }
|
max@0
|
239
|
max@0
|
240 }
|
max@0
|
241
|
max@0
|
242
|
max@0
|
243
|
max@0
|
244
|
max@0
|
245 template<typename eT, typename gen_type>
|
max@0
|
246 inline
|
max@0
|
247 void
|
max@0
|
248 Gen<eT, gen_type>::apply_inplace_div(Mat<eT>& out) const
|
max@0
|
249 {
|
max@0
|
250 arma_extra_debug_sigprint();
|
max@0
|
251
|
max@0
|
252 arma_debug_assert_same_size(out.n_rows, out.n_cols, n_rows, n_cols, "element-wise division");
|
max@0
|
253
|
max@0
|
254
|
max@0
|
255 if(is_same_type<gen_type, gen_ones_diag>::value == true)
|
max@0
|
256 {
|
max@0
|
257 const uword N = (std::min)(n_rows, n_cols);
|
max@0
|
258
|
max@0
|
259 for(uword i=0; i<N; ++i)
|
max@0
|
260 {
|
max@0
|
261 const eT zero = eT(0);
|
max@0
|
262
|
max@0
|
263 for(uword row=0; row<i; ++row) { out.at(row,i) /= zero; }
|
max@0
|
264 for(uword row=i+1; row<n_rows; ++row) { out.at(row,i) /= zero; }
|
max@0
|
265 }
|
max@0
|
266 }
|
max@0
|
267 else
|
max@0
|
268 {
|
max@0
|
269 eT* out_mem = out.memptr();
|
max@0
|
270 const uword n_elem = out.n_elem;
|
max@0
|
271
|
max@0
|
272 uword i,j;
|
max@0
|
273
|
max@0
|
274 for(i=0, j=1; j<n_elem; i+=2, j+=2)
|
max@0
|
275 {
|
max@0
|
276 const eT tmp_i = Gen<eT, gen_type>::generate();
|
max@0
|
277 const eT tmp_j = Gen<eT, gen_type>::generate();
|
max@0
|
278
|
max@0
|
279 out_mem[i] /= tmp_i;
|
max@0
|
280 out_mem[j] /= tmp_j;
|
max@0
|
281 }
|
max@0
|
282
|
max@0
|
283 if(i < n_elem)
|
max@0
|
284 {
|
max@0
|
285 out_mem[i] /= Gen<eT, gen_type>::generate();
|
max@0
|
286 }
|
max@0
|
287 }
|
max@0
|
288
|
max@0
|
289 }
|
max@0
|
290
|
max@0
|
291
|
max@0
|
292
|
max@0
|
293
|
max@0
|
294 //! @}
|