wolffd@0: function [dPsi, M, SO_time] = cuttingPlaneFull(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores) wolffd@0: % wolffd@0: % [dPsi, M, SO_time] = cuttingPlaneFull(k, X, W, Yp, Yn, batchSize, SAMPLES, ClassScores) wolffd@0: % wolffd@0: % k = k parameter for the SO wolffd@0: % X = d*n data matrix wolffd@0: % W = d*d PSD metric wolffd@0: % Yp = cell-array of relevant results for each point wolffd@0: % Yn = cell-array of irrelevant results for each point wolffd@0: % batchSize = number of points to use in the constraint batch wolffd@0: % SAMPLES = indices of valid points to include in the batch wolffd@0: % ClassScores = structure for synthetic constraints wolffd@0: % wolffd@0: % dPsi = dPsi vector for this batch wolffd@0: % M = mean loss on this batch wolffd@0: % SO_time = time spent in separation oracle wolffd@0: wolffd@0: global SO PSI DISTANCE CPGRADIENT; wolffd@0: wolffd@0: [d,n,m] = size(X); wolffd@0: D = DISTANCE(W, X); wolffd@0: wolffd@0: M = 0; wolffd@0: S = zeros(n); wolffd@0: dIndex = sub2ind([n n], 1:n, 1:n); wolffd@0: wolffd@0: SO_time = 0; wolffd@0: wolffd@0: if isempty(ClassScores) wolffd@0: TS = zeros(batchSize, n); wolffd@0: parfor i = 1:batchSize wolffd@0: if i <= length(SAMPLES) wolffd@0: j = SAMPLES(i); wolffd@0: wolffd@0: if isempty(Ypos{j}) wolffd@0: continue; wolffd@0: end wolffd@0: if isempty(Yneg) wolffd@0: % Construct a negative set wolffd@0: Ynegative = setdiff((1:n)', [j ; Ypos{j}]); wolffd@0: else wolffd@0: Ynegative = Yneg{j}; wolffd@0: end wolffd@0: SO_start = tic(); wolffd@0: [yi, li] = SO(j, D, Ypos{j}, Ynegative, k); wolffd@0: SO_time = SO_time + toc(SO_start); wolffd@0: wolffd@0: M = M + li /batchSize; wolffd@0: TS(i,:) = PSI(j, yi', n, Ypos{j}, Ynegative); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: % Reconstruct the S matrix from TS wolffd@0: S(SAMPLES,:) = TS; wolffd@0: S(:,SAMPLES) = S(:,SAMPLES) + TS'; wolffd@0: S(dIndex) = S(dIndex) - sum(TS, 1); wolffd@0: else wolffd@0: wolffd@0: % Do it class-wise for efficiency wolffd@0: batchSize = 0; wolffd@0: for j = 1:length(ClassScores.classes) wolffd@0: c = ClassScores.classes(j); wolffd@0: points = find(ClassScores.Y == c); wolffd@0: wolffd@0: Yneg = find(ClassScores.Yneg{j}); wolffd@0: yp = ClassScores.Ypos{j}; wolffd@0: wolffd@0: if length(points) <= 1 wolffd@0: continue; wolffd@0: end wolffd@0: wolffd@0: batchSize = batchSize + length(points); wolffd@0: TS = zeros(length(points), n); wolffd@0: parfor x = 1:length(points) wolffd@0: i = points(x); wolffd@0: yl = yp; wolffd@0: yl(i) = 0; wolffd@0: Ypos = find(yl); wolffd@0: SO_start = tic(); wolffd@0: [yi, li] = SO(i, D, Ypos, Yneg, k); wolffd@0: SO_time = SO_time + toc(SO_start); wolffd@0: wolffd@0: M = M + li; wolffd@0: TS(x,:) = PSI(i, yi', n, Ypos, Yneg); wolffd@0: end wolffd@0: wolffd@0: S(points,:) = S(points,:) + TS; wolffd@0: S(:,points) = S(:,points) + TS'; wolffd@0: S(dIndex) = S(dIndex) - sum(TS, 1); wolffd@0: end wolffd@0: M = M / batchSize; wolffd@0: end wolffd@0: wolffd@0: dPsi = CPGRADIENT(X, S, batchSize); wolffd@0: wolffd@0: end