Mercurial > hg > smallbox
comparison DL/RLS-DLA/SMALL_rlsdla.m @ 85:fd1c32cda22c
Comments to small_rlsdla.m, removed unfinished work.
author | Ivan <ivan.damnjanovic@eecs.qmul.ac.uk> |
---|---|
date | Tue, 05 Apr 2011 17:03:26 +0100 |
parents | a02503d91c8d |
children | 04cce72a4dc8 |
comparison
equal
deleted
inserted
replaced
84:67aae1283973 | 85:fd1c32cda22c |
---|---|
1 function Dictionary = SMALL_rlsdla(X, params) | 1 function Dictionary = SMALL_rlsdla(X, params) |
2 | 2 %% Recursive Least Squares Dictionary Learning Algorithm |
3 | 3 % |
4 | 4 % D = SMALL_rlsdla(X, params) - runs RLS-DLA algorithm for |
5 | 5 % training signals specified as columns of matrix X with parameters |
6 | 6 % specified in params structure returning the learned dictionary D. |
7 % | |
8 % Fields in params structure: | |
9 % Required: | |
10 % 'Tdata' / 'Edata' sparse-coding target | |
11 % 'initdict' / 'dictsize' initial dictionary / dictionary size | |
12 % | |
13 % Optional (default values in parentheses): | |
14 % 'codemode' 'sparsity' or 'error' ('sparsity') | |
15 % 'maxatoms' max # of atoms in error sparse-coding (none) | |
16 % 'forgettingMode' 'fix' - fix forgetting factor, | |
17 % other modes are not implemented in | |
18 % this version(exponential etc.) | |
19 % 'forgettingFactor' for 'fix' mode (default is 1) | |
20 % 'show_dict' shows dictionary after # of | |
21 % iterations specified (less then 100 | |
22 % can make it running slow). In this | |
23 % version it assumes that it is image | |
24 % dictionary and atoms size is 8x8 | |
25 % | |
26 % - RLS-DLA - Skretting, K.; Engan, K.; , "Recursive Least Squares | |
27 % Dictionary Learning Algorithm," Signal Processing, IEEE Transactions on, | |
28 % vol.58, no.4, pp.2121-2130, April 2010 | |
29 % | |
30 | |
31 | |
32 % Centre for Digital Music, Queen Mary, University of London. | |
33 % This file copyright 2011 Ivan Damnjanovic. | |
34 % | |
35 % This program is free software; you can redistribute it and/or | |
36 % modify it under the terms of the GNU General Public License as | |
37 % published by the Free Software Foundation; either version 2 of the | |
38 % License, or (at your option) any later version. See the file | |
39 % COPYING included with this distribution for more information. | |
40 % | |
41 %% | |
7 | 42 |
8 CODE_SPARSITY = 1; | 43 CODE_SPARSITY = 1; |
9 CODE_ERROR = 2; | 44 CODE_ERROR = 2; |
10 | 45 |
11 | 46 |
43 else | 78 else |
44 maxatoms = -1; | 79 maxatoms = -1; |
45 end | 80 end |
46 | 81 |
47 | 82 |
83 | |
84 | |
85 % determine dictionary size % | |
86 | |
87 if (isfield(params,'initdict')) | |
88 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) | |
89 dictsize = length(params.initdict); | |
90 else | |
91 dictsize = size(params.initdict,2); | |
92 end | |
93 end | |
94 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict | |
95 dictsize = params.dictsize; | |
96 end | |
97 | |
98 if (size(X,2) < dictsize) | |
99 error('Number of training signals is smaller than number of atoms to train'); | |
100 end | |
101 | |
102 | |
103 % initialize the dictionary % | |
104 | |
105 if (isfield(params,'initdict')) | |
106 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) | |
107 D = X(:,params.initdict(1:dictsize)); | |
108 else | |
109 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) | |
110 error('Invalid initial dictionary'); | |
111 end | |
112 D = params.initdict(:,1:dictsize); | |
113 end | |
114 else | |
115 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen | |
116 perm = randperm(length(data_ids)); | |
117 D = X(:,data_ids(perm(1:dictsize))); | |
118 end | |
119 | |
120 | |
121 % normalize the dictionary % | |
122 | |
123 D = normcols(D); | |
124 | |
125 % show dictonary every specified number of iterations | |
126 | |
127 if (isfield(params,'show_dict')) | |
128 show_dictionary=1; | |
129 show_iter=params.show_dict; | |
130 else | |
131 show_dictionary=0; | |
132 show_iter=0; | |
133 end | |
134 | |
135 if (show_dictionary) | |
136 dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); | |
137 figure(2); imshow(imresize(dictimg,2,'nearest')); | |
138 end | |
48 % Forgetting factor | 139 % Forgetting factor |
49 | 140 |
50 if (isfield(params,'forgettingMode')) | 141 if (isfield(params,'forgettingMode')) |
51 switch lower(params.forgettingMode) | 142 switch lower(params.forgettingMode) |
52 case 'fix' | 143 case 'fix' |
62 lambda=params.forgettingFactor; | 153 lambda=params.forgettingFactor; |
63 else | 154 else |
64 lambda=1; | 155 lambda=1; |
65 end | 156 end |
66 | 157 |
67 % determine dictionary size % | |
68 | |
69 if (isfield(params,'initdict')) | |
70 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) | |
71 dictsize = length(params.initdict); | |
72 else | |
73 dictsize = size(params.initdict,2); | |
74 end | |
75 end | |
76 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict | |
77 dictsize = params.dictsize; | |
78 end | |
79 | |
80 if (size(X,2) < dictsize) | |
81 error('Number of training signals is smaller than number of atoms to train'); | |
82 end | |
83 | |
84 | |
85 % initialize the dictionary % | |
86 | |
87 if (isfield(params,'initdict')) | |
88 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) | |
89 D = X(:,params.initdict(1:dictsize)); | |
90 else | |
91 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) | |
92 error('Invalid initial dictionary'); | |
93 end | |
94 D = params.initdict(:,1:dictsize); | |
95 end | |
96 else | |
97 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen | |
98 perm = randperm(length(data_ids)); | |
99 D = X(:,data_ids(perm(1:dictsize))); | |
100 end | |
101 | |
102 | |
103 % normalize the dictionary % | |
104 | |
105 D = normcols(D); | |
106 | |
107 % Training data | 158 % Training data |
108 | 159 |
109 data=X; | 160 data=X; |
110 cnt=size(data,2); | 161 cnt=size(data,2); |
162 | |
111 % | 163 % |
112 | 164 |
113 C=(100000*thresh)*eye(dictsize); | 165 C=(100000*thresh)*eye(dictsize); |
114 w=zeros(dictsize,1); | 166 w=zeros(dictsize,1); |
115 u=zeros(dictsize,1); | 167 u=zeros(dictsize,1); |
116 | 168 |
117 | 169 |
118 for i = 1:cnt | 170 for i = 1:cnt |
119 | 171 |
120 if (codemode == CODE_SPARSITY) | 172 if (codemode == CODE_SPARSITY) |
121 w = ompmex(D,data(:,i),[],thresh,'checkdict','off'); | 173 w = omp2(D,data(:,i),[],thresh,'checkdict','off'); |
122 else | 174 else |
123 w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off'); | 175 w = omp2(D,data(:,i),[],thresh,'maxatoms',maxatoms, 'checkdict','off'); |
124 end | 176 end |
125 | 177 |
126 spind=find(w); | 178 spind=find(w); |
138 | 190 |
139 D = D + (alfa * residual) * u'; | 191 D = D + (alfa * residual) * u'; |
140 | 192 |
141 | 193 |
142 C = C - (alfa * u)* u'; | 194 C = C - (alfa * u)* u'; |
143 | 195 if (show_dictionary &&(mod(i,show_iter)==0)) |
196 dictimg = showdict(D,[8 8],... | |
197 round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); | |
198 figure(2); imshow(imresize(dictimg,2,'nearest')); | |
199 pause(0.02); | |
200 end | |
144 end | 201 end |
145 | 202 |
146 Dictionary = D; | 203 Dictionary = D; |
147 | 204 |
148 end | 205 end |