Mercurial > hg > camir-aes2014
diff toolboxes/distance_learning/mlr/mlr_demo.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/distance_learning/mlr/mlr_demo.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,70 @@ +function mlr_demo() + + display('Loading Wine data'); + load Wine; + + % z-score the input dimensions + display('z-scoring features'); + X = zscore(X')'; + + [d,n] = size(X); + + % Generate a random training/test split + display('Generating a 80/20 training/test split'); + P = randperm(n); + Xtrain = X(:,P(1:floor(0.8 * n))); + Ytrain = Y(P(1:floor(0.8*n))); + Xtest = X(:,P((1+floor(0.8*n)):end)); + Ytest = Y(P((1+floor(0.8*n)):end)); + + + % Optimize W for AUC + C = 1e-2; + display(sprintf('Training with C=%.2e, Delta=mAP', C)); + [W, Xi, Diagnostics] = mlr_train(Xtrain, Ytrain, C, 'map'); +% [W, Xi, Diagnostics] = mlr_train_primal(Xtrain, Ytrain, C, 'map'); + + display('Test performance in the native (normalized) metric'); + mlr_test(eye(d), 3, Xtrain, Ytrain, Xtest, Ytest) + + display('Test performance with MLR metric'); + mlr_test(W, 3, Xtrain, Ytrain, Xtest, Ytest) + + % Scatter-plot + figure; + subplot(1,2,1), drawData(eye(d), Xtrain, Ytrain, Xtest, Ytest), title('Native metric (z-scored)'); + subplot(1,2,2), drawData(W, Xtrain, Ytrain, Xtest, Ytest), title('Learned metric (MLR-mAP)'); + + Diagnostics + +end + + +function drawData(W, Xtrain, Ytrain, Xtest, Ytest); + + n = length(Ytrain); + m = length(Ytest); + + if size(W,2) == 1 + W = diag(W); + end + % PCA the learned metric + Z = [Xtrain Xtest]; + A = Z' * W * Z; + [v,d] = eig(A); + + L = (d.^0.5) * v'; + L = L(1:2,:); + + % Draw training points + hold on; + trmarkers = {'b+', 'r+', 'g+'}; + tsmarkers = {'bo', 'ro', 'go'}; + for i = min(Ytrain):max(Ytrain) + points = find(Ytrain == i); + scatter(L(1,points), L(2,points), trmarkers{i}); + points = n + find(Ytest == i); + scatter(L(1,points), L(2,points), tsmarkers{i}); + end + legend({'Training', 'Test'}); +end