view toolboxes/MIRtoolbox1.3.2/MIRToolbox/@mirclassify/mirclassify.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
function c = mirclassify(a,da,t,dt,varargin)
%   c = mirclassify(test,features_test,train,features_train) classifies the
%       audio sequence(s) contained in the audio object test, along the
%       analytic feature(s) features_test, following the supervised
%       learning of a training set defined by the audio object train and
%       the corresponding analytic feature(s) features_train.
%           * The analytic feature(s) features_test should *not* be frame 
%               decomposed. Frame-decomposed data should first be 
%               summarized, using for instance mirmean or mirstd.
%           * Multiple analytic features have to be grouped into one array 
%               of cells.
%       You can also integrate your own arrays of numbers computed outside
%           MIRtoolbox as part of the features. These arrays should be 
%           given as matrices where each successive column is the analysis 
%           of each successive file.
%   Example:
%       mirclassify(test, mfcc(test), train, mfcc(train))
%       mirclassify(test, {mfcc(test), centroid(test)}, ...
%                train, {mfcc(train), centroid(train)})
%   Optional argument:
%       mirclassify(...,'Nearest') uses the minimum distance strategy.
%           (by default)
%       mirclassify(...,'Nearest',k) uses the k-nearest-neighbour strategy.
%           Default value: k = 1, corresponding to the minimum distance
%               strategy.
%       mirclassify(...,'GMM',ng) uses a gaussian mixture model. Each class is
%           modeled by at most ng gaussians.
%           Default value: ng = 1.
%           Additionnally, the type of mixture model can be specified,
%           using the set of value proposed in the gmm function: i.e.,
%           'spherical','diag','full' (default value) and 'ppca'.
%               (cf. help gmm)
%           Requires the Netlab toolbox.

lab = get(t,'Label');
c.labtraining = lab;
rlab = get(a,'Label');
c.labtest = rlab;
[k,ncentres,covartype,kmiter,emiter,d,norml,mahl] = scanargin(varargin);
disp('Classifying...')
if not(iscell(dt))
    dt = {dt};
end
lvt = length(get(t,'Data'));
vt = [];
for i = 1:length(dt)
    if isnumeric(dt{i})
        d = cell(1,size(dt{i},2));
        for j = 1:size(dt{i},2)
            d{j} = dt{i}(:,j);
        end
    else
        d = get(dt{i},'Data');
    end
    vt = integrate(vt,d,lvt,norml);
    if isa(dt{i},'scalar')
        m = mode(dt{i});
        if not(isempty(m))
            vt = integrate(vt,m,lvt,norml);
        end
    end
end
c.training = vt;
dim = size(vt,1);
if not(iscell(da))
    da = {da};
end
lva = length(get(a,'Data'));
va = [];
for i = 1:length(da)
    if isnumeric(da{i})
        d = cell(1,size(da{i},2));
        for j = 1:size(da{i},2)
            d{j} = da{i}(:,j);
        end
    else
        d = get(da{i},'Data');
    end
    va = integrate(va,d,lva,norml);
    if isa(da{i},'scalar')
        m = mode(da{i});
        if not(isempty(m))
            va = integrate(va,m,lva,norml);
        end
    end
