Mercurial > hg > hybrid-music-recommender-using-content-based-and-social-information
view Code/make_lists.py @ 18:c0a08cbdfacd
First script
author | Paulo Chiliguano <p.e.chiliguano@se14.qmul.ac.uk> |
---|---|
date | Tue, 28 Jul 2015 20:58:57 +0100 |
parents | 0a0d6203638a |
children | 2e3c57fba632 |
line wrap: on
line source
import numpy #import numpy.random as random import os import pickle import sys import utils as U #import pdb def read_file(filename): """ Loads a file into a list """ file_list=[l.strip() for l in open(filename,'r').readlines()] return file_list def get_folds(filelist, n_folds): n_per_fold = len(filelist) / n_folds folds = [] for i in range(n_folds-1): folds.append(filelist[i * n_per_fold: (i + 1) * n_per_fold]) i = n_folds - 1 folds.append(filelist[i * n_per_fold:]) return folds def generate_mirex_list(train_list, annotations): out_list = [] for song in train_list: annot = annotations.get(song,None) if annot is None: print 'No annotations for song %s' % song continue assert(type('') == type(annot)) out_list.append('%s\t%s\n' % (song,annot)) return out_list def make_file_list(gtzan_path, n_folds=5,): """ Generates lists """ audio_path = os.path.join(gtzan_path,'preview_clip') out_path = os.path.join(gtzan_path,'lists') files_list = [] for ext in ['.au', '.mp3', '.wav']: files = U.getFiles(audio_path, ext) files_list.extend(files) #random.shuffle(files_list) if not os.path.exists(out_path): os.makedirs(out_path) audio_list_path = os.path.join(out_path, 'audio_files.txt') open(audio_list_path,'w').writelines(['%s\n' % f for f in files_list]) annotations = get_annotations(files_list) ground_truth_path = os.path.join(out_path, 'ground_truth.txt') open(ground_truth_path,'w').writelines(generate_mirex_list(files_list, annotations)) generate_ground_truth_pickle(ground_truth_path) folds = get_folds(files_list, n_folds=n_folds) ### Single fold for quick experiments create_fold(0, 1, folds, annotations, out_path) for n in range(n_folds): create_fold(n, n_folds, folds, annotations, out_path) def create_fold(n, n_folds, folds, annotations, out_path): train_path = os.path.join(out_path, 'train_%i_of_%i.txt' % (n+1, n_folds)) valid_path = os.path.join(out_path, 'valid_%i_of_%i.txt' % (n+1, n_folds)) test_path = os.path.join(out_path, 'test_%i_of_%i.txt' % (n+1, n_folds)) test_list = folds[n] train_list = [] for m in range(len(folds)): if m != n: train_list.extend(folds[m]) open(train_path,'w').writelines(generate_mirex_list(train_list, annotations)) open(test_path,'w').writelines(generate_mirex_list(test_list, annotations)) split_list_file(train_path, train_path, valid_path, ratio=0.8) def split_list_file(input_file, out_file1, out_file2, ratio=0.8): input_list = open(input_file,'r').readlines() n = len(input_list) nsplit = int(n *ratio) list1 = input_list[:nsplit] list2 = input_list[nsplit:] open(out_file1, 'w').writelines(list1) open(out_file2, 'w').writelines(list2) def get_annotation(filename): genre = os.path.split(U.parseFile(filename)[0])[-1] return genre def get_annotations(files_list): annotations = {} for filename in files_list: annotations[filename] = get_annotation(filename) return annotations def generate_ground_truth_pickle(gt_file): gt_path,_ = os.path.split(gt_file) tag_file = os.path.join(gt_path,'tags.txt') gt_pickle = os.path.join(gt_path,'ground_truth.pickle') lines = open(gt_file,'r').readlines() tag_set = set() for line in lines: filename,tag = line.strip().split('\t') tag_set.add(tag) tag_list = sorted(list(tag_set)) open(tag_file,'w').writelines('\n'.join(tag_list + [''])) tag_dict = dict([(tag,i) for i,tag in enumerate(tag_list)]) n_tags = len(tag_dict) mp3_dict = {} for line in lines: filename,tag = line.strip().split('\t') tag_vector = mp3_dict.get(filename,numpy.zeros(n_tags)) if tag != '': tag_vector[tag_dict[tag]] = 1. mp3_dict[filename] = tag_vector pickle.dump(mp3_dict,open(gt_pickle,'w')) if __name__ == '__main__': if len(sys.argv) < 2: print 'Usage: python %s gtzan_path [n_folds=10]' % sys.argv[0] sys.exit() gtzan_path = os.path.abspath(sys.argv[1]) if len(sys.argv) > 2: n_folds = int(sys.argv[2]) else: n_folds = 10 make_file_list(gtzan_path, n_folds)