wolffd@0: function [Y, Loss] = separationOracleKNN(q, D, pos, neg, k) wolffd@0: % wolffd@0: % [Y,Loss] = separationOracleKNN(q, D, pos, neg, k) wolffd@0: % wolffd@0: % q = index of the query point wolffd@0: % D = the current distance matrix wolffd@0: % pos = indices of relevant results for q wolffd@0: % neg = indices of irrelevant results for q wolffd@0: % k = length of the list to consider wolffd@0: % wolffd@0: % Y is a permutation 1:n corresponding to the maximally wolffd@0: % violated constraint wolffd@0: % wolffd@0: % Loss is the loss for Y, in this case, 1-Prec@k(Y) wolffd@0: wolffd@0: wolffd@0: % First, sort the documents in descending order of W'Phi(q,x) wolffd@0: % Phi = - (X(q) - X(x)) * (X(q) - X(x))' wolffd@0: wolffd@0: % Sort the positive documents wolffd@0: ScorePos = - D(pos,q); wolffd@0: [Vpos, Ipos] = sort(full(ScorePos'), 'descend'); wolffd@0: Ipos = pos(Ipos); wolffd@0: wolffd@0: % Sort the negative documents wolffd@0: ScoreNeg = - D(neg,q); wolffd@0: [Vneg, Ineg] = sort(full(ScoreNeg'), 'descend'); wolffd@0: Ineg = neg(Ineg); wolffd@0: wolffd@0: % Now, solve the DP for the interleaving wolffd@0: wolffd@0: numPos = length(pos); wolffd@0: numNeg = length(neg); wolffd@0: n = numPos + numNeg; wolffd@0: wolffd@0: cVpos = cumsum(Vpos); wolffd@0: cVneg = cumsum(Vneg); wolffd@0: wolffd@0: wolffd@0: % If we don't have enough positive (or negative) examples, scale k down wolffd@0: k = min([k, numPos, numNeg]); wolffd@0: wolffd@0: % Algorithm: wolffd@0: % For each precision score in 0, 1/k, 2/k, ... 1 wolffd@0: % Calculate maximum discriminant score for that precision level wolffd@0: KNN = (0:(1/k):1)' > 0.5; wolffd@0: Discriminant = zeros(k+1, 1); wolffd@0: NegsBefore = zeros(numPos,k+1); wolffd@0: wolffd@0: % For 0 precision, all positives go after the first k negatives wolffd@0: wolffd@0: NegsBefore(:,1) = k + binarysearch(Vpos, Vneg(k+1:end)); wolffd@0: Discriminant(1) = Vpos * (numNeg - 2 * NegsBefore(:,1)) + numPos * cVneg(end) ... wolffd@0: - 2 * sum(cVneg(NegsBefore((NegsBefore(:,1) > 0),1))); wolffd@0: wolffd@0: wolffd@0: wolffd@0: % For precision (a-1)/k, swap the (a-1)'th positive doc wolffd@0: % into the top (k-a) negative docs wolffd@0: wolffd@0: for a = 2:(k+1) wolffd@0: NegsBefore(:,a) = NegsBefore(:,a-1); wolffd@0: wolffd@0: % We have a-1 positives, and k - (a-1) negatives wolffd@0: NegsBefore(a-1, a) = binarysearch(Vpos(a-1), Vneg(1:(k-a+1))); wolffd@0: wolffd@0: % There were NegsBefore(a-1,a-1) negatives before (a-1) wolffd@0: % Now there are NegsBefore(a,a-1) wolffd@0: wolffd@0: Discriminant(a) = Discriminant(a-1) ... wolffd@0: + 2 * (NegsBefore(a-1,a-1) - NegsBefore(a-1,a)) * Vpos(a-1); wolffd@0: wolffd@0: if NegsBefore(a-1,a-1) > 0 wolffd@0: Discriminant(a) = Discriminant(a) + 2 * cVneg(NegsBefore(a-1,a-1)); wolffd@0: end wolffd@0: if NegsBefore(a-1,a) > 0 wolffd@0: Discriminant(a) = Discriminant(a) - 2 * cVneg(NegsBefore(a-1,a)); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: % Normalize discriminant scores wolffd@0: Discriminant = Discriminant / (numPos * numNeg); wolffd@0: [s, x] = max(Discriminant - KNN); wolffd@0: wolffd@0: % Now we know that there are x-1 relevant docs in the max ranking wolffd@0: % Construct Y from NegsBefore(x,:) wolffd@0: wolffd@0: Y = nan * ones(n,1); wolffd@0: Y((1:numPos)' + NegsBefore(:,x)) = Ipos; wolffd@0: if sum(isnan(Y)) ~= length(Ineg) wolffd@0: keyboard; wolffd@0: end wolffd@0: Y(isnan(Y)) = Ineg; wolffd@0: wolffd@0: % Compute loss for this list wolffd@0: Loss = 1 - KNN(x); wolffd@0: end wolffd@0: