Chris@49
|
1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2012 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9 //! \addtogroup op_diagmat
|
Chris@49
|
10 //! @{
|
Chris@49
|
11
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 template<typename T1>
|
Chris@49
|
15 inline
|
Chris@49
|
16 void
|
Chris@49
|
17 op_diagmat::apply(Mat<typename T1::elem_type>& out, const Op<T1, op_diagmat>& X)
|
Chris@49
|
18 {
|
Chris@49
|
19 arma_extra_debug_sigprint();
|
Chris@49
|
20
|
Chris@49
|
21 typedef typename T1::elem_type eT;
|
Chris@49
|
22
|
Chris@49
|
23 const Proxy<T1> P(X.m);
|
Chris@49
|
24
|
Chris@49
|
25 const uword n_rows = P.get_n_rows();
|
Chris@49
|
26 const uword n_cols = P.get_n_cols();
|
Chris@49
|
27
|
Chris@49
|
28 const bool P_is_vec = (n_rows == 1) || (n_cols == 1);
|
Chris@49
|
29
|
Chris@49
|
30
|
Chris@49
|
31 if(P.is_alias(out) == false)
|
Chris@49
|
32 {
|
Chris@49
|
33 if(P_is_vec) // generate a diagonal matrix out of a vector
|
Chris@49
|
34 {
|
Chris@49
|
35 const uword N = (n_rows == 1) ? n_cols : n_rows;
|
Chris@49
|
36
|
Chris@49
|
37 out.zeros(N, N);
|
Chris@49
|
38
|
Chris@49
|
39 if(Proxy<T1>::prefer_at_accessor == false)
|
Chris@49
|
40 {
|
Chris@49
|
41 typename Proxy<T1>::ea_type P_ea = P.get_ea();
|
Chris@49
|
42
|
Chris@49
|
43 for(uword i=0; i < N; ++i) { out.at(i,i) = P_ea[i]; }
|
Chris@49
|
44 }
|
Chris@49
|
45 else
|
Chris@49
|
46 {
|
Chris@49
|
47 if(n_rows == 1)
|
Chris@49
|
48 {
|
Chris@49
|
49 for(uword i=0; i < N; ++i) { out.at(i,i) = P.at(0,i); }
|
Chris@49
|
50 }
|
Chris@49
|
51 else
|
Chris@49
|
52 {
|
Chris@49
|
53 for(uword i=0; i < N; ++i) { out.at(i,i) = P.at(i,0); }
|
Chris@49
|
54 }
|
Chris@49
|
55 }
|
Chris@49
|
56 }
|
Chris@49
|
57 else // generate a diagonal matrix out of a matrix
|
Chris@49
|
58 {
|
Chris@49
|
59 arma_debug_check( (n_rows != n_cols), "diagmat(): given matrix is not square" );
|
Chris@49
|
60
|
Chris@49
|
61 out.zeros(n_rows, n_rows);
|
Chris@49
|
62
|
Chris@49
|
63 for(uword i=0; i < n_rows; ++i) { out.at(i,i) = P.at(i,i); }
|
Chris@49
|
64 }
|
Chris@49
|
65 }
|
Chris@49
|
66 else // we have aliasing
|
Chris@49
|
67 {
|
Chris@49
|
68 if(P_is_vec) // generate a diagonal matrix out of a vector
|
Chris@49
|
69 {
|
Chris@49
|
70 const uword N = (n_rows == 1) ? n_cols : n_rows;
|
Chris@49
|
71
|
Chris@49
|
72 podarray<eT> tmp(N);
|
Chris@49
|
73 eT* tmp_mem = tmp.memptr();
|
Chris@49
|
74
|
Chris@49
|
75 if(Proxy<T1>::prefer_at_accessor == false)
|
Chris@49
|
76 {
|
Chris@49
|
77 typename Proxy<T1>::ea_type P_ea = P.get_ea();
|
Chris@49
|
78
|
Chris@49
|
79 for(uword i=0; i < N; ++i) { tmp_mem[i] = P_ea[i]; }
|
Chris@49
|
80 }
|
Chris@49
|
81 else
|
Chris@49
|
82 {
|
Chris@49
|
83 if(n_rows == 1)
|
Chris@49
|
84 {
|
Chris@49
|
85 for(uword i=0; i < N; ++i) { tmp_mem[i] = P.at(0,i); }
|
Chris@49
|
86 }
|
Chris@49
|
87 else
|
Chris@49
|
88 {
|
Chris@49
|
89 for(uword i=0; i < N; ++i) { tmp_mem[i] = P.at(i,0); }
|
Chris@49
|
90 }
|
Chris@49
|
91 }
|
Chris@49
|
92
|
Chris@49
|
93 out.zeros(N, N);
|
Chris@49
|
94
|
Chris@49
|
95 for(uword i=0; i < N; ++i) { out.at(i,i) = tmp_mem[i]; }
|
Chris@49
|
96 }
|
Chris@49
|
97 else // generate a diagonal matrix out of a matrix
|
Chris@49
|
98 {
|
Chris@49
|
99 arma_debug_check( (n_rows != n_cols), "diagmat(): given matrix is not square" );
|
Chris@49
|
100
|
Chris@49
|
101 if( (Proxy<T1>::has_subview == false) && (Proxy<T1>::fake_mat == false) )
|
Chris@49
|
102 {
|
Chris@49
|
103 // NOTE: we have aliasing and it's not due to a subview, hence we're assuming that the output matrix already has the correct size
|
Chris@49
|
104
|
Chris@49
|
105 for(uword i=0; i < n_rows; ++i)
|
Chris@49
|
106 {
|
Chris@49
|
107 const eT val = P.at(i,i);
|
Chris@49
|
108
|
Chris@49
|
109 arrayops::inplace_set(out.colptr(i), eT(0), n_rows);
|
Chris@49
|
110
|
Chris@49
|
111 out.at(i,i) = val;
|
Chris@49
|
112 }
|
Chris@49
|
113 }
|
Chris@49
|
114 else
|
Chris@49
|
115 {
|
Chris@49
|
116 podarray<eT> tmp(n_rows);
|
Chris@49
|
117 eT* tmp_mem = tmp.memptr();
|
Chris@49
|
118
|
Chris@49
|
119 for(uword i=0; i < n_rows; ++i) { tmp_mem[i] = P.at(i,i); }
|
Chris@49
|
120
|
Chris@49
|
121 out.zeros(n_rows, n_rows);
|
Chris@49
|
122
|
Chris@49
|
123 for(uword i=0; i < n_rows; ++i) { out.at(i,i) = tmp_mem[i]; }
|
Chris@49
|
124 }
|
Chris@49
|
125 }
|
Chris@49
|
126 }
|
Chris@49
|
127 }
|
Chris@49
|
128
|
Chris@49
|
129
|
Chris@49
|
130
|
Chris@49
|
131 //! @}
|