Mercurial > hg > camir-aes2014
comparison toolboxes/distance_learning/mlr/mlr_demo.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function mlr_demo() | |
2 | |
3 display('Loading Wine data'); | |
4 load Wine; | |
5 | |
6 % z-score the input dimensions | |
7 display('z-scoring features'); | |
8 X = zscore(X')'; | |
9 | |
10 [d,n] = size(X); | |
11 | |
12 % Generate a random training/test split | |
13 display('Generating a 80/20 training/test split'); | |
14 P = randperm(n); | |
15 Xtrain = X(:,P(1:floor(0.8 * n))); | |
16 Ytrain = Y(P(1:floor(0.8*n))); | |
17 Xtest = X(:,P((1+floor(0.8*n)):end)); | |
18 Ytest = Y(P((1+floor(0.8*n)):end)); | |
19 | |
20 | |
21 % Optimize W for AUC | |
22 C = 1e-2; | |
23 display(sprintf('Training with C=%.2e, Delta=mAP', C)); | |
24 [W, Xi, Diagnostics] = mlr_train(Xtrain, Ytrain, C, 'map'); | |
25 % [W, Xi, Diagnostics] = mlr_train_primal(Xtrain, Ytrain, C, 'map'); | |
26 | |
27 display('Test performance in the native (normalized) metric'); | |
28 mlr_test(eye(d), 3, Xtrain, Ytrain, Xtest, Ytest) | |
29 | |
30 display('Test performance with MLR metric'); | |
31 mlr_test(W, 3, Xtrain, Ytrain, Xtest, Ytest) | |
32 | |
33 % Scatter-plot | |
34 figure; | |
35 subplot(1,2,1), drawData(eye(d), Xtrain, Ytrain, Xtest, Ytest), title('Native metric (z-scored)'); | |
36 subplot(1,2,2), drawData(W, Xtrain, Ytrain, Xtest, Ytest), title('Learned metric (MLR-mAP)'); | |
37 | |
38 Diagnostics | |
39 | |
40 end | |
41 | |
42 | |
43 function drawData(W, Xtrain, Ytrain, Xtest, Ytest); | |
44 | |
45 n = length(Ytrain); | |
46 m = length(Ytest); | |
47 | |
48 if size(W,2) == 1 | |
49 W = diag(W); | |
50 end | |
51 % PCA the learned metric | |
52 Z = [Xtrain Xtest]; | |
53 A = Z' * W * Z; | |
54 [v,d] = eig(A); | |
55 | |
56 L = (d.^0.5) * v'; | |
57 L = L(1:2,:); | |
58 | |
59 % Draw training points | |
60 hold on; | |
61 trmarkers = {'b+', 'r+', 'g+'}; | |
62 tsmarkers = {'bo', 'ro', 'go'}; | |
63 for i = min(Ytrain):max(Ytrain) | |
64 points = find(Ytrain == i); | |
65 scatter(L(1,points), L(2,points), trmarkers{i}); | |
66 points = n + find(Ytest == i); | |
67 scatter(L(1,points), L(2,points), tsmarkers{i}); | |
68 end | |
69 legend({'Training', 'Test'}); | |
70 end |