wolffd@0
|
1 function rmlr_demo()
|
wolffd@0
|
2
|
wolffd@0
|
3 display('Loading Wine data');
|
wolffd@0
|
4 load Wine;
|
wolffd@0
|
5
|
wolffd@0
|
6 noisedim = 96;
|
wolffd@0
|
7 [d,n] = size(X);
|
wolffd@0
|
8 d = d + noisedim;
|
wolffd@0
|
9
|
wolffd@0
|
10 %create covariance matrix
|
wolffd@0
|
11 var = randn(noisedim); var = var'*var;
|
wolffd@0
|
12 noise = sqrtm(var)* randn(noisedim, n);
|
wolffd@0
|
13 X = [X; noise];
|
wolffd@0
|
14
|
wolffd@0
|
15 % z-score the input dimensions
|
wolffd@0
|
16 display('z-scoring features');
|
wolffd@0
|
17 X = zscore(X')';
|
wolffd@0
|
18
|
wolffd@0
|
19 display('Generating a 80/20 training/test split');
|
wolffd@0
|
20 P = randperm(n);
|
wolffd@0
|
21 Xtrain = X(:,P(1:floor(0.8 * n)));
|
wolffd@0
|
22 Ytrain = Y(P(1:floor(0.8*n)));
|
wolffd@0
|
23 Xtest = X(:,P((1+floor(0.8*n)):end));
|
wolffd@0
|
24 Ytest = Y(P((1+floor(0.8*n)):end));
|
wolffd@0
|
25
|
wolffd@0
|
26 C = 1e2;
|
wolffd@0
|
27 lam = 0.5;
|
wolffd@0
|
28
|
wolffd@0
|
29 display(sprintf('Training with C=%.2e, Delta=MAP', C));
|
wolffd@0
|
30 %learn metric with R-MLR
|
wolffd@0
|
31 [W_rmlr, Xi, Diagnostics_rmlr] = rmlr_train(Xtrain, Ytrain, C, 'map',3,1,0,0,lam);
|
wolffd@0
|
32
|
wolffd@0
|
33 %learn metric with MLR
|
wolffd@0
|
34 [W_mlr, Xi, Diagnostics_mlr] = mlr_train(Xtrain, Ytrain, C, 'map');
|
wolffd@0
|
35
|
wolffd@0
|
36 display('Test performance in the native (normalized) metric');
|
wolffd@0
|
37 mlr_test(eye(d), 3, Xtrain, Ytrain, Xtest, Ytest)
|
wolffd@0
|
38
|
wolffd@0
|
39 display('Test performance with R-MLR metric');
|
wolffd@0
|
40 mlr_test(W_rmlr, 3, Xtrain, Ytrain, Xtest, Ytest)
|
wolffd@0
|
41
|
wolffd@0
|
42 display('Test performance with MLR metric');
|
wolffd@0
|
43 mlr_test(W_mlr, 3, Xtrain, Ytrain, Xtest, Ytest)
|
wolffd@0
|
44
|
wolffd@0
|
45 % Scatter-plot
|
wolffd@0
|
46 figure;
|
wolffd@0
|
47 subplot(1,3,1), drawData(eye(d), Xtrain, Ytrain, Xtest, Ytest), title('Native metric (z-scored)');
|
wolffd@0
|
48 subplot(1,3,2), drawData(W_mlr, Xtrain, Ytrain, Xtest, Ytest), title('Learned metric (MLR)');
|
wolffd@0
|
49 subplot(1,3,3), drawData(W_rmlr, Xtrain, Ytrain, Xtest, Ytest), title('Learned metric (RMLR)');
|
wolffd@0
|
50
|
wolffd@0
|
51 figure;
|
wolffd@0
|
52 subplot(121), imagesc(W_mlr), title('W: MLR');
|
wolffd@0
|
53 subplot(122), imagesc(W_rmlr), title('W: RMLR');
|
wolffd@0
|
54 Diagnostics_rmlr
|
wolffd@0
|
55 Diagnostics_mlr
|
wolffd@0
|
56
|
wolffd@0
|
57 end
|
wolffd@0
|
58
|
wolffd@0
|
59
|
wolffd@0
|
60 function drawData(W, Xtrain, Ytrain, Xtest, Ytest);
|
wolffd@0
|
61
|
wolffd@0
|
62 n = length(Ytrain);
|
wolffd@0
|
63 m = length(Ytest);
|
wolffd@0
|
64
|
wolffd@0
|
65 if size(W,2) == 1
|
wolffd@0
|
66 W = diag(W);
|
wolffd@0
|
67 end
|
wolffd@0
|
68 % PCA the learned metric
|
wolffd@0
|
69 Z = [Xtrain Xtest];
|
wolffd@0
|
70 A = Z' * W * Z;
|
wolffd@0
|
71 [v,d] = eig(A);
|
wolffd@0
|
72
|
wolffd@0
|
73 L = (d.^0.5) * v';
|
wolffd@0
|
74 L = L(1:2,:);
|
wolffd@0
|
75
|
wolffd@0
|
76 % Draw training points
|
wolffd@0
|
77 hold on;
|
wolffd@0
|
78 trmarkers = {'b+', 'r+', 'g+'};
|
wolffd@0
|
79 tsmarkers = {'bo', 'ro', 'go'};
|
wolffd@0
|
80 for i = min(Ytrain):max(Ytrain)
|
wolffd@0
|
81 points = find(Ytrain == i);
|
wolffd@0
|
82 scatter(L(1,points), L(2,points), trmarkers{i});
|
wolffd@0
|
83 points = n + find(Ytest == i);
|
wolffd@0
|
84 scatter(L(1,points), L(2,points), tsmarkers{i});
|
wolffd@0
|
85 end
|
wolffd@0
|
86 legend({'Training', 'Test'});
|
wolffd@0
|
87 end
|