Mercurial > hg > camir-aes2014
comparison toolboxes/distance_learning/mlr/rmlr_demo.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function rmlr_demo() | |
2 | |
3 display('Loading Wine data'); | |
4 load Wine; | |
5 | |
6 noisedim = 96; | |
7 [d,n] = size(X); | |
8 d = d + noisedim; | |
9 | |
10 %create covariance matrix | |
11 var = randn(noisedim); var = var'*var; | |
12 noise = sqrtm(var)* randn(noisedim, n); | |
13 X = [X; noise]; | |
14 | |
15 % z-score the input dimensions | |
16 display('z-scoring features'); | |
17 X = zscore(X')'; | |
18 | |
19 display('Generating a 80/20 training/test split'); | |
20 P = randperm(n); | |
21 Xtrain = X(:,P(1:floor(0.8 * n))); | |
22 Ytrain = Y(P(1:floor(0.8*n))); | |
23 Xtest = X(:,P((1+floor(0.8*n)):end)); | |
24 Ytest = Y(P((1+floor(0.8*n)):end)); | |
25 | |
26 C = 1e2; | |
27 lam = 0.5; | |
28 | |
29 display(sprintf('Training with C=%.2e, Delta=MAP', C)); | |
30 %learn metric with R-MLR | |
31 [W_rmlr, Xi, Diagnostics_rmlr] = rmlr_train(Xtrain, Ytrain, C, 'map',3,1,0,0,lam); | |
32 | |
33 %learn metric with MLR | |
34 [W_mlr, Xi, Diagnostics_mlr] = mlr_train(Xtrain, Ytrain, C, 'map'); | |
35 | |
36 display('Test performance in the native (normalized) metric'); | |
37 mlr_test(eye(d), 3, Xtrain, Ytrain, Xtest, Ytest) | |
38 | |
39 display('Test performance with R-MLR metric'); | |
40 mlr_test(W_rmlr, 3, Xtrain, Ytrain, Xtest, Ytest) | |
41 | |
42 display('Test performance with MLR metric'); | |
43 mlr_test(W_mlr, 3, Xtrain, Ytrain, Xtest, Ytest) | |
44 | |
45 % Scatter-plot | |
46 figure; | |
47 subplot(1,3,1), drawData(eye(d), Xtrain, Ytrain, Xtest, Ytest), title('Native metric (z-scored)'); | |
48 subplot(1,3,2), drawData(W_mlr, Xtrain, Ytrain, Xtest, Ytest), title('Learned metric (MLR)'); | |
49 subplot(1,3,3), drawData(W_rmlr, Xtrain, Ytrain, Xtest, Ytest), title('Learned metric (RMLR)'); | |
50 | |
51 figure; | |
52 subplot(121), imagesc(W_mlr), title('W: MLR'); | |
53 subplot(122), imagesc(W_rmlr), title('W: RMLR'); | |
54 Diagnostics_rmlr | |
55 Diagnostics_mlr | |
56 | |
57 end | |
58 | |
59 | |
60 function drawData(W, Xtrain, Ytrain, Xtest, Ytest); | |
61 | |
62 n = length(Ytrain); | |
63 m = length(Ytest); | |
64 | |
65 if size(W,2) == 1 | |
66 W = diag(W); | |
67 end | |
68 % PCA the learned metric | |
69 Z = [Xtrain Xtest]; | |
70 A = Z' * W * Z; | |
71 [v,d] = eig(A); | |
72 | |
73 L = (d.^0.5) * v'; | |
74 L = L(1:2,:); | |
75 | |
76 % Draw training points | |
77 hold on; | |
78 trmarkers = {'b+', 'r+', 'g+'}; | |
79 tsmarkers = {'bo', 'ro', 'go'}; | |
80 for i = min(Ytrain):max(Ytrain) | |
81 points = find(Ytrain == i); | |
82 scatter(L(1,points), L(2,points), trmarkers{i}); | |
83 points = n + find(Ytest == i); | |
84 scatter(L(1,points), L(2,points), tsmarkers{i}); | |
85 end | |
86 legend({'Training', 'Test'}); | |
87 end |