max@0
|
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2008-2011 Conrad Sanderson
|
max@0
|
3 //
|
max@0
|
4 // This file is part of the Armadillo C++ library.
|
max@0
|
5 // It is provided without any warranty of fitness
|
max@0
|
6 // for any purpose. You can redistribute this file
|
max@0
|
7 // and/or modify it under the terms of the GNU
|
max@0
|
8 // Lesser General Public License (LGPL) as published
|
max@0
|
9 // by the Free Software Foundation, either version 3
|
max@0
|
10 // of the License or (at your option) any later version.
|
max@0
|
11 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
12
|
max@0
|
13
|
max@0
|
14
|
max@0
|
15 //! \addtogroup op_reshape
|
max@0
|
16 //! @{
|
max@0
|
17
|
max@0
|
18
|
max@0
|
19
|
max@0
|
20 template<typename T1>
|
max@0
|
21 inline
|
max@0
|
22 void
|
max@0
|
23 op_reshape::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_reshape>& in)
|
max@0
|
24 {
|
max@0
|
25 arma_extra_debug_sigprint();
|
max@0
|
26
|
max@0
|
27 typedef typename T1::elem_type eT;
|
max@0
|
28
|
max@0
|
29 const unwrap<T1> tmp(in.m);
|
max@0
|
30 const Mat<eT>& A = tmp.M;
|
max@0
|
31
|
max@0
|
32 const uword in_n_rows = in.aux_uword_a;
|
max@0
|
33 const uword in_n_cols = in.aux_uword_b;
|
max@0
|
34 const uword in_dim = in.aux_uword_c;
|
max@0
|
35
|
max@0
|
36 const uword in_n_elem = in_n_rows * in_n_cols;
|
max@0
|
37
|
max@0
|
38 if(A.n_elem == in_n_elem)
|
max@0
|
39 {
|
max@0
|
40 if(in_dim == 0)
|
max@0
|
41 {
|
max@0
|
42 if(&out != &A)
|
max@0
|
43 {
|
max@0
|
44 out.set_size(in_n_rows, in_n_cols);
|
max@0
|
45 arrayops::copy( out.memptr(), A.memptr(), out.n_elem );
|
max@0
|
46 }
|
max@0
|
47 else // &out == &A, i.e. inplace resize
|
max@0
|
48 {
|
max@0
|
49 const bool same_size = ( (out.n_rows == in_n_rows) && (out.n_cols == in_n_cols) );
|
max@0
|
50
|
max@0
|
51 if(same_size == false)
|
max@0
|
52 {
|
max@0
|
53 arma_debug_check
|
max@0
|
54 (
|
max@0
|
55 (out.mem_state == 3),
|
max@0
|
56 "reshape(): size can't be changed as template based size specification is in use"
|
max@0
|
57 );
|
max@0
|
58
|
max@0
|
59 access::rw(out.n_rows) = in_n_rows;
|
max@0
|
60 access::rw(out.n_cols) = in_n_cols;
|
max@0
|
61 }
|
max@0
|
62 }
|
max@0
|
63 }
|
max@0
|
64 else
|
max@0
|
65 {
|
max@0
|
66 unwrap_check< Mat<eT> > tmp(A, out);
|
max@0
|
67 const Mat<eT>& B = tmp.M;
|
max@0
|
68
|
max@0
|
69 out.set_size(in_n_rows, in_n_cols);
|
max@0
|
70
|
max@0
|
71 eT* out_mem = out.memptr();
|
max@0
|
72 uword i = 0;
|
max@0
|
73
|
max@0
|
74 const uword B_n_rows = B.n_rows;
|
max@0
|
75 const uword B_n_cols = B.n_cols;
|
max@0
|
76
|
max@0
|
77 for(uword row=0; row<B_n_rows; ++row)
|
max@0
|
78 {
|
max@0
|
79 for(uword col=0; col<B_n_cols; ++col)
|
max@0
|
80 {
|
max@0
|
81 out_mem[i] = B.at(row,col);
|
max@0
|
82 ++i;
|
max@0
|
83 }
|
max@0
|
84 }
|
max@0
|
85
|
max@0
|
86 }
|
max@0
|
87 }
|
max@0
|
88 else
|
max@0
|
89 {
|
max@0
|
90 const unwrap_check< Mat<eT> > tmp(A, out);
|
max@0
|
91 const Mat<eT>& B = tmp.M;
|
max@0
|
92
|
max@0
|
93 const uword n_elem_to_copy = (std::min)(B.n_elem, in_n_elem);
|
max@0
|
94
|
max@0
|
95 out.set_size(in_n_rows, in_n_cols);
|
max@0
|
96
|
max@0
|
97 eT* out_mem = out.memptr();
|
max@0
|
98
|
max@0
|
99 if(in_dim == 0)
|
max@0
|
100 {
|
max@0
|
101 arrayops::copy( out_mem, B.memptr(), n_elem_to_copy );
|
max@0
|
102 }
|
max@0
|
103 else
|
max@0
|
104 {
|
max@0
|
105 uword row = 0;
|
max@0
|
106 uword col = 0;
|
max@0
|
107
|
max@0
|
108 const uword B_n_cols = B.n_cols;
|
max@0
|
109
|
max@0
|
110 for(uword i=0; i<n_elem_to_copy; ++i)
|
max@0
|
111 {
|
max@0
|
112 out_mem[i] = B.at(row,col);
|
max@0
|
113
|
max@0
|
114 ++col;
|
max@0
|
115
|
max@0
|
116 if(col >= B_n_cols)
|
max@0
|
117 {
|
max@0
|
118 col = 0;
|
max@0
|
119 ++row;
|
max@0
|
120 }
|
max@0
|
121 }
|
max@0
|
122 }
|
max@0
|
123
|
max@0
|
124 for(uword i=n_elem_to_copy; i<in_n_elem; ++i)
|
max@0
|
125 {
|
max@0
|
126 out_mem[i] = eT(0);
|
max@0
|
127 }
|
max@0
|
128
|
max@0
|
129 }
|
max@0
|
130 }
|
max@0
|
131
|
max@0
|
132
|
max@0
|
133
|
max@0
|
134 template<typename T1>
|
max@0
|
135 inline
|
max@0
|
136 void
|
max@0
|
137 op_reshape::apply(Cube<typename T1::elem_type>& out, const OpCube<T1,op_reshape>& in)
|
max@0
|
138 {
|
max@0
|
139 arma_extra_debug_sigprint();
|
max@0
|
140
|
max@0
|
141 typedef typename T1::elem_type eT;
|
max@0
|
142
|
max@0
|
143 const unwrap_cube<T1> tmp(in.m);
|
max@0
|
144 const Cube<eT>& A = tmp.M;
|
max@0
|
145
|
max@0
|
146 const uword in_n_rows = in.aux_uword_a;
|
max@0
|
147 const uword in_n_cols = in.aux_uword_b;
|
max@0
|
148 const uword in_n_slices = in.aux_uword_c;
|
max@0
|
149 const uword in_dim = in.aux_uword_d;
|
max@0
|
150
|
max@0
|
151 const uword in_n_elem = in_n_rows * in_n_cols * in_n_slices;
|
max@0
|
152
|
max@0
|
153 if(A.n_elem == in_n_elem)
|
max@0
|
154 {
|
max@0
|
155 if(in_dim == 0)
|
max@0
|
156 {
|
max@0
|
157 if(&out != &A)
|
max@0
|
158 {
|
max@0
|
159 out.set_size(in_n_rows, in_n_cols, in_n_slices);
|
max@0
|
160 arrayops::copy( out.memptr(), A.memptr(), out.n_elem );
|
max@0
|
161 }
|
max@0
|
162 else // &out == &A, i.e. inplace resize
|
max@0
|
163 {
|
max@0
|
164 const bool same_size = ( (out.n_rows == in_n_rows) && (out.n_cols == in_n_cols) && (out.n_slices == in_n_slices) );
|
max@0
|
165
|
max@0
|
166 if(same_size == false)
|
max@0
|
167 {
|
max@0
|
168 arma_debug_check
|
max@0
|
169 (
|
max@0
|
170 (out.mem_state == 3),
|
max@0
|
171 "reshape(): size can't be changed as template based size specification is in use"
|
max@0
|
172 );
|
max@0
|
173
|
max@0
|
174 out.delete_mat();
|
max@0
|
175
|
max@0
|
176 access::rw(out.n_rows) = in_n_rows;
|
max@0
|
177 access::rw(out.n_cols) = in_n_cols;
|
max@0
|
178 access::rw(out.n_elem_slice) = in_n_rows * in_n_cols;
|
max@0
|
179 access::rw(out.n_slices) = in_n_slices;
|
max@0
|
180
|
max@0
|
181 out.create_mat();
|
max@0
|
182 }
|
max@0
|
183 }
|
max@0
|
184 }
|
max@0
|
185 else
|
max@0
|
186 {
|
max@0
|
187 unwrap_cube_check< Cube<eT> > tmp(A, out);
|
max@0
|
188 const Cube<eT>& B = tmp.M;
|
max@0
|
189
|
max@0
|
190 out.set_size(in_n_rows, in_n_cols, in_n_slices);
|
max@0
|
191
|
max@0
|
192 eT* out_mem = out.memptr();
|
max@0
|
193 uword i = 0;
|
max@0
|
194
|
max@0
|
195 const uword B_n_rows = B.n_rows;
|
max@0
|
196 const uword B_n_cols = B.n_cols;
|
max@0
|
197 const uword B_n_slices = B.n_slices;
|
max@0
|
198
|
max@0
|
199 for(uword slice=0; slice<B_n_slices; ++slice)
|
max@0
|
200 {
|
max@0
|
201 for(uword row=0; row<B_n_rows; ++row)
|
max@0
|
202 {
|
max@0
|
203 for(uword col=0; col<B_n_cols; ++col)
|
max@0
|
204 {
|
max@0
|
205 out_mem[i] = B.at(row,col,slice);
|
max@0
|
206 ++i;
|
max@0
|
207 }
|
max@0
|
208 }
|
max@0
|
209 }
|
max@0
|
210
|
max@0
|
211 }
|
max@0
|
212 }
|
max@0
|
213 else
|
max@0
|
214 {
|
max@0
|
215 const unwrap_cube_check< Cube<eT> > tmp(A, out);
|
max@0
|
216 const Cube<eT>& B = tmp.M;
|
max@0
|
217
|
max@0
|
218 const uword n_elem_to_copy = (std::min)(B.n_elem, in_n_elem);
|
max@0
|
219
|
max@0
|
220 out.set_size(in_n_rows, in_n_cols, in_n_slices);
|
max@0
|
221
|
max@0
|
222 eT* out_mem = out.memptr();
|
max@0
|
223
|
max@0
|
224 if(in_dim == 0)
|
max@0
|
225 {
|
max@0
|
226 arrayops::copy( out_mem, B.memptr(), n_elem_to_copy );
|
max@0
|
227 }
|
max@0
|
228 else
|
max@0
|
229 {
|
max@0
|
230 uword row = 0;
|
max@0
|
231 uword col = 0;
|
max@0
|
232 uword slice = 0;
|
max@0
|
233
|
max@0
|
234 const uword B_n_rows = B.n_rows;
|
max@0
|
235 const uword B_n_cols = B.n_cols;
|
max@0
|
236
|
max@0
|
237 for(uword i=0; i<n_elem_to_copy; ++i)
|
max@0
|
238 {
|
max@0
|
239 out_mem[i] = B.at(row,col,slice);
|
max@0
|
240
|
max@0
|
241 ++col;
|
max@0
|
242
|
max@0
|
243 if(col >= B_n_cols)
|
max@0
|
244 {
|
max@0
|
245 col = 0;
|
max@0
|
246 ++row;
|
max@0
|
247
|
max@0
|
248 if(row >= B_n_rows)
|
max@0
|
249 {
|
max@0
|
250 row = 0;
|
max@0
|
251 ++slice;
|
max@0
|
252 }
|
max@0
|
253 }
|
max@0
|
254 }
|
max@0
|
255 }
|
max@0
|
256
|
max@0
|
257 for(uword i=n_elem_to_copy; i<in_n_elem; ++i)
|
max@0
|
258 {
|
max@0
|
259 out_mem[i] = eT(0);
|
max@0
|
260 }
|
max@0
|
261
|
max@0
|
262 }
|
max@0
|
263 }
|
max@0
|
264
|
max@0
|
265
|
max@0
|
266
|
max@0
|
267 //! @}
|