import os, sys
import numpy as np
import scipy as sp
from scipy.ndimage.filters import *
from vampy import *
from copy import deepcopy

class HPspectrogram: 
	
	def __init__(self, inputSampleRate):

		# flags:
		self.vampy_flags = vf_DEBUG | vf_ARRAY | vf_REALTIME
		self.inputSampleRate = inputSampleRate
		self.stepSize = 1024
		self.blockSize = 2048
		self.channels = 1
		self.harmonic_medwin = 8 	
		self.percussive_medwin = 8 	
		self.harmonic_thresh = 0.5
		self.percussive_thresh = 0.5
		self.whitening_on = False
		self.whitenRelaxCoeff = 0.9997
		self.whitenFloor = 0.01
		self.magPeaks = None
		
	def initialise(self, channels, stepSize, blockSize):
		self.channels = channels
		self.stepSize = stepSize		
		self.blockSize = blockSize
		self.maglist = []
		self.timestamp = []
		return True
		
	def reset(self):
		# reset any initial conditions
		self.maglist = []
		self.timestamp = []
		return None
	
	def getMaker(self):
		return 'Mi Tian'
	
	def getName(self):
		return 'Harmonic percussive spectrogram'
		
	def getIdentifier(self):
		return 'hps'

	def getDescription(self):
		return 'Harmonic and percussive emphasised spectrograms.'
	
	def getCopyright(self):
		return ''

	def getMaxChannelCount(self):
		return 1
		
	def getInputDomain(self):
		return FrequencyDomain
			
	def getOutputDescriptors(self):

		#Generic values are the same for all
		Generic = OutputDescriptor()
		Generic.hasFixedBinCount=False
		Generic.hasKnownExtents=False
		Generic.isQuantized=False
		Generic.sampleType = OneSamplePerStep
		Generic.unit = 'Hz'		
		
		HSPEC = OutputDescriptor(Generic)
		HSPEC.identifier = 'gt'
		HSPEC.name = 'Harmonic emphasised spectrogram'
		HSPEC.description = 'Harmonic emphasised spectrogram'
		HSPEC.binCount = self.nBands
		
		PSPEC = OutputDescriptor(Generic)
		PSPEC.identifier = 'gt'
		PSPEC.name = 'Percussive emphasised spectrogram'
		PSPEC.description = 'Percussive emphasised spectrogram'
		
		return OutputList(HSPEC, PSPEC)

	def getParameterDescriptors(self):
		
		harmonic_thresh = ParameterDescriptor()
		harmonic_thresh.identifier='percussive_thresh'
		harmonic_thresh.name='percussive_thresh'
		harmonic_thresh.description='Binarisation threshold for harmonic emphasised magnitude spectrogram'
		harmonic_thresh.unit='v'
		harmonic_thresh.minValue=0
		harmonic_thresh.maxValue=1
		harmonic_thresh.defaultValue=0.5
		harmonic_thresh.isQuantized=False

		percussive_thresh = ParameterDescriptor(harmonic_thresh)
		percussive_thresh.identifier='percussive_thresh'
		percussive_thresh.name='percussive_thresh'
		percussive_thresh.description='Binarisation threshold for percussive emphasised magnitude spectrogram'
		
		harmonic_medwin = ParameterDescriptor()
		harmonic_medwin.identifier = 'harmonic_medwin'
		harmonic_medwin.description = 'Median window length for harmonic part.'
		harmonic_medwin.name = 'Median window length for harmonic part.'
		harmonic_medwin.minValue = 1
		harmonic_medwin.defaultValue = 8
		harmonic_medwin.maxValue = 15
		harmonic_medwin.isQuantized = True

		percussive_medwin = ParameterDescriptor(harmonic_medwin)
		percussive_medwin.identifier = 'percussive_medwin'
		percussive_medwin.description = 'Median window length for percussive part.'
		percussive_medwin.name = 'Median window length for percussive part.'
		
		boolDescriptor = ParameterDescriptor()
		boolDescriptor.isQuantized = True
		boolDescriptor.minValue= 0
		boolDescriptor.maxValue= 1
		boolDescriptor.quantizeStep = 1
		
		whitening = ParameterDescriptor(boolDescriptor) 
		whitening.identifier='whitening'
		whitening.name='Adaptive whitening'
		whitening.description='Turn adaptive whitening on or off'
		whitening.defaultValue = False
		
		return ParameterList(harmonic_thresh, percussive_thresh, harmonic_medwin, percussive_medwin, whitening)

	def setParameter(self, paramid, newval):
		if paramid == 'percussive_thresh' :
			self.percussive_thresh = newval
		if paramid == 'harmonic_medwin' :
			self.harmonic_medwin = newval
		if paramid == 'harmonic_medwin' :
			self.harmonic_medwin = newval
		if paramid == 'percussive_medwin' :
			self.percussive_medwin = newval
		if paramid == 'whitening' :
			self.whitening_on = newval == 1.0
			
		return None
		
	def getParameter(self, paramid):
		if paramid == 'percussive_thresh' :
			return self.percussive_thresh
		if paramid == 'harmonic_medwin' :
			return self.harmonic_medwin
		if paramid == 'percussive_medwin' :
			return self.percussive_medwin
		if paramid == 'harmonic_medwin' :
			return self.harmonic_medwin	
		if paramid == 'whitening' :
			if self.whitening_on :
				return 1.0
			else :
				return 0.0
		else:
			return 0.0

	def whiten(self, magnitudeSpectrogram):
		'''This function reproduces adaptive whitening as described in Dan Stowell's paper.'''
	
		half_length = self.blockSize * 0.5 + 1.0
		nFrames = magnitudeSpectrogram.shape[0]
		whitened_ms = np.zeros_like(magnitudeSpectrogram)
		
		if self.magPeaks is None :
			self.magPeaks = np.zeros(half_length, dtype = float32)
	
		for i in xrange(nFrames):
			m = magnitudeSpectrogram[i, :]	
			idx = m < self.magPeaks
			m[idx] += (self.magPeaks[idx] - m[idx]) * self.whitenRelaxCoeff
			m[m < self.whitenFloor] = self.whitenFloor
			self.magPeaks = m
		
			whitened_ms[i, :] = magnitudeSpectrogram[i, :] / m
		
		return whitened_ms
			
	def process(self, inputbuffers, timestamp):

		output_featureSet = FeatureSet()
		self.timestamp.append(timestamp)
		complexSpectrum = inputbuffers[0]		
		magnitudeSpectrum = abs(complexSpectrum) 
		self.maglist.append(magnitudeSpectrum)

		return output_featureSet
	
	def getRemainingFeatures(self):
		
		output_featureSet = FeatureSet() 
		
		nFrames = len(self.maglist)
		magnitudeSpectrogram = np.array(self.maglist)
		
		if self.whitening_on:
			magnitudeSpectrogram = self.whiten(magnitudeSpectrogram)
		else:
			magnitudeSpectrogram = (magnitudeSpectrogram - np.min(magnitudeSpectrogram)) / (np.max(magnitudeSpectrogram) - np.min(magnitudeSpectrogram))
		
		harmonic_ma = median_filter(magnitudeSpectrogram, size=(self.harmonic_medwin, 1))
		percussive_ma = median_filter(magnitudeSpectrogram, size=(1, self.percussive_medwin))
		
		harmonicSpectrogram = magnitudeSpectrogram * (harmonic_ma>=percussive_ma).astype(float)
		percussiveSpectrogram = magnitudeSpectrogram * (harmonic_ma<percussive_ma).astype(float)

		output_featureSet[0] = flist0 = FeatureList()
		output_featureSet[1] = flist1 = FeatureList()
		for frame in xrange(nFrames):
			f = Feature(timestamp = self.timestamp[frame], values = ma_harmonicSpectrogram[frame, :])
			flist0.append(f)
			f = Feature(timestamp = self.timestamp[frame], values = ma_percussiveSpectrogram[frame, :])
			flist1.append(f)
		
		return output_featureSet