max@0
|
1 // Copyright (C) 2010-2011 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2010-2011 Conrad Sanderson
|
max@0
|
3 // Copyright (C) 2011 Ryan Curtin
|
max@0
|
4 //
|
max@0
|
5 // This file is part of the Armadillo C++ library.
|
max@0
|
6 // It is provided without any warranty of fitness
|
max@0
|
7 // for any purpose. You can redistribute this file
|
max@0
|
8 // and/or modify it under the terms of the GNU
|
max@0
|
9 // Lesser General Public License (LGPL) as published
|
max@0
|
10 // by the Free Software Foundation, either version 3
|
max@0
|
11 // of the License or (at your option) any later version.
|
max@0
|
12 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
13
|
max@0
|
14
|
max@0
|
15 //! \addtogroup op_trimat
|
max@0
|
16 //! @{
|
max@0
|
17
|
max@0
|
18
|
max@0
|
19
|
max@0
|
20 template<typename eT>
|
max@0
|
21 inline
|
max@0
|
22 void
|
max@0
|
23 op_trimat::fill_zeros(Mat<eT>& out, const bool upper)
|
max@0
|
24 {
|
max@0
|
25 arma_extra_debug_sigprint();
|
max@0
|
26
|
max@0
|
27 const uword N = out.n_rows;
|
max@0
|
28
|
max@0
|
29 if(upper)
|
max@0
|
30 {
|
max@0
|
31 // upper triangular: set all elements below the diagonal to zero
|
max@0
|
32
|
max@0
|
33 for(uword i=0; i<N; ++i)
|
max@0
|
34 {
|
max@0
|
35 eT* data = out.colptr(i);
|
max@0
|
36
|
max@0
|
37 arrayops::inplace_set( &data[i+1], eT(0), (N-(i+1)) );
|
max@0
|
38 }
|
max@0
|
39 }
|
max@0
|
40 else
|
max@0
|
41 {
|
max@0
|
42 // lower triangular: set all elements above the diagonal to zero
|
max@0
|
43
|
max@0
|
44 for(uword i=1; i<N; ++i)
|
max@0
|
45 {
|
max@0
|
46 eT* data = out.colptr(i);
|
max@0
|
47
|
max@0
|
48 arrayops::inplace_set( data, eT(0), i );
|
max@0
|
49 }
|
max@0
|
50 }
|
max@0
|
51 }
|
max@0
|
52
|
max@0
|
53
|
max@0
|
54
|
max@0
|
55 template<typename T1>
|
max@0
|
56 inline
|
max@0
|
57 void
|
max@0
|
58 op_trimat::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_trimat>& in)
|
max@0
|
59 {
|
max@0
|
60 arma_extra_debug_sigprint();
|
max@0
|
61
|
max@0
|
62 typedef typename T1::elem_type eT;
|
max@0
|
63
|
max@0
|
64 const unwrap<T1> tmp(in.m);
|
max@0
|
65 const Mat<eT>& A = tmp.M;
|
max@0
|
66
|
max@0
|
67 arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square" );
|
max@0
|
68
|
max@0
|
69 const uword N = A.n_rows;
|
max@0
|
70 const bool upper = (in.aux_uword_a == 0);
|
max@0
|
71
|
max@0
|
72 if(&out != &A)
|
max@0
|
73 {
|
max@0
|
74 out.copy_size(A);
|
max@0
|
75
|
max@0
|
76 if(upper)
|
max@0
|
77 {
|
max@0
|
78 // upper triangular: copy the diagonal and the elements above the diagonal
|
max@0
|
79 for(uword i=0; i<N; ++i)
|
max@0
|
80 {
|
max@0
|
81 const eT* A_data = A.colptr(i);
|
max@0
|
82 eT* out_data = out.colptr(i);
|
max@0
|
83
|
max@0
|
84 arrayops::copy( out_data, A_data, i+1 );
|
max@0
|
85 }
|
max@0
|
86 }
|
max@0
|
87 else
|
max@0
|
88 {
|
max@0
|
89 // lower triangular: copy the diagonal and the elements below the diagonal
|
max@0
|
90 for(uword i=0; i<N; ++i)
|
max@0
|
91 {
|
max@0
|
92 const eT* A_data = A.colptr(i);
|
max@0
|
93 eT* out_data = out.colptr(i);
|
max@0
|
94
|
max@0
|
95 arrayops::copy( &out_data[i], &A_data[i], N-i );
|
max@0
|
96 }
|
max@0
|
97 }
|
max@0
|
98 }
|
max@0
|
99
|
max@0
|
100 op_trimat::fill_zeros(out, upper);
|
max@0
|
101 }
|
max@0
|
102
|
max@0
|
103
|
max@0
|
104
|
max@0
|
105 template<typename T1>
|
max@0
|
106 inline
|
max@0
|
107 void
|
max@0
|
108 op_trimat::apply(Mat<typename T1::elem_type>& out, const Op<Op<T1, op_htrans>, op_trimat>& in)
|
max@0
|
109 {
|
max@0
|
110 arma_extra_debug_sigprint();
|
max@0
|
111
|
max@0
|
112 typedef typename T1::elem_type eT;
|
max@0
|
113
|
max@0
|
114 const unwrap<T1> tmp(in.m.m);
|
max@0
|
115 const Mat<eT>& A = tmp.M;
|
max@0
|
116
|
max@0
|
117 const bool upper = (in.aux_uword_a == 0);
|
max@0
|
118
|
max@0
|
119 op_trimat::apply_htrans(out, A, upper);
|
max@0
|
120 }
|
max@0
|
121
|
max@0
|
122
|
max@0
|
123
|
max@0
|
124 template<typename eT>
|
max@0
|
125 inline
|
max@0
|
126 void
|
max@0
|
127 op_trimat::apply_htrans
|
max@0
|
128 (
|
max@0
|
129 Mat<eT>& out,
|
max@0
|
130 const Mat<eT>& A,
|
max@0
|
131 const bool upper,
|
max@0
|
132 const typename arma_not_cx<eT>::result* junk
|
max@0
|
133 )
|
max@0
|
134 {
|
max@0
|
135 arma_extra_debug_sigprint();
|
max@0
|
136 arma_ignore(junk);
|
max@0
|
137
|
max@0
|
138 // This specialisation is for trimatl(trans(X)) = trans(trimatu(X)) and also
|
max@0
|
139 // trimatu(trans(X)) = trans(trimatl(X)). We want to avoid the creation of an
|
max@0
|
140 // extra temporary.
|
max@0
|
141
|
max@0
|
142 // It doesn't matter if the input and output matrices are the same; we will
|
max@0
|
143 // pull data from the upper or lower triangular to the lower or upper
|
max@0
|
144 // triangular (respectively) and then set the rest to 0, so overwriting issues
|
max@0
|
145 // aren't present.
|
max@0
|
146
|
max@0
|
147 arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square" );
|
max@0
|
148
|
max@0
|
149 const uword N = A.n_rows;
|
max@0
|
150
|
max@0
|
151 if(&out != &A)
|
max@0
|
152 {
|
max@0
|
153 out.copy_size(A);
|
max@0
|
154 }
|
max@0
|
155
|
max@0
|
156 // We can't really get away with any array copy operations here,
|
max@0
|
157 // unfortunately...
|
max@0
|
158
|
max@0
|
159 if(upper)
|
max@0
|
160 {
|
max@0
|
161 // Upper triangular: but since we're transposing, we're taking the lower
|
max@0
|
162 // triangular and putting it in the upper half.
|
max@0
|
163 for(uword row = 0; row < N; ++row)
|
max@0
|
164 {
|
max@0
|
165 eT* out_colptr = out.colptr(row);
|
max@0
|
166
|
max@0
|
167 for(uword col = 0; col <= row; ++col)
|
max@0
|
168 {
|
max@0
|
169 //out.at(col, row) = A.at(row, col);
|
max@0
|
170 out_colptr[col] = A.at(row, col);
|
max@0
|
171 }
|
max@0
|
172 }
|
max@0
|
173 }
|
max@0
|
174 else
|
max@0
|
175 {
|
max@0
|
176 // Lower triangular: but since we're transposing, we're taking the upper
|
max@0
|
177 // triangular and putting it in the lower half.
|
max@0
|
178 for(uword row = 0; row < N; ++row)
|
max@0
|
179 {
|
max@0
|
180 for(uword col = row; col < N; ++col)
|
max@0
|
181 {
|
max@0
|
182 out.at(col, row) = A.at(row, col);
|
max@0
|
183 }
|
max@0
|
184 }
|
max@0
|
185 }
|
max@0
|
186
|
max@0
|
187 op_trimat::fill_zeros(out, upper);
|
max@0
|
188 }
|
max@0
|
189
|
max@0
|
190
|
max@0
|
191
|
max@0
|
192 template<typename eT>
|
max@0
|
193 inline
|
max@0
|
194 void
|
max@0
|
195 op_trimat::apply_htrans
|
max@0
|
196 (
|
max@0
|
197 Mat<eT>& out,
|
max@0
|
198 const Mat<eT>& A,
|
max@0
|
199 const bool upper,
|
max@0
|
200 const typename arma_cx_only<eT>::result* junk
|
max@0
|
201 )
|
max@0
|
202 {
|
max@0
|
203 arma_extra_debug_sigprint();
|
max@0
|
204 arma_ignore(junk);
|
max@0
|
205
|
max@0
|
206 arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square" );
|
max@0
|
207
|
max@0
|
208 const uword N = A.n_rows;
|
max@0
|
209
|
max@0
|
210 if(&out != &A)
|
max@0
|
211 {
|
max@0
|
212 out.copy_size(A);
|
max@0
|
213 }
|
max@0
|
214
|
max@0
|
215 if(upper)
|
max@0
|
216 {
|
max@0
|
217 // Upper triangular: but since we're transposing, we're taking the lower
|
max@0
|
218 // triangular and putting it in the upper half.
|
max@0
|
219 for(uword row = 0; row < N; ++row)
|
max@0
|
220 {
|
max@0
|
221 eT* out_colptr = out.colptr(row);
|
max@0
|
222
|
max@0
|
223 for(uword col = 0; col <= row; ++col)
|
max@0
|
224 {
|
max@0
|
225 //out.at(col, row) = std::conj( A.at(row, col) );
|
max@0
|
226 out_colptr[col] = std::conj( A.at(row, col) );
|
max@0
|
227 }
|
max@0
|
228 }
|
max@0
|
229 }
|
max@0
|
230 else
|
max@0
|
231 {
|
max@0
|
232 // Lower triangular: but since we're transposing, we're taking the upper
|
max@0
|
233 // triangular and putting it in the lower half.
|
max@0
|
234 for(uword row = 0; row < N; ++row)
|
max@0
|
235 {
|
max@0
|
236 for(uword col = row; col < N; ++col)
|
max@0
|
237 {
|
max@0
|
238 out.at(col, row) = std::conj( A.at(row, col) );
|
max@0
|
239 }
|
max@0
|
240 }
|
max@0
|
241 }
|
max@0
|
242
|
max@0
|
243 op_trimat::fill_zeros(out, upper);
|
max@0
|
244 }
|
max@0
|
245
|
max@0
|
246
|
max@0
|
247
|
max@0
|
248 //! @}
|