Mercurial > hg > smallbox
comparison DL/dl_ramirez.m @ 177:714fa7b8c1ad danieleb
added ramirez dl (to be completed) and MOCOD dictionary update
author | Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk> |
---|---|
date | Thu, 17 Nov 2011 11:18:25 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
176:d0645d5fca7d | 177:714fa7b8c1ad |
---|---|
1 function DL = dl_ramirez(Problem,DL) | |
2 %% Dictionary learning with incoherent dictionary | |
3 % | |
4 % REFERENCE | |
5 % I. Ramirez, F. Lecumberry and G. Sapiro, Sparse modeling with universal | |
6 % priors and learned incoherent dictionaries. | |
7 | |
8 %% | |
9 % Centre for Digital Music, Queen Mary, University of London. | |
10 % This file copyright 2011 Daniele Barchiesi. | |
11 % | |
12 % This program is free software; you can redistribute it and/or | |
13 % modify it under the terms of the GNU General Public License as | |
14 % published by the Free Software Foundation; either version 2 of the | |
15 % License, or (at your option) any later version. See the file | |
16 % COPYING included with this distribution for more information. | |
17 | |
18 %% Test function | |
19 if ~nargin, testdl_ramirez; return; end | |
20 | |
21 %% Parameters & Defaults | |
22 X = Problem.b; %matrix of observed signals | |
23 | |
24 % determine dictionary size % | |
25 if (isfield(DL.param,'initdict')) %if the dictionary has been initialised | |
26 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) | |
27 dictSize = length(DL.param.initdict); | |
28 else | |
29 dictSize = size(DL.param.initdict,2); | |
30 end | |
31 end | |
32 if (isfield(DL.param,'dictsize')) | |
33 dictSize = DL.param.dictsize; | |
34 end | |
35 | |
36 if (size(X,2) < dictSize) | |
37 error('Number of training signals is smaller than number of atoms to train'); | |
38 end | |
39 | |
40 | |
41 % initialize the dictionary % | |
42 if (isfield(DL.param,'initdict')) && ~isempty(DL.param.initdict); | |
43 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) | |
44 D = X(:,DL.param.initdict(1:dictSize)); | |
45 else | |
46 if (size(DL.param.initdict,1)~=size(X,1) || size(DL.param.initdict,2)<dictSize) | |
47 error('Invalid initial dictionary'); | |
48 end | |
49 D = DL.param.initdict(:,1:dictSize); | |
50 end | |
51 else | |
52 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen | |
53 perm = randperm(length(data_ids)); | |
54 D = X(:,data_ids(perm(1:dictSize))); | |
55 end | |
56 | |
57 | |
58 % coherence penalty factor | |
59 if isfield(DL.param,'zeta') | |
60 zeta = DL.param.zeta; | |
61 else | |
62 zeta = 0.1; | |
63 end | |
64 | |
65 % atoms norm penalty factor | |
66 if isfield(DL.param,'eta') | |
67 eta = DL.param.eta; | |
68 else | |
69 eta = 0.1; | |
70 end | |
71 | |
72 % number of iterations (default is 40) % | |
73 if isfield(DL.param,'iternum') | |
74 iternum = DL.param.iternum; | |
75 else | |
76 iternum = 40; | |
77 end | |
78 | |
79 % show dictonary every specified number of iterations | |
80 if isfield(DL.param,'show_dict') | |
81 show_dictionary=1; | |
82 show_iter=DL.param.show_dict; | |
83 else | |
84 show_dictionary=0; | |
85 show_iter=0; | |
86 end | |
87 | |
88 tmpTraining = Problem.b1; | |
89 Problem.b1 = X; | |
90 if isfield(Problem,'reconstruct') | |
91 Problem = rmfield(Problem, 'reconstruct'); | |
92 end | |
93 | |
94 | |
95 %% Main Algorithm | |
96 Dprev = D; %initial dictionary | |
97 Aprev = D\X; %set initial solution as pseudoinverse | |
98 for i = 1:iternum | |
99 %Sparse Coding by | |
100 A = sparsecoding(X,D,Aprev); | |
101 %Dictionary Update | |
102 D = dictionaryupdate(X,A,Dprev,zeta,eta); | |
103 | |
104 Dprev = D; | |
105 Aprev = A; | |
106 if ((show_dictionary)&&(mod(i,show_iter)==0)) | |
107 dictimg = SMALL_showdict(dico,[8 8],... | |
108 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); | |
109 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; | |
110 pause(0.02); | |
111 end | |
112 end | |
113 | |
114 Problem.b1 = tmpTraining; | |
115 DL.D = D; | |
116 | |
117 end | |
118 | |
119 function A = sparsecoding(X,D,Aprev) | |
120 %Sparse coding using a mixture of laplacians (MOL) as universal prior. | |
121 | |
122 %parameters | |
123 K = size(D,2); %number of atoms | |
124 M = size(X,2); %number of signals | |
125 | |
126 mu1 = mean(abs(Aprev(:))); %first moment of distribution of Aprev | |
127 mu2 = (norm(Aprev(:))^2)/numel(Aprev);%second moment of distribution of Aprev | |
128 kappa = 2*(mu2-mu1^2)/(mu2-2*mu2^2); %parameter kappa of the MOL distribution | |
129 beta = (kappa-1)*mu1; %parameter beta of the MOL distribution | |
130 | |
131 E = X-D*Aprev; %error term | |
132 sigmasq = mean(var(E)); %error variance | |
133 tau = 2*sigmasq*(kappa+1); %sparsity factor | |
134 | |
135 %solve a succession of subproblems to approximate the non-convex cost | |
136 %function | |
137 nIter = 10; %number of iterations of surrogate subproblem | |
138 Psi = zeros(K,M); %initialise solution of subproblem | |
139 for iIter=1:nIter | |
140 Reg = 1./(abs(Psi) + beta); | |
141 Psi = solvel1(X,D,tau,Reg); | |
142 end | |
143 A = Psi; | |
144 end | |
145 | |
146 function Psi = solvel1(X,D,tau,A) | |
147 [K M] = size(A); | |
148 Psi = zeros(K,M); | |
149 for m=1:M | |
150 cvx_begin quiet | |
151 variable v(K) | |
152 minimise (norm(X(:,m)-D*v) + tau*norm(A(:,m).*v,1)); | |
153 cvx_end | |
154 Psi(:,m) = v; | |
155 end | |
156 end | |
157 | |
158 function D = dictionaryupdate(X,A,Dprev,zeta,eta) | |
159 D = (X*A' + 2*(zeta + eta)*Dprev)/(A*A' + 2*zeta*(Dprev'*Dprev) + 2*eta*diag(diag(Dprev'*Dprev))); | |
160 end | |
161 | |
162 | |
163 | |
164 function Y = colnorms_squared(X) | |
165 % compute in blocks to conserve memory | |
166 Y = zeros(1,size(X,2)); | |
167 blocksize = 2000; | |
168 for i = 1:blocksize:size(X,2) | |
169 blockids = i : min(i+blocksize-1,size(X,2)); | |
170 Y(blockids) = sum(X(:,blockids).^2); | |
171 end | |
172 end | |
173 | |
174 function testdl_ramirez | |
175 clc | |
176 N = 10; %ambient dimension | |
177 K = 20; %number of atoms | |
178 M = 30; %number of observed signals | |
179 X = randn(N,M); %observed signals | |
180 D = normcol(randn(N,K)); %initial dictionary | |
181 Problem.b = X; %sparse representation problem | |
182 Problem.b1 = X; | |
183 DL = SMALL_init_DL('dl_ramirez'); | |
184 DL.param.initdict = D; | |
185 DL.param = struct('initdict',D,... | |
186 'zeta',0.5,... | |
187 'eta',0.5); | |
188 DL = SMALL_learn(Problem,DL); | |
189 end |