Mercurial > hg > camir-aes2014
diff toolboxes/distance_learning/mlr/util/mlr_admm.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/distance_learning/mlr/util/mlr_admm.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,181 @@ +function [W, Xi, Diagnostics] = mlr_admm(C, K, Delta, H, Q) +% [W, Xi, D] = mlr_admm(C, Delta, W, X) +% +% C >= 0 Slack trade-off parameter +% K = data matrix (or kernel) +% Delta = array of mean margin values +% H = structural kernel matrix +% Q = kernel-structure interaction vector +% +% W (output) = the learned metric +% Xi = 1-slack +% D = diagnostics + + global DEBUG REG FEASIBLE LOSS INIT STRUCTKERNEL DUALW; + + %%% + % Initialize the gradient directions for each constraint + % + global PsiR; + + global ADMM_Z ADMM_U; + + global ADMM_STEPS; + + global RHO; + + global ADMM_RELTOL; + + + numConstraints = length(PsiR); + + Diagnostics = struct( 'f', [], ... + 'num_steps', [], ... + 'stop_criteria', []); + + + % Convergence settings + if ~isempty(ADMM_STEPS) + MAX_ITER = ADMM_STEPS; + else + MAX_ITER = 10; + end + ABSTOL = 1e-4 * sqrt(numel(ADMM_Z)); + + if ~isempty(ADMM_RELTOL) + RELTOL = ADMM_RELTOL; + else + RELTOL = 1e-3; + end + + SCALE_THRESH = 10; + RHO_RESCALE = 2; + stopcriteria= 'MAX STEPS'; + + % Objective function + F = zeros(1,MAX_ITER); + + % how many constraints + + alpha = zeros(numConstraints, 1); + Gamma = zeros(numConstraints, 1); + + ln1 = 0; + ln2 = 0; + for step = 1:MAX_ITER + % do a w-update + % dubstep needs: + % C <-- static + % RHO <-- static + % H <-- static + % Q <-- static + % Delta <-- static + % Gamma <-- this one's dynamic + + for i = 1:numConstraints + Gamma(i) = STRUCTKERNEL(ADMM_Z-ADMM_U, PsiR{i}); + end + + alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha); + + %%% + % 3) convert back to W + % + W = DUALW(alpha, ADMM_Z, ADMM_U, RHO, K); + + %figure(1), imagesc(W), drawnow; + % Update Z + Zold = ADMM_Z; + ADMM_Z = FEASIBLE(W + ADMM_U); + + % Update residuals + ADMM_U = ADMM_U + W - ADMM_Z; + + % Compute primal objective + % slack term + Xi = 0; + for R = numConstraints:-1:1 + Xi = max(Xi, LOSS(ADMM_Z, PsiR{R}, Delta(R), 0)); + end + F(step) = C * Xi + REG(ADMM_Z, K, 0); + +% figure(2), loglog(1:step, F(1:step)), xlim([0, MAX_ITER]), drawnow; + % Test for convergence + + N1 = norm(W(:)-ADMM_Z(:)); + N2 = RHO * norm(Zold(:) - ADMM_Z(:)); + + eps_primal = ABSTOL + RELTOL * max(norm(W(:)), norm(ADMM_Z(:))); + eps_dual = ABSTOL + RELTOL * RHO * norm(ADMM_U(:)); +% figure(2), loglog(step + (-1:0), [ln1, N1/eps_primal], 'b'), xlim([0, MAX_ITER]), hold('on'); +% figure(2), loglog(step + (-1:0), [ln2, N2/eps_dual], 'r-'), xlim([0, MAX_ITER]), hold('on'), drawnow; +% ln1 = N1/eps_primal; +% ln2 = N2/eps_dual; + if N1 < eps_primal && N2 < eps_dual + stopcriteria = 'CONVERGENCE'; + break; + end + + if N1 > SCALE_THRESH * N2 + dbprint(3, sprintf('RHO: %.2e UP %.2e', RHO, RHO * RHO_RESCALE)); + RHO = RHO * RHO_RESCALE; + ADMM_U = ADMM_U / RHO_RESCALE; + elseif N2 > SCALE_THRESH * N1 + dbprint(3, sprintf('RHO: %.2e DN %.2e', RHO, RHO / RHO_RESCALE)); + RHO = RHO / RHO_RESCALE; + ADMM_U = ADMM_U * RHO_RESCALE; + end + end +% figure(2), hold('off'); + + %%% + % Ensure feasibility + % + W = FEASIBLE(W); + + + %%% + % Compute the slack + % + Xi = 0; + for R = numConstraints:-1:1 + Xi = max(Xi, LOSS(W, PsiR{R}, Delta(R), 0)); + end + + %%% + % Update diagnostics + % + + Diagnostics.f = F(1:step)'; + Diagnostics.stop_criteria = stopcriteria; + Diagnostics.num_steps = step; + + dbprint(1, '\t%s after %d steps.\n', stopcriteria, step); +end + +function alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha) + + global PsiClock; + + m = length(Delta); + + if nargin < 7 + alpha = zeros(m,1); + end + + %%% + % 1) construct the QP parameters + % + b = RHO * (Gamma - Delta) - Q; + + %%% + % 2) solve the QP + % + alpha = qplcprog(H, b, ones(1, m), C, [], [], 0, []); + + %%% + % 3) update the Psi clock + % + PsiClock(alpha > 0) = 0; + +end