# Part of DML (Digital Music Laboratory)
# Copyright 2014-2015 Daniel Wolff, City University
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA

# -*- coding: utf-8 -*-

# json testfile
#{ "module":"chord_seq_key_relative",
#      "function":"aggregate",
#      "arguments": [[
#      {"keys": { "tag": "csv", "value":"D:\\mirg\\Chord_Analysis20141216\\Beethoven\\qm_vamp_key_standard.n3_50ac9\\1CD0000653_BD01_vamp_qm-vamp-plugins_qm-keydetector_key.csv"},
#          "chords": { "tag": "csv", "value":"D:\\mirg\\Chord_Analysis20141216\\Beethoven\\chordino_simple.n3_1a812\\1CD0000653_BD01_vamp_nnls-chroma_chordino_simplechord.csv"},
#            "trackuri": "Eins"},
#      {"keys": { "tag": "csv", "value":"D:\\mirg\\Chord_Analysis20141216\\Beethoven\\qm_vamp_key_standard.n3_50ac9\\1CD0000653_BD01_vamp_qm-vamp-plugins_qm-keydetector_key.csv"},
#          "chords": { "tag": "csv", "value":"D:\\mirg\\Chord_Analysis20141216\\Beethoven\\chordino_simple.n3_1a812\\1CD0000653_BD01_vamp_nnls-chroma_chordino_simplechord.csv"}}
#      ]]

# these for file reading etc
import re
import os
import csv
import numpy

# spmf functions
import chord_seq_spmf_helper as spmf

from aggregate import *
from csvutils import *

# ---
# roots
# ---
chord_roots = ["C","D","E","F","G","A","B"] 

# create a dictionary for efficiency
roots_dic = dict(zip(chord_roots, [0,2,4,5,7,9,11]))

mode_lbls = ['major','minor']
mode_dic = dict(zip(mode_lbls, range(0,2)))
# ---
# types
# ---
type_labels = ["", "6", "7", "m","m6", "m7", "maj7", "m7b5", "dim", "dim7", "aug"]
type_dic = dict(zip(type_labels, range(0,len(type_labels))))

base_labels = ["1","","2","b3","3","4","","5","","6","b7","7"]
#base_dic = dict(zip(base_labels, range(0,len(base_labels))))

# functions
root_funs_maj = ['I','#I','II','#II','III','IV','#IV','V','#V','VI','#VI','VII']
root_funs_min = ['I','#I','II','III','#III','IV','#IV','V','VI','#VI','VII','#VII']
# dan's suggestion
#root_funs_maj = ['I','#I','II','#II','(M)III','IV','#IV','V','#V','VI','#VI','(M)VII']
#root_funs_min = ['I','#I','II','(m)III','#III','IV','#IV','V','VI','#VI','(m)VII','#VII']

fun_dic_maj = dict(zip(range(0,len(root_funs_maj)),root_funs_maj))
fun_dic_min = dict(zip(range(0,len(root_funs_min)),root_funs_min))
# regex that separates roots and types, and gets chord base
# this only accepts chords with a sharp (#) and no flats
p = re.compile(r'(?P<root>[A-G,N](#|b)*)(?P<type>[a-z,0-9]*)(/(?P<base>[A-G](#|b)*))*')
p2 = re.compile(r'(?P<key>[A-G](#|b)*)(\s/\s[A-G](#|b)*)*\s(?P<mode>[major|minor]+)')
pclip = re.compile(r'(?P<clipid>[A-Z,0-9]+(\-|_)[A-Z,0-9]+((\-|_)[A-Z,0-9]+)*((\-|_)[A-Z,0-9]+)*)_(?P<type>vamp.*).(?P<ext>(csv|xml|txt|n3)+)')

def chords_from_csv(filename):
    # we assume CSV: time, chord_string
    # return (time, chord_string)
    return csv_map_rows(filename,2, lambda row:(float(row[0]),row[1]))

def keys_from_csv(filename):
    # we assume CSV: time, key_code, key_string
    # return ( time, key_code, key_string)
    return csv_map_rows(filename,3, lambda row:(float(row[0]),row[1],row[2]))

# parsers for n3 / csv
key_parser_table = { 'csv':keys_from_csv }
chord_parser_table = { 'csv':chords_from_csv }