end
c.test = va;
c.nbobs = lvt;
totva = [vt va];
mahl = cov(totva');
if k                % k-Nearest Neighbour
    c.nbparam = lvt;
    for l = 1:lva
        [sv,idx] = sort(distance(va(:,l),vt,d,mahl));
        labs = cell(0); % Class labels
        founds = [];    % Number of found elements in each class
        for i = idx(1:k)
            labi = lab{i};
            found = 0;
            for j = 1:length(labs)
                if isequal(labi,labs{j})
                    found = j;
                end
            end
            if found
                founds(found) = founds(found)+1;
            else
                labs{end+1} = labi;
                founds(end+1) = 1;
            end
        end
        [b ib] = max(founds);
        c.classes{l} = labs{ib};
    end
elseif ncentres     % Gaussian Mixture Model
    labs = cell(0);    % Class labels
    founds = cell(0);  % Elements associated to each label.
    for i = 1:lvt
        labi = lab{i};
        found = 0;
        for j = 1:length(labs)
            if isequal(labi,labs{j})
                founds{j}(end+1) = i;
                found = 1;
            end
        end
        if not(found)
            labs{end+1} = labi;
            founds{end+1} = i;
        end
    end
    options      = zeros(1, 18);
    options(2:3) = 1e-4;
    options(4)   = 1e-6;
    options(16)  = 1e-8;
    options(17)  = 0.1;
    options(1)   = 0; %Prints out error values, -1 else
    c.nbparam = 0;
    OK = 0;
    while not(OK)
        OK = 1;
        for i = 1:length(labs)
            options(14)  = kmiter;
            try
                mix{i} = gmm(dim,ncentres,covartype);
            catch
                error('ERROR IN CLASSIFY: Netlab toolbox not installed.');
            end
            mix{i} = netlabgmminit(mix{i},vt(:,founds{i})',options);
            options(5)   = 1;
            options(14)  = emiter;
            try
                mix{i} = gmmem(mix{i},vt(:,founds{i})',options);
                c.nbparam = c.nbparam + ...
                    length(mix{i}.centres(:)) + length(mix{i}.covars(:));
            catch
                err = lasterr;
                warning('WARNING IN CLASSIFY: Problem when calling GMMEM:');
                disp(err);
                disp('Let us try again...');
                OK = 0;
            end
        end    
    end
    pr = zeros(lva,length(labs));
    for i = 1:length(labs)
        prior = length(founds{i})/lvt;
        pr(:,i) = prior * gmmprob(mix{i},va');
        %c.post{i} = gmmpost(mix{i},va');
    end
    [mm ib] = max(pr');
    for i = 1:lva
        c.classes{i} = labs{ib(i)};
    end
end
if isempty(rlab)
    c.correct = NaN;
else
    correct = 0;
    for i = 1:lva
        if isequal(c.classes{i},rlab{i})
            correct = correct + 1;
        end
    end
    c.correct = correct / lva;
end        
c = class(c,'mirclassify');


function vt = integrate(vt,v,lvt,norml)
vtl = [];
for l = 1:lvt
    vl = v{l};
    if iscell(vl)
        vl = vl{1};
    end
    if iscell(vl)
        vl = vl{1};
    end
    if size(vl,2) > 1
        mirerror('MIRCLASSIFY','The analytic features guiding the classification should not be frame-decomposed.');
    end
    vtl(:,l) = vl;
end
if norml
    dnom = repmat(std(vtl,0,2),[1 size(vtl,2)]);
    dnom = dnom + (dnom == 0);  % In order to avoid division by 0
    vtl = (vtl - repmat(mean(vtl,2),[1 size(vtl,2)])) ./ dnom;
end
vt(end+1:end+size(vtl,1),:) = vtl;


function [k,ncentres,covartype,kmiter,emiter,d,norml,mahl] = scanargin(v)
k = 1;
d = 0;
i = 1;
ncentres = 0;
covartype = 'full';
kmiter = 10;
emiter = 100;
norml = 1;
mahl = 1;
while i <= length(v)
    arg = v{i};
    if ischar(arg) && strcmpi(arg,'Nearest')
        k = 1;
        if length(v)>i && isnumeric(v{i+1})
            i = i+1;
            k = v{i};
        end
    elseif ischar(arg) && strcmpi(arg,'GMM')
        k = 0;
        ncentres = 1;
        if length(v)>i
            if isnumeric(v{i+1})
                i = i+1;
                ncentres = v{i};
                if length(v)>i && ischar(v{i+1})
                    i = i+1;
                    covartype = v{i};
                end
            elseif ischar(v{i+1})
                i = i+1;
                covartype = v{i};
                if length(v)>i && isnumeric(v{i+1})
                    i = i+1;
                    ncentres = v{i};
                end
            end                
        end
    elseif isnumeric(arg)
        k = v{i};
    else
        error('ERROR IN MIRCLASSIFY: Syntax error. See help mirclassify.');
    end    
    i = i+1;
end


function y = distance(a,t,d,mahl)

for i = 1:size(t,2)
    if det(mahl) > 0  % more generally, uses cond
        lham = inv(mahl);
    else
        lham = pinv(mahl);
    end
    y(i) = sqrt((a - t(:,i))'*lham*(a - t(:,i)));        
end
%y = sqrt(sum(repmat(a,[1,size(t,2)])-t,1).^2);