diff pyspark/ilm/spark_feat_extract.py @ 0:e34cf1b6fe09 tip

commit
author Daniel Wolff
date Sat, 20 Feb 2016 18:14:24 +0100
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pyspark/ilm/spark_feat_extract.py	Sat Feb 20 18:14:24 2016 +0100
@@ -0,0 +1,158 @@
+# Part of DML (Digital Music Laboratory)
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; either version 2
+# of the License, or (at your option) any later version.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
+
+#!/usr/local/spark-1.0.0-bin-hadoop2/bin/spark-submit
+# -*- coding: utf-8 -*-
+__author__="hargreavess"
+
+from assetDB import assetDB
+from pyspark import SparkConf, SparkContext
+import ConfigParser
+import logging
+from transform import *
+import os
+import time
+import shutil
+
+def main():
+    start_complete = time.time();
+
+    # get config
+    config = ConfigParser.ConfigParser()
+    config.read('server.cfg')
+
+    #vamp_transform = [config.get('Sonic Annotator', 'vamp-transform')]
+    vamp_transform_list = config.get('Sonic Annotator', 'vamp-transform-list')
+    genre_id = config.getint('Queries', 'genre-id')
+
+    output_dir = config.get('Sonic Annotator', 'output-dir')
+    ltime = time.localtime()
+    output_dir = output_dir + '_' + str(ltime.tm_mday) + '_' + str(ltime.tm_mon) + '_' + str(ltime.tm_year)
+    output_dir = output_dir + '_' + str(ltime.tm_hour) + str(ltime.tm_min) + '_' + str(ltime.tm_sec)
+    output_dir = output_dir + '_genre_id_' + str(genre_id)
+    # create output directory, if it doesn't exist
+    if not os.access(output_dir, os.F_OK):
+        os.makedirs(output_dir)
+
+    # copy vamp_transform_list file to output directory
+    shutil.copy(vamp_transform_list, output_dir)
+        
+    # create logger
+    #logger = logging.getLogger('spark_feat_extract')
+    logger = logging.getLogger('spark_feat_extract')
+    logger.setLevel(logging.DEBUG)
+
+    # create file handler and set level to debug
+    fh = logging.FileHandler(output_dir + "/ilm.assets.spark.features.log")
+    fh.setLevel(logging.DEBUG)
+
+    # create formatter
+    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+
+    # add formatter to fh
+    fh.setFormatter(formatter)
+
+    # add fh to logger
+    logger.addHandler(fh)
+
+    logger.info('starting new spark_feat_extract job')
+    logger.info("using vamp transform list: " + vamp_transform_list)
+    logger.info('audio-file-size-limit: ' + config.get('Audio Files', 'audio-file-size-limit'))
+    logger.info("audio-prefix: " + config.get('Audio Files', 'audio-prefix'))
+    logger.info('num-cores: ' + config.get('Spark', 'num-cores'))
+    logger.info("spark memory: " + config.get('Spark', 'memory'))
+    logger.info("genre_id: " + str(genre_id))
+
+    # create a spark context
+    conf = (SparkConf()
+            .setMaster("local[" + config.get('Spark', 'num-cores') + "]")
+            .setAppName("spark feature extractor")
+            .set("spark.executor.memory", "" + config.get('Spark', 'memory') + ""))
+    sc = SparkContext(conf = conf)
+
+    SQL_start = config.getint('Queries', 'sql-start')
+    SQL_limit = config.getint('Queries', 'sql-limit')
+    local_SQL_start = SQL_start
+    logger.info('SQL_start = %i', SQL_start)
+    logger.info('SQL_limit = %i', SQL_limit)
+
+    array_step_size = config.getint('Application', 'array-step-size')
+    logger.info('array-step-size = %i', array_step_size)
+    local_SQL_limit = min(SQL_limit, array_step_size)
+
+    while local_SQL_limit <= SQL_limit:
+
+        # query db for assets (song tracks)
+        db = assetDB(prefix=config.get('Audio Files', 'audio-prefix'),config=config)
+        db.connect()
+
+        data = []
+        logger.info('local_start = %i', local_SQL_start)
+        logger.info('local_SQL_limit = %i', local_SQL_limit)
+
+        for path, asset in db.get_assets_by_genre(genre_id, local_SQL_start, local_SQL_limit):
+            if path == None:
+                logger.warning("Asset not found for: %s. (Album ID: %i Track No: %i)",asset.song_title,asset.album_id,asset.track_no)
+            else:
+                data.append(path)
+
+        db.close
+
+        # If the db query returned no results, stop here
+        if len(data) == 0:
+            break
+
+        batch_output_dir = output_dir + '/batch' + str(local_SQL_start) + '-' + str(local_SQL_limit)
+        os.makedirs(batch_output_dir)
+        logger.info('created results directory ' + batch_output_dir)
+
+        logger.info("calling sc.parallelize(data)...")
+        start = time.time();
+
+        # define distributed dataset
+        distData = sc.parallelize(data)
+        end = time.time();
+        logger.info("finished in " + (str)(end - start))
+
+        logger.info("calling distData.map...")
+        start = time.time();
+
+        # define map 
+        m1 = distData.map(lambda x: transform(audio_file=x,
+            vamp_transform_list=vamp_transform_list,
+            output_dir=batch_output_dir))
+        end = time.time();
+        logger.info("finished in " + (str)(end - start))
+
+        logger.info("calling m1.collect()...")
+        start = time.time();
+
+        # collect results
+        theResult = m1.collect()
+
+        end = time.time();
+
+        logger.info("finished in " + (str)(end - start))
+
+        local_SQL_start += array_step_size
+        local_SQL_limit += min(SQL_limit, array_step_size)
+
+    print "finished all in " + (str)(end - start_complete)
+    logger.info("finished all in " + (str)(end - start_complete))
+
+if __name__ == "__main__":
+    main()
+