Mercurial > hg > camir-aes2014
comparison toolboxes/MIRtoolbox1.3.2/MIRToolbox/@mirclassify/mirclassify.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function c = mirclassify(a,da,t,dt,varargin) | |
2 % c = mirclassify(test,features_test,train,features_train) classifies the | |
3 % audio sequence(s) contained in the audio object test, along the | |
4 % analytic feature(s) features_test, following the supervised | |
5 % learning of a training set defined by the audio object train and | |
6 % the corresponding analytic feature(s) features_train. | |
7 % * The analytic feature(s) features_test should *not* be frame | |
8 % decomposed. Frame-decomposed data should first be | |
9 % summarized, using for instance mirmean or mirstd. | |
10 % * Multiple analytic features have to be grouped into one array | |
11 % of cells. | |
12 % You can also integrate your own arrays of numbers computed outside | |
13 % MIRtoolbox as part of the features. These arrays should be | |
14 % given as matrices where each successive column is the analysis | |
15 % of each successive file. | |
16 % Example: | |
17 % mirclassify(test, mfcc(test), train, mfcc(train)) | |
18 % mirclassify(test, {mfcc(test), centroid(test)}, ... | |
19 % train, {mfcc(train), centroid(train)}) | |
20 % Optional argument: | |
21 % mirclassify(...,'Nearest') uses the minimum distance strategy. | |
22 % (by default) | |
23 % mirclassify(...,'Nearest',k) uses the k-nearest-neighbour strategy. | |
24 % Default value: k = 1, corresponding to the minimum distance | |
25 % strategy. | |
26 % mirclassify(...,'GMM',ng) uses a gaussian mixture model. Each class is | |
27 % modeled by at most ng gaussians. | |
28 % Default value: ng = 1. | |
29 % Additionnally, the type of mixture model can be specified, | |
30 % using the set of value proposed in the gmm function: i.e., | |
31 % 'spherical','diag','full' (default value) and 'ppca'. | |
32 % (cf. help gmm) | |
33 % Requires the Netlab toolbox. | |
34 | |
35 lab = get(t,'Label'); | |
36 c.labtraining = lab; | |
37 rlab = get(a,'Label'); | |
38 c.labtest = rlab; | |
39 [k,ncentres,covartype,kmiter,emiter,d,norml,mahl] = scanargin(varargin); | |
40 disp('Classifying...') | |
41 if not(iscell(dt)) | |
42 dt = {dt}; | |
43 end | |
44 lvt = length(get(t,'Data')); | |
45 vt = []; | |
46 for i = 1:length(dt) | |
47 if isnumeric(dt{i}) | |
48 d = cell(1,size(dt{i},2)); | |
49 for j = 1:size(dt{i},2) | |
50 d{j} = dt{i}(:,j); | |
51 end | |
52 else | |
53 d = get(dt{i},'Data'); | |
54 end | |
55 vt = integrate(vt,d,lvt,norml); | |
56 if isa(dt{i},'scalar') | |
57 m = mode(dt{i}); | |
58 if not(isempty(m)) | |
59 vt = integrate(vt,m,lvt,norml); | |
60 end | |
61 end | |
62 end | |
63 c.training = vt; | |
64 dim = size(vt,1); | |
65 if not(iscell(da)) | |
66 da = {da}; | |
67 end | |
68 lva = length(get(a,'Data')); | |
69 va = []; | |
70 for i = 1:length(da) | |
71 if isnumeric(da{i}) | |
72 d = cell(1,size(da{i},2)); | |
73 for j = 1:size(da{i},2) | |
74 d{j} = da{i}(:,j); | |
75 end | |
76 else | |
77 d = get(da{i},'Data'); | |
78 end | |
79 va = integrate(va,d,lva,norml); | |
80 if isa(da{i},'scalar') | |
81 m = mode(da{i}); | |
82 if not(isempty(m)) | |
83 va = integrate(va,m,lva,norml); | |
84 end | |
85 end | |
86 end | |
87 c.test = va; | |
88 c.nbobs = lvt; | |
89 totva = [vt va]; | |
90 mahl = cov(totva'); | |
91 if k % k-Nearest Neighbour | |
92 c.nbparam = lvt; | |
93 for l = 1:lva | |
94 [sv,idx] = sort(distance(va(:,l),vt,d,mahl)); | |
95 labs = cell(0); % Class labels | |
96 founds = []; % Number of found elements in each class | |
97 for i = idx(1:k) | |
98 labi = lab{i}; | |
99 found = 0; | |
100 for j = 1:length(labs) | |
101 if isequal(labi,labs{j}) | |
102 found = j; | |
103 end | |
104 end | |
105 if found | |
106 founds(found) = founds(found)+1; | |
107 else | |
108 labs{end+1} = labi; | |
109 founds(end+1) = 1; | |
110 end | |
111 end | |
112 [b ib] = max(founds); | |
113 c.classes{l} = labs{ib}; | |
114 end | |
115 elseif ncentres % Gaussian Mixture Model | |
116 labs = cell(0); % Class labels | |
117 founds = cell(0); % Elements associated to each label. | |
118 for i = 1:lvt | |
119 labi = lab{i}; | |
120 found = 0; | |
121 for j = 1:length(labs) | |
122 if isequal(labi,labs{j}) | |
123 founds{j}(end+1) = i; | |
124 found = 1; | |
125 end | |
126 end | |
127 if not(found) | |
128 labs{end+1} = labi; | |
129 founds{end+1} = i; | |
130 end | |
131 end | |
132 options = zeros(1, 18); | |
133 options(2:3) = 1e-4; | |
134 options(4) = 1e-6; | |
135 options(16) = 1e-8; | |
136 options(17) = 0.1; | |
137 options(1) = 0; %Prints out error values, -1 else | |
138 c.nbparam = 0; | |
139 OK = 0; | |
140 while not(OK) | |
141 OK = 1; | |
142 for i = 1:length(labs) | |
143 options(14) = kmiter; | |
144 try | |
145 mix{i} = gmm(dim,ncentres,covartype); | |
146 catch | |
147 error('ERROR IN CLASSIFY: Netlab toolbox not installed.'); | |
148 end | |
149 mix{i} = netlabgmminit(mix{i},vt(:,founds{i})',options); | |
150 options(5) = 1; | |
151 options(14) = emiter; | |
152 try | |
153 mix{i} = gmmem(mix{i},vt(:,founds{i})',options); | |
154 c.nbparam = c.nbparam + ... | |
155 length(mix{i}.centres(:)) + length(mix{i}.covars(:)); | |
156 catch | |
157 err = lasterr; | |
158 warning('WARNING IN CLASSIFY: Problem when calling GMMEM:'); | |
159 disp(err); | |
160 disp('Let us try again...'); | |
161 OK = 0; | |
162 end | |
163 end | |
164 end | |
165 pr = zeros(lva,length(labs)); | |
166 for i = 1:length(labs) | |
167 prior = length(founds{i})/lvt; | |
168 pr(:,i) = prior * gmmprob(mix{i},va'); | |
169 %c.post{i} = gmmpost(mix{i},va'); | |
170 end | |
171 [mm ib] = max(pr'); | |
172 for i = 1:lva | |
173 c.classes{i} = labs{ib(i)}; | |
174 end | |
175 end | |
176 if isempty(rlab) | |
177 c.correct = NaN; | |
178 else | |
179 correct = 0; | |
180 for i = 1:lva | |
181 if isequal(c.classes{i},rlab{i}) | |
182 correct = correct + 1; | |
183 end | |
184 end | |
185 c.correct = correct / lva; | |
186 end | |
187 c = class(c,'mirclassify'); | |
188 | |
189 | |
190 function vt = integrate(vt,v,lvt,norml) | |
191 vtl = []; | |
192 for l = 1:lvt | |
193 vl = v{l}; | |
194 if iscell(vl) | |
195 vl = vl{1}; | |
196 end | |
197 if iscell(vl) | |
198 vl = vl{1}; | |
199 end | |
200 if size(vl,2) > 1 | |
201 mirerror('MIRCLASSIFY','The analytic features guiding the classification should not be frame-decomposed.'); | |
202 end | |
203 vtl(:,l) = vl; | |
204 end | |
205 if norml | |
206 dnom = repmat(std(vtl,0,2),[1 size(vtl,2)]); | |
207 dnom = dnom + (dnom == 0); % In order to avoid division by 0 | |
208 vtl = (vtl - repmat(mean(vtl,2),[1 size(vtl,2)])) ./ dnom; | |
209 end | |
210 vt(end+1:end+size(vtl,1),:) = vtl; | |
211 | |
212 | |
213 function [k,ncentres,covartype,kmiter,emiter,d,norml,mahl] = scanargin(v) | |
214 k = 1; | |
215 d = 0; | |
216 i = 1; | |
217 ncentres = 0; | |
218 covartype = 'full'; | |
219 kmiter = 10; | |
220 emiter = 100; | |
221 norml = 1; | |
222 mahl = 1; | |
223 while i <= length(v) | |
224 arg = v{i}; | |
225 if ischar(arg) && strcmpi(arg,'Nearest') | |
226 k = 1; | |
227 if length(v)>i && isnumeric(v{i+1}) | |
228 i = i+1; | |
229 k = v{i}; | |
230 end | |
231 elseif ischar(arg) && strcmpi(arg,'GMM') | |
232 k = 0; | |
233 ncentres = 1; | |
234 if length(v)>i | |
235 if isnumeric(v{i+1}) | |
236 i = i+1; | |
237 ncentres = v{i}; | |
238 if length(v)>i && ischar(v{i+1}) | |
239 i = i+1; | |
240 covartype = v{i}; | |
241 end | |
242 elseif ischar(v{i+1}) | |
243 i = i+1; | |
244 covartype = v{i}; | |
245 if length(v)>i && isnumeric(v{i+1}) | |
246 i = i+1; | |
247 ncentres = v{i}; | |
248 end | |
249 end | |
250 end | |
251 elseif isnumeric(arg) | |
252 k = v{i}; | |
253 else | |
254 error('ERROR IN MIRCLASSIFY: Syntax error. See help mirclassify.'); | |
255 end | |
256 i = i+1; | |
257 end | |
258 | |
259 | |
260 function y = distance(a,t,d,mahl) | |
261 | |
262 for i = 1:size(t,2) | |
263 if det(mahl) > 0 % more generally, uses cond | |
264 lham = inv(mahl); | |
265 else | |
266 lham = pinv(mahl); | |
267 end | |
268 y(i) = sqrt((a - t(:,i))'*lham*(a - t(:,i))); | |
269 end | |
270 %y = sqrt(sum(repmat(a,[1,size(t,2)])-t,1).^2); |