wolffd@0
|
1 function Perf = mlr_test(W, test_k, Xtrain, Ytrain, Xtest, Ytest)
|
wolffd@0
|
2 % Perf = mlr_test(W, test_k, Xtrain, Ytrain, Xtest, Ytest)
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % W = d-by-d positive semi-definite matrix
|
wolffd@0
|
5 % test_k = vector of k-values to use for KNN/Prec@k/NDCG
|
wolffd@0
|
6 % Xtrain = d-by-n matrix of training data
|
wolffd@0
|
7 % Ytrain = n-by-1 vector of training labels
|
wolffd@0
|
8 % OR
|
wolffd@0
|
9 % n-by-2 cell array where
|
wolffd@0
|
10 % Y{q,1} contains relevant indices (in 1..n) for point q
|
wolffd@0
|
11 % Y{q,2} contains irrelevant indices (in 1..n) for point q
|
wolffd@0
|
12 % Xtest = d-by-m matrix of testing data
|
wolffd@0
|
13 % Ytest = m-by-1 vector of training labels, or m-by-2 cell array
|
wolffd@0
|
14 %
|
wolffd@0
|
15 %
|
wolffd@0
|
16 % The output structure Perf contains the mean score for:
|
wolffd@0
|
17 % AUC, KNN, Prec@k, MAP, MRR, NDCG,
|
wolffd@0
|
18 % as well as the effective dimensionality of W, and
|
wolffd@0
|
19 % the best-performing k-value for KNN, Prec@k, and NDCG.
|
wolffd@0
|
20 %
|
wolffd@0
|
21
|
wolffd@0
|
22 Perf = struct( ...
|
wolffd@0
|
23 'AUC', [], ...
|
wolffd@0
|
24 'KNN', [], ...
|
wolffd@0
|
25 'PrecAtK', [], ...
|
wolffd@0
|
26 'MAP', [], ...
|
wolffd@0
|
27 'MRR', [], ...
|
wolffd@0
|
28 'NDCG', [], ...
|
wolffd@0
|
29 'dimensionality', [], ...
|
wolffd@0
|
30 'KNNk', [], ...
|
wolffd@0
|
31 'PrecAtKk', [], ...
|
wolffd@0
|
32 'NDCGk', [] ...
|
wolffd@0
|
33 );
|
wolffd@0
|
34
|
wolffd@0
|
35 [d, nTrain, nKernel] = size(Xtrain);
|
wolffd@0
|
36 % Compute dimensionality of the learned metric
|
wolffd@0
|
37 Perf.dimensionality = mlr_test_dimension(W, nTrain, nKernel);
|
wolffd@0
|
38 test_k = min(test_k, nTrain);
|
wolffd@0
|
39
|
wolffd@0
|
40 if nargin > 5
|
wolffd@0
|
41 % Knock out the points with no labels
|
wolffd@0
|
42 if ~iscell(Ytest)
|
wolffd@0
|
43 Ibad = find(isnan(Ytrain));
|
wolffd@0
|
44 Xtrain(:,Ibad,:) = inf;
|
wolffd@0
|
45 end
|
wolffd@0
|
46
|
wolffd@0
|
47 % Build the distance matrix
|
wolffd@0
|
48 [D, I] = mlr_test_distance(W, Xtrain, Xtest);
|
wolffd@0
|
49 else
|
wolffd@0
|
50 % Leave-one-out validation
|
wolffd@0
|
51
|
wolffd@0
|
52 if nargin > 4
|
wolffd@0
|
53 % In this case, Xtest is a subset of training indices to test on
|
wolffd@0
|
54 testRange = Xtest;
|
wolffd@0
|
55 else
|
wolffd@0
|
56 testRange = 1:nTrain;
|
wolffd@0
|
57 end
|
wolffd@0
|
58 Xtest = Xtrain(:,testRange,:);
|
wolffd@0
|
59 Ytest = Ytrain(testRange);
|
wolffd@0
|
60
|
wolffd@0
|
61 % compute self-distance
|
wolffd@0
|
62 [D, I] = mlr_test_distance(W, Xtrain, Xtest);
|
wolffd@0
|
63 % clear out the self-link (distance = 0)
|
wolffd@0
|
64 I = I(2:end,:);
|
wolffd@0
|
65 D = D(2:end,:);
|
wolffd@0
|
66 end
|
wolffd@0
|
67
|
wolffd@0
|
68 nTest = length(Ytest);
|
wolffd@0
|
69
|
wolffd@0
|
70 % Compute label agreement
|
wolffd@0
|
71 if ~iscell(Ytest)
|
wolffd@0
|
72 % First, knock out the points with no label
|
wolffd@0
|
73 Labels = Ytrain(I);
|
wolffd@0
|
74 Agree = bsxfun(@eq, Ytest', Labels);
|
wolffd@0
|
75
|
wolffd@0
|
76 % We only compute KNN error if Y are labels
|
wolffd@0
|
77 [Perf.KNN, Perf.KNNk] = mlr_test_knn(Labels, Ytest, test_k);
|
wolffd@0
|
78 else
|
wolffd@0
|
79 if nargin > 5
|
wolffd@0
|
80 Agree = zeros(nTrain, nTest);
|
wolffd@0
|
81 else
|
wolffd@0
|
82 Agree = zeros(nTrain-1, nTest);
|
wolffd@0
|
83 end
|
wolffd@0
|
84 for i = 1:nTest
|
wolffd@0
|
85 Agree(:,i) = ismember(I(:,i), Ytest{i,1});
|
wolffd@0
|
86 end
|
wolffd@0
|
87
|
wolffd@0
|
88 Agree = reduceAgreement(Agree);
|
wolffd@0
|
89 end
|
wolffd@0
|
90
|
wolffd@0
|
91 % Compute AUC score
|
wolffd@0
|
92 Perf.AUC = mlr_test_auc(Agree);
|
wolffd@0
|
93
|
wolffd@0
|
94 % Compute MAP score
|
wolffd@0
|
95 Perf.MAP = mlr_test_map(Agree);
|
wolffd@0
|
96
|
wolffd@0
|
97 % Compute MRR score
|
wolffd@0
|
98 Perf.MRR = mlr_test_mrr(Agree);
|
wolffd@0
|
99
|
wolffd@0
|
100 % Compute prec@k
|
wolffd@0
|
101 [Perf.PrecAtK, Perf.PrecAtKk] = mlr_test_preck(Agree, test_k);
|
wolffd@0
|
102
|
wolffd@0
|
103 % Compute NDCG score
|
wolffd@0
|
104 [Perf.NDCG, Perf.NDCGk] = mlr_test_ndcg(Agree, test_k);
|
wolffd@0
|
105
|
wolffd@0
|
106 end
|
wolffd@0
|
107
|
wolffd@0
|
108
|
wolffd@0
|
109 function [D,I] = mlr_test_distance(W, Xtrain, Xtest)
|
wolffd@0
|
110
|
wolffd@0
|
111 % CASES:
|
wolffd@0
|
112 % Raw: W = []
|
wolffd@0
|
113
|
wolffd@0
|
114 % Linear, full: W = d-by-d
|
wolffd@0
|
115 % Single Kernel, full: W = n-by-n
|
wolffd@0
|
116 % MKL, full: W = n-by-n-by-m
|
wolffd@0
|
117
|
wolffd@0
|
118 % Linear, diagonal: W = d-by-1
|
wolffd@0
|
119 % Single Kernel, diagonal: W = n-by-1
|
wolffd@0
|
120 % MKL, diag: W = n-by-m
|
wolffd@0
|
121 % MKL, diag-off-diag: W = m-by-m-by-n
|
wolffd@0
|
122
|
wolffd@0
|
123 [d, nTrain, nKernel] = size(Xtrain);
|
wolffd@0
|
124 nTest = size(Xtest, 2);
|
wolffd@0
|
125
|
wolffd@0
|
126 if isempty(W)
|
wolffd@0
|
127 % W = [] => native euclidean distances
|
wolffd@0
|
128 D = mlr_test_distance_raw(Xtrain, Xtest);
|
wolffd@0
|
129
|
wolffd@0
|
130 elseif size(W,1) == d && size(W,2) == d
|
wolffd@0
|
131 % We're in a full-projection case
|
wolffd@0
|
132 D = setDistanceFullMKL([Xtrain Xtest], W, nTrain + (1:nTest), 1:nTrain);
|
wolffd@0
|
133
|
wolffd@0
|
134 elseif size(W,1) == d && size(W,2) == nKernel
|
wolffd@0
|
135 % We're in a simple diagonal case
|
wolffd@0
|
136 D = setDistanceDiagMKL([Xtrain Xtest], W, nTrain + (1:nTest), 1:nTrain);
|
wolffd@0
|
137
|
wolffd@0
|
138 else
|
wolffd@0
|
139 % Error?
|
wolffd@0
|
140 error('Cannot determine metric mode.');
|
wolffd@0
|
141
|
wolffd@0
|
142 end
|
wolffd@0
|
143
|
wolffd@0
|
144 D = full(D(1:nTrain, nTrain + (1:nTest)));
|
wolffd@0
|
145 [v,I] = sort(D, 1);
|
wolffd@0
|
146 end
|
wolffd@0
|
147
|
wolffd@0
|
148
|
wolffd@0
|
149
|
wolffd@0
|
150 function dimension = mlr_test_dimension(W, nTrain, nKernel)
|
wolffd@0
|
151
|
wolffd@0
|
152 % CASES:
|
wolffd@0
|
153 % Raw: W = []
|
wolffd@0
|
154
|
wolffd@0
|
155 % Linear, full: W = d-by-d
|
wolffd@0
|
156 % Single Kernel, full: W = n-by-n
|
wolffd@0
|
157 % MKL, full: W = n-by-n-by-m
|
wolffd@0
|
158
|
wolffd@0
|
159 % Linear, diagonal: W = d-by-1
|
wolffd@0
|
160 % Single Kernel, diagonal: W = n-by-1
|
wolffd@0
|
161 % MKL, diag: W = n-by-m
|
wolffd@0
|
162 % MKL, diag-off-diag: W = m-by-m-by-n
|
wolffd@0
|
163
|
wolffd@0
|
164
|
wolffd@0
|
165 if size(W,1) == size(W,2)
|
wolffd@0
|
166 dim = [];
|
wolffd@0
|
167 for i = 1:nKernel
|
wolffd@0
|
168 [v,d] = eig(0.5 * (W(:,:,i) + W(:,:,i)'));
|
wolffd@0
|
169 dim = [dim ; abs(real(diag(d)))];
|
wolffd@0
|
170 end
|
wolffd@0
|
171 else
|
wolffd@0
|
172 dim = W(:);
|
wolffd@0
|
173 end
|
wolffd@0
|
174
|
wolffd@0
|
175 cd = cumsum(dim) / sum(dim);
|
wolffd@0
|
176 dimension = find(cd >= 0.95, 1);
|
wolffd@0
|
177 if isempty(dimension)
|
wolffd@0
|
178 dimension = 0;
|
wolffd@0
|
179 end
|
wolffd@0
|
180 end
|
wolffd@0
|
181
|
wolffd@0
|
182 function [NDCG, NDCGk] = mlr_test_ndcg(Agree, test_k)
|
wolffd@0
|
183
|
wolffd@0
|
184 nTrain = size(Agree, 1);
|
wolffd@0
|
185
|
wolffd@0
|
186 Discount = zeros(1, nTrain);
|
wolffd@0
|
187 Discount(1:2) = 1;
|
wolffd@0
|
188
|
wolffd@0
|
189 NDCG = -Inf;
|
wolffd@0
|
190 NDCGk = 0;
|
wolffd@0
|
191 for k = test_k
|
wolffd@0
|
192
|
wolffd@0
|
193 Discount(3:k) = 1 ./ log2(3:k);
|
wolffd@0
|
194 Discount = Discount / sum(Discount);
|
wolffd@0
|
195
|
wolffd@0
|
196 b = mean(Discount * Agree);
|
wolffd@0
|
197 if b > NDCG
|
wolffd@0
|
198 NDCG = b;
|
wolffd@0
|
199 NDCGk = k;
|
wolffd@0
|
200 end
|
wolffd@0
|
201 end
|
wolffd@0
|
202 end
|
wolffd@0
|
203
|
wolffd@0
|
204 function [PrecAtK, PrecAtKk] = mlr_test_preck(Agree, test_k)
|
wolffd@0
|
205
|
wolffd@0
|
206 PrecAtK = -Inf;
|
wolffd@0
|
207 PrecAtKk = 0;
|
wolffd@0
|
208 for k = test_k
|
wolffd@0
|
209 b = mean( mean( Agree(1:k, :), 1 ) );
|
wolffd@0
|
210 if b > PrecAtK
|
wolffd@0
|
211 PrecAtK = b;
|
wolffd@0
|
212 PrecAtKk = k;
|
wolffd@0
|
213 end
|
wolffd@0
|
214 end
|
wolffd@0
|
215 end
|
wolffd@0
|
216
|
wolffd@0
|
217 function [KNN, KNNk] = mlr_test_knn(Labels, Ytest, test_k)
|
wolffd@0
|
218
|
wolffd@0
|
219 KNN = -Inf;
|
wolffd@0
|
220 KNNk = 0;
|
wolffd@0
|
221 for k = test_k
|
wolffd@0
|
222 % FIXME: 2012-02-07 16:51:59 by Brian McFee <bmcfee@cs.ucsd.edu>
|
wolffd@0
|
223 % fix these to discount nans
|
wolffd@0
|
224
|
wolffd@0
|
225 b = mean( mode( Labels(1:k,:), 1 ) == Ytest');
|
wolffd@0
|
226 if b > KNN
|
wolffd@0
|
227 KNN = b;
|
wolffd@0
|
228 KNNk = k;
|
wolffd@0
|
229 end
|
wolffd@0
|
230 end
|
wolffd@0
|
231 end
|
wolffd@0
|
232
|
wolffd@0
|
233 function MAP = mlr_test_map(Agree);
|
wolffd@0
|
234
|
wolffd@0
|
235 nTrain = size(Agree, 1);
|
wolffd@0
|
236 MAP = bsxfun(@ldivide, (1:nTrain)', cumsum(Agree, 1));
|
wolffd@0
|
237 MAP = mean(sum(MAP .* Agree, 1)./ sum(Agree, 1));
|
wolffd@0
|
238 end
|
wolffd@0
|
239
|
wolffd@0
|
240 function MRR = mlr_test_mrr(Agree);
|
wolffd@0
|
241
|
wolffd@0
|
242 nTest = size(Agree, 2);
|
wolffd@0
|
243 MRR = 0;
|
wolffd@0
|
244 for i = 1:nTest
|
wolffd@0
|
245 MRR = MRR + (1 / find(Agree(:,i), 1));
|
wolffd@0
|
246 end
|
wolffd@0
|
247 MRR = MRR / nTest;
|
wolffd@0
|
248 end
|
wolffd@0
|
249
|
wolffd@0
|
250 function AUC = mlr_test_auc(Agree)
|
wolffd@0
|
251
|
wolffd@0
|
252 TPR = cumsum(Agree, 1);
|
wolffd@0
|
253 FPR = cumsum(~Agree, 1);
|
wolffd@0
|
254
|
wolffd@0
|
255 numPos = TPR(end,:);
|
wolffd@0
|
256 numNeg = FPR(end,:);
|
wolffd@0
|
257
|
wolffd@0
|
258 TPR = mean(bsxfun(@rdivide, TPR, numPos),2);
|
wolffd@0
|
259 FPR = mean(bsxfun(@rdivide, FPR, numNeg),2);
|
wolffd@0
|
260 AUC = diff([0 FPR']) * TPR;
|
wolffd@0
|
261 end
|
wolffd@0
|
262
|
wolffd@0
|
263
|
wolffd@0
|
264 function D = mlr_test_distance_raw(Xtrain, Xtest)
|
wolffd@0
|
265
|
wolffd@0
|
266 [d, nTrain, nKernel] = size(Xtrain);
|
wolffd@0
|
267 nTest = size(Xtest, 2);
|
wolffd@0
|
268
|
wolffd@0
|
269 % Not in kernel mode, compute distances directly
|
wolffd@0
|
270 D = 0;
|
wolffd@0
|
271 for i = 1:nKernel
|
wolffd@0
|
272 D = D + setDistanceDiag([Xtrain(:,:,i) Xtest(:,:,i)], ones(d,1), ...
|
wolffd@0
|
273 nTrain + (1:nTest), 1:nTrain);
|
wolffd@0
|
274 end
|
wolffd@0
|
275 end
|
wolffd@0
|
276
|
wolffd@0
|
277 function A = reduceAgreement(Agree)
|
wolffd@0
|
278 nPos = sum(Agree,1);
|
wolffd@0
|
279 nNeg = sum(~Agree,1);
|
wolffd@0
|
280
|
wolffd@0
|
281 goodI = find(nPos > 0 & nNeg > 0);
|
wolffd@0
|
282 A = Agree(:,goodI);
|
wolffd@0
|
283 end
|