wolffd@0
|
1 function [Y, Loss] = separationOracleMAP(q, D, pos, neg, k)
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % [Y,Loss] = separationOracleMAP(q, D, pos, neg, k)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % q = index of the query point
|
wolffd@0
|
6 % D = the current distance matrix
|
wolffd@0
|
7 % pos = indices of relevant results for q
|
wolffd@0
|
8 % neg = indices of irrelevant results for q
|
wolffd@0
|
9 % k = length of the list to consider (unused in MAP)
|
wolffd@0
|
10 %
|
wolffd@0
|
11 % Y is a permutation 1:n corresponding to the maximally
|
wolffd@0
|
12 % violated constraint
|
wolffd@0
|
13 %
|
wolffd@0
|
14 % Loss is the loss for Y, in this case, 1-AP(Y)
|
wolffd@0
|
15
|
wolffd@0
|
16
|
wolffd@0
|
17 % First, sort the documents in descending order of W'Phi(q,x)
|
wolffd@0
|
18 % Phi = - (X(q) - X(x)) * (X(q) - X(x))'
|
wolffd@0
|
19
|
wolffd@0
|
20 % Sort the positive documents
|
wolffd@0
|
21 ScorePos = - D(pos,q);
|
wolffd@0
|
22 [Vpos, Ipos] = sort(full(ScorePos'), 'descend');
|
wolffd@0
|
23 Ipos = pos(Ipos);
|
wolffd@0
|
24
|
wolffd@0
|
25 % Sort the negative documents
|
wolffd@0
|
26 ScoreNeg = - D(neg,q);
|
wolffd@0
|
27 [Vneg, Ineg] = sort(full(ScoreNeg'), 'descend');
|
wolffd@0
|
28 Ineg = neg(Ineg);
|
wolffd@0
|
29
|
wolffd@0
|
30 % Now, solve the DP for the interleaving
|
wolffd@0
|
31
|
wolffd@0
|
32 numPos = length(pos);
|
wolffd@0
|
33 numNeg = length(neg);
|
wolffd@0
|
34 n = numPos + numNeg;
|
wolffd@0
|
35
|
wolffd@0
|
36
|
wolffd@0
|
37 % Pre-generate the precision scores
|
wolffd@0
|
38 % H = triu(1./bsxfun(@minus, (0:(numPos-1))', 1:n));
|
wolffd@0
|
39 H = tril(1./bsxfun(@minus, 0:(numPos-1), (1:n)'));
|
wolffd@0
|
40
|
wolffd@0
|
41 % Padded cumulative Vneg
|
wolffd@0
|
42 pcVneg = cumsum([0 Vneg]);
|
wolffd@0
|
43
|
wolffd@0
|
44 % Generate the discriminant scores
|
wolffd@0
|
45 H = H + scoreChangeMatrix(Vpos, Vneg, n, pcVneg);
|
wolffd@0
|
46
|
wolffd@0
|
47 % Cost of inserting the first + at position b
|
wolffd@0
|
48 P = zeros(size(H));
|
wolffd@0
|
49
|
wolffd@0
|
50 % Now recurse
|
wolffd@0
|
51 for a = 2:numPos
|
wolffd@0
|
52
|
wolffd@0
|
53 % Fill in the back-pointers
|
wolffd@0
|
54 [m,p] = cummax(H(:,a-1));
|
wolffd@0
|
55 % The best point is the previous row, up to b-1
|
wolffd@0
|
56 H(a:n,a) = H(a:n,a) + (a-1)/a .* m(a-1:n-1)';
|
wolffd@0
|
57 P(a+1:n,a) = p(a:n-1);
|
wolffd@0
|
58 P(a,a) = a-1;
|
wolffd@0
|
59 end
|
wolffd@0
|
60
|
wolffd@0
|
61 % Now reconstruct the permutation from the DP table
|
wolffd@0
|
62 Y = nan * ones(n,1);
|
wolffd@0
|
63 [m,p] = max(H(:,numPos));
|
wolffd@0
|
64 Y(p) = Ipos(numPos);
|
wolffd@0
|
65
|
wolffd@0
|
66 for a = numPos:-1:2
|
wolffd@0
|
67 p = P(p,a);
|
wolffd@0
|
68 Y(p) = Ipos(a-1);
|
wolffd@0
|
69 end
|
wolffd@0
|
70 Y(isnan(Y)) = Ineg;
|
wolffd@0
|
71
|
wolffd@0
|
72 % Compute loss for this list
|
wolffd@0
|
73 Loss = 1 - AP(Y, pos);
|
wolffd@0
|
74 end
|
wolffd@0
|
75
|
wolffd@0
|
76 function C = scoreChangeMatrix(Vpos, Vneg, n, pcVneg)
|
wolffd@0
|
77 numNeg = length(Vneg);
|
wolffd@0
|
78 numPos = length(Vpos);
|
wolffd@0
|
79
|
wolffd@0
|
80 % Inserting the a'th relevant document at position b
|
wolffd@0
|
81 % There are (b - (a - 1)) negative docs before a
|
wolffd@0
|
82 % And (numNeg - (b - (a - 1))) negative docs after
|
wolffd@0
|
83 %
|
wolffd@0
|
84 % The change in score is proportional to:
|
wolffd@0
|
85 %
|
wolffd@0
|
86 % sum_{negative j} (Vpos(a) - Vneg(j)) * y_{aj}
|
wolffd@0
|
87 %
|
wolffd@0
|
88 % = (numNeg - (b - (a - 1))) * Vpos(a) # Negatives after a
|
wolffd@0
|
89 % - (cVneg(end) - cVneg(b - (a - 1))) Weight of negs after a
|
wolffd@0
|
90 % - (b - (a - 1)) * Vpos(a) # Negatives before a
|
wolffd@0
|
91 % + cVneg(b - (a - 1)) Weight of negs before a
|
wolffd@0
|
92 %
|
wolffd@0
|
93 % Rearrange:
|
wolffd@0
|
94 %
|
wolffd@0
|
95 % (numNeg - 2 * (b - a + 1)) * Vpos(a)
|
wolffd@0
|
96 % - cVneg(end) + 2 * cVneg(b - a + 1)
|
wolffd@0
|
97 %
|
wolffd@0
|
98 % Valid range of a: 1:numPos
|
wolffd@0
|
99 % Valid range of b: a:n
|
wolffd@0
|
100
|
wolffd@0
|
101 D = bsxfun(@plus, 1-(1:numPos), (1:n)');
|
wolffd@0
|
102 C = numNeg - 2 * D;
|
wolffd@0
|
103 C = bsxfun(@times, Vpos, C);
|
wolffd@0
|
104
|
wolffd@0
|
105 D(D < 1) = 1;
|
wolffd@0
|
106 D(D > length(pcVneg)) = length(pcVneg);
|
wolffd@0
|
107
|
wolffd@0
|
108 % FIXME: 2011-01-28 21:13:37 by Brian McFee <bmcfee@cs.ucsd.edu>
|
wolffd@0
|
109 % brutal hack to get around matlab's screwy matrix reshaping
|
wolffd@0
|
110 if numPos == 1
|
wolffd@0
|
111 pcVneg = pcVneg';
|
wolffd@0
|
112 end
|
wolffd@0
|
113
|
wolffd@0
|
114 C = C + 2 * pcVneg(D) - pcVneg(end);
|
wolffd@0
|
115
|
wolffd@0
|
116 % Normalize
|
wolffd@0
|
117 C = bsxfun(@ldivide, (1:numPos) * numNeg, C);
|
wolffd@0
|
118
|
wolffd@0
|
119 % -Inf out the infeasible regions
|
wolffd@0
|
120 C = C - triu(Inf * bsxfun(@gt, (1:numPos), (1:n)'),1);
|
wolffd@0
|
121
|
wolffd@0
|
122
|
wolffd@0
|
123 end
|
wolffd@0
|
124
|
wolffd@0
|
125 function x = AP(Y, pos)
|
wolffd@0
|
126 % Indicator for relevant documents
|
wolffd@0
|
127 rel = ismember(Y, pos);
|
wolffd@0
|
128
|
wolffd@0
|
129 % Prec@k for all k
|
wolffd@0
|
130 Prec = cumsum(rel)' ./ (1:length(Y));
|
wolffd@0
|
131
|
wolffd@0
|
132 % Prec@k averaged over relevant positions
|
wolffd@0
|
133 x = mean(Prec(rel));
|
wolffd@0
|
134 end
|