Mercurial > hg > camir-aes2014
diff toolboxes/distance_learning/mlr/cuttingPlane/cuttingPlaneFull.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/distance_learning/mlr/cuttingPlane/cuttingPlaneFull.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,96 @@ +function [dPsi, M, SO_time] = cuttingPlaneFull(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores) +% +% [dPsi, M, SO_time] = cuttingPlaneFull(k, X, W, Yp, Yn, batchSize, SAMPLES, ClassScores) +% +% k = k parameter for the SO +% X = d*n data matrix +% W = d*d PSD metric +% Yp = cell-array of relevant results for each point +% Yn = cell-array of irrelevant results for each point +% batchSize = number of points to use in the constraint batch +% SAMPLES = indices of valid points to include in the batch +% ClassScores = structure for synthetic constraints +% +% dPsi = dPsi vector for this batch +% M = mean loss on this batch +% SO_time = time spent in separation oracle + + global SO PSI DISTANCE CPGRADIENT; + + [d,n,m] = size(X); + D = DISTANCE(W, X); + + M = 0; + S = zeros(n); + dIndex = sub2ind([n n], 1:n, 1:n); + + SO_time = 0; + + if isempty(ClassScores) + TS = zeros(batchSize, n); + parfor i = 1:batchSize + if i <= length(SAMPLES) + j = SAMPLES(i); + + if isempty(Ypos{j}) + continue; + end + if isempty(Yneg) + % Construct a negative set + Ynegative = setdiff((1:n)', [j ; Ypos{j}]); + else + Ynegative = Yneg{j}; + end + SO_start = tic(); + [yi, li] = SO(j, D, Ypos{j}, Ynegative, k); + SO_time = SO_time + toc(SO_start); + + M = M + li /batchSize; + TS(i,:) = PSI(j, yi', n, Ypos{j}, Ynegative); + end + end + + % Reconstruct the S matrix from TS + S(SAMPLES,:) = TS; + S(:,SAMPLES) = S(:,SAMPLES) + TS'; + S(dIndex) = S(dIndex) - sum(TS, 1); + else + + % Do it class-wise for efficiency + batchSize = 0; + for j = 1:length(ClassScores.classes) + c = ClassScores.classes(j); + points = find(ClassScores.Y == c); + + Yneg = find(ClassScores.Yneg{j}); + yp = ClassScores.Ypos{j}; + + if length(points) <= 1 + continue; + end + + batchSize = batchSize + length(points); + TS = zeros(length(points), n); + parfor x = 1:length(points) + i = points(x); + yl = yp; + yl(i) = 0; + Ypos = find(yl); + SO_start = tic(); + [yi, li] = SO(i, D, Ypos, Yneg, k); + SO_time = SO_time + toc(SO_start); + + M = M + li; + TS(x,:) = PSI(i, yi', n, Ypos, Yneg); + end + + S(points,:) = S(points,:) + TS; + S(:,points) = S(:,points) + TS'; + S(dIndex) = S(dIndex) - sum(TS, 1); + end + M = M / batchSize; + end + + dPsi = CPGRADIENT(X, S, batchSize); + +end