diff toolboxes/distance_learning/mlr/util/rmlr_admm.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/distance_learning/mlr/util/rmlr_admm.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,188 @@
+function [W, Xi, Diagnostics] = rmlr_admm(C, K, Delta, H, Q, lam)
+% [W, Xi, D] = mlr_admm(C, Delta, W, X)
+%
+%   C       >= 0    Slack trade-off parameter
+%   K       =       data matrix (or kernel)
+%   Delta   =       array of mean margin values
+%   H       =       structural kernel matrix
+%   Q       =       kernel-structure interaction vector
+%
+%   W (output)  =   the learned metric
+%   Xi          =   1-slack
+%   D           =   diagnostics
+
+global DEBUG REG FEASIBLE LOSS INIT STRUCTKERNEL DUALW THRESH;
+
+%%%
+% Initialize the gradient directions for each constraint
+%
+global PsiR;
+
+global ADMM_Z ADMM_V ADMM_UW ADMM_UV;
+
+global ADMM_STEPS;
+
+global RHO;
+
+numConstraints = length(PsiR);
+
+Diagnostics = struct(   'f',                [], ...
+    'num_steps',        [], ...
+    'stop_criteria',    []);
+
+
+% Convergence settings
+if ~isempty(ADMM_STEPS)
+    MAX_ITER = ADMM_STEPS;
+else
+    MAX_ITER = 10;
+end
+ABSTOL      = 1e-4 * sqrt(numel(ADMM_Z));
+RELTOL      = 1e-3;
+SCALE_THRESH    = 10;
+RHO_RESCALE     = 2;
+stopcriteria= 'MAX STEPS';
+
+% Objective function
+F           = zeros(1,MAX_ITER);
+
+% how many constraints
+
+alpha           = zeros(numConstraints,  1);
+Gamma           = zeros(numConstraints,  1);
+
+ln1 = 0;
+ln2 = 0;
+
+% figure(2)
+% hold off
+% plot(0)
+% delete(abc)
+% delete(abc2)
+for step = 1:MAX_ITER
+    % do a w-update
+    % dubstep needs:
+    %   C       <-- static
+    %   RHO     <-- static
+    %   H       <-- static
+    %   Q       <-- static
+    %   Delta   <-- static
+    %   Gamma   <-- this one's dynamic
+    
+    for i = 1:numConstraints
+        Gamma(i) = STRUCTKERNEL(ADMM_Z-ADMM_UW, PsiR{i});
+    end
+    %     d = length(K);
+    alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha);
+    
+    %%%
+    % 3) convert back to W
+    %
+    W = DUALW(alpha, ADMM_Z, ADMM_UW, RHO, K);
+    
+    %   figure(1), imagesc(W), drawnow;
+    
+    % Update V
+    ADMM_V = THRESH(ADMM_Z - ADMM_UV, lam/RHO);
+    
+    % Update Z
+    Zold    = ADMM_Z;
+    ADMM_Z  = FEASIBLE(0.5* (W + ADMM_V + ADMM_UW + ADMM_UV));
+    
+    % Update residuals
+    ADMM_UW  = ADMM_UW + W - ADMM_Z;
+    ADMM_UV  = ADMM_UV + ADMM_V - ADMM_Z;
+    
+    % Compute primal objective
+    %   slack term
+    Xi      = 0;
+    for R = numConstraints:-1:1
+        Xi = max(Xi, LOSS(ADMM_Z, PsiR{R}, Delta(R), 0));
+    end
+    F(step)     = C * Xi + REG(W, K, 0) + lam * sum(sqrt(sum(W.^2)));
+    
+%     figure(2), loglog(1:step, F(1:step)), xlim([0, MAX_ITER]), drawnow;
+    % Test for convergence
+    
+    %WIP
+    N1          = norm(ADMM_V(:) + W(:) - 2* ADMM_Z(:));
+    N2          = RHO * norm(2* (Zold(:) - ADMM_Z(:)));
+    
+    eps_primal = ABSTOL + RELTOL * max(norm(W(:)), norm(ADMM_Z(:)));
+    eps_dual   = ABSTOL + RELTOL * RHO * norm(ADMM_UW(:));
+    %end WIP
+    
+    
+%            figure(2), loglog(step + (-1:0), [ln1, N1/eps_primal], 'b'), xlim([0, MAX_ITER]), hold('on');
+%            figure(2), loglog(step + (-1:0), [ln2, N2/eps_dual], 'r-'), xlim([0, MAX_ITER]), hold('on'), drawnow;
+%           ln1 = N1/eps_primal;
+%           ln2 = N2/eps_dual;
+    
+    if N1 < eps_primal && N2 < eps_dual
+        stopcriteria = 'CONVERGENCE';
+        break;
+    end
+    
+    if N1 > SCALE_THRESH * N2
+         dbprint(3, sprintf('RHO: %.2e UP %.2e', RHO, RHO * RHO_RESCALE));
+        RHO = RHO * RHO_RESCALE;
+        ADMM_UW  = ADMM_UW / RHO_RESCALE;
+    elseif N2 > SCALE_THRESH * N1
+         dbprint(3, sprintf('RHO: %.2e DN %.2e', RHO, RHO / RHO_RESCALE));
+        RHO = RHO / RHO_RESCALE;
+        ADMM_UW  = ADMM_UW * RHO_RESCALE;
+    end
+end
+%     figure(2), hold('off');
+
+%%%
+% Ensure feasibility
+%
+W = FEASIBLE(W);
+
+
+%%%
+% Compute the slack
+%
+Xi = 0;
+for R = numConstraints:-1:1
+    Xi  = max(Xi, LOSS(W, PsiR{R}, Delta(R), 0));
+end
+
+%%%
+% Update diagnostics
+%
+
+Diagnostics.f               = F(1:step)';
+Diagnostics.stop_criteria   = stopcriteria;
+Diagnostics.num_steps       = step;
+
+dbprint(1, '\t%s after %d steps.\n', stopcriteria, step);
+end
+
+function alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha)
+
+global PsiClock;
+
+m = length(Delta);
+
+if nargin < 7
+    alpha = zeros(m,1);
+end
+
+%%%
+% 1) construct the QP parameters
+%
+b = RHO * (Gamma - Delta) - Q;
+
+%%%
+% 2) solve the QP
+%
+alpha = qplcprog(H, b, ones(1, m), C, [], [], 0, []);
+
+%%%
+% 3) update the Psi clock
+%
+PsiClock(alpha > 0)  = 0;
+
+end