annotate matlab/GAP/ArgminOperL2Constrained.m @ 18:a8ff9a881d2f

GAP test almost working. For some data the results are not the same because of representation error, so the test doesn't fully work for now. But the results seem to be accurate.
author nikcleju
date Mon, 07 Nov 2011 17:48:05 +0000
parents
children
rev   line source
nikcleju@18 1 function [xhat, arepr, lagmult] = ArgminOperL2Constrained(y, M, MH, Omega, OmegaH, Lambdahat, xinit, ilagmult, params)
nikcleju@18 2
nikcleju@18 3 %
nikcleju@18 4 % This function aims to compute
nikcleju@18 5 % xhat = argmin || Omega(Lambdahat, :) * x ||_2 subject to || y - M*x ||_2 <= epsilon.
nikcleju@18 6 % arepr is the analysis representation corresponding to Lambdahat, i.e.,
nikcleju@18 7 % arepr = Omega(Lambdahat, :) * xhat.
nikcleju@18 8 % The function also returns the lagrange multiplier in the process used to compute xhat.
nikcleju@18 9 %
nikcleju@18 10 % Inputs:
nikcleju@18 11 % y : observation/measurements of an unknown vector x0. It is equal to M*x0 + noise.
nikcleju@18 12 % M : Measurement matrix
nikcleju@18 13 % MH : M', the conjugate transpose of M
nikcleju@18 14 % Omega : analysis operator
nikcleju@18 15 % OmegaH : Omega', the conjugate transpose of Omega. Also, synthesis operator.
nikcleju@18 16 % Lambdahat : an index set indicating some rows of Omega.
nikcleju@18 17 % xinit : initial estimate that will be used for the conjugate gradient algorithm.
nikcleju@18 18 % ilagmult : initial lagrange multiplier to be used in
nikcleju@18 19 % params : parameters
nikcleju@18 20 % params.noise_level : this corresponds to epsilon above.
nikcleju@18 21 % params.max_inner_iteration : `maximum' number of iterations in conjugate gradient method.
nikcleju@18 22 % params.l2_accurary : the l2 accuracy parameter used in conjugate gradient method
nikcleju@18 23 % params.l2solver : if the value is 'pseudoinverse', then direct matrix computation (not conjugate gradient method) is used. Otherwise, conjugate gradient method is used.
nikcleju@18 24 %
nikcleju@18 25
nikcleju@18 26 d = length(xinit);
nikcleju@18 27 lagmultmax = 1e5;
nikcleju@18 28 lagmultmin = 1e-4;
nikcleju@18 29 lagmultfactor = 2;
nikcleju@18 30 accuracy_adjustment_exponent = 4/5;
nikcleju@18 31 lagmult = max(min(ilagmult, lagmultmax), lagmultmin);
nikcleju@18 32 was_infeasible = 0;
nikcleju@18 33 was_feasible = 0;
nikcleju@18 34
nikcleju@18 35 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
nikcleju@18 36 %% Computation done using direct matrix computation from matlab. (no conjugate gradient method.)
nikcleju@18 37 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
nikcleju@18 38 if strcmp(params.l2solver, 'pseudoinverse')
nikcleju@18 39 if strcmp(class(M), 'double') && strcmp(class(Omega), 'double')
nikcleju@18 40 while true
nikcleju@18 41 alpha = sqrt(lagmult);
nikcleju@18 42 xhat = [M; alpha*Omega(Lambdahat,:)]\[y; zeros(length(Lambdahat), 1)];
nikcleju@18 43 temp = norm(y - M*xhat, 2);
nikcleju@18 44 %disp(['fidelity error=', num2str(temp), ' lagmult=', num2str(lagmult)]);
nikcleju@18 45 if temp <= params.noise_level
nikcleju@18 46 was_feasible = 1;
nikcleju@18 47 if was_infeasible == 1
nikcleju@18 48 break;
nikcleju@18 49 else
nikcleju@18 50 lagmult = lagmult*lagmultfactor;
nikcleju@18 51 end
nikcleju@18 52 elseif temp > params.noise_level
nikcleju@18 53 was_infeasible = 1;
nikcleju@18 54 if was_feasible == 1
nikcleju@18 55 xhat = xprev;
nikcleju@18 56 break;
nikcleju@18 57 end
nikcleju@18 58 lagmult = lagmult/lagmultfactor;
nikcleju@18 59 end
nikcleju@18 60 if lagmult < lagmultmin || lagmult > lagmultmax
nikcleju@18 61 break;
nikcleju@18 62 end
nikcleju@18 63 xprev = xhat;
nikcleju@18 64 end
nikcleju@18 65 arepr = Omega(Lambdahat, :) * xhat;
nikcleju@18 66 return;
nikcleju@18 67 end
nikcleju@18 68 end
nikcleju@18 69
nikcleju@18 70
nikcleju@18 71 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
nikcleju@18 72 %% Computation using conjugate gradient method.
nikcleju@18 73 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
nikcleju@18 74 if strcmp(class(MH),'function_handle')
nikcleju@18 75 b = MH(y);
nikcleju@18 76 else
nikcleju@18 77 b = MH * y;
nikcleju@18 78 end
nikcleju@18 79 norm_b = norm(b, 2);
nikcleju@18 80 xhat = xinit;
nikcleju@18 81 xprev = xinit;
nikcleju@18 82 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b;
nikcleju@18 83 direction = -residual;
nikcleju@18 84 iter = 0;
nikcleju@18 85
nikcleju@18 86 while iter < params.max_inner_iteration
nikcleju@18 87 iter = iter + 1;
nikcleju@18 88 alpha = norm(residual,2)^2 / (direction' * TheHermitianMatrix(direction, M, MH, Omega, OmegaH, Lambdahat, lagmult));
nikcleju@18 89 xhat = xhat + alpha*direction;
nikcleju@18 90 prev_residual = residual;
nikcleju@18 91 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b;
nikcleju@18 92 beta = norm(residual,2)^2 / norm(prev_residual,2)^2;
nikcleju@18 93 direction = -residual + beta*direction;
nikcleju@18 94
nikcleju@18 95 if norm(residual,2)/norm_b < params.l2_accuracy*(lagmult^(accuracy_adjustment_exponent)) || iter == params.max_inner_iteration
nikcleju@18 96 if strcmp(class(M), 'function_handle')
nikcleju@18 97 temp = norm(y-M(xhat), 2);
nikcleju@18 98 else
nikcleju@18 99 temp = norm(y-M*xhat, 2);
nikcleju@18 100 end
nikcleju@18 101
nikcleju@18 102 if strcmp(class(Omega), 'function_handle')
nikcleju@18 103 u = Omega(xhat);
nikcleju@18 104 u = sqrt(lagmult)*norm(u(Lambdahat), 2);
nikcleju@18 105 else
nikcleju@18 106 u = sqrt(lagmult)*norm(Omega(Lambdahat,:)*xhat, 2);
nikcleju@18 107 end
nikcleju@18 108
nikcleju@18 109 %disp(['residual=', num2str(norm(residual,2)), ' norm_b=', num2str(norm_b), ' omegapart=', num2str(u), ' fidelity error=', num2str(temp), ' lagmult=', num2str(lagmult), ' iter=', num2str(iter)]);
nikcleju@18 110
nikcleju@18 111 if temp <= params.noise_level
nikcleju@18 112 was_feasible = 1;
nikcleju@18 113 if was_infeasible == 1
nikcleju@18 114 break;
nikcleju@18 115 else
nikcleju@18 116 lagmult = lagmultfactor*lagmult;
nikcleju@18 117 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b;
nikcleju@18 118 direction = -residual;
nikcleju@18 119 iter = 0;
nikcleju@18 120 end
nikcleju@18 121 elseif temp > params.noise_level
nikcleju@18 122 lagmult = lagmult/lagmultfactor;
nikcleju@18 123 if was_feasible == 1
nikcleju@18 124 xhat = xprev;
nikcleju@18 125 break;
nikcleju@18 126 end
nikcleju@18 127 was_infeasible = 1;
nikcleju@18 128 residual = TheHermitianMatrix(xhat, M, MH, Omega, OmegaH, Lambdahat, lagmult) - b;
nikcleju@18 129 direction = -residual;
nikcleju@18 130 iter = 0;
nikcleju@18 131 end
nikcleju@18 132 if lagmult > lagmultmax || lagmult < lagmultmin
nikcleju@18 133 break;
nikcleju@18 134 end
nikcleju@18 135 xprev = xhat;
nikcleju@18 136 %elseif norm(xprev-xhat)/norm(xhat) < 1e-2
nikcleju@18 137 % disp(['rel_change=', num2str(norm(xprev-xhat)/norm(xhat))]);
nikcleju@18 138 % if strcmp(class(M), 'function_handle')
nikcleju@18 139 % temp = norm(y-M(xhat), 2);
nikcleju@18 140 % else
nikcleju@18 141 % temp = norm(y-M*xhat, 2);
nikcleju@18 142 % end
nikcleju@18 143 %
nikcleju@18 144 % if temp > 1.2*params.noise_level
nikcleju@18 145 % was_infeasible = 1;
nikcleju@18 146 % lagmult = lagmult/lagmultfactor;
nikcleju@18 147 % xprev = xhat;
nikcleju@18 148 % end
nikcleju@18 149 end
nikcleju@18 150
nikcleju@18 151 end
nikcleju@18 152 disp(['fidelity_error=', num2str(temp)]);
nikcleju@18 153 if iter == params.max_inner_iteration
nikcleju@18 154 %disp('max_inner_iteration reached. l2_accuracy not achieved.');
nikcleju@18 155 end
nikcleju@18 156
nikcleju@18 157 %%
nikcleju@18 158 % Compute analysis representation for xhat
nikcleju@18 159 %%
nikcleju@18 160 if strcmp(class(Omega),'function_handle')
nikcleju@18 161 temp = Omega(xhat);
nikcleju@18 162 arepr = temp(Lambdahat);
nikcleju@18 163 else %% here Omega is assumed to be a matrix
nikcleju@18 164 arepr = Omega(Lambdahat, :) * xhat;
nikcleju@18 165 end
nikcleju@18 166 return;
nikcleju@18 167
nikcleju@18 168
nikcleju@18 169 %%
nikcleju@18 170 % This function computes (M'*M + lm*Omega(L,:)'*Omega(L,:)) * x.
nikcleju@18 171 %%
nikcleju@18 172 function w = TheHermitianMatrix(x, M, MH, Omega, OmegaH, L, lm)
nikcleju@18 173 if strcmp(class(M), 'function_handle')
nikcleju@18 174 w = MH(M(x));
nikcleju@18 175 else %% M and MH are matrices
nikcleju@18 176 w = MH * M * x;
nikcleju@18 177 end
nikcleju@18 178 if strcmp(class(Omega),'function_handle')
nikcleju@18 179 v = Omega(x);
nikcleju@18 180 vt = zeros(size(v));
nikcleju@18 181 vt(L) = v(L);
nikcleju@18 182 w = w + lm*OmegaH(vt);
nikcleju@18 183 else %% Omega is assumed to be a matrix and OmegaH is its conjugate transpose
nikcleju@18 184 w = w + lm*OmegaH(:, L)*Omega(L, :)*x;
nikcleju@18 185 end