view training-sessions/tom.deacon/ui.py @ 0:246d5546657c

initial commit, needs cleanup
author Emmanouil Theofanis Chourdakis <e.t.chourdakis@qmul.ac.uk>
date Wed, 14 Dec 2016 13:15:48 +0000
parents
children
line wrap: on
line source
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 11 11:03:04 2015

@author: mmxgn
"""



from essentia.standard import YamlInput, YamlOutput, AudioLoader, AudioWriter
from essentia import Pool
import matplotlib
matplotlib.use("TkAgg")
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2TkAgg
from mapping import *

from sys import argv, exit

from numpy import *
from matplotlib.pyplot import *

from Tkinter import *
import tkMessageBox, tkFileDialog

from numpy.core._internal import _gcd as gcd

import os
import subprocess
from scikits.audiolab import Format, Sndfile
from scipy.signal import fftconvolve
from glob import glob


def zafar(lx, rx, d1, g1, da, G, gc, m):
    """ Rafii & Pardo Reverberator (2009) controlled by High Level parameters 
        Inputs:
            lx : left channel input
            rx : right channel input
            d1 : delay of first comb filter in samples
            g1 : gain of first comb filters
            da : delay of allpass filter in samples
            G  : dry/wet mix gain
            gc : lowpass filter gain
            m  : difference between left and right channel phases
            
        Outputs:
            ly: left channel output
            ry: right channel output
            """
        
    def calculate_parameters(d1,g1):
        
        d2 = int(round((1.5)**(-1)*d1))

        while gcd(d2,d1) != 1:
            d2 += 1
        
        d3 = int(round((1.5)**(-2)*d1))
        
        while gcd(d3, d2) != 1 or gcd(d3, d1) != 1:
            d3 += 1
            
        d4 = int(round((1.5)**(-3)*d1))
        
        while gcd(d4, d3) != 1 or gcd(d4, d2) != 1  or gcd(d4, d1) != 1:
            d4 += 1

            
        d5 = int(round((1.5)**(-4)*d1))

        while gcd(d5, d4) != 1 or gcd(d5, d3) != 1  or gcd(d5, d2) != 1 or gcd(d5, d1) != 1:
            d5 += 1        
        
        d6 = int(round((1.5)**(-5)*d1))
        while gcd(d6, d5) != 1 or gcd(d6, d4) != 1  or gcd(d6, d3) != 1 or gcd(d6, d2) != 1 or gcd(d6, d1) != 1:
            d6 += 1          
        g2 = g1**(1.5)**(-1)*g1
        g3 = g1**(1.5)**(-2)*g1
        g4 = g1**(1.5)**(-3)*g1     
        g5 = g1**(1.5)**(-4)*g1 
        g6 = g1**(1.5)**(-5)*g1        
        
        return (d1, d2, d3, d4, d5, d6, g1, g2, g3, g4, g5, g6)       
    def comb_array(x, g1, d1):

        (d1,d2,d3,d4,d5,d6,g1,g2,g3,g4,g5,g6) = calculate_parameters(d1,g1)


        
        c1out = comb(x, g1, d1)
        c2out = comb(x, g2, d2)
        c3out = comb(x, g3, d3)
        c4out = comb(x, g4, d4)
        c5out = comb(x, g5, d5)
        c6out = comb(x, g6, d6)
        
        
        Lc1 = len(c1out)
        Lc2 = len(c2out)
        Lc3 = len(c3out)
        Lc4 = len(c4out)
        Lc5 = len(c5out)
        Lc6 = len(c6out)
        
        Lc = max(Lc1, Lc2, Lc3, Lc4, Lc5, Lc6)
        
        y = zeros((Lc, ))
        
        y[0:Lc1] = c1out
        y[0:Lc2] += c2out
        y[0:Lc3] += c3out
        y[0:Lc4] += c4out
        y[0:Lc5] += c5out
        y[0:Lc6] += c6out
        
        return y        
        
    def comb(x, g, d):
        LEN = len(x)+d
    #    print d
        y = zeros((LEN,))
        for n in range(0, LEN):
            if n - d < 0:
                y[n] = 0
            else:
                y[n] =  x[n-d] + g*y[n-d]
                
        return y     
        
    def allpass(x, g, d):
        LENx = len(x)
        LENy = LENx+d
        y = zeros((LENy,))
        for n in range(0, LENy):
            if n-d < 0:
                y[n] = -g*x[n]
            elif n >= LENx:
                y[n] = x[n-d] + g*y[n-d]
            else:
                y[n] = x[n-d] - g*x[n] + g*y[n-d]
                
        return y
        
    def lowpass(x, g):
        LEN = len(x)
        y = zeros((LEN,))
        
        for n in range(0, LEN):
            if n-1 < 0:
                y[n] = (1-g)*x[n]
            else:
                y[n] = (1-g)*x[n] + g*y[n-1]
        
        return y    
        
    ga = 1./sqrt(2.)
    
    cin = 0.5*lx + 0.5*rx
    cout = comb_array(cin, g1, d1)     
    
            
    ra = allpass(cout,  ga, da+m/2)
    la = allpass(cout,  ga, da-m/2)
    
    ral = lowpass(ra, gc)
    lal = lowpass(la, gc)

    ralg = G*ral
    lalg = G*lal

    ry = ralg[0:len(rx)] + (1-G)*rx
    ly = lalg[0:len(lx)] + (1-G)*lx

    return (ry, ly)

class UI:
    
    def __init__(self, master, directory):
        self.master = master
        
        self.directory = directory
        


        yamlinput = YamlInput(filename="session.yaml")
        
        try:
            self.sessionpool = yamlinput()
            try:
                self.files_to_visit = self.sessionpool['files_to_visit']
            except:
                self.files_to_visit = []
                
            try:
                self.visited_files = self.sessionpool['visited_files']
            except:
                self.visited_files = []
            
            

        except:
            print "[II] Could not open sessionpool file, creating a new one"
            self.sessionpool = Pool()
            self.files_to_visit = glob("%s/*.wav" % directory)
            for i in self.files_to_visit:
                self.sessionpool.add('files_to_visit', i)
            self.visited_files = []            
            
        if len(self.files_to_visit) == 0:
            tkMessageBox.showinfo("","No files to visit")
            master.destroy()
            return
            
        filename = self.files_to_visit[-1]
        self.filename = filename
       # visited_files.append(filename)
        self.label_top = Label(master, text="")
        self.label_top.grid(row=0, column=0, columnspan=6)
        
        self.load_song(filename)

        
        # Top Label

        self.label_top.config( text="Training song: %s (sampleRate: %.0f, nChannels: %d) - %d songs left" % (filename, self.SR, self.numChannels, len(self.files_to_visit)-1))

        # High Level Parameters



        # Sliders
        self.scale_T60 = Scale(master, to_=T60_min, from_=T60_max, resolution=0.01, label="RT60", showvalue=False)
        self.scale_T60.bind("<ButtonRelease-1>",self.callback_update_parameters_high)
        self.scale_T60.grid(row=1,column=0,rowspan=23,sticky=N+S+E+W)
        
        self.scale_D = Scale(master, to_=D_max, from_=D_min, resolution=0.01, label="D", showvalue=False)
        self.scale_D.bind("<ButtonRelease-1>",self.callback_update_parameters_high)
        self.scale_D.grid(row=1,column=1,rowspan=23,sticky=N+S+E+W)

        self.scale_C = Scale(master, to_=C_min, from_=C_max, resolution=0.01, label="C", showvalue=False)
        self.scale_C.bind("<ButtonRelease-1>",self.callback_update_parameters_high)
        self.scale_C.grid(row=1,column=2,rowspan=23,sticky=N+S+E+W)

        self.scale_Tc = Scale(master, to_=Tc_min, from_=Tc_max, resolution=0.01, label="T_c", showvalue=False)
        self.scale_Tc.bind("<ButtonRelease-1>",self.callback_update_parameters_high)
        self.scale_Tc.grid(row=1,column=3,rowspan=23,sticky=N+S+E+W)

        self.scale_SC = Scale(master, to_=SC_min, from_=SC_max, resolution=0.01, label="SC", showvalue=False)
        self.scale_SC.bind("<ButtonRelease-1>",self.callback_update_parameters_high)
        self.scale_SC.grid(row=1,column=4,rowspan=23,sticky=N+S+E+W)        


        # Fine Tuning (coefficients)       

        # Labels
        #self.label_p = Label(master, text="Coefficients Fine Tuning:")        
        #self.label_p.grid(row=13,column=1,sticky=N+W)


        self.label_legend1 = Label(master, text="Legend:")
        self.label_legend1.grid(row=2,column=6,sticky=N+E+W)
        self.label_legend2 = Label(master, text="RT60: Reverberation time")
        self.label_legend2.grid(row=3,column=6,sticky=N+W)
        self.label_legend3 = Label(master, text="D: Echo density")
        self.label_legend3.grid(row=4,column=6,sticky=N+W)
        self.label_legend4 = Label(master, text="C: Clarity")
        self.label_legend4.grid(row=5,column=6,sticky=N+W)
        self.label_legend5 = Label(master, text="Tc: Central Time")
        self.label_legend5.grid(row=6,column=6,sticky=N+W)
        self.label_legend6 = Label(master, text="SC: Spectral Centroid")
        self.label_legend6.grid(row=7,column=6,sticky=N+W)
                
                
                
                
        # Sliders
        self.scale_d1 = Scale(master, to_=d1_min, from_=d1_max, resolution=0.01, label="d1", showvalue=False)#, command=self.callback_update_parameters)
        self.scale_d1.bind("<ButtonRelease-1>",self.callback_update_parameters)        
#        self.scale_d1.grid(row=16,column=0,rowspan=8,sticky=N+S+E+W)
        self.scale_g1 = Scale(master,to_=g1_min, from_=g1_max, resolution=0.001, label="g1", showvalue=False)#, command=self.callback_update_parameters)
        self.scale_g1.bind("<ButtonRelease-1>",self.callback_update_parameters)        
        
 #       self.scale_g1.grid(row=16,column=1,rowspan=8,sticky=N+S+E+W)        
        self.scale_da = Scale(master, to_=da_min, from_=da_max, resolution=0.001, label="da", showvalue=False)#, command=self.callback_update_parameters)
        self.scale_da.bind("<ButtonRelease-1>",self.callback_update_parameters)        
        
  #      self.scale_da.grid(row=16,column=2,rowspan=8,sticky=N+S+E+W)      
        self.scale_G = Scale(master,to_=G_min, from_=G_max, resolution=0.001, label="G", showvalue=False)#, command=self.callback_update_parameters)
        self.scale_G.bind("<ButtonRelease-1>",self.callback_update_parameters)        
        
   #     self.scale_G.grid(row=16,column=3,rowspan=8,sticky=N+S+E+W)    
        self.scale_gc = Scale(master, to_=gc_min, from_=gc_max, resolution=0.001, label="gc", showvalue=False)#, command=self.callback_update_parameters)
        self.scale_gc.bind("<ButtonRelease-1>",self.callback_update_parameters)        
        
    #    self.scale_gc.grid(row=16,column=4,rowspan=8,sticky=N+S+E+W)        
        

        # Labels
        
        
        
        self.label_T60 = Label(master, text="Reverberation Time: ")
     #   self.label_T60.grid(row=2,column=6,sticky=N+S+E+W)
        self.label_D = Label(master, text="Echo Density: ")
     #   self.label_D.grid(row=3,column=6,sticky=N+S+E+W)
        self.label_C = Label(master, text="Clarity: ")
     #   self.label_C.grid(row=4,column=6,sticky=N+S+E+W)
        self.label_Tc = Label(master, text="Central Time: ")
     #   self.label_Tc.grid(row=5,column=6,sticky=N+S+E+W)
        self.label_SC = Label(master, text="Spectral Centroid: ")
     #   self.label_SC.grid(row=6,column=6,sticky=N+S+E+W)
        
        
        self.label_d1 = Label(master, text="d_1: ")
    #    self.label_d1.grid(row=7, column=6, sticky=N+S+E+W)

        self.label_g1 = Label(master, text="g_1: ")
    #    self.label_g1.grid(row=8, column=6, sticky=N+S+E+W)

        self.label_da = Label(master, text="d_a: ")
    #    self.label_da.grid(row=9, column=6, sticky=N+S+E+W)

        self.label_gc = Label(master, text="gc: ")
    #    self.label_gc.grid(row=10, column=6, sticky=N+S+E+W)        

        self.label_G = Label(master, text="G: ")
    #    self.label_G.grid(row=11, column=6, sticky=N+S+E+W)
        
        # Buttons
        
        self.button_plot_impulse = Button(master, text="Plot Impulse",command=self.callback_plot_impulse, width=15).grid(row=13,column=6,sticky=N+S+E+W)
 #       self.button_plot_raw = Button(master, text="Plot Raw", width=15,command=self.callback_plot_raw).grid(row=16,column=6,sticky=N+S+E+W)
#        self.button_plot_reverb = Button(master, text="Plot Reverb", width=15, command=self.callback_plot_reverb).grid(row=17, column=6,sticky=N+S+E+W) 
        self.button_play_raw = Button(master, text="Play Dry", bg="green", fg="white", width=15, command=self.callback_play_raw).grid(row=18, column=6,sticky=N+S+E+W)           
        self.button_play_reverb = Button(master, text="Play Reverb", bg="green", fg="white", width=15, command=self.callback_play_reverb).grid(row=19, column=6,sticky=N+S+E+W)         
        self.button_stop = Button(master, text="Stop Playing", bg="red", fg="white", width=15, command=self.callback_stop).grid(row=20, column=6,sticky=N+S+E+W)   
        self.button_save = Button(master, text="Save", fg="white", bg="orange", width=15, command=self.callback_save).grid(row=21, column=6,sticky=N+S+E+W)         
        self.button_reset = Button(master, text="Undo", width=15, command=self.callback_reset).grid(row=22, column=6,sticky=N+S+E+W)                 
        self.button_next = Button(master, text="Next >>", width=15, command=self.callback_next).grid(row=23, column=6,sticky=N+S+E+W)                 
        
        
        
        # Figure
        
        self.figure = Figure( dpi=50)
    
        self.figure.text(0.5,0.5,'No plot selected', weight = "bold", horizontalalignment='center')
        
        
        self.canvas = FigureCanvasTkAgg(self.figure, master=root)
        self.canvas.show()
        self.canvas.get_tk_widget().grid(row=0, column=7, rowspan=17, padx=20,sticky=E+W+N+S)
        self.canvas._tkcanvas.grid(row=0, column=7, rowspan=23)
        
        # Toolbar for canvas
        
        self.toolbar_frame = Frame(master)
        self.toolbar_frame.grid(row=23,column=7,sticky=E+W+N+S, padx=19)
        self.toolbar = NavigationToolbar2TkAgg(self.canvas, self.toolbar_frame)
        

        
        Grid.columnconfigure(master, 7, weight=1)
        Grid.rowconfigure(master, 1, weight=1)
        
        # Status bar

        self.status_bar_text = Label(text="Ready.")
        self.status_bar_text.grid(row=18, column=0, columnspan = 8, sticky=N+S+E, padx=10)
        
        self.lastplot = ''
        self.parameterschanged_render = True
        self.pendingactions = []
        
        # Initial values
        d1t = 0.05
        self.d1 = d1t*self.SR
        dat = 0.006
        self.da = dat*self.SR
        g1 = 0.5
        self.g1 = g1
        G = 0.5
        self.G = G
        gc = 0.5
        self.gc = gc
        
        self.scale_d1.set(d1t)
        self.scale_da.set(dat)
        self.scale_g1.set(g1)
        self.scale_gc.set(gc)
        self.scale_G.set(G)
        
        t = zeros((self.SR*120,))
        t[0] = 1  
        
        self.impulse = t
              

        
        # Presets
        self.presets=[]
        self.var_presets = StringVar()
        self.drop_presets = OptionMenu(master, self.var_presets, self.presets, command=self.callback_load_preset)
        self.drop_presets.grid(row=0, column=6, sticky=N+S+E+W)        
        self.callback_load_presets()
        
        
        # Preset options
        self.button_save_preset = Button(master, text="Save Preset", command=self.callback_save_preset)
        self.button_save_preset.grid(row=1, column=6, sticky=N+E+W)
        
        # Pool
        self.pool = Pool()
        self.pool.set('filename', self.filename)
        self.pool.set('sampleRate', self.SR)
        self.pool.set('nChannels', self.numChannels)
        
        
        
        self.callback_load_preset('almost_no_reverb')

        # Finally        
        self.callback_update_parameters(None)


        
        
        
    def pyaudio_callback_raw(self, in_data, frame_count, time_info, status):
        if self.player_command == 'Stop':
            return (0, pyaudio.paAbort)
        
        data = self.audio[self.player_idx:self.player_idx+frame_count, :]
        
        data = reshape(data, (data.shape[0]*data.shape[1], 1))
        
      #  print data
        self.player_idx += frame_count
        
        
        return (data, pyaudio.paContinue)
        
    
    def play_reverb(self):
            
            self.calculate_impulse_response()
            ly, ry = self.impulse_response_left_channel, self.impulse_response_right_channel
            
            lx = self.audio[:,0]
            rx = self.audio[:,1]
            
            print "Convolving left channel"
            l_out = fftconvolve(ly, lx)
            l_out = l_out/max(l_out)

            print "Convolving right channel"
            r_out = fftconvolve(ry, rx)
            r_out = r_out/max(r_out)
            
            
            
            lim = min(len(l_out), len(r_out))
            
            
            
            if self.numChannels == 1:
                audio_out = l_out[0:lim]
            else:
                audio_out = concatenate((matrix(l_out[0:lim]).T,
                                    matrix(r_out[0:lim]).T                                    
                                    ),
                                    axis=1)
            
            reverb_filename = "%s_reverb_%s" % (self.filename.split('.')[0], self.filename.split('.')[1])
                        
            audio_file = Sndfile(reverb_filename, 'w', Format(self.filename.split('.')[1]), self.numChannels, self.SR)
            audio_file.write_frames(audio_out)
            audio_file.close()
            
            self.reverberated_audio = audio_out
            
            self.reverb_filename = reverb_filename
            
            self.playerprocess = subprocess.Popen("mplayer %s" % reverb_filename, 
                                                  stdout = subprocess.PIPE,
                                                  shell=True,
                                                  preexec_fn=os.setsid)
        
    def play_raw(self):
        self.playerprocess = subprocess.Popen("mplayer %s" % self.filename,
                                              stdout = subprocess.PIPE,
                                              shell=True,
                                              preexec_fn=os.setsid)
        
        
        
        
        
    def remove_action_from_status(self, text):
        
        self.pendingactions.remove(text)
        
        if len(self.pendingactions) == 0:
            self.status_bar_text.config(text='Ready.')
        elif len(self.pendingactions) == 1:
            self.status_bar_text.config(text=self.pendingactions[0]+'.') 
        else:
            self.status_bar_text.config(text=reduce(lambda h,t: h+','+t, self.pendingactions)+'.')        
        
        
    def add_action_to_status(self, text):
        
        self.pendingactions.append(text)
        
        if len(self.pendingactions) == 0:
            self.status_bar_text.config(text='Ready.')
        elif len(self.pendingactions) == 1:
            self.status_bar_text.config(text=text+'.')             
        else:
            self.status_bar_text.config(text=reduce(lambda h,t: h+', '+t, self.pendingactions)+'.')
            
        print self.pendingactions, len(self.pendingactions)
        
        
    def load_song(self, filename):
        tup = AudioLoader(filename=filename)()
        self.audio = tup[0]
        self.SR = tup[1]
        global SC_max
        SC_max = self.SR/4.0
        self.numChannels = tup[2]
        self.label_top.config(text="Training song: %s (sampleRate: %.0f, nChannels: %d) \n %d songs left" % (filename, self.SR, self.numChannels, len(self.files_to_visit)-1),wraplength=500)
        self.saved = False
        
        
        
  
    def estimate_T60(self, d1, g1, gc, G, SR):
        ga = 1/sqrt(2)
        return d1/SR/log(g1)*log(10**-3/ga/(1-gc)/G) 
        
    def calculate_parameters(self,d1,g1):
        
        d2 = int(round((1.5)**(-1)*d1))

        while gcd(d2,d1) != 1:
            d2 += 1
        
        d3 = int(round((1.5)**(-2)*d1))
        
        while gcd(d3, d2) != 1 or gcd(d3, d1) != 1:
            d3 += 1
            
        d4 = int(round((1.5)**(-3)*d1))
        
        while gcd(d4, d3) != 1 or gcd(d4, d2) != 1  or gcd(d4, d1) != 1:
            d4 += 1

            
        d5 = int(round((1.5)**(-4)*d1))

        while gcd(d5, d4) != 1 or gcd(d5, d3) != 1  or gcd(d5, d2) != 1 or gcd(d5, d1) != 1:
            d5 += 1        
        
        d6 = int(round((1.5)**(-5)*d1))
        while gcd(d6, d5) != 1 or gcd(d6, d4) != 1  or gcd(d6, d3) != 1 or gcd(d6, d2) != 1 or gcd(d6, d1) != 1:
            d6 += 1          
        g2 = g1**(1.5)**(-1)*g1
        g3 = g1**(1.5)**(-2)*g1
        g4 = g1**(1.5)**(-3)*g1     
        g5 = g1**(1.5)**(-4)*g1 
        g6 = g1**(1.5)**(-5)*g1        
        
        return (d1, d2, d3, d4, d5, d6, g1, g2, g3, g4, g5, g6)     
    def estimate_C(self, g1, G, gc):
        g2 = g1**(1.5)**(-1)*g1
        g3 = g1**(1.5)**(-2)*g1
        g4 = g1**(1.5)**(-3)*g1     
        g5 = g1**(1.5)**(-4)*g1 
        g6 = g1**(1.5)**(-5)*g1         
        gains = zeros((6,1))
        gains[0] = g1
        gains[1] = g2
        gains[2] = g3
        gains[3] = g4
        gains[4] = g5
        gains[5] = g6
        
        return -10*log10(G**2*(1-gc)/(1+gc)*sum(1/(1-gains**2)))    


    def estimate_D(self, d1, g1, da, SR):
        kS = 20.78125
        
    
        return kS*0.1/d1/da*self.SR**2
                
    def estimate_Tc(self, d1, g1, da, SR):
        delays = zeros((6,))
        gains = zeros((6,1))        
        (delays[0],delays[1],delays[2],delays[3],delays[4],delays[5],gains[0],gains[1],gains[2],gains[3],gains[4],gains[5]) = self.calculate_parameters(d1,g1) 
        return sum(delays/SR*gains**2/(1-gains**2)**2)/sum(gains**2/(1-gains**2)) + da/SR           
        
    def update_parameters_high(self, _):

        self.T60_old = self.T60
        self.D_old = self.D
        self.C_old = self.C
        self.Tc_old = self.Tc
        self.SC_old = self.SC
        
        T60 = self.scale_T60.get()
        D = self.scale_D.get()
        C = self.scale_C.get()
        Tc = self.scale_Tc.get()
        SC = self.scale_SC.get()
        print self.SR        
        print (T60, D, C, Tc, SC)
        (d1t, dat, g1, gc, G) = inverse_mapping(T60,D,C,Tc,SC,SR=self.SR)
        
        print "da",(d1t, dat, g1, gc, G)         
        
        self.scale_d1.set(d1t)
        self.scale_da.set(dat)
        self.scale_g1.set(g1)
        self.scale_gc.set(gc)
        self.scale_G.set(G)
        self.parameterschanged_render = True        

    def callback_update_parameters_high(self, _):
        print("callback_update_parameters_high")        
        self.update_parameters_high(_)        

        
        self.update_parameters(_)

        self.callback_plot_impulse()
    
    def update_parameters(self, _):
        SR = self.SR
        d1t = self.scale_d1.get()
        
        d1 = round(d1t*SR)
        g1 = self.scale_g1.get()
        dat = self.scale_da.get()
        da = round(dat*SR)
        G = self.scale_G.get()
        gc = self.scale_gc.get()
   
        
        T60 = self.estimate_T60(d1,g1,gc,G,SR)
        D = self.estimate_D(d1, g1, da, SR)
        C = self.estimate_C(g1, G, gc)
        Tc = self.estimate_Tc(d1,g1,da,SR)
        SC = self.estimate_SC(gc, SR)

        
        self.d1_old = self.d1
        self.G_old = self.G
        self.gc_old = self.gc
        self.g1_old = self.g1
        self.da_old = self.da
        
        self.d1 = d1
        self.G = G
        self.gc = gc
        self.g1 = g1
        self.da = da
        
        
        
        self.pool.set('parameters.d1', d1t)
        self.pool.set('parameters.G', G)
        self.pool.set('parameters.gc', gc)
        self.pool.set('parameters.g1', g1)
        self.pool.set('parameters.da', dat)
        
        
        
        self.T60 = T60
        self.D = D
        self.Tc = Tc
        self.SC = SC
        self.C = C
        
        self.pool.set('parameters.T60', T60)
        self.pool.set('parameters.D', D)
        self.pool.set('parameters.C', C)
        self.pool.set('parameters.Tc', Tc)
        self.pool.set('parameters.SC', SC)
        
        self.label_T60.config(text="Reverberation Time: %.3fs" % T60)
        self.label_D.config(text="Echo Density: %.3f at 0.1s" % D)
        self.label_C.config(text="Clarity: %.3f dB" % C)   
        self.label_Tc.config(text="Central Time: %.3fs" % Tc)    
        self.label_SC.config(text="Spectral Centroid: %.3f Hz" % SC)        
        
        self.label_d1.config(text="d_1: %.3fs" % d1t)
        self.label_g1.config(text="g_1: %.3f" % g1)
        self.label_da.config(text="d_a: %.3fs" % dat)   
        self.label_gc.config(text="g_c: %.3f" % gc)    
        self.label_G.config(text="G: %.3f" % G)           
        self.lastplot = ''


        self.parameterschanged_render = True        
        
        
        
    def callback_update_parameters(self,_):
        self.update_parameters(_)
        
        
    def estimate_SC(self, gc, SR):
        n = arange(0, SR/2+1)
        return sum(n/(1+gc**2-2*gc*cos(2*pi*n/SR)))/sum(1/(1+gc**2-2*gc*cos(2*pi*n/SR)))        
       
        
        
        
    def say_hi(self):
        print "Hi, there"
        
        

    def callback_save_preset(self):
        print self.presets
        name= tkFileDialog.asksaveasfilename(defaultextension=".pre",filetypes=[("presets",".pre")])
        print name
        p = Pool()
        p.set('d1',self.scale_d1.get())
        p.set('g1',self.scale_g1.get())
        p.set('da',self.scale_da.get())
        p.set('gc',self.scale_gc.get())
        p.set('G',self.scale_G.get())
        y = YamlOutput(filename=name)
        y(p)
        
        
        #f.write(p)            
        #f.close()
	self.callback_load_presets()
                
    
    def callback_load_presets(self):
        presets = glob('*.pre')

        self.presets = tuple([os.path.splitext(f)[0] for f in presets])

        
        self.drop_presets.destroy()

        if len(self.presets) == 0:
#            self.drop_presets = OptionMenu(root, self.var_presets, ('No presets') ,command=self.callback_load_preset)
            pass
        else:
            self.drop_presets = OptionMenu(root, self.var_presets, *tuple(self.presets),command=self.callback_load_preset)
            self.drop_presets.grid(row=0, column=6,sticky=S+W+E)
        
    
    def callback_load_preset(self, preset):
        self.var_presets.set(preset)
        print "loading preset:", preset
        
        p = YamlInput(filename = '%s.pre' % preset)()
        
        self.scale_d1.set(p['d1'])
        self.scale_g1.set(p['g1'])        
        self.scale_da.set(p['da'])  
        self.scale_gc.set(p['gc'])  
        self.scale_G.set(p['G'])        
        
        self.callback_update_parameters(None)

        self.scale_T60.set(self.T60)
        self.scale_D.set(self.D)
        self.scale_C.set(self.C)
        self.scale_Tc.set(self.Tc)
        self.scale_SC.set(self.SC)
        
        self.plot_impulse()
        
    def callback_plot_impulse(self):
            self.plot_impulse()
    
    def calculate_impulse_response(self):
            self.add_action_to_status('Calculating impulse response')     
            N = self.numChannels
            SR = self.SR
            T = 1.0/self.SR
            
            delta = self.impulse[0:int(self.T60*self.SR)]

            
            d1 = int(self.d1)
            g1 = self.g1
            da = int(self.da)
            G = self.G
            gc = self.gc
            
            mt = 0.002
            m = int(mt*SR)
                   
            (ly, ry) = zafar(delta,delta,d1,g1,da,G,gc,m)
            
            limt = 2*self.T60
            
            lim = int(limt*SR)            
            t = arange(0, lim)*T

            padded_y = zeros(shape(t))
            padded_y[0:len(ly)] = ly
            
            
            padded_y = zeros(shape(t))
            padded_y[0:len(ry)] = ry
            
            ry = padded_y                        
            
            self.impulse_response_left_channel = ly
            self.impulse_response_right_channel = ry
            
            
            self.remove_action_from_status('Calculating impulse response')
            
        
    
    def plot_impulse(self):
        if self.lastplot != 'impulse':
            self.add_action_to_status('Plotting impulse response')
            N = self.numChannels
            SR = self.SR
            T = 1.0/self.SR
            limt = max(self.T60,1.0)
            
            lim = int(limt*SR)           
            delta = self.impulse[0:int(lim)]
     #       print "delta:"
      #      print delta
            
            d1 = int(self.d1)
            g1 = self.g1
            da = int(self.da)
            G = self.G
            gc = self.gc
            
            mt = 0.002
            m = int(mt*SR)
                   
            print "Calculating zafar"
            (ly, ry) = zafar(delta,delta,d1,g1,da,G,gc,m)
            
            print "Stopped calculating zafar"#ly.shape

     #       print "lim:", lim
            
            t = arange(0, lim)*T
           # print t
            
            # Pad ly to t
   #         print "Shape ly"
     #       print ly
     #       print len(ly)
            padded_y = zeros(shape(t))
            padded_y[0:len(ly)] = ly
            
            print "Padded y"
            #print padded_y
            
           # ly = padded_y
            
            # Pad ry to t
            
            padded_y = zeros(shape(t))
            padded_y[0:len(ry)] = ry
            
            ry = padded_y            
            
            
            
            self.figure.clear()
            
         #   print "Passed A"
            subplt0 =  self.figure.add_subplot(2,1,1)
#            subplt0.ion()

            subplt0.plot(t[1:lim],abs(ly[1:lim]))
            subplt0.set_title('Left Channel')
            subplt0.set_xlabel('time (s)')
            subplt0.set_ylabel('amplitude')   
            subplt0.axvspan(self.d1/self.SR,self.d1/self.SR+0.1, alpha=0.1,color='cyan')
            subplt0.axvline(self.Tc, color='red', linestyle='--')
            subplt0.axvline(self.d1/self.SR+0.1, color='cyan', linestyle='--')
            subplt0.axhline(0.001, color='black', linestyle='--')
            subplt0.axvline(self.d1/self.SR, color='cyan', linestyle='--')
            
            subplt0.annotate('Central Time (Tc)', xy=(self.Tc, 0.25), xytext=(self.Tc+0.01, 0.25), arrowprops=dict(facecolor='black',width=1))
      #      subplt0.annotate('Echo Density (D) Measurement Range',  xytext=(self.d1/self.SR, 0.62), arrowprops=dict(facecolor='black',width=1))
            
#            
            
            subplt1 = self.figure.add_subplot(2,1,2,sharex=subplt0)
            subplt1.set_title('Right Channel')
            
            subplt1.plot(t[1:lim],abs(ry[1:lim]))
            subplt1.set_xlabel('time (s)')
            subplt1.set_ylabel('amplitude')
            subplt1.axvspan(self.d1/self.SR,self.d1/self.SR+0.1, alpha=0.1,color='cyan')
            
            subplt1.axvline(self.Tc, color='red', linestyle='--')
            subplt1.axvline(self.d1/self.SR+0.1, color='cyan', linestyle='--')
            subplt1.axvline(self.d1/self.SR, color='cyan', linestyle='--')
            subplt1.axhline(0.001, color='black', linestyle='--')

            
            self.figure.suptitle("Reverberation Impulse Response")
            
#            print "Passed B"
#            
            self.remove_action_from_status('Plotting impulse response')
            self.canvas.draw()
            
            self.lastplot = 'impulse'
#            
#            thread.exit_thread()

                
    
    def plot_raw(self):
        if self.lastplot != 'raw':
            self.add_action_to_status('Plotting raw')
            N = self.numChannels
            print "Channels: %d" % N
            L = len(self.audio[:,0])
            
            
            
            
            self.figure.clear()
            
            T = 1.0/self.SR
            t = arange(0, L)*T
            
            oldsubplt = None
            for n in range(0, N):
                if oldsubplt is not None:
                    subplt =  self.figure.add_subplot(N,1,n+1,sharex=oldsubplt)
                else:
                    subplt = self.figure.add_subplot(N,1,n+1)
                subplt.plot(t,self.audio[:,n])
                subplt.set_title('Channel %d' % n)
                subplt.set_xlabel('time (s)')
                subplt.set_ylabel('amplitude')
                
                oldsubplt = subplt
    
                
            self.figure.suptitle('Raw Signal')
            self.canvas.draw()
            
            self.lastplot = 'raw'
            self.remove_action_from_status('Plotting raw')    
       #     thread.exit_thread()
    def callback_plot_raw(self):
        try:
            #thread.start_new_thread(self.plot_raw, ())
            self.plot_raw()
        except:
            print "[EE] Could not start new thread"
            


       # show()
    
    def plot_reverb(self):       
        if self.lastplot != 'reverb':
            self.add_action_to_status('Plotting reverberated signal')
            
            self.calculate_impulse_response()
            ly, ry = self.impulse_response_left_channel, self.impulse_response_right_channel
            
            lx = self.audio[:,0]
            rx = self.audio[:,1]
            
            print "Concolving left channel"
            l_out = fftconvolve(ly, lx)

            print "Convolving right channel"
            r_out = fftconvolve(ry, rx)
            
            
            
            lim = min(len(l_out), len(r_out))            
#            N = self.numChannels
#            SR = self.SR
#            T = 1.0/self.SR
#            
#            
#            d1 = int(self.d1)
#            g1 = self.g1
#            da = int(self.da)
#            G = self.G
#            gc = self.gc
#            
#            mt = 0.002
#            m = int(mt*SR)
#            
#            lchannel = ravel(self.audio[:,0])
#            rchannel = ravel(self.audio[:,1])
#                   
#            print "Calculating zafar"
#            
#            if self.parameterschanged_render == True:
#                (ly, ry) = zafar(lchannel,rchannel,d1,g1,da,G,gc,m)
#                
#                self.reverberated_signal_left_channel = ly
#                self.reverberated_signal_right_channel = ry
#                
#                self.parameterschanged_render = 0
#            else:
#                ly = self.reverberated_signal_left_channel
#                ry = self.reverberated_signal_right_channel
#            
#            print "Stopped calculating zafar"#ly.shape
#           # limt = self.T60
#            
         #   lim = int(limt*SR)
            
    #        lim = len(lchannel)
     #       print "lim:", lim
            T = 1/self.SR
            t = arange(0, lim)*T
           # print t
            
            # Pad ly to t
       #     print "Shape ly"
         ##   print ly
         #   print len(ly)
           # padded_y = zeros(shape(t))
           # padded_y[0:len(ly)] = ly
            
   #         print "Padded y"
            #print padded_y
            
           # ly = padded_y
            
            # Pad ry to t
            
          #  padded_y = zeros(shape(t))
          #  padded_y[0:len(ry)] = ry
            
        #    ry = padded_y            
      #      
            
            
            self.figure.clear()
            
         #   print "Passed A"
            subplt0 =  self.figure.add_subplot(2,1,1)

            subplt0.semilogy(t,l_out[0:lim])
            subplt0.set_title('Left Channel')
            subplt0.set_xlabel('time (s)')
            subplt0.set_ylabel('amplitude')   


            subplt1 = self.figure.add_subplot(2,1,2,sharex=subplt0)
            subplt1.set_title('Right Channel')
            
            subplt1.semilogy(t,r_out[0:lim])
            subplt1.set_xlabel('time (s)')
            subplt1.set_ylabel('amplitude')
            
            self.figure.suptitle("Reverberated Signal")
            
            self.remove_action_from_status('Plotting reverberated signal')
            self.canvas.draw()
            
            self.lastplot = 'reverb'           
      #      thread.exit_thread()        
    def callback_plot_reverb(self):
        self.plot_reverb()

    def callback_play_raw(self):
        print "[II] Called callback_play_raw"        
        try:
            self.playerprocess.terminate()
        except:
            pass
        self.play_raw()
        
    def callback_play_reverb(self):
        
        print "[II] Called callback_play_reverb"
        try:
            self.playerprocess.terminate()
        except:
            pass
            
        self.play_reverb()
    
    def callback_stop(self):
        self.playerprocess.terminate()
    
    def callback_save(self):
        outf = "%s_parameters.yaml" % self.filename.split('.')[0]
        out = YamlOutput(filename=outf)
        out(self.pool)
        print "[II] Parameters Saved"
        self.saved = True
    
    def callback_reset(self):
        d1 = self.d1
        g1 = self.g1
        da = self.da
        G = self.G
        gc = self.gc
        
        self.d1 = self.d1_old
        self.g1 = self.g1_old
        self.G = self.G_old
        self.gc = self.gc_old
        self.da = self.da_old

        self.T60 = self.T60_old
        self.D = self.D_old
        self.C = self.C_old
        self.Tc = self.Tc_old
        self.SC = self.SC_old
        
        self.scale_d1.set(self.d1/self.SR)
        self.scale_g1.set(self.g1)
        self.scale_da.set(self.da/self.SR)
        self.scale_G.set(self.G)
        self.scale_gc.set(self.gc)


        self.scale_T60.set(self.T60)
        self.scale_C.set(self.C)
        self.scale_D.set(self.D)
        self.scale_Tc.set(self.Tc)
        self.scale_SC.set(self.SC)

        self.callback_update_parameters_high(None)

    
    def callback_next(self):
        if self.saved == False:
            tkMessageBox.showerror("File not saved", "You need to save your changes first")
            return
            
            
        self.visited_files.append(self.filename)
        self.sessionpool.add('visited_files', self.filename)
        self.files_to_visit.pop()
        self.sessionpool.remove('files_to_visit')
        for i in self.files_to_visit:
            self.sessionpool.add('files_to_visit', i)
        outp = YamlOutput(filename="session.yaml")(self.sessionpool)
            
        if len(self.files_to_visit) == 0:
            tkMessageBox.showinfo("Congratulations!", "You finished the training session!")
            self.master.destroy()
            return            
        self.filename = self.files_to_visit[-1]
        self.load_song(self.filename)
        
        
        
    
if __name__ == "__main__":
    if len(argv) != 2:
        print "[EE] Wrong number of arguments"
        print "[II] Correct syntax is:"
        print "[II] \t%s <trainingdir>"
        print "[II] where <trainingdir> contains the segments in .wav format and their corresponding .yaml files"

        exit(-1)    
        
    print "[II] Using directory: %s" % argv[1]
    root = Tk()
    app = UI(root, argv[1])
    root.mainloop()