Mercurial > hg > plosone_underreview
comparison notebooks/sensitivity_experiment.ipynb @ 82:4395037087b6 branch-tests
notebooks
author | mpanteli <m.x.panteli@gmail.com> |
---|---|
date | Tue, 26 Sep 2017 21:18:26 +0100 |
parents | 92a5e280946d |
children | e279ccea5f9b |
comparison
equal
deleted
inserted
replaced
79:98fc06ba2938 | 82:4395037087b6 |
---|---|
1 { | 1 { |
2 "cells": [ | 2 "cells": [ |
3 { | 3 { |
4 "cell_type": "code", | 4 "cell_type": "code", |
5 "execution_count": 1, | 5 "execution_count": 1, |
6 "metadata": { | 6 "metadata": {}, |
7 "collapsed": false | |
8 }, | |
9 "outputs": [ | 7 "outputs": [ |
8 { | |
9 "name": "stdout", | |
10 "output_type": "stream", | |
11 "text": [ | |
12 "ERROR! Session/line number was not unique in database. History logging moved to new session 32\n" | |
13 ] | |
14 }, | |
10 { | 15 { |
11 "name": "stderr", | 16 "name": "stderr", |
12 "output_type": "stream", | 17 "output_type": "stream", |
13 "text": [ | 18 "text": [ |
14 "/homes/mp305/anaconda/lib/python2.7/site-packages/librosa/core/audio.py:33: UserWarning: Could not import scikits.samplerate. Falling back to scipy.signal\n", | 19 "/homes/mp305/anaconda/lib/python2.7/site-packages/librosa/core/audio.py:33: UserWarning: Could not import scikits.samplerate. Falling back to scipy.signal\n", |
35 ] | 40 ] |
36 }, | 41 }, |
37 { | 42 { |
38 "cell_type": "code", | 43 "cell_type": "code", |
39 "execution_count": 2, | 44 "execution_count": 2, |
40 "metadata": { | 45 "metadata": {}, |
41 "collapsed": true | |
42 }, | |
43 "outputs": [], | 46 "outputs": [], |
44 "source": [ | 47 "source": [ |
45 "OUTPUT_FILES = load_dataset.OUTPUT_FILES\n", | 48 "OUTPUT_FILES = load_dataset.OUTPUT_FILES\n", |
46 "n_iters = 10" | 49 "n_iters = 10" |
47 ] | 50 ] |
48 }, | 51 }, |
49 { | 52 { |
50 "cell_type": "code", | 53 "cell_type": "code", |
51 "execution_count": 5, | 54 "execution_count": 5, |
52 "metadata": { | 55 "metadata": {}, |
53 "collapsed": false | |
54 }, | |
55 "outputs": [ | 56 "outputs": [ |
56 { | 57 { |
57 "data": { | 58 "data": { |
58 "text/plain": [ | 59 "text/plain": [ |
59 "(8396, 108)" | 60 "(8396, 108)" |
70 ] | 71 ] |
71 }, | 72 }, |
72 { | 73 { |
73 "cell_type": "code", | 74 "cell_type": "code", |
74 "execution_count": 48, | 75 "execution_count": 48, |
75 "metadata": { | 76 "metadata": {}, |
76 "collapsed": false | |
77 }, | |
78 "outputs": [ | 77 "outputs": [ |
79 { | 78 { |
80 "name": "stdout", | 79 "name": "stdout", |
81 "output_type": "stream", | 80 "output_type": "stream", |
82 "text": [ | 81 "text": [ |
282 ] | 281 ] |
283 }, | 282 }, |
284 { | 283 { |
285 "cell_type": "code", | 284 "cell_type": "code", |
286 "execution_count": 52, | 285 "execution_count": 52, |
287 "metadata": { | 286 "metadata": {}, |
288 "collapsed": false | |
289 }, | |
290 "outputs": [ | 287 "outputs": [ |
291 { | 288 { |
292 "name": "stdout", | 289 "name": "stdout", |
293 "output_type": "stream", | 290 "output_type": "stream", |
294 "text": [ | 291 "text": [ |
453 ] | 450 ] |
454 }, | 451 }, |
455 { | 452 { |
456 "cell_type": "code", | 453 "cell_type": "code", |
457 "execution_count": 56, | 454 "execution_count": 56, |
458 "metadata": { | 455 "metadata": {}, |
459 "collapsed": false | |
460 }, | |
461 "outputs": [ | 456 "outputs": [ |
462 { | 457 { |
463 "data": { | 458 "data": { |
464 "text/plain": [ | 459 "text/plain": [ |
465 "array([ 59, 1, 1, 1, 1, 733, 733, 733, 733, 733, 733, 733, 733,\n", | 460 "array([ 59, 1, 1, 1, 1, 733, 733, 733, 733, 733, 733, 733, 733,\n", |
478 ] | 473 ] |
479 }, | 474 }, |
480 { | 475 { |
481 "cell_type": "code", | 476 "cell_type": "code", |
482 "execution_count": 8, | 477 "execution_count": 8, |
483 "metadata": { | 478 "metadata": {}, |
484 "collapsed": false | |
485 }, | |
486 "outputs": [ | 479 "outputs": [ |
487 { | 480 { |
488 "data": { | 481 "data": { |
489 "text/html": [ | 482 "text/html": [ |
490 "<div>\n", | 483 "<div>\n", |
739 ] | 732 ] |
740 }, | 733 }, |
741 { | 734 { |
742 "cell_type": "code", | 735 "cell_type": "code", |
743 "execution_count": 47, | 736 "execution_count": 47, |
744 "metadata": { | 737 "metadata": {}, |
745 "collapsed": false | |
746 }, | |
747 "outputs": [ | 738 "outputs": [ |
748 { | 739 { |
749 "name": "stdout", | 740 "name": "stdout", |
750 "output_type": "stream", | 741 "output_type": "stream", |
751 "text": [ | 742 "text": [ |
840 ] | 831 ] |
841 }, | 832 }, |
842 { | 833 { |
843 "cell_type": "code", | 834 "cell_type": "code", |
844 "execution_count": 59, | 835 "execution_count": 59, |
845 "metadata": { | 836 "metadata": {}, |
846 "collapsed": false | |
847 }, | |
848 "outputs": [ | 837 "outputs": [ |
849 { | 838 { |
850 "name": "stdout", | 839 "name": "stdout", |
851 "output_type": "stream", | 840 "output_type": "stream", |
852 "text": [ | 841 "text": [ |
4589 ] | 4578 ] |
4590 }, | 4579 }, |
4591 { | 4580 { |
4592 "cell_type": "code", | 4581 "cell_type": "code", |
4593 "execution_count": 21, | 4582 "execution_count": 21, |
4594 "metadata": { | 4583 "metadata": {}, |
4595 "collapsed": false | |
4596 }, | |
4597 "outputs": [ | 4584 "outputs": [ |
4598 { | 4585 { |
4599 "name": "stdout", | 4586 "name": "stdout", |
4600 "output_type": "stream", | 4587 "output_type": "stream", |
4601 "text": [ | 4588 "text": [ |
4941 ] | 4928 ] |
4942 }, | 4929 }, |
4943 { | 4930 { |
4944 "cell_type": "code", | 4931 "cell_type": "code", |
4945 "execution_count": 52, | 4932 "execution_count": 52, |
4946 "metadata": { | 4933 "metadata": {}, |
4947 "collapsed": false | |
4948 }, | |
4949 "outputs": [ | 4934 "outputs": [ |
4950 { | 4935 { |
4951 "name": "stdout", | 4936 "name": "stdout", |
4952 "output_type": "stream", | 4937 "output_type": "stream", |
4953 "text": [ | 4938 "text": [ |
5269 ] | 5254 ] |
5270 }, | 5255 }, |
5271 { | 5256 { |
5272 "cell_type": "code", | 5257 "cell_type": "code", |
5273 "execution_count": 67, | 5258 "execution_count": 67, |
5274 "metadata": { | 5259 "metadata": {}, |
5275 "collapsed": false | |
5276 }, | |
5277 "outputs": [ | 5260 "outputs": [ |
5278 { | 5261 { |
5279 "name": "stdout", | 5262 "name": "stdout", |
5280 "output_type": "stream", | 5263 "output_type": "stream", |
5281 "text": [ | 5264 "text": [ |
5611 "<br> Sort by outlier percentage in descending order." | 5594 "<br> Sort by outlier percentage in descending order." |
5612 ] | 5595 ] |
5613 }, | 5596 }, |
5614 { | 5597 { |
5615 "cell_type": "code", | 5598 "cell_type": "code", |
5616 "execution_count": 68, | 5599 "execution_count": 7, |
5617 "metadata": { | 5600 "metadata": { |
5618 "collapsed": true | 5601 "collapsed": true |
5619 }, | 5602 }, |
5620 "outputs": [], | 5603 "outputs": [], |
5621 "source": [ | 5604 "source": [ |
5628 " ranked_outliers = pd.concat([ranked_outliers, df_global['Outliers']], axis=1)" | 5611 " ranked_outliers = pd.concat([ranked_outliers, df_global['Outliers']], axis=1)" |
5629 ] | 5612 ] |
5630 }, | 5613 }, |
5631 { | 5614 { |
5632 "cell_type": "code", | 5615 "cell_type": "code", |
5633 "execution_count": 69, | 5616 "execution_count": 8, |
5634 "metadata": { | 5617 "metadata": {}, |
5635 "collapsed": false | |
5636 }, | |
5637 "outputs": [ | 5618 "outputs": [ |
5638 { | 5619 { |
5639 "data": { | 5620 "data": { |
5640 "text/plain": [ | 5621 "text/plain": [ |
5641 "(137, 10)" | 5622 "(137, 10)" |
5642 ] | 5623 ] |
5643 }, | 5624 }, |
5644 "execution_count": 69, | 5625 "execution_count": 8, |
5645 "metadata": {}, | 5626 "metadata": {}, |
5646 "output_type": "execute_result" | 5627 "output_type": "execute_result" |
5647 } | 5628 } |
5648 ], | 5629 ], |
5649 "source": [ | 5630 "source": [ |
5657 "Remove countries with 0% outliers as these are in random (probably alphabetical) order." | 5638 "Remove countries with 0% outliers as these are in random (probably alphabetical) order." |
5658 ] | 5639 ] |
5659 }, | 5640 }, |
5660 { | 5641 { |
5661 "cell_type": "code", | 5642 "cell_type": "code", |
5662 "execution_count": 70, | 5643 "execution_count": 9, |
5663 "metadata": { | 5644 "metadata": {}, |
5664 "collapsed": false | |
5665 }, | |
5666 "outputs": [ | 5645 "outputs": [ |
5667 { | 5646 { |
5668 "name": "stdout", | 5647 "name": "stdout", |
5669 "output_type": "stream", | 5648 "output_type": "stream", |
5670 "text": [ | 5649 "text": [ |
5700 "source": [ | 5679 "source": [ |
5701 "zero_idx = np.where(np.sum(ranked_outliers, axis=1)==0)[0]\n", | 5680 "zero_idx = np.where(np.sum(ranked_outliers, axis=1)==0)[0]\n", |
5702 "first_zero_idx = np.min(zero_idx)\n", | 5681 "first_zero_idx = np.min(zero_idx)\n", |
5703 "ranked_countries = ranked_countries.iloc[:first_zero_idx, :]\n", | 5682 "ranked_countries = ranked_countries.iloc[:first_zero_idx, :]\n", |
5704 "ranked_outliers = ranked_outliers.iloc[:first_zero_idx, :]\n", | 5683 "ranked_outliers = ranked_outliers.iloc[:first_zero_idx, :]\n", |
5684 "ranked_countries_arr = ranked_countries.get_values()\n", | |
5705 "\n", | 5685 "\n", |
5706 "print ranked_countries.head()\n", | 5686 "print ranked_countries.head()\n", |
5707 "print ranked_outliers.head()" | 5687 "print ranked_outliers.head()" |
5708 ] | 5688 ] |
5709 }, | 5689 }, |
5735 ] | 5715 ] |
5736 }, | 5716 }, |
5737 { | 5717 { |
5738 "cell_type": "code", | 5718 "cell_type": "code", |
5739 "execution_count": 72, | 5719 "execution_count": 72, |
5740 "metadata": { | 5720 "metadata": {}, |
5741 "collapsed": false | |
5742 }, | |
5743 "outputs": [ | 5721 "outputs": [ |
5744 { | 5722 { |
5745 "name": "stdout", | 5723 "name": "stdout", |
5746 "output_type": "stream", | 5724 "output_type": "stream", |
5747 "text": [ | 5725 "text": [ |
5754 ] | 5732 ] |
5755 }, | 5733 }, |
5756 { | 5734 { |
5757 "cell_type": "code", | 5735 "cell_type": "code", |
5758 "execution_count": 80, | 5736 "execution_count": 80, |
5759 "metadata": { | 5737 "metadata": {}, |
5760 "collapsed": false | |
5761 }, | |
5762 "outputs": [ | 5738 "outputs": [ |
5763 { | 5739 { |
5764 "name": "stdout", | 5740 "name": "stdout", |
5765 "output_type": "stream", | 5741 "output_type": "stream", |
5766 "text": [ | 5742 "text": [ |
5785 ] | 5761 ] |
5786 }, | 5762 }, |
5787 { | 5763 { |
5788 "cell_type": "code", | 5764 "cell_type": "code", |
5789 "execution_count": 81, | 5765 "execution_count": 81, |
5790 "metadata": { | 5766 "metadata": {}, |
5791 "collapsed": false | |
5792 }, | |
5793 "outputs": [ | 5767 "outputs": [ |
5794 { | 5768 { |
5795 "name": "stdout", | 5769 "name": "stdout", |
5796 "output_type": "stream", | 5770 "output_type": "stream", |
5797 "text": [ | 5771 "text": [ |
5822 ] | 5796 ] |
5823 }, | 5797 }, |
5824 { | 5798 { |
5825 "cell_type": "code", | 5799 "cell_type": "code", |
5826 "execution_count": 76, | 5800 "execution_count": 76, |
5827 "metadata": { | 5801 "metadata": {}, |
5828 "collapsed": false | |
5829 }, | |
5830 "outputs": [ | 5802 "outputs": [ |
5831 { | 5803 { |
5832 "data": { | 5804 "data": { |
5833 "text/plain": [ | 5805 "text/plain": [ |
5834 "{'Chad', 'French Guiana', 'Gambia'}" | 5806 "{'Chad', 'French Guiana', 'Gambia'}" |
5842 "source": [ | 5814 "source": [ |
5843 "common_set" | 5815 "common_set" |
5844 ] | 5816 ] |
5845 }, | 5817 }, |
5846 { | 5818 { |
5819 "cell_type": "markdown", | |
5820 "metadata": {}, | |
5821 "source": [ | |
5822 "## Try precision at K" | |
5823 ] | |
5824 }, | |
5825 { | |
5847 "cell_type": "code", | 5826 "cell_type": "code", |
5848 "execution_count": 97, | 5827 "execution_count": 10, |
5849 "metadata": { | 5828 "metadata": {}, |
5850 "collapsed": true | |
5851 }, | |
5852 "outputs": [], | 5829 "outputs": [], |
5853 "source": [ | 5830 "source": [ |
5854 "# majority voting + precision at K (top5?)\n", | 5831 "# majority voting + precision at K (top5?)\n", |
5855 "from collections import Counter\n", | 5832 "from collections import Counter\n", |
5856 "K_vote = 10\n", | 5833 "K_vote = 10\n", |
5857 "country_vote = Counter(ranked_countries_arr[:K_vote, :].ravel())" | 5834 "country_vote = Counter(ranked_countries_arr[:K_vote, :].ravel())" |
5858 ] | 5835 ] |
5859 }, | 5836 }, |
5860 { | 5837 { |
5861 "cell_type": "code", | 5838 "cell_type": "code", |
5862 "execution_count": 98, | 5839 "execution_count": 11, |
5863 "metadata": { | 5840 "metadata": {}, |
5864 "collapsed": false | |
5865 }, | |
5866 "outputs": [ | 5841 "outputs": [ |
5867 { | 5842 { |
5868 "data": { | 5843 "data": { |
5869 "text/html": [ | 5844 "text/html": [ |
5870 "<div>\n", | 5845 "<div>\n", |
5913 "2 Belize 2\n", | 5888 "2 Belize 2\n", |
5914 "3 Chad 10\n", | 5889 "3 Chad 10\n", |
5915 "4 Bhutan 7" | 5890 "4 Bhutan 7" |
5916 ] | 5891 ] |
5917 }, | 5892 }, |
5918 "execution_count": 98, | 5893 "execution_count": 11, |
5919 "metadata": {}, | 5894 "metadata": {}, |
5920 "output_type": "execute_result" | 5895 "output_type": "execute_result" |
5921 } | 5896 } |
5922 ], | 5897 ], |
5923 "source": [ | 5898 "source": [ |
5925 "df_country_vote.head()" | 5900 "df_country_vote.head()" |
5926 ] | 5901 ] |
5927 }, | 5902 }, |
5928 { | 5903 { |
5929 "cell_type": "code", | 5904 "cell_type": "code", |
5930 "execution_count": 99, | 5905 "execution_count": 12, |
5931 "metadata": { | 5906 "metadata": {}, |
5932 "collapsed": false | |
5933 }, | |
5934 "outputs": [ | 5907 "outputs": [ |
5935 { | 5908 { |
5936 "data": { | 5909 "data": { |
5937 "text/html": [ | 5910 "text/html": [ |
5938 "<div>\n", | 5911 "<div>\n", |
6065 "0 Brazil 1\n", | 6038 "0 Brazil 1\n", |
6066 "13 Laos 1\n", | 6039 "13 Laos 1\n", |
6067 "9 Malta 1" | 6040 "9 Malta 1" |
6068 ] | 6041 ] |
6069 }, | 6042 }, |
6070 "execution_count": 99, | 6043 "execution_count": 12, |
6071 "metadata": {}, | 6044 "metadata": {}, |
6072 "output_type": "execute_result" | 6045 "output_type": "execute_result" |
6073 } | 6046 } |
6074 ], | 6047 ], |
6075 "source": [ | 6048 "source": [ |
6076 "df_country_vote.sort_values(0, ascending=False)" | 6049 "df_country_vote.sort_values(0, ascending=False)" |
6077 ] | 6050 ] |
6078 }, | 6051 }, |
6079 { | 6052 { |
6080 "cell_type": "code", | 6053 "cell_type": "code", |
6081 "execution_count": 102, | 6054 "execution_count": 14, |
6082 "metadata": { | 6055 "metadata": {}, |
6083 "collapsed": false | |
6084 }, | |
6085 "outputs": [ | 6056 "outputs": [ |
6086 { | 6057 { |
6087 "data": { | 6058 "name": "stdout", |
6088 "text/plain": [ | 6059 "output_type": "stream", |
6089 "0.51000000000000001" | 6060 "text": [ |
6090 ] | 6061 "0.51 0.0830662386292\n" |
6091 }, | 6062 ] |
6092 "execution_count": 102, | |
6093 "metadata": {}, | |
6094 "output_type": "execute_result" | |
6095 } | 6063 } |
6096 ], | 6064 ], |
6097 "source": [ | 6065 "source": [ |
6098 "def precision_at_k(array, gr_truth, k):\n", | 6066 "def precision_at_k(array, gr_truth, k):\n", |
6099 " return len(set(array[:k]) & set(gr_truth[:k])) / float(k)\n", | 6067 " return len(set(array[:k]) & set(gr_truth[:k])) / float(k)\n", |
6102 "ground_truth = df_country_vote['index'].get_values()\n", | 6070 "ground_truth = df_country_vote['index'].get_values()\n", |
6103 "p_ = []\n", | 6071 "p_ = []\n", |
6104 "for j in range(ranked_countries_arr.shape[1]):\n", | 6072 "for j in range(ranked_countries_arr.shape[1]):\n", |
6105 " p_.append(precision_at_k(ranked_countries_arr[:, j], ground_truth, k))\n", | 6073 " p_.append(precision_at_k(ranked_countries_arr[:, j], ground_truth, k))\n", |
6106 "p_ = np.array(p_)\n", | 6074 "p_ = np.array(p_)\n", |
6107 "np.mean(p_)" | 6075 "print np.mean(p_), np.std(p_)" |
6076 ] | |
6077 }, | |
6078 { | |
6079 "cell_type": "code", | |
6080 "execution_count": 15, | |
6081 "metadata": {}, | |
6082 "outputs": [ | |
6083 { | |
6084 "data": { | |
6085 "text/plain": [ | |
6086 "array([ 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.7, 0.4])" | |
6087 ] | |
6088 }, | |
6089 "execution_count": 15, | |
6090 "metadata": {}, | |
6091 "output_type": "execute_result" | |
6092 } | |
6093 ], | |
6094 "source": [ | |
6095 "p_" | |
6108 ] | 6096 ] |
6109 }, | 6097 }, |
6110 { | 6098 { |
6111 "cell_type": "code", | 6099 "cell_type": "code", |
6112 "execution_count": null, | 6100 "execution_count": null, |
6131 "file_extension": ".py", | 6119 "file_extension": ".py", |
6132 "mimetype": "text/x-python", | 6120 "mimetype": "text/x-python", |
6133 "name": "python", | 6121 "name": "python", |
6134 "nbconvert_exporter": "python", | 6122 "nbconvert_exporter": "python", |
6135 "pygments_lexer": "ipython2", | 6123 "pygments_lexer": "ipython2", |
6136 "version": "2.7.11" | 6124 "version": "2.7.12" |
6137 } | 6125 } |
6138 }, | 6126 }, |
6139 "nbformat": 4, | 6127 "nbformat": 4, |
6140 "nbformat_minor": 1 | 6128 "nbformat_minor": 1 |
6141 } | 6129 } |