comparison toolboxes/distance_learning/mlr/util/mlr_admm.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [W, Xi, Diagnostics] = mlr_admm(C, K, Delta, H, Q)
2 % [W, Xi, D] = mlr_admm(C, Delta, W, X)
3 %
4 % C >= 0 Slack trade-off parameter
5 % K = data matrix (or kernel)
6 % Delta = array of mean margin values
7 % H = structural kernel matrix
8 % Q = kernel-structure interaction vector
9 %
10 % W (output) = the learned metric
11 % Xi = 1-slack
12 % D = diagnostics
13
14 global DEBUG REG FEASIBLE LOSS INIT STRUCTKERNEL DUALW;
15
16 %%%
17 % Initialize the gradient directions for each constraint
18 %
19 global PsiR;
20
21 global ADMM_Z ADMM_U;
22
23 global ADMM_STEPS;
24
25 global RHO;
26
27 global ADMM_RELTOL;
28
29
30 numConstraints = length(PsiR);
31
32 Diagnostics = struct( 'f', [], ...
33 'num_steps', [], ...
34 'stop_criteria', []);
35
36
37 % Convergence settings
38 if ~isempty(ADMM_STEPS)
39 MAX_ITER = ADMM_STEPS;
40 else
41 MAX_ITER = 10;
42 end
43 ABSTOL = 1e-4 * sqrt(numel(ADMM_Z));
44
45 if ~isempty(ADMM_RELTOL)
46 RELTOL = ADMM_RELTOL;
47 else
48 RELTOL = 1e-3;
49 end
50
51 SCALE_THRESH = 10;
52 RHO_RESCALE = 2;
53 stopcriteria= 'MAX STEPS';
54
55 % Objective function
56 F = zeros(1,MAX_ITER);
57
58 % how many constraints
59
60 alpha = zeros(numConstraints, 1);
61 Gamma = zeros(numConstraints, 1);
62
63 ln1 = 0;
64 ln2 = 0;
65 for step = 1:MAX_ITER
66 % do a w-update
67 % dubstep needs:
68 % C <-- static
69 % RHO <-- static
70 % H <-- static
71 % Q <-- static
72 % Delta <-- static
73 % Gamma <-- this one's dynamic
74
75 for i = 1:numConstraints
76 Gamma(i) = STRUCTKERNEL(ADMM_Z-ADMM_U, PsiR{i});
77 end
78
79 alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha);
80
81 %%%
82 % 3) convert back to W
83 %
84 W = DUALW(alpha, ADMM_Z, ADMM_U, RHO, K);
85
86 %figure(1), imagesc(W), drawnow;
87 % Update Z
88 Zold = ADMM_Z;
89 ADMM_Z = FEASIBLE(W + ADMM_U);
90
91 % Update residuals
92 ADMM_U = ADMM_U + W - ADMM_Z;
93
94 % Compute primal objective
95 % slack term
96 Xi = 0;
97 for R = numConstraints:-1:1
98 Xi = max(Xi, LOSS(ADMM_Z, PsiR{R}, Delta(R), 0));
99 end
100 F(step) = C * Xi + REG(ADMM_Z, K, 0);
101
102 % figure(2), loglog(1:step, F(1:step)), xlim([0, MAX_ITER]), drawnow;
103 % Test for convergence
104
105 N1 = norm(W(:)-ADMM_Z(:));
106 N2 = RHO * norm(Zold(:) - ADMM_Z(:));
107
108 eps_primal = ABSTOL + RELTOL * max(norm(W(:)), norm(ADMM_Z(:)));
109 eps_dual = ABSTOL + RELTOL * RHO * norm(ADMM_U(:));
110 % figure(2), loglog(step + (-1:0), [ln1, N1/eps_primal], 'b'), xlim([0, MAX_ITER]), hold('on');
111 % figure(2), loglog(step + (-1:0), [ln2, N2/eps_dual], 'r-'), xlim([0, MAX_ITER]), hold('on'), drawnow;
112 % ln1 = N1/eps_primal;
113 % ln2 = N2/eps_dual;
114 if N1 < eps_primal && N2 < eps_dual
115 stopcriteria = 'CONVERGENCE';
116 break;
117 end
118
119 if N1 > SCALE_THRESH * N2
120 dbprint(3, sprintf('RHO: %.2e UP %.2e', RHO, RHO * RHO_RESCALE));
121 RHO = RHO * RHO_RESCALE;
122 ADMM_U = ADMM_U / RHO_RESCALE;
123 elseif N2 > SCALE_THRESH * N1
124 dbprint(3, sprintf('RHO: %.2e DN %.2e', RHO, RHO / RHO_RESCALE));
125 RHO = RHO / RHO_RESCALE;
126 ADMM_U = ADMM_U * RHO_RESCALE;
127 end
128 end
129 % figure(2), hold('off');
130
131 %%%
132 % Ensure feasibility
133 %
134 W = FEASIBLE(W);
135
136
137 %%%
138 % Compute the slack
139 %
140 Xi = 0;
141 for R = numConstraints:-1:1
142 Xi = max(Xi, LOSS(W, PsiR{R}, Delta(R), 0));
143 end
144
145 %%%
146 % Update diagnostics
147 %
148
149 Diagnostics.f = F(1:step)';
150 Diagnostics.stop_criteria = stopcriteria;
151 Diagnostics.num_steps = step;
152
153 dbprint(1, '\t%s after %d steps.\n', stopcriteria, step);
154 end
155
156 function alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha)
157
158 global PsiClock;
159
160 m = length(Delta);
161
162 if nargin < 7
163 alpha = zeros(m,1);
164 end
165
166 %%%
167 % 1) construct the QP parameters
168 %
169 b = RHO * (Gamma - Delta) - Q;
170
171 %%%
172 % 2) solve the QP
173 %
174 alpha = qplcprog(H, b, ones(1, m), C, [], [], 0, []);
175
176 %%%
177 % 3) update the Psi clock
178 %
179 PsiClock(alpha > 0) = 0;
180
181 end