comparison notebooks/sensitivity_experiment.ipynb @ 82:4395037087b6 branch-tests

notebooks
author mpanteli <m.x.panteli@gmail.com>
date Tue, 26 Sep 2017 21:18:26 +0100
parents 92a5e280946d
children e279ccea5f9b
comparison
equal deleted inserted replaced
79:98fc06ba2938 82:4395037087b6
1 { 1 {
2 "cells": [ 2 "cells": [
3 { 3 {
4 "cell_type": "code", 4 "cell_type": "code",
5 "execution_count": 1, 5 "execution_count": 1,
6 "metadata": { 6 "metadata": {},
7 "collapsed": false
8 },
9 "outputs": [ 7 "outputs": [
8 {
9 "name": "stdout",
10 "output_type": "stream",
11 "text": [
12 "ERROR! Session/line number was not unique in database. History logging moved to new session 32\n"
13 ]
14 },
10 { 15 {
11 "name": "stderr", 16 "name": "stderr",
12 "output_type": "stream", 17 "output_type": "stream",
13 "text": [ 18 "text": [
14 "/homes/mp305/anaconda/lib/python2.7/site-packages/librosa/core/audio.py:33: UserWarning: Could not import scikits.samplerate. Falling back to scipy.signal\n", 19 "/homes/mp305/anaconda/lib/python2.7/site-packages/librosa/core/audio.py:33: UserWarning: Could not import scikits.samplerate. Falling back to scipy.signal\n",
35 ] 40 ]
36 }, 41 },
37 { 42 {
38 "cell_type": "code", 43 "cell_type": "code",
39 "execution_count": 2, 44 "execution_count": 2,
40 "metadata": { 45 "metadata": {},
41 "collapsed": true
42 },
43 "outputs": [], 46 "outputs": [],
44 "source": [ 47 "source": [
45 "OUTPUT_FILES = load_dataset.OUTPUT_FILES\n", 48 "OUTPUT_FILES = load_dataset.OUTPUT_FILES\n",
46 "n_iters = 10" 49 "n_iters = 10"
47 ] 50 ]
48 }, 51 },
49 { 52 {
50 "cell_type": "code", 53 "cell_type": "code",
51 "execution_count": 5, 54 "execution_count": 5,
52 "metadata": { 55 "metadata": {},
53 "collapsed": false
54 },
55 "outputs": [ 56 "outputs": [
56 { 57 {
57 "data": { 58 "data": {
58 "text/plain": [ 59 "text/plain": [
59 "(8396, 108)" 60 "(8396, 108)"
70 ] 71 ]
71 }, 72 },
72 { 73 {
73 "cell_type": "code", 74 "cell_type": "code",
74 "execution_count": 48, 75 "execution_count": 48,
75 "metadata": { 76 "metadata": {},
76 "collapsed": false
77 },
78 "outputs": [ 77 "outputs": [
79 { 78 {
80 "name": "stdout", 79 "name": "stdout",
81 "output_type": "stream", 80 "output_type": "stream",
82 "text": [ 81 "text": [
282 ] 281 ]
283 }, 282 },
284 { 283 {
285 "cell_type": "code", 284 "cell_type": "code",
286 "execution_count": 52, 285 "execution_count": 52,
287 "metadata": { 286 "metadata": {},
288 "collapsed": false
289 },
290 "outputs": [ 287 "outputs": [
291 { 288 {
292 "name": "stdout", 289 "name": "stdout",
293 "output_type": "stream", 290 "output_type": "stream",
294 "text": [ 291 "text": [
453 ] 450 ]
454 }, 451 },
455 { 452 {
456 "cell_type": "code", 453 "cell_type": "code",
457 "execution_count": 56, 454 "execution_count": 56,
458 "metadata": { 455 "metadata": {},
459 "collapsed": false
460 },
461 "outputs": [ 456 "outputs": [
462 { 457 {
463 "data": { 458 "data": {
464 "text/plain": [ 459 "text/plain": [
465 "array([ 59, 1, 1, 1, 1, 733, 733, 733, 733, 733, 733, 733, 733,\n", 460 "array([ 59, 1, 1, 1, 1, 733, 733, 733, 733, 733, 733, 733, 733,\n",
478 ] 473 ]
479 }, 474 },
480 { 475 {
481 "cell_type": "code", 476 "cell_type": "code",
482 "execution_count": 8, 477 "execution_count": 8,
483 "metadata": { 478 "metadata": {},
484 "collapsed": false
485 },
486 "outputs": [ 479 "outputs": [
487 { 480 {
488 "data": { 481 "data": {
489 "text/html": [ 482 "text/html": [
490 "<div>\n", 483 "<div>\n",
739 ] 732 ]
740 }, 733 },
741 { 734 {
742 "cell_type": "code", 735 "cell_type": "code",
743 "execution_count": 47, 736 "execution_count": 47,
744 "metadata": { 737 "metadata": {},
745 "collapsed": false
746 },
747 "outputs": [ 738 "outputs": [
748 { 739 {
749 "name": "stdout", 740 "name": "stdout",
750 "output_type": "stream", 741 "output_type": "stream",
751 "text": [ 742 "text": [
840 ] 831 ]
841 }, 832 },
842 { 833 {
843 "cell_type": "code", 834 "cell_type": "code",
844 "execution_count": 59, 835 "execution_count": 59,
845 "metadata": { 836 "metadata": {},
846 "collapsed": false
847 },
848 "outputs": [ 837 "outputs": [
849 { 838 {
850 "name": "stdout", 839 "name": "stdout",
851 "output_type": "stream", 840 "output_type": "stream",
852 "text": [ 841 "text": [
4589 ] 4578 ]
4590 }, 4579 },
4591 { 4580 {
4592 "cell_type": "code", 4581 "cell_type": "code",
4593 "execution_count": 21, 4582 "execution_count": 21,
4594 "metadata": { 4583 "metadata": {},
4595 "collapsed": false
4596 },
4597 "outputs": [ 4584 "outputs": [
4598 { 4585 {
4599 "name": "stdout", 4586 "name": "stdout",
4600 "output_type": "stream", 4587 "output_type": "stream",
4601 "text": [ 4588 "text": [
4941 ] 4928 ]
4942 }, 4929 },
4943 { 4930 {
4944 "cell_type": "code", 4931 "cell_type": "code",
4945 "execution_count": 52, 4932 "execution_count": 52,
4946 "metadata": { 4933 "metadata": {},
4947 "collapsed": false
4948 },
4949 "outputs": [ 4934 "outputs": [
4950 { 4935 {
4951 "name": "stdout", 4936 "name": "stdout",
4952 "output_type": "stream", 4937 "output_type": "stream",
4953 "text": [ 4938 "text": [
5269 ] 5254 ]
5270 }, 5255 },
5271 { 5256 {
5272 "cell_type": "code", 5257 "cell_type": "code",
5273 "execution_count": 67, 5258 "execution_count": 67,
5274 "metadata": { 5259 "metadata": {},
5275 "collapsed": false
5276 },
5277 "outputs": [ 5260 "outputs": [
5278 { 5261 {
5279 "name": "stdout", 5262 "name": "stdout",
5280 "output_type": "stream", 5263 "output_type": "stream",
5281 "text": [ 5264 "text": [
5611 "<br> Sort by outlier percentage in descending order." 5594 "<br> Sort by outlier percentage in descending order."
5612 ] 5595 ]
5613 }, 5596 },
5614 { 5597 {
5615 "cell_type": "code", 5598 "cell_type": "code",
5616 "execution_count": 68, 5599 "execution_count": 7,
5617 "metadata": { 5600 "metadata": {
5618 "collapsed": true 5601 "collapsed": true
5619 }, 5602 },
5620 "outputs": [], 5603 "outputs": [],
5621 "source": [ 5604 "source": [
5628 " ranked_outliers = pd.concat([ranked_outliers, df_global['Outliers']], axis=1)" 5611 " ranked_outliers = pd.concat([ranked_outliers, df_global['Outliers']], axis=1)"
5629 ] 5612 ]
5630 }, 5613 },
5631 { 5614 {
5632 "cell_type": "code", 5615 "cell_type": "code",
5633 "execution_count": 69, 5616 "execution_count": 8,
5634 "metadata": { 5617 "metadata": {},
5635 "collapsed": false
5636 },
5637 "outputs": [ 5618 "outputs": [
5638 { 5619 {
5639 "data": { 5620 "data": {
5640 "text/plain": [ 5621 "text/plain": [
5641 "(137, 10)" 5622 "(137, 10)"
5642 ] 5623 ]
5643 }, 5624 },
5644 "execution_count": 69, 5625 "execution_count": 8,
5645 "metadata": {}, 5626 "metadata": {},
5646 "output_type": "execute_result" 5627 "output_type": "execute_result"
5647 } 5628 }
5648 ], 5629 ],
5649 "source": [ 5630 "source": [
5657 "Remove countries with 0% outliers as these are in random (probably alphabetical) order." 5638 "Remove countries with 0% outliers as these are in random (probably alphabetical) order."
5658 ] 5639 ]
5659 }, 5640 },
5660 { 5641 {
5661 "cell_type": "code", 5642 "cell_type": "code",
5662 "execution_count": 70, 5643 "execution_count": 9,
5663 "metadata": { 5644 "metadata": {},
5664 "collapsed": false
5665 },
5666 "outputs": [ 5645 "outputs": [
5667 { 5646 {
5668 "name": "stdout", 5647 "name": "stdout",
5669 "output_type": "stream", 5648 "output_type": "stream",
5670 "text": [ 5649 "text": [
5700 "source": [ 5679 "source": [
5701 "zero_idx = np.where(np.sum(ranked_outliers, axis=1)==0)[0]\n", 5680 "zero_idx = np.where(np.sum(ranked_outliers, axis=1)==0)[0]\n",
5702 "first_zero_idx = np.min(zero_idx)\n", 5681 "first_zero_idx = np.min(zero_idx)\n",
5703 "ranked_countries = ranked_countries.iloc[:first_zero_idx, :]\n", 5682 "ranked_countries = ranked_countries.iloc[:first_zero_idx, :]\n",
5704 "ranked_outliers = ranked_outliers.iloc[:first_zero_idx, :]\n", 5683 "ranked_outliers = ranked_outliers.iloc[:first_zero_idx, :]\n",
5684 "ranked_countries_arr = ranked_countries.get_values()\n",
5705 "\n", 5685 "\n",
5706 "print ranked_countries.head()\n", 5686 "print ranked_countries.head()\n",
5707 "print ranked_outliers.head()" 5687 "print ranked_outliers.head()"
5708 ] 5688 ]
5709 }, 5689 },
5735 ] 5715 ]
5736 }, 5716 },
5737 { 5717 {
5738 "cell_type": "code", 5718 "cell_type": "code",
5739 "execution_count": 72, 5719 "execution_count": 72,
5740 "metadata": { 5720 "metadata": {},
5741 "collapsed": false
5742 },
5743 "outputs": [ 5721 "outputs": [
5744 { 5722 {
5745 "name": "stdout", 5723 "name": "stdout",
5746 "output_type": "stream", 5724 "output_type": "stream",
5747 "text": [ 5725 "text": [
5754 ] 5732 ]
5755 }, 5733 },
5756 { 5734 {
5757 "cell_type": "code", 5735 "cell_type": "code",
5758 "execution_count": 80, 5736 "execution_count": 80,
5759 "metadata": { 5737 "metadata": {},
5760 "collapsed": false
5761 },
5762 "outputs": [ 5738 "outputs": [
5763 { 5739 {
5764 "name": "stdout", 5740 "name": "stdout",
5765 "output_type": "stream", 5741 "output_type": "stream",
5766 "text": [ 5742 "text": [
5785 ] 5761 ]
5786 }, 5762 },
5787 { 5763 {
5788 "cell_type": "code", 5764 "cell_type": "code",
5789 "execution_count": 81, 5765 "execution_count": 81,
5790 "metadata": { 5766 "metadata": {},
5791 "collapsed": false
5792 },
5793 "outputs": [ 5767 "outputs": [
5794 { 5768 {
5795 "name": "stdout", 5769 "name": "stdout",
5796 "output_type": "stream", 5770 "output_type": "stream",
5797 "text": [ 5771 "text": [
5822 ] 5796 ]
5823 }, 5797 },
5824 { 5798 {
5825 "cell_type": "code", 5799 "cell_type": "code",
5826 "execution_count": 76, 5800 "execution_count": 76,
5827 "metadata": { 5801 "metadata": {},
5828 "collapsed": false
5829 },
5830 "outputs": [ 5802 "outputs": [
5831 { 5803 {
5832 "data": { 5804 "data": {
5833 "text/plain": [ 5805 "text/plain": [
5834 "{'Chad', 'French Guiana', 'Gambia'}" 5806 "{'Chad', 'French Guiana', 'Gambia'}"
5842 "source": [ 5814 "source": [
5843 "common_set" 5815 "common_set"
5844 ] 5816 ]
5845 }, 5817 },
5846 { 5818 {
5819 "cell_type": "markdown",
5820 "metadata": {},
5821 "source": [
5822 "## Try precision at K"
5823 ]
5824 },
5825 {
5847 "cell_type": "code", 5826 "cell_type": "code",
5848 "execution_count": 97, 5827 "execution_count": 10,
5849 "metadata": { 5828 "metadata": {},
5850 "collapsed": true
5851 },
5852 "outputs": [], 5829 "outputs": [],
5853 "source": [ 5830 "source": [
5854 "# majority voting + precision at K (top5?)\n", 5831 "# majority voting + precision at K (top5?)\n",
5855 "from collections import Counter\n", 5832 "from collections import Counter\n",
5856 "K_vote = 10\n", 5833 "K_vote = 10\n",
5857 "country_vote = Counter(ranked_countries_arr[:K_vote, :].ravel())" 5834 "country_vote = Counter(ranked_countries_arr[:K_vote, :].ravel())"
5858 ] 5835 ]
5859 }, 5836 },
5860 { 5837 {
5861 "cell_type": "code", 5838 "cell_type": "code",
5862 "execution_count": 98, 5839 "execution_count": 11,
5863 "metadata": { 5840 "metadata": {},
5864 "collapsed": false
5865 },
5866 "outputs": [ 5841 "outputs": [
5867 { 5842 {
5868 "data": { 5843 "data": {
5869 "text/html": [ 5844 "text/html": [
5870 "<div>\n", 5845 "<div>\n",
5913 "2 Belize 2\n", 5888 "2 Belize 2\n",
5914 "3 Chad 10\n", 5889 "3 Chad 10\n",
5915 "4 Bhutan 7" 5890 "4 Bhutan 7"
5916 ] 5891 ]
5917 }, 5892 },
5918 "execution_count": 98, 5893 "execution_count": 11,
5919 "metadata": {}, 5894 "metadata": {},
5920 "output_type": "execute_result" 5895 "output_type": "execute_result"
5921 } 5896 }
5922 ], 5897 ],
5923 "source": [ 5898 "source": [
5925 "df_country_vote.head()" 5900 "df_country_vote.head()"
5926 ] 5901 ]
5927 }, 5902 },
5928 { 5903 {
5929 "cell_type": "code", 5904 "cell_type": "code",
5930 "execution_count": 99, 5905 "execution_count": 12,
5931 "metadata": { 5906 "metadata": {},
5932 "collapsed": false
5933 },
5934 "outputs": [ 5907 "outputs": [
5935 { 5908 {
5936 "data": { 5909 "data": {
5937 "text/html": [ 5910 "text/html": [
5938 "<div>\n", 5911 "<div>\n",
6065 "0 Brazil 1\n", 6038 "0 Brazil 1\n",
6066 "13 Laos 1\n", 6039 "13 Laos 1\n",
6067 "9 Malta 1" 6040 "9 Malta 1"
6068 ] 6041 ]
6069 }, 6042 },
6070 "execution_count": 99, 6043 "execution_count": 12,
6071 "metadata": {}, 6044 "metadata": {},
6072 "output_type": "execute_result" 6045 "output_type": "execute_result"
6073 } 6046 }
6074 ], 6047 ],
6075 "source": [ 6048 "source": [
6076 "df_country_vote.sort_values(0, ascending=False)" 6049 "df_country_vote.sort_values(0, ascending=False)"
6077 ] 6050 ]
6078 }, 6051 },
6079 { 6052 {
6080 "cell_type": "code", 6053 "cell_type": "code",
6081 "execution_count": 102, 6054 "execution_count": 14,
6082 "metadata": { 6055 "metadata": {},
6083 "collapsed": false
6084 },
6085 "outputs": [ 6056 "outputs": [
6086 { 6057 {
6087 "data": { 6058 "name": "stdout",
6088 "text/plain": [ 6059 "output_type": "stream",
6089 "0.51000000000000001" 6060 "text": [
6090 ] 6061 "0.51 0.0830662386292\n"
6091 }, 6062 ]
6092 "execution_count": 102,
6093 "metadata": {},
6094 "output_type": "execute_result"
6095 } 6063 }
6096 ], 6064 ],
6097 "source": [ 6065 "source": [
6098 "def precision_at_k(array, gr_truth, k):\n", 6066 "def precision_at_k(array, gr_truth, k):\n",
6099 " return len(set(array[:k]) & set(gr_truth[:k])) / float(k)\n", 6067 " return len(set(array[:k]) & set(gr_truth[:k])) / float(k)\n",
6102 "ground_truth = df_country_vote['index'].get_values()\n", 6070 "ground_truth = df_country_vote['index'].get_values()\n",
6103 "p_ = []\n", 6071 "p_ = []\n",
6104 "for j in range(ranked_countries_arr.shape[1]):\n", 6072 "for j in range(ranked_countries_arr.shape[1]):\n",
6105 " p_.append(precision_at_k(ranked_countries_arr[:, j], ground_truth, k))\n", 6073 " p_.append(precision_at_k(ranked_countries_arr[:, j], ground_truth, k))\n",
6106 "p_ = np.array(p_)\n", 6074 "p_ = np.array(p_)\n",
6107 "np.mean(p_)" 6075 "print np.mean(p_), np.std(p_)"
6076 ]
6077 },
6078 {
6079 "cell_type": "code",
6080 "execution_count": 15,
6081 "metadata": {},
6082 "outputs": [
6083 {
6084 "data": {
6085 "text/plain": [
6086 "array([ 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.4, 0.7, 0.4])"
6087 ]
6088 },
6089 "execution_count": 15,
6090 "metadata": {},
6091 "output_type": "execute_result"
6092 }
6093 ],
6094 "source": [
6095 "p_"
6108 ] 6096 ]
6109 }, 6097 },
6110 { 6098 {
6111 "cell_type": "code", 6099 "cell_type": "code",
6112 "execution_count": null, 6100 "execution_count": null,
6131 "file_extension": ".py", 6119 "file_extension": ".py",
6132 "mimetype": "text/x-python", 6120 "mimetype": "text/x-python",
6133 "name": "python", 6121 "name": "python",
6134 "nbconvert_exporter": "python", 6122 "nbconvert_exporter": "python",
6135 "pygments_lexer": "ipython2", 6123 "pygments_lexer": "ipython2",
6136 "version": "2.7.11" 6124 "version": "2.7.12"
6137 } 6125 }
6138 }, 6126 },
6139 "nbformat": 4, 6127 "nbformat": 4,
6140 "nbformat_minor": 1 6128 "nbformat_minor": 1
6141 } 6129 }