from vampy import *
from numpy import zeros,float64, float32, array
import sys
from scipy.signal import butter
import numpy as np

''' 
Define a common base class where we define the methods common to plugins without 
a fusion process involved. 
The base class implements adaptive whitening and onset location backtracking.
This also makes individual plugins easier to change / manage / overview.
'''

class OnsetDetectBase(object):
	# WARNING: Apparently vampy doesn't handle errors in super classes (bases) as gracefully as they are handeled
	# in the single class scenario with no inheritnace. This is to be fixed later in vampy itself.
	# For now syntax errors, missing imports, etc... are likely to cause segfault, without printing a detailed python traceback.
	# However, the source of the error is still printed in debug mode, so at least we know which function to fix.
	
	def __init__(self,inputSampleRate):
		self.vampy_flags = vf_DEBUG

		# basic common parameters
		self.preferredStepSecs = 0.01161
		# self.preferredStepSecs = 0.02322
		self.inputSampleRate = inputSampleRate
		self.stepSize = 0
		self.blockSize = 0
		self.channels = 0
		
		# user configurable parameters
		self.threshold = 50
		self.delta_threshold = 0.0
		self.backtracking_threshold = 1.9
		self.polyfitting_on = True
		self.medfilter_on = True
		self.LPfilter_on = True
		self.whitening_on = False
		self.simplePick = False
		
		# whitening
		self.whitenRelaxCoeff = 0.9997
		self.whitenFloor = 0.01
		self.magPeaks = None
		self.medianWin = 7
		self.aCoeffs = [1.0000, -0.5949, 0.2348]
		self.bCoeffs = [0.1600,  0.3200, 0.1600]
		self.cutoff = 0.34
		
		
	def initialise(self,channels,stepSize,blockSize):
		self.channels = channels
		self.stepSize = stepSize		
		self.blockSize = blockSize
		self.half_length = self.blockSize * 0.5 + 1.0
		self.magPeaks = zeros(self.half_length, dtype = float64)
		return True
		
	def reset(self):
		self.magPeaks = None
		return None
	
	def getMaker(self):
		return 'Mi Tian, Testing'

	def getIdentifier(self):
		return 'vampy-base'		
		
	def getPreferredBlockSize(self):
		'''Preferred window size is twice the preferred step size'''
		# return 2048
		return int(self.getPreferredStepSize() * 2)
		
	def getPreferredStepSize(self):
		'''Preferred block size is set to 256 in the QM Vamp plugin in case SR is 22.5kHz'''
		step = int(self.inputSampleRate * self.preferredStepSecs + 0.0001)
		if step < 1 : return 1
		return step
		# return 1024
		
	def getMaxChannelCount(self):
		return 1
		
	def getInputDomain(self):
		return FrequencyDomain 
		
	def getParameterDescriptors(self):
		'''Define all common parameters of the plugins.'''
		
		threshold = ParameterDescriptor()
		threshold.identifier ='threshold'
		threshold.name ='Detection Sensitivity'
		threshold.description = 'Detection Sensitivity'
		threshold.unit = '%'
		threshold.minValue = 0
		threshold.maxValue = 100
		threshold.defaultValue = 50
		threshold.isQuantized = False
		
		delta_thd = ParameterDescriptor()
		delta_thd.identifier ='dthreshold'
		delta_thd.name ='Delta Threshold'
		delta_thd.description = 'Delta threshold used for adaptive theresholding using the median of the detection function'
		delta_thd.unit = ''
		delta_thd.minValue = -1.0
		delta_thd.maxValue = 1.0
		delta_thd.defaultValue = 0.0
		delta_thd.isQuantized = False
		
		# NOTE: GF: Not sure this should really be called a threshold. 'Tolerance' may be better.
		bt_thd = ParameterDescriptor()
		bt_thd.identifier ='bt-threshold'
		bt_thd.name ='Backtracking Threshold'
		bt_thd.description = 'Backtracking threshold used determine the stopping condition for backtracking the onset location'
		bt_thd.unit = ''
		bt_thd.minValue = -1.0
		bt_thd.maxValue = 3.0
		bt_thd.defaultValue = 1.9
		bt_thd.isQuantized = False
			
		cutoff = ParameterDescriptor()
		cutoff.identifier ='cut-off'
		cutoff.name ='cut off value'
		cutoff.description = 'low pass filter cut off value'
		cutoff.unit = ''
		cutoff.minValue = 0.1
		cutoff.maxValue = 0.6
		cutoff.defaultValue = 0.34
		cutoff.isQuantized = False
		
		med_thd = ParameterDescriptor()
		med_thd.identifier ='med-threshold'
		med_thd.name ='Median filter window'
		med_thd.description = 'Median filter window size'
		med_thd.unit = ''
		med_thd.minValue = 3.0
		med_thd.maxValue = 12.0
		med_thd.defaultValue = 7.0
		med_thd.isQuantized = True
		med_thd.quantizeStep = 1
		
		# save some typing by defining a descriptor type
		boolDescriptor = ParameterDescriptor()
		boolDescriptor.isQuantized = True
		boolDescriptor.minValue= 0
		boolDescriptor.maxValue= 1
		boolDescriptor.quantizeStep = 1
		
		polyfit = ParameterDescriptor(boolDescriptor)
		polyfit.identifier='polyfit'
		polyfit.name='polynomial fitting'
		polyfit.description='Use polynomial fitting to evaluate detection function peaks.'
		
		medfilt = ParameterDescriptor(boolDescriptor)
		medfilt.identifier='medfilt'
		medfilt.name='median filtering'
		medfilt.description='Use median filtering'
		
		filtfilt = ParameterDescriptor(boolDescriptor)
		filtfilt.identifier='filtfilt'
		filtfilt.name='low-pass filtering'
		filtfilt.description='Use zero-phase foward-backward low-pass filtering'
		
		whitening = ParameterDescriptor(boolDescriptor) 
		whitening.identifier='whitening'
		whitening.name='Adaptive whitening'
		whitening.description='Turn adaptive whitening on or off'
		whitening.defaultValue = False
		
		return ParameterList(threshold, delta_thd, bt_thd, cutoff, med_thd, whitening, polyfit, medfilt, filtfilt)
		

	def setParameter(self,paramid,newval):	
		if paramid == 'threshold' :
			self.threshold = newval
			print >> sys.stderr, "sensitivity threshold: ", newval
		if paramid == 'dthreshold' :
			self.delta_threshold = newval
			print >> sys.stderr, "delta threshold: ", newval
		if paramid == 'bt-threshold' :
			self.backtracking_threshold = newval
			print >> sys.stderr, "backtracking threshold: ", newval
		if paramid == 'cut-off' :
			self.cutoff = newval
			self.bCoeffs, self.aCoeffs = butter(2, self.cutoff)
			print >> sys.stderr, "low pass filter cut off value: ", newval
		if paramid == 'med-threshold' :
			self.medianWin = newval
			print >> sys.stderr, "meidan filter windown: ", newval
		if paramid == 'medfilt' :
			self.medfilter_on = newval == 1.0
			print >> sys.stderr, "median filering: ", self.medfilter_on, newval
		if paramid == 'filtfilt' :
			self.LPfilter_on = newval == 1.0
			print >> sys.stderr, "foward-backward filering: ", self.LPfilter_on, newval		
		if paramid == 'polyfit' :
			self.polyfitting_on = newval == 1.0
			print >> sys.stderr, "polynomial fitting: ", self.polyfitting_on, newval
		if paramid == 'whitening' :
			self.whitening_on = newval == 1.0
			print >> sys.stderr, "whitening: ", self.whitening_on, newval
		
		return None

		
	def getParameter(self,paramid):
		if paramid == 'whitening' :
			if self.whitening_on :
				return 1.0
			else :
				return 0.0
		if paramid == 'medfilt' :
			if self.medfilter_on :
				return 1.0
			else :
				return 0.0
		if paramid == 'filtfilt' :
			if self.LPfilter_on :
				return 1.0
			else :
				return 0.0		
		if paramid == 'polyfit' :
			if self.polyfitting_on :
				return 1.0
			else :
				return 0.0
		if paramid == 'threshold' :
			return self.threshold
		if paramid == 'dthreshold' :
			return self.delta_threshold
		if paramid == 'bt-threshold' :
			return self.backtracking_threshold
		# if paramid == 'tol-threshold' :
		if paramid == 'med-threshold' :
			return self.medianWin	
		if paramid == 'cut-off' :
			return self.cutoff	
		return 0.0
		
			
			
	def getGenericOutputDescriptors(self):
		'''Define 3 outputs ike in the QM plugin. First is the raw detecion function, second is the smoothed one,
		third is the actual note onset outputs. Note: in QM-Vamp the onsets are the first output.'''
		# We call this getGenericOutputDescriptors as we don't want the base to have real outputs. 
		# Identifiers shoudl be defined in the sub-classes therefore they are ommitted here.

		DF_Descriptor = OutputDescriptor()
		DF_Descriptor.hasFixedBinCount=True
		DF_Descriptor.binCount=1
		DF_Descriptor.hasKnownExtents=False
		DF_Descriptor.isQuantized=False
		DF_Descriptor.sampleType = OneSamplePerStep
		DF_Descriptor.unit = ''		
		DF_Descriptor.name = 'Onset Detection Function'
		DF_Descriptor.description ='Onset Detection Function'

		# NOTE: Just change what we need, all oter parameters are inherited from DF_Descriptor
		SDF_Descriptor = OutputDescriptor(DF_Descriptor)
		SDF_Descriptor.name = 'Smoothed Onset Detection Function'
		SDF_Descriptor.description ='Smoothed Onset Detection Function'
		SDF_Descriptor.sampleType = VariableSampleRate
		SDF_Descriptor.sampleRate = 1.0 / self.preferredStepSecs
		
		Onset_Descriptor = OutputDescriptor()
		Onset_Descriptor.name = 'Onsets'
		Onset_Descriptor.description ='Onsets using spectral difference'
		Onset_Descriptor.hasFixedBinCount=True
		Onset_Descriptor.binCount=0
		Onset_Descriptor.hasKnownExtents=False
		Onset_Descriptor.isQuantized=False
		Onset_Descriptor.sampleType = VariableSampleRate
		Onset_Descriptor.unit = ''
				
		return DF_Descriptor, SDF_Descriptor, Onset_Descriptor


	def backtrack(self, onset_index, smoothed_df):
		'''Backtrack the onsets to an earlier 'perceived' location from the actually detected peak...
		This is based on the rationale that the perceived onset tends to be a few frames before the detected peak.
		This tracks the position in the detection function back to where the peak is startng to build up.
		Notice the "out of the blue" parameter: 0.9. (Ideally, this should be tested, evaluated and reported...)'''
		prevDiff = 0.0
		while (onset_index > 1) :
			diff = smoothed_df[onset_index] - smoothed_df[onset_index-1]
			if diff < prevDiff * self.backtracking_threshold : break
			prevDiff = diff
			onset_index -= 1
		return onset_index
		
	def trackDF(self, onset1_index, df2):
		'''In the second round of detection, remove the known onsets from the DF by tracking from the peak given by the first round
		to a valley to deminish the recognised peaks on top of which to start new detection.'''	
		for idx in xrange(len(onset1_index)) :
			remove = True
			for i in xrange(onset1_index[idx], 1, -1) :
				if remove :
					if df2[i] >= df2[i-1] :	
						df2[i] == 0.0
					else:
						remove = False
		return df2	
		
	def whiten(self,magnitudeSpectrum):
		'''This function reproduces adaptive whitening as described in Dan Stowell's paper.'''
		if self.magPeaks is None :
			self.magPeaks = zeros(self.half_length, dtype = float32)
		m = array(magnitudeSpectrum, dtype=float32)
		idx = m < self.magPeaks
		# print " m", m[idx]
		
		m[idx] += (self.magPeaks[idx] - m[idx]) * self.whitenRelaxCoeff
		m[m < self.whitenFloor] = self.whitenFloor
		self.magPeaks = m

		magnitudeSpectrum /= m
		
		return magnitudeSpectrum
	
	

	
