annotate toolboxes/MIRtoolbox1.3.2/MIRToolbox/@mirclassify/mirclassify.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function c = mirclassify(a,da,t,dt,varargin)
wolffd@0 2 % c = mirclassify(test,features_test,train,features_train) classifies the
wolffd@0 3 % audio sequence(s) contained in the audio object test, along the
wolffd@0 4 % analytic feature(s) features_test, following the supervised
wolffd@0 5 % learning of a training set defined by the audio object train and
wolffd@0 6 % the corresponding analytic feature(s) features_train.
wolffd@0 7 % * The analytic feature(s) features_test should *not* be frame
wolffd@0 8 % decomposed. Frame-decomposed data should first be
wolffd@0 9 % summarized, using for instance mirmean or mirstd.
wolffd@0 10 % * Multiple analytic features have to be grouped into one array
wolffd@0 11 % of cells.
wolffd@0 12 % You can also integrate your own arrays of numbers computed outside
wolffd@0 13 % MIRtoolbox as part of the features. These arrays should be
wolffd@0 14 % given as matrices where each successive column is the analysis
wolffd@0 15 % of each successive file.
wolffd@0 16 % Example:
wolffd@0 17 % mirclassify(test, mfcc(test), train, mfcc(train))
wolffd@0 18 % mirclassify(test, {mfcc(test), centroid(test)}, ...
wolffd@0 19 % train, {mfcc(train), centroid(train)})
wolffd@0 20 % Optional argument:
wolffd@0 21 % mirclassify(...,'Nearest') uses the minimum distance strategy.
wolffd@0 22 % (by default)
wolffd@0 23 % mirclassify(...,'Nearest',k) uses the k-nearest-neighbour strategy.
wolffd@0 24 % Default value: k = 1, corresponding to the minimum distance
wolffd@0 25 % strategy.
wolffd@0 26 % mirclassify(...,'GMM',ng) uses a gaussian mixture model. Each class is
wolffd@0 27 % modeled by at most ng gaussians.
wolffd@0 28 % Default value: ng = 1.
wolffd@0 29 % Additionnally, the type of mixture model can be specified,
wolffd@0 30 % using the set of value proposed in the gmm function: i.e.,
wolffd@0 31 % 'spherical','diag','full' (default value) and 'ppca'.
wolffd@0 32 % (cf. help gmm)
wolffd@0 33 % Requires the Netlab toolbox.
wolffd@0 34
wolffd@0 35 lab = get(t,'Label');
wolffd@0 36 c.labtraining = lab;
wolffd@0 37 rlab = get(a,'Label');
wolffd@0 38 c.labtest = rlab;
wolffd@0 39 [k,ncentres,covartype,kmiter,emiter,d,norml,mahl] = scanargin(varargin);
wolffd@0 40 disp('Classifying...')
wolffd@0 41 if not(iscell(dt))
wolffd@0 42 dt = {dt};
wolffd@0 43 end
wolffd@0 44 lvt = length(get(t,'Data'));
wolffd@0 45 vt = [];
wolffd@0 46 for i = 1:length(dt)
wolffd@0 47 if isnumeric(dt{i})
wolffd@0 48 d = cell(1,size(dt{i},2));
wolffd@0 49 for j = 1:size(dt{i},2)
wolffd@0 50 d{j} = dt{i}(:,j);
wolffd@0 51 end
wolffd@0 52 else
wolffd@0 53 d = get(dt{i},'Data');
wolffd@0 54 end
wolffd@0 55 vt = integrate(vt,d,lvt,norml);
wolffd@0 56 if isa(dt{i},'scalar')
wolffd@0 57 m = mode(dt{i});
wolffd@0 58 if not(isempty(m))
wolffd@0 59 vt = integrate(vt,m,lvt,norml);
wolffd@0 60 end
wolffd@0 61 end
wolffd@0 62 end
wolffd@0 63 c.training = vt;
wolffd@0 64 dim = size(vt,1);
wolffd@0 65 if not(iscell(da))
wolffd@0 66 da = {da};
wolffd@0 67 end
wolffd@0 68 lva = length(get(a,'Data'));
wolffd@0 69 va = [];
wolffd@0 70 for i = 1:length(da)
wolffd@0 71 if isnumeric(da{i})
wolffd@0 72 d = cell(1,size(da{i},2));
wolffd@0 73 for j = 1:size(da{i},2)
wolffd@0 74 d{j} = da{i}(:,j);
wolffd@0 75 end
wolffd@0 76 else
wolffd@0 77 d = get(da{i},'Data');
wolffd@0 78 end
wolffd@0 79 va = integrate(va,d,lva,norml);
wolffd@0 80 if isa(da{i},'scalar')
wolffd@0 81 m = mode(da{i});
wolffd@0 82 if not(isempty(m))
wolffd@0 83 va = integrate(va,m,lva,norml);
wolffd@0 84 end
wolffd@0 85 end
wolffd@0 86 end
wolffd@0 87 c.test = va;
wolffd@0 88 c.nbobs = lvt;
wolffd@0 89 totva = [vt va];
wolffd@0 90 mahl = cov(totva');
wolffd@0 91 if k % k-Nearest Neighbour
wolffd@0 92 c.nbparam = lvt;
wolffd@0 93 for l = 1:lva
wolffd@0 94 [sv,idx] = sort(distance(va(:,l),vt,d,mahl));
wolffd@0 95 labs = cell(0); % Class labels
wolffd@0 96 founds = []; % Number of found elements in each class
wolffd@0 97 for i = idx(1:k)
wolffd@0 98 labi = lab{i};
wolffd@0 99 found = 0;
wolffd@0 100 for j = 1:length(labs)
wolffd@0 101 if isequal(labi,labs{j})
wolffd@0 102 found = j;
wolffd@0 103 end
wolffd@0 104 end
wolffd@0 105 if found
wolffd@0 106 founds(found) = founds(found)+1;
wolffd@0 107 else
wolffd@0 108 labs{end+1} = labi;
wolffd@0 109 founds(end+1) = 1;
wolffd@0 110 end
wolffd@0 111 end
wolffd@0 112 [b ib] = max(founds);
wolffd@0 113 c.classes{l} = labs{ib};
wolffd@0 114 end
wolffd@0 115 elseif ncentres % Gaussian Mixture Model
wolffd@0 116 labs = cell(0); % Class labels
wolffd@0 117 founds = cell(0); % Elements associated to each label.
wolffd@0 118 for i = 1:lvt
wolffd@0 119 labi = lab{i};
wolffd@0 120 found = 0;
wolffd@0 121 for j = 1:length(labs)
wolffd@0 122 if isequal(labi,labs{j})
wolffd@0 123 founds{j}(end+1) = i;
wolffd@0 124 found = 1;
wolffd@0 125 end
wolffd@0 126 end
wolffd@0 127 if not(found)
wolffd@0 128 labs{end+1} = labi;
wolffd@0 129 founds{end+1} = i;
wolffd@0 130 end
wolffd@0 131 end
wolffd@0 132 options = zeros(1, 18);
wolffd@0 133 options(2:3) = 1e-4;
wolffd@0 134 options(4) = 1e-6;
wolffd@0 135 options(16) = 1e-8;
wolffd@0 136 options(17) = 0.1;
wolffd@0 137 options(1) = 0; %Prints out error values, -1 else
wolffd@0 138 c.nbparam = 0;
wolffd@0 139 OK = 0;
wolffd@0 140 while not(OK)
wolffd@0 141 OK = 1;
wolffd@0 142 for i = 1:length(labs)
wolffd@0 143 options(14) = kmiter;
wolffd@0 144 try
wolffd@0 145 mix{i} = gmm(dim,ncentres,covartype);
wolffd@0 146 catch
wolffd@0 147 error('ERROR IN CLASSIFY: Netlab toolbox not installed.');
wolffd@0 148 end
wolffd@0 149 mix{i} = netlabgmminit(mix{i},vt(:,founds{i})',options);
wolffd@0 150 options(5) = 1;
wolffd@0 151 options(14) = emiter;
wolffd@0 152 try
wolffd@0 153 mix{i} = gmmem(mix{i},vt(:,founds{i})',options);
wolffd@0 154 c.nbparam = c.nbparam + ...
wolffd@0 155 length(mix{i}.centres(:)) + length(mix{i}.covars(:));
wolffd@0 156 catch
wolffd@0 157 err = lasterr;
wolffd@0 158 warning('WARNING IN CLASSIFY: Problem when calling GMMEM:');
wolffd@0 159 disp(err);
wolffd@0 160 disp('Let us try again...');
wolffd@0 161 OK = 0;
wolffd@0 162 end
wolffd@0 163 end
wolffd@0 164 end
wolffd@0 165 pr = zeros(lva,length(labs));
wolffd@0 166 for i = 1:length(labs)
wolffd@0 167 prior = length(founds{i})/lvt;
wolffd@0 168 pr(:,i) = prior * gmmprob(mix{i},va');
wolffd@0 169 %c.post{i} = gmmpost(mix{i},va');
wolffd@0 170 end
wolffd@0 171 [mm ib] = max(pr');
wolffd@0 172 for i = 1:lva
wolffd@0 173 c.classes{i} = labs{ib(i)};
wolffd@0 174 end
wolffd@0 175 end
wolffd@0 176 if isempty(rlab)
wolffd@0 177 c.correct = NaN;
wolffd@0 178 else
wolffd@0 179 correct = 0;
wolffd@0 180 for i = 1:lva
wolffd@0 181 if isequal(c.classes{i},rlab{i})
wolffd@0 182 correct = correct + 1;
wolffd@0 183 end
wolffd@0 184 end
wolffd@0 185 c.correct = correct / lva;
wolffd@0 186 end
wolffd@0 187 c = class(c,'mirclassify');
wolffd@0 188
wolffd@0 189
wolffd@0 190 function vt = integrate(vt,v,lvt,norml)
wolffd@0 191 vtl = [];
wolffd@0 192 for l = 1:lvt
wolffd@0 193 vl = v{l};
wolffd@0 194 if iscell(vl)
wolffd@0 195 vl = vl{1};
wolffd@0 196 end
wolffd@0 197 if iscell(vl)
wolffd@0 198 vl = vl{1};
wolffd@0 199 end
wolffd@0 200 if size(vl,2) > 1
wolffd@0 201 mirerror('MIRCLASSIFY','The analytic features guiding the classification should not be frame-decomposed.');
wolffd@0 202 end
wolffd@0 203 vtl(:,l) = vl;
wolffd@0 204 end
wolffd@0 205 if norml
wolffd@0 206 dnom = repmat(std(vtl,0,2),[1 size(vtl,2)]);
wolffd@0 207 dnom = dnom + (dnom == 0); % In order to avoid division by 0
wolffd@0 208 vtl = (vtl - repmat(mean(vtl,2),[1 size(vtl,2)])) ./ dnom;
wolffd@0 209 end
wolffd@0 210 vt(end+1:end+size(vtl,1),:) = vtl;
wolffd@0 211
wolffd@0 212
wolffd@0 213 function [k,ncentres,covartype,kmiter,emiter,d,norml,mahl] = scanargin(v)
wolffd@0 214 k = 1;
wolffd@0 215 d = 0;
wolffd@0 216 i = 1;
wolffd@0 217 ncentres = 0;
wolffd@0 218 covartype = 'full';
wolffd@0 219 kmiter = 10;
wolffd@0 220 emiter = 100;
wolffd@0 221 norml = 1;
wolffd@0 222 mahl = 1;
wolffd@0 223 while i <= length(v)
wolffd@0 224 arg = v{i};
wolffd@0 225 if ischar(arg) && strcmpi(arg,'Nearest')
wolffd@0 226 k = 1;
wolffd@0 227 if length(v)>i && isnumeric(v{i+1})
wolffd@0 228 i = i+1;
wolffd@0 229 k = v{i};
wolffd@0 230 end
wolffd@0 231 elseif ischar(arg) && strcmpi(arg,'GMM')
wolffd@0 232 k = 0;
wolffd@0 233 ncentres = 1;
wolffd@0 234 if length(v)>i
wolffd@0 235 if isnumeric(v{i+1})
wolffd@0 236 i = i+1;
wolffd@0 237 ncentres = v{i};
wolffd@0 238 if length(v)>i && ischar(v{i+1})
wolffd@0 239 i = i+1;
wolffd@0 240 covartype = v{i};
wolffd@0 241 end
wolffd@0 242 elseif ischar(v{i+1})
wolffd@0 243 i = i+1;
wolffd@0 244 covartype = v{i};
wolffd@0 245 if length(v)>i && isnumeric(v{i+1})
wolffd@0 246 i = i+1;
wolffd@0 247 ncentres = v{i};
wolffd@0 248 end
wolffd@0 249 end
wolffd@0 250 end
wolffd@0 251 elseif isnumeric(arg)
wolffd@0 252 k = v{i};
wolffd@0 253 else
wolffd@0 254 error('ERROR IN MIRCLASSIFY: Syntax error. See help mirclassify.');
wolffd@0 255 end
wolffd@0 256 i = i+1;
wolffd@0 257 end
wolffd@0 258
wolffd@0 259
wolffd@0 260 function y = distance(a,t,d,mahl)
wolffd@0 261
wolffd@0 262 for i = 1:size(t,2)
wolffd@0 263 if det(mahl) > 0 % more generally, uses cond
wolffd@0 264 lham = inv(mahl);
wolffd@0 265 else
wolffd@0 266 lham = pinv(mahl);
wolffd@0 267 end
wolffd@0 268 y(i) = sqrt((a - t(:,i))'*lham*(a - t(:,i)));
wolffd@0 269 end
wolffd@0 270 %y = sqrt(sum(repmat(a,[1,size(t,2)])-t,1).^2);