annotate autoencoder-specgram.py @ 1:04f1e3463466 tip master

Implement maxpooling and unpooling aspect
author Dan Stowell <danstowell@users.sourceforge.net>
date Wed, 13 Jan 2016 09:56:16 +0000
parents 73317239d6d1
children
rev   line source
danstowell@0 1
danstowell@0 2 # Spectrogram auto-encoder
danstowell@0 3 # Dan Stowell 2016.
danstowell@0 4 #
danstowell@0 5 # Unusual things about this implementation:
danstowell@0 6 # * Data is not pre-whitened, instead we use a custom layer (NormalisationLayer) to normalise the mean-and-variance of the data for us. This is because I want the spectrogram to be normalised when it is input but not normalised when it is output.
danstowell@0 7 # * It's a convolutional net but only along the time axis; along the frequency axis it's fully-connected.
danstowell@0 8
danstowell@0 9 import numpy as np
danstowell@0 10
danstowell@0 11 import theano
danstowell@0 12 import theano.tensor as T
danstowell@0 13 import lasagne
danstowell@0 14 #import downhill
danstowell@0 15 from lasagne.nonlinearities import rectify, leaky_rectify, very_leaky_rectify
danstowell@0 16 from numpy import float32
danstowell@0 17
danstowell@1 18 try:
danstowell@1 19 from lasagne.layers import InverseLayer as _
danstowell@1 20 use_maxpool = True
danstowell@1 21 except ImportError:
danstowell@1 22 print("""**********************
danstowell@1 23 WARNING: InverseLayer not found in Lasagne. Please use a more recent version of Lasagne.
danstowell@1 24 WARNING: We'll deactivate the maxpooling part of the network (since we can't use InverseLayer to undo it)""")
danstowell@1 25 use_maxpool = False
danstowell@1 26
danstowell@0 27 import matplotlib
danstowell@0 28 #matplotlib.use('PDF') # http://www.astrobetter.com/plotting-to-a-file-in-python/
danstowell@0 29 import matplotlib.pyplot as plt
danstowell@0 30 import matplotlib.cm as cm
danstowell@0 31 from matplotlib.backends.backend_pdf import PdfPages
danstowell@0 32 plt.rcParams.update({'font.size': 6})
danstowell@0 33
danstowell@0 34 from userconfig import *
danstowell@0 35 import util
danstowell@0 36 from layers_custom import *
danstowell@0 37
danstowell@0 38 ###################################################################################################################
danstowell@0 39 # create Theano variables for input minibatch
danstowell@0 40 input_var = T.tensor4('X')
danstowell@0 41 # note that in general, the main data tensors will have these axes:
danstowell@0 42 # - minibatchsize
danstowell@0 43 # - numchannels (always 1 for us, since spectrograms)
danstowell@0 44 # - numfilts (or specbinnum for input)
danstowell@0 45 # - numtimebins
danstowell@0 46
danstowell@0 47 if example_is_audio:
danstowell@0 48 # load our example audio file as a specgram
danstowell@0 49 examplegram = util.standard_specgram((util.load_soundfile(examplewavpath, 0)))
danstowell@0 50 print("examplegram is of shape %s" % str(np.shape(examplegram)))
danstowell@0 51
danstowell@0 52 ###################################################################################################################
danstowell@0 53 # here we define our "semi-convolutional" autoencoder
danstowell@0 54 # NOTE: lasagne assumes pooling is on the TRAILING axis of the tensor, so we always use time as the trailing axis
danstowell@0 55
danstowell@0 56 def make_custom_convlayer(network, in_num_chans, out_num_chans):
danstowell@0 57 "Applies our special padding and reshaping to do 1D convolution on 2D data"
danstowell@0 58 network = lasagne.layers.PadLayer(network, width=(featframe_len-1)/2, batch_ndim=3) # NOTE: the "batch_ndim" is used to stop batch dims being padded, but here ALSO to skip first data dim
danstowell@0 59 print("shape after pad layer: %s" % str(network.output_shape))
danstowell@0 60 network = lasagne.layers.Conv2DLayer(network, out_num_chans, (in_num_chans, featframe_len), stride=(1,1), pad=0, nonlinearity=very_leaky_rectify, W=lasagne.init.Orthogonal()) # we pad "manually" in order to do it in one dimension only
danstowell@0 61 filters = network.W
danstowell@0 62 network = lasagne.layers.ReshapeLayer(network, ([0], [2], [1], [3])) # reinterpret channels as rows
danstowell@0 63 print("shape after conv layer: %s" % str(network.output_shape))
danstowell@0 64 return network, filters
danstowell@0 65
danstowell@0 66 network = lasagne.layers.InputLayer((None, 1, specbinnum, numtimebins), input_var)
danstowell@0 67 print("shape after input layer: %s" % str(network.output_shape))
danstowell@0 68 #
danstowell@0 69 # normalisation layer
danstowell@0 70 # -- note that we deliberately normalise the input but do not undo that at the output.
danstowell@0 71 # -- note that the normalisation params are not set by the training procedure, they need to be set before training begins.
danstowell@0 72 network = NormalisationLayer(network, specbinnum)
danstowell@0 73 normlayer = network # we need to remember this one so we can set its parameters
danstowell@0 74 #
danstowell@0 75 network, filters_enc = make_custom_convlayer(network, in_num_chans=specbinnum, out_num_chans=numfilters)
danstowell@0 76 #
danstowell@1 77 # NOTE: here we're using max-pooling, along the time axis only, and then
danstowell@1 78 # using Lasagne's "InverseLayer" to undo the maxpooling in one-hot fashion.
danstowell@1 79 # There's a side-effect of this: if you use *overlapping* maxpooling windows,
danstowell@1 80 # the InverseLayer may behave slightly unexpectedly, adding some points with
danstowell@1 81 # double magnitude. It's OK here since we're not overlapping the windows
danstowell@1 82 if use_maxpool:
danstowell@1 83 network = lasagne.layers.MaxPool2DLayer(network, pool_size=(1,2), stride=(1,2))
danstowell@1 84 maxpool_layer = network # store a pointer to this one
danstowell@0 85
danstowell@1 86 # NOTE: HERE is the "middle" of the autoencoder!
danstowell@0 87 latents = network # we remember the "latents" at the midpoint of the net, since we'll want to inspect them, and maybe regularise them too
danstowell@0 88
danstowell@1 89 if use_maxpool:
danstowell@1 90 network = lasagne.layers.InverseLayer(network, maxpool_layer)
danstowell@1 91
danstowell@0 92 network, filters_dec = make_custom_convlayer(network, in_num_chans=numfilters, out_num_chans=specbinnum)
danstowell@0 93
danstowell@0 94 network = lasagne.layers.NonlinearityLayer(network, nonlinearity=rectify) # finally a standard rectify since nonneg (specgram) target
danstowell@0 95
danstowell@0 96 ###################################################################################################################
danstowell@0 97 # define simple L2 loss function with a mild touch of regularisation
danstowell@0 98 prediction = lasagne.layers.get_output(network)
danstowell@0 99 loss = lasagne.objectives.squared_error(prediction, input_var)
danstowell@0 100 loss = loss.mean() + 1e-4 * lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2)
danstowell@0 101
danstowell@0 102 ###################################################################################################################
danstowell@0 103
danstowell@0 104 plot_probedata_data = None
danstowell@0 105 def plot_probedata(outpostfix, plottitle=None):
danstowell@0 106 """Visualises the network behaviour.
danstowell@0 107 NOTE: currently accesses globals. Should really be passed in the network, filters etc"""
danstowell@0 108 global plot_probedata_data
danstowell@0 109
danstowell@0 110 if plottitle==None:
danstowell@0 111 plottitle = outpostfix
danstowell@0 112
danstowell@0 113 if np.shape(plot_probedata_data)==():
danstowell@0 114 if example_is_audio:
danstowell@0 115 plot_probedata_data = np.array([[examplegram[:, examplegram_startindex:examplegram_startindex+numtimebins]]], float32)
danstowell@0 116 else:
danstowell@0 117 plot_probedata_data = np.zeros((minibatchsize, 1, specbinnum, numtimebins), dtype=float32)
danstowell@0 118 for _ in range(5):
danstowell@0 119 plot_probedata_data[:, :, np.random.randint(specbinnum), np.random.randint(numtimebins)] = 1
danstowell@0 120
danstowell@0 121 test_prediction = lasagne.layers.get_output(network, deterministic=True)
danstowell@0 122 test_latents = lasagne.layers.get_output(latents, deterministic=True)
danstowell@0 123 predict_fn = theano.function([input_var], test_prediction)
danstowell@0 124 latents_fn = theano.function([input_var], test_latents)
danstowell@0 125 prediction = predict_fn(plot_probedata_data)
danstowell@0 126 latentsval = latents_fn(plot_probedata_data)
danstowell@0 127 if False:
danstowell@0 128 print("Probedata has shape %s and meanabs %g" % ( plot_probedata_data.shape, np.mean(np.abs(plot_probedata_data ))))
danstowell@0 129 print("Latents has shape %s and meanabs %g" % (latentsval.shape, np.mean(np.abs(latentsval))))
danstowell@0 130 print("Prediction has shape %s and meanabs %g" % (prediction.shape, np.mean(np.abs(prediction))))
danstowell@0 131 print("Ratio %g" % (np.mean(np.abs(prediction)) / np.mean(np.abs(plot_probedata_data))))
danstowell@0 132
danstowell@0 133 util.mkdir_p('pdf')
danstowell@0 134 pdf = PdfPages('pdf/autoenc_probe_%s.pdf' % outpostfix)
danstowell@0 135 plt.figure(frameon=False)
danstowell@0 136 #
danstowell@0 137 plt.subplot(3, 1, 1)
danstowell@0 138 plotdata = plot_probedata_data[0,0,:,:]
danstowell@0 139 plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-np.max(np.abs(plotdata)), vmax=np.max(np.abs(plotdata)))
danstowell@0 140 plt.ylabel('Input')
danstowell@0 141 plt.title("%s" % (plottitle))
danstowell@0 142 #
danstowell@0 143 plt.subplot(3, 1, 2)
danstowell@0 144 plotdata = latentsval[0,0,:,:]
danstowell@0 145 plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-np.max(np.abs(plotdata)), vmax=np.max(np.abs(plotdata)))
danstowell@0 146 plt.ylabel('Latents')
danstowell@0 147 #
danstowell@0 148 plt.subplot(3, 1, 3)
danstowell@0 149 plotdata = prediction[0,0,:,:]
danstowell@0 150 plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-np.max(np.abs(plotdata)), vmax=np.max(np.abs(plotdata)))
danstowell@0 151 plt.ylabel('Output')
danstowell@0 152 #
danstowell@0 153 pdf.savefig()
danstowell@0 154 plt.close()
danstowell@0 155 ##
danstowell@0 156 for filtvar, filtlbl, isenc in [
danstowell@0 157 (filters_enc, 'encoding', True),
danstowell@0 158 (filters_dec, 'decoding', False),
danstowell@0 159 ]:
danstowell@0 160 plt.figure(frameon=False)
danstowell@0 161 vals = filtvar.get_value()
danstowell@0 162 #print(" %s filters have shape %s" % (filtlbl, vals.shape))
danstowell@0 163 vlim = np.max(np.abs(vals))
danstowell@0 164 for whichfilt in range(numfilters):
danstowell@0 165 plt.subplot(3, 8, whichfilt+1)
danstowell@0 166 # NOTE: for encoding/decoding filters, we grab the "slice" of interest from the tensor in different ways: different axes, and flipped.
danstowell@0 167 if isenc:
danstowell@0 168 plotdata = vals[numfilters-(1+whichfilt),0,::-1,::-1]
danstowell@0 169 else:
danstowell@0 170 plotdata = vals[:,0,whichfilt,:]
danstowell@0 171
danstowell@0 172 plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-vlim, vmax=vlim)
danstowell@0 173 plt.xticks([])
danstowell@0 174 if whichfilt==0:
danstowell@0 175 plt.title("%s filters (%s)" % (filtlbl, outpostfix))
danstowell@0 176 else:
danstowell@0 177 plt.yticks([])
danstowell@0 178
danstowell@0 179 pdf.savefig()
danstowell@0 180 plt.close()
danstowell@0 181 ##
danstowell@0 182 pdf.close()
danstowell@0 183
danstowell@0 184 plot_probedata('init')
danstowell@0 185
danstowell@0 186 ###################################################################################################################
danstowell@0 187 if True:
danstowell@0 188 ###################################
danstowell@0 189 # here we set up some training data. this is ALL A BIT SIMPLE - for a proper experiment we'd prepare a full dataset, and it might be too big to be all in memory.
danstowell@0 190 training_data_size=100
danstowell@0 191 training_data = np.zeros((training_data_size, minibatchsize, 1, specbinnum, numtimebins), dtype=float32)
danstowell@0 192 if example_is_audio:
danstowell@0 193 # manually grab a load of random subsets of the training audio
danstowell@0 194 training_data_size=100
danstowell@0 195 for which_training_batch in range(training_data_size):
danstowell@0 196 for which_training_datum in range(minibatchsize):
danstowell@0 197 startindex = np.random.randint(examplegram.shape[1]-numtimebins)
danstowell@0 198 training_data[which_training_batch, which_training_datum, :, :, :] = examplegram[:, startindex:startindex+numtimebins]
danstowell@0 199 else:
danstowell@0 200 # make some simple (sparse) data that we can train with
danstowell@0 201 for which_training_batch in range(training_data_size):
danstowell@0 202 for which_training_datum in range(minibatchsize):
danstowell@0 203 for _ in range(5):
danstowell@0 204 training_data[which_training_batch, which_training_datum, :, np.random.randint(specbinnum), np.random.randint(numtimebins)] = 1
danstowell@0 205
danstowell@0 206 ###################################
danstowell@0 207 # pre-training setup
danstowell@0 208
danstowell@0 209 # set the normalisation parameters manually as an estimate from the training data
danstowell@0 210 normlayer.set_normalisation(training_data)
danstowell@0 211
danstowell@0 212 ###################################
danstowell@0 213 # training
danstowell@0 214
danstowell@0 215 # compile training function that updates parameters and returns training loss
danstowell@0 216 params = lasagne.layers.get_all_params(network, trainable=True)
danstowell@0 217 updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=0.01, momentum=0.9)
danstowell@0 218 train_fn = theano.function([input_var], loss, updates=updates)
danstowell@0 219
danstowell@0 220 # train network
danstowell@0 221 numepochs = 2048 # 3 # 100 # 5000
danstowell@0 222 for epoch in range(numepochs):
danstowell@0 223 loss = 0
danstowell@0 224 for input_batch in training_data:
danstowell@0 225 loss += train_fn(input_batch)
danstowell@0 226 if epoch==0 or epoch==numepochs-1 or (2 ** int(np.log2(epoch)) == epoch):
danstowell@0 227 lossreadout = loss / len(training_data)
danstowell@0 228 infostring = "Epoch %d/%d: Loss %g" % (epoch, numepochs, lossreadout)
danstowell@0 229 print(infostring)
danstowell@0 230 plot_probedata('progress', plottitle="progress (%s)" % infostring)
danstowell@0 231
danstowell@0 232 plot_probedata('trained', plottitle="trained (%d epochs; Loss %g)" % (numepochs, lossreadout))
danstowell@0 233