idamnjanovic@40
|
1 function Dictionary = SMALL_rlsdla(X, params)
|
idamnjanovic@40
|
2
|
idamnjanovic@40
|
3
|
idamnjanovic@40
|
4
|
idamnjanovic@40
|
5
|
idamnjanovic@40
|
6
|
idamnjanovic@40
|
7
|
idamnjanovic@40
|
8 CODE_SPARSITY = 1;
|
idamnjanovic@40
|
9 CODE_ERROR = 2;
|
idamnjanovic@40
|
10
|
idamnjanovic@40
|
11
|
idamnjanovic@40
|
12 % Determine which method will be used for sparse representation step -
|
idamnjanovic@40
|
13 % Sparsity or Error mode
|
idamnjanovic@40
|
14
|
idamnjanovic@40
|
15 if (isfield(params,'codemode'))
|
idamnjanovic@40
|
16 switch lower(params.codemode)
|
idamnjanovic@40
|
17 case 'sparsity'
|
idamnjanovic@40
|
18 codemode = CODE_SPARSITY;
|
idamnjanovic@40
|
19 thresh = params.Tdata;
|
idamnjanovic@40
|
20 case 'error'
|
idamnjanovic@40
|
21 codemode = CODE_ERROR;
|
idamnjanovic@40
|
22 thresh = params.Edata;
|
idamnjanovic@40
|
23
|
idamnjanovic@40
|
24 otherwise
|
idamnjanovic@40
|
25 error('Invalid coding mode specified');
|
idamnjanovic@40
|
26 end
|
idamnjanovic@40
|
27 elseif (isfield(params,'Tdata'))
|
idamnjanovic@40
|
28 codemode = CODE_SPARSITY;
|
idamnjanovic@40
|
29 thresh = params.Tdata;
|
idamnjanovic@40
|
30 elseif (isfield(params,'Edata'))
|
idamnjanovic@40
|
31 codemode = CODE_ERROR;
|
idamnjanovic@40
|
32 thresh = params.Edata;
|
idamnjanovic@40
|
33
|
idamnjanovic@40
|
34 else
|
idamnjanovic@40
|
35 error('Data sparse-coding target not specified');
|
idamnjanovic@40
|
36 end
|
idamnjanovic@40
|
37
|
idamnjanovic@40
|
38
|
idamnjanovic@40
|
39 % max number of atoms %
|
idamnjanovic@40
|
40
|
idamnjanovic@40
|
41 if (codemode==CODE_ERROR && isfield(params,'maxatoms'))
|
idamnjanovic@40
|
42 maxatoms = params.maxatoms;
|
idamnjanovic@40
|
43 else
|
idamnjanovic@40
|
44 maxatoms = -1;
|
idamnjanovic@40
|
45 end
|
idamnjanovic@40
|
46
|
idamnjanovic@40
|
47
|
idamnjanovic@40
|
48 % Forgetting factor
|
idamnjanovic@40
|
49
|
idamnjanovic@40
|
50 if (isfield(params,'forgettingMode'))
|
idamnjanovic@40
|
51 switch lower(params.forgettingMode)
|
idamnjanovic@40
|
52 case 'fix'
|
idamnjanovic@40
|
53 if (isfield(params,'forgettingFactor'))
|
idamnjanovic@40
|
54 lambda=params.forgettingFactor;
|
idamnjanovic@40
|
55 else
|
idamnjanovic@40
|
56 lambda=1;
|
idamnjanovic@40
|
57 end
|
idamnjanovic@40
|
58 otherwise
|
idamnjanovic@40
|
59 error('This mode is still not implemented');
|
idamnjanovic@40
|
60 end
|
idamnjanovic@40
|
61 elseif (isfield(params,'forgettingFactor'))
|
idamnjanovic@40
|
62 lambda=params.forgettingFactor;
|
idamnjanovic@40
|
63 else
|
idamnjanovic@40
|
64 lambda=1;
|
idamnjanovic@40
|
65 end
|
idamnjanovic@40
|
66
|
idamnjanovic@40
|
67 % determine dictionary size %
|
idamnjanovic@40
|
68
|
idamnjanovic@40
|
69 if (isfield(params,'initdict'))
|
idamnjanovic@40
|
70 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
|
idamnjanovic@40
|
71 dictsize = length(params.initdict);
|
idamnjanovic@40
|
72 else
|
idamnjanovic@40
|
73 dictsize = size(params.initdict,2);
|
idamnjanovic@40
|
74 end
|
idamnjanovic@40
|
75 end
|
idamnjanovic@40
|
76 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict
|
idamnjanovic@40
|
77 dictsize = params.dictsize;
|
idamnjanovic@40
|
78 end
|
idamnjanovic@40
|
79
|
idamnjanovic@40
|
80 if (size(X,2) < dictsize)
|
idamnjanovic@40
|
81 error('Number of training signals is smaller than number of atoms to train');
|
idamnjanovic@40
|
82 end
|
idamnjanovic@40
|
83
|
idamnjanovic@40
|
84
|
idamnjanovic@40
|
85 % initialize the dictionary %
|
idamnjanovic@40
|
86
|
idamnjanovic@40
|
87 if (isfield(params,'initdict'))
|
idamnjanovic@40
|
88 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
|
idamnjanovic@40
|
89 D = X(:,params.initdict(1:dictsize));
|
idamnjanovic@40
|
90 else
|
idamnjanovic@40
|
91 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize)
|
idamnjanovic@40
|
92 error('Invalid initial dictionary');
|
idamnjanovic@40
|
93 end
|
idamnjanovic@40
|
94 D = params.initdict(:,1:dictsize);
|
idamnjanovic@40
|
95 end
|
idamnjanovic@40
|
96 else
|
idamnjanovic@40
|
97 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen
|
idamnjanovic@40
|
98 perm = randperm(length(data_ids));
|
idamnjanovic@40
|
99 D = X(:,data_ids(perm(1:dictsize)));
|
idamnjanovic@40
|
100 end
|
idamnjanovic@40
|
101
|
idamnjanovic@40
|
102
|
idamnjanovic@40
|
103 % normalize the dictionary %
|
idamnjanovic@40
|
104
|
idamnjanovic@40
|
105 D = normcols(D);
|
idamnjanovic@40
|
106
|
idamnjanovic@40
|
107 % Training data
|
idamnjanovic@40
|
108
|
idamnjanovic@40
|
109 data=X;
|
idamnjanovic@40
|
110
|
idamnjanovic@40
|
111 %
|
idamnjanovic@40
|
112
|
idamnjanovic@40
|
113 C=(100000*thresh)*eye(dictsize);
|
idamnjanovic@40
|
114 w=zeros(dictsize,1);
|
idamnjanovic@40
|
115 u=zeros(dictsize,1);
|
idamnjanovic@40
|
116
|
idamnjanovic@40
|
117
|
idamnjanovic@40
|
118 for i = 1:size(data,2)
|
idamnjanovic@40
|
119
|
idamnjanovic@40
|
120 if (codemode == CODE_SPARSITY)
|
idamnjanovic@40
|
121 w = ompmex(D,data(:,i),[],[],thresh,1,-1,0);
|
idamnjanovic@40
|
122 else
|
idamnjanovic@40
|
123 w = omp2mex(D,data(:,i),[],[],[],thresh,0,-1,maxatoms,0);
|
idamnjanovic@40
|
124 end
|
idamnjanovic@40
|
125
|
idamnjanovic@40
|
126 spind=find(w);
|
idamnjanovic@40
|
127
|
idamnjanovic@40
|
128 residual = data(:,i) - D * w;
|
idamnjanovic@40
|
129
|
idamnjanovic@40
|
130 if (lambda~=1)
|
idamnjanovic@40
|
131 C = C *(1/ lambda);
|
idamnjanovic@40
|
132 end
|
idamnjanovic@40
|
133
|
idamnjanovic@40
|
134 u = C(:,spind) * w(spind);
|
idamnjanovic@40
|
135
|
idamnjanovic@40
|
136
|
idamnjanovic@40
|
137 alfa = 1/(1 + w' * u);
|
idamnjanovic@40
|
138
|
idamnjanovic@40
|
139 D = D + (alfa * residual) * u';
|
idamnjanovic@40
|
140
|
idamnjanovic@40
|
141
|
idamnjanovic@40
|
142 C = C - (alfa * u)* u';
|
idamnjanovic@40
|
143
|
idamnjanovic@40
|
144 end
|
idamnjanovic@40
|
145
|
idamnjanovic@40
|
146 Dictionary = D;
|
idamnjanovic@40
|
147
|
idamnjanovic@40
|
148 end
|