daniele@177
|
1 function DL = dl_ramirez(Problem,DL)
|
daniele@177
|
2 %% Dictionary learning with incoherent dictionary
|
daniele@177
|
3 %
|
daniele@177
|
4 % REFERENCE
|
daniele@177
|
5 % I. Ramirez, F. Lecumberry and G. Sapiro, Sparse modeling with universal
|
daniele@177
|
6 % priors and learned incoherent dictionaries.
|
daniele@177
|
7
|
daniele@177
|
8 %%
|
daniele@177
|
9 % Centre for Digital Music, Queen Mary, University of London.
|
daniele@177
|
10 % This file copyright 2011 Daniele Barchiesi.
|
daniele@177
|
11 %
|
daniele@177
|
12 % This program is free software; you can redistribute it and/or
|
daniele@177
|
13 % modify it under the terms of the GNU General Public License as
|
daniele@177
|
14 % published by the Free Software Foundation; either version 2 of the
|
daniele@177
|
15 % License, or (at your option) any later version. See the file
|
daniele@177
|
16 % COPYING included with this distribution for more information.
|
daniele@177
|
17
|
daniele@177
|
18 %% Test function
|
daniele@177
|
19 if ~nargin, testdl_ramirez; return; end
|
daniele@177
|
20
|
daniele@177
|
21 %% Parameters & Defaults
|
daniele@177
|
22 X = Problem.b; %matrix of observed signals
|
daniele@177
|
23
|
daniele@177
|
24 % determine dictionary size %
|
daniele@177
|
25 if (isfield(DL.param,'initdict')) %if the dictionary has been initialised
|
daniele@177
|
26 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
daniele@177
|
27 dictSize = length(DL.param.initdict);
|
daniele@177
|
28 else
|
daniele@177
|
29 dictSize = size(DL.param.initdict,2);
|
daniele@177
|
30 end
|
daniele@177
|
31 end
|
daniele@177
|
32 if (isfield(DL.param,'dictsize'))
|
daniele@177
|
33 dictSize = DL.param.dictsize;
|
daniele@177
|
34 end
|
daniele@177
|
35
|
daniele@177
|
36 if (size(X,2) < dictSize)
|
daniele@177
|
37 error('Number of training signals is smaller than number of atoms to train');
|
daniele@177
|
38 end
|
daniele@177
|
39
|
daniele@177
|
40
|
daniele@177
|
41 % initialize the dictionary %
|
daniele@177
|
42 if (isfield(DL.param,'initdict')) && ~isempty(DL.param.initdict);
|
daniele@177
|
43 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
daniele@177
|
44 D = X(:,DL.param.initdict(1:dictSize));
|
daniele@177
|
45 else
|
daniele@177
|
46 if (size(DL.param.initdict,1)~=size(X,1) || size(DL.param.initdict,2)<dictSize)
|
daniele@177
|
47 error('Invalid initial dictionary');
|
daniele@177
|
48 end
|
daniele@177
|
49 D = DL.param.initdict(:,1:dictSize);
|
daniele@177
|
50 end
|
daniele@177
|
51 else
|
daniele@177
|
52 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen
|
daniele@177
|
53 perm = randperm(length(data_ids));
|
daniele@177
|
54 D = X(:,data_ids(perm(1:dictSize)));
|
daniele@177
|
55 end
|
daniele@177
|
56
|
daniele@177
|
57
|
daniele@177
|
58 % coherence penalty factor
|
daniele@177
|
59 if isfield(DL.param,'zeta')
|
daniele@177
|
60 zeta = DL.param.zeta;
|
daniele@177
|
61 else
|
daniele@177
|
62 zeta = 0.1;
|
daniele@177
|
63 end
|
daniele@177
|
64
|
daniele@177
|
65 % atoms norm penalty factor
|
daniele@177
|
66 if isfield(DL.param,'eta')
|
daniele@177
|
67 eta = DL.param.eta;
|
daniele@177
|
68 else
|
daniele@177
|
69 eta = 0.1;
|
daniele@177
|
70 end
|
daniele@177
|
71
|
daniele@177
|
72 % number of iterations (default is 40) %
|
daniele@177
|
73 if isfield(DL.param,'iternum')
|
daniele@177
|
74 iternum = DL.param.iternum;
|
daniele@177
|
75 else
|
daniele@177
|
76 iternum = 40;
|
daniele@177
|
77 end
|
daniele@177
|
78
|
daniele@177
|
79 % show dictonary every specified number of iterations
|
daniele@177
|
80 if isfield(DL.param,'show_dict')
|
daniele@177
|
81 show_dictionary=1;
|
daniele@177
|
82 show_iter=DL.param.show_dict;
|
daniele@177
|
83 else
|
daniele@177
|
84 show_dictionary=0;
|
daniele@177
|
85 show_iter=0;
|
daniele@177
|
86 end
|
daniele@177
|
87
|
daniele@177
|
88 tmpTraining = Problem.b1;
|
daniele@177
|
89 Problem.b1 = X;
|
daniele@177
|
90 if isfield(Problem,'reconstruct')
|
daniele@177
|
91 Problem = rmfield(Problem, 'reconstruct');
|
daniele@177
|
92 end
|
daniele@177
|
93
|
daniele@177
|
94
|
daniele@177
|
95 %% Main Algorithm
|
daniele@177
|
96 Dprev = D; %initial dictionary
|
daniele@177
|
97 Aprev = D\X; %set initial solution as pseudoinverse
|
daniele@177
|
98 for i = 1:iternum
|
daniele@177
|
99 %Sparse Coding by
|
daniele@177
|
100 A = sparsecoding(X,D,Aprev);
|
daniele@177
|
101 %Dictionary Update
|
daniele@177
|
102 D = dictionaryupdate(X,A,Dprev,zeta,eta);
|
daniele@177
|
103
|
daniele@177
|
104 Dprev = D;
|
daniele@177
|
105 Aprev = A;
|
daniele@177
|
106 if ((show_dictionary)&&(mod(i,show_iter)==0))
|
daniele@177
|
107 dictimg = SMALL_showdict(dico,[8 8],...
|
daniele@177
|
108 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
|
daniele@177
|
109 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
|
daniele@177
|
110 pause(0.02);
|
daniele@177
|
111 end
|
daniele@177
|
112 end
|
daniele@177
|
113
|
daniele@177
|
114 Problem.b1 = tmpTraining;
|
daniele@177
|
115 DL.D = D;
|
daniele@177
|
116
|
daniele@177
|
117 end
|
daniele@177
|
118
|
daniele@177
|
119 function A = sparsecoding(X,D,Aprev)
|
daniele@177
|
120 %Sparse coding using a mixture of laplacians (MOL) as universal prior.
|
daniele@177
|
121
|
daniele@177
|
122 %parameters
|
daniele@177
|
123 K = size(D,2); %number of atoms
|
daniele@177
|
124 M = size(X,2); %number of signals
|
daniele@177
|
125
|
daniele@177
|
126 mu1 = mean(abs(Aprev(:))); %first moment of distribution of Aprev
|
daniele@177
|
127 mu2 = (norm(Aprev(:))^2)/numel(Aprev);%second moment of distribution of Aprev
|
daniele@177
|
128 kappa = 2*(mu2-mu1^2)/(mu2-2*mu2^2); %parameter kappa of the MOL distribution
|
daniele@177
|
129 beta = (kappa-1)*mu1; %parameter beta of the MOL distribution
|
daniele@177
|
130
|
daniele@177
|
131 E = X-D*Aprev; %error term
|
daniele@177
|
132 sigmasq = mean(var(E)); %error variance
|
daniele@177
|
133 tau = 2*sigmasq*(kappa+1); %sparsity factor
|
daniele@177
|
134
|
daniele@177
|
135 %solve a succession of subproblems to approximate the non-convex cost
|
daniele@177
|
136 %function
|
daniele@177
|
137 nIter = 10; %number of iterations of surrogate subproblem
|
daniele@177
|
138 Psi = zeros(K,M); %initialise solution of subproblem
|
daniele@177
|
139 for iIter=1:nIter
|
daniele@177
|
140 Reg = 1./(abs(Psi) + beta);
|
daniele@177
|
141 Psi = solvel1(X,D,tau,Reg);
|
daniele@177
|
142 end
|
daniele@177
|
143 A = Psi;
|
daniele@177
|
144 end
|
daniele@177
|
145
|
daniele@177
|
146 function Psi = solvel1(X,D,tau,A)
|
daniele@177
|
147 [K M] = size(A);
|
daniele@177
|
148 Psi = zeros(K,M);
|
daniele@177
|
149 for m=1:M
|
daniele@177
|
150 cvx_begin quiet
|
daniele@177
|
151 variable v(K)
|
daniele@177
|
152 minimise (norm(X(:,m)-D*v) + tau*norm(A(:,m).*v,1));
|
daniele@177
|
153 cvx_end
|
daniele@177
|
154 Psi(:,m) = v;
|
daniele@177
|
155 end
|
daniele@177
|
156 end
|
daniele@177
|
157
|
daniele@177
|
158 function D = dictionaryupdate(X,A,Dprev,zeta,eta)
|
daniele@177
|
159 D = (X*A' + 2*(zeta + eta)*Dprev)/(A*A' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev)));
|
daniele@177
|
160 end
|
daniele@177
|
161
|
daniele@177
|
162
|
daniele@177
|
163
|
daniele@177
|
164 function Y = colnorms_squared(X)
|
daniele@177
|
165 % compute in blocks to conserve memory
|
daniele@177
|
166 Y = zeros(1,size(X,2));
|
daniele@177
|
167 blocksize = 2000;
|
daniele@177
|
168 for i = 1:blocksize:size(X,2)
|
daniele@177
|
169 blockids = i : min(i+blocksize-1,size(X,2));
|
daniele@177
|
170 Y(blockids) = sum(X(:,blockids).^2);
|
daniele@177
|
171 end
|
daniele@177
|
172 end
|
daniele@177
|
173
|
daniele@177
|
174 function testdl_ramirez
|
daniele@177
|
175 clc
|
daniele@177
|
176 N = 10; %ambient dimension
|
daniele@177
|
177 K = 20; %number of atoms
|
daniele@177
|
178 M = 30; %number of observed signals
|
daniele@177
|
179 X = randn(N,M); %observed signals
|
daniele@177
|
180 D = normcol(randn(N,K)); %initial dictionary
|
daniele@177
|
181 Problem.b = X; %sparse representation problem
|
daniele@177
|
182 Problem.b1 = X;
|
daniele@177
|
183 DL = SMALL_init_DL('dl_ramirez');
|
daniele@177
|
184 DL.param.initdict = D;
|
daniele@177
|
185 DL.param = struct('initdict',D,...
|
daniele@177
|
186 'zeta',0.5,...
|
daniele@177
|
187 'eta',0.5);
|
daniele@177
|
188 DL = SMALL_learn(Problem,DL);
|
daniele@177
|
189 end
|