comparison toolboxes/distance_learning/mlr/cuttingPlane/cuttingPlaneFull.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [dPsi, M, SO_time] = cuttingPlaneFull(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores)
2 %
3 % [dPsi, M, SO_time] = cuttingPlaneFull(k, X, W, Yp, Yn, batchSize, SAMPLES, ClassScores)
4 %
5 % k = k parameter for the SO
6 % X = d*n data matrix
7 % W = d*d PSD metric
8 % Yp = cell-array of relevant results for each point
9 % Yn = cell-array of irrelevant results for each point
10 % batchSize = number of points to use in the constraint batch
11 % SAMPLES = indices of valid points to include in the batch
12 % ClassScores = structure for synthetic constraints
13 %
14 % dPsi = dPsi vector for this batch
15 % M = mean loss on this batch
16 % SO_time = time spent in separation oracle
17
18 global SO PSI DISTANCE CPGRADIENT;
19
20 [d,n,m] = size(X);
21 D = DISTANCE(W, X);
22
23 M = 0;
24 S = zeros(n);
25 dIndex = sub2ind([n n], 1:n, 1:n);
26
27 SO_time = 0;
28
29 if isempty(ClassScores)
30 TS = zeros(batchSize, n);
31 parfor i = 1:batchSize
32 if i <= length(SAMPLES)
33 j = SAMPLES(i);
34
35 if isempty(Ypos{j})
36 continue;
37 end
38 if isempty(Yneg)
39 % Construct a negative set
40 Ynegative = setdiff((1:n)', [j ; Ypos{j}]);
41 else
42 Ynegative = Yneg{j};
43 end
44 SO_start = tic();
45 [yi, li] = SO(j, D, Ypos{j}, Ynegative, k);
46 SO_time = SO_time + toc(SO_start);
47
48 M = M + li /batchSize;
49 TS(i,:) = PSI(j, yi', n, Ypos{j}, Ynegative);
50 end
51 end
52
53 % Reconstruct the S matrix from TS
54 S(SAMPLES,:) = TS;
55 S(:,SAMPLES) = S(:,SAMPLES) + TS';
56 S(dIndex) = S(dIndex) - sum(TS, 1);
57 else
58
59 % Do it class-wise for efficiency
60 batchSize = 0;
61 for j = 1:length(ClassScores.classes)
62 c = ClassScores.classes(j);
63 points = find(ClassScores.Y == c);
64
65 Yneg = find(ClassScores.Yneg{j});
66 yp = ClassScores.Ypos{j};
67
68 if length(points) <= 1
69 continue;
70 end
71
72 batchSize = batchSize + length(points);
73 TS = zeros(length(points), n);
74 parfor x = 1:length(points)
75 i = points(x);
76 yl = yp;
77 yl(i) = 0;
78 Ypos = find(yl);
79 SO_start = tic();
80 [yi, li] = SO(i, D, Ypos, Yneg, k);
81 SO_time = SO_time + toc(SO_start);
82
83 M = M + li;
84 TS(x,:) = PSI(i, yi', n, Ypos, Yneg);
85 end
86
87 S(points,:) = S(points,:) + TS;
88 S(:,points) = S(:,points) + TS';
89 S(dIndex) = S(dIndex) - sum(TS, 1);
90 end
91 M = M / batchSize;
92 end
93
94 dPsi = CPGRADIENT(X, S, batchSize);
95
96 end