wolffd@0
|
1 function [Y, Loss] = separationOracleKNN(q, D, pos, neg, k)
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % [Y,Loss] = separationOracleKNN(q, D, pos, neg, k)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % q = index of the query point
|
wolffd@0
|
6 % D = the current distance matrix
|
wolffd@0
|
7 % pos = indices of relevant results for q
|
wolffd@0
|
8 % neg = indices of irrelevant results for q
|
wolffd@0
|
9 % k = length of the list to consider
|
wolffd@0
|
10 %
|
wolffd@0
|
11 % Y is a permutation 1:n corresponding to the maximally
|
wolffd@0
|
12 % violated constraint
|
wolffd@0
|
13 %
|
wolffd@0
|
14 % Loss is the loss for Y, in this case, 1-Prec@k(Y)
|
wolffd@0
|
15
|
wolffd@0
|
16
|
wolffd@0
|
17 % First, sort the documents in descending order of W'Phi(q,x)
|
wolffd@0
|
18 % Phi = - (X(q) - X(x)) * (X(q) - X(x))'
|
wolffd@0
|
19
|
wolffd@0
|
20 % Sort the positive documents
|
wolffd@0
|
21 ScorePos = - D(pos,q);
|
wolffd@0
|
22 [Vpos, Ipos] = sort(full(ScorePos'), 'descend');
|
wolffd@0
|
23 Ipos = pos(Ipos);
|
wolffd@0
|
24
|
wolffd@0
|
25 % Sort the negative documents
|
wolffd@0
|
26 ScoreNeg = - D(neg,q);
|
wolffd@0
|
27 [Vneg, Ineg] = sort(full(ScoreNeg'), 'descend');
|
wolffd@0
|
28 Ineg = neg(Ineg);
|
wolffd@0
|
29
|
wolffd@0
|
30 % Now, solve the DP for the interleaving
|
wolffd@0
|
31
|
wolffd@0
|
32 numPos = length(pos);
|
wolffd@0
|
33 numNeg = length(neg);
|
wolffd@0
|
34 n = numPos + numNeg;
|
wolffd@0
|
35
|
wolffd@0
|
36 cVpos = cumsum(Vpos);
|
wolffd@0
|
37 cVneg = cumsum(Vneg);
|
wolffd@0
|
38
|
wolffd@0
|
39
|
wolffd@0
|
40 % If we don't have enough positive (or negative) examples, scale k down
|
wolffd@0
|
41 k = min([k, numPos, numNeg]);
|
wolffd@0
|
42
|
wolffd@0
|
43 % Algorithm:
|
wolffd@0
|
44 % For each precision score in 0, 1/k, 2/k, ... 1
|
wolffd@0
|
45 % Calculate maximum discriminant score for that precision level
|
wolffd@0
|
46 KNN = (0:(1/k):1)' > 0.5;
|
wolffd@0
|
47 Discriminant = zeros(k+1, 1);
|
wolffd@0
|
48 NegsBefore = zeros(numPos,k+1);
|
wolffd@0
|
49
|
wolffd@0
|
50 % For 0 precision, all positives go after the first k negatives
|
wolffd@0
|
51
|
wolffd@0
|
52 NegsBefore(:,1) = k + binarysearch(Vpos, Vneg(k+1:end));
|
wolffd@0
|
53 Discriminant(1) = Vpos * (numNeg - 2 * NegsBefore(:,1)) + numPos * cVneg(end) ...
|
wolffd@0
|
54 - 2 * sum(cVneg(NegsBefore((NegsBefore(:,1) > 0),1)));
|
wolffd@0
|
55
|
wolffd@0
|
56
|
wolffd@0
|
57
|
wolffd@0
|
58 % For precision (a-1)/k, swap the (a-1)'th positive doc
|
wolffd@0
|
59 % into the top (k-a) negative docs
|
wolffd@0
|
60
|
wolffd@0
|
61 for a = 2:(k+1)
|
wolffd@0
|
62 NegsBefore(:,a) = NegsBefore(:,a-1);
|
wolffd@0
|
63
|
wolffd@0
|
64 % We have a-1 positives, and k - (a-1) negatives
|
wolffd@0
|
65 NegsBefore(a-1, a) = binarysearch(Vpos(a-1), Vneg(1:(k-a+1)));
|
wolffd@0
|
66
|
wolffd@0
|
67 % There were NegsBefore(a-1,a-1) negatives before (a-1)
|
wolffd@0
|
68 % Now there are NegsBefore(a,a-1)
|
wolffd@0
|
69
|
wolffd@0
|
70 Discriminant(a) = Discriminant(a-1) ...
|
wolffd@0
|
71 + 2 * (NegsBefore(a-1,a-1) - NegsBefore(a-1,a)) * Vpos(a-1);
|
wolffd@0
|
72
|
wolffd@0
|
73 if NegsBefore(a-1,a-1) > 0
|
wolffd@0
|
74 Discriminant(a) = Discriminant(a) + 2 * cVneg(NegsBefore(a-1,a-1));
|
wolffd@0
|
75 end
|
wolffd@0
|
76 if NegsBefore(a-1,a) > 0
|
wolffd@0
|
77 Discriminant(a) = Discriminant(a) - 2 * cVneg(NegsBefore(a-1,a));
|
wolffd@0
|
78 end
|
wolffd@0
|
79 end
|
wolffd@0
|
80
|
wolffd@0
|
81 % Normalize discriminant scores
|
wolffd@0
|
82 Discriminant = Discriminant / (numPos * numNeg);
|
wolffd@0
|
83 [s, x] = max(Discriminant - KNN);
|
wolffd@0
|
84
|
wolffd@0
|
85 % Now we know that there are x-1 relevant docs in the max ranking
|
wolffd@0
|
86 % Construct Y from NegsBefore(x,:)
|
wolffd@0
|
87
|
wolffd@0
|
88 Y = nan * ones(n,1);
|
wolffd@0
|
89 Y((1:numPos)' + NegsBefore(:,x)) = Ipos;
|
wolffd@0
|
90 if sum(isnan(Y)) ~= length(Ineg)
|
wolffd@0
|
91 keyboard;
|
wolffd@0
|
92 end
|
wolffd@0
|
93 Y(isnan(Y)) = Ineg;
|
wolffd@0
|
94
|
wolffd@0
|
95 % Compute loss for this list
|
wolffd@0
|
96 Loss = 1 - KNN(x);
|
wolffd@0
|
97 end
|
wolffd@0
|
98
|