comparison toolboxes/MIRtoolbox1.3.2/MIRToolbox/@mirclassify/mirclassify.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function c = mirclassify(a,da,t,dt,varargin)
2 % c = mirclassify(test,features_test,train,features_train) classifies the
3 % audio sequence(s) contained in the audio object test, along the
4 % analytic feature(s) features_test, following the supervised
5 % learning of a training set defined by the audio object train and
6 % the corresponding analytic feature(s) features_train.
7 % * The analytic feature(s) features_test should *not* be frame
8 % decomposed. Frame-decomposed data should first be
9 % summarized, using for instance mirmean or mirstd.
10 % * Multiple analytic features have to be grouped into one array
11 % of cells.
12 % You can also integrate your own arrays of numbers computed outside
13 % MIRtoolbox as part of the features. These arrays should be
14 % given as matrices where each successive column is the analysis
15 % of each successive file.
16 % Example:
17 % mirclassify(test, mfcc(test), train, mfcc(train))
18 % mirclassify(test, {mfcc(test), centroid(test)}, ...
19 % train, {mfcc(train), centroid(train)})
20 % Optional argument:
21 % mirclassify(...,'Nearest') uses the minimum distance strategy.
22 % (by default)
23 % mirclassify(...,'Nearest',k) uses the k-nearest-neighbour strategy.
24 % Default value: k = 1, corresponding to the minimum distance
25 % strategy.
26 % mirclassify(...,'GMM',ng) uses a gaussian mixture model. Each class is
27 % modeled by at most ng gaussians.
28 % Default value: ng = 1.
29 % Additionnally, the type of mixture model can be specified,
30 % using the set of value proposed in the gmm function: i.e.,
31 % 'spherical','diag','full' (default value) and 'ppca'.
32 % (cf. help gmm)
33 % Requires the Netlab toolbox.
34
35 lab = get(t,'Label');
36 c.labtraining = lab;
37 rlab = get(a,'Label');
38 c.labtest = rlab;
39 [k,ncentres,covartype,kmiter,emiter,d,norml,mahl] = scanargin(varargin);
40 disp('Classifying...')
41 if not(iscell(dt))
42 dt = {dt};
43 end
44 lvt = length(get(t,'Data'));
45 vt = [];
46 for i = 1:length(dt)
47 if isnumeric(dt{i})
48 d = cell(1,size(dt{i},2));
49 for j = 1:size(dt{i},2)
50 d{j} = dt{i}(:,j);
51 end
52 else
53 d = get(dt{i},'Data');
54 end
55 vt = integrate(vt,d,lvt,norml);
56 if isa(dt{i},'scalar')
57 m = mode(dt{i});
58 if not(isempty(m))
59 vt = integrate(vt,m,lvt,norml);
60 end
61 end
62 end
63 c.training = vt;
64 dim = size(vt,1);
65 if not(iscell(da))
66 da = {da};
67 end
68 lva = length(get(a,'Data'));
69 va = [];
70 for i = 1:length(da)
71 if isnumeric(da{i})
72 d = cell(1,size(da{i},2));
73 for j = 1:size(da{i},2)
74 d{j} = da{i}(:,j);
75 end
76 else
77 d = get(da{i},'Data');
78 end
79 va = integrate(va,d,lva,norml);
80 if isa(da{i},'scalar')
81 m = mode(da{i});
82 if not(isempty(m))
83 va = integrate(va,m,lva,norml);
84 end
85 end
86 end
87 c.test = va;
88 c.nbobs = lvt;
89 totva = [vt va];
90 mahl = cov(totva');
91 if k % k-Nearest Neighbour
92 c.nbparam = lvt;
93 for l = 1:lva
94 [sv,idx] = sort(distance(va(:,l),vt,d,mahl));
95 labs = cell(0); % Class labels
96 founds = []; % Number of found elements in each class
97 for i = idx(1:k)
98 labi = lab{i};
99 found = 0;
100 for j = 1:length(labs)
101 if isequal(labi,labs{j})
102 found = j;
103 end
104 end
105 if found
106 founds(found) = founds(found)+1;
107 else
108 labs{end+1} = labi;
109 founds(end+1) = 1;
110 end
111 end
112 [b ib] = max(founds);
113 c.classes{l} = labs{ib};
114 end
115 elseif ncentres % Gaussian Mixture Model
116 labs = cell(0); % Class labels
117 founds = cell(0); % Elements associated to each label.
118 for i = 1:lvt
119 labi = lab{i};
120 found = 0;
121 for j = 1:length(labs)
122 if isequal(labi,labs{j})
123 founds{j}(end+1) = i;
124 found = 1;
125 end
126 end
127 if not(found)
128 labs{end+1} = labi;
129 founds{end+1} = i;
130 end
131 end
132 options = zeros(1, 18);
133 options(2:3) = 1e-4;
134 options(4) = 1e-6;
135 options(16) = 1e-8;
136 options(17) = 0.1;
137 options(1) = 0; %Prints out error values, -1 else
138 c.nbparam = 0;
139 OK = 0;
140 while not(OK)
141 OK = 1;
142 for i = 1:length(labs)
143 options(14) = kmiter;
144 try
145 mix{i} = gmm(dim,ncentres,covartype);
146 catch
147 error('ERROR IN CLASSIFY: Netlab toolbox not installed.');
148 end
149 mix{i} = netlabgmminit(mix{i},vt(:,founds{i})',options);
150 options(5) = 1;
151 options(14) = emiter;
152 try
153 mix{i} = gmmem(mix{i},vt(:,founds{i})',options);
154 c.nbparam = c.nbparam + ...
155 length(mix{i}.centres(:)) + length(mix{i}.covars(:));
156 catch
157 err = lasterr;
158 warning('WARNING IN CLASSIFY: Problem when calling GMMEM:');
159 disp(err);
160 disp('Let us try again...');
161 OK = 0;
162 end
163 end
164 end
165 pr = zeros(lva,length(labs));
166 for i = 1:length(labs)
167 prior = length(founds{i})/lvt;
168 pr(:,i) = prior * gmmprob(mix{i},va');
169 %c.post{i} = gmmpost(mix{i},va');
170 end
171 [mm ib] = max(pr');
172 for i = 1:lva
173 c.classes{i} = labs{ib(i)};
174 end
175 end
176 if isempty(rlab)
177 c.correct = NaN;
178 else
179 correct = 0;
180 for i = 1:lva
181 if isequal(c.classes{i},rlab{i})
182 correct = correct + 1;
183 end
184 end
185 c.correct = correct / lva;
186 end
187 c = class(c,'mirclassify');
188
189
190 function vt = integrate(vt,v,lvt,norml)
191 vtl = [];
192 for l = 1:lvt
193 vl = v{l};
194 if iscell(vl)
195 vl = vl{1};
196 end
197 if iscell(vl)
198 vl = vl{1};
199 end
200 if size(vl,2) > 1
201 mirerror('MIRCLASSIFY','The analytic features guiding the classification should not be frame-decomposed.');
202 end
203 vtl(:,l) = vl;
204 end
205 if norml
206 dnom = repmat(std(vtl,0,2),[1 size(vtl,2)]);
207 dnom = dnom + (dnom == 0); % In order to avoid division by 0
208 vtl = (vtl - repmat(mean(vtl,2),[1 size(vtl,2)])) ./ dnom;
209 end
210 vt(end+1:end+size(vtl,1),:) = vtl;
211
212
213 function [k,ncentres,covartype,kmiter,emiter,d,norml,mahl] = scanargin(v)
214 k = 1;
215 d = 0;
216 i = 1;
217 ncentres = 0;
218 covartype = 'full';
219 kmiter = 10;
220 emiter = 100;
221 norml = 1;
222 mahl = 1;
223 while i <= length(v)
224 arg = v{i};
225 if ischar(arg) && strcmpi(arg,'Nearest')
226 k = 1;
227 if length(v)>i && isnumeric(v{i+1})
228 i = i+1;
229 k = v{i};
230 end
231 elseif ischar(arg) && strcmpi(arg,'GMM')
232 k = 0;
233 ncentres = 1;
234 if length(v)>i
235 if isnumeric(v{i+1})
236 i = i+1;
237 ncentres = v{i};
238 if length(v)>i && ischar(v{i+1})
239 i = i+1;
240 covartype = v{i};
241 end
242 elseif ischar(v{i+1})
243 i = i+1;
244 covartype = v{i};
245 if length(v)>i && isnumeric(v{i+1})
246 i = i+1;
247 ncentres = v{i};
248 end
249 end
250 end
251 elseif isnumeric(arg)
252 k = v{i};
253 else
254 error('ERROR IN MIRCLASSIFY: Syntax error. See help mirclassify.');
255 end
256 i = i+1;
257 end
258
259
260 function y = distance(a,t,d,mahl)
261
262 for i = 1:size(t,2)
263 if det(mahl) > 0 % more generally, uses cond
264 lham = inv(mahl);
265 else
266 lham = pinv(mahl);
267 end
268 y(i) = sqrt((a - t(:,i))'*lham*(a - t(:,i)));
269 end
270 %y = sqrt(sum(repmat(a,[1,size(t,2)])-t,1).^2);