Daniel@0: function kmeans_demo() Daniel@0: Daniel@0: % Generate T points from K=5 1D clusters, and try to recover the cluster Daniel@0: % centers using k-means. Daniel@0: % Requires BNT, netlab and the matlab stats toolbox v4. Daniel@0: Daniel@0: K = 5; Daniel@0: ndim = 1; Daniel@0: true_centers = 1:K; Daniel@0: sigma = 1e-6; Daniel@0: T = 100; Daniel@0: % data(t,:) is the t'th data point Daniel@0: data = zeros(T, ndim); Daniel@0: % ndx(t) = i means the t'th data point is sample from cluster i Daniel@0: %ndx = sample_discrete(normalise(ones(1,K))); Daniel@0: ndx = [1*ones(1,20) 2*ones(1,20) 3*ones(1,20) 4*ones(1,20) 5*ones(1,20)]; Daniel@0: for t=1:T Daniel@0: data(t) = sample_gaussian(true_centers(ndx(t)), sigma, 1); Daniel@0: end Daniel@0: plot(1:T, data, 'x') Daniel@0: Daniel@0: Daniel@0: Daniel@0: % set the centers randomly from Gauss(0) Daniel@0: mix = gmm(ndim, K, 'spherical'); Daniel@0: h = plot_centers_as_lines(mix, [], T); Daniel@0: Daniel@0: if 0 Daniel@0: % Place initial centers at K data points chosen at random, but add some noise Daniel@0: choose_ndx = randperm(T); Daniel@0: choose_ndx = choose_ndx(1:K); Daniel@0: init_centers = data(choose_ndx) + sample_gaussian(0, 0.1, K); Daniel@0: mix.centres = init_centers; Daniel@0: h = plot_centers_as_lines(mix, h, T); Daniel@0: end Daniel@0: Daniel@0: if 0 Daniel@0: % update centers using netlab k-means Daniel@0: options = foptions; Daniel@0: niter = 10; Daniel@0: options(14) = niter; Daniel@0: mix = gmminit(mix, data, options); Daniel@0: h = plot_centers_as_lines(mix, h, T); Daniel@0: end Daniel@0: Daniel@0: % use matlab stats toolbox k-means with multiple restarts Daniel@0: nrestarts = 5; Daniel@0: [idx, centers] = kmeans(data, K, 'replicates', nrestarts, ... Daniel@0: 'emptyAction', 'singleton', 'display', 'iter'); Daniel@0: mix.centres = centers; Daniel@0: h = plot_centers_as_lines(mix, h, T); Daniel@0: Daniel@0: % fine tune with EM; compute covariances of each cluster Daniel@0: options = foptions; Daniel@0: niter = 20; Daniel@0: options(1) = 1; % display cost fn at each iter Daniel@0: options(14) = niter; Daniel@0: mix = gmmem(mix, data, options); Daniel@0: h = plot_centers_as_lines(mix, h, T); Daniel@0: Daniel@0: %%%%%%%%% Daniel@0: function h = plot_centers_as_lines(mix, h, T) Daniel@0: Daniel@0: K = mix.ncentres; Daniel@0: hold on Daniel@0: if isempty(h) Daniel@0: for k=1:K Daniel@0: h(k)=line([0 T], [mix.centres(k) mix.centres(k)]); Daniel@0: end Daniel@0: else Daniel@0: for k=1:K Daniel@0: set(h(k), 'xdata', [0 T], 'ydata', [mix.centres(k) mix.centres(k)]); Daniel@0: end Daniel@0: end Daniel@0: hold off Daniel@0: