annotate toolboxes/distance_learning/mlr/util/rmlr_admm.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [W, Xi, Diagnostics] = rmlr_admm(C, K, Delta, H, Q, lam)
wolffd@0 2 % [W, Xi, D] = mlr_admm(C, Delta, W, X)
wolffd@0 3 %
wolffd@0 4 % C >= 0 Slack trade-off parameter
wolffd@0 5 % K = data matrix (or kernel)
wolffd@0 6 % Delta = array of mean margin values
wolffd@0 7 % H = structural kernel matrix
wolffd@0 8 % Q = kernel-structure interaction vector
wolffd@0 9 %
wolffd@0 10 % W (output) = the learned metric
wolffd@0 11 % Xi = 1-slack
wolffd@0 12 % D = diagnostics
wolffd@0 13
wolffd@0 14 global DEBUG REG FEASIBLE LOSS INIT STRUCTKERNEL DUALW THRESH;
wolffd@0 15
wolffd@0 16 %%%
wolffd@0 17 % Initialize the gradient directions for each constraint
wolffd@0 18 %
wolffd@0 19 global PsiR;
wolffd@0 20
wolffd@0 21 global ADMM_Z ADMM_V ADMM_UW ADMM_UV;
wolffd@0 22
wolffd@0 23 global ADMM_STEPS;
wolffd@0 24
wolffd@0 25 global RHO;
wolffd@0 26
wolffd@0 27 numConstraints = length(PsiR);
wolffd@0 28
wolffd@0 29 Diagnostics = struct( 'f', [], ...
wolffd@0 30 'num_steps', [], ...
wolffd@0 31 'stop_criteria', []);
wolffd@0 32
wolffd@0 33
wolffd@0 34 % Convergence settings
wolffd@0 35 if ~isempty(ADMM_STEPS)
wolffd@0 36 MAX_ITER = ADMM_STEPS;
wolffd@0 37 else
wolffd@0 38 MAX_ITER = 10;
wolffd@0 39 end
wolffd@0 40 ABSTOL = 1e-4 * sqrt(numel(ADMM_Z));
wolffd@0 41 RELTOL = 1e-3;
wolffd@0 42 SCALE_THRESH = 10;
wolffd@0 43 RHO_RESCALE = 2;
wolffd@0 44 stopcriteria= 'MAX STEPS';
wolffd@0 45
wolffd@0 46 % Objective function
wolffd@0 47 F = zeros(1,MAX_ITER);
wolffd@0 48
wolffd@0 49 % how many constraints
wolffd@0 50
wolffd@0 51 alpha = zeros(numConstraints, 1);
wolffd@0 52 Gamma = zeros(numConstraints, 1);
wolffd@0 53
wolffd@0 54 ln1 = 0;
wolffd@0 55 ln2 = 0;
wolffd@0 56
wolffd@0 57 % figure(2)
wolffd@0 58 % hold off
wolffd@0 59 % plot(0)
wolffd@0 60 % delete(abc)
wolffd@0 61 % delete(abc2)
wolffd@0 62 for step = 1:MAX_ITER
wolffd@0 63 % do a w-update
wolffd@0 64 % dubstep needs:
wolffd@0 65 % C <-- static
wolffd@0 66 % RHO <-- static
wolffd@0 67 % H <-- static
wolffd@0 68 % Q <-- static
wolffd@0 69 % Delta <-- static
wolffd@0 70 % Gamma <-- this one's dynamic
wolffd@0 71
wolffd@0 72 for i = 1:numConstraints
wolffd@0 73 Gamma(i) = STRUCTKERNEL(ADMM_Z-ADMM_UW, PsiR{i});
wolffd@0 74 end
wolffd@0 75 % d = length(K);
wolffd@0 76 alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha);
wolffd@0 77
wolffd@0 78 %%%
wolffd@0 79 % 3) convert back to W
wolffd@0 80 %
wolffd@0 81 W = DUALW(alpha, ADMM_Z, ADMM_UW, RHO, K);
wolffd@0 82
wolffd@0 83 % figure(1), imagesc(W), drawnow;
wolffd@0 84
wolffd@0 85 % Update V
wolffd@0 86 ADMM_V = THRESH(ADMM_Z - ADMM_UV, lam/RHO);
wolffd@0 87
wolffd@0 88 % Update Z
wolffd@0 89 Zold = ADMM_Z;
wolffd@0 90 ADMM_Z = FEASIBLE(0.5* (W + ADMM_V + ADMM_UW + ADMM_UV));
wolffd@0 91
wolffd@0 92 % Update residuals
wolffd@0 93 ADMM_UW = ADMM_UW + W - ADMM_Z;
wolffd@0 94 ADMM_UV = ADMM_UV + ADMM_V - ADMM_Z;
wolffd@0 95
wolffd@0 96 % Compute primal objective
wolffd@0 97 % slack term
wolffd@0 98 Xi = 0;
wolffd@0 99 for R = numConstraints:-1:1
wolffd@0 100 Xi = max(Xi, LOSS(ADMM_Z, PsiR{R}, Delta(R), 0));
wolffd@0 101 end
wolffd@0 102 F(step) = C * Xi + REG(W, K, 0) + lam * sum(sqrt(sum(W.^2)));
wolffd@0 103
wolffd@0 104 % figure(2), loglog(1:step, F(1:step)), xlim([0, MAX_ITER]), drawnow;
wolffd@0 105 % Test for convergence
wolffd@0 106
wolffd@0 107 %WIP
wolffd@0 108 N1 = norm(ADMM_V(:) + W(:) - 2* ADMM_Z(:));
wolffd@0 109 N2 = RHO * norm(2* (Zold(:) - ADMM_Z(:)));
wolffd@0 110
wolffd@0 111 eps_primal = ABSTOL + RELTOL * max(norm(W(:)), norm(ADMM_Z(:)));
wolffd@0 112 eps_dual = ABSTOL + RELTOL * RHO * norm(ADMM_UW(:));
wolffd@0 113 %end WIP
wolffd@0 114
wolffd@0 115
wolffd@0 116 % figure(2), loglog(step + (-1:0), [ln1, N1/eps_primal], 'b'), xlim([0, MAX_ITER]), hold('on');
wolffd@0 117 % figure(2), loglog(step + (-1:0), [ln2, N2/eps_dual], 'r-'), xlim([0, MAX_ITER]), hold('on'), drawnow;
wolffd@0 118 % ln1 = N1/eps_primal;
wolffd@0 119 % ln2 = N2/eps_dual;
wolffd@0 120
wolffd@0 121 if N1 < eps_primal && N2 < eps_dual
wolffd@0 122 stopcriteria = 'CONVERGENCE';
wolffd@0 123 break;
wolffd@0 124 end
wolffd@0 125
wolffd@0 126 if N1 > SCALE_THRESH * N2
wolffd@0 127 dbprint(3, sprintf('RHO: %.2e UP %.2e', RHO, RHO * RHO_RESCALE));
wolffd@0 128 RHO = RHO * RHO_RESCALE;
wolffd@0 129 ADMM_UW = ADMM_UW / RHO_RESCALE;
wolffd@0 130 elseif N2 > SCALE_THRESH * N1
wolffd@0 131 dbprint(3, sprintf('RHO: %.2e DN %.2e', RHO, RHO / RHO_RESCALE));
wolffd@0 132 RHO = RHO / RHO_RESCALE;
wolffd@0 133 ADMM_UW = ADMM_UW * RHO_RESCALE;
wolffd@0 134 end
wolffd@0 135 end
wolffd@0 136 % figure(2), hold('off');
wolffd@0 137
wolffd@0 138 %%%
wolffd@0 139 % Ensure feasibility
wolffd@0 140 %
wolffd@0 141 W = FEASIBLE(W);
wolffd@0 142
wolffd@0 143
wolffd@0 144 %%%
wolffd@0 145 % Compute the slack
wolffd@0 146 %
wolffd@0 147 Xi = 0;
wolffd@0 148 for R = numConstraints:-1:1
wolffd@0 149 Xi = max(Xi, LOSS(W, PsiR{R}, Delta(R), 0));
wolffd@0 150 end
wolffd@0 151
wolffd@0 152 %%%
wolffd@0 153 % Update diagnostics
wolffd@0 154 %
wolffd@0 155
wolffd@0 156 Diagnostics.f = F(1:step)';
wolffd@0 157 Diagnostics.stop_criteria = stopcriteria;
wolffd@0 158 Diagnostics.num_steps = step;
wolffd@0 159
wolffd@0 160 dbprint(1, '\t%s after %d steps.\n', stopcriteria, step);
wolffd@0 161 end
wolffd@0 162
wolffd@0 163 function alpha = mlr_dual(C, RHO, H, Q, Delta, Gamma, alpha)
wolffd@0 164
wolffd@0 165 global PsiClock;
wolffd@0 166
wolffd@0 167 m = length(Delta);
wolffd@0 168
wolffd@0 169 if nargin < 7
wolffd@0 170 alpha = zeros(m,1);
wolffd@0 171 end
wolffd@0 172
wolffd@0 173 %%%
wolffd@0 174 % 1) construct the QP parameters
wolffd@0 175 %
wolffd@0 176 b = RHO * (Gamma - Delta) - Q;
wolffd@0 177
wolffd@0 178 %%%
wolffd@0 179 % 2) solve the QP
wolffd@0 180 %
wolffd@0 181 alpha = qplcprog(H, b, ones(1, m), C, [], [], 0, []);
wolffd@0 182
wolffd@0 183 %%%
wolffd@0 184 % 3) update the Psi clock
wolffd@0 185 %
wolffd@0 186 PsiClock(alpha > 0) = 0;
wolffd@0 187
wolffd@0 188 end