wolffd@0
|
1
|
wolffd@0
|
2 %%%%%%%%%%
|
wolffd@0
|
3
|
wolffd@0
|
4 function [theta, eta] = mixture_of_experts(q, data, num_iter, theta, eta)
|
wolffd@0
|
5 % MIXTURE_OF_EXPERTS Fit a piecewise linear regression model using stochastic gradient descent.
|
wolffd@0
|
6 % [theta, eta] = mixture_of_experts(q, data, num_iter)
|
wolffd@0
|
7 %
|
wolffd@0
|
8 % Inputs:
|
wolffd@0
|
9 % q = number of pieces (experts)
|
wolffd@0
|
10 % data(l,:) = input example l
|
wolffd@0
|
11 %
|
wolffd@0
|
12 % Outputs:
|
wolffd@0
|
13 % theta(i,:) = regression vector for expert i
|
wolffd@0
|
14 % eta(i,:) = softmax (gating) params for expert i
|
wolffd@0
|
15
|
wolffd@0
|
16 [num_cases dim] = size(data);
|
wolffd@0
|
17 data = [ones(num_cases,1) data]; % prepend with offset
|
wolffd@0
|
18 mu = 0.5; % step size
|
wolffd@0
|
19 sigma = 1; % variance of noise
|
wolffd@0
|
20
|
wolffd@0
|
21 if nargin < 4
|
wolffd@0
|
22 theta = 0.1*rand(q, dim);
|
wolffd@0
|
23 eta = 0.1*rand(q, dim);
|
wolffd@0
|
24 end
|
wolffd@0
|
25
|
wolffd@0
|
26 for t=1:num_iter
|
wolffd@0
|
27 for iter=1:num_cases
|
wolffd@0
|
28 x = data(iter, 1:dim);
|
wolffd@0
|
29 ystar = data(iter, dim+1); % target
|
wolffd@0
|
30 % yhat(i) = E[y | Q=i, x] = prediction of i'th expert
|
wolffd@0
|
31 yhat = theta * x';
|
wolffd@0
|
32 % gate_prior(i,:) = Pr(Q=i | x)
|
wolffd@0
|
33 gate_prior = exp(eta * x');
|
wolffd@0
|
34 gate_prior = gate_prior / sum(gate_prior);
|
wolffd@0
|
35 % lik(i) = Pr(y | Q=i, x)
|
wolffd@0
|
36 lik = (1/(sqrt(2*pi)*sigma)) * exp(-(0.5/sigma^2) * ((ystar - yhat) .* (ystar - yhat)));
|
wolffd@0
|
37 % gate_posterior(i,:) = Pr(Q=i | x, y)
|
wolffd@0
|
38 gate_posterior = gate_prior .* lik;
|
wolffd@0
|
39 gate_posterior = gate_posterior / sum(gate_posterior);
|
wolffd@0
|
40 % Update
|
wolffd@0
|
41 eta = eta + mu*(gate_posterior - gate_prior)*x;
|
wolffd@0
|
42 theta = theta + mu*(gate_posterior .* (ystar - yhat))*x;
|
wolffd@0
|
43 end
|
wolffd@0
|
44
|
wolffd@0
|
45 if mod(t,100)==0
|
wolffd@0
|
46 fprintf(1, 'iter %d\n', t);
|
wolffd@0
|
47 end
|
wolffd@0
|
48
|
wolffd@0
|
49 end
|
wolffd@0
|
50 fprintf(1, '\n');
|
wolffd@0
|
51
|