danstowell@0
|
1
|
danstowell@0
|
2 # Spectrogram auto-encoder
|
danstowell@0
|
3 # Dan Stowell 2016.
|
danstowell@0
|
4 #
|
danstowell@0
|
5 # Unusual things about this implementation:
|
danstowell@0
|
6 # * Data is not pre-whitened, instead we use a custom layer (NormalisationLayer) to normalise the mean-and-variance of the data for us. This is because I want the spectrogram to be normalised when it is input but not normalised when it is output.
|
danstowell@0
|
7 # * It's a convolutional net but only along the time axis; along the frequency axis it's fully-connected.
|
danstowell@0
|
8
|
danstowell@0
|
9 import numpy as np
|
danstowell@0
|
10
|
danstowell@0
|
11 import theano
|
danstowell@0
|
12 import theano.tensor as T
|
danstowell@0
|
13 import lasagne
|
danstowell@0
|
14 #import downhill
|
danstowell@0
|
15 from lasagne.nonlinearities import rectify, leaky_rectify, very_leaky_rectify
|
danstowell@0
|
16 from numpy import float32
|
danstowell@0
|
17
|
danstowell@1
|
18 try:
|
danstowell@1
|
19 from lasagne.layers import InverseLayer as _
|
danstowell@1
|
20 use_maxpool = True
|
danstowell@1
|
21 except ImportError:
|
danstowell@1
|
22 print("""**********************
|
danstowell@1
|
23 WARNING: InverseLayer not found in Lasagne. Please use a more recent version of Lasagne.
|
danstowell@1
|
24 WARNING: We'll deactivate the maxpooling part of the network (since we can't use InverseLayer to undo it)""")
|
danstowell@1
|
25 use_maxpool = False
|
danstowell@1
|
26
|
danstowell@0
|
27 import matplotlib
|
danstowell@0
|
28 #matplotlib.use('PDF') # http://www.astrobetter.com/plotting-to-a-file-in-python/
|
danstowell@0
|
29 import matplotlib.pyplot as plt
|
danstowell@0
|
30 import matplotlib.cm as cm
|
danstowell@0
|
31 from matplotlib.backends.backend_pdf import PdfPages
|
danstowell@0
|
32 plt.rcParams.update({'font.size': 6})
|
danstowell@0
|
33
|
danstowell@0
|
34 from userconfig import *
|
danstowell@0
|
35 import util
|
danstowell@0
|
36 from layers_custom import *
|
danstowell@0
|
37
|
danstowell@0
|
38 ###################################################################################################################
|
danstowell@0
|
39 # create Theano variables for input minibatch
|
danstowell@0
|
40 input_var = T.tensor4('X')
|
danstowell@0
|
41 # note that in general, the main data tensors will have these axes:
|
danstowell@0
|
42 # - minibatchsize
|
danstowell@0
|
43 # - numchannels (always 1 for us, since spectrograms)
|
danstowell@0
|
44 # - numfilts (or specbinnum for input)
|
danstowell@0
|
45 # - numtimebins
|
danstowell@0
|
46
|
danstowell@0
|
47 if example_is_audio:
|
danstowell@0
|
48 # load our example audio file as a specgram
|
danstowell@0
|
49 examplegram = util.standard_specgram((util.load_soundfile(examplewavpath, 0)))
|
danstowell@0
|
50 print("examplegram is of shape %s" % str(np.shape(examplegram)))
|
danstowell@0
|
51
|
danstowell@0
|
52 ###################################################################################################################
|
danstowell@0
|
53 # here we define our "semi-convolutional" autoencoder
|
danstowell@0
|
54 # NOTE: lasagne assumes pooling is on the TRAILING axis of the tensor, so we always use time as the trailing axis
|
danstowell@0
|
55
|
danstowell@0
|
56 def make_custom_convlayer(network, in_num_chans, out_num_chans):
|
danstowell@0
|
57 "Applies our special padding and reshaping to do 1D convolution on 2D data"
|
danstowell@0
|
58 network = lasagne.layers.PadLayer(network, width=(featframe_len-1)/2, batch_ndim=3) # NOTE: the "batch_ndim" is used to stop batch dims being padded, but here ALSO to skip first data dim
|
danstowell@0
|
59 print("shape after pad layer: %s" % str(network.output_shape))
|
danstowell@0
|
60 network = lasagne.layers.Conv2DLayer(network, out_num_chans, (in_num_chans, featframe_len), stride=(1,1), pad=0, nonlinearity=very_leaky_rectify, W=lasagne.init.Orthogonal()) # we pad "manually" in order to do it in one dimension only
|
danstowell@0
|
61 filters = network.W
|
danstowell@0
|
62 network = lasagne.layers.ReshapeLayer(network, ([0], [2], [1], [3])) # reinterpret channels as rows
|
danstowell@0
|
63 print("shape after conv layer: %s" % str(network.output_shape))
|
danstowell@0
|
64 return network, filters
|
danstowell@0
|
65
|
danstowell@0
|
66 network = lasagne.layers.InputLayer((None, 1, specbinnum, numtimebins), input_var)
|
danstowell@0
|
67 print("shape after input layer: %s" % str(network.output_shape))
|
danstowell@0
|
68 #
|
danstowell@0
|
69 # normalisation layer
|
danstowell@0
|
70 # -- note that we deliberately normalise the input but do not undo that at the output.
|
danstowell@0
|
71 # -- note that the normalisation params are not set by the training procedure, they need to be set before training begins.
|
danstowell@0
|
72 network = NormalisationLayer(network, specbinnum)
|
danstowell@0
|
73 normlayer = network # we need to remember this one so we can set its parameters
|
danstowell@0
|
74 #
|
danstowell@0
|
75 network, filters_enc = make_custom_convlayer(network, in_num_chans=specbinnum, out_num_chans=numfilters)
|
danstowell@0
|
76 #
|
danstowell@1
|
77 # NOTE: here we're using max-pooling, along the time axis only, and then
|
danstowell@1
|
78 # using Lasagne's "InverseLayer" to undo the maxpooling in one-hot fashion.
|
danstowell@1
|
79 # There's a side-effect of this: if you use *overlapping* maxpooling windows,
|
danstowell@1
|
80 # the InverseLayer may behave slightly unexpectedly, adding some points with
|
danstowell@1
|
81 # double magnitude. It's OK here since we're not overlapping the windows
|
danstowell@1
|
82 if use_maxpool:
|
danstowell@1
|
83 network = lasagne.layers.MaxPool2DLayer(network, pool_size=(1,2), stride=(1,2))
|
danstowell@1
|
84 maxpool_layer = network # store a pointer to this one
|
danstowell@0
|
85
|
danstowell@1
|
86 # NOTE: HERE is the "middle" of the autoencoder!
|
danstowell@0
|
87 latents = network # we remember the "latents" at the midpoint of the net, since we'll want to inspect them, and maybe regularise them too
|
danstowell@0
|
88
|
danstowell@1
|
89 if use_maxpool:
|
danstowell@1
|
90 network = lasagne.layers.InverseLayer(network, maxpool_layer)
|
danstowell@1
|
91
|
danstowell@0
|
92 network, filters_dec = make_custom_convlayer(network, in_num_chans=numfilters, out_num_chans=specbinnum)
|
danstowell@0
|
93
|
danstowell@0
|
94 network = lasagne.layers.NonlinearityLayer(network, nonlinearity=rectify) # finally a standard rectify since nonneg (specgram) target
|
danstowell@0
|
95
|
danstowell@0
|
96 ###################################################################################################################
|
danstowell@0
|
97 # define simple L2 loss function with a mild touch of regularisation
|
danstowell@0
|
98 prediction = lasagne.layers.get_output(network)
|
danstowell@0
|
99 loss = lasagne.objectives.squared_error(prediction, input_var)
|
danstowell@0
|
100 loss = loss.mean() + 1e-4 * lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
|
danstowell@0
|
101
|
danstowell@0
|
102 ###################################################################################################################
|
danstowell@0
|
103
|
danstowell@0
|
104 plot_probedata_data = None
|
danstowell@0
|
105 def plot_probedata(outpostfix, plottitle=None):
|
danstowell@0
|
106 """Visualises the network behaviour.
|
danstowell@0
|
107 NOTE: currently accesses globals. Should really be passed in the network, filters etc"""
|
danstowell@0
|
108 global plot_probedata_data
|
danstowell@0
|
109
|
danstowell@0
|
110 if plottitle==None:
|
danstowell@0
|
111 plottitle = outpostfix
|
danstowell@0
|
112
|
danstowell@0
|
113 if np.shape(plot_probedata_data)==():
|
danstowell@0
|
114 if example_is_audio:
|
danstowell@0
|
115 plot_probedata_data = np.array([[examplegram[:, examplegram_startindex:examplegram_startindex+numtimebins]]], float32)
|
danstowell@0
|
116 else:
|
danstowell@0
|
117 plot_probedata_data = np.zeros((minibatchsize, 1, specbinnum, numtimebins), dtype=float32)
|
danstowell@0
|
118 for _ in range(5):
|
danstowell@0
|
119 plot_probedata_data[:, :, np.random.randint(specbinnum), np.random.randint(numtimebins)] = 1
|
danstowell@0
|
120
|
danstowell@0
|
121 test_prediction = lasagne.layers.get_output(network, deterministic=True)
|
danstowell@0
|
122 test_latents = lasagne.layers.get_output(latents, deterministic=True)
|
danstowell@0
|
123 predict_fn = theano.function([input_var], test_prediction)
|
danstowell@0
|
124 latents_fn = theano.function([input_var], test_latents)
|
danstowell@0
|
125 prediction = predict_fn(plot_probedata_data)
|
danstowell@0
|
126 latentsval = latents_fn(plot_probedata_data)
|
danstowell@0
|
127 if False:
|
danstowell@0
|
128 print("Probedata has shape %s and meanabs %g" % ( plot_probedata_data.shape, np.mean(np.abs(plot_probedata_data ))))
|
danstowell@0
|
129 print("Latents has shape %s and meanabs %g" % (latentsval.shape, np.mean(np.abs(latentsval))))
|
danstowell@0
|
130 print("Prediction has shape %s and meanabs %g" % (prediction.shape, np.mean(np.abs(prediction))))
|
danstowell@0
|
131 print("Ratio %g" % (np.mean(np.abs(prediction)) / np.mean(np.abs(plot_probedata_data))))
|
danstowell@0
|
132
|
danstowell@0
|
133 util.mkdir_p('pdf')
|
danstowell@0
|
134 pdf = PdfPages('pdf/autoenc_probe_%s.pdf' % outpostfix)
|
danstowell@0
|
135 plt.figure(frameon=False)
|
danstowell@0
|
136 #
|
danstowell@0
|
137 plt.subplot(3, 1, 1)
|
danstowell@0
|
138 plotdata = plot_probedata_data[0,0,:,:]
|
danstowell@0
|
139 plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-np.max(np.abs(plotdata)), vmax=np.max(np.abs(plotdata)))
|
danstowell@0
|
140 plt.ylabel('Input')
|
danstowell@0
|
141 plt.title("%s" % (plottitle))
|
danstowell@0
|
142 #
|
danstowell@0
|
143 plt.subplot(3, 1, 2)
|
danstowell@0
|
144 plotdata = latentsval[0,0,:,:]
|
danstowell@0
|
145 plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-np.max(np.abs(plotdata)), vmax=np.max(np.abs(plotdata)))
|
danstowell@0
|
146 plt.ylabel('Latents')
|
danstowell@0
|
147 #
|
danstowell@0
|
148 plt.subplot(3, 1, 3)
|
danstowell@0
|
149 plotdata = prediction[0,0,:,:]
|
danstowell@0
|
150 plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-np.max(np.abs(plotdata)), vmax=np.max(np.abs(plotdata)))
|
danstowell@0
|
151 plt.ylabel('Output')
|
danstowell@0
|
152 #
|
danstowell@0
|
153 pdf.savefig()
|
danstowell@0
|
154 plt.close()
|
danstowell@0
|
155 ##
|
danstowell@0
|
156 for filtvar, filtlbl, isenc in [
|
danstowell@0
|
157 (filters_enc, 'encoding', True),
|
danstowell@0
|
158 (filters_dec, 'decoding', False),
|
danstowell@0
|
159 ]:
|
danstowell@0
|
160 plt.figure(frameon=False)
|
danstowell@0
|
161 vals = filtvar.get_value()
|
danstowell@0
|
162 #print(" %s filters have shape %s" % (filtlbl, vals.shape))
|
danstowell@0
|
163 vlim = np.max(np.abs(vals))
|
danstowell@0
|
164 for whichfilt in range(numfilters):
|
danstowell@0
|
165 plt.subplot(3, 8, whichfilt+1)
|
danstowell@0
|
166 # NOTE: for encoding/decoding filters, we grab the "slice" of interest from the tensor in different ways: different axes, and flipped.
|
danstowell@0
|
167 if isenc:
|
danstowell@0
|
168 plotdata = vals[numfilters-(1+whichfilt),0,::-1,::-1]
|
danstowell@0
|
169 else:
|
danstowell@0
|
170 plotdata = vals[:,0,whichfilt,:]
|
danstowell@0
|
171
|
danstowell@0
|
172 plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-vlim, vmax=vlim)
|
danstowell@0
|
173 plt.xticks([])
|
danstowell@0
|
174 if whichfilt==0:
|
danstowell@0
|
175 plt.title("%s filters (%s)" % (filtlbl, outpostfix))
|
danstowell@0
|
176 else:
|
danstowell@0
|
177 plt.yticks([])
|
danstowell@0
|
178
|
danstowell@0
|
179 pdf.savefig()
|
danstowell@0
|
180 plt.close()
|
danstowell@0
|
181 ##
|
danstowell@0
|
182 pdf.close()
|
danstowell@0
|
183
|
danstowell@0
|
184 plot_probedata('init')
|
danstowell@0
|
185
|
danstowell@0
|
186 ###################################################################################################################
|
danstowell@0
|
187 if True:
|
danstowell@0
|
188 ###################################
|
danstowell@0
|
189 # here we set up some training data. this is ALL A BIT SIMPLE - for a proper experiment we'd prepare a full dataset, and it might be too big to be all in memory.
|
danstowell@0
|
190 training_data_size=100
|
danstowell@0
|
191 training_data = np.zeros((training_data_size, minibatchsize, 1, specbinnum, numtimebins), dtype=float32)
|
danstowell@0
|
192 if example_is_audio:
|
danstowell@0
|
193 # manually grab a load of random subsets of the training audio
|
danstowell@0
|
194 training_data_size=100
|
danstowell@0
|
195 for which_training_batch in range(training_data_size):
|
danstowell@0
|
196 for which_training_datum in range(minibatchsize):
|
danstowell@0
|
197 startindex = np.random.randint(examplegram.shape[1]-numtimebins)
|
danstowell@0
|
198 training_data[which_training_batch, which_training_datum, :, :, :] = examplegram[:, startindex:startindex+numtimebins]
|
danstowell@0
|
199 else:
|
danstowell@0
|
200 # make some simple (sparse) data that we can train with
|
danstowell@0
|
201 for which_training_batch in range(training_data_size):
|
danstowell@0
|
202 for which_training_datum in range(minibatchsize):
|
danstowell@0
|
203 for _ in range(5):
|
danstowell@0
|
204 training_data[which_training_batch, which_training_datum, :, np.random.randint(specbinnum), np.random.randint(numtimebins)] = 1
|
danstowell@0
|
205
|
danstowell@0
|
206 ###################################
|
danstowell@0
|
207 # pre-training setup
|
danstowell@0
|
208
|
danstowell@0
|
209 # set the normalisation parameters manually as an estimate from the training data
|
danstowell@0
|
210 normlayer.set_normalisation(training_data)
|
danstowell@0
|
211
|
danstowell@0
|
212 ###################################
|
danstowell@0
|
213 # training
|
danstowell@0
|
214
|
danstowell@0
|
215 # compile training function that updates parameters and returns training loss
|
danstowell@0
|
216 params = lasagne.layers.get_all_params(network, trainable=True)
|
danstowell@0
|
217 updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=0.01, momentum=0.9)
|
danstowell@0
|
218 train_fn = theano.function([input_var], loss, updates=updates)
|
danstowell@0
|
219
|
danstowell@0
|
220 # train network
|
danstowell@0
|
221 numepochs = 2048 # 3 # 100 # 5000
|
danstowell@0
|
222 for epoch in range(numepochs):
|
danstowell@0
|
223 loss = 0
|
danstowell@0
|
224 for input_batch in training_data:
|
danstowell@0
|
225 loss += train_fn(input_batch)
|
danstowell@0
|
226 if epoch==0 or epoch==numepochs-1 or (2 ** int(np.log2(epoch)) == epoch):
|
danstowell@0
|
227 lossreadout = loss / len(training_data)
|
danstowell@0
|
228 infostring = "Epoch %d/%d: Loss %g" % (epoch, numepochs, lossreadout)
|
danstowell@0
|
229 print(infostring)
|
danstowell@0
|
230 plot_probedata('progress', plottitle="progress (%s)" % infostring)
|
danstowell@0
|
231
|
danstowell@0
|
232 plot_probedata('trained', plottitle="trained (%d epochs; Loss %g)" % (numepochs, lossreadout))
|
danstowell@0
|
233
|