Mercurial > hg > camir-aes2014
comparison toolboxes/distance_learning/mlr/util/mlr_solver.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [W, Xi, Diagnostics] = mlr_solver(C, Margins, W, K) | |
2 % [W, Xi, D] = mlr_solver(C, Margins, W, X) | |
3 % | |
4 % C >= 0 Slack trade-off parameter | |
5 % Margins = array of mean margin values | |
6 % W = initial value for W | |
7 % X = data matrix (or kernel) | |
8 % | |
9 % W (output) = the learned metric | |
10 % Xi = 1-slack | |
11 % D = diagnostics | |
12 | |
13 global DEBUG REG FEASIBLE LOSS; | |
14 | |
15 %%% | |
16 % Initialize the gradient directions for each constraint | |
17 % | |
18 global PsiR; | |
19 global PsiClock; | |
20 | |
21 numConstraints = length(PsiR); | |
22 | |
23 %%% | |
24 % Some optimization details | |
25 | |
26 % Armijo rule number | |
27 armijo = 1e-5; | |
28 | |
29 % Initial learning rate | |
30 lambda0 = 1e-4; | |
31 | |
32 % Increase/decrease after each iteration | |
33 lambdaup = ((1+sqrt(5))/2)^(1/3); | |
34 lambdadown = ((1+sqrt(5))/2)^(-1); | |
35 | |
36 % Maximum steps to take | |
37 maxsteps = 1e4; | |
38 | |
39 % Size of convergence window | |
40 frame = 10; | |
41 | |
42 % Convergence threshold | |
43 convthresh = 1e-5; | |
44 | |
45 % Maximum number of backtracks | |
46 maxbackcount = 100; | |
47 | |
48 | |
49 Diagnostics = struct( 'f', [], ... | |
50 'num_steps', [], ... | |
51 'stop_criteria', []); | |
52 | |
53 % Repeat until convergence: | |
54 % 1) Calculate f | |
55 % 2) Take a gradient step | |
56 % 3) Project W back onto PSD | |
57 | |
58 %%% | |
59 % Initialze | |
60 % | |
61 | |
62 f = inf; | |
63 dfdW = zeros(size(W)); | |
64 lambda = lambda0; | |
65 F = Inf * ones(1,maxsteps+1); | |
66 XiR = zeros(numConstraints,1); | |
67 | |
68 | |
69 stepcount = -1; | |
70 backcount = 0; | |
71 done = 0; | |
72 | |
73 | |
74 while 1 | |
75 fold = f; | |
76 Wold = W; | |
77 | |
78 %%% | |
79 % Count constraint violations and build the gradient | |
80 dbprint(3, 'Computing gradient'); | |
81 | |
82 %%% | |
83 % Calculate constraint violations | |
84 % | |
85 XiR(:) = 0; | |
86 for R = numConstraints:-1:1 | |
87 XiR(R) = LOSS(W, PsiR{R}, Margins(R), 0); | |
88 end | |
89 | |
90 %%% | |
91 % Find the most active constraint | |
92 % | |
93 [Xi, mgrad] = max(XiR); | |
94 Xi = max(Xi, 0); | |
95 | |
96 PsiClock(mgrad) = 0; | |
97 | |
98 %%% | |
99 % Evaluate f | |
100 % | |
101 | |
102 f = C * max(Xi, 0) ... | |
103 + REG(W, K, 0); | |
104 | |
105 %%% | |
106 % Test for convergence | |
107 % | |
108 objDiff = fold - f; | |
109 | |
110 if objDiff > armijo * lambda * (dfdW(:)' * dfdW(:)) | |
111 | |
112 stepcount = stepcount + 1; | |
113 | |
114 F(stepcount+1) = f; | |
115 | |
116 sdiff = inf; | |
117 if stepcount >= frame; | |
118 sdiff = log(F(stepcount+1-frame) / f); | |
119 end | |
120 | |
121 if stepcount >= maxsteps | |
122 done = 1; | |
123 stopcriteria = 'MAXSTEPS'; | |
124 elseif sdiff <= convthresh | |
125 done = 1; | |
126 stopcriteria = 'CONVERGENCE'; | |
127 else | |
128 %%% | |
129 % If it's positive, add the corresponding gradient | |
130 dfdW = C * LOSS(W, PsiR{mgrad}, Margins(mgrad), 1) ... | |
131 + REG(W, K, 1); | |
132 end | |
133 | |
134 dbprint(3, 'Lambda up!'); | |
135 Wold = W; | |
136 lambda = lambdaup * lambda; | |
137 backcount = 0; | |
138 | |
139 else | |
140 % Backtracking time, drop the learning rate | |
141 if backcount >= maxbackcount | |
142 W = Wold; | |
143 f = fold; | |
144 done = 1; | |
145 | |
146 stopcriteria = 'BACKTRACK'; | |
147 else | |
148 dbprint(3, 'Lambda down!'); | |
149 lambda = lambdadown * lambda; | |
150 backcount = backcount+1; | |
151 end | |
152 end | |
153 | |
154 %%% | |
155 % Take a gradient step | |
156 % | |
157 W = W - lambda * dfdW; | |
158 | |
159 %%% | |
160 % Project back onto the feasible set | |
161 % | |
162 | |
163 dbprint(3, 'Projecting onto feasible set'); | |
164 W = FEASIBLE(W); | |
165 if done | |
166 break; | |
167 end; | |
168 | |
169 end | |
170 | |
171 Diagnostics.f = F(2:(stepcount+1))'; | |
172 Diagnostics.stop_criteria = stopcriteria; | |
173 Diagnostics.num_steps = stepcount; | |
174 | |
175 dbprint(1, '\t%s after %d steps.\n', stopcriteria, stepcount); | |
176 end | |
177 |