Mercurial > hg > camir-aes2014
comparison toolboxes/distance_learning/mlr/util/rmlr_admm.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [W, Xi, Diagnostics] = rmlr_admm(C, K, Delta, H, Q, lam) | |
2 % [W, Xi, D] = mlr_admm(C, Delta, W, X) | |
3 % | |
4 % C >= 0 Slack trade-off parameter | |
5 % K = data matrix (or kernel) | |
6 % Delta = array of mean margin values | |
7 % H = structural kernel matrix | |
8 % Q = kernel-structure interaction vector | |
9 % | |
10 % W (output) = the learned metric | |
11 % Xi = 1-slack | |
12 % D = diagnostics | |
13 | |
14 global DEBUG REG FEASIBLE LOSS INIT STRUCTKERNEL DUALW THRESH; | |
15 | |
16 %%% | |
17 % Initialize the gradient directions for each constraint | |
18 % | |
19 global PsiR; | |
20 | |
21 global ADMM_Z ADMM_V ADMM_UW ADMM_UV; | |
22 | |
23 global ADMM_STEPS; | |
24 | |
25 global RHO; | |
26 | |
27 numConstraints = length(PsiR); | |
28 | |
29 Diagnostics = struct( 'f', [], ... | |
30 'num_steps', [], ... | |
31 'stop_criteria', []); | |
32 | |
33 | |
34 % Convergence settings | |
35 if ~isempty(ADMM_STEPS) | |
36 MAX_ITER = ADMM_STEPS; | |
37 else | |
38 MAX_ITER = 10; | |
39 end | |
40 ABSTOL = 1e-4 * sqrt(numel(ADMM_Z)); | |
41 RELTOL = 1e-3; | |
42 SCALE_THRESH = 10; | |
43 RHO_RESCALE = 2; | |
44 stopcriteria= 'MAX STEPS'; | |
45 | |
46 % Objective function | |
47 F = zeros(1,MAX_ITER); | |
48 | |
49 % how many constraints | |
50 | |
51 alpha = zeros(numConstraints, 1); | |
52 Gamma = zeros(numConstraints, 1); | |
53 | |
54 ln1 = 0; | |
55 ln2 = 0; | |
56 | |
57 % figure(2) | |
58 % hold off | |
59 % plot(0) | |
60 % delete(abc) | |
61 % delete(abc2) | |
62 for step = 1:MAX_ITER | |
63 % do a w-update | |
64 % dubstep needs: | |
65 % C <-- static | |
66 % RHO <-- static | |
67 % H <-- static | |
68 % Q <-- static | |
69 % Delta <-- static | |
70 % Gamma <-- this one's dynamic | |
71 | |
72 for i = 1:numConstraints | |
73 Gamma(i) = STRUCTKERNEL(ADMM_Z-ADMM_UW, PsiR{i}); | |
74 end | |
75 % d = length(K); | |
76 alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha); | |
77 | |
78 %%% | |
79 % 3) convert back to W | |
80 % | |
81 W = DUALW(alpha, ADMM_Z, ADMM_UW, RHO, K); | |
82 | |
83 % figure(1), imagesc(W), drawnow; | |
84 | |
85 % Update V | |
86 ADMM_V = THRESH(ADMM_Z - ADMM_UV, lam/RHO); | |
87 | |
88 % Update Z | |
89 Zold = ADMM_Z; | |
90 ADMM_Z = FEASIBLE(0.5* (W + ADMM_V + ADMM_UW + ADMM_UV)); | |
91 | |
92 % Update residuals | |
93 ADMM_UW = ADMM_UW + W - ADMM_Z; | |
94 ADMM_UV = ADMM_UV + ADMM_V - ADMM_Z; | |
95 | |
96 % Compute primal objective | |
97 % slack term | |
98 Xi = 0; | |
99 for R = numConstraints:-1:1 | |
100 Xi = max(Xi, LOSS(ADMM_Z, PsiR{R}, Delta(R), 0)); | |
101 end | |
102 F(step) = C * Xi + REG(W, K, 0) + lam * sum(sqrt(sum(W.^2))); | |
103 | |
104 % figure(2), loglog(1:step, F(1:step)), xlim([0, MAX_ITER]), drawnow; | |
105 % Test for convergence | |
106 | |
107 %WIP | |
108 N1 = norm(ADMM_V(:) + W(:) - 2* ADMM_Z(:)); | |
109 N2 = RHO * norm(2* (Zold(:) - ADMM_Z(:))); | |
110 | |
111 eps_primal = ABSTOL + RELTOL * max(norm(W(:)), norm(ADMM_Z(:))); | |
112 eps_dual = ABSTOL + RELTOL * RHO * norm(ADMM_UW(:)); | |
113 %end WIP | |
114 | |
115 | |
116 % figure(2), loglog(step + (-1:0), [ln1, N1/eps_primal], 'b'), xlim([0, MAX_ITER]), hold('on'); | |
117 % figure(2), loglog(step + (-1:0), [ln2, N2/eps_dual], 'r-'), xlim([0, MAX_ITER]), hold('on'), drawnow; | |
118 % ln1 = N1/eps_primal; | |
119 % ln2 = N2/eps_dual; | |
120 | |
121 if N1 < eps_primal && N2 < eps_dual | |
122 stopcriteria = 'CONVERGENCE'; | |
123 break; | |
124 end | |
125 | |
126 if N1 > SCALE_THRESH * N2 | |
127 dbprint(3, sprintf('RHO: %.2e UP %.2e', RHO, RHO * RHO_RESCALE)); | |
128 RHO = RHO * RHO_RESCALE; | |
129 ADMM_UW = ADMM_UW / RHO_RESCALE; | |
130 elseif N2 > SCALE_THRESH * N1 | |
131 dbprint(3, sprintf('RHO: %.2e DN %.2e', RHO, RHO / RHO_RESCALE)); | |
132 RHO = RHO / RHO_RESCALE; | |
133 ADMM_UW = ADMM_UW * RHO_RESCALE; | |
134 end | |
135 end | |
136 % figure(2), hold('off'); | |
137 | |
138 %%% | |
139 % Ensure feasibility | |
140 % | |
141 W = FEASIBLE(W); | |
142 | |
143 | |
144 %%% | |
145 % Compute the slack | |
146 % | |
147 Xi = 0; | |
148 for R = numConstraints:-1:1 | |
149 Xi = max(Xi, LOSS(W, PsiR{R}, Delta(R), 0)); | |
150 end | |
151 | |
152 %%% | |
153 % Update diagnostics | |
154 % | |
155 | |
156 Diagnostics.f = F(1:step)'; | |
157 Diagnostics.stop_criteria = stopcriteria; | |
158 Diagnostics.num_steps = step; | |
159 | |
160 dbprint(1, '\t%s after %d steps.\n', stopcriteria, step); | |
161 end | |
162 | |
163 function alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha) | |
164 | |
165 global PsiClock; | |
166 | |
167 m = length(Delta); | |
168 | |
169 if nargin < 7 | |
170 alpha = zeros(m,1); | |
171 end | |
172 | |
173 %%% | |
174 % 1) construct the QP parameters | |
175 % | |
176 b = RHO * (Gamma - Delta) - Q; | |
177 | |
178 %%% | |
179 % 2) solve the QP | |
180 % | |
181 alpha = qplcprog(H, b, ones(1, m), C, [], [], 0, []); | |
182 | |
183 %%% | |
184 % 3) update the Psi clock | |
185 % | |
186 PsiClock(alpha > 0) = 0; | |
187 | |
188 end |