Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/gtminit.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function net = gtminit(net, options, data, samp_type, varargin) | |
2 %GTMINIT Initialise the weights and latent sample in a GTM. | |
3 % | |
4 % Description | |
5 % NET = GTMINIT(NET, OPTIONS, DATA, SAMPTYPE) takes a GTM NET and | |
6 % generates a sample of latent data points and sets the centres (and | |
7 % widths if appropriate) of NET.RBFNET. | |
8 % | |
9 % If the SAMPTYPE is 'REGULAR', then regular grids of latent data | |
10 % points and RBF centres are created. The dimension of the latent data | |
11 % space must be 1 or 2. For one-dimensional latent space, the | |
12 % LSAMPSIZE parameter gives the number of latent points and the | |
13 % RBFSAMPSIZE parameter gives the number of RBF centres. For a two- | |
14 % dimensional latent space, these parameters must be vectors of length | |
15 % 2 with the number of points in each of the x and y directions to | |
16 % create a rectangular grid. The widths of the RBF basis functions are | |
17 % set by a call to RBFSETFW passing OPTIONS(7) as the scaling | |
18 % parameter. | |
19 % | |
20 % If the SAMPTYPE is 'UNIFORM' or 'GAUSSIAN' then the latent data is | |
21 % found by sampling from a uniform or Gaussian distribution | |
22 % correspondingly. The RBF basis function parameters are set by a call | |
23 % to RBFSETBF with the DATA parameter as dataset and the OPTIONS | |
24 % vector. | |
25 % | |
26 % Finally, the output layer weights of the RBF are initialised by | |
27 % mapping the mean of the latent variable to the mean of the target | |
28 % variable, and the L-dimensional latent variale variance to the | |
29 % variance of the targets along the first L principal components. | |
30 % | |
31 % See also | |
32 % GTM, GTMEM, PCA, RBFSETBF, RBFSETFW | |
33 % | |
34 | |
35 % Copyright (c) Ian T Nabney (1996-2001) | |
36 | |
37 % Check for consistency | |
38 errstring = consist(net, 'gtm', data); | |
39 if ~isempty(errstring) | |
40 error(errstring); | |
41 end | |
42 | |
43 % Check type of sample | |
44 stypes = {'regular', 'uniform', 'gaussian'}; | |
45 if (strcmp(samp_type, stypes)) == 0 | |
46 error('Undefined sample type.') | |
47 end | |
48 | |
49 if net.dim_latent > size(data, 2) | |
50 error('Latent space dimension must not be greater than data dimension') | |
51 end | |
52 nlatent = net.gmmnet.ncentres; | |
53 nhidden = net.rbfnet.nhidden; | |
54 | |
55 % Create latent data sample and set RBF centres | |
56 | |
57 switch samp_type | |
58 case 'regular' | |
59 if nargin ~= 6 | |
60 error('Regular type must specify latent and RBF shapes'); | |
61 end | |
62 l_samp_size = varargin{1}; | |
63 rbf_samp_size = varargin{2}; | |
64 if round(l_samp_size) ~= l_samp_size | |
65 error('Latent sample specification must contain integers') | |
66 end | |
67 % Check existence and size of rbf specification | |
68 if any(size(rbf_samp_size) ~= [1 net.dim_latent]) | ... | |
69 prod(rbf_samp_size) ~= nhidden | |
70 error('Incorrect specification of RBF centres') | |
71 end | |
72 % Check dimension and type of latent data specification | |
73 if any(size(l_samp_size) ~= [1 net.dim_latent]) | ... | |
74 prod(l_samp_size) ~= nlatent | |
75 error('Incorrect dimension of latent sample spec.') | |
76 end | |
77 if net.dim_latent == 1 | |
78 net.X = [-1:2/(l_samp_size-1):1]'; | |
79 net.rbfnet.c = [-1:2/(rbf_samp_size-1):1]'; | |
80 net.rbfnet = rbfsetfw(net.rbfnet, options(7)); | |
81 elseif net.dim_latent == 2 | |
82 net.X = gtm_rctg(l_samp_size); | |
83 net.rbfnet.c = gtm_rctg(rbf_samp_size); | |
84 net.rbfnet = rbfsetfw(net.rbfnet, options(7)); | |
85 else | |
86 error('For regular sample, input dimension must be 1 or 2.') | |
87 end | |
88 | |
89 | |
90 case {'uniform', 'gaussian'} | |
91 if strcmp(samp_type, 'uniform') | |
92 net.X = 2 * (rand(nlatent, net.dim_latent) - 0.5); | |
93 else | |
94 % Sample from N(0, 0.25) distribution to ensure most latent | |
95 % data is inside square | |
96 net.X = randn(nlatent, net.dim_latent)/2; | |
97 end | |
98 net.rbfnet = rbfsetbf(net.rbfnet, options, net.X); | |
99 otherwise | |
100 % Shouldn't get here | |
101 error('Invalid sample type'); | |
102 | |
103 end | |
104 | |
105 % Latent data sample and basis function parameters chosen. | |
106 % Now set output weights | |
107 [PCcoeff, PCvec] = pca(data); | |
108 | |
109 % Scale PCs by eigenvalues | |
110 A = PCvec(:, 1:net.dim_latent)*diag(sqrt(PCcoeff(1:net.dim_latent))); | |
111 | |
112 [temp, Phi] = rbffwd(net.rbfnet, net.X); | |
113 % Normalise X to ensure 1:1 mapping of variances and calculate weights | |
114 % as solution of Phi*W = normX*A' | |
115 normX = (net.X - ones(size(net.X))*diag(mean(net.X)))*diag(1./std(net.X)); | |
116 net.rbfnet.w2 = Phi \ (normX*A'); | |
117 % Bias is mean of target data | |
118 net.rbfnet.b2 = mean(data); | |
119 | |
120 % Must also set initial value of variance | |
121 % Find average distance between nearest centres | |
122 % Ensure that distance of centre to itself is excluded by setting diagonal | |
123 % entries to realmax | |
124 net.gmmnet.centres = rbffwd(net.rbfnet, net.X); | |
125 d = dist2(net.gmmnet.centres, net.gmmnet.centres) + ... | |
126 diag(ones(net.gmmnet.ncentres, 1)*realmax); | |
127 sigma = mean(min(d))/2; | |
128 | |
129 % Now set covariance to minimum of this and next largest eigenvalue | |
130 if net.dim_latent < size(data, 2) | |
131 sigma = min(sigma, PCcoeff(net.dim_latent+1)); | |
132 end | |
133 net.gmmnet.covars = sigma*ones(1, net.gmmnet.ncentres); | |
134 | |
135 % Sub-function to create the sample data in 2d | |
136 function sample = gtm_rctg(samp_size) | |
137 | |
138 xDim = samp_size(1); | |
139 yDim = samp_size(2); | |
140 % Produce a grid with the right number of rows and columns | |
141 [X, Y] = meshgrid([0:1:(xDim-1)], [(yDim-1):-1:0]); | |
142 | |
143 % Change grid representation | |
144 sample = [X(:), Y(:)]; | |
145 | |
146 % Shift grid to correct position and scale it | |
147 maxXY= max(sample); | |
148 sample(:,1) = 2*(sample(:,1) - maxXY(1)/2)./maxXY(1); | |
149 sample(:,2) = 2*(sample(:,2) - maxXY(2)/2)./maxXY(2); | |
150 return; | |
151 | |
152 | |
153 |