comparison toolboxes/distance_learning/mlr/util/mlr_solver.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [W, Xi, Diagnostics] = mlr_solver(C, Margins, W, K)
2 % [W, Xi, D] = mlr_solver(C, Margins, W, X)
3 %
4 % C >= 0 Slack trade-off parameter
5 % Margins = array of mean margin values
6 % W = initial value for W
7 % X = data matrix (or kernel)
8 %
9 % W (output) = the learned metric
10 % Xi = 1-slack
11 % D = diagnostics
12
13 global DEBUG REG FEASIBLE LOSS;
14
15 %%%
16 % Initialize the gradient directions for each constraint
17 %
18 global PsiR;
19 global PsiClock;
20
21 numConstraints = length(PsiR);
22
23 %%%
24 % Some optimization details
25
26 % Armijo rule number
27 armijo = 1e-5;
28
29 % Initial learning rate
30 lambda0 = 1e-4;
31
32 % Increase/decrease after each iteration
33 lambdaup = ((1+sqrt(5))/2)^(1/3);
34 lambdadown = ((1+sqrt(5))/2)^(-1);
35
36 % Maximum steps to take
37 maxsteps = 1e4;
38
39 % Size of convergence window
40 frame = 10;
41
42 % Convergence threshold
43 convthresh = 1e-5;
44
45 % Maximum number of backtracks
46 maxbackcount = 100;
47
48
49 Diagnostics = struct( 'f', [], ...
50 'num_steps', [], ...
51 'stop_criteria', []);
52
53 % Repeat until convergence:
54 % 1) Calculate f
55 % 2) Take a gradient step
56 % 3) Project W back onto PSD
57
58 %%%
59 % Initialze
60 %
61
62 f = inf;
63 dfdW = zeros(size(W));
64 lambda = lambda0;
65 F = Inf * ones(1,maxsteps+1);
66 XiR = zeros(numConstraints,1);
67
68
69 stepcount = -1;
70 backcount = 0;
71 done = 0;
72
73
74 while 1
75 fold = f;
76 Wold = W;
77
78 %%%
79 % Count constraint violations and build the gradient
80 dbprint(3, 'Computing gradient');
81
82 %%%
83 % Calculate constraint violations
84 %
85 XiR(:) = 0;
86 for R = numConstraints:-1:1
87 XiR(R) = LOSS(W, PsiR{R}, Margins(R), 0);
88 end
89
90 %%%
91 % Find the most active constraint
92 %
93 [Xi, mgrad] = max(XiR);
94 Xi = max(Xi, 0);
95
96 PsiClock(mgrad) = 0;
97
98 %%%
99 % Evaluate f
100 %
101
102 f = C * max(Xi, 0) ...
103 + REG(W, K, 0);
104
105 %%%
106 % Test for convergence
107 %
108 objDiff = fold - f;
109
110 if objDiff > armijo * lambda * (dfdW(:)' * dfdW(:))
111
112 stepcount = stepcount + 1;
113
114 F(stepcount+1) = f;
115
116 sdiff = inf;
117 if stepcount >= frame;
118 sdiff = log(F(stepcount+1-frame) / f);
119 end
120
121 if stepcount >= maxsteps
122 done = 1;
123 stopcriteria = 'MAXSTEPS';
124 elseif sdiff <= convthresh
125 done = 1;
126 stopcriteria = 'CONVERGENCE';
127 else
128 %%%
129 % If it's positive, add the corresponding gradient
130 dfdW = C * LOSS(W, PsiR{mgrad}, Margins(mgrad), 1) ...
131 + REG(W, K, 1);
132 end
133
134 dbprint(3, 'Lambda up!');
135 Wold = W;
136 lambda = lambdaup * lambda;
137 backcount = 0;
138
139 else
140 % Backtracking time, drop the learning rate
141 if backcount >= maxbackcount
142 W = Wold;
143 f = fold;
144 done = 1;
145
146 stopcriteria = 'BACKTRACK';
147 else
148 dbprint(3, 'Lambda down!');
149 lambda = lambdadown * lambda;
150 backcount = backcount+1;
151 end
152 end
153
154 %%%
155 % Take a gradient step
156 %
157 W = W - lambda * dfdW;
158
159 %%%
160 % Project back onto the feasible set
161 %
162
163 dbprint(3, 'Projecting onto feasible set');
164 W = FEASIBLE(W);
165 if done
166 break;
167 end;
168
169 end
170
171 Diagnostics.f = F(2:(stepcount+1))';
172 Diagnostics.stop_criteria = stopcriteria;
173 Diagnostics.num_steps = stepcount;
174
175 dbprint(1, '\t%s after %d steps.\n', stopcriteria, stepcount);
176 end
177