wolffd@0
|
1 function [W, Xi, Diagnostics] = mlr_admm(C, K, Delta, H, Q)
|
wolffd@0
|
2 % [W, Xi, D] = mlr_admm(C, Delta, W, X)
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % C >= 0 Slack trade-off parameter
|
wolffd@0
|
5 % K = data matrix (or kernel)
|
wolffd@0
|
6 % Delta = array of mean margin values
|
wolffd@0
|
7 % H = structural kernel matrix
|
wolffd@0
|
8 % Q = kernel-structure interaction vector
|
wolffd@0
|
9 %
|
wolffd@0
|
10 % W (output) = the learned metric
|
wolffd@0
|
11 % Xi = 1-slack
|
wolffd@0
|
12 % D = diagnostics
|
wolffd@0
|
13
|
wolffd@0
|
14 global DEBUG REG FEASIBLE LOSS INIT STRUCTKERNEL DUALW;
|
wolffd@0
|
15
|
wolffd@0
|
16 %%%
|
wolffd@0
|
17 % Initialize the gradient directions for each constraint
|
wolffd@0
|
18 %
|
wolffd@0
|
19 global PsiR;
|
wolffd@0
|
20
|
wolffd@0
|
21 global ADMM_Z ADMM_U;
|
wolffd@0
|
22
|
wolffd@0
|
23 global ADMM_STEPS;
|
wolffd@0
|
24
|
wolffd@0
|
25 global RHO;
|
wolffd@0
|
26
|
wolffd@0
|
27 global ADMM_RELTOL;
|
wolffd@0
|
28
|
wolffd@0
|
29
|
wolffd@0
|
30 numConstraints = length(PsiR);
|
wolffd@0
|
31
|
wolffd@0
|
32 Diagnostics = struct( 'f', [], ...
|
wolffd@0
|
33 'num_steps', [], ...
|
wolffd@0
|
34 'stop_criteria', []);
|
wolffd@0
|
35
|
wolffd@0
|
36
|
wolffd@0
|
37 % Convergence settings
|
wolffd@0
|
38 if ~isempty(ADMM_STEPS)
|
wolffd@0
|
39 MAX_ITER = ADMM_STEPS;
|
wolffd@0
|
40 else
|
wolffd@0
|
41 MAX_ITER = 10;
|
wolffd@0
|
42 end
|
wolffd@0
|
43 ABSTOL = 1e-4 * sqrt(numel(ADMM_Z));
|
wolffd@0
|
44
|
wolffd@0
|
45 if ~isempty(ADMM_RELTOL)
|
wolffd@0
|
46 RELTOL = ADMM_RELTOL;
|
wolffd@0
|
47 else
|
wolffd@0
|
48 RELTOL = 1e-3;
|
wolffd@0
|
49 end
|
wolffd@0
|
50
|
wolffd@0
|
51 SCALE_THRESH = 10;
|
wolffd@0
|
52 RHO_RESCALE = 2;
|
wolffd@0
|
53 stopcriteria= 'MAX STEPS';
|
wolffd@0
|
54
|
wolffd@0
|
55 % Objective function
|
wolffd@0
|
56 F = zeros(1,MAX_ITER);
|
wolffd@0
|
57
|
wolffd@0
|
58 % how many constraints
|
wolffd@0
|
59
|
wolffd@0
|
60 alpha = zeros(numConstraints, 1);
|
wolffd@0
|
61 Gamma = zeros(numConstraints, 1);
|
wolffd@0
|
62
|
wolffd@0
|
63 ln1 = 0;
|
wolffd@0
|
64 ln2 = 0;
|
wolffd@0
|
65 for step = 1:MAX_ITER
|
wolffd@0
|
66 % do a w-update
|
wolffd@0
|
67 % dubstep needs:
|
wolffd@0
|
68 % C <-- static
|
wolffd@0
|
69 % RHO <-- static
|
wolffd@0
|
70 % H <-- static
|
wolffd@0
|
71 % Q <-- static
|
wolffd@0
|
72 % Delta <-- static
|
wolffd@0
|
73 % Gamma <-- this one's dynamic
|
wolffd@0
|
74
|
wolffd@0
|
75 for i = 1:numConstraints
|
wolffd@0
|
76 Gamma(i) = STRUCTKERNEL(ADMM_Z-ADMM_U, PsiR{i});
|
wolffd@0
|
77 end
|
wolffd@0
|
78
|
wolffd@0
|
79 alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha);
|
wolffd@0
|
80
|
wolffd@0
|
81 %%%
|
wolffd@0
|
82 % 3) convert back to W
|
wolffd@0
|
83 %
|
wolffd@0
|
84 W = DUALW(alpha, ADMM_Z, ADMM_U, RHO, K);
|
wolffd@0
|
85
|
wolffd@0
|
86 %figure(1), imagesc(W), drawnow;
|
wolffd@0
|
87 % Update Z
|
wolffd@0
|
88 Zold = ADMM_Z;
|
wolffd@0
|
89 ADMM_Z = FEASIBLE(W + ADMM_U);
|
wolffd@0
|
90
|
wolffd@0
|
91 % Update residuals
|
wolffd@0
|
92 ADMM_U = ADMM_U + W - ADMM_Z;
|
wolffd@0
|
93
|
wolffd@0
|
94 % Compute primal objective
|
wolffd@0
|
95 % slack term
|
wolffd@0
|
96 Xi = 0;
|
wolffd@0
|
97 for R = numConstraints:-1:1
|
wolffd@0
|
98 Xi = max(Xi, LOSS(ADMM_Z, PsiR{R}, Delta(R), 0));
|
wolffd@0
|
99 end
|
wolffd@0
|
100 F(step) = C * Xi + REG(ADMM_Z, K, 0);
|
wolffd@0
|
101
|
wolffd@0
|
102 % figure(2), loglog(1:step, F(1:step)), xlim([0, MAX_ITER]), drawnow;
|
wolffd@0
|
103 % Test for convergence
|
wolffd@0
|
104
|
wolffd@0
|
105 N1 = norm(W(:)-ADMM_Z(:));
|
wolffd@0
|
106 N2 = RHO * norm(Zold(:) - ADMM_Z(:));
|
wolffd@0
|
107
|
wolffd@0
|
108 eps_primal = ABSTOL + RELTOL * max(norm(W(:)), norm(ADMM_Z(:)));
|
wolffd@0
|
109 eps_dual = ABSTOL + RELTOL * RHO * norm(ADMM_U(:));
|
wolffd@0
|
110 % figure(2), loglog(step + (-1:0), [ln1, N1/eps_primal], 'b'), xlim([0, MAX_ITER]), hold('on');
|
wolffd@0
|
111 % figure(2), loglog(step + (-1:0), [ln2, N2/eps_dual], 'r-'), xlim([0, MAX_ITER]), hold('on'), drawnow;
|
wolffd@0
|
112 % ln1 = N1/eps_primal;
|
wolffd@0
|
113 % ln2 = N2/eps_dual;
|
wolffd@0
|
114 if N1 < eps_primal && N2 < eps_dual
|
wolffd@0
|
115 stopcriteria = 'CONVERGENCE';
|
wolffd@0
|
116 break;
|
wolffd@0
|
117 end
|
wolffd@0
|
118
|
wolffd@0
|
119 if N1 > SCALE_THRESH * N2
|
wolffd@0
|
120 dbprint(3, sprintf('RHO: %.2e UP %.2e', RHO, RHO * RHO_RESCALE));
|
wolffd@0
|
121 RHO = RHO * RHO_RESCALE;
|
wolffd@0
|
122 ADMM_U = ADMM_U / RHO_RESCALE;
|
wolffd@0
|
123 elseif N2 > SCALE_THRESH * N1
|
wolffd@0
|
124 dbprint(3, sprintf('RHO: %.2e DN %.2e', RHO, RHO / RHO_RESCALE));
|
wolffd@0
|
125 RHO = RHO / RHO_RESCALE;
|
wolffd@0
|
126 ADMM_U = ADMM_U * RHO_RESCALE;
|
wolffd@0
|
127 end
|
wolffd@0
|
128 end
|
wolffd@0
|
129 % figure(2), hold('off');
|
wolffd@0
|
130
|
wolffd@0
|
131 %%%
|
wolffd@0
|
132 % Ensure feasibility
|
wolffd@0
|
133 %
|
wolffd@0
|
134 W = FEASIBLE(W);
|
wolffd@0
|
135
|
wolffd@0
|
136
|
wolffd@0
|
137 %%%
|
wolffd@0
|
138 % Compute the slack
|
wolffd@0
|
139 %
|
wolffd@0
|
140 Xi = 0;
|
wolffd@0
|
141 for R = numConstraints:-1:1
|
wolffd@0
|
142 Xi = max(Xi, LOSS(W, PsiR{R}, Delta(R), 0));
|
wolffd@0
|
143 end
|
wolffd@0
|
144
|
wolffd@0
|
145 %%%
|
wolffd@0
|
146 % Update diagnostics
|
wolffd@0
|
147 %
|
wolffd@0
|
148
|
wolffd@0
|
149 Diagnostics.f = F(1:step)';
|
wolffd@0
|
150 Diagnostics.stop_criteria = stopcriteria;
|
wolffd@0
|
151 Diagnostics.num_steps = step;
|
wolffd@0
|
152
|
wolffd@0
|
153 dbprint(1, '\t%s after %d steps.\n', stopcriteria, step);
|
wolffd@0
|
154 end
|
wolffd@0
|
155
|
wolffd@0
|
156 function alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha)
|
wolffd@0
|
157
|
wolffd@0
|
158 global PsiClock;
|
wolffd@0
|
159
|
wolffd@0
|
160 m = length(Delta);
|
wolffd@0
|
161
|
wolffd@0
|
162 if nargin < 7
|
wolffd@0
|
163 alpha = zeros(m,1);
|
wolffd@0
|
164 end
|
wolffd@0
|
165
|
wolffd@0
|
166 %%%
|
wolffd@0
|
167 % 1) construct the QP parameters
|
wolffd@0
|
168 %
|
wolffd@0
|
169 b = RHO * (Gamma - Delta) - Q;
|
wolffd@0
|
170
|
wolffd@0
|
171 %%%
|
wolffd@0
|
172 % 2) solve the QP
|
wolffd@0
|
173 %
|
wolffd@0
|
174 alpha = qplcprog(H, b, ones(1, m), C, [], [], 0, []);
|
wolffd@0
|
175
|
wolffd@0
|
176 %%%
|
wolffd@0
|
177 % 3) update the Psi clock
|
wolffd@0
|
178 %
|
wolffd@0
|
179 PsiClock(alpha > 0) = 0;
|
wolffd@0
|
180
|
wolffd@0
|
181 end
|