wolffd@0
|
1 function [W, Xi, Diagnostics] = rmlr_admm(C, K, Delta, H, Q, lam)
|
wolffd@0
|
2 % [W, Xi, D] = mlr_admm(C, Delta, W, X)
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % C >= 0 Slack trade-off parameter
|
wolffd@0
|
5 % K = data matrix (or kernel)
|
wolffd@0
|
6 % Delta = array of mean margin values
|
wolffd@0
|
7 % H = structural kernel matrix
|
wolffd@0
|
8 % Q = kernel-structure interaction vector
|
wolffd@0
|
9 %
|
wolffd@0
|
10 % W (output) = the learned metric
|
wolffd@0
|
11 % Xi = 1-slack
|
wolffd@0
|
12 % D = diagnostics
|
wolffd@0
|
13
|
wolffd@0
|
14 global DEBUG REG FEASIBLE LOSS INIT STRUCTKERNEL DUALW THRESH;
|
wolffd@0
|
15
|
wolffd@0
|
16 %%%
|
wolffd@0
|
17 % Initialize the gradient directions for each constraint
|
wolffd@0
|
18 %
|
wolffd@0
|
19 global PsiR;
|
wolffd@0
|
20
|
wolffd@0
|
21 global ADMM_Z ADMM_V ADMM_UW ADMM_UV;
|
wolffd@0
|
22
|
wolffd@0
|
23 global ADMM_STEPS;
|
wolffd@0
|
24
|
wolffd@0
|
25 global RHO;
|
wolffd@0
|
26
|
wolffd@0
|
27 numConstraints = length(PsiR);
|
wolffd@0
|
28
|
wolffd@0
|
29 Diagnostics = struct( 'f', [], ...
|
wolffd@0
|
30 'num_steps', [], ...
|
wolffd@0
|
31 'stop_criteria', []);
|
wolffd@0
|
32
|
wolffd@0
|
33
|
wolffd@0
|
34 % Convergence settings
|
wolffd@0
|
35 if ~isempty(ADMM_STEPS)
|
wolffd@0
|
36 MAX_ITER = ADMM_STEPS;
|
wolffd@0
|
37 else
|
wolffd@0
|
38 MAX_ITER = 10;
|
wolffd@0
|
39 end
|
wolffd@0
|
40 ABSTOL = 1e-4 * sqrt(numel(ADMM_Z));
|
wolffd@0
|
41 RELTOL = 1e-3;
|
wolffd@0
|
42 SCALE_THRESH = 10;
|
wolffd@0
|
43 RHO_RESCALE = 2;
|
wolffd@0
|
44 stopcriteria= 'MAX STEPS';
|
wolffd@0
|
45
|
wolffd@0
|
46 % Objective function
|
wolffd@0
|
47 F = zeros(1,MAX_ITER);
|
wolffd@0
|
48
|
wolffd@0
|
49 % how many constraints
|
wolffd@0
|
50
|
wolffd@0
|
51 alpha = zeros(numConstraints, 1);
|
wolffd@0
|
52 Gamma = zeros(numConstraints, 1);
|
wolffd@0
|
53
|
wolffd@0
|
54 ln1 = 0;
|
wolffd@0
|
55 ln2 = 0;
|
wolffd@0
|
56
|
wolffd@0
|
57 % figure(2)
|
wolffd@0
|
58 % hold off
|
wolffd@0
|
59 % plot(0)
|
wolffd@0
|
60 % delete(abc)
|
wolffd@0
|
61 % delete(abc2)
|
wolffd@0
|
62 for step = 1:MAX_ITER
|
wolffd@0
|
63 % do a w-update
|
wolffd@0
|
64 % dubstep needs:
|
wolffd@0
|
65 % C <-- static
|
wolffd@0
|
66 % RHO <-- static
|
wolffd@0
|
67 % H <-- static
|
wolffd@0
|
68 % Q <-- static
|
wolffd@0
|
69 % Delta <-- static
|
wolffd@0
|
70 % Gamma <-- this one's dynamic
|
wolffd@0
|
71
|
wolffd@0
|
72 for i = 1:numConstraints
|
wolffd@0
|
73 Gamma(i) = STRUCTKERNEL(ADMM_Z-ADMM_UW, PsiR{i});
|
wolffd@0
|
74 end
|
wolffd@0
|
75 % d = length(K);
|
wolffd@0
|
76 alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha);
|
wolffd@0
|
77
|
wolffd@0
|
78 %%%
|
wolffd@0
|
79 % 3) convert back to W
|
wolffd@0
|
80 %
|
wolffd@0
|
81 W = DUALW(alpha, ADMM_Z, ADMM_UW, RHO, K);
|
wolffd@0
|
82
|
wolffd@0
|
83 % figure(1), imagesc(W), drawnow;
|
wolffd@0
|
84
|
wolffd@0
|
85 % Update V
|
wolffd@0
|
86 ADMM_V = THRESH(ADMM_Z - ADMM_UV, lam/RHO);
|
wolffd@0
|
87
|
wolffd@0
|
88 % Update Z
|
wolffd@0
|
89 Zold = ADMM_Z;
|
wolffd@0
|
90 ADMM_Z = FEASIBLE(0.5* (W + ADMM_V + ADMM_UW + ADMM_UV));
|
wolffd@0
|
91
|
wolffd@0
|
92 % Update residuals
|
wolffd@0
|
93 ADMM_UW = ADMM_UW + W - ADMM_Z;
|
wolffd@0
|
94 ADMM_UV = ADMM_UV + ADMM_V - ADMM_Z;
|
wolffd@0
|
95
|
wolffd@0
|
96 % Compute primal objective
|
wolffd@0
|
97 % slack term
|
wolffd@0
|
98 Xi = 0;
|
wolffd@0
|
99 for R = numConstraints:-1:1
|
wolffd@0
|
100 Xi = max(Xi, LOSS(ADMM_Z, PsiR{R}, Delta(R), 0));
|
wolffd@0
|
101 end
|
wolffd@0
|
102 F(step) = C * Xi + REG(W, K, 0) + lam * sum(sqrt(sum(W.^2)));
|
wolffd@0
|
103
|
wolffd@0
|
104 % figure(2), loglog(1:step, F(1:step)), xlim([0, MAX_ITER]), drawnow;
|
wolffd@0
|
105 % Test for convergence
|
wolffd@0
|
106
|
wolffd@0
|
107 %WIP
|
wolffd@0
|
108 N1 = norm(ADMM_V(:) + W(:) - 2* ADMM_Z(:));
|
wolffd@0
|
109 N2 = RHO * norm(2* (Zold(:) - ADMM_Z(:)));
|
wolffd@0
|
110
|
wolffd@0
|
111 eps_primal = ABSTOL + RELTOL * max(norm(W(:)), norm(ADMM_Z(:)));
|
wolffd@0
|
112 eps_dual = ABSTOL + RELTOL * RHO * norm(ADMM_UW(:));
|
wolffd@0
|
113 %end WIP
|
wolffd@0
|
114
|
wolffd@0
|
115
|
wolffd@0
|
116 % figure(2), loglog(step + (-1:0), [ln1, N1/eps_primal], 'b'), xlim([0, MAX_ITER]), hold('on');
|
wolffd@0
|
117 % figure(2), loglog(step + (-1:0), [ln2, N2/eps_dual], 'r-'), xlim([0, MAX_ITER]), hold('on'), drawnow;
|
wolffd@0
|
118 % ln1 = N1/eps_primal;
|
wolffd@0
|
119 % ln2 = N2/eps_dual;
|
wolffd@0
|
120
|
wolffd@0
|
121 if N1 < eps_primal && N2 < eps_dual
|
wolffd@0
|
122 stopcriteria = 'CONVERGENCE';
|
wolffd@0
|
123 break;
|
wolffd@0
|
124 end
|
wolffd@0
|
125
|
wolffd@0
|
126 if N1 > SCALE_THRESH * N2
|
wolffd@0
|
127 dbprint(3, sprintf('RHO: %.2e UP %.2e', RHO, RHO * RHO_RESCALE));
|
wolffd@0
|
128 RHO = RHO * RHO_RESCALE;
|
wolffd@0
|
129 ADMM_UW = ADMM_UW / RHO_RESCALE;
|
wolffd@0
|
130 elseif N2 > SCALE_THRESH * N1
|
wolffd@0
|
131 dbprint(3, sprintf('RHO: %.2e DN %.2e', RHO, RHO / RHO_RESCALE));
|
wolffd@0
|
132 RHO = RHO / RHO_RESCALE;
|
wolffd@0
|
133 ADMM_UW = ADMM_UW * RHO_RESCALE;
|
wolffd@0
|
134 end
|
wolffd@0
|
135 end
|
wolffd@0
|
136 % figure(2), hold('off');
|
wolffd@0
|
137
|
wolffd@0
|
138 %%%
|
wolffd@0
|
139 % Ensure feasibility
|
wolffd@0
|
140 %
|
wolffd@0
|
141 W = FEASIBLE(W);
|
wolffd@0
|
142
|
wolffd@0
|
143
|
wolffd@0
|
144 %%%
|
wolffd@0
|
145 % Compute the slack
|
wolffd@0
|
146 %
|
wolffd@0
|
147 Xi = 0;
|
wolffd@0
|
148 for R = numConstraints:-1:1
|
wolffd@0
|
149 Xi = max(Xi, LOSS(W, PsiR{R}, Delta(R), 0));
|
wolffd@0
|
150 end
|
wolffd@0
|
151
|
wolffd@0
|
152 %%%
|
wolffd@0
|
153 % Update diagnostics
|
wolffd@0
|
154 %
|
wolffd@0
|
155
|
wolffd@0
|
156 Diagnostics.f = F(1:step)';
|
wolffd@0
|
157 Diagnostics.stop_criteria = stopcriteria;
|
wolffd@0
|
158 Diagnostics.num_steps = step;
|
wolffd@0
|
159
|
wolffd@0
|
160 dbprint(1, '\t%s after %d steps.\n', stopcriteria, step);
|
wolffd@0
|
161 end
|
wolffd@0
|
162
|
wolffd@0
|
163 function alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha)
|
wolffd@0
|
164
|
wolffd@0
|
165 global PsiClock;
|
wolffd@0
|
166
|
wolffd@0
|
167 m = length(Delta);
|
wolffd@0
|
168
|
wolffd@0
|
169 if nargin < 7
|
wolffd@0
|
170 alpha = zeros(m,1);
|
wolffd@0
|
171 end
|
wolffd@0
|
172
|
wolffd@0
|
173 %%%
|
wolffd@0
|
174 % 1) construct the QP parameters
|
wolffd@0
|
175 %
|
wolffd@0
|
176 b = RHO * (Gamma - Delta) - Q;
|
wolffd@0
|
177
|
wolffd@0
|
178 %%%
|
wolffd@0
|
179 % 2) solve the QP
|
wolffd@0
|
180 %
|
wolffd@0
|
181 alpha = qplcprog(H, b, ones(1, m), C, [], [], 0, []);
|
wolffd@0
|
182
|
wolffd@0
|
183 %%%
|
wolffd@0
|
184 % 3) update the Psi clock
|
wolffd@0
|
185 %
|
wolffd@0
|
186 PsiClock(alpha > 0) = 0;
|
wolffd@0
|
187
|
wolffd@0
|
188 end
|