wolffd@0: function mlr_demo() wolffd@0: wolffd@0: display('Loading Wine data'); wolffd@0: load Wine; wolffd@0: wolffd@0: % z-score the input dimensions wolffd@0: display('z-scoring features'); wolffd@0: X = zscore(X')'; wolffd@0: wolffd@0: [d,n] = size(X); wolffd@0: wolffd@0: % Generate a random training/test split wolffd@0: display('Generating a 80/20 training/test split'); wolffd@0: P = randperm(n); wolffd@0: Xtrain = X(:,P(1:floor(0.8 * n))); wolffd@0: Ytrain = Y(P(1:floor(0.8*n))); wolffd@0: Xtest = X(:,P((1+floor(0.8*n)):end)); wolffd@0: Ytest = Y(P((1+floor(0.8*n)):end)); wolffd@0: wolffd@0: wolffd@0: % Optimize W for AUC wolffd@0: C = 1e-2; wolffd@0: display(sprintf('Training with C=%.2e, Delta=mAP', C)); wolffd@0: [W, Xi, Diagnostics] = mlr_train(Xtrain, Ytrain, C, 'map'); wolffd@0: % [W, Xi, Diagnostics] = mlr_train_primal(Xtrain, Ytrain, C, 'map'); wolffd@0: wolffd@0: display('Test performance in the native (normalized) metric'); wolffd@0: mlr_test(eye(d), 3, Xtrain, Ytrain, Xtest, Ytest) wolffd@0: wolffd@0: display('Test performance with MLR metric'); wolffd@0: mlr_test(W, 3, Xtrain, Ytrain, Xtest, Ytest) wolffd@0: wolffd@0: % Scatter-plot wolffd@0: figure; wolffd@0: subplot(1,2,1), drawData(eye(d), Xtrain, Ytrain, Xtest, Ytest), title('Native metric (z-scored)'); wolffd@0: subplot(1,2,2), drawData(W, Xtrain, Ytrain, Xtest, Ytest), title('Learned metric (MLR-mAP)'); wolffd@0: wolffd@0: Diagnostics wolffd@0: wolffd@0: end wolffd@0: wolffd@0: wolffd@0: function drawData(W, Xtrain, Ytrain, Xtest, Ytest); wolffd@0: wolffd@0: n = length(Ytrain); wolffd@0: m = length(Ytest); wolffd@0: wolffd@0: if size(W,2) == 1 wolffd@0: W = diag(W); wolffd@0: end wolffd@0: % PCA the learned metric wolffd@0: Z = [Xtrain Xtest]; wolffd@0: A = Z' * W * Z; wolffd@0: [v,d] = eig(A); wolffd@0: wolffd@0: L = (d.^0.5) * v'; wolffd@0: L = L(1:2,:); wolffd@0: wolffd@0: % Draw training points wolffd@0: hold on; wolffd@0: trmarkers = {'b+', 'r+', 'g+'}; wolffd@0: tsmarkers = {'bo', 'ro', 'go'}; wolffd@0: for i = min(Ytrain):max(Ytrain) wolffd@0: points = find(Ytrain == i); wolffd@0: scatter(L(1,points), L(2,points), trmarkers{i}); wolffd@0: points = n + find(Ytest == i); wolffd@0: scatter(L(1,points), L(2,points), tsmarkers{i}); wolffd@0: end wolffd@0: legend({'Training', 'Test'}); wolffd@0: end