danstowell@0: danstowell@0: # Spectrogram auto-encoder danstowell@0: # Dan Stowell 2016. danstowell@0: # danstowell@0: # Unusual things about this implementation: danstowell@0: # * Data is not pre-whitened, instead we use a custom layer (NormalisationLayer) to normalise the mean-and-variance of the data for us. This is because I want the spectrogram to be normalised when it is input but not normalised when it is output. danstowell@0: # * It's a convolutional net but only along the time axis; along the frequency axis it's fully-connected. danstowell@0: danstowell@0: import numpy as np danstowell@0: danstowell@0: import theano danstowell@0: import theano.tensor as T danstowell@0: import lasagne danstowell@0: #import downhill danstowell@0: from lasagne.nonlinearities import rectify, leaky_rectify, very_leaky_rectify danstowell@0: from numpy import float32 danstowell@0: danstowell@1: try: danstowell@1: from lasagne.layers import InverseLayer as _ danstowell@1: use_maxpool = True danstowell@1: except ImportError: danstowell@1: print("""********************** danstowell@1: WARNING: InverseLayer not found in Lasagne. Please use a more recent version of Lasagne. danstowell@1: WARNING: We'll deactivate the maxpooling part of the network (since we can't use InverseLayer to undo it)""") danstowell@1: use_maxpool = False danstowell@1: danstowell@0: import matplotlib danstowell@0: #matplotlib.use('PDF') # http://www.astrobetter.com/plotting-to-a-file-in-python/ danstowell@0: import matplotlib.pyplot as plt danstowell@0: import matplotlib.cm as cm danstowell@0: from matplotlib.backends.backend_pdf import PdfPages danstowell@0: plt.rcParams.update({'font.size': 6}) danstowell@0: danstowell@0: from userconfig import * danstowell@0: import util danstowell@0: from layers_custom import * danstowell@0: danstowell@0: ################################################################################################################### danstowell@0: # create Theano variables for input minibatch danstowell@0: input_var = T.tensor4('X') danstowell@0: # note that in general, the main data tensors will have these axes: danstowell@0: # - minibatchsize danstowell@0: # - numchannels (always 1 for us, since spectrograms) danstowell@0: # - numfilts (or specbinnum for input) danstowell@0: # - numtimebins danstowell@0: danstowell@0: if example_is_audio: danstowell@0: # load our example audio file as a specgram danstowell@0: examplegram = util.standard_specgram((util.load_soundfile(examplewavpath, 0))) danstowell@0: print("examplegram is of shape %s" % str(np.shape(examplegram))) danstowell@0: danstowell@0: ################################################################################################################### danstowell@0: # here we define our "semi-convolutional" autoencoder danstowell@0: # NOTE: lasagne assumes pooling is on the TRAILING axis of the tensor, so we always use time as the trailing axis danstowell@0: danstowell@0: def make_custom_convlayer(network, in_num_chans, out_num_chans): danstowell@0: "Applies our special padding and reshaping to do 1D convolution on 2D data" danstowell@0: network = lasagne.layers.PadLayer(network, width=(featframe_len-1)/2, batch_ndim=3) # NOTE: the "batch_ndim" is used to stop batch dims being padded, but here ALSO to skip first data dim danstowell@0: print("shape after pad layer: %s" % str(network.output_shape)) danstowell@0: network = lasagne.layers.Conv2DLayer(network, out_num_chans, (in_num_chans, featframe_len), stride=(1,1), pad=0, nonlinearity=very_leaky_rectify, W=lasagne.init.Orthogonal()) # we pad "manually" in order to do it in one dimension only danstowell@0: filters = network.W danstowell@0: network = lasagne.layers.ReshapeLayer(network, ([0], [2], [1], [3])) # reinterpret channels as rows danstowell@0: print("shape after conv layer: %s" % str(network.output_shape)) danstowell@0: return network, filters danstowell@0: danstowell@0: network = lasagne.layers.InputLayer((None, 1, specbinnum, numtimebins), input_var) danstowell@0: print("shape after input layer: %s" % str(network.output_shape)) danstowell@0: # danstowell@0: # normalisation layer danstowell@0: # -- note that we deliberately normalise the input but do not undo that at the output. danstowell@0: # -- note that the normalisation params are not set by the training procedure, they need to be set before training begins. danstowell@0: network = NormalisationLayer(network, specbinnum) danstowell@0: normlayer = network # we need to remember this one so we can set its parameters danstowell@0: # danstowell@0: network, filters_enc = make_custom_convlayer(network, in_num_chans=specbinnum, out_num_chans=numfilters) danstowell@0: # danstowell@1: # NOTE: here we're using max-pooling, along the time axis only, and then danstowell@1: # using Lasagne's "InverseLayer" to undo the maxpooling in one-hot fashion. danstowell@1: # There's a side-effect of this: if you use *overlapping* maxpooling windows, danstowell@1: # the InverseLayer may behave slightly unexpectedly, adding some points with danstowell@1: # double magnitude. It's OK here since we're not overlapping the windows danstowell@1: if use_maxpool: danstowell@1: network = lasagne.layers.MaxPool2DLayer(network, pool_size=(1,2), stride=(1,2)) danstowell@1: maxpool_layer = network # store a pointer to this one danstowell@0: danstowell@1: # NOTE: HERE is the "middle" of the autoencoder! danstowell@0: latents = network # we remember the "latents" at the midpoint of the net, since we'll want to inspect them, and maybe regularise them too danstowell@0: danstowell@1: if use_maxpool: danstowell@1: network = lasagne.layers.InverseLayer(network, maxpool_layer) danstowell@1: danstowell@0: network, filters_dec = make_custom_convlayer(network, in_num_chans=numfilters, out_num_chans=specbinnum) danstowell@0: danstowell@0: network = lasagne.layers.NonlinearityLayer(network, nonlinearity=rectify) # finally a standard rectify since nonneg (specgram) target danstowell@0: danstowell@0: ################################################################################################################### danstowell@0: # define simple L2 loss function with a mild touch of regularisation danstowell@0: prediction = lasagne.layers.get_output(network) danstowell@0: loss = lasagne.objectives.squared_error(prediction, input_var) danstowell@0: loss = loss.mean() + 1e-4 * lasagne.regularization.regularize_network_params(network, lasagne.regularization.l2) danstowell@0: danstowell@0: ################################################################################################################### danstowell@0: danstowell@0: plot_probedata_data = None danstowell@0: def plot_probedata(outpostfix, plottitle=None): danstowell@0: """Visualises the network behaviour. danstowell@0: NOTE: currently accesses globals. Should really be passed in the network, filters etc""" danstowell@0: global plot_probedata_data danstowell@0: danstowell@0: if plottitle==None: danstowell@0: plottitle = outpostfix danstowell@0: danstowell@0: if np.shape(plot_probedata_data)==(): danstowell@0: if example_is_audio: danstowell@0: plot_probedata_data = np.array([[examplegram[:, examplegram_startindex:examplegram_startindex+numtimebins]]], float32) danstowell@0: else: danstowell@0: plot_probedata_data = np.zeros((minibatchsize, 1, specbinnum, numtimebins), dtype=float32) danstowell@0: for _ in range(5): danstowell@0: plot_probedata_data[:, :, np.random.randint(specbinnum), np.random.randint(numtimebins)] = 1 danstowell@0: danstowell@0: test_prediction = lasagne.layers.get_output(network, deterministic=True) danstowell@0: test_latents = lasagne.layers.get_output(latents, deterministic=True) danstowell@0: predict_fn = theano.function([input_var], test_prediction) danstowell@0: latents_fn = theano.function([input_var], test_latents) danstowell@0: prediction = predict_fn(plot_probedata_data) danstowell@0: latentsval = latents_fn(plot_probedata_data) danstowell@0: if False: danstowell@0: print("Probedata has shape %s and meanabs %g" % ( plot_probedata_data.shape, np.mean(np.abs(plot_probedata_data )))) danstowell@0: print("Latents has shape %s and meanabs %g" % (latentsval.shape, np.mean(np.abs(latentsval)))) danstowell@0: print("Prediction has shape %s and meanabs %g" % (prediction.shape, np.mean(np.abs(prediction)))) danstowell@0: print("Ratio %g" % (np.mean(np.abs(prediction)) / np.mean(np.abs(plot_probedata_data)))) danstowell@0: danstowell@0: util.mkdir_p('pdf') danstowell@0: pdf = PdfPages('pdf/autoenc_probe_%s.pdf' % outpostfix) danstowell@0: plt.figure(frameon=False) danstowell@0: # danstowell@0: plt.subplot(3, 1, 1) danstowell@0: plotdata = plot_probedata_data[0,0,:,:] danstowell@0: plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-np.max(np.abs(plotdata)), vmax=np.max(np.abs(plotdata))) danstowell@0: plt.ylabel('Input') danstowell@0: plt.title("%s" % (plottitle)) danstowell@0: # danstowell@0: plt.subplot(3, 1, 2) danstowell@0: plotdata = latentsval[0,0,:,:] danstowell@0: plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-np.max(np.abs(plotdata)), vmax=np.max(np.abs(plotdata))) danstowell@0: plt.ylabel('Latents') danstowell@0: # danstowell@0: plt.subplot(3, 1, 3) danstowell@0: plotdata = prediction[0,0,:,:] danstowell@0: plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-np.max(np.abs(plotdata)), vmax=np.max(np.abs(plotdata))) danstowell@0: plt.ylabel('Output') danstowell@0: # danstowell@0: pdf.savefig() danstowell@0: plt.close() danstowell@0: ## danstowell@0: for filtvar, filtlbl, isenc in [ danstowell@0: (filters_enc, 'encoding', True), danstowell@0: (filters_dec, 'decoding', False), danstowell@0: ]: danstowell@0: plt.figure(frameon=False) danstowell@0: vals = filtvar.get_value() danstowell@0: #print(" %s filters have shape %s" % (filtlbl, vals.shape)) danstowell@0: vlim = np.max(np.abs(vals)) danstowell@0: for whichfilt in range(numfilters): danstowell@0: plt.subplot(3, 8, whichfilt+1) danstowell@0: # NOTE: for encoding/decoding filters, we grab the "slice" of interest from the tensor in different ways: different axes, and flipped. danstowell@0: if isenc: danstowell@0: plotdata = vals[numfilters-(1+whichfilt),0,::-1,::-1] danstowell@0: else: danstowell@0: plotdata = vals[:,0,whichfilt,:] danstowell@0: danstowell@0: plt.imshow(plotdata, origin='lower', interpolation='nearest', cmap='RdBu', aspect='auto', vmin=-vlim, vmax=vlim) danstowell@0: plt.xticks([]) danstowell@0: if whichfilt==0: danstowell@0: plt.title("%s filters (%s)" % (filtlbl, outpostfix)) danstowell@0: else: danstowell@0: plt.yticks([]) danstowell@0: danstowell@0: pdf.savefig() danstowell@0: plt.close() danstowell@0: ## danstowell@0: pdf.close() danstowell@0: danstowell@0: plot_probedata('init') danstowell@0: danstowell@0: ################################################################################################################### danstowell@0: if True: danstowell@0: ################################### danstowell@0: # here we set up some training data. this is ALL A BIT SIMPLE - for a proper experiment we'd prepare a full dataset, and it might be too big to be all in memory. danstowell@0: training_data_size=100 danstowell@0: training_data = np.zeros((training_data_size, minibatchsize, 1, specbinnum, numtimebins), dtype=float32) danstowell@0: if example_is_audio: danstowell@0: # manually grab a load of random subsets of the training audio danstowell@0: training_data_size=100 danstowell@0: for which_training_batch in range(training_data_size): danstowell@0: for which_training_datum in range(minibatchsize): danstowell@0: startindex = np.random.randint(examplegram.shape[1]-numtimebins) danstowell@0: training_data[which_training_batch, which_training_datum, :, :, :] = examplegram[:, startindex:startindex+numtimebins] danstowell@0: else: danstowell@0: # make some simple (sparse) data that we can train with danstowell@0: for which_training_batch in range(training_data_size): danstowell@0: for which_training_datum in range(minibatchsize): danstowell@0: for _ in range(5): danstowell@0: training_data[which_training_batch, which_training_datum, :, np.random.randint(specbinnum), np.random.randint(numtimebins)] = 1 danstowell@0: danstowell@0: ################################### danstowell@0: # pre-training setup danstowell@0: danstowell@0: # set the normalisation parameters manually as an estimate from the training data danstowell@0: normlayer.set_normalisation(training_data) danstowell@0: danstowell@0: ################################### danstowell@0: # training danstowell@0: danstowell@0: # compile training function that updates parameters and returns training loss danstowell@0: params = lasagne.layers.get_all_params(network, trainable=True) danstowell@0: updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=0.01, momentum=0.9) danstowell@0: train_fn = theano.function([input_var], loss, updates=updates) danstowell@0: danstowell@0: # train network danstowell@0: numepochs = 2048 # 3 # 100 # 5000 danstowell@0: for epoch in range(numepochs): danstowell@0: loss = 0 danstowell@0: for input_batch in training_data: danstowell@0: loss += train_fn(input_batch) danstowell@0: if epoch==0 or epoch==numepochs-1 or (2 ** int(np.log2(epoch)) == epoch): danstowell@0: lossreadout = loss / len(training_data) danstowell@0: infostring = "Epoch %d/%d: Loss %g" % (epoch, numepochs, lossreadout) danstowell@0: print(infostring) danstowell@0: plot_probedata('progress', plottitle="progress (%s)" % infostring) danstowell@0: danstowell@0: plot_probedata('trained', plottitle="trained (%d epochs; Loss %g)" % (numepochs, lossreadout)) danstowell@0: