comparison DL/RLS-DLA/SMALL_rlsdla.m @ 85:fd1c32cda22c

Comments to small_rlsdla.m, removed unfinished work.
author Ivan <ivan.damnjanovic@eecs.qmul.ac.uk>
date Tue, 05 Apr 2011 17:03:26 +0100
parents a02503d91c8d
children 04cce72a4dc8
comparison
equal deleted inserted replaced
84:67aae1283973 85:fd1c32cda22c
1 function Dictionary = SMALL_rlsdla(X, params) 1 function Dictionary = SMALL_rlsdla(X, params)
2 2 %% Recursive Least Squares Dictionary Learning Algorithm
3 3 %
4 4 % D = SMALL_rlsdla(X, params) - runs RLS-DLA algorithm for
5 5 % training signals specified as columns of matrix X with parameters
6 6 % specified in params structure returning the learned dictionary D.
7 %
8 % Fields in params structure:
9 % Required:
10 % 'Tdata' / 'Edata' sparse-coding target
11 % 'initdict' / 'dictsize' initial dictionary / dictionary size
12 %
13 % Optional (default values in parentheses):
14 % 'codemode' 'sparsity' or 'error' ('sparsity')
15 % 'maxatoms' max # of atoms in error sparse-coding (none)
16 % 'forgettingMode' 'fix' - fix forgetting factor,
17 % other modes are not implemented in
18 % this version(exponential etc.)
19 % 'forgettingFactor' for 'fix' mode (default is 1)
20 % 'show_dict' shows dictionary after # of
21 % iterations specified (less then 100
22 % can make it running slow). In this
23 % version it assumes that it is image
24 % dictionary and atoms size is 8x8
25 %
26 % - RLS-DLA - Skretting, K.; Engan, K.; , "Recursive Least Squares
27 % Dictionary Learning Algorithm," Signal Processing, IEEE Transactions on,
28 % vol.58, no.4, pp.2121-2130, April 2010
29 %
30
31
32 % Centre for Digital Music, Queen Mary, University of London.
33 % This file copyright 2011 Ivan Damnjanovic.
34 %
35 % This program is free software; you can redistribute it and/or
36 % modify it under the terms of the GNU General Public License as
37 % published by the Free Software Foundation; either version 2 of the
38 % License, or (at your option) any later version. See the file
39 % COPYING included with this distribution for more information.
40 %
41 %%
7 42
8 CODE_SPARSITY = 1; 43 CODE_SPARSITY = 1;
9 CODE_ERROR = 2; 44 CODE_ERROR = 2;
10 45
11 46
43 else 78 else
44 maxatoms = -1; 79 maxatoms = -1;
45 end 80 end
46 81
47 82
83
84
85 % determine dictionary size %
86
87 if (isfield(params,'initdict'))
88 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
89 dictsize = length(params.initdict);
90 else
91 dictsize = size(params.initdict,2);
92 end
93 end
94 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict
95 dictsize = params.dictsize;
96 end
97
98 if (size(X,2) < dictsize)
99 error('Number of training signals is smaller than number of atoms to train');
100 end
101
102
103 % initialize the dictionary %
104
105 if (isfield(params,'initdict'))
106 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
107 D = X(:,params.initdict(1:dictsize));
108 else
109 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize)
110 error('Invalid initial dictionary');
111 end
112 D = params.initdict(:,1:dictsize);
113 end
114 else
115 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen
116 perm = randperm(length(data_ids));
117 D = X(:,data_ids(perm(1:dictsize)));
118 end
119
120
121 % normalize the dictionary %
122
123 D = normcols(D);
124
125 % show dictonary every specified number of iterations
126
127 if (isfield(params,'show_dict'))
128 show_dictionary=1;
129 show_iter=params.show_dict;
130 else
131 show_dictionary=0;
132 show_iter=0;
133 end
134
135 if (show_dictionary)
136 dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
137 figure(2); imshow(imresize(dictimg,2,'nearest'));
138 end
48 % Forgetting factor 139 % Forgetting factor
49 140
50 if (isfield(params,'forgettingMode')) 141 if (isfield(params,'forgettingMode'))
51 switch lower(params.forgettingMode) 142 switch lower(params.forgettingMode)
52 case 'fix' 143 case 'fix'
62 lambda=params.forgettingFactor; 153 lambda=params.forgettingFactor;
63 else 154 else
64 lambda=1; 155 lambda=1;
65 end 156 end
66 157
67 % determine dictionary size %
68
69 if (isfield(params,'initdict'))
70 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
71 dictsize = length(params.initdict);
72 else
73 dictsize = size(params.initdict,2);
74 end
75 end
76 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict
77 dictsize = params.dictsize;
78 end
79
80 if (size(X,2) < dictsize)
81 error('Number of training signals is smaller than number of atoms to train');
82 end
83
84
85 % initialize the dictionary %
86
87 if (isfield(params,'initdict'))
88 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
89 D = X(:,params.initdict(1:dictsize));
90 else
91 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize)
92 error('Invalid initial dictionary');
93 end
94 D = params.initdict(:,1:dictsize);
95 end
96 else
97 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen
98 perm = randperm(length(data_ids));
99 D = X(:,data_ids(perm(1:dictsize)));
100 end
101
102
103 % normalize the dictionary %
104
105 D = normcols(D);
106
107 % Training data 158 % Training data
108 159
109 data=X; 160 data=X;
110 cnt=size(data,2); 161 cnt=size(data,2);
162
111 % 163 %
112 164
113 C=(100000*thresh)*eye(dictsize); 165 C=(100000*thresh)*eye(dictsize);
114 w=zeros(dictsize,1); 166 w=zeros(dictsize,1);
115 u=zeros(dictsize,1); 167 u=zeros(dictsize,1);
116 168
117 169
118 for i = 1:cnt 170 for i = 1:cnt
119 171
120 if (codemode == CODE_SPARSITY) 172 if (codemode == CODE_SPARSITY)
121 w = ompmex(D,data(:,i),[],thresh,'checkdict','off'); 173 w = omp2(D,data(:,i),[],thresh,'checkdict','off');
122 else 174 else
123 w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off'); 175 w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off');
124 end 176 end
125 177
126 spind=find(w); 178 spind=find(w);
138 190
139 D = D + (alfa * residual) * u'; 191 D = D + (alfa * residual) * u';
140 192
141 193
142 C = C - (alfa * u)* u'; 194 C = C - (alfa * u)* u';
143 195 if (show_dictionary &&(mod(i,show_iter)==0))
196 dictimg = showdict(D,[8 8],...
197 round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
198 figure(2); imshow(imresize(dictimg,2,'nearest'));
199 pause(0.02);
200 end
144 end 201 end
145 202
146 Dictionary = D; 203 Dictionary = D;
147 204
148 end 205 end