view experiment-reverb/code/ui.py @ 2:c87a9505f294 tip

Added LICENSE for code, removed .wav files
author Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk>
date Sat, 30 Sep 2017 13:25:50 +0100
parents 246d5546657c
children
line wrap: on
line source
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 11 11:03:04 2015

@author: mmxgn
"""


import matplotlib
matplotlib.use("TkAgg")
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2TkAgg

import pyaudio
from sys import argv, exit
from essentia.standard import YamlInput, YamlOutput, AudioLoader, AudioWriter
from essentia import Pool
from pca import *

from numpy import *
from sklearn import cluster
from sklearn.metrics import pairwise_distances
from sklearn.cluster import KMeans, MiniBatchKMeans
from matplotlib.pyplot import *
#from sklearn.mixture import GMM
from sklearn.naive_bayes import GaussianNB, MultinomialNB
from scipy.signal import decimate
from sklearn import cross_validation
from Tkinter import *
import tkMessageBox
import thread
from numpy.core._internal import _gcd as gcd
import time
import os
import subprocess
from scikits.audiolab import Format, Sndfile
from scipy.signal import fftconvolve
from glob import glob

def zafar(lx, rx, d1, g1, da, G, gc, m):
    """ Rafii & Pardo Reverberator (2009) controlled by High Level parameters 
        Inputs:
            lx : left channel input
            rx : right channel input
            d1 : delay of first comb filter in samples
            g1 : gain of first comb filters
            da : delay of allpass filter in samples
            G  : dry/wet mix gain
            gc : lowpass filter gain
            m  : difference between left and right channel phases
            
        Outputs:
            ly: left channel output
            ry: right channel output
            """
        
    def calculate_parameters(d1,g1):
        
        d2 = int(round((1.5)**(-1)*d1))

        while gcd(d2,d1) != 1:
            d2 += 1
        
        d3 = int(round((1.5)**(-2)*d1))
        
        while gcd(d3, d2) != 1 or gcd(d3, d1) != 1:
            d3 += 1
            
        d4 = int(round((1.5)**(-3)*d1))
        
        while gcd(d4, d3) != 1 or gcd(d4, d2) != 1  or gcd(d4, d1) != 1:
            d4 += 1

            
        d5 = int(round((1.5)**(-4)*d1))

        while gcd(d5, d4) != 1 or gcd(d5, d3) != 1  or gcd(d5, d2) != 1 or gcd(d5, d1) != 1:
            d5 += 1        
        
        d6 = int(round((1.5)**(-5)*d1))
        while gcd(d6, d5) != 1 or gcd(d6, d4) != 1  or gcd(d6, d3) != 1 or gcd(d6, d2) != 1 or gcd(d6, d1) != 1:
            d6 += 1          
        g2 = g1**(1.5)**(-1)*g1
        g3 = g1**(1.5)**(-2)*g1
        g4 = g1**(1.5)**(-3)*g1     
        g5 = g1**(1.5)**(-4)*g1 
        g6 = g1**(1.5)**(-5)*g1        
        
        return (d1, d2, d3, d4, d5, d6, g1, g2, g3, g4, g5, g6)       
    def comb_array(x, g1, d1):

        (d1,d2,d3,d4,d5,d6,g1,g2,g3,g4,g5,g6) = calculate_parameters(d1,g1)


        
        c1out = comb(x, g1, d1)
        c2out = comb(x, g2, d2)
        c3out = comb(x, g3, d3)
        c4out = comb(x, g4, d4)
        c5out = comb(x, g5, d5)
        c6out = comb(x, g6, d6)
        
        
        Lc1 = len(c1out)
        Lc2 = len(c2out)
        Lc3 = len(c3out)
        Lc4 = len(c4out)
        Lc5 = len(c5out)
        Lc6 = len(c6out)
        
        Lc = max(Lc1, Lc2, Lc3, Lc4, Lc5, Lc6)
        
        y = zeros((Lc, ))
        
        y[0:Lc1] = c1out
        y[0:Lc2] += c2out
        y[0:Lc3] += c3out
        y[0:Lc4] += c4out
        y[0:Lc5] += c5out
        y[0:Lc6] += c6out
        
        return y        
        
    def comb(x, g, d):
        LEN = len(x)+d
        print d
        y = zeros((LEN,))
        for n in range(0, LEN):
            if n - d < 0:
                y[n] = 0
            else:
                y[n] =  x[n-d] + g*y[n-d]
                
        return y     
        
    def allpass(x, g, d):
        LENx = len(x)
        LENy = LENx+d
        y = zeros((LENy,))
        for n in range(0, LENy):
            if n-d < 0:
                y[n] = -g*x[n]
            elif n >= LENx:
                y[n] = x[n-d] + g*y[n-d]
            else:
                y[n] = x[n-d] - g*x[n] + g*y[n-d]
                
        return y
        
    def lowpass(x, g):
        LEN = len(x)
        y = zeros((LEN,))
        
        for n in range(0, LEN):
            if n-1 < 0:
                y[n] = (1-g)*x[n]
            else:
                y[n] = (1-g)*x[n] + g*y[n-1]
        
        return y    
        
    ga = 1./sqrt(2.)
    
    cin = 0.5*lx + 0.5*rx
    cout = comb_array(cin, g1, d1)     
    
            
    ra = allpass(cout,  ga, da+m/2)
    la = allpass(cout,  ga, da-m/2)
    
    ral = lowpass(ra, gc)
    lal = lowpass(la, gc)

    ralg = G*ral
    lalg = G*lal

    ry = ralg[0:len(rx)] + (1-G)*rx
    ly = lalg[0:len(lx)] + (1-G)*lx

  #     ry = cout
  #      ly = cout
            
    
    
    return (ry, ly)

class UI:
    
    def __init__(self, master, directory):
        self.master = master
        
        self.directory = directory
        


        yamlinput = YamlInput(filename="session.yaml")
        
        try:
            self.sessionpool = yamlinput()
            try:
                self.files_to_visit = self.sessionpool['files_to_visit']
            except:
                self.files_to_visit = []
                
            try:
                self.visited_files = self.sessionpool['visited_files']
            except:
                self.visited_files = []
            
            

        except:
            print "[II] Could not open sessionpool file, creating a new one"
            self.sessionpool = Pool()
            self.files_to_visit = glob("%s/*.wav" % directory)
            for i in self.files_to_visit:
                self.sessionpool.add('files_to_visit', i)
            self.visited_files = []            
            
        if len(self.files_to_visit) == 0:
            tkMessageBox.showinfo("","No files to visit")
            master.destroy()
            return
            
        filename = self.files_to_visit[-1]
        self.filename = filename
       # visited_files.append(filename)
        self.label_top = Label(master, text="")
        self.label_top.grid(row=0, column=0, columnspan=6)
        
        self.load_song(filename)

        
        # Top Label

        self.label_top.config( text="Training song: %s (sampleRate: %.0f, nChannels: %d) - %d songs left" % (filename, self.SR, self.numChannels, len(self.files_to_visit)-1))

        # Sliders
        
        self.scale_d1 = Scale(master, to_=0.01, from_=0.1, resolution=0.01, label="d1", showvalue=True)#, command=self.callback_update_parameters)
        self.scale_d1.bind("<ButtonRelease-1>",self.callback_update_parameters)        
        self.scale_d1.grid(row=1,column=0,rowspan=17,sticky=N+S+E+W)
        self.scale_g1 = Scale(master,to_=0.01, from_=0.99, resolution=0.01, label="g1", showvalue=True)#, command=self.callback_update_parameters)
        self.scale_g1.bind("<ButtonRelease-1>",self.callback_update_parameters)        
        
        self.scale_g1.grid(row=1,column=1,rowspan=17,sticky=N+S+E+W)        
        self.scale_da = Scale(master, to_=0.006, from_=0.012, resolution=0.001, label="da", showvalue=True)#, command=self.callback_update_parameters)
        self.scale_da.bind("<ButtonRelease-1>",self.callback_update_parameters)        
        
        self.scale_da.grid(row=1,column=2,rowspan=17,sticky=N+S+E+W)      
        self.scale_G = Scale(master,to_=0.01, from_=0.99, resolution=0.01, label="G", showvalue=True)#, command=self.callback_update_parameters)
        self.scale_G.bind("<ButtonRelease-1>",self.callback_update_parameters)        
        
        self.scale_G.grid(row=1,column=3,rowspan=17,sticky=N+S+E+W)    
        self.scale_gc = Scale(master, to_=0.01, from_=0.99, resolution=0.01, label="gc", showvalue=True)#, command=self.callback_update_parameters)
        self.scale_gc.bind("<ButtonRelease-1>",self.callback_update_parameters)        
        
        self.scale_gc.grid(row=1,column=4,rowspan=17,sticky=N+S+E+W)        
        
        
        # Labels
        
        self.label_T60 = Label(master, text="Reverberation Time: ")
        self.label_T60.grid(row=2,column=6,sticky=N+S+E+W)
        self.label_D = Label(master, text="Echo Density: ")
        self.label_D.grid(row=3,column=6,sticky=N+S+E+W)
        self.label_C = Label(master, text="Clarity: ")
        self.label_C.grid(row=4,column=6,sticky=N+S+E+W)
        self.label_Tc = Label(master, text="Central Time: ")
        self.label_Tc.grid(row=5,column=6,sticky=N+S+E+W)
        self.label_SC = Label(master, text="Spectral Centroid: ")
        self.label_SC.grid(row=6,column=6,sticky=N+S+E+W)
        
        
        # Buttons
        
        self.button_plot_impulse = Button(master, text="Plot Impulse",command=self.callback_plot_impulse, width=15).grid(row=7,column=6,sticky=N+S+E+W)
        self.button_plot_raw = Button(master, text="Plot Raw", width=15,command=self.callback_plot_raw).grid(row=8,column=6,sticky=N+S+E+W)
        self.button_plot_reverb = Button(master, text="Plot Reverb", width=15, command=self.callback_plot_reverb).grid(row=9, column=6,sticky=N+S+E+W) 
        self.button_play_raw = Button(master, text="Play Raw", bg="green", fg="white", width=15, command=self.callback_play_raw).grid(row=10, column=6,sticky=N+S+E+W)           
        self.button_play_reverb = Button(master, text="Play Reverb", bg="green", fg="white", width=15, command=self.callback_play_reverb).grid(row=11, column=6,sticky=N+S+E+W)         
        self.button_stop = Button(master, text="Stop Playing", bg="red", fg="white", width=15, command=self.callback_stop).grid(row=12, column=6,sticky=N+S+E+W)   
        self.button_save = Button(master, text="Save", fg="white", bg="orange", width=15, command=self.callback_save).grid(row=13, column=6,sticky=N+S+E+W)         
        self.button_reset = Button(master, text="Undo", width=15, command=self.callback_reset).grid(row=14, column=6,sticky=N+S+E+W)                 
        self.button_next = Button(master, text="Next >>", width=15, command=self.callback_next).grid(row=15, column=6,sticky=N+S+E+W)                 
        
        
        
        # Figure
        
        self.figure = Figure( dpi=50)
        self.figure.text(0.5,0.5,'No plot selected', weight = "bold", horizontalalignment='center')
        
        
        # a tk.DrawingArea
        self.canvas = FigureCanvasTkAgg(self.figure, master=root)
        #self.canvas.get_tk_widget().config(height=500, width=640)
        self.canvas.show()
        self.canvas.get_tk_widget().grid(row=0, column=7, rowspan=17, padx=20,sticky=E+W+N+S)
        
   #     toolbar = NavigationToolbar2TkAgg( master, master )
   #     toolbar.update()
        self.canvas._tkcanvas.grid(row=0, column=7, rowspan=17)
     #   self.scale_d1.pack()
        
        # Toolbar for canvas
        
        self.toolbar_frame = Frame(master)
        self.toolbar_frame.grid(row=17,column=7,sticky=E+W+N+S, padx=19)
        self.toolbar = NavigationToolbar2TkAgg(self.canvas, self.toolbar_frame)
        

        
        Grid.columnconfigure(master, 7, weight=1)
        Grid.rowconfigure(master, 1, weight=1)
        
        # Status bar

        self.status_bar_text = Label(text="Ready.")
        self.status_bar_text.grid(row=18, column=0, columnspan = 8, sticky=N+S+E, padx=10)
        
        self.lastplot = ''
        self.parameterschanged_render = True
        self.pendingactions = []
        
        # Initial values
        
        d1t = 0.05
        self.d1 = d1t*self.SR
        dat = 0.006
        self.da = dat*self.SR
        g1 = 0.5
        self.g1 = g1
        G = 0.5
        self.G = G
        gc = 0.5
        self.gc = gc
        
        self.scale_d1.set(d1t)
        self.scale_da.set(dat)
        self.scale_g1.set(g1)
        self.scale_gc.set(gc)
        self.scale_G.set(G)
        
        t = zeros((self.SR*120,))
        t[0] = 1  
        
        self.impulse = t
        
      #  self.callback_plot_impulse()
        
        
        # Pyaudio object
        
   #     self.player = pyaudio.PyAudio()
        
   #     self.player_idx = 0
   #     
    #    self.playerprocess = None
        
        
        # Pool
        
        self.pool = Pool()
        self.pool.set('filename', self.filename)
        self.pool.set('sampleRate', self.SR)
        self.pool.set('nChannels', self.numChannels)
        

        # Finally        
        self.callback_update_parameters(None)


        
        
        
    def pyaudio_callback_raw(self, in_data, frame_count, time_info, status):
        if self.player_command == 'Stop':
            return (0, pyaudio.paAbort)
        
        data = self.audio[self.player_idx:self.player_idx+frame_count, :]
        
        data = reshape(data, (data.shape[0]*data.shape[1], 1))
        
      #  print data
        self.player_idx += frame_count
        
        
        return (data, pyaudio.paContinue)
        
    
    def play_reverb(self):
            
            self.calculate_impulse_response()
            ly, ry = self.impulse_response_left_channel, self.impulse_response_right_channel
            
            lx = self.audio[:,0]
            rx = self.audio[:,1]
            
            print "Convolving left channel"
            l_out = fftconvolve(ly, lx)

            print "Convolving right channel"
            r_out = fftconvolve(ry, rx)
            
            
            
            lim = min(len(l_out), len(r_out))
            
            
            
            if self.numChannels == 1:
                audio_out = l_out[0:lim]
            else:
                audio_out = concatenate((matrix(l_out[0:lim]).T,
                                    matrix(r_out[0:lim]).T                                    
                                    ),
                                    axis=1)
            
            reverb_filename = "%s_reverb_%s" % (self.filename.split('.')[0], self.filename.split('.')[1])
                        
            audio_file = Sndfile(reverb_filename, 'w', Format(self.filename.split('.')[1]), self.numChannels, self.SR)
            audio_file.write_frames(audio_out)
            audio_file.close()
            
            self.reverberated_audio = audio_out
      #                              
       #     audiowriter(audio_out)
            
            self.reverb_filename = reverb_filename
            
            self.playerprocess = subprocess.Popen("mplayer %s" % reverb_filename, 
                                                  stdout = subprocess.PIPE,
                                                  shell=True,
                                                  preexec_fn=os.setsid)
        
    def play_raw(self):
        self.playerprocess = subprocess.Popen("mplayer %s" % self.filename,
                                              stdout = subprocess.PIPE,
                                              shell=True,
                                              preexec_fn=os.setsid)
        
        
        
        
        
    def remove_action_from_status(self, text):
        
        self.pendingactions.remove(text)
        
        if len(self.pendingactions) == 0:
            self.status_bar_text.config(text='Ready.')
        elif len(self.pendingactions) == 1:
            self.status_bar_text.config(text=self.pendingactions[0]+'.') 
        else:
            self.status_bar_text.config(text=reduce(lambda h,t: h+','+t, self.pendingactions)+'.')        
        
        
    def add_action_to_status(self, text):
        
        self.pendingactions.append(text)
        
        if len(self.pendingactions) == 0:
            self.status_bar_text.config(text='Ready.')
        elif len(self.pendingactions) == 1:
            self.status_bar_text.config(text=text+'.')             
        else:
            self.status_bar_text.config(text=reduce(lambda h,t: h+', '+t, self.pendingactions)+'.')
            
        print self.pendingactions, len(self.pendingactions)
        
        
    def load_song(self, filename):
  #      self.label_top.config(text="Training '%s'" % filename)
       # print filename
        #self.audio, self.SR, self.numChannels = AudioLoader(filename=filename)()
        tup = AudioLoader(filename=filename)()
        self.audio = tup[0]
        self.SR = tup[1]
        self.numChannels = tup[2]
        self.label_top.config(text="Training song: %s (sampleRate: %.0f, nChannels: %d) - %d songs left" % (filename, self.SR, self.numChannels, len(self.files_to_visit)-1))
        self.saved = False
        
        
        
  
    def estimate_T60(self, d1, g1, gc, G, SR):
        ga = 1/sqrt(2)
        return d1/SR/log(g1)*log(10**-3/ga/(1-gc)/G) 
        
    def calculate_parameters(self,d1,g1):
        
        d2 = int(round((1.5)**(-1)*d1))

        while gcd(d2,d1) != 1:
            d2 += 1
        
        d3 = int(round((1.5)**(-2)*d1))
        
        while gcd(d3, d2) != 1 or gcd(d3, d1) != 1:
            d3 += 1
            
        d4 = int(round((1.5)**(-3)*d1))
        
        while gcd(d4, d3) != 1 or gcd(d4, d2) != 1  or gcd(d4, d1) != 1:
            d4 += 1

            
        d5 = int(round((1.5)**(-4)*d1))

        while gcd(d5, d4) != 1 or gcd(d5, d3) != 1  or gcd(d5, d2) != 1 or gcd(d5, d1) != 1:
            d5 += 1        
        
        d6 = int(round((1.5)**(-5)*d1))
        while gcd(d6, d5) != 1 or gcd(d6, d4) != 1  or gcd(d6, d3) != 1 or gcd(d6, d2) != 1 or gcd(d6, d1) != 1:
            d6 += 1          
        g2 = g1**(1.5)**(-1)*g1
        g3 = g1**(1.5)**(-2)*g1
        g4 = g1**(1.5)**(-3)*g1     
        g5 = g1**(1.5)**(-4)*g1 
        g6 = g1**(1.5)**(-5)*g1        
        
        return (d1, d2, d3, d4, d5, d6, g1, g2, g3, g4, g5, g6)     
    def estimate_C(self, g1, G, gc):
        g2 = g1**(1.5)**(-1)*g1
        g3 = g1**(1.5)**(-2)*g1
        g4 = g1**(1.5)**(-3)*g1     
        g5 = g1**(1.5)**(-4)*g1 
        g6 = g1**(1.5)**(-5)*g1         
        gains = zeros((6,1))
        gains[0] = g1
        gains[1] = g2
        gains[2] = g3
        gains[3] = g4
        gains[4] = g5
        gains[5] = g6
        
        return -10*log10(G**2*(1-gc)/(1+gc)*sum(1/(1-gains**2)))    
    
    def estimate_D(self, d1, g1, da, SR):
        Dm = zeros((6,10))
        delays = zeros((6,))
        gains = zeros((6,1))        
        (delays[0],delays[1],delays[2],delays[3],delays[4],delays[5],gains[0],gains[1],gains[2],gains[3],gains[4],gains[5]) = self.calculate_parameters(d1,g1)    
        for k in range(1, 7):
            for m in range(1, 11):
                Dm[k-1,m-1] = max(0.1-m*delays[k-1]/SR,0)        
        
        return 10/da*self.SR*sum(Dm)
        
    def estimate_Tc(self, d1, g1, da, SR):
        delays = zeros((6,))
        gains = zeros((6,1))        
        (delays[0],delays[1],delays[2],delays[3],delays[4],delays[5],gains[0],gains[1],gains[2],gains[3],gains[4],gains[5]) = self.calculate_parameters(d1,g1) 
        return sum(delays/SR*gains**2/(1-gains**2)**2)/sum(gains**2/(1-gains**2)) + da/SR           
        
        
    def callback_update_parameters(self,_):
        SR = self.SR
        d1t = self.scale_d1.get()
        print d1t
        
        d1 = round(d1t*SR)
        g1 = self.scale_g1.get()
        dat = self.scale_da.get()
        da = round(dat*SR)
        G = self.scale_G.get()
        gc = self.scale_gc.get()
   
        
        T60 = self.estimate_T60(d1,g1,gc,G,SR)
        D = self.estimate_D(d1, g1, da, SR)/10
        C = self.estimate_C(g1, G, gc)
        Tc = self.estimate_Tc(d1,g1,da,SR)
        SC = self.estimate_SC(gc, SR)
        
        self.d1_old = self.d1
        self.G_old = self.G
        self.gc_old = self.gc
        self.g1_old = self.g1
        self.da_old = self.da
        
        self.d1 = d1
        self.G = G
        self.gc = gc
        self.g1 = g1
        self.da = da
        
        
        
        self.pool.set('parameters.d1', d1t)
        self.pool.set('parameters.G', G)
        self.pool.set('parameters.gc', gc)
        self.pool.set('parameters.g1', g1)
        self.pool.set('parameters.da', dat)
        
        
        
        self.T60 = T60
        self.D = D
        self.Tc = Tc
        self.SC = SC
        self.C = C
        
        self.pool.set('parameters.T60', T60)
        self.pool.set('parameters.D', D)
        self.pool.set('parameters.C', C)
        self.pool.set('parameters.Tc', Tc)
        self.pool.set('parameters.SC', SC)
        
        self.label_T60.config(text="Reverberation Time: %.3fs" % T60)
        self.label_D.config(text="Echo Density: %.3f at 0.1s" % D)
        self.label_C.config(text="Clarity: %.3f dB" % C)   
        self.label_Tc.config(text="Central Time: %.3fs" % Tc)    
        self.label_SC.config(text="Spectral Centroid: %.3f Hz" % SC)        
        
        self.lastplot = ''
        self.parameterschanged_render = True
#        self.callback_plot_impulse()
        
        
    def estimate_SC(self, gc, SR):
        n = arange(0, SR/2+1)
        return sum(n/(1+gc**2-2*gc*cos(2*pi*n/SR)))/sum(1/(1+gc**2-2*gc*cos(2*pi*n/SR)))        
       
        
        
        
    def say_hi(self):
        print "Hi, there"

    def callback_plot_impulse(self):
        try:
            thread.start_new_thread(self.plot_impulse, ())
        except:
            print "[EE] Could not start new thread"
    
    
    
    def calculate_impulse_response(self):
            self.add_action_to_status('Calculating impulse response')     
            N = self.numChannels
            SR = self.SR
            T = 1.0/self.SR
            
            delta = self.impulse[0:int(self.T60*self.SR)]
            print "delta:"
            print delta
            
            d1 = int(self.d1)
            g1 = self.g1
            da = int(self.da)
            G = self.G
            gc = self.gc
            
            mt = 0.002
            m = int(mt*SR)
                   
            (ly, ry) = zafar(delta,delta,d1,g1,da,G,gc,m)
            
            limt = self.T60
            
            lim = int(limt*SR)            
            t = arange(0, lim)*T

            padded_y = zeros(shape(t))
            padded_y[0:len(ly)] = ly
            
            
            padded_y = zeros(shape(t))
            padded_y[0:len(ry)] = ry
            
            ry = padded_y                        
            
            self.impulse_response_left_channel = ly
            self.impulse_response_right_channel = ry
            
            
            self.remove_action_from_status('Calculating impulse response')
            
        
    
    def plot_impulse(self):
        if self.lastplot != 'impulse':
            self.add_action_to_status('Plotting impulse response')
            N = self.numChannels
            SR = self.SR
            T = 1.0/self.SR
            
            delta = self.impulse[0:int(self.T60*self.SR)]
            print "delta:"
            print delta
            
            d1 = int(self.d1)
            g1 = self.g1
            da = int(self.da)
            G = self.G
            gc = self.gc
            
            mt = 0.002
            m = int(mt*SR)
                   
            print "Calculating zafar"
            (ly, ry) = zafar(delta,delta,d1,g1,da,G,gc,m)
            
            print "Stopped calculating zafar"#ly.shape
            limt = self.T60
            
            lim = int(limt*SR)
            print "lim:", lim
            
            t = arange(0, lim)*T
           # print t
            
            # Pad ly to t
            print "Shape ly"
            print ly
            print len(ly)
            padded_y = zeros(shape(t))
            padded_y[0:len(ly)] = ly
            
            print "Padded y"
            #print padded_y
            
           # ly = padded_y
            
            # Pad ry to t
            
            padded_y = zeros(shape(t))
            padded_y[0:len(ry)] = ry
            
            ry = padded_y            
            
            
            
            self.figure.clear()
            
            print "Passed A"
            subplt0 =  self.figure.add_subplot(2,1,1)

            subplt0.plot(t,abs(ly[0:lim]))
            subplt0.set_title('Left Channel')
            subplt0.set_xlabel('time (s)')
            subplt0.set_ylabel('amplitude')   
            subplt0.axvspan(0,0.1, alpha=0.1,color='cyan')
            subplt0.axvline(self.Tc, color='red', linestyle='--')
            subplt0.axvline(0.1, color='cyan', linestyle='--')
            subplt0.annotate('Central Time (Tc)', xy=(self.Tc, 0.5), xytext=(self.Tc+0.01, 0.52), arrowprops=dict(facecolor='black',width=1))
            subplt0.annotate('Echo Density (D) Measurement Point ', xy=(0.1, 0.6), xytext=(.11, 0.62), arrowprops=dict(facecolor='black',width=1))
            
#            
            
            subplt1 = self.figure.add_subplot(2,1,2,sharex=subplt0)
            subplt1.set_title('Right Channel')
            
            subplt1.plot(t,abs(ry[0:lim]))
            subplt1.set_xlabel('time (s)')
            subplt1.set_ylabel('amplitude')
            subplt1.axvspan(0,0.1, alpha=0.1,color='cyan')
            subplt1.axvline(self.Tc, color='red', linestyle='--')
            subplt1.axvline(0.1, color='cyan', linestyle='--')
            
            
            self.figure.suptitle("Reverberation Impulse Response")
            
#            print "Passed B"
#            
            self.remove_action_from_status('Plotting impulse response')
            self.canvas.draw()
            
            self.lastplot = 'impulse'
#            
            thread.exit_thread()

                
    
    def plot_raw(self):
        if self.lastplot != 'raw':
            self.add_action_to_status('Plotting raw')
            N = self.numChannels
            print "Channels: %d" % N
            L = len(self.audio[:,0])
            
            
            
            
            self.figure.clear()
            
            T = 1.0/self.SR
            t = arange(0, L)*T
            
            oldsubplt = None
            for n in range(0, N):
                if oldsubplt is not None:
                    subplt =  self.figure.add_subplot(N,1,n+1,sharex=oldsubplt)
                else:
                    subplt = self.figure.add_subplot(N,1,n+1)
                subplt.plot(t,self.audio[:,n])
                subplt.set_title('Channel %d' % n)
                subplt.set_xlabel('time (s)')
                subplt.set_ylabel('amplitude')
                
                oldsubplt = subplt
    
                
            self.figure.suptitle('Raw Signal')
            self.canvas.draw()
            
            self.lastplot = 'raw'
            self.remove_action_from_status('Plotting raw')    
            thread.exit_thread()
    def callback_plot_raw(self):
        try:
            thread.start_new_thread(self.plot_raw, ())
        except:
            print "[EE] Could not start new thread"
            


       # show()
    
    def plot_reverb(self):       
        if self.lastplot != 'reverb':
            self.add_action_to_status('Plotting reverberated signal')
            
            self.calculate_impulse_response()
            ly, ry = self.impulse_response_left_channel, self.impulse_response_right_channel
            
            lx = self.audio[:,0]
            rx = self.audio[:,1]
            
            print "Concolving left channel"
            l_out = fftconvolve(ly, lx)

            print "Convolving right channel"
            r_out = fftconvolve(ry, rx)
            
            
            
            lim = min(len(l_out), len(r_out))            
#            N = self.numChannels
#            SR = self.SR
#            T = 1.0/self.SR
#            
#            
#            d1 = int(self.d1)
#            g1 = self.g1
#            da = int(self.da)
#            G = self.G
#            gc = self.gc
#            
#            mt = 0.002
#            m = int(mt*SR)
#            
#            lchannel = ravel(self.audio[:,0])
#            rchannel = ravel(self.audio[:,1])
#                   
#            print "Calculating zafar"
#            
#            if self.parameterschanged_render == True:
#                (ly, ry) = zafar(lchannel,rchannel,d1,g1,da,G,gc,m)
#                
#                self.reverberated_signal_left_channel = ly
#                self.reverberated_signal_right_channel = ry
#                
#                self.parameterschanged_render = 0
#            else:
#                ly = self.reverberated_signal_left_channel
#                ry = self.reverberated_signal_right_channel
#            
#            print "Stopped calculating zafar"#ly.shape
#           # limt = self.T60
#            
         #   lim = int(limt*SR)
            
    #        lim = len(lchannel)
     #       print "lim:", lim
            T = 1/self.SR
            t = arange(0, lim)*T
           # print t
            
            # Pad ly to t
       #     print "Shape ly"
         ##   print ly
         #   print len(ly)
           # padded_y = zeros(shape(t))
           # padded_y[0:len(ly)] = ly
            
   #         print "Padded y"
            #print padded_y
            
           # ly = padded_y
            
            # Pad ry to t
            
          #  padded_y = zeros(shape(t))
          #  padded_y[0:len(ry)] = ry
            
        #    ry = padded_y            
      #      
            
            
            self.figure.clear()
            
            print "Passed A"
            subplt0 =  self.figure.add_subplot(2,1,1)

            subplt0.plot(t,l_out[0:lim])
            subplt0.set_title('Left Channel')
            subplt0.set_xlabel('time (s)')
            subplt0.set_ylabel('amplitude')   
      #      subplt0.axvspan(0,0.1, alpha=0.1,color='cyan')
        #    subplt0.axvline(self.Tc, color='red', linestyle='--')
       #     subplt0.axvline(0.1, color='cyan', linestyle='--')
      #      subplt0.annotate('Central Time (Tc)', xy=(self.Tc, 0.5), xytext=(self.Tc+0.01, 0.52), arrowprops=dict(facecolor='black',width=1))
      #     subplt0.annotate('Echo Density (D) Measurement Point ', xy=(0.1, 0.6), xytext=(.11, 0.62), arrowprops=dict(facecolor='black',width=1))
            
#            
            
            subplt1 = self.figure.add_subplot(2,1,2,sharex=subplt0)
            subplt1.set_title('Right Channel')
            
            subplt1.plot(t,r_out[0:lim])
            subplt1.set_xlabel('time (s)')
            subplt1.set_ylabel('amplitude')
        #    subplt1.axvspan(0,0.1, alpha=0.1,color='cyan')
    #        subplt1.axvline(self.Tc, color='red', linestyle='--')
     #       subplt1.axvline(0.1, color='cyan', linestyle='--')
            
            
            self.figure.suptitle("Reverberated Signal")
            
#            print "Passed B"
#            
            self.remove_action_from_status('Plotting reverberated signal')
            self.canvas.draw()
            
            self.lastplot = 'reverb'
#            
            thread.exit_thread()        
    def callback_plot_reverb(self):
        try:
            thread.start_new_thread(self.plot_reverb, ())
        except:
            print "[EE] Could not start new thread"
            
    
    def callback_play_raw(self):
        print "[II] Called callback_play_raw"        
        try:
            self.playerprocess.terminate()
        except:
            pass
        self.play_raw()
        
    def callback_play_reverb(self):
        
        print "[II] Called callback_play_reverb"
        try:
            self.playerprocess.terminate()
        except:
            pass
            
        self.play_reverb()
    
    def callback_stop(self):
        self.playerprocess.terminate()
    
    def callback_save(self):
        outf = "%s_parameters.yaml" % self.filename.split('.')[0]
        out = YamlOutput(filename=outf)
        out(self.pool)
        print "[II] Parameters Saved"
        self.saved = True
    
    def callback_reset(self):
        d1 = self.d1
        g1 = self.g1
        da = self.da
        G = self.G
        gc = self.gc
        
        self.d1 = self.d1_old
        self.g1 = self.g1_old
        self.G = self.G_old
        self.gc = self.gc_old
        self.da = self.da_old
        
        self.scale_d1.set(self.d1/self.SR)
        self.scale_g1.set(self.g1)
        self.scale_da.set(self.da/self.SR)
        self.scale_G.set(self.G)
        self.scale_gc.set(self.gc)
        
    
    def callback_next(self):
        if self.saved == False:
            tkMessageBox.showerror("File not saved", "You need to save your changes first")
            return
            
            
        self.visited_files.append(self.filename)
        self.sessionpool.add('visited_files', self.filename)
        self.files_to_visit.pop()
        self.sessionpool.remove('files_to_visit')
        for i in self.files_to_visit:
            self.sessionpool.add('files_to_visit', i)
        outp = YamlOutput(filename="session.yaml")(self.sessionpool)
            
        if len(self.files_to_visit) == 0:
            tkMessageBox.showinfo("Congratulations!", "You finished the training session!")
            self.master.destroy()
            return            
        self.filename = self.files_to_visit[-1]
        self.load_song(self.filename)
        
        
        
    
if __name__ == "__main__":
    if len(argv) != 2:
        print "[EE] Wrong number of arguments"
        print "[II] Correct syntax is:"
        print "[II] \t%s <trainingdir>"
        print "[II] where <trainingdir> contains the segments in .wav format and their corresponding .yaml files"

        exit(-1)    
        
    print "[II] Using directory: %s" % argv[1]
    root = Tk()
    app = UI(root, argv[1])
    root.mainloop()
    
  #  app.player.terminate()
   # root.destroy()