Chris@49
|
1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2012 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9 //! \addtogroup op_sort
|
Chris@49
|
10 //! @{
|
Chris@49
|
11
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 template<typename eT>
|
Chris@49
|
15 class arma_ascend_sort_helper
|
Chris@49
|
16 {
|
Chris@49
|
17 public:
|
Chris@49
|
18
|
Chris@49
|
19 arma_inline
|
Chris@49
|
20 bool
|
Chris@49
|
21 operator() (eT a, eT b) const
|
Chris@49
|
22 {
|
Chris@49
|
23 return (a < b);
|
Chris@49
|
24 }
|
Chris@49
|
25 };
|
Chris@49
|
26
|
Chris@49
|
27
|
Chris@49
|
28
|
Chris@49
|
29 template<typename eT>
|
Chris@49
|
30 class arma_descend_sort_helper
|
Chris@49
|
31 {
|
Chris@49
|
32 public:
|
Chris@49
|
33
|
Chris@49
|
34 arma_inline
|
Chris@49
|
35 bool
|
Chris@49
|
36 operator() (eT a, eT b) const
|
Chris@49
|
37 {
|
Chris@49
|
38 return (a > b);
|
Chris@49
|
39 }
|
Chris@49
|
40 };
|
Chris@49
|
41
|
Chris@49
|
42
|
Chris@49
|
43
|
Chris@49
|
44 template<typename T>
|
Chris@49
|
45 class arma_ascend_sort_helper< std::complex<T> >
|
Chris@49
|
46 {
|
Chris@49
|
47 public:
|
Chris@49
|
48
|
Chris@49
|
49 typedef typename std::complex<T> eT;
|
Chris@49
|
50
|
Chris@49
|
51 inline
|
Chris@49
|
52 bool
|
Chris@49
|
53 operator() (const eT& a, const eT& b) const
|
Chris@49
|
54 {
|
Chris@49
|
55 return (std::abs(a) < std::abs(b));
|
Chris@49
|
56 }
|
Chris@49
|
57 };
|
Chris@49
|
58
|
Chris@49
|
59
|
Chris@49
|
60
|
Chris@49
|
61 template<typename T>
|
Chris@49
|
62 class arma_descend_sort_helper< std::complex<T> >
|
Chris@49
|
63 {
|
Chris@49
|
64 public:
|
Chris@49
|
65
|
Chris@49
|
66 typedef typename std::complex<T> eT;
|
Chris@49
|
67
|
Chris@49
|
68 inline
|
Chris@49
|
69 bool
|
Chris@49
|
70 operator() (const eT& a, const eT& b) const
|
Chris@49
|
71 {
|
Chris@49
|
72 return (std::abs(a) > std::abs(b));
|
Chris@49
|
73 }
|
Chris@49
|
74 };
|
Chris@49
|
75
|
Chris@49
|
76
|
Chris@49
|
77
|
Chris@49
|
78 template<typename eT>
|
Chris@49
|
79 inline
|
Chris@49
|
80 void
|
Chris@49
|
81 op_sort::direct_sort(eT* X, const uword n_elem, const uword sort_type)
|
Chris@49
|
82 {
|
Chris@49
|
83 arma_extra_debug_sigprint();
|
Chris@49
|
84
|
Chris@49
|
85 if(sort_type == 0)
|
Chris@49
|
86 {
|
Chris@49
|
87 arma_ascend_sort_helper<eT> comparator;
|
Chris@49
|
88
|
Chris@49
|
89 std::sort(&X[0], &X[n_elem], comparator);
|
Chris@49
|
90 }
|
Chris@49
|
91 else
|
Chris@49
|
92 {
|
Chris@49
|
93 arma_descend_sort_helper<eT> comparator;
|
Chris@49
|
94
|
Chris@49
|
95 std::sort(&X[0], &X[n_elem], comparator);
|
Chris@49
|
96 }
|
Chris@49
|
97 }
|
Chris@49
|
98
|
Chris@49
|
99
|
Chris@49
|
100
|
Chris@49
|
101 template<typename eT>
|
Chris@49
|
102 inline
|
Chris@49
|
103 void
|
Chris@49
|
104 op_sort::direct_sort_ascending(eT* X, const uword n_elem)
|
Chris@49
|
105 {
|
Chris@49
|
106 arma_extra_debug_sigprint();
|
Chris@49
|
107
|
Chris@49
|
108 arma_ascend_sort_helper<eT> comparator;
|
Chris@49
|
109
|
Chris@49
|
110 std::sort(&X[0], &X[n_elem], comparator);
|
Chris@49
|
111 }
|
Chris@49
|
112
|
Chris@49
|
113
|
Chris@49
|
114
|
Chris@49
|
115 template<typename eT>
|
Chris@49
|
116 inline
|
Chris@49
|
117 void
|
Chris@49
|
118 op_sort::copy_row(eT* X, const Mat<eT>& A, const uword row)
|
Chris@49
|
119 {
|
Chris@49
|
120 const uword N = A.n_cols;
|
Chris@49
|
121
|
Chris@49
|
122 uword i,j;
|
Chris@49
|
123
|
Chris@49
|
124 for(i=0, j=1; j<N; i+=2, j+=2)
|
Chris@49
|
125 {
|
Chris@49
|
126 X[i] = A.at(row,i);
|
Chris@49
|
127 X[j] = A.at(row,j);
|
Chris@49
|
128 }
|
Chris@49
|
129
|
Chris@49
|
130 if(i < N)
|
Chris@49
|
131 {
|
Chris@49
|
132 X[i] = A.at(row,i);
|
Chris@49
|
133 }
|
Chris@49
|
134 }
|
Chris@49
|
135
|
Chris@49
|
136
|
Chris@49
|
137
|
Chris@49
|
138 template<typename eT>
|
Chris@49
|
139 inline
|
Chris@49
|
140 void
|
Chris@49
|
141 op_sort::copy_row(Mat<eT>& A, const eT* X, const uword row)
|
Chris@49
|
142 {
|
Chris@49
|
143 const uword N = A.n_cols;
|
Chris@49
|
144
|
Chris@49
|
145 uword i,j;
|
Chris@49
|
146
|
Chris@49
|
147 for(i=0, j=1; j<N; i+=2, j+=2)
|
Chris@49
|
148 {
|
Chris@49
|
149 A.at(row,i) = X[i];
|
Chris@49
|
150 A.at(row,j) = X[j];
|
Chris@49
|
151 }
|
Chris@49
|
152
|
Chris@49
|
153 if(i < N)
|
Chris@49
|
154 {
|
Chris@49
|
155 A.at(row,i) = X[i];
|
Chris@49
|
156 }
|
Chris@49
|
157 }
|
Chris@49
|
158
|
Chris@49
|
159
|
Chris@49
|
160
|
Chris@49
|
161 template<typename T1>
|
Chris@49
|
162 inline
|
Chris@49
|
163 void
|
Chris@49
|
164 op_sort::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_sort>& in)
|
Chris@49
|
165 {
|
Chris@49
|
166 arma_extra_debug_sigprint();
|
Chris@49
|
167
|
Chris@49
|
168 typedef typename T1::elem_type eT;
|
Chris@49
|
169
|
Chris@49
|
170 const unwrap_check<T1> tmp(in.m, out);
|
Chris@49
|
171 const Mat<eT>& X = tmp.M;
|
Chris@49
|
172
|
Chris@49
|
173 const uword sort_type = in.aux_uword_a;
|
Chris@49
|
174 const uword dim = in.aux_uword_b;
|
Chris@49
|
175
|
Chris@49
|
176 arma_debug_check( (sort_type > 1), "sort(): incorrect usage. sort_type must be 0 or 1");
|
Chris@49
|
177 arma_debug_check( (dim > 1), "sort(): incorrect usage. dim must be 0 or 1" );
|
Chris@49
|
178 arma_debug_check( (X.is_finite() == false), "sort(): given object has non-finite elements" );
|
Chris@49
|
179
|
Chris@49
|
180 if( (X.n_rows * X.n_cols) <= 1 )
|
Chris@49
|
181 {
|
Chris@49
|
182 out = X;
|
Chris@49
|
183 return;
|
Chris@49
|
184 }
|
Chris@49
|
185
|
Chris@49
|
186
|
Chris@49
|
187 if(dim == 0) // sort the contents of each column
|
Chris@49
|
188 {
|
Chris@49
|
189 arma_extra_debug_print("op_sort::apply(), dim = 0");
|
Chris@49
|
190
|
Chris@49
|
191 out = X;
|
Chris@49
|
192
|
Chris@49
|
193 const uword n_rows = out.n_rows;
|
Chris@49
|
194 const uword n_cols = out.n_cols;
|
Chris@49
|
195
|
Chris@49
|
196 for(uword col=0; col < n_cols; ++col)
|
Chris@49
|
197 {
|
Chris@49
|
198 op_sort::direct_sort( out.colptr(col), n_rows, sort_type );
|
Chris@49
|
199 }
|
Chris@49
|
200 }
|
Chris@49
|
201 else
|
Chris@49
|
202 if(dim == 1) // sort the contents of each row
|
Chris@49
|
203 {
|
Chris@49
|
204 if(X.n_rows == 1) // a row vector
|
Chris@49
|
205 {
|
Chris@49
|
206 arma_extra_debug_print("op_sort::apply(), dim = 1, vector specific");
|
Chris@49
|
207
|
Chris@49
|
208 out = X;
|
Chris@49
|
209 op_sort::direct_sort(out.memptr(), out.n_elem, sort_type);
|
Chris@49
|
210 }
|
Chris@49
|
211 else // not a row vector
|
Chris@49
|
212 {
|
Chris@49
|
213 arma_extra_debug_print("op_sort::apply(), dim = 1, generic");
|
Chris@49
|
214
|
Chris@49
|
215 out.copy_size(X);
|
Chris@49
|
216
|
Chris@49
|
217 const uword n_rows = out.n_rows;
|
Chris@49
|
218 const uword n_cols = out.n_cols;
|
Chris@49
|
219
|
Chris@49
|
220 podarray<eT> tmp_array(n_cols);
|
Chris@49
|
221
|
Chris@49
|
222 for(uword row=0; row < n_rows; ++row)
|
Chris@49
|
223 {
|
Chris@49
|
224 op_sort::copy_row(tmp_array.memptr(), X, row);
|
Chris@49
|
225
|
Chris@49
|
226 op_sort::direct_sort( tmp_array.memptr(), n_cols, sort_type );
|
Chris@49
|
227
|
Chris@49
|
228 op_sort::copy_row(out, tmp_array.memptr(), row);
|
Chris@49
|
229 }
|
Chris@49
|
230 }
|
Chris@49
|
231 }
|
Chris@49
|
232
|
Chris@49
|
233 }
|
Chris@49
|
234
|
Chris@49
|
235
|
Chris@49
|
236 //! @}
|