Chris@49
|
1 // Copyright (C) 2010-2012 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2010-2012 Conrad Sanderson
|
Chris@49
|
3 // Copyright (C) 2011 Ryan Curtin
|
Chris@49
|
4 //
|
Chris@49
|
5 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
8
|
Chris@49
|
9
|
Chris@49
|
10 //! \addtogroup op_trimat
|
Chris@49
|
11 //! @{
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14
|
Chris@49
|
15 template<typename eT>
|
Chris@49
|
16 inline
|
Chris@49
|
17 void
|
Chris@49
|
18 op_trimat::fill_zeros(Mat<eT>& out, const bool upper)
|
Chris@49
|
19 {
|
Chris@49
|
20 arma_extra_debug_sigprint();
|
Chris@49
|
21
|
Chris@49
|
22 const uword N = out.n_rows;
|
Chris@49
|
23
|
Chris@49
|
24 if(upper)
|
Chris@49
|
25 {
|
Chris@49
|
26 // upper triangular: set all elements below the diagonal to zero
|
Chris@49
|
27
|
Chris@49
|
28 for(uword i=0; i<N; ++i)
|
Chris@49
|
29 {
|
Chris@49
|
30 eT* data = out.colptr(i);
|
Chris@49
|
31
|
Chris@49
|
32 arrayops::inplace_set( &data[i+1], eT(0), (N-(i+1)) );
|
Chris@49
|
33 }
|
Chris@49
|
34 }
|
Chris@49
|
35 else
|
Chris@49
|
36 {
|
Chris@49
|
37 // lower triangular: set all elements above the diagonal to zero
|
Chris@49
|
38
|
Chris@49
|
39 for(uword i=1; i<N; ++i)
|
Chris@49
|
40 {
|
Chris@49
|
41 eT* data = out.colptr(i);
|
Chris@49
|
42
|
Chris@49
|
43 arrayops::inplace_set( data, eT(0), i );
|
Chris@49
|
44 }
|
Chris@49
|
45 }
|
Chris@49
|
46 }
|
Chris@49
|
47
|
Chris@49
|
48
|
Chris@49
|
49
|
Chris@49
|
50 template<typename T1>
|
Chris@49
|
51 inline
|
Chris@49
|
52 void
|
Chris@49
|
53 op_trimat::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_trimat>& in)
|
Chris@49
|
54 {
|
Chris@49
|
55 arma_extra_debug_sigprint();
|
Chris@49
|
56
|
Chris@49
|
57 typedef typename T1::elem_type eT;
|
Chris@49
|
58
|
Chris@49
|
59 const unwrap<T1> tmp(in.m);
|
Chris@49
|
60 const Mat<eT>& A = tmp.M;
|
Chris@49
|
61
|
Chris@49
|
62 arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square" );
|
Chris@49
|
63
|
Chris@49
|
64 const uword N = A.n_rows;
|
Chris@49
|
65 const bool upper = (in.aux_uword_a == 0);
|
Chris@49
|
66
|
Chris@49
|
67 if(&out != &A)
|
Chris@49
|
68 {
|
Chris@49
|
69 out.copy_size(A);
|
Chris@49
|
70
|
Chris@49
|
71 if(upper)
|
Chris@49
|
72 {
|
Chris@49
|
73 // upper triangular: copy the diagonal and the elements above the diagonal
|
Chris@49
|
74 for(uword i=0; i<N; ++i)
|
Chris@49
|
75 {
|
Chris@49
|
76 const eT* A_data = A.colptr(i);
|
Chris@49
|
77 eT* out_data = out.colptr(i);
|
Chris@49
|
78
|
Chris@49
|
79 arrayops::copy( out_data, A_data, i+1 );
|
Chris@49
|
80 }
|
Chris@49
|
81 }
|
Chris@49
|
82 else
|
Chris@49
|
83 {
|
Chris@49
|
84 // lower triangular: copy the diagonal and the elements below the diagonal
|
Chris@49
|
85 for(uword i=0; i<N; ++i)
|
Chris@49
|
86 {
|
Chris@49
|
87 const eT* A_data = A.colptr(i);
|
Chris@49
|
88 eT* out_data = out.colptr(i);
|
Chris@49
|
89
|
Chris@49
|
90 arrayops::copy( &out_data[i], &A_data[i], N-i );
|
Chris@49
|
91 }
|
Chris@49
|
92 }
|
Chris@49
|
93 }
|
Chris@49
|
94
|
Chris@49
|
95 op_trimat::fill_zeros(out, upper);
|
Chris@49
|
96 }
|
Chris@49
|
97
|
Chris@49
|
98
|
Chris@49
|
99
|
Chris@49
|
100 template<typename T1>
|
Chris@49
|
101 inline
|
Chris@49
|
102 void
|
Chris@49
|
103 op_trimat::apply(Mat<typename T1::elem_type>& out, const Op<Op<T1, op_htrans>, op_trimat>& in)
|
Chris@49
|
104 {
|
Chris@49
|
105 arma_extra_debug_sigprint();
|
Chris@49
|
106
|
Chris@49
|
107 typedef typename T1::elem_type eT;
|
Chris@49
|
108
|
Chris@49
|
109 const unwrap<T1> tmp(in.m.m);
|
Chris@49
|
110 const Mat<eT>& A = tmp.M;
|
Chris@49
|
111
|
Chris@49
|
112 const bool upper = (in.aux_uword_a == 0);
|
Chris@49
|
113
|
Chris@49
|
114 op_trimat::apply_htrans(out, A, upper);
|
Chris@49
|
115 }
|
Chris@49
|
116
|
Chris@49
|
117
|
Chris@49
|
118
|
Chris@49
|
119 template<typename eT>
|
Chris@49
|
120 inline
|
Chris@49
|
121 void
|
Chris@49
|
122 op_trimat::apply_htrans
|
Chris@49
|
123 (
|
Chris@49
|
124 Mat<eT>& out,
|
Chris@49
|
125 const Mat<eT>& A,
|
Chris@49
|
126 const bool upper,
|
Chris@49
|
127 const typename arma_not_cx<eT>::result* junk
|
Chris@49
|
128 )
|
Chris@49
|
129 {
|
Chris@49
|
130 arma_extra_debug_sigprint();
|
Chris@49
|
131 arma_ignore(junk);
|
Chris@49
|
132
|
Chris@49
|
133 // This specialisation is for trimatl(trans(X)) = trans(trimatu(X)) and also
|
Chris@49
|
134 // trimatu(trans(X)) = trans(trimatl(X)). We want to avoid the creation of an
|
Chris@49
|
135 // extra temporary.
|
Chris@49
|
136
|
Chris@49
|
137 // It doesn't matter if the input and output matrices are the same; we will
|
Chris@49
|
138 // pull data from the upper or lower triangular to the lower or upper
|
Chris@49
|
139 // triangular (respectively) and then set the rest to 0, so overwriting issues
|
Chris@49
|
140 // aren't present.
|
Chris@49
|
141
|
Chris@49
|
142 arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square" );
|
Chris@49
|
143
|
Chris@49
|
144 const uword N = A.n_rows;
|
Chris@49
|
145
|
Chris@49
|
146 if(&out != &A)
|
Chris@49
|
147 {
|
Chris@49
|
148 out.copy_size(A);
|
Chris@49
|
149 }
|
Chris@49
|
150
|
Chris@49
|
151 // We can't really get away with any array copy operations here,
|
Chris@49
|
152 // unfortunately...
|
Chris@49
|
153
|
Chris@49
|
154 if(upper)
|
Chris@49
|
155 {
|
Chris@49
|
156 // Upper triangular: but since we're transposing, we're taking the lower
|
Chris@49
|
157 // triangular and putting it in the upper half.
|
Chris@49
|
158 for(uword row = 0; row < N; ++row)
|
Chris@49
|
159 {
|
Chris@49
|
160 eT* out_colptr = out.colptr(row);
|
Chris@49
|
161
|
Chris@49
|
162 for(uword col = 0; col <= row; ++col)
|
Chris@49
|
163 {
|
Chris@49
|
164 //out.at(col, row) = A.at(row, col);
|
Chris@49
|
165 out_colptr[col] = A.at(row, col);
|
Chris@49
|
166 }
|
Chris@49
|
167 }
|
Chris@49
|
168 }
|
Chris@49
|
169 else
|
Chris@49
|
170 {
|
Chris@49
|
171 // Lower triangular: but since we're transposing, we're taking the upper
|
Chris@49
|
172 // triangular and putting it in the lower half.
|
Chris@49
|
173 for(uword row = 0; row < N; ++row)
|
Chris@49
|
174 {
|
Chris@49
|
175 for(uword col = row; col < N; ++col)
|
Chris@49
|
176 {
|
Chris@49
|
177 out.at(col, row) = A.at(row, col);
|
Chris@49
|
178 }
|
Chris@49
|
179 }
|
Chris@49
|
180 }
|
Chris@49
|
181
|
Chris@49
|
182 op_trimat::fill_zeros(out, upper);
|
Chris@49
|
183 }
|
Chris@49
|
184
|
Chris@49
|
185
|
Chris@49
|
186
|
Chris@49
|
187 template<typename eT>
|
Chris@49
|
188 inline
|
Chris@49
|
189 void
|
Chris@49
|
190 op_trimat::apply_htrans
|
Chris@49
|
191 (
|
Chris@49
|
192 Mat<eT>& out,
|
Chris@49
|
193 const Mat<eT>& A,
|
Chris@49
|
194 const bool upper,
|
Chris@49
|
195 const typename arma_cx_only<eT>::result* junk
|
Chris@49
|
196 )
|
Chris@49
|
197 {
|
Chris@49
|
198 arma_extra_debug_sigprint();
|
Chris@49
|
199 arma_ignore(junk);
|
Chris@49
|
200
|
Chris@49
|
201 arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square" );
|
Chris@49
|
202
|
Chris@49
|
203 const uword N = A.n_rows;
|
Chris@49
|
204
|
Chris@49
|
205 if(&out != &A)
|
Chris@49
|
206 {
|
Chris@49
|
207 out.copy_size(A);
|
Chris@49
|
208 }
|
Chris@49
|
209
|
Chris@49
|
210 if(upper)
|
Chris@49
|
211 {
|
Chris@49
|
212 // Upper triangular: but since we're transposing, we're taking the lower
|
Chris@49
|
213 // triangular and putting it in the upper half.
|
Chris@49
|
214 for(uword row = 0; row < N; ++row)
|
Chris@49
|
215 {
|
Chris@49
|
216 eT* out_colptr = out.colptr(row);
|
Chris@49
|
217
|
Chris@49
|
218 for(uword col = 0; col <= row; ++col)
|
Chris@49
|
219 {
|
Chris@49
|
220 //out.at(col, row) = std::conj( A.at(row, col) );
|
Chris@49
|
221 out_colptr[col] = std::conj( A.at(row, col) );
|
Chris@49
|
222 }
|
Chris@49
|
223 }
|
Chris@49
|
224 }
|
Chris@49
|
225 else
|
Chris@49
|
226 {
|
Chris@49
|
227 // Lower triangular: but since we're transposing, we're taking the upper
|
Chris@49
|
228 // triangular and putting it in the lower half.
|
Chris@49
|
229 for(uword row = 0; row < N; ++row)
|
Chris@49
|
230 {
|
Chris@49
|
231 for(uword col = row; col < N; ++col)
|
Chris@49
|
232 {
|
Chris@49
|
233 out.at(col, row) = std::conj( A.at(row, col) );
|
Chris@49
|
234 }
|
Chris@49
|
235 }
|
Chris@49
|
236 }
|
Chris@49
|
237
|
Chris@49
|
238 op_trimat::fill_zeros(out, upper);
|
Chris@49
|
239 }
|
Chris@49
|
240
|
Chris@49
|
241
|
Chris@49
|
242
|
Chris@49
|
243 //! @}
|