# extracts relative chord sequences from inputs of chord / key data
# input list of pairs with instances of features:
#    (['chords'] chordino_simple.n3_1a812 , ['keys'] qm_vamp_key_standard.n3_50ac9,
#     optional:  ['trackuri'] trackidentifier )
# @note: in future we could add support for qm_key_tonic input
# opts : dictionary with opts["spm_algorithm"] =  SPADE, TKS or ClaSP algorithm?
#                       and opts["spm_options"]  = "70%"
# output:
# 'sequences': seq, 'support': sup

trackctr = 0

def aggregate(inputs,opts={}):
    print_status('In chord_seq_key_relative')
    # SPADE, TKS or ClaSP algorithm?
    algo = opts.get("spm_algorithm","CM-SPADE")
    # number of sequences
    maxseqs = int(opts.get("spm_maxseqs",500)/2)
    # min. length of sequences
    minlen = int(opts.get("spm_minlen",2))
    # min. length of sequences in seconds
    maxtime = int(opts.get("spm_maxtime",1*60)/2)
    ignoreN = int(opts.get("spm_ignore_n",1))
    # min. length of sequences
    minsup = int(opts.get("spm_minsupport",50))
    # we now safe the mode of each piece 
    # to treat them separately
    out_chords = [dict(), dict()];
    # generate dict[trackuri] = [ (time,key,mode,fun,typ,bfun) ]
    def accum(item):
        global trackctr
        # increase virtual identifier
        trackctr += 1
        # get duration and normalised frequency for all tuning pitches (A3,A4,A5)
        keys = decode_tagged(key_parser_table,item['keys'])
        # get most frequent key
        key,mode = most_frequent_key(keys)
        relchords = []      
        for (time,chord) in decode_tagged(chord_parser_table,item['chords']):

            # ignore chords that are 'N':
            # a. the open pattern matching allows for arbitrary chords 
            #   to appear inbetween those in a sequence
            # b. the N chord potentially maps to any contents, so the
            #   inclusion of N chord has limited (or no) use 
            # get chord function
            (root,fun,typ, bfun) = chord2function(chord, key,mode)
            if not (ignoreN & (root == -1)): 
                # translate into text
                txt = fun2txt(fun,typ, bfun, mode)
                # print 'Chord: ' + chord + ', function: ' + txt
                # add to chords of this clip

        # save results into dict for this track
        trackuri = item.get('trackuri',trackctr)  
        out_chords[mode][trackuri] = relchords

    # collate relative chord information per file 
    # print_status('Finished accumulating')
    if trackctr < 2:
        raise Exception("Need more than 1 track")
    seq = [[],[]]
    sup = [[],[]]
    for mode in [0,1]:
        # write to spmf file
        spmffile = spmf.relchords2spmf(out_chords[mode])
        #print_status('Wrote SPMF data ' +
        # run sequential pattern matching
        if algo == "TKS":
            algoopts = opts.get("spm_options","")
            seqfile = spmf.spmf(,'TKS',[str(maxseqs), algoopts])
        elif algo == "ClaSP":
            algoopts = opts.get("spm_options",str(minsup) + "%")
            seqfile = spmf.spmf(,'ClaSP',[algoopts, str(minlen)], timeout = maxtime)
        elif algo == "SPADE":
            algoopts = opts.get("spm_options",str(minsup) + "%")
            seqfile = spmf.spmf(,'SPADE',[algoopts, str(minlen)], timeout = maxtime)
            print_status('Running CM-SPADE algo')
            algoopts = opts.get("spm_options",str(minsup) + "%")
            seqfile = spmf.spmf(,'CM-SPADE',[algoopts, str(minlen)], timeout = maxtime)
        #seqfile = spmf.spmf(,'BIDE+',['70%'])
        #seqfile = "D:\mirg\Chord_Analysis20141216\Beethoven_60.txt"
        #print_status('SPADE finished in ' + seqfile)
        # parse spmf output
        seq[mode],sup[mode] = spmf.spmf2table(seqfile)
        #clean up
    # fold back sequences and support
    # note that this results in the sequences being truncated together below
    seq = [item for sublist in seq for item in sublist] 
    sup = [item for sublist in sup for item in sublist] 
    # filter according to min. sequencelength and number of sequences
    seq_out = []
    sup_out = []
    seq_count = 0
    # sort in descending support and pick up sequences of sufficient length
    for i in numpy.argsort(sup)[::-1]:
        if len(seq[i]) >= minlen:
            seq_count += 1
        if seq_count >= maxseqs:

    return { 'result': { 'sequences': seq_out, 'support': sup_out}, 
             'stats' : st }

