annotate toolboxes/FullBNT-1.0.7/bnt/examples/static/Misc/mixexp_graddesc.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1
wolffd@0 2 %%%%%%%%%%
wolffd@0 3
wolffd@0 4 function [theta, eta] = mixture_of_experts(q, data, num_iter, theta, eta)
wolffd@0 5 % MIXTURE_OF_EXPERTS Fit a piecewise linear regression model using stochastic gradient descent.
wolffd@0 6 % [theta, eta] = mixture_of_experts(q, data, num_iter)
wolffd@0 7 %
wolffd@0 8 % Inputs:
wolffd@0 9 % q = number of pieces (experts)
wolffd@0 10 % data(l,:) = input example l
wolffd@0 11 %
wolffd@0 12 % Outputs:
wolffd@0 13 % theta(i,:) = regression vector for expert i
wolffd@0 14 % eta(i,:) = softmax (gating) params for expert i
wolffd@0 15
wolffd@0 16 [num_cases dim] = size(data);
wolffd@0 17 data = [ones(num_cases,1) data]; % prepend with offset
wolffd@0 18 mu = 0.5; % step size
wolffd@0 19 sigma = 1; % variance of noise
wolffd@0 20
wolffd@0 21 if nargin < 4
wolffd@0 22 theta = 0.1*rand(q, dim);
wolffd@0 23 eta = 0.1*rand(q, dim);
wolffd@0 24 end
wolffd@0 25
wolffd@0 26 for t=1:num_iter
wolffd@0 27 for iter=1:num_cases
wolffd@0 28 x = data(iter, 1:dim);
wolffd@0 29 ystar = data(iter, dim+1); % target
wolffd@0 30 % yhat(i) = E[y | Q=i, x] = prediction of i'th expert
wolffd@0 31 yhat = theta * x';
wolffd@0 32 % gate_prior(i,:) = Pr(Q=i | x)
wolffd@0 33 gate_prior = exp(eta * x');
wolffd@0 34 gate_prior = gate_prior / sum(gate_prior);
wolffd@0 35 % lik(i) = Pr(y | Q=i, x)
wolffd@0 36 lik = (1/(sqrt(2*pi)*sigma)) * exp(-(0.5/sigma^2) * ((ystar - yhat) .* (ystar - yhat)));
wolffd@0 37 % gate_posterior(i,:) = Pr(Q=i | x, y)
wolffd@0 38 gate_posterior = gate_prior .* lik;
wolffd@0 39 gate_posterior = gate_posterior / sum(gate_posterior);
wolffd@0 40 % Update
wolffd@0 41 eta = eta + mu*(gate_posterior - gate_prior)*x;
wolffd@0 42 theta = theta + mu*(gate_posterior .* (ystar - yhat))*x;
wolffd@0 43 end
wolffd@0 44
wolffd@0 45 if mod(t,100)==0
wolffd@0 46 fprintf(1, 'iter %d\n', t);
wolffd@0 47 end
wolffd@0 48
wolffd@0 49 end
wolffd@0 50 fprintf(1, '\n');
wolffd@0 51