Mercurial > hg > camir-aes2014
comparison toolboxes/distance_learning/mlr/separationOracle/separationOracleMAP.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [Y, Loss] = separationOracleMAP(q, D, pos, neg, k) | |
2 % | |
3 % [Y,Loss] = separationOracleMAP(q, D, pos, neg, k) | |
4 % | |
5 % q = index of the query point | |
6 % D = the current distance matrix | |
7 % pos = indices of relevant results for q | |
8 % neg = indices of irrelevant results for q | |
9 % k = length of the list to consider (unused in MAP) | |
10 % | |
11 % Y is a permutation 1:n corresponding to the maximally | |
12 % violated constraint | |
13 % | |
14 % Loss is the loss for Y, in this case, 1-AP(Y) | |
15 | |
16 | |
17 % First, sort the documents in descending order of W'Phi(q,x) | |
18 % Phi = - (X(q) - X(x)) * (X(q) - X(x))' | |
19 | |
20 % Sort the positive documents | |
21 ScorePos = - D(pos,q); | |
22 [Vpos, Ipos] = sort(full(ScorePos'), 'descend'); | |
23 Ipos = pos(Ipos); | |
24 | |
25 % Sort the negative documents | |
26 ScoreNeg = - D(neg,q); | |
27 [Vneg, Ineg] = sort(full(ScoreNeg'), 'descend'); | |
28 Ineg = neg(Ineg); | |
29 | |
30 % Now, solve the DP for the interleaving | |
31 | |
32 numPos = length(pos); | |
33 numNeg = length(neg); | |
34 n = numPos + numNeg; | |
35 | |
36 | |
37 % Pre-generate the precision scores | |
38 % H = triu(1./bsxfun(@minus, (0:(numPos-1))', 1:n)); | |
39 H = tril(1./bsxfun(@minus, 0:(numPos-1), (1:n)')); | |
40 | |
41 % Padded cumulative Vneg | |
42 pcVneg = cumsum([0 Vneg]); | |
43 | |
44 % Generate the discriminant scores | |
45 H = H + scoreChangeMatrix(Vpos, Vneg, n, pcVneg); | |
46 | |
47 % Cost of inserting the first + at position b | |
48 P = zeros(size(H)); | |
49 | |
50 % Now recurse | |
51 for a = 2:numPos | |
52 | |
53 % Fill in the back-pointers | |
54 [m,p] = cummax(H(:,a-1)); | |
55 % The best point is the previous row, up to b-1 | |
56 H(a:n,a) = H(a:n,a) + (a-1)/a .* m(a-1:n-1)'; | |
57 P(a+1:n,a) = p(a:n-1); | |
58 P(a,a) = a-1; | |
59 end | |
60 | |
61 % Now reconstruct the permutation from the DP table | |
62 Y = nan * ones(n,1); | |
63 [m,p] = max(H(:,numPos)); | |
64 Y(p) = Ipos(numPos); | |
65 | |
66 for a = numPos:-1:2 | |
67 p = P(p,a); | |
68 Y(p) = Ipos(a-1); | |
69 end | |
70 Y(isnan(Y)) = Ineg; | |
71 | |
72 % Compute loss for this list | |
73 Loss = 1 - AP(Y, pos); | |
74 end | |
75 | |
76 function C = scoreChangeMatrix(Vpos, Vneg, n, pcVneg) | |
77 numNeg = length(Vneg); | |
78 numPos = length(Vpos); | |
79 | |
80 % Inserting the a'th relevant document at position b | |
81 % There are (b - (a - 1)) negative docs before a | |
82 % And (numNeg - (b - (a - 1))) negative docs after | |
83 % | |
84 % The change in score is proportional to: | |
85 % | |
86 % sum_{negative j} (Vpos(a) - Vneg(j)) * y_{aj} | |
87 % | |
88 % = (numNeg - (b - (a - 1))) * Vpos(a) # Negatives after a | |
89 % - (cVneg(end) - cVneg(b - (a - 1))) Weight of negs after a | |
90 % - (b - (a - 1)) * Vpos(a) # Negatives before a | |
91 % + cVneg(b - (a - 1)) Weight of negs before a | |
92 % | |
93 % Rearrange: | |
94 % | |
95 % (numNeg - 2 * (b - a + 1)) * Vpos(a) | |
96 % - cVneg(end) + 2 * cVneg(b - a + 1) | |
97 % | |
98 % Valid range of a: 1:numPos | |
99 % Valid range of b: a:n | |
100 | |
101 D = bsxfun(@plus, 1-(1:numPos), (1:n)'); | |
102 C = numNeg - 2 * D; | |
103 C = bsxfun(@times, Vpos, C); | |
104 | |
105 D(D < 1) = 1; | |
106 D(D > length(pcVneg)) = length(pcVneg); | |
107 | |
108 % FIXME: 2011-01-28 21:13:37 by Brian McFee <bmcfee@cs.ucsd.edu> | |
109 % brutal hack to get around matlab's screwy matrix reshaping | |
110 if numPos == 1 | |
111 pcVneg = pcVneg'; | |
112 end | |
113 | |
114 C = C + 2 * pcVneg(D) - pcVneg(end); | |
115 | |
116 % Normalize | |
117 C = bsxfun(@ldivide, (1:numPos) * numNeg, C); | |
118 | |
119 % -Inf out the infeasible regions | |
120 C = C - triu(Inf * bsxfun(@gt, (1:numPos), (1:n)'),1); | |
121 | |
122 | |
123 end | |
124 | |
125 function x = AP(Y, pos) | |
126 % Indicator for relevant documents | |
127 rel = ismember(Y, pos); | |
128 | |
129 % Prec@k for all k | |
130 Prec = cumsum(rel)' ./ (1:length(Y)); | |
131 | |
132 % Prec@k averaged over relevant positions | |
133 x = mean(Prec(rel)); | |
134 end |