Mercurial > hg > camir-aes2014
comparison toolboxes/distance_learning/mlr/cuttingPlane/cuttingPlaneRandom.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [dPsi, M, SO_time] = cuttingPlaneRandom(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores) | |
2 % | |
3 % [dPsi, M, SO_time] = cuttingPlaneRandom(k, X, W, Yp, Yn, batchSize, SAMPLES, ClassScores) | |
4 % | |
5 % k = k parameter for the SO | |
6 % X = d*n data matrix | |
7 % W = d*d PSD metric | |
8 % Yp = cell-array of relevant results for each point | |
9 % Yn = cell-array of irrelevant results for each point | |
10 % batchSize = number of points to use in the constraint batch | |
11 % SAMPLES = indices of valid points to include in the batch | |
12 % ClassScores = structure for synthetic constraints | |
13 % | |
14 % dPsi = dPsi vector for this batch | |
15 % M = mean loss on this batch | |
16 % SO_time = time spent in separation oracle | |
17 | |
18 global SO PSI SETDISTANCE CPGRADIENT; | |
19 | |
20 [d,n] = size(X); | |
21 | |
22 | |
23 if length(SAMPLES) == n | |
24 % All samples are fair game (full data) | |
25 Batch = randperm(n); | |
26 Batch = Batch(1:batchSize); | |
27 D = SETDISTANCE(X, W, Batch); | |
28 | |
29 else | |
30 Batch = randperm(length(SAMPLES)); | |
31 Batch = SAMPLES(Batch(1:batchSize)); | |
32 | |
33 Ito = sparse(n,1); | |
34 | |
35 if isempty(ClassScores) | |
36 for i = Batch | |
37 Ito(Ypos{i}) = 1; | |
38 Ito(Yneg{i}) = 1; | |
39 end | |
40 D = SETDISTANCE(X, W, Batch, find(Ito)); | |
41 else | |
42 D = SETDISTANCE(X, W, Batch, 1:n); | |
43 end | |
44 end | |
45 | |
46 | |
47 M = 0; | |
48 S = zeros(n); | |
49 dIndex = sub2ind([n n], 1:n, 1:n); | |
50 | |
51 SO_time = 0; | |
52 | |
53 | |
54 if isempty(ClassScores) | |
55 TS = zeros(batchSize, n); | |
56 parfor j = 1:batchSize | |
57 i = Batch(j); | |
58 if isempty(Yneg) | |
59 Ynegative = setdiff((1:n)', [i ; Ypos{i}]); | |
60 else | |
61 Ynegative = Yneg{i}; | |
62 end | |
63 SO_start = tic(); | |
64 [yi, li] = SO(i, D, Ypos{i}, Ynegative, k); | |
65 SO_time = SO_time + toc(SO_start); | |
66 | |
67 M = M + li /batchSize; | |
68 TS(j,:) = PSI(i, yi', n, Ypos{i}, Ynegative); | |
69 end | |
70 S(Batch,:) = TS; | |
71 S(:,Batch) = S(:,Batch) + TS'; | |
72 S(dIndex) = S(dIndex) - sum(TS, 1); | |
73 else | |
74 for j = 1:length(ClassScores.classes) | |
75 c = ClassScores.classes(j); | |
76 points = find(ClassScores.Y(Batch) == c); | |
77 if ~any(points) | |
78 continue; | |
79 end | |
80 | |
81 Yneg = find(ClassScores.Yneg{j}); | |
82 yp = ClassScores.Ypos{j}; | |
83 | |
84 TS = zeros(length(points), n); | |
85 parfor x = 1:length(points) | |
86 i = Batch(points(x)); | |
87 yl = yp; | |
88 yl(i) = 0; | |
89 Ypos = find(yl); | |
90 SO_start = tic(); | |
91 [yi, li] = SO(i, D, Ypos, Yneg, k); | |
92 SO_time = SO_time + toc(SO_start); | |
93 | |
94 M = M + li /batchSize; | |
95 TS(x,:) = PSI(i, yi', n, Ypos, Yneg); | |
96 end | |
97 S(Batch(points),:) = S(Batch(points),:) + TS; | |
98 S(:,Batch(points)) = S(:,Batch(points)) + TS'; | |
99 S(dIndex) = S(dIndex) - sum(TS, 1); | |
100 end | |
101 end | |
102 | |
103 dPsi = CPGRADIENT(X, S, batchSize); | |
104 | |
105 end |