Mercurial > hg > camir-aes2014
comparison toolboxes/distance_learning/mlr/cuttingPlane/cuttingPlaneFull.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [dPsi, M, SO_time] = cuttingPlaneFull(k, X, W, Ypos, Yneg, batchSize, SAMPLES, ClassScores) | |
2 % | |
3 % [dPsi, M, SO_time] = cuttingPlaneFull(k, X, W, Yp, Yn, batchSize, SAMPLES, ClassScores) | |
4 % | |
5 % k = k parameter for the SO | |
6 % X = d*n data matrix | |
7 % W = d*d PSD metric | |
8 % Yp = cell-array of relevant results for each point | |
9 % Yn = cell-array of irrelevant results for each point | |
10 % batchSize = number of points to use in the constraint batch | |
11 % SAMPLES = indices of valid points to include in the batch | |
12 % ClassScores = structure for synthetic constraints | |
13 % | |
14 % dPsi = dPsi vector for this batch | |
15 % M = mean loss on this batch | |
16 % SO_time = time spent in separation oracle | |
17 | |
18 global SO PSI DISTANCE CPGRADIENT; | |
19 | |
20 [d,n,m] = size(X); | |
21 D = DISTANCE(W, X); | |
22 | |
23 M = 0; | |
24 S = zeros(n); | |
25 dIndex = sub2ind([n n], 1:n, 1:n); | |
26 | |
27 SO_time = 0; | |
28 | |
29 if isempty(ClassScores) | |
30 TS = zeros(batchSize, n); | |
31 parfor i = 1:batchSize | |
32 if i <= length(SAMPLES) | |
33 j = SAMPLES(i); | |
34 | |
35 if isempty(Ypos{j}) | |
36 continue; | |
37 end | |
38 if isempty(Yneg) | |
39 % Construct a negative set | |
40 Ynegative = setdiff((1:n)', [j ; Ypos{j}]); | |
41 else | |
42 Ynegative = Yneg{j}; | |
43 end | |
44 SO_start = tic(); | |
45 [yi, li] = SO(j, D, Ypos{j}, Ynegative, k); | |
46 SO_time = SO_time + toc(SO_start); | |
47 | |
48 M = M + li /batchSize; | |
49 TS(i,:) = PSI(j, yi', n, Ypos{j}, Ynegative); | |
50 end | |
51 end | |
52 | |
53 % Reconstruct the S matrix from TS | |
54 S(SAMPLES,:) = TS; | |
55 S(:,SAMPLES) = S(:,SAMPLES) + TS'; | |
56 S(dIndex) = S(dIndex) - sum(TS, 1); | |
57 else | |
58 | |
59 % Do it class-wise for efficiency | |
60 batchSize = 0; | |
61 for j = 1:length(ClassScores.classes) | |
62 c = ClassScores.classes(j); | |
63 points = find(ClassScores.Y == c); | |
64 | |
65 Yneg = find(ClassScores.Yneg{j}); | |
66 yp = ClassScores.Ypos{j}; | |
67 | |
68 if length(points) <= 1 | |
69 continue; | |
70 end | |
71 | |
72 batchSize = batchSize + length(points); | |
73 TS = zeros(length(points), n); | |
74 parfor x = 1:length(points) | |
75 i = points(x); | |
76 yl = yp; | |
77 yl(i) = 0; | |
78 Ypos = find(yl); | |
79 SO_start = tic(); | |
80 [yi, li] = SO(i, D, Ypos, Yneg, k); | |
81 SO_time = SO_time + toc(SO_start); | |
82 | |
83 M = M + li; | |
84 TS(x,:) = PSI(i, yi', n, Ypos, Yneg); | |
85 end | |
86 | |
87 S(points,:) = S(points,:) + TS; | |
88 S(:,points) = S(:,points) + TS'; | |
89 S(dIndex) = S(dIndex) - sum(TS, 1); | |
90 end | |
91 M = M / batchSize; | |
92 end | |
93 | |
94 dPsi = CPGRADIENT(X, S, batchSize); | |
95 | |
96 end |