comparison toolboxes/distance_learning/mlr/mlr_demo.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function mlr_demo()
2
3 display('Loading Wine data');
4 load Wine;
5
6 % z-score the input dimensions
7 display('z-scoring features');
8 X = zscore(X')';
9
10 [d,n] = size(X);
11
12 % Generate a random training/test split
13 display('Generating a 80/20 training/test split');
14 P = randperm(n);
15 Xtrain = X(:,P(1:floor(0.8 * n)));
16 Ytrain = Y(P(1:floor(0.8*n)));
17 Xtest = X(:,P((1+floor(0.8*n)):end));
18 Ytest = Y(P((1+floor(0.8*n)):end));
19
20
21 % Optimize W for AUC
22 C = 1e-2;
23 display(sprintf('Training with C=%.2e, Delta=mAP', C));
24 [W, Xi, Diagnostics] = mlr_train(Xtrain, Ytrain, C, 'map');
25 % [W, Xi, Diagnostics] = mlr_train_primal(Xtrain, Ytrain, C, 'map');
26
27 display('Test performance in the native (normalized) metric');
28 mlr_test(eye(d), 3, Xtrain, Ytrain, Xtest, Ytest)
29
30 display('Test performance with MLR metric');
31 mlr_test(W, 3, Xtrain, Ytrain, Xtest, Ytest)
32
33 % Scatter-plot
34 figure;
35 subplot(1,2,1), drawData(eye(d), Xtrain, Ytrain, Xtest, Ytest), title('Native metric (z-scored)');
36 subplot(1,2,2), drawData(W, Xtrain, Ytrain, Xtest, Ytest), title('Learned metric (MLR-mAP)');
37
38 Diagnostics
39
40 end
41
42
43 function drawData(W, Xtrain, Ytrain, Xtest, Ytest);
44
45 n = length(Ytrain);
46 m = length(Ytest);
47
48 if size(W,2) == 1
49 W = diag(W);
50 end
51 % PCA the learned metric
52 Z = [Xtrain Xtest];
53 A = Z' * W * Z;
54 [v,d] = eig(A);
55
56 L = (d.^0.5) * v';
57 L = L(1:2,:);
58
59 % Draw training points
60 hold on;
61 trmarkers = {'b+', 'r+', 'g+'};
62 tsmarkers = {'bo', 'ro', 'go'};
63 for i = min(Ytrain):max(Ytrain)
64 points = find(Ytrain == i);
65 scatter(L(1,points), L(2,points), trmarkers{i});
66 points = n + find(Ytest == i);
67 scatter(L(1,points), L(2,points), tsmarkers{i});
68 end
69 legend({'Training', 'Test'});
70 end