annotate toolboxes/distance_learning/mlr/rmlr_demo.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function rmlr_demo()
wolffd@0 2
wolffd@0 3 display('Loading Wine data');
wolffd@0 4 load Wine;
wolffd@0 5
wolffd@0 6 noisedim = 96;
wolffd@0 7 [d,n] = size(X);
wolffd@0 8 d = d + noisedim;
wolffd@0 9
wolffd@0 10 %create covariance matrix
wolffd@0 11 var = randn(noisedim); var = var'*var;
wolffd@0 12 noise = sqrtm(var)* randn(noisedim, n);
wolffd@0 13 X = [X; noise];
wolffd@0 14
wolffd@0 15 % z-score the input dimensions
wolffd@0 16 display('z-scoring features');
wolffd@0 17 X = zscore(X')';
wolffd@0 18
wolffd@0 19 display('Generating a 80/20 training/test split');
wolffd@0 20 P = randperm(n);
wolffd@0 21 Xtrain = X(:,P(1:floor(0.8 * n)));
wolffd@0 22 Ytrain = Y(P(1:floor(0.8*n)));
wolffd@0 23 Xtest = X(:,P((1+floor(0.8*n)):end));
wolffd@0 24 Ytest = Y(P((1+floor(0.8*n)):end));
wolffd@0 25
wolffd@0 26 C = 1e2;
wolffd@0 27 lam = 0.5;
wolffd@0 28
wolffd@0 29 display(sprintf('Training with C=%.2e, Delta=MAP', C));
wolffd@0 30 %learn metric with R-MLR
wolffd@0 31 [W_rmlr, Xi, Diagnostics_rmlr] = rmlr_train(Xtrain, Ytrain, C, 'map',3,1,0,0,lam);
wolffd@0 32
wolffd@0 33 %learn metric with MLR
wolffd@0 34 [W_mlr, Xi, Diagnostics_mlr] = mlr_train(Xtrain, Ytrain, C, 'map');
wolffd@0 35
wolffd@0 36 display('Test performance in the native (normalized) metric');
wolffd@0 37 mlr_test(eye(d), 3, Xtrain, Ytrain, Xtest, Ytest)
wolffd@0 38
wolffd@0 39 display('Test performance with R-MLR metric');
wolffd@0 40 mlr_test(W_rmlr, 3, Xtrain, Ytrain, Xtest, Ytest)
wolffd@0 41
wolffd@0 42 display('Test performance with MLR metric');
wolffd@0 43 mlr_test(W_mlr, 3, Xtrain, Ytrain, Xtest, Ytest)
wolffd@0 44
wolffd@0 45 % Scatter-plot
wolffd@0 46 figure;
wolffd@0 47 subplot(1,3,1), drawData(eye(d), Xtrain, Ytrain, Xtest, Ytest), title('Native metric (z-scored)');
wolffd@0 48 subplot(1,3,2), drawData(W_mlr, Xtrain, Ytrain, Xtest, Ytest), title('Learned metric (MLR)');
wolffd@0 49 subplot(1,3,3), drawData(W_rmlr, Xtrain, Ytrain, Xtest, Ytest), title('Learned metric (RMLR)');
wolffd@0 50
wolffd@0 51 figure;
wolffd@0 52 subplot(121), imagesc(W_mlr), title('W: MLR');
wolffd@0 53 subplot(122), imagesc(W_rmlr), title('W: RMLR');
wolffd@0 54 Diagnostics_rmlr
wolffd@0 55 Diagnostics_mlr
wolffd@0 56
wolffd@0 57 end
wolffd@0 58
wolffd@0 59
wolffd@0 60 function drawData(W, Xtrain, Ytrain, Xtest, Ytest);
wolffd@0 61
wolffd@0 62 n = length(Ytrain);
wolffd@0 63 m = length(Ytest);
wolffd@0 64
wolffd@0 65 if size(W,2) == 1
wolffd@0 66 W = diag(W);
wolffd@0 67 end
wolffd@0 68 % PCA the learned metric
wolffd@0 69 Z = [Xtrain Xtest];
wolffd@0 70 A = Z' * W * Z;
wolffd@0 71 [v,d] = eig(A);
wolffd@0 72
wolffd@0 73 L = (d.^0.5) * v';
wolffd@0 74 L = L(1:2,:);
wolffd@0 75
wolffd@0 76 % Draw training points
wolffd@0 77 hold on;
wolffd@0 78 trmarkers = {'b+', 'r+', 'g+'};
wolffd@0 79 tsmarkers = {'bo', 'ro', 'go'};
wolffd@0 80 for i = min(Ytrain):max(Ytrain)
wolffd@0 81 points = find(Ytrain == i);
wolffd@0 82 scatter(L(1,points), L(2,points), trmarkers{i});
wolffd@0 83 points = n + find(Ytest == i);
wolffd@0 84 scatter(L(1,points), L(2,points), tsmarkers{i});
wolffd@0 85 end
wolffd@0 86 legend({'Training', 'Test'});
wolffd@0 87 end