wolffd@0: function kmeans_demo() wolffd@0: wolffd@0: % Generate T points from K=5 1D clusters, and try to recover the cluster wolffd@0: % centers using k-means. wolffd@0: % Requires BNT, netlab and the matlab stats toolbox v4. wolffd@0: wolffd@0: K = 5; wolffd@0: ndim = 1; wolffd@0: true_centers = 1:K; wolffd@0: sigma = 1e-6; wolffd@0: T = 100; wolffd@0: % data(t,:) is the t'th data point wolffd@0: data = zeros(T, ndim); wolffd@0: % ndx(t) = i means the t'th data point is sample from cluster i wolffd@0: %ndx = sample_discrete(normalise(ones(1,K))); wolffd@0: ndx = [1*ones(1,20) 2*ones(1,20) 3*ones(1,20) 4*ones(1,20) 5*ones(1,20)]; wolffd@0: for t=1:T wolffd@0: data(t) = sample_gaussian(true_centers(ndx(t)), sigma, 1); wolffd@0: end wolffd@0: plot(1:T, data, 'x') wolffd@0: wolffd@0: wolffd@0: wolffd@0: % set the centers randomly from Gauss(0) wolffd@0: mix = gmm(ndim, K, 'spherical'); wolffd@0: h = plot_centers_as_lines(mix, [], T); wolffd@0: wolffd@0: if 0 wolffd@0: % Place initial centers at K data points chosen at random, but add some noise wolffd@0: choose_ndx = randperm(T); wolffd@0: choose_ndx = choose_ndx(1:K); wolffd@0: init_centers = data(choose_ndx) + sample_gaussian(0, 0.1, K); wolffd@0: mix.centres = init_centers; wolffd@0: h = plot_centers_as_lines(mix, h, T); wolffd@0: end wolffd@0: wolffd@0: if 0 wolffd@0: % update centers using netlab k-means wolffd@0: options = foptions; wolffd@0: niter = 10; wolffd@0: options(14) = niter; wolffd@0: mix = gmminit(mix, data, options); wolffd@0: h = plot_centers_as_lines(mix, h, T); wolffd@0: end wolffd@0: wolffd@0: % use matlab stats toolbox k-means with multiple restarts wolffd@0: nrestarts = 5; wolffd@0: [idx, centers] = kmeans(data, K, 'replicates', nrestarts, ... wolffd@0: 'emptyAction', 'singleton', 'display', 'iter'); wolffd@0: mix.centres = centers; wolffd@0: h = plot_centers_as_lines(mix, h, T); wolffd@0: wolffd@0: % fine tune with EM; compute covariances of each cluster wolffd@0: options = foptions; wolffd@0: niter = 20; wolffd@0: options(1) = 1; % display cost fn at each iter wolffd@0: options(14) = niter; wolffd@0: mix = gmmem(mix, data, options); wolffd@0: h = plot_centers_as_lines(mix, h, T); wolffd@0: wolffd@0: %%%%%%%%% wolffd@0: function h = plot_centers_as_lines(mix, h, T) wolffd@0: wolffd@0: K = mix.ncentres; wolffd@0: hold on wolffd@0: if isempty(h) wolffd@0: for k=1:K wolffd@0: h(k)=line([0 T], [mix.centres(k) mix.centres(k)]); wolffd@0: end wolffd@0: else wolffd@0: for k=1:K wolffd@0: set(h(k), 'xdata', [0 T], 'ydata', [mix.centres(k) mix.centres(k)]); wolffd@0: end wolffd@0: end wolffd@0: hold off wolffd@0: