Chris@49
|
1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2012 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9
|
Chris@49
|
10 //! \addtogroup op_reshape
|
Chris@49
|
11 //! @{
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14
|
Chris@49
|
15 template<typename T1>
|
Chris@49
|
16 inline
|
Chris@49
|
17 void
|
Chris@49
|
18 op_reshape::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_reshape>& in)
|
Chris@49
|
19 {
|
Chris@49
|
20 arma_extra_debug_sigprint();
|
Chris@49
|
21
|
Chris@49
|
22 typedef typename T1::elem_type eT;
|
Chris@49
|
23
|
Chris@49
|
24 const unwrap<T1> A_tmp(in.m);
|
Chris@49
|
25 const Mat<eT>& A = A_tmp.M;
|
Chris@49
|
26
|
Chris@49
|
27 const bool is_alias = (&out == &A);
|
Chris@49
|
28
|
Chris@49
|
29 const uword in_n_rows = in.aux_uword_a;
|
Chris@49
|
30 const uword in_n_cols = in.aux_uword_b;
|
Chris@49
|
31 const uword in_dim = in.aux_uword_c;
|
Chris@49
|
32
|
Chris@49
|
33 const uword in_n_elem = in_n_rows * in_n_cols;
|
Chris@49
|
34
|
Chris@49
|
35 if(A.n_elem == in_n_elem)
|
Chris@49
|
36 {
|
Chris@49
|
37 if(in_dim == 0)
|
Chris@49
|
38 {
|
Chris@49
|
39 if(is_alias == false)
|
Chris@49
|
40 {
|
Chris@49
|
41 out.set_size(in_n_rows, in_n_cols);
|
Chris@49
|
42 arrayops::copy( out.memptr(), A.memptr(), out.n_elem );
|
Chris@49
|
43 }
|
Chris@49
|
44 else // &out == &A, i.e. inplace resize
|
Chris@49
|
45 {
|
Chris@49
|
46 const bool same_size = ( (out.n_rows == in_n_rows) && (out.n_cols == in_n_cols) );
|
Chris@49
|
47
|
Chris@49
|
48 if(same_size == false)
|
Chris@49
|
49 {
|
Chris@49
|
50 arma_debug_check
|
Chris@49
|
51 (
|
Chris@49
|
52 (out.mem_state == 3),
|
Chris@49
|
53 "reshape(): size can't be changed as template based size specification is in use"
|
Chris@49
|
54 );
|
Chris@49
|
55
|
Chris@49
|
56 access::rw(out.n_rows) = in_n_rows;
|
Chris@49
|
57 access::rw(out.n_cols) = in_n_cols;
|
Chris@49
|
58 }
|
Chris@49
|
59 }
|
Chris@49
|
60 }
|
Chris@49
|
61 else
|
Chris@49
|
62 {
|
Chris@49
|
63 unwrap_check< Mat<eT> > B_tmp(A, is_alias);
|
Chris@49
|
64 const Mat<eT>& B = B_tmp.M;
|
Chris@49
|
65
|
Chris@49
|
66 out.set_size(in_n_rows, in_n_cols);
|
Chris@49
|
67
|
Chris@49
|
68 eT* out_mem = out.memptr();
|
Chris@49
|
69 uword i = 0;
|
Chris@49
|
70
|
Chris@49
|
71 const uword B_n_rows = B.n_rows;
|
Chris@49
|
72 const uword B_n_cols = B.n_cols;
|
Chris@49
|
73
|
Chris@49
|
74 for(uword row=0; row<B_n_rows; ++row)
|
Chris@49
|
75 for(uword col=0; col<B_n_cols; ++col)
|
Chris@49
|
76 {
|
Chris@49
|
77 out_mem[i] = B.at(row,col);
|
Chris@49
|
78 ++i;
|
Chris@49
|
79 }
|
Chris@49
|
80 }
|
Chris@49
|
81 }
|
Chris@49
|
82 else
|
Chris@49
|
83 {
|
Chris@49
|
84 const unwrap_check< Mat<eT> > B_tmp(A, is_alias);
|
Chris@49
|
85 const Mat<eT>& B = B_tmp.M;
|
Chris@49
|
86
|
Chris@49
|
87 const uword n_elem_to_copy = (std::min)(B.n_elem, in_n_elem);
|
Chris@49
|
88
|
Chris@49
|
89 out.set_size(in_n_rows, in_n_cols);
|
Chris@49
|
90
|
Chris@49
|
91 eT* out_mem = out.memptr();
|
Chris@49
|
92
|
Chris@49
|
93 if(in_dim == 0)
|
Chris@49
|
94 {
|
Chris@49
|
95 arrayops::copy( out_mem, B.memptr(), n_elem_to_copy );
|
Chris@49
|
96 }
|
Chris@49
|
97 else
|
Chris@49
|
98 {
|
Chris@49
|
99 uword row = 0;
|
Chris@49
|
100 uword col = 0;
|
Chris@49
|
101
|
Chris@49
|
102 const uword B_n_cols = B.n_cols;
|
Chris@49
|
103
|
Chris@49
|
104 for(uword i=0; i<n_elem_to_copy; ++i)
|
Chris@49
|
105 {
|
Chris@49
|
106 out_mem[i] = B.at(row,col);
|
Chris@49
|
107
|
Chris@49
|
108 ++col;
|
Chris@49
|
109
|
Chris@49
|
110 if(col >= B_n_cols)
|
Chris@49
|
111 {
|
Chris@49
|
112 col = 0;
|
Chris@49
|
113 ++row;
|
Chris@49
|
114 }
|
Chris@49
|
115 }
|
Chris@49
|
116 }
|
Chris@49
|
117
|
Chris@49
|
118 for(uword i=n_elem_to_copy; i<in_n_elem; ++i)
|
Chris@49
|
119 {
|
Chris@49
|
120 out_mem[i] = eT(0);
|
Chris@49
|
121 }
|
Chris@49
|
122
|
Chris@49
|
123 }
|
Chris@49
|
124 }
|
Chris@49
|
125
|
Chris@49
|
126
|
Chris@49
|
127
|
Chris@49
|
128 template<typename T1>
|
Chris@49
|
129 inline
|
Chris@49
|
130 void
|
Chris@49
|
131 op_reshape::apply(Cube<typename T1::elem_type>& out, const OpCube<T1,op_reshape>& in)
|
Chris@49
|
132 {
|
Chris@49
|
133 arma_extra_debug_sigprint();
|
Chris@49
|
134
|
Chris@49
|
135 typedef typename T1::elem_type eT;
|
Chris@49
|
136
|
Chris@49
|
137 const unwrap_cube<T1> A_tmp(in.m);
|
Chris@49
|
138 const Cube<eT>& A = A_tmp.M;
|
Chris@49
|
139
|
Chris@49
|
140 const uword in_n_rows = in.aux_uword_a;
|
Chris@49
|
141 const uword in_n_cols = in.aux_uword_b;
|
Chris@49
|
142 const uword in_n_slices = in.aux_uword_c;
|
Chris@49
|
143 const uword in_dim = in.aux_uword_d;
|
Chris@49
|
144
|
Chris@49
|
145 const uword in_n_elem = in_n_rows * in_n_cols * in_n_slices;
|
Chris@49
|
146
|
Chris@49
|
147 if(A.n_elem == in_n_elem)
|
Chris@49
|
148 {
|
Chris@49
|
149 if(in_dim == 0)
|
Chris@49
|
150 {
|
Chris@49
|
151 if(&out != &A)
|
Chris@49
|
152 {
|
Chris@49
|
153 out.set_size(in_n_rows, in_n_cols, in_n_slices);
|
Chris@49
|
154 arrayops::copy( out.memptr(), A.memptr(), out.n_elem );
|
Chris@49
|
155 }
|
Chris@49
|
156 else // &out == &A, i.e. inplace resize
|
Chris@49
|
157 {
|
Chris@49
|
158 const bool same_size = ( (out.n_rows == in_n_rows) && (out.n_cols == in_n_cols) && (out.n_slices == in_n_slices) );
|
Chris@49
|
159
|
Chris@49
|
160 if(same_size == false)
|
Chris@49
|
161 {
|
Chris@49
|
162 arma_debug_check
|
Chris@49
|
163 (
|
Chris@49
|
164 (out.mem_state == 3),
|
Chris@49
|
165 "reshape(): size can't be changed as template based size specification is in use"
|
Chris@49
|
166 );
|
Chris@49
|
167
|
Chris@49
|
168 out.delete_mat();
|
Chris@49
|
169
|
Chris@49
|
170 access::rw(out.n_rows) = in_n_rows;
|
Chris@49
|
171 access::rw(out.n_cols) = in_n_cols;
|
Chris@49
|
172 access::rw(out.n_elem_slice) = in_n_rows * in_n_cols;
|
Chris@49
|
173 access::rw(out.n_slices) = in_n_slices;
|
Chris@49
|
174
|
Chris@49
|
175 out.create_mat();
|
Chris@49
|
176 }
|
Chris@49
|
177 }
|
Chris@49
|
178 }
|
Chris@49
|
179 else
|
Chris@49
|
180 {
|
Chris@49
|
181 unwrap_cube_check< Cube<eT> > B_tmp(A, out);
|
Chris@49
|
182 const Cube<eT>& B = B_tmp.M;
|
Chris@49
|
183
|
Chris@49
|
184 out.set_size(in_n_rows, in_n_cols, in_n_slices);
|
Chris@49
|
185
|
Chris@49
|
186 eT* out_mem = out.memptr();
|
Chris@49
|
187 uword i = 0;
|
Chris@49
|
188
|
Chris@49
|
189 const uword B_n_rows = B.n_rows;
|
Chris@49
|
190 const uword B_n_cols = B.n_cols;
|
Chris@49
|
191 const uword B_n_slices = B.n_slices;
|
Chris@49
|
192
|
Chris@49
|
193 for(uword slice=0; slice<B_n_slices; ++slice)
|
Chris@49
|
194 for(uword row=0; row<B_n_rows; ++row)
|
Chris@49
|
195 for(uword col=0; col<B_n_cols; ++col)
|
Chris@49
|
196 {
|
Chris@49
|
197 out_mem[i] = B.at(row,col,slice);
|
Chris@49
|
198 ++i;
|
Chris@49
|
199 }
|
Chris@49
|
200 }
|
Chris@49
|
201 }
|
Chris@49
|
202 else
|
Chris@49
|
203 {
|
Chris@49
|
204 const unwrap_cube_check< Cube<eT> > B_tmp(A, out);
|
Chris@49
|
205 const Cube<eT>& B = B_tmp.M;
|
Chris@49
|
206
|
Chris@49
|
207 const uword n_elem_to_copy = (std::min)(B.n_elem, in_n_elem);
|
Chris@49
|
208
|
Chris@49
|
209 out.set_size(in_n_rows, in_n_cols, in_n_slices);
|
Chris@49
|
210
|
Chris@49
|
211 eT* out_mem = out.memptr();
|
Chris@49
|
212
|
Chris@49
|
213 if(in_dim == 0)
|
Chris@49
|
214 {
|
Chris@49
|
215 arrayops::copy( out_mem, B.memptr(), n_elem_to_copy );
|
Chris@49
|
216 }
|
Chris@49
|
217 else
|
Chris@49
|
218 {
|
Chris@49
|
219 uword row = 0;
|
Chris@49
|
220 uword col = 0;
|
Chris@49
|
221 uword slice = 0;
|
Chris@49
|
222
|
Chris@49
|
223 const uword B_n_rows = B.n_rows;
|
Chris@49
|
224 const uword B_n_cols = B.n_cols;
|
Chris@49
|
225
|
Chris@49
|
226 for(uword i=0; i<n_elem_to_copy; ++i)
|
Chris@49
|
227 {
|
Chris@49
|
228 out_mem[i] = B.at(row,col,slice);
|
Chris@49
|
229
|
Chris@49
|
230 ++col;
|
Chris@49
|
231
|
Chris@49
|
232 if(col >= B_n_cols)
|
Chris@49
|
233 {
|
Chris@49
|
234 col = 0;
|
Chris@49
|
235 ++row;
|
Chris@49
|
236
|
Chris@49
|
237 if(row >= B_n_rows)
|
Chris@49
|
238 {
|
Chris@49
|
239 row = 0;
|
Chris@49
|
240 ++slice;
|
Chris@49
|
241 }
|
Chris@49
|
242 }
|
Chris@49
|
243 }
|
Chris@49
|
244 }
|
Chris@49
|
245
|
Chris@49
|
246 for(uword i=n_elem_to_copy; i<in_n_elem; ++i)
|
Chris@49
|
247 {
|
Chris@49
|
248 out_mem[i] = eT(0);
|
Chris@49
|
249 }
|
Chris@49
|
250
|
Chris@49
|
251 }
|
Chris@49
|
252 }
|
Chris@49
|
253
|
Chris@49
|
254
|
Chris@49
|
255
|
Chris@49
|
256 //! @}
|