# HG changeset patch # User Daniele Barchiesi # Date 1331736172 0 # Node ID d50f5bdbe14c41825b64b07152652341584d72fe # Parent 759313488e7b83ccb7eda53ae3955096df98fe5b - Added SMALL_DL_test: simple DL showcase - Added dico_decorr_symmetric: improved version of INK-SVD decorrelation step - Debugged SMALL_learn, SMALLBoxInit and SMALL_two_step_DL diff -r 759313488e7b -r d50f5bdbe14c DL/two-step DL/SMALL_two_step_DL.m --- a/DL/two-step DL/SMALL_two_step_DL.m Tue Mar 13 17:33:20 2012 +0000 +++ b/DL/two-step DL/SMALL_two_step_DL.m Wed Mar 14 14:42:52 2012 +0000 @@ -98,6 +98,7 @@ % want sparse representation of training set, and in Problem.b1 in this % version of software we store the signal that needs to be represented % (for example the whole image) +global SMALL_path tmpTraining = Problem.b1; Problem.b1 = sig; @@ -144,4 +145,4 @@ Y(blockids) = sum(X(:,blockids).^2); end -end \ No newline at end of file +end diff -r 759313488e7b -r d50f5bdbe14c DL/two-step DL/dico_decorr_symetric.m --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/DL/two-step DL/dico_decorr_symetric.m Wed Mar 14 14:42:52 2012 +0000 @@ -0,0 +1,61 @@ +function dico = dico_decorr_symetric(dico, mu, amp) + %DICO_DECORR decorrelate a dictionary + % Parameters: + % dico: the dictionary + % mu: the coherence threshold + % amp: the amplitude coefficients, only used to decide which atom to + % project + % + % Result: + % dico: a dictionary close to the input one with coherence mu. + + eps = 1e-6; % define tolerance for normalisation term alpha + + % convert mu to the to the mean direction + theta = acos(mu)/2; + ctheta = cos(theta); + stheta = sin(theta); + + % compute atom weights + % if nargin > 2 + % rank = sum(amp.*amp, 2); + % else + % rank = randperm(length(dico)); + % end + + % several decorrelation iterations might be needed to reach global + % coherence mu. niter can be adjusted to needs. + niter = 1; + while max(max(abs(dico'*dico -eye(length(dico))))) > mu + 0.01 + % find pairs of high correlation atoms + colors = dico_color(dico, mu); + + % iterate on all pairs + nbColors = max(colors); + for c = 1:nbColors + index = find(colors==c); + if numel(index) == 2 + if dico(:,index(1))'*dico(:,index(2)) > 0 + %build the basis vectors + v1 = dico(:,index(1))+dico(:,index(2)); + v1 = v1/norm(v1); + v2 = dico(:,index(1))-dico(:,index(2)); + v2 = v2/norm(v2); + + dico(:,index(1)) = ctheta*v1+stheta*v2; + dico(:,index(2)) = ctheta*v1-stheta*v2; + else + v1 = dico(:,index(1))-dico(:,index(2)); + v1 = v1/norm(v1); + v2 = dico(:,index(1))+dico(:,index(2)); + v2 = v2/norm(v2); + + dico(:,index(1)) = ctheta*v1+stheta*v2; + dico(:,index(2)) = -ctheta*v1+stheta*v2; + end + end + end + niter = niter+1; + end +end + diff -r 759313488e7b -r d50f5bdbe14c SMALLboxInit.m --- a/SMALLboxInit.m Tue Mar 13 17:33:20 2012 +0000 +++ b/SMALLboxInit.m Wed Mar 14 14:42:52 2012 +0000 @@ -1,5 +1,5 @@ global SMALL_path; -SMALL_path=pwd; +SMALL_path=[fileparts(mfilename('fullpath')) filesep]; % SMALL_p = genpath(SMALL_path); % addpath(SMALL_p); diff -r 759313488e7b -r d50f5bdbe14c examples/SMALL_DL_test.m --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/examples/SMALL_DL_test.m Wed Mar 14 14:42:52 2012 +0000 @@ -0,0 +1,76 @@ +function SMALL_DL_test +clear, clc, close all +% Create a 2-dimensional dataset of points that are oriented in 3 +% directions on a x/y plane + +% +% Centre for Digital Music, Queen Mary, University of London. +% This file copyright 2012 Daniele Barchiesi. +% +% This program is free software; you can redistribute it and/or +% modify it under the terms of the GNU General Public License as +% published by the Free Software Foundation; either version 2 of the +% License, or (at your option) any later version. See the file +% COPYING included with this distribution for more information. + +nData = 10000; %number of data +theta = [pi/6 pi/3 4*pi/6]; %angles +nAngles = length(theta); %number of angles +Q = [cos(theta); sin(theta)]; %rotation matrix +X = Q*randmog(nAngles,nData); %training data + +% find principal directions using PCA +XXt = X*X'; %cross correlation matrix +[U ~] = svd(XXt); %svd of XXt + +scale = 3; %scale factor for plots +subplot(1,2,1), hold on +title('Principal Component Analysis') +scatter(X(1,:), X(2,:),'.'); %scatter training data +O = zeros(size(U)); %origin +quiver(O(1,1:2),O(2,1:2),scale*U(1,:),scale*U(2,:),... + 'LineWidth',2,'Color','k') %plot atoms +axis equal %scale axis + +subplot(1,2,2), hold on +title('K-SVD Dictionary') +scatter(X(1,:), X(2,:),'.'); +axis equal + +nAtoms = 3; %number of atoms in the dictionary +nIter = 1; %number of dictionary learning iterations +initDict = normc(randn(2,nAtoms)); %random initial dictionary +O = zeros(size(initDict)); %origin + +% apply dictionary learning algorithm +ksvd_params = struct('data',X,... %training data + 'Tdata',1,... %sparsity level + 'dictsize',nAtoms,... %number of atoms + 'initdict',initDict,...%initial dictionary + 'iternum',10); %number of iterations +DL = SMALL_init_DL('ksvd','ksvd',ksvd_params); %dictionary learning structure +DL.D = initDict; %copy initial dictionary in solution variable +problem = struct('b',X); %copy training data in problem structure + +xdata = DL.D(1,:); +ydata = DL.D(2,:); +qPlot = quiver(O(1,:),O(2,:),scale*initDict(1,:),scale*initDict(2,:),... + 'LineWidth',2,'Color','k','UDataSource','xdata','VDataSource','ydata'); + +for iIter=1:nIter + DL.ksvd_params.initdict = DL.D; + pause + DL = SMALL_learn(problem,DL); %learn dictionary + xdata = scale*DL.D(1,:); + ydata = scale*DL.D(2,:); + refreshdata(gcf,'caller'); +end + + +function X = randmog(m, n) +% RANDMOG - Generate mixture of Gaussians +s = [0.2 2]; +% Choose which Gaussian +G1 = (rand(m, n) < 0.9); +% Make them +X = (G1.*s(1) + (1-G1).*s(2)) .* randn(m,n); diff -r 759313488e7b -r d50f5bdbe14c util/SMALL_learn.m --- a/util/SMALL_learn.m Tue Mar 13 17:33:20 2012 +0000 +++ b/util/SMALL_learn.m Wed Mar 14 14:42:52 2012 +0000 @@ -18,6 +18,7 @@ % License, or (at your option) any later version. See the file % COPYING included with this distribution for more information. %% +global SMALL_path if (DL.profile) fprintf('\nStarting Dictionary Learning %s... \n', DL.name); end @@ -41,4 +42,4 @@ DL.D = full(D); end - \ No newline at end of file +