comparison toolboxes/distance_learning/mlr/cuttingPlane/cuttingPlaneRandom.m @ 0:cc4b1211e677 tip

initial commit to HG from Changeset: 646 (e263d8a21543) added further path and more save "camirversion.m"
author Daniel Wolff
date Fri, 19 Aug 2016 13:07:06 +0200
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:cc4b1211e677
1 function [dPsi, M, SO_time] = cuttingPlaneRandom(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores)
2 %
3 % [dPsi, M, SO_time] = cuttingPlaneRandom(k, X, W, Yp, Yn, batchSize, SAMPLES, ClassScores)
4 %
5 % k = k parameter for the SO
6 % X = d*n data matrix
7 % W = d*d PSD metric
8 % Yp = cell-array of relevant results for each point
9 % Yn = cell-array of irrelevant results for each point
10 % batchSize = number of points to use in the constraint batch
11 % SAMPLES = indices of valid points to include in the batch
12 % ClassScores = structure for synthetic constraints
13 %
14 % dPsi = dPsi vector for this batch
15 % M = mean loss on this batch
16 % SO_time = time spent in separation oracle
17
18 global SO PSI DISTANCE SETDISTANCE CPGRADIENT;
19
20 [d,n] = size(X);
21
22
23 if length(SAMPLES) == n
24 % All samples are fair game (full data)
25 Batch = randperm(n);
26 Batch = Batch(1:batchSize);
27 D = SETDISTANCE(X, W, Batch);
28
29 else
30 Batch = randperm(length(SAMPLES));
31 Batch = SAMPLES(Batch(1:batchSize));
32
33 Ito = sparse(n,1);
34
35 if isempty(ClassScores)
36 for i = Batch
37 Ito(Ypos{i}) = 1;
38 Ito(Yneg{i}) = 1;
39 end
40 D = SETDISTANCE(X, W, Batch, find(Ito));
41 else
42 D = SETDISTANCE(X, W, Batch, 1:n);
43 end
44 end
45
46
47 M = 0;
48 S = zeros(n);
49 dIndex = sub2ind([n n], 1:n, 1:n);
50
51 SO_time = 0;
52
53 if isempty(ClassScores)
54 for i = Batch
55 SO_start = tic();
56 [yi, li] = SO(i, D, Ypos{i}, Yneg{i}, k);
57 SO_time = SO_time + toc(SO_start);
58
59 M = M + li /batchSize;
60 snew = PSI(i, yi', n, Ypos{i}, Yneg{i});
61 S(i,:) = S(i,:) + snew';
62 S(:,i) = S(:,i) + snew;
63 S(dIndex) = S(dIndex) - snew';
64 end
65 else
66 for j = 1:length(ClassScores.classes)
67 c = ClassScores.classes(j);
68 points = find(ClassScores.Y(Batch) == c);
69 if ~any(points)
70 continue;
71 end
72
73 Yneg = find(ClassScores.Yneg{j});
74 yp = ClassScores.Ypos{j};
75
76 for x = 1:length(points)
77 i = Batch(points(x));
78 yp(i) = 0;
79 Ypos = find(yp);
80 SO_start = tic();
81 [yi, li] = SO(i, D, Ypos, Yneg, k);
82 SO_time = SO_time + toc(SO_start);
83
84 M = M + li /batchSize;
85 snew = PSI(i, yi', n, Ypos, Yneg);
86 S(i,:) = S(i,:) + snew';
87 S(:,i) = S(:,i) + snew;
88 S(dIndex) = S(dIndex) - snew';
89
90 yp(i) = 1;
91 end
92 end
93 end
94
95 dPsi = CPGRADIENT(X, S) / batchSize;
96
97 end