comparison toolboxes/MIRtoolbox1.3.2/somtoolbox/som_prototrain.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [sM,sTrain] = som_prototrain(sM, D)
2
3 %SOM_PROTOTRAIN Use sequential algorithm to train the Self-Organizing Map.
4 %
5 % [sM,sT] = som_prototrain(sM, D)
6 %
7 % sM = som_prototrain(sM,D);
8 %
9 % Input and output arguments:
10 % sM (struct) map struct, the trained and updated map is returned
11 % (matrix) codebook matrix of a self-organizing map
12 % size munits x dim or msize(1) x ... x msize(k) x dim
13 % The trained map codebook is returned.
14 % D (struct) training data; data struct
15 % (matrix) training data, size dlen x dim
16 %
17 % This function is otherwise just like SOM_SEQTRAIN except that
18 % the implementation of the sequential training algorithm is very
19 % straightforward (and slower). This should make it easy for you
20 % to modify the algorithm, if you want to.
21 %
22 % For help on input and output parameters, try
23 % 'type som_prototrain' or check out the help for SOM_SEQTRAIN.
24 % See also SOM_SEQTRAIN, SOM_BATCHTRAIN.
25
26 % Contributed to SOM Toolbox vs2, February 2nd, 2000 by Juha Vesanto
27 % http://www.cis.hut.fi/projects/somtoolbox/
28
29 % Version 2.0beta juuso 080200 130300
30
31 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
32 %% Check input arguments
33
34 % map
35 struct_mode = isstruct(sM);
36 if struct_mode,
37 M = sM.codebook;
38 sTopol = sM.topol;
39 mask = sM.mask;
40 msize = sTopol.msize;
41 neigh = sM.neigh;
42 else
43 M = sM; orig_size = size(M);
44 if ndims(sM) > 2,
45 si = size(sM); dim = si(end); msize = si(1:end-1);
46 M = reshape(sM,[prod(msize) dim]);
47 else
48 msize = [orig_size(1) 1]; dim = orig_size(2);
49 end
50 sM = som_map_struct(dim,'msize',msize); sTopol = sM.topol;
51 mask = ones(dim,1);
52 neigh = 'gaussian';
53 end
54 [munits dim] = size(M);
55
56 % data
57 if isstruct(D), data_name = D.name; D = D.data;
58 else data_name = inputname(2);
59 end
60 D = D(find(sum(isnan(D),2) < dim),:); % remove empty vectors from the data
61 [dlen ddim] = size(D); % check input dimension
62 if dim ~= ddim, error('Map and data input space dimensions disagree.'); end
63
64 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
65 %% initialize (these are default values, change as you will)
66
67 % training length
68 trainlen = 20*dlen; % 20 epochs by default
69
70 % neighborhood radius
71 radius_type = 'linear';
72 rini = max(msize)/2;
73 rfin = 1;
74
75 % learning rate
76 alpha_type = 'inv';
77 alpha_ini = 0.2;
78
79 % initialize random number generator
80 rand('state',sum(100*clock));
81
82 % tracking
83 start = clock; trackstep = 100;
84
85 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
86 %% Action
87
88 Ud = som_unit_dists(sTopol); % distance between map units on the grid
89 mu_x_1 = ones(munits,1); % this is used pretty often
90
91 for t = 1:trainlen,
92
93 %% find BMU
94 ind = ceil(dlen*rand(1)+eps); % select one vector
95 x = D(ind,:); % pick it up
96 known = ~isnan(x); % its known components
97 Dx = M(:,known) - x(mu_x_1,known); % each map unit minus the vector
98 dist2 = (Dx.^2)*mask(known); % squared distances
99 [qerr bmu] = min(dist2); % find BMU
100
101 %% neighborhood
102 switch radius_type, % radius
103 case 'linear', r = rini+(rfin-rini)*(t-1)/(trainlen-1);
104 end
105 if ~r, r=eps; end % zero neighborhood radius may cause div-by-zero error
106 switch neigh, % neighborhood function
107 case 'bubble', h = (Ud(:,bmu) <= r);
108 case 'gaussian', h = exp(-(Ud(:,bmu).^2)/(2*r*r));
109 case 'cutgauss', h = exp(-(Ud(:,bmu).^2)/(2*r*r)) .* (Ud(:,bmu) <= r);
110 case 'ep', h = (1 - (Ud(:,bmu).^2)/(r*r)) .* (Ud(:,bmu) <= r);
111 end
112
113 %% learning rate
114 switch alpha_type,
115 case 'linear', a = (1-t/trainlen)*alpha_ini;
116 case 'inv', a = alpha_ini / (1 + 99*(t-1)/(trainlen-1));
117 case 'power', a = alpha_ini * (0.005/alpha_ini)^((t-1)/trainlen);
118 end
119
120 %% update
121 M(:,known) = M(:,known) - a*h(:,ones(sum(known),1)).*Dx;
122
123 %% tracking
124 if t==1 | ~rem(t,trackstep),
125 elap_t = etime(clock,start); tot_t = elap_t*trainlen/t;
126 fprintf(1,'\rTraining: %3.0f/ %3.0f s',elap_t,tot_t)
127 end
128
129 end; % for t = 1:trainlen
130 fprintf(1,'\n');
131
132 % outputs
133 sTrain = som_set('som_train','algorithm','proto',...
134 'data_name',data_name,...
135 'neigh',neigh,...
136 'mask',mask,...
137 'radius_ini',rini,...
138 'radius_fin',rfin,...
139 'alpha_ini',alpha_ini,...
140 'alpha_type',alpha_type,...
141 'trainlen',trainlen,...
142 'time',datestr(now,0));
143
144 if struct_mode,
145 sM = som_set(sM,'codebook',M,'mask',mask,'neigh',neigh);
146 sM.trainhist(end+1) = sTrain;
147 else
148 sM = reshape(M,orig_size);
149 end
150
151 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
152
153
154
155