diff toolboxes/MIRtoolbox1.3.2/MIRToolbox/@mirclassify/mirclassify.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/MIRtoolbox1.3.2/MIRToolbox/@mirclassify/mirclassify.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,270 @@
+function c = mirclassify(a,da,t,dt,varargin)
+%   c = mirclassify(test,features_test,train,features_train) classifies the
+%       audio sequence(s) contained in the audio object test, along the
+%       analytic feature(s) features_test, following the supervised
+%       learning of a training set defined by the audio object train and
+%       the corresponding analytic feature(s) features_train.
+%           * The analytic feature(s) features_test should *not* be frame 
+%               decomposed. Frame-decomposed data should first be 
+%               summarized, using for instance mirmean or mirstd.
+%           * Multiple analytic features have to be grouped into one array 
+%               of cells.
+%       You can also integrate your own arrays of numbers computed outside
+%           MIRtoolbox as part of the features. These arrays should be 
+%           given as matrices where each successive column is the analysis 
+%           of each successive file.
+%   Example:
+%       mirclassify(test, mfcc(test), train, mfcc(train))
+%       mirclassify(test, {mfcc(test), centroid(test)}, ...
+%                train, {mfcc(train), centroid(train)})
+%   Optional argument:
+%       mirclassify(...,'Nearest') uses the minimum distance strategy.
+%           (by default)
+%       mirclassify(...,'Nearest',k) uses the k-nearest-neighbour strategy.
+%           Default value: k = 1, corresponding to the minimum distance
+%               strategy.
+%       mirclassify(...,'GMM',ng) uses a gaussian mixture model. Each class is
+%           modeled by at most ng gaussians.
+%           Default value: ng = 1.
+%           Additionnally, the type of mixture model can be specified,
+%           using the set of value proposed in the gmm function: i.e.,
+%           'spherical','diag','full' (default value) and 'ppca'.
+%               (cf. help gmm)
+%           Requires the Netlab toolbox.
+
+lab = get(t,'Label');
+c.labtraining = lab;
+rlab = get(a,'Label');
+c.labtest = rlab;
+[k,ncentres,covartype,kmiter,emiter,d,norml,mahl] = scanargin(varargin);
+disp('Classifying...')
+if not(iscell(dt))
+    dt = {dt};
+end
+lvt = length(get(t,'Data'));
+vt = [];
+for i = 1:length(dt)
+    if isnumeric(dt{i})
+        d = cell(1,size(dt{i},2));
+        for j = 1:size(dt{i},2)
+            d{j} = dt{i}(:,j);
+        end
+    else
+        d = get(dt{i},'Data');
+    end
+    vt = integrate(vt,d,lvt,norml);
+    if isa(dt{i},'scalar')
+        m = mode(dt{i});
+        if not(isempty(m))
+            vt = integrate(vt,m,lvt,norml);
+        end
+    end
+end
+c.training = vt;
+dim = size(vt,1);
+if not(iscell(da))
+    da = {da};
+end
+lva = length(get(a,'Data'));
+va = [];
+for i = 1:length(da)
+    if isnumeric(da{i})
+        d = cell(1,size(da{i},2));
+        for j = 1:size(da{i},2)
+            d{j} = da{i}(:,j);
+        end
+    else
+        d = get(da{i},'Data');
+    end
+    va = integrate(va,d,lva,norml);
+    if isa(da{i},'scalar')
+        m = mode(da{i});
+        if not(isempty(m))
+            va = integrate(va,m,lva,norml);
+        end
+    end
+end
+c.test = va;
+c.nbobs = lvt;
+totva = [vt va];
+mahl = cov(totva');
+if k                % k-Nearest Neighbour
+    c.nbparam = lvt;
+    for l = 1:lva
+        [sv,idx] = sort(distance(va(:,l),vt,d,mahl));
+        labs = cell(0); % Class labels
+        founds = [];    % Number of found elements in each class
+        for i = idx(1:k)
+            labi = lab{i};
+            found = 0;
+            for j = 1:length(labs)
+                if isequal(labi,labs{j})
+                    found = j;
+                end
+            end
+            if found
+                founds(found) = founds(found)+1;
+            else
+                labs{end+1} = labi;
+                founds(end+1) = 1;
+            end
+        end
+        [b ib] = max(founds);
+        c.classes{l} = labs{ib};
+    end
+elseif ncentres     % Gaussian Mixture Model
+    labs = cell(0);    % Class labels
+    founds = cell(0);  % Elements associated to each label.
+    for i = 1:lvt
+        labi = lab{i};
+        found = 0;
+        for j = 1:length(labs)
+            if isequal(labi,labs{j})
+                founds{j}(end+1) = i;
+                found = 1;
+            end
+        end
+        if not(found)
+            labs{end+1} = labi;
+            founds{end+1} = i;
+        end
+    end
+    options      = zeros(1, 18);
+    options(2:3) = 1e-4;
+    options(4)   = 1e-6;
+    options(16)  = 1e-8;
+    options(17)  = 0.1;
+    options(1)   = 0; %Prints out error values, -1 else
+    c.nbparam = 0;
+    OK = 0;
+    while not(OK)
+        OK = 1;
+        for i = 1:length(labs)
+            options(14)  = kmiter;
+            try
+                mix{i} = gmm(dim,ncentres,covartype);
+            catch
+                error('ERROR IN CLASSIFY: Netlab toolbox not installed.');
+            end
+            mix{i} = netlabgmminit(mix{i},vt(:,founds{i})',options);
+            options(5)   = 1;
+            options(14)  = emiter;
+            try
+                mix{i} = gmmem(mix{i},vt(:,founds{i})',options);
+                c.nbparam = c.nbparam + ...
+                    length(mix{i}.centres(:)) + length(mix{i}.covars(:));
+            catch
+                err = lasterr;
+                warning('WARNING IN CLASSIFY: Problem when calling GMMEM:');
+                disp(err);
+                disp('Let us try again...');
+                OK = 0;
+            end
+        end    
+    end
+    pr = zeros(lva,length(labs));
+    for i = 1:length(labs)
+        prior = length(founds{i})/lvt;
+        pr(:,i) = prior * gmmprob(mix{i},va');
+        %c.post{i} = gmmpost(mix{i},va');
+    end
+    [mm ib] = max(pr');
+    for i = 1:lva
+        c.classes{i} = labs{ib(i)};
+    end
+end
+if isempty(rlab)
+    c.correct = NaN;
+else
+    correct = 0;
+    for i = 1:lva
+        if isequal(c.classes{i},rlab{i})
+            correct = correct + 1;
+        end
+    end
+    c.correct = correct / lva;
+end        
+c = class(c,'mirclassify');
+
+
+function vt = integrate(vt,v,lvt,norml)
+vtl = [];
+for l = 1:lvt
+    vl = v{l};
+    if iscell(vl)
+        vl = vl{1};
+    end
+    if iscell(vl)
+        vl = vl{1};
+    end
+    if size(vl,2) > 1
+        mirerror('MIRCLASSIFY','The analytic features guiding the classification should not be frame-decomposed.');
+    end
+    vtl(:,l) = vl;
+end
+if norml
+    dnom = repmat(std(vtl,0,2),[1 size(vtl,2)]);
+    dnom = dnom + (dnom == 0);  % In order to avoid division by 0
+    vtl = (vtl - repmat(mean(vtl,2),[1 size(vtl,2)])) ./ dnom;
+end
+vt(end+1:end+size(vtl,1),:) = vtl;
+
+
+function [k,ncentres,covartype,kmiter,emiter,d,norml,mahl] = scanargin(v)
+k = 1;
+d = 0;
+i = 1;
+ncentres = 0;
+covartype = 'full';
+kmiter = 10;
+emiter = 100;
+norml = 1;
+mahl = 1;
+while i <= length(v)
+    arg = v{i};
+    if ischar(arg) && strcmpi(arg,'Nearest')
+        k = 1;
+        if length(v)>i && isnumeric(v{i+1})
+            i = i+1;
+            k = v{i};
+        end
+    elseif ischar(arg) && strcmpi(arg,'GMM')
+        k = 0;
+        ncentres = 1;
+        if length(v)>i
+            if isnumeric(v{i+1})
+                i = i+1;
+                ncentres = v{i};
+                if length(v)>i && ischar(v{i+1})
+                    i = i+1;
+                    covartype = v{i};
+                end
+            elseif ischar(v{i+1})
+                i = i+1;
+                covartype = v{i};
+                if length(v)>i && isnumeric(v{i+1})
+                    i = i+1;
+                    ncentres = v{i};
+                end
+            end                
+        end
+    elseif isnumeric(arg)
+        k = v{i};
+    else
+        error('ERROR IN MIRCLASSIFY: Syntax error. See help mirclassify.');
+    end    
+    i = i+1;
+end
+
+
+function y = distance(a,t,d,mahl)
+
+for i = 1:size(t,2)
+    if det(mahl) > 0  % more generally, uses cond
+        lham = inv(mahl);
+    else
+        lham = pinv(mahl);
+    end
+    y(i) = sqrt((a - t(:,i))'*lham*(a - t(:,i)));        
+end
+%y = sqrt(sum(repmat(a,[1,size(t,2)])-t,1).^2);
\ No newline at end of file