m@94: { m@94: "cells": [ m@94: { m@94: "cell_type": "code", m@94: "execution_count": 1, m@94: "metadata": { m@94: "collapsed": true m@94: }, m@94: "outputs": [], m@94: "source": [ m@94: "import numpy as np\n", m@94: "import pandas as pd\n", m@94: "import pickle \n", m@94: "\n", m@94: "%load_ext autoreload\n", m@94: "%autoreload 2\n", m@94: "\n", m@94: "%matplotlib inline\n", m@94: "import matplotlib.pyplot as plt\n", m@94: "\n", m@94: "import sys\n", m@94: "sys.path.append('../')\n", m@94: "import scripts.map_and_average as mapper\n", m@94: "import scripts.classification as classification" m@94: ] m@94: }, m@94: { m@94: "cell_type": "markdown", m@94: "metadata": {}, m@94: "source": [ m@94: "## Feature learning and write output" m@94: ] m@94: }, m@94: { m@94: "cell_type": "code", m@94: "execution_count": null, m@94: "metadata": { m@94: "collapsed": true m@94: }, m@94: "outputs": [], m@94: "source": [ m@94: "print \"mapping...\"\n", m@94: "data_list, pcadata_list, ldadata_list, nmfdata_list, ssnmfdata_list, classlabs, audiolabs = mapper.map_and_average_frames(min_variance=0.99)\n", m@94: "mapper.write_output(data_list, pcadata_list, ldadata_list, nmfdata_list, ssnmfdata_list, classlabs, audiolabs)" m@94: ] m@94: }, m@94: { m@94: "cell_type": "markdown", m@94: "metadata": {}, m@94: "source": [ m@94: "## Classification" m@94: ] m@94: }, m@94: { m@94: "cell_type": "code", m@94: "execution_count": 2, m@94: "metadata": { m@94: "scrolled": true m@94: }, m@94: "outputs": [ m@94: { m@94: "name": "stderr", m@94: "output_type": "stream", m@94: "text": [ m@94: "/homes/mp305/anaconda/lib/python2.7/site-packages/sklearn/metrics/classification.py:1113: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.\n", m@94: " 'precision', 'predicted', average, warn_for)\n", m@94: "/homes/mp305/anaconda/lib/python2.7/site-packages/sklearn/discriminant_analysis.py:455: UserWarning: The priors do not sum to 1. Renormalizing\n", m@94: " UserWarning)\n" m@94: ] m@94: }, m@94: { m@94: "name": "stdout", m@94: "output_type": "stream", m@94: "text": [ m@94: "KNN LDA 0.151978449974\n", m@94: "LDA LDA 0.320669835863\n", m@94: "SVM LDA 0.0231101788399\n", m@94: "RF LDA 0.0742913265128\n", m@94: "KNN LDA 0.0547390436205\n", m@94: "LDA LDA 0.150312531138\n", m@94: "SVM LDA 0.0787628988868\n", m@94: "RF LDA 0.0427708629723\n", m@94: "KNN LDA 0.0232330458268\n", m@94: "LDA LDA 0.0702474072041\n", m@94: "SVM LDA 0.050068706152\n", m@94: "RF LDA 0.0193967985786\n", m@94: "KNN LDA 0.281733731607\n", m@94: "LDA LDA 0.198582742899\n", m@94: "SVM LDA 0.296355560166\n", m@94: "RF LDA 0.132752660164\n", m@94: "KNN LDA 0.0857923493684\n", m@94: "LDA LDA 0.107355289483\n", m@94: "SVM LDA 0.0896098014444\n", m@94: "RF LDA 0.0419252843345\n", m@94: "KNN PCA 0.140643930221\n", m@94: "LDA PCA 0.175099072208\n", m@94: "SVM PCA 0.0149273059799\n", m@94: "RF PCA 0.044347638128\n", m@94: "KNN PCA 0.052516908106\n", m@94: "LDA PCA 0.055028942176\n", m@94: "SVM PCA 0.0479512645907\n", m@94: "RF PCA 0.0325567872284\n", m@94: "KNN PCA 0.0268729640269\n", m@94: "LDA PCA 0.0459303318699\n", m@94: "SVM PCA 0.0386730267598\n", m@94: "RF PCA 0.0184694543728\n", m@94: "KNN PCA 0.220850433533\n", m@94: "LDA PCA 0.161502657527\n", m@94: "SVM PCA 0.245790916558\n", m@94: "RF PCA 0.131939188698\n", m@94: "KNN PCA 0.0814272808267\n", m@94: "LDA PCA 0.0839732813486\n", m@94: "SVM PCA 0.0918638232782\n", m@94: "RF PCA 0.0449817232296\n", m@94: "KNN NMF 0.114298949339\n", m@94: "LDA NMF 0.178244078869\n", m@94: "SVM NMF 0.0164055663008\n", m@94: "RF NMF 0.0588656307204\n", m@94: "KNN NMF 0.043057794756\n", m@94: "LDA NMF 0.0586662842996\n", m@94: "SVM NMF 0.00781273342686\n", m@94: "RF NMF 0.0285937566916\n", m@94: "KNN NMF 0.0285281454673\n", m@94: "LDA NMF 0.0463659955869\n", m@94: "SVM NMF 0.00768887594564\n", m@94: "RF NMF 0.0206293416635\n", m@94: "KNN NMF 0.177819886656\n", m@94: "LDA NMF 0.166221515627\n", m@94: "SVM NMF 0.010788613595\n", m@94: "RF NMF 0.111500698621\n", m@94: "KNN NMF 0.0795454671166\n", m@94: "LDA NMF 0.0856428557896\n", m@94: "SVM NMF 0.0116920633048\n", m@94: "RF NMF 0.0421056105664\n", m@94: "KNN SSNMF 0.14322692821\n", m@94: "LDA SSNMF 0.18320247367\n", m@94: "SVM SSNMF 0.0205784326384\n", m@94: "RF SSNMF 0.0438349321113\n", m@94: "KNN SSNMF 0.0431300683181\n", m@94: "LDA SSNMF 0.0533449581285\n", m@94: "SVM SSNMF 0.0106542141335\n", m@94: "RF SSNMF 0.0272971462205\n", m@94: "KNN SSNMF 0.0152235481009\n", m@94: "LDA SSNMF 0.038872838043\n", m@94: "SVM SSNMF 0.00536127803533\n", m@94: "RF SSNMF 0.0189951953248\n", m@94: "KNN SSNMF 0.227101074174\n", m@94: "LDA SSNMF 0.165382484171\n", m@94: "SVM SSNMF 0.0184921176111\n", m@94: "RF SSNMF 0.105334465578\n", m@94: "KNN SSNMF 0.0715413500709\n", m@94: "LDA SSNMF 0.0819764377219\n", m@94: "SVM SSNMF 0.0138822224913\n", m@94: "RF SSNMF 0.035838117053\n", m@94: "KNN NA 0.140075287804\n", m@94: "LDA NA 0.176953549195\n", m@94: "SVM NA 0.0149485545637\n", m@94: "RF NA 0.0891679877647\n", m@94: "KNN NA 0.0515315452955\n", m@94: "LDA NA 0.0599453579616\n", m@94: "SVM NA 0.0468615478392\n", m@94: "RF NA 0.0440373525097\n", m@94: "KNN NA 0.0273364752119\n", m@94: "LDA NA 0.0378819151174\n", m@94: "SVM NA 0.038290667129\n", m@94: "RF NA 0.0256534114754\n", m@94: "KNN NA 0.221769305159\n", m@94: "LDA NA 0.191217962613\n", m@94: "SVM NA 0.250268813953\n", m@94: "RF NA 0.118133604659\n", m@94: "KNN NA 0.0814734970192\n", m@94: "LDA NA 0.0839348156722\n", m@94: "SVM NA 0.0881235182136\n", m@94: "RF NA 0.0532974158539\n" m@94: ] m@94: } m@94: ], m@94: "source": [ m@94: "df_results = classification.classify_for_filenames(file_list=mapper.OUTPUT_FILES)" m@94: ] m@94: }, m@94: { m@94: "cell_type": "markdown", m@94: "metadata": { m@94: "collapsed": true m@94: }, m@94: "source": [ m@94: "Sort results by accuracy of all features ('All' - Column 2)" m@94: ] m@94: }, m@94: { m@94: "cell_type": "code", m@94: "execution_count": 6, m@94: "metadata": {}, m@94: "outputs": [ m@94: { m@94: "name": "stdout", m@94: "output_type": "stream", m@94: "text": [ m@94: "\\begin{tabular}{llrrrrr}\n", m@94: "\\toprule\n", m@94: " 0 & 1 & 2 & 3 & 4 & 5 & 6 \\\\\n", m@94: "\\midrule\n", m@94: " LDA & LDA & 0.320670 & 0.150313 & 0.070247 & 0.198583 & 0.107355 \\\\\n", m@94: " SSNMF & LDA & 0.183202 & 0.053345 & 0.038873 & 0.165382 & 0.081976 \\\\\n", m@94: " NMF & LDA & 0.178244 & 0.058666 & 0.046366 & 0.166222 & 0.085643 \\\\\n", m@94: " NA & LDA & 0.176954 & 0.059945 & 0.037882 & 0.191218 & 0.083935 \\\\\n", m@94: " PCA & LDA & 0.175099 & 0.055029 & 0.045930 & 0.161503 & 0.083973 \\\\\n", m@94: " LDA & KNN & 0.151978 & 0.054739 & 0.023233 & 0.281734 & 0.085792 \\\\\n", m@94: " SSNMF & KNN & 0.143227 & 0.043130 & 0.015224 & 0.227101 & 0.071541 \\\\\n", m@94: " PCA & KNN & 0.140644 & 0.052517 & 0.026873 & 0.220850 & 0.081427 \\\\\n", m@94: " NA & KNN & 0.140075 & 0.051532 & 0.027336 & 0.221769 & 0.081473 \\\\\n", m@94: " NMF & KNN & 0.114299 & 0.043058 & 0.028528 & 0.177820 & 0.079545 \\\\\n", m@94: " NA & RF & 0.084140 & 0.045801 & 0.022834 & 0.118752 & 0.052336 \\\\\n", m@94: " LDA & RF & 0.075053 & 0.040452 & 0.014805 & 0.133543 & 0.052025 \\\\\n", m@94: " NMF & RF & 0.065347 & 0.036663 & 0.024069 & 0.121136 & 0.049071 \\\\\n", m@94: " PCA & RF & 0.053220 & 0.029322 & 0.017777 & 0.113936 & 0.046819 \\\\\n", m@94: " SSNMF & RF & 0.031423 & 0.021354 & 0.015184 & 0.100996 & 0.045024 \\\\\n", m@94: " LDA & SVM & 0.023110 & 0.078763 & 0.050069 & 0.296356 & 0.089610 \\\\\n", m@94: " SSNMF & SVM & 0.020578 & 0.010654 & 0.005361 & 0.018492 & 0.013882 \\\\\n", m@94: " NMF & SVM & 0.016406 & 0.007813 & 0.007689 & 0.010789 & 0.011692 \\\\\n", m@94: " NA & SVM & 0.014949 & 0.046862 & 0.038291 & 0.250269 & 0.088124 \\\\\n", m@94: " PCA & SVM & 0.014927 & 0.047951 & 0.038673 & 0.245791 & 0.091864 \\\\\n", m@94: "\\bottomrule\n", m@94: "\\end{tabular}\n", m@94: "\n" m@94: ] m@94: } m@94: ], m@94: "source": [ m@94: "df_results_sorted = df_results.sort_values(2, ascending=False, inplace=False)\n", m@94: "df_results_sorted.head()\n", m@94: "print df_results_sorted.to_latex(index=False)" m@94: ] m@94: }, m@94: { m@94: "cell_type": "markdown", m@94: "metadata": {}, m@94: "source": [ m@94: "## Confusion matrix" m@94: ] m@94: }, m@94: { m@94: "cell_type": "markdown", m@94: "metadata": {}, m@94: "source": [ m@94: "According to results above, best classifier = LDA and best transformation = LDA." m@94: ] m@94: }, m@94: { m@94: "cell_type": "code", m@94: "execution_count": 3, m@94: "metadata": { m@94: "collapsed": true m@94: }, m@94: "outputs": [], m@94: "source": [ m@94: "accuracy, CF, labels = classification.confusion_matrix_for_dataset(df_results, classifier='LDA', filename=mapper.OUTPUT_FILES[0])" m@94: ] m@94: }, m@94: { m@94: "cell_type": "code", m@94: "execution_count": 4, m@94: "metadata": {}, m@94: "outputs": [ m@94: { m@94: "name": "stdout", m@94: "output_type": "stream", m@94: "text": [ m@94: "0.320669835863\n" m@94: ] m@94: } m@94: ], m@94: "source": [ m@94: "print accuracy" m@94: ] m@94: }, m@94: { m@94: "cell_type": "markdown", m@94: "metadata": {}, m@94: "source": [ m@94: "Use the figure functionality to zoom in the confusion matrix." m@94: ] m@94: }, m@94: { m@94: "cell_type": "code", m@94: "execution_count": 12, m@94: "metadata": { m@94: "scrolled": false m@94: }, m@94: "outputs": [ m@94: { m@94: "data": { m@94: "application/javascript": [ m@94: "/* Put everything inside the global mpl namespace */\n", m@94: "window.mpl = {};\n", m@94: "\n", m@94: "mpl.get_websocket_type = function() {\n", m@94: " if (typeof(WebSocket) !== 'undefined') {\n", m@94: " return WebSocket;\n", m@94: " } else if (typeof(MozWebSocket) !== 'undefined') {\n", m@94: " return MozWebSocket;\n", m@94: " } else {\n", m@94: " alert('Your browser does not have WebSocket support.' +\n", m@94: " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", m@94: " 'Firefox 4 and 5 are also supported but you ' +\n", m@94: " 'have to enable WebSockets in about:config.');\n", m@94: " };\n", m@94: "}\n", m@94: "\n", m@94: "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", m@94: " this.id = figure_id;\n", m@94: "\n", m@94: " this.ws = websocket;\n", m@94: "\n", m@94: " this.supports_binary = (this.ws.binaryType != undefined);\n", m@94: "\n", m@94: " if (!this.supports_binary) {\n", m@94: " var warnings = document.getElementById(\"mpl-warnings\");\n", m@94: " if (warnings) {\n", m@94: " warnings.style.display = 'block';\n", m@94: " warnings.textContent = (\n", m@94: " \"This browser does not support binary websocket messages. \" +\n", m@94: " \"Performance may be slow.\");\n", m@94: " }\n", m@94: " }\n", m@94: "\n", m@94: " this.imageObj = new Image();\n", m@94: "\n", m@94: " this.context = undefined;\n", m@94: " this.message = undefined;\n", m@94: " this.canvas = undefined;\n", m@94: " this.rubberband_canvas = undefined;\n", m@94: " this.rubberband_context = undefined;\n", m@94: " this.format_dropdown = undefined;\n", m@94: "\n", m@94: " this.image_mode = 'full';\n", m@94: "\n", m@94: " this.root = $('
');\n", m@94: " this._root_extra_style(this.root)\n", m@94: " this.root.attr('style', 'display: inline-block');\n", m@94: "\n", m@94: " $(parent_element).append(this.root);\n", m@94: "\n", m@94: " this._init_header(this);\n", m@94: " this._init_canvas(this);\n", m@94: " this._init_toolbar(this);\n", m@94: "\n", m@94: " var fig = this;\n", m@94: "\n", m@94: " this.waiting = false;\n", m@94: "\n", m@94: " this.ws.onopen = function () {\n", m@94: " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", m@94: " fig.send_message(\"send_image_mode\", {});\n", m@94: " fig.send_message(\"refresh\", {});\n", m@94: " }\n", m@94: "\n", m@94: " this.imageObj.onload = function() {\n", m@94: " if (fig.image_mode == 'full') {\n", m@94: " // Full images could contain transparency (where diff images\n", m@94: " // almost always do), so we need to clear the canvas so that\n", m@94: " // there is no ghosting.\n", m@94: " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", m@94: " }\n", m@94: " fig.context.drawImage(fig.imageObj, 0, 0);\n", m@94: " fig.waiting = false;\n", m@94: " };\n", m@94: "\n", m@94: " this.imageObj.onunload = function() {\n", m@94: " this.ws.close();\n", m@94: " }\n", m@94: "\n", m@94: " this.ws.onmessage = this._make_on_message_function(this);\n", m@94: "\n", m@94: " this.ondownload = ondownload;\n", m@94: "}\n", m@94: "\n", m@94: "mpl.figure.prototype._init_header = function() {\n", m@94: " var titlebar = $(\n", m@94: " '
');\n", m@94: " var titletext = $(\n", m@94: " '
');\n", m@94: " titlebar.append(titletext)\n", m@94: " this.root.append(titlebar);\n", m@94: " this.header = titletext[0];\n", m@94: "}\n", m@94: "\n", m@94: "\n", m@94: "\n", m@94: "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", m@94: "\n", m@94: "}\n", m@94: "\n", m@94: "\n", m@94: "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", m@94: "\n", m@94: "}\n", m@94: "\n", m@94: "mpl.figure.prototype._init_canvas = function() {\n", m@94: " var fig = this;\n", m@94: "\n", m@94: " var canvas_div = $('
');\n", m@94: "\n", m@94: " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", m@94: "\n", m@94: " function canvas_keyboard_event(event) {\n", m@94: " return fig.key_event(event, event['data']);\n", m@94: " }\n", m@94: "\n", m@94: " canvas_div.keydown('key_press', canvas_keyboard_event);\n", m@94: " canvas_div.keyup('key_release', canvas_keyboard_event);\n", m@94: " this.canvas_div = canvas_div\n", m@94: " this._canvas_extra_style(canvas_div)\n", m@94: " this.root.append(canvas_div);\n", m@94: "\n", m@94: " var canvas = $('');\n", m@94: " canvas.addClass('mpl-canvas');\n", m@94: " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", m@94: "\n", m@94: " this.canvas = canvas[0];\n", m@94: " this.context = canvas[0].getContext(\"2d\");\n", m@94: "\n", m@94: " var rubberband = $('');\n", m@94: " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", m@94: "\n", m@94: " var pass_mouse_events = true;\n", m@94: "\n", m@94: " canvas_div.resizable({\n", m@94: " start: function(event, ui) {\n", m@94: " pass_mouse_events = false;\n", m@94: " },\n", m@94: " resize: function(event, ui) {\n", m@94: " fig.request_resize(ui.size.width, ui.size.height);\n", m@94: " },\n", m@94: " stop: function(event, ui) {\n", m@94: " pass_mouse_events = true;\n", m@94: " fig.request_resize(ui.size.width, ui.size.height);\n", m@94: " },\n", m@94: " });\n", m@94: "\n", m@94: " function mouse_event_fn(event) {\n", m@94: " if (pass_mouse_events)\n", m@94: " return fig.mouse_event(event, event['data']);\n", m@94: " }\n", m@94: "\n", m@94: " rubberband.mousedown('button_press', mouse_event_fn);\n", m@94: " rubberband.mouseup('button_release', mouse_event_fn);\n", m@94: " // Throttle sequential mouse events to 1 every 20ms.\n", m@94: " rubberband.mousemove('motion_notify', mouse_event_fn);\n", m@94: "\n", m@94: " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", m@94: " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", m@94: "\n", m@94: " canvas_div.on(\"wheel\", function (event) {\n", m@94: " event = event.originalEvent;\n", m@94: " event['data'] = 'scroll'\n", m@94: " if (event.deltaY < 0) {\n", m@94: " event.step = 1;\n", m@94: " } else {\n", m@94: " event.step = -1;\n", m@94: " }\n", m@94: " mouse_event_fn(event);\n", m@94: " });\n", m@94: "\n", m@94: " canvas_div.append(canvas);\n", m@94: " canvas_div.append(rubberband);\n", m@94: "\n", m@94: " this.rubberband = rubberband;\n", m@94: " this.rubberband_canvas = rubberband[0];\n", m@94: " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", m@94: " this.rubberband_context.strokeStyle = \"#000000\";\n", m@94: "\n", m@94: " this._resize_canvas = function(width, height) {\n", m@94: " // Keep the size of the canvas, canvas container, and rubber band\n", m@94: " // canvas in synch.\n", m@94: " canvas_div.css('width', width)\n", m@94: " canvas_div.css('height', height)\n", m@94: "\n", m@94: " canvas.attr('width', width);\n", m@94: " canvas.attr('height', height);\n", m@94: "\n", m@94: " rubberband.attr('width', width);\n", m@94: " rubberband.attr('height', height);\n", m@94: " }\n", m@94: "\n", m@94: " // Set the figure to an initial 600x600px, this will subsequently be updated\n", m@94: " // upon first draw.\n", m@94: " this._resize_canvas(600, 600);\n", m@94: "\n", m@94: " // Disable right mouse context menu.\n", m@94: " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", m@94: " return false;\n", m@94: " });\n", m@94: "\n", m@94: " function set_focus () {\n", m@94: " canvas.focus();\n", m@94: " canvas_div.focus();\n", m@94: " }\n", m@94: "\n", m@94: " window.setTimeout(set_focus, 100);\n", m@94: "}\n", m@94: "\n", m@94: "mpl.figure.prototype._init_toolbar = function() {\n", m@94: " var fig = this;\n", m@94: "\n", m@94: " var nav_element = $('
')\n", m@94: " nav_element.attr('style', 'width: 100%');\n", m@94: " this.root.append(nav_element);\n", m@94: "\n", m@94: " // Define a callback function for later on.\n", m@94: " function toolbar_event(event) {\n", m@94: " return fig.toolbar_button_onclick(event['data']);\n", m@94: " }\n", m@94: " function toolbar_mouse_event(event) {\n", m@94: " return fig.toolbar_button_onmouseover(event['data']);\n", m@94: " }\n", m@94: "\n", m@94: " for(var toolbar_ind in mpl.toolbar_items) {\n", m@94: " var name = mpl.toolbar_items[toolbar_ind][0];\n", m@94: " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", m@94: " var image = mpl.toolbar_items[toolbar_ind][2];\n", m@94: " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", m@94: "\n", m@94: " if (!name) {\n", m@94: " // put a spacer in here.\n", m@94: " continue;\n", m@94: " }\n", m@94: " var button = $('