# most simple note2num
def note2num(notein = 'Cb'):
    base = roots_dic[notein[0]]
    if len(notein) > 1:
        if notein[1] == 'b':
            return (base - 1) % 12
        elif notein[1] == '#':
            return (base + 1) % 12
            print "Error parsing chord " + notein
        return base % 12

# convert key to number
def key2num(keyin = 'C major'):
    # ---
    # parse key string: separate root from rest
    # ---
    sepstring = p2.match(keyin)
    if not sepstring:
        print "Error parsing key " + keyin
    # get relative position of chord and adapt for flats
    key ='key')      
    key = note2num(key)
    # ---
    # parse mode. care for (unknown) string
    # ---
    mode ='mode')   

    if mode:
        mode = mode_dic[mode]
        mode = -1

    return (key, mode)


# convert chord to relative function
def chord2function(cin = 'B',key=3, mode=0):
    # ---
    # parse chord string: separate root from rest
    # ---
    sepstring = p.match(cin)
    # test for N code -> no chord detected
    if'root') == 'N':
        return (-1,-1,-1,-1)
    # get root and type otherwise 
    root = note2num('root'))
    type ='type') 
    typ = type_dic[type]

    # get relative position
    fun = (root - key) % 12
    #--- do we have a base key?
    # if yes return it relative to chord root
    # ---
        broot = note2num('base'))
        bfun = (broot - root) % 12
        # this standard gives 1 as a base key if not specified otherwise
        bfun = 0
    # ---
    # todo: integrate bfun in final type list
    # ---
    return (root,fun,typ,bfun)    

# reads in any csv and returns a list of structure
# time(float), data1, data2 ....data2
def read_vamp_csv(filein = ''):
    output = []
    with open(filein, 'rb') as csvfile:
        contents = csv.reader(csvfile, delimiter=',', quotechar='"')
        for row in contents:
            output.append([float(row[0])] + row[1:])
    return output

# histogram of the last entry in a list
# returns the most frequently used key
def histogram(keysin = []):
    # build histogram 
    histo = dict()
    for row in keysin:
        histo[row[-1]] = histo.get(row[-1], 0) + 1 

    # return most frequent key
    return (histo, max(histo.iterkeys(), key=(lambda key: histo[key])))

def most_frequent_key(keys):   
    # delete 'unknown' keys
    keys = [(time,knum,key) for (time,knum,key) in keys if not key == '(unknown)']

    # aggregate to one key 
    (histo, skey) = histogram(keys)

    # bet key number
    (key,mode) = key2num(skey)    
    return key,mode
def fun2txt(fun,typ, bfun,mode):
    # now we can interpret this function 
    # when given the mode of major or minor.
    if (fun >= 0):
        if (mode == 1):
            pfun = fun_dic_min[fun]
            md = '(m)'
        elif (mode == 0):
            pfun = fun_dic_maj[fun] 
            md = '(M)'
        return 'N'

    #if typ == 'm':
    #    print 'Key: ' + skey + ', chord: ' + chord + ' function ' + str(fun) + ' type ' + typ + ' bfun ' + str(bfun)
    type = type_labels[typ] if typ > 0 else ''
    blb = '/' + base_labels[bfun] if (bfun >= 0 and base_labels[bfun]) else ''
    return md + pfun + type + blb

def fun2num(fun,typ, bfun,mode):
    # now we can interpret this function 
    if not fun == -1:
        return (mode+1)* 1000000 + (fun+1) * 10000 + (typ+1) * 100 + (bfun+1)
        return 0

if __name__ == "__main__":
    print "Creates a key-independent chord histogram. Usage: chord2function path_vamp_chords path_vamp_keys"
    # sys.argv[1]
    result = folder2histogram()
    print "Please input a description for the chord function histogram"