comparison toolboxes/distance_learning/mlr/cuttingPlane/cuttingPlaneFull.m @ 0:cc4b1211e677 tip

initial commit to HG from Changeset: 646 (e263d8a21543) added further path and more save "camirversion.m"
author Daniel Wolff
date Fri, 19 Aug 2016 13:07:06 +0200
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:cc4b1211e677
1 function [dPsi, M, SO_time] = cuttingPlaneFull(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores)
2 %
3 % [dPsi, M, SO_time] = cuttingPlaneFull(k, X, W, Yp, Yn, batchSize, SAMPLES)
4 %
5 % k = k parameter for the SO
6 % X = d*n data matrix
7 % W = d*d PSD metric
8 % Yp = cell-array of relevant results for each point
9 % Yn = cell-array of irrelevant results for each point
10 % batchSize = number of points to use in the constraint batch
11 % SAMPLES = indices of valid points to include in the batch
12 %
13 % dPsi = dPsi vector for this batch
14 % M = mean loss on this batch
15 % SO_time = time spent in separation oracle
16
17 global SO PSI DISTANCE CPGRADIENT;
18
19 [d,n,m] = size(X);
20 D = DISTANCE(W, X);
21
22 M = 0;
23 S = zeros(n);
24 dIndex = sub2ind([n n], 1:n, 1:n);
25
26 SO_time = 0;
27
28 if isempty(ClassScores)
29 for i = 1:batchSize
30 if i > length(SAMPLES)
31 break;
32 end
33 i = SAMPLES(i);
34
35 if isempty(Ypos{i})
36 continue;
37 end
38 if isempty(Yneg)
39 % Construct a negative set
40 Ynegative = setdiff((1:n)', [i ; Ypos{i}]);
41 else
42 Ynegative = Yneg{i};
43 end
44 SO_start = tic();
45 [yi, li] = SO(i, D, Ypos{i}, Ynegative, k);
46 SO_time = SO_time + toc(SO_start);
47
48 M = M + li /batchSize;
49 snew = PSI(i, yi', n, Ypos{i}, Ynegative);
50 S(i,:) = S(i,:) + snew';
51 S(:,i) = S(:,i) + snew;
52 S(dIndex) = S(dIndex) - snew';
53 end
54 else
55
56 % Do it class-wise for efficiency
57 for j = 1:length(ClassScores.classes)
58 c = ClassScores.classes(j);
59 points = find(ClassScores.Y == c);
60
61 Yneg = find(ClassScores.Yneg{j});
62 yp = ClassScores.Ypos{j};
63
64 if length(points) <= 1
65 continue;
66 end
67 for x = 1:length(points)
68 i = points(x);
69 yp(i) = 0;
70 Ypos = find(yp);
71 SO_start = tic();
72 [yi, li] = SO(i, D, Ypos, Yneg, k);
73 SO_time = SO_time + toc(SO_start);
74
75 M = M + li /batchSize;
76 snew = PSI(i, yi', n, Ypos, Yneg);
77 S(i,:) = S(i,:) + snew';
78 S(:,i) = S(:,i) + snew;
79 S(dIndex) = S(dIndex) - snew';
80
81 yp(i) = 1;
82 end
83 end
84 end
85
86 dPsi = CPGRADIENT(X, S) / batchSize;
87
88 end