Mercurial > hg > plosone_underreview
changeset 78:9e526f7c9715 branch-tests
not sure what changed..
author | Maria Panteli <m.x.panteli@gmail.com> |
---|---|
date | Tue, 26 Sep 2017 12:40:07 +0100 |
parents | bde45ce0eeab (current diff) ce368898158f (diff) |
children | 98fc06ba2938 103f7411c3ad |
files | notebooks/results_for_30_seconds.ipynb |
diffstat | 3 files changed, 1101 insertions(+), 191 deletions(-) [+] |
line wrap: on
line diff
--- a/notebooks/results_for_30_seconds.ipynb Fri Sep 22 18:02:59 2017 +0100 +++ b/notebooks/results_for_30_seconds.ipynb Tue Sep 26 12:40:07 2017 +0100 @@ -2,17 +2,17 @@ "cells": [ { "cell_type": "code", - "execution_count": 36, + "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" + "/homes/mp305/anaconda/lib/python2.7/site-packages/librosa/core/audio.py:33: UserWarning: Could not import scikits.samplerate. Falling back to scipy.signal\n", + " warnings.warn('Could not import scikits.samplerate. '\n" ] } ], @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 3, "metadata": { "collapsed": false }, @@ -127,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 4, "metadata": { "collapsed": false }, @@ -136,7 +136,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "['/import/c4dm-04/mariap/train_data_melodia_8_30sec.pickle', '/import/c4dm-04/mariap/val_data_melodia_8_30sec.pickle', '/import/c4dm-04/mariap/test_data_melodia_8_30sec.pickle'] ['/import/c4dm-04/mariap/lda_data_melodia_8_30sec_30sec.pickle', '/import/c4dm-04/mariap/pca_data_melodia_8_30sec_30sec.pickle', '/import/c4dm-04/mariap/nmf_data_melodia_8_30sec_30sec.pickle', '/import/c4dm-04/mariap/ssnmf_data_melodia_8_30sec_30sec.pickle', '/import/c4dm-04/mariap/na_data_melodia_8_30sec_30sec.pickle']\n" + "['/import/c4dm-04/mariap/train_data_melodia_8_30sec.pickle', '/import/c4dm-04/mariap/val_data_melodia_8_30sec.pickle', '/import/c4dm-04/mariap/test_data_melodia_8_30sec.pickle'] ['/import/c4dm-04/mariap/lda_data_melodia_8_30sec.pickle', '/import/c4dm-04/mariap/pca_data_melodia_8_30sec.pickle', '/import/c4dm-04/mariap/nmf_data_melodia_8_30sec.pickle', '/import/c4dm-04/mariap/ssnmf_data_melodia_8_30sec.pickle', '/import/c4dm-04/mariap/na_data_melodia_8_30sec.pickle']\n" ] } ], @@ -150,7 +150,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "collapsed": false }, @@ -166,7 +166,7 @@ "variance explained 1.0\n", "138 400\n", "training with PCA transform...\n", - "variance explained 0.989999211296\n", + "variance explained 0.989994197011\n", "training with LDA transform...\n" ] }, @@ -174,6 +174,8 @@ "name": "stderr", "output_type": "stream", "text": [ + "/homes/mp305/anaconda/lib/python2.7/site-packages/sklearn/utils/validation.py:526: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n", + " y = column_or_1d(y, warn=True)\n", "/homes/mp305/anaconda/lib/python2.7/site-packages/sklearn/discriminant_analysis.py:455: UserWarning: The priors do not sum to 1. Renormalizing\n", " UserWarning)\n" ] @@ -183,6 +185,10 @@ "output_type": "stream", "text": [ "variance explained 1.0\n", + "training with NMF transform...\n", + "reconstruction error 6.59195506061\n", + "training with SSNMF transform...\n", + "reconstruction error 25.0727210368\n", "transform test data...\n", "mapping mel\n", "training with PCA transform...\n", @@ -192,32 +198,16 @@ "variance explained 0.990347897477\n", "training with LDA transform...\n", "variance explained 1.0\n", - "transform test data...\n", - "mapping mfc\n", - "training with PCA transform...\n", - "variance explained 1.0\n", - "39 80\n", - "training with PCA transform...\n", - "variance explained 0.991458741216\n", - "training with LDA transform...\n", - "variance explained 0.942657629903\n", - "transform test data...\n", - "mapping chr\n", - "training with PCA transform...\n", - "variance explained 1.0\n", - "70 120\n", - "training with PCA transform...\n", - "variance explained 0.990503308525\n", - "training with LDA transform...\n", - "variance explained 0.954607427999\n", - "transform test data...\n" + "training with NMF transform...\n" ] } ], "source": [ "print \"mapping...\"\n", - "_, _, ldadata_list, _, _, Y, Yaudio = mapper.lda_map_and_average_frames(min_variance=0.99)\n", - "mapper.write_output([], [], ldadata_list, [], [], Y, Yaudio)" + "#_, _, ldadata_list, _, _, Y, Yaudio = mapper.lda_map_and_average_frames(min_variance=0.99)\n", + "#mapper.write_output([], [], ldadata_list, [], [], Y, Yaudio)\n", + "data_list, pcadata_list, ldadata_list, nmfdata_list, ssnmfdata_list, classlabs, audiolabs = mapper.map_and_average_frames(min_variance=0.99)\n", + "mapper.write_output(data_list, pcadata_list, ldadata_list, nmfdata_list, ssnmfdata_list, classlabs, audiolabs)" ] }, {
--- a/notebooks/sensitivity_experiment.ipynb Fri Sep 22 18:02:59 2017 +0100 +++ b/notebooks/sensitivity_experiment.ipynb Tue Sep 26 12:40:07 2017 +0100 @@ -2,10 +2,20 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { - "outputs": [], + "collapsed": false }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/homes/mp305/anaconda/lib/python2.7/site-packages/librosa/core/audio.py:33: UserWarning: Could not import scikits.samplerate. Falling back to scipy.signal\n", + " warnings.warn('Could not import scikits.samplerate. '\n" + ] + } + ], "source": [ "import numpy as np\n", "import pandas as pd\n", @@ -26,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "collapsed": true }, @@ -4799,13 +4809,7 @@ "105 Sudan 0.045455 66 3\n", "120 Kazakhstan 0.045455 88 4\n", "writing file\n", - "iteration 7\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "iteration 7\n", "classifying...\n", "/import/c4dm-04/mariap/train_data_melodia_8_7.pickle\n", "0.179777654473\n", @@ -4911,12 +4915,675 @@ " X, Y, Yaudio = classification.load_data_from_pickle(CLASS_INPUT_FILES[0])\n", " #X = np.concatenate(ldadata_list, axis=1)\n", " # classification and confusion\n", - " print \"classifying...\"\n", - " traininds, testinds = classification.get_train_test_indices(Yaudio)\n", - " X_train, Y_train, X_test, Y_test = classification.get_train_test_sets(X, Y, traininds, testinds)\n", - " accuracy, _ = classification.confusion_matrix(X_train, Y_train, X_test, Y_test, saveCF=False, plots=False)\n", - " print accuracy\n", + " if 1:\n", + " print \"classifying...\"\n", + " traininds, testinds = classification.get_train_test_indices(Yaudio)\n", + " X_train, Y_train, X_test, Y_test = classification.get_train_test_sets(X, Y, traininds, testinds)\n", + " accuracy, _ = classification.confusion_matrix(X_train, Y_train, X_test, Y_test, saveCF=False, plots=False)\n", + " print accuracy\n", + "\n", + " # outliers\n", + " print \"detecting outliers...\"\n", + " #ddf = outliers.load_metadata(Yaudio, metadata_file=load_dataset.METADATA_FILE)\n", + " df_global, threshold, MD = outliers.get_outliers_df(X, Y, chi2thr=0.999)\n", + " outliers.print_most_least_outliers_topN(df_global, N=10)\n", " \n", + " # write output\n", + " print \"writing file\"\n", + " df_global.to_csv('../data/outliers_'+str(n)+'.csv', index=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's sample only 80% of the recordings each time so outlier results are different. Otherwise if we are only including the same 10 recordings from Chad we have more chances of getting the same outliers from Chad." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "iteration 0\n", + "(8089, 381) (8089,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "136 Botswana 0.590909 88 52\n", + "31 Ivory Coast 0.571429 14 8\n", + "86 Gambia 0.541667 48 26\n", + "42 Benin 0.538462 26 14\n", + "102 Fiji 0.466667 15 7\n", + "20 Pakistan 0.461538 91 42\n", + "64 Uganda 0.437500 80 35\n", + "14 Liberia 0.425000 40 17\n", + "78 El Salvador 0.424242 33 14\n", + "50 Western Sahara 0.421687 83 35\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 47 0\n", + "30 Afghanistan 0.000000 24 0\n", + "28 Tajikistan 0.000000 19 0\n", + "27 South Korea 0.000000 11 0\n", + "113 Iceland 0.000000 14 0\n", + "119 Denmark 0.000000 16 0\n", + "74 Czech Republic 0.000000 41 0\n", + "15 Netherlands 0.014925 67 1\n", + "121 Poland 0.040000 100 4\n", + "134 Paraguay 0.043478 23 1\n", + "writing file\n", + "iteration 1\n", + "(8099, 381) (8099,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "60 Chad 0.545455 11 6\n", + "62 Fiji 0.533333 15 8\n", + "86 Gambia 0.520833 48 25\n", + "21 Pakistan 0.500000 88 44\n", + "43 Benin 0.500000 26 13\n", + "32 Ivory Coast 0.500000 14 7\n", + "136 Botswana 0.488095 84 41\n", + "78 El Salvador 0.484848 33 16\n", + "106 Nepal 0.436782 87 38\n", + "135 French Guiana 0.428571 28 12\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 47 0\n", + "113 Iceland 0.000000 14 0\n", + "119 Denmark 0.000000 16 0\n", + "74 Czech Republic 0.000000 41 0\n", + "28 South Korea 0.000000 11 0\n", + "16 Netherlands 0.029851 67 2\n", + "31 Afghanistan 0.041667 24 1\n", + "134 Paraguay 0.043478 23 1\n", + "105 Sudan 0.045455 66 3\n", + "120 Kazakhstan 0.045455 88 4\n", + "writing file\n", + "iteration 2\n", + "(8078, 380) (8078,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "136 Botswana 0.615385 78 48\n", + "86 Gambia 0.520833 48 25\n", + "72 Ivory Coast 0.500000 14 7\n", + "62 Fiji 0.466667 15 7\n", + "43 Benin 0.461538 26 12\n", + "20 Pakistan 0.451613 93 42\n", + "17 French Guiana 0.428571 28 12\n", + "14 Liberia 0.425000 40 17\n", + "78 El Salvador 0.424242 33 14\n", + "51 Western Sahara 0.414634 82 34\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "119 Denmark 0.000000 16 0\n", + "113 Iceland 0.000000 14 0\n", + "27 South Korea 0.000000 11 0\n", + "1 Lithuania 0.000000 47 0\n", + "31 Czech Republic 0.024390 41 1\n", + "15 Netherlands 0.029851 67 2\n", + "30 Afghanistan 0.041667 24 1\n", + "105 Sudan 0.045455 66 3\n", + "120 Kazakhstan 0.045455 88 4\n", + "100 Antigua and Barbuda 0.047619 42 2\n", + "writing file\n", + "iteration 3\n", + "(8103, 380) (8103,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "136 Botswana 0.617284 81 50\n", + "31 Ivory Coast 0.571429 14 8\n", + "86 Gambia 0.541667 48 26\n", + "43 Benin 0.538462 26 14\n", + "62 Fiji 0.533333 15 8\n", + "20 Pakistan 0.468750 96 45\n", + "51 Western Sahara 0.439024 82 36\n", + "14 Liberia 0.425000 40 17\n", + "78 El Salvador 0.424242 33 14\n", + "106 Nepal 0.416667 96 40\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "113 Iceland 0.000000 14 0\n", + "30 Afghanistan 0.000000 24 0\n", + "119 Denmark 0.000000 16 0\n", + "134 Paraguay 0.000000 23 0\n", + "27 South Korea 0.000000 11 0\n", + "1 Lithuania 0.000000 47 0\n", + "100 Antigua and Barbuda 0.023810 42 1\n", + "74 Czech Republic 0.024390 41 1\n", + "15 Netherlands 0.029851 67 2\n", + "105 Sudan 0.045455 66 3\n", + "writing file\n", + "iteration 4\n", + "(8100, 381) (8100,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "60 Chad 0.727273 11 8\n", + "136 Botswana 0.630952 84 53\n", + "72 Ivory Coast 0.571429 14 8\n", + "62 Fiji 0.533333 15 8\n", + "86 Gambia 0.520833 48 25\n", + "43 Benin 0.500000 26 13\n", + "20 Pakistan 0.468085 94 44\n", + "135 French Guiana 0.464286 28 13\n", + "64 Mozambique 0.441176 34 15\n", + "51 Western Sahara 0.439024 82 36\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 47 0\n", + "27 South Korea 0.000000 11 0\n", + "113 Iceland 0.000000 14 0\n", + "119 Denmark 0.000000 16 0\n", + "15 Netherlands 0.014925 67 1\n", + "31 Czech Republic 0.024390 41 1\n", + "112 Israel 0.030000 100 3\n", + "30 Afghanistan 0.041667 24 1\n", + "134 Paraguay 0.043478 23 1\n", + "105 Sudan 0.045455 66 3\n", + "writing file\n", + "iteration 5\n", + "(8101, 380) (8101,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "136 Botswana 0.607143 84 51\n", + "72 Ivory Coast 0.571429 14 8\n", + "21 Pakistan 0.553191 94 52\n", + "95 Chad 0.545455 11 6\n", + "63 Fiji 0.533333 15 8\n", + "86 Gambia 0.520833 48 25\n", + "44 Benin 0.500000 26 13\n", + "78 El Salvador 0.454545 33 15\n", + "117 Zimbabwe 0.428571 14 6\n", + "66 Uganda 0.418605 86 36\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "119 Denmark 0.000000 16 0\n", + "1 Lithuania 0.000000 47 0\n", + "28 South Korea 0.000000 11 0\n", + "113 Iceland 0.000000 14 0\n", + "32 Czech Republic 0.024390 41 1\n", + "16 Netherlands 0.029851 67 2\n", + "31 Afghanistan 0.041667 24 1\n", + "134 Paraguay 0.043478 23 1\n", + "120 Kazakhstan 0.045455 88 4\n", + "105 Sudan 0.045455 66 3\n", + "writing file\n", + "iteration 6\n", + "(8110, 380) (8110,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "136 Botswana 0.574468 94 54\n", + "32 Ivory Coast 0.571429 14 8\n", + "86 Gambia 0.520833 48 25\n", + "21 Pakistan 0.516854 89 46\n", + "62 Fiji 0.466667 15 7\n", + "43 Benin 0.461538 26 12\n", + "95 Chad 0.454545 11 5\n", + "78 El Salvador 0.454545 33 15\n", + "51 Western Sahara 0.439024 82 36\n", + "63 Senegal 0.405405 37 15\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 47 0\n", + "119 Denmark 0.000000 16 0\n", + "28 South Korea 0.000000 11 0\n", + "113 Iceland 0.000000 14 0\n", + "16 Netherlands 0.014925 67 1\n", + "74 Czech Republic 0.024390 41 1\n", + "13 Germany 0.040000 100 4\n", + "31 Afghanistan 0.041667 24 1\n", + "105 Sudan 0.045455 66 3\n", + "120 Kazakhstan 0.045455 88 4\n", + "writing file\n", + "iteration 7\n", + "(8048, 381) (8048,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "136 Botswana 0.636364 88 56\n", + "95 Chad 0.636364 11 7\n", + "86 Gambia 0.511111 45 23\n", + "42 Benin 0.500000 26 13\n", + "14 Liberia 0.500000 40 20\n", + "63 Mozambique 0.500000 34 17\n", + "78 El Salvador 0.424242 33 14\n", + "62 Senegal 0.416667 36 15\n", + "20 Pakistan 0.415730 89 37\n", + "106 Nepal 0.402174 92 37\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 47 0\n", + "119 Denmark 0.000000 16 0\n", + "113 Iceland 0.000000 14 0\n", + "27 South Korea 0.000000 11 0\n", + "15 Netherlands 0.015152 66 1\n", + "120 Kazakhstan 0.034884 86 3\n", + "30 Afghanistan 0.041667 24 1\n", + "102 Nicaragua 0.050000 20 1\n", + "112 Israel 0.050000 100 5\n", + "28 Tajikistan 0.052632 19 1\n", + "writing file\n", + "iteration 8\n", + "(8012, 380) (8012,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "95 Chad 0.636364 11 7\n", + "43 Benin 0.576923 26 15\n", + "136 Botswana 0.571429 77 44\n", + "14 Liberia 0.525000 40 21\n", + "86 Gambia 0.488889 45 22\n", + "78 El Salvador 0.484848 33 16\n", + "64 Mozambique 0.470588 34 16\n", + "62 Fiji 0.466667 15 7\n", + "20 Pakistan 0.436782 87 38\n", + "63 Senegal 0.416667 36 15\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 47 0\n", + "119 Denmark 0.000000 16 0\n", + "113 Iceland 0.000000 14 0\n", + "27 South Korea 0.000000 11 0\n", + "102 Nicaragua 0.000000 20 0\n", + "28 Tajikistan 0.000000 19 0\n", + "15 Netherlands 0.015152 66 1\n", + "89 Croatia 0.032258 31 1\n", + "120 Kazakhstan 0.034884 86 3\n", + "30 Afghanistan 0.041667 24 1\n", + "writing file\n", + "iteration 9\n", + "(8032, 380) (8032,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "43 Benin 0.576923 26 15\n", + "136 Botswana 0.567901 81 46\n", + "60 Chad 0.545455 11 6\n", + "86 Gambia 0.533333 45 24\n", + "14 Liberia 0.525000 40 21\n", + "65 Uganda 0.482759 87 42\n", + "64 Mozambique 0.470588 34 16\n", + "20 Pakistan 0.465909 88 41\n", + "135 French Guiana 0.464286 28 13\n", + "67 Brazil 0.460000 100 46\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 47 0\n", + "90 French Polynesia 0.000000 15 0\n", + "102 Nicaragua 0.000000 20 0\n", + "113 Iceland 0.000000 14 0\n", + "119 Denmark 0.000000 16 0\n", + "15 Netherlands 0.015152 66 1\n", + "18 New Zealand 0.029412 34 1\n", + "120 Kazakhstan 0.034884 86 3\n", + "31 Czech Republic 0.048780 41 2\n", + "28 Tajikistan 0.052632 19 1\n", + "writing file\n" + ] + } + ], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "n_iters = 10\n", + "OUTPUT_FILES = load_dataset.OUTPUT_FILES\n", + "MAPPER_OUTPUT_FILES = mapper.OUTPUT_FILES\n", + "for n in range(n_iters):\n", + " print \"iteration %d\" % n\n", + " CLASS_INPUT_FILES = [output_file.split('.pickle')[0]+'_'+str(n)+'.pickle' for \n", + " output_file in MAPPER_OUTPUT_FILES]\n", + " mapper.INPUT_FILES = [output_file.split('.pickle')[0]+'_'+str(n)+'.pickle' for \n", + " output_file in OUTPUT_FILES]\n", + " X, Y, Yaudio = classification.load_data_from_pickle(CLASS_INPUT_FILES[0])\n", + " # get only 80% of the dataset.. to vary the choice of outliers\n", + " X, _, Y, _ = train_test_split(X, Y, train_size=0.8, stratify=Y)\n", + " print X.shape, Y.shape\n", + " # outliers\n", + " print \"detecting outliers...\"\n", + " #ddf = outliers.load_metadata(Yaudio, metadata_file=load_dataset.METADATA_FILE)\n", + " df_global, threshold, MD = outliers.get_outliers_df(X, Y, chi2thr=0.999)\n", + " outliers.print_most_least_outliers_topN(df_global, N=10)\n", + " \n", + " # write output\n", + " print \"writing file\"\n", + " df_global.to_csv('../data/outliers_'+str(n)+'.csv', index=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try without changing the LDA mapping, so just load the original dataset and get outlier countries by selecting 80% of the recordigns (in stratified manner).'" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "iteration 0\n", + "/import/c4dm-04/mariap/lda_data_melodia_8.pickle\n", + "(6560, 380) (6560,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "95 Chad 0.555556 9 5\n", + "86 Gambia 0.525000 40 21\n", + "135 French Guiana 0.500000 22 11\n", + "44 Benin 0.476190 21 10\n", + "15 Liberia 0.468750 32 15\n", + "136 Botswana 0.458333 72 33\n", + "104 Bhutan 0.444444 9 4\n", + "68 Brazil 0.437500 80 35\n", + "92 Switzerland 0.428571 42 18\n", + "78 El Salvador 0.423077 26 11\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 38 0\n", + "29 Tajikistan 0.000000 15 0\n", + "32 Czech Republic 0.000000 33 0\n", + "107 Kiribati 0.000000 14 0\n", + "120 Kazakhstan 0.000000 70 0\n", + "119 Denmark 0.000000 13 0\n", + "0 Canada 0.050000 80 4\n", + "73 Nigeria 0.051948 77 4\n", + "109 Democratic Republic of the Congo 0.052632 38 2\n", + "105 Sudan 0.055556 54 3\n", + "writing file\n", + "iteration 1\n", + "/import/c4dm-04/mariap/lda_data_melodia_8.pickle\n", + "(6560, 380) (6560,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "95 Chad 0.666667 9 6\n", + "17 French Guiana 0.545455 22 12\n", + "86 Gambia 0.525000 40 21\n", + "44 Benin 0.523810 21 11\n", + "6 Bolivia 0.500000 28 14\n", + "78 El Salvador 0.500000 26 13\n", + "136 Botswana 0.486111 72 35\n", + "10 Guatemala 0.465116 43 20\n", + "115 Senegal 0.454545 33 15\n", + "104 Bhutan 0.444444 9 4\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "120 Kazakhstan 0.000000 70 0\n", + "1 Lithuania 0.000000 38 0\n", + "107 Kiribati 0.000000 14 0\n", + "119 Denmark 0.000000 13 0\n", + "9 Saudi Arabia 0.000000 8 0\n", + "98 Uzbekistan 0.030303 33 1\n", + "15 Netherlands 0.037037 54 2\n", + "57 Russia 0.037975 79 3\n", + "109 Democratic Republic of the Congo 0.052632 38 2\n", + "105 Sudan 0.055556 54 3\n", + "writing file\n", + "iteration 2\n", + "/import/c4dm-04/mariap/lda_data_melodia_8.pickle\n", + "(6560, 380) (6560,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "95 Chad 0.666667 9 6\n", + "104 Bhutan 0.555556 9 5\n", + "86 Gambia 0.550000 40 22\n", + "135 French Guiana 0.545455 22 12\n", + "78 El Salvador 0.538462 26 14\n", + "43 Benin 0.523810 21 11\n", + "6 Bolivia 0.500000 28 14\n", + "136 Botswana 0.486111 72 35\n", + "64 Mozambique 0.444444 27 12\n", + "14 Liberia 0.437500 32 14\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 38 0\n", + "107 Kiribati 0.000000 14 0\n", + "119 Denmark 0.000000 13 0\n", + "120 Kazakhstan 0.000000 70 0\n", + "15 Netherlands 0.018519 54 1\n", + "105 Sudan 0.037037 54 2\n", + "0 Canada 0.050000 80 4\n", + "109 Democratic Republic of the Congo 0.052632 38 2\n", + "94 Iraq 0.057971 69 4\n", + "31 Czech Republic 0.060606 33 2\n", + "writing file\n", + "iteration 3\n", + "/import/c4dm-04/mariap/lda_data_melodia_8.pickle\n", + "(6560, 380) (6560,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "60 Chad 0.666667 9 6\n", + "17 French Guiana 0.590909 22 13\n", + "86 Gambia 0.550000 40 22\n", + "6 Bolivia 0.535714 28 15\n", + "136 Botswana 0.513889 72 37\n", + "64 Mozambique 0.481481 27 13\n", + "14 Liberia 0.468750 32 15\n", + "78 El Salvador 0.461538 26 12\n", + "115 Senegal 0.454545 33 15\n", + "108 Malta 0.437500 16 7\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "120 Kazakhstan 0.000000 70 0\n", + "1 Lithuania 0.000000 38 0\n", + "30 Afghanistan 0.000000 19 0\n", + "119 Denmark 0.000000 13 0\n", + "107 Kiribati 0.000000 14 0\n", + "31 Czech Republic 0.030303 33 1\n", + "98 Uzbekistan 0.030303 33 1\n", + "15 Netherlands 0.037037 54 2\n", + "105 Sudan 0.037037 54 2\n", + "84 Iraq 0.042857 70 3\n", + "writing file\n", + "iteration 4\n", + "/import/c4dm-04/mariap/lda_data_melodia_8.pickle\n", + "(6560, 380) (6560,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "117 Zimbabwe 0.583333 12 7\n", + "60 Chad 0.555556 9 5\n", + "86 Gambia 0.550000 40 22\n", + "43 Benin 0.523810 21 11\n", + "6 Bolivia 0.500000 28 14\n", + "135 French Guiana 0.500000 22 11\n", + "136 Botswana 0.472222 72 34\n", + "78 El Salvador 0.461538 26 12\n", + "10 Guatemala 0.441860 43 19\n", + "14 Liberia 0.437500 32 14\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 38 0\n", + "107 Kiribati 0.000000 14 0\n", + "119 Denmark 0.000000 13 0\n", + "120 Kazakhstan 0.000000 70 0\n", + "27 South Korea 0.000000 9 0\n", + "109 Democratic Republic of the Congo 0.026316 38 1\n", + "94 Iraq 0.028571 70 2\n", + "31 Czech Republic 0.030303 33 1\n", + "105 Sudan 0.037037 54 2\n", + "85 Sierra Leone 0.050000 80 4\n", + "writing file\n", + "iteration 5\n", + "/import/c4dm-04/mariap/lda_data_melodia_8.pickle\n", + "(6560, 380) (6560,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "61 Chad 0.666667 9 6\n", + "44 Benin 0.619048 21 13\n", + "104 Bhutan 0.555556 9 5\n", + "18 French Guiana 0.545455 22 12\n", + "86 Gambia 0.525000 40 21\n", + "136 Botswana 0.500000 72 36\n", + "117 Zimbabwe 0.500000 12 6\n", + "15 Liberia 0.500000 32 16\n", + "64 Senegal 0.484848 33 16\n", + "78 El Salvador 0.461538 26 12\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 38 0\n", + "120 Kazakhstan 0.000000 70 0\n", + "119 Denmark 0.000000 13 0\n", + "107 Kiribati 0.000000 14 0\n", + "9 Saudi Arabia 0.000000 8 0\n", + "0 Canada 0.025000 80 2\n", + "57 Russia 0.050633 79 4\n", + "109 Democratic Republic of the Congo 0.052632 38 2\n", + "51 Finland 0.052632 19 1\n", + "105 Sudan 0.055556 54 3\n", + "writing file\n", + "iteration 6\n", + "/import/c4dm-04/mariap/lda_data_melodia_8.pickle\n", + "(6560, 380) (6560,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "60 Chad 0.666667 9 6\n", + "17 French Guiana 0.590909 22 13\n", + "117 Zimbabwe 0.583333 12 7\n", + "86 Gambia 0.575000 40 23\n", + "78 El Salvador 0.538462 26 14\n", + "43 Benin 0.523810 21 11\n", + "115 Senegal 0.515152 33 17\n", + "136 Botswana 0.472222 72 34\n", + "104 Bhutan 0.444444 9 4\n", + "84 Belize 0.441176 34 15\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "1 Lithuania 0.000000 38 0\n", + "107 Kiribati 0.000000 14 0\n", + "113 Iceland 0.000000 11 0\n", + "72 Ivory Coast 0.000000 12 0\n", + "119 Denmark 0.000000 13 0\n", + "120 Kazakhstan 0.000000 70 0\n", + "28 Tajikistan 0.000000 15 0\n", + "105 Sudan 0.018519 54 1\n", + "15 Netherlands 0.018519 54 1\n", + "109 Democratic Republic of the Congo 0.026316 38 1\n", + "writing file\n", + "iteration 7\n", + "/import/c4dm-04/mariap/lda_data_melodia_8.pickle\n", + "(6560, 380) (6560,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "95 Chad 0.555556 9 5\n", + "86 Gambia 0.525000 40 21\n", + "43 Benin 0.523810 21 11\n", + "135 French Guiana 0.500000 22 11\n", + "63 Senegal 0.484848 33 16\n", + "14 Liberia 0.468750 32 15\n", + "52 Indonesia 0.437500 80 35\n", + "136 Botswana 0.430556 72 31\n", + "6 Bolivia 0.428571 28 12\n", + "92 Switzerland 0.428571 42 18\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "119 Denmark 0.000000 13 0\n", + "1 Lithuania 0.000000 38 0\n", + "107 Kiribati 0.000000 14 0\n", + "120 Kazakhstan 0.000000 70 0\n", + "113 Iceland 0.000000 11 0\n", + "94 Iraq 0.028571 70 2\n", + "98 Uzbekistan 0.030303 33 1\n", + "105 Sudan 0.037037 54 2\n", + "85 Sierra Leone 0.037500 80 3\n", + "109 Democratic Republic of the Congo 0.052632 38 2\n", + "writing file\n", + "iteration 8\n", + "/import/c4dm-04/mariap/lda_data_melodia_8.pickle\n", + "(6560, 380) (6560,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "61 Chad 0.666667 9 6\n", + "78 El Salvador 0.576923 26 15\n", + "44 Benin 0.571429 21 12\n", + "104 Bhutan 0.555556 9 5\n", + "86 Gambia 0.550000 40 22\n", + "17 French Guiana 0.545455 22 12\n", + "94 Belize 0.470588 34 16\n", + "14 Liberia 0.468750 32 15\n", + "92 Switzerland 0.452381 42 19\n", + "53 Indonesia 0.450000 80 36\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "119 Denmark 0.000000 13 0\n", + "1 Lithuania 0.000000 38 0\n", + "120 Kazakhstan 0.000000 70 0\n", + "107 Kiribati 0.000000 14 0\n", + "98 Uzbekistan 0.030303 33 1\n", + "105 Sudan 0.037037 54 2\n", + "15 Netherlands 0.037037 54 2\n", + "85 Sierra Leone 0.037500 80 3\n", + "84 Iraq 0.042857 70 3\n", + "109 Democratic Republic of the Congo 0.052632 38 2\n", + "writing file\n", + "iteration 9\n", + "/import/c4dm-04/mariap/lda_data_melodia_8.pickle\n", + "(6560, 380) (6560,)\n", + "detecting outliers...\n", + "most outliers \n", + " Country Outliers N_Country N_Outliers\n", + "95 Chad 0.555556 9 5\n", + "104 Bhutan 0.555556 9 5\n", + "86 Gambia 0.550000 40 22\n", + "78 El Salvador 0.538462 26 14\n", + "18 French Guiana 0.500000 22 11\n", + "115 Senegal 0.484848 33 16\n", + "44 Benin 0.476190 21 10\n", + "41 Laos 0.470588 17 8\n", + "6 Bolivia 0.464286 28 13\n", + "65 Mozambique 0.444444 27 12\n", + "least outliers \n", + " Country Outliers N_Country N_Outliers\n", + "119 Denmark 0.000000 13 0\n", + "1 Lithuania 0.000000 38 0\n", + "120 Kazakhstan 0.000000 70 0\n", + "107 Kiribati 0.000000 14 0\n", + "32 Czech Republic 0.000000 33 0\n", + "85 Sierra Leone 0.050000 80 4\n", + "0 Canada 0.050000 80 4\n", + "109 Democratic Republic of the Congo 0.052632 38 2\n", + "105 Sudan 0.055556 54 3\n", + "16 Netherlands 0.055556 54 3\n", + "writing file\n" + ] + } + ], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "n_iters = 10\n", + "for n in range(n_iters):\n", + " print \"iteration %d\" % n\n", + " results_file = mapper.OUTPUT_FILES[0]\n", + " print results_file\n", + " X, Y, Yaudio = classification.load_data_from_pickle(results_file)\n", + " # get only 80% of the dataset.. to vary the choice of outliers\n", + " X, _, Y, _ = train_test_split(X, Y, train_size=0.8, stratify=Y)\n", + " print X.shape, Y.shape\n", " # outliers\n", " print \"detecting outliers...\"\n", " #ddf = outliers.load_metadata(Yaudio, metadata_file=load_dataset.METADATA_FILE)\n", @@ -4946,7 +5613,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 68, "metadata": { "collapsed": true }, @@ -4963,16 +5630,18 @@ }, { "cell_type": "code", - "execution_count": 49, - "metadata": {}, + "execution_count": 69, + "metadata": { + "collapsed": false + }, "outputs": [ { "data": { "text/plain": [ - "(133, 10)" + "(137, 10)" ] }, - "execution_count": 49, + "execution_count": 69, "metadata": {}, "output_type": "execute_result" } @@ -4990,39 +5659,41 @@ }, { "cell_type": "code", - "execution_count": 48, - "metadata": {}, + "execution_count": 70, + "metadata": { + "collapsed": false + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " Country Country Country Country Country Country \\\n", - "0 Botswana Chad Botswana Botswana Chad Botswana \n", - "1 Ivory Coast Fiji Gambia Ivory Coast Botswana Ivory Coast \n", - "2 Gambia Gambia Ivory Coast Gambia Ivory Coast Pakistan \n", - "3 Benin Benin Fiji Benin Fiji Chad \n", - "4 Fiji Pakistan Benin Fiji Gambia Fiji \n", + " Country Country Country Country Country \\\n", + "0 Chad Chad Chad Chad Zimbabwe \n", + "1 Gambia French Guiana Bhutan French Guiana Chad \n", + "2 French Guiana Gambia Gambia Gambia Gambia \n", + "3 Benin Benin French Guiana Bolivia Benin \n", + "4 Liberia Bolivia El Salvador Botswana Bolivia \n", "\n", - " Country Country Country Country \n", - "0 Botswana Botswana Chad Benin \n", - "1 Ivory Coast Chad Benin Botswana \n", - "2 Gambia Gambia Botswana Chad \n", - "3 Pakistan Mozambique Liberia Gambia \n", - "4 Fiji Benin Gambia Liberia \n", + " Country Country Country Country Country \n", + "0 Chad Chad Chad Chad Bhutan \n", + "1 Benin French Guiana Gambia El Salvador Chad \n", + "2 Bhutan Zimbabwe Benin Benin Gambia \n", + "3 French Guiana Gambia French Guiana Bhutan El Salvador \n", + "4 Gambia El Salvador Senegal Gambia French Guiana \n", " Outliers Outliers Outliers Outliers Outliers Outliers Outliers \\\n", - "0 0.590909 0.545455 0.615385 0.617284 0.727273 0.607143 0.574468 \n", - "1 0.571429 0.533333 0.520833 0.571429 0.630952 0.571429 0.571429 \n", - "2 0.541667 0.520833 0.500000 0.541667 0.571429 0.553191 0.520833 \n", - "3 0.538462 0.500000 0.466667 0.538462 0.533333 0.545455 0.516854 \n", - "4 0.466667 0.500000 0.461538 0.533333 0.520833 0.533333 0.466667 \n", + "0 0.555556 0.666667 0.666667 0.666667 0.583333 0.666667 0.666667 \n", + "1 0.525000 0.545455 0.555556 0.590909 0.555556 0.619048 0.590909 \n", + "2 0.500000 0.525000 0.550000 0.550000 0.550000 0.555556 0.583333 \n", + "3 0.476190 0.523810 0.545455 0.535714 0.523810 0.545455 0.575000 \n", + "4 0.468750 0.500000 0.538462 0.513889 0.500000 0.525000 0.538462 \n", "\n", " Outliers Outliers Outliers \n", - "0 0.636364 0.636364 0.576923 \n", - "1 0.636364 0.576923 0.567901 \n", - "2 0.511111 0.571429 0.545455 \n", - "3 0.500000 0.525000 0.533333 \n", - "4 0.500000 0.488889 0.525000 \n" + "0 0.555556 0.666667 0.555556 \n", + "1 0.525000 0.576923 0.555556 \n", + "2 0.523810 0.571429 0.550000 \n", + "3 0.500000 0.555556 0.538462 \n", + "4 0.484848 0.550000 0.500000 \n" ] } ], @@ -5045,7 +5716,27 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 71, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from scipy.stats import kendalltau\n", + "r_, p_ = [], []\n", + "ranked_countries_arr = ranked_countries.get_values()\n", + "for i in range(n_iters-1):\n", + " for j in range(i+1, n_iters):\n", + " r, p = kendalltau(ranked_countries_arr[:, i], ranked_countries_arr[:, j])\n", + " r_.append(r)\n", + " p_.append(p)\n", + "r_ = np.array(r_)\n", + "p_ = np.array(p_)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, "metadata": { "collapsed": false }, @@ -5054,64 +5745,85 @@ "name": "stdout", "output_type": "stream", "text": [ - "KendalltauResult(correlation=0.11870585554796083, pvalue=0.042684955693776824)\n", - "KendalltauResult(correlation=0.061289587605377081, pvalue=0.29535042403787393)\n", - "KendalltauResult(correlation=0.14057871952608797, pvalue=0.016384498702657929)\n", - "KendalltauResult(correlation=0.043062200956937809, pvalue=0.46219181347134564)\n", - "KendalltauResult(correlation=0.038049669628617007, pvalue=0.51591269004232343)\n", - "KendalltauResult(correlation=0.15516062884483939, pvalue=0.0080680863973824919)\n", - "KendalltauResult(correlation=0.097972203235361141, pvalue=0.094371801845320874)\n", - "KendalltauResult(correlation=0.070403280929596718, pvalue=0.22933906132681292)\n", - "KendalltauResult(correlation=0.087263613579403057, pvalue=0.13624109595088119)\n", - "KendalltauResult(correlation=0.026657552973342449, pvalue=0.64900123852931668)\n", - "KendalltauResult(correlation=0.012531328320802006, pvalue=0.83057867073317604)\n", - "KendalltauResult(correlation=0.15698336750968331, pvalue=0.0073549938316186895)\n", - "KendalltauResult(correlation=0.072226019594440652, pvalue=0.21750692637496993)\n", - "KendalltauResult(correlation=0.064479380268853956, pvalue=0.27093205134080134)\n", - "KendalltauResult(correlation=0.07518796992481204, pvalue=0.19922707586147026)\n", - "KendalltauResult(correlation=0.017088174982911826, pvalue=0.77046791234681555)\n", - "KendalltauResult(correlation=0.098200045568466648, pvalue=0.093608177106345392)\n", - "KendalltauResult(correlation=0.11004784688995217, pvalue=0.060250899787989511)\n", - "KendalltauResult(correlation=0.051720209614946465, pvalue=0.37719896672100306)\n", - "KendalltauResult(correlation=0.099567099567099596, pvalue=0.089129953079656793)\n", - "KendalltauResult(correlation=-0.081795397584871282, pvalue=0.16254238954046385)\n", - "KendalltauResult(correlation=0.089769879243563472, pvalue=0.12534294310051713)\n", - "KendalltauResult(correlation=0.10047846889952156, pvalue=0.086241531926005505)\n", - "KendalltauResult(correlation=0.014809751651856917, pvalue=0.80037548797424396)\n", - "KendalltauResult(correlation=0.021189336978810668, pvalue=0.71751195692767422)\n", - "KendalltauResult(correlation=0.020733652312599684, pvalue=0.7233346465763022)\n", - "KendalltauResult(correlation=-0.057644110275689227, pvalue=0.32501053989276085)\n", - "KendalltauResult(correlation=0.04647983595352017, pvalue=0.42743119135699703)\n", - "KendalltauResult(correlation=-0.02939166097060834, pvalue=0.6157855679677966)\n", - "KendalltauResult(correlation=-0.01754385964912281, pvalue=0.76452558103925983)\n", - "KendalltauResult(correlation=-0.00022784233310549102, pvalue=0.99689609964041026)\n", - "KendalltauResult(correlation=0.053087263613579412, pvalue=0.36471883993264553)\n", - "KendalltauResult(correlation=0.11027568922305765, pvalue=0.059721613251292195)\n", - "KendalltauResult(correlation=0.1319207108680793, pvalue=0.024296399889465414)\n", - "KendalltauResult(correlation=0.11050353155616316, pvalue=0.059196189350124301)\n", - "KendalltauResult(correlation=0.081339712918660295, pvalue=0.16489618845189757)\n", - "KendalltauResult(correlation=0.091136933242196419, pvalue=0.11969173188443738)\n", - "KendalltauResult(correlation=0.010252904989747097, pvalue=0.86103426355600943)\n", - "KendalltauResult(correlation=0.026201868307131469, pvalue=0.6546080905364744)\n", - "KendalltauResult(correlation=0.056049213943950793, pvalue=0.33857618122131272)\n", - "KendalltauResult(correlation=0.075415812257917533, pvalue=0.19786889281527764)\n", - "KendalltauResult(correlation=0.026657552973342449, pvalue=0.64900123852931668)\n", - "KendalltauResult(correlation=0.091136933242196419, pvalue=0.11969173188443738)\n", - "KendalltauResult(correlation=0.1964000911369333, pvalue=0.00079845943724486494)\n", - "KendalltauResult(correlation=0.049441786283891551, pvalue=0.39857590952666144)\n" + "0.0493253335359 0.410409379365\n" ] } ], "source": [ - "from scipy.stats import kendalltau\n", - "for i in range(n_iters-1):\n", - " for j in range(i+1, n_iters):\n", - " print kendalltau(ranked_countries.iloc[:, i], ranked_countries.iloc[:, j])" + "print np.mean(r_), np.mean(p_)" ] }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 80, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.240026302342 0.351418392739\n" + ] + } + ], + "source": [ + "from scipy.stats import spearmanr\n", + "r, p = spearmanr(ranked_countries_arr)\n", + "# only the upper triangular for correlation\n", + "upper_idx = np.triu_indices(len(r))\n", + "r, p = r[upper_idx], p[upper_idx]\n", + "print np.mean(r), np.mean(p)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "let's focus only on the top K results" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.237245179063 0.417925582965\n" + ] + } + ], + "source": [ + "k=10\n", + "r, p = spearmanr(ranked_countries_arr[:k, :])\n", + "# only the upper triangular for correlation\n", + "upper_idx = np.triu_indices(len(r))\n", + "r, p = r[upper_idx], p[upper_idx]\n", + "print np.mean(r), np.mean(p)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "common_set = set(np.unique(Y))\n", + "for i in range(ranked_countries_arr.shape[1]):\n", + " common_set = common_set & set(ranked_countries_arr[:k, i])" + ] + }, + { + "cell_type": "code", + "execution_count": 76, "metadata": { "collapsed": false }, @@ -5119,87 +5831,280 @@ { "data": { "text/plain": [ - "133" + "{'Chad', 'French Guiana', 'Gambia'}" ] }, - "execution_count": 53, + "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "len(ranked_countries)" + "common_set" ] }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 97, "metadata": { "collapsed": true }, "outputs": [], "source": [ - "from scipy.stats import spearmanr\n", - "r, p = spearmanr(ranked_countries)" + "# majority voting + precision at K (top5?)\n", + "from collections import Counter\n", + "K_vote = 10\n", + "country_vote = Counter(ranked_countries_arr[:K_vote, :].ravel())" ] }, { "cell_type": "code", - "execution_count": 58, - "metadata": {}, + "execution_count": 98, + "metadata": { + "collapsed": false + }, "outputs": [ { "data": { + "text/html": [ + "<div>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>index</th>\n", + " <th>0</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>Brazil</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>Liberia</td>\n", + " <td>7</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>Belize</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>Chad</td>\n", + " <td>10</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>Bhutan</td>\n", + " <td>7</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], "text/plain": [ - "array([[ 1.00000000e+00, 1.74432009e-01, 8.97001663e-02,\n", - " 1.99727609e-01, 6.82200753e-02, 5.39272197e-02,\n", - " 2.21325022e-01, 1.33629528e-01, 1.08109487e-01,\n", - " 1.31114761e-01],\n", - " [ 1.74432009e-01, 1.00000000e+00, 4.20573142e-02,\n", - " 2.07251507e-02, 2.28481652e-01, 1.01916936e-01,\n", - " 1.01442548e-01, 1.12532008e-01, 1.89806266e-02,\n", - " 1.48213138e-01],\n", - " [ 8.97001663e-02, 4.20573142e-02, 1.00000000e+00,\n", - " 1.53308985e-01, 7.91412044e-02, 1.41734934e-01,\n", - " -1.14419359e-01, 1.23519450e-01, 1.50641189e-01,\n", - " 3.17074913e-02],\n", - " [ 1.99727609e-01, 2.07251507e-02, 1.53308985e-01,\n", - " 1.00000000e+00, 3.04934657e-02, 3.27786903e-02,\n", - " -7.58255884e-02, 6.98727824e-02, -4.16900460e-02,\n", - " -2.15208986e-02],\n", - " [ 6.82200753e-02, 2.28481652e-01, 7.91412044e-02,\n", - " 3.04934657e-02, 1.00000000e+00, -8.00848798e-04,\n", - " 8.02532110e-02, 1.65796105e-01, 1.91678314e-01,\n", - " 1.62863060e-01],\n", - " [ 5.39272197e-02, 1.01916936e-01, 1.41734934e-01,\n", - " 3.27786903e-02, -8.00848798e-04, 1.00000000e+00,\n", - " 1.17969619e-01, 1.31221881e-01, 2.06996460e-02,\n", - " 3.92160863e-02],\n", - " [ 2.21325022e-01, 1.01442548e-01, -1.14419359e-01,\n", - " -7.58255884e-02, 8.02532110e-02, 1.17969619e-01,\n", - " 1.00000000e+00, 8.75832730e-02, 1.10578345e-01,\n", - " 4.28326583e-02],\n", - " [ 1.33629528e-01, 1.12532008e-01, 1.23519450e-01,\n", - " 6.98727824e-02, 1.65796105e-01, 1.31221881e-01,\n", - " 8.75832730e-02, 1.00000000e+00, 1.31374909e-01,\n", - " 2.78868814e-01],\n", - " [ 1.08109487e-01, 1.89806266e-02, 1.50641189e-01,\n", - " -4.16900460e-02, 1.91678314e-01, 2.06996460e-02,\n", - " 1.10578345e-01, 1.31374909e-01, 1.00000000e+00,\n", - " 7.53103927e-02],\n", - " [ 1.31114761e-01, 1.48213138e-01, 3.17074913e-02,\n", - " -2.15208986e-02, 1.62863060e-01, 3.92160863e-02,\n", - " 4.28326583e-02, 2.78868814e-01, 7.53103927e-02,\n", - " 1.00000000e+00]])" + " index 0\n", + "0 Brazil 1\n", + "1 Liberia 7\n", + "2 Belize 2\n", + "3 Chad 10\n", + "4 Bhutan 7" ] }, - "execution_count": 58, + "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "r" + "df_country_vote = pd.DataFrame.from_dict(country_vote, orient='index').reset_index()\n", + "df_country_vote.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>index</th>\n", + " <th>0</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>Chad</td>\n", + " <td>10</td>\n", + " </tr>\n", + " <tr>\n", + " <th>6</th>\n", + " <td>Gambia</td>\n", + " <td>10</td>\n", + " </tr>\n", + " <tr>\n", + " <th>12</th>\n", + " <td>French Guiana</td>\n", + " <td>10</td>\n", + " </tr>\n", + " <tr>\n", + " <th>18</th>\n", + " <td>Benin</td>\n", + " <td>9</td>\n", + " </tr>\n", + " <tr>\n", + " <th>5</th>\n", + " <td>El Salvador</td>\n", + " <td>9</td>\n", + " </tr>\n", + " <tr>\n", + " <th>17</th>\n", + " <td>Botswana</td>\n", + " <td>8</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>Bhutan</td>\n", + " <td>7</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>Liberia</td>\n", + " <td>7</td>\n", + " </tr>\n", + " <tr>\n", + " <th>16</th>\n", + " <td>Bolivia</td>\n", + " <td>6</td>\n", + " </tr>\n", + " <tr>\n", + " <th>10</th>\n", + " <td>Senegal</td>\n", + " <td>6</td>\n", + " </tr>\n", + " <tr>\n", + " <th>11</th>\n", + " <td>Zimbabwe</td>\n", + " <td>3</td>\n", + " </tr>\n", + " <tr>\n", + " <th>14</th>\n", + " <td>Switzerland</td>\n", + " <td>3</td>\n", + " </tr>\n", + " <tr>\n", + " <th>15</th>\n", + " <td>Mozambique</td>\n", + " <td>3</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>Belize</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <th>7</th>\n", + " <td>Indonesia</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <th>8</th>\n", + " <td>Guatemala</td>\n", + " <td>2</td>\n", + " </tr>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>Brazil</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>13</th>\n", + " <td>Laos</td>\n", + " <td>1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>9</th>\n", + " <td>Malta</td>\n", + " <td>1</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " index 0\n", + "3 Chad 10\n", + "6 Gambia 10\n", + "12 French Guiana 10\n", + "18 Benin 9\n", + "5 El Salvador 9\n", + "17 Botswana 8\n", + "4 Bhutan 7\n", + "1 Liberia 7\n", + "16 Bolivia 6\n", + "10 Senegal 6\n", + "11 Zimbabwe 3\n", + "14 Switzerland 3\n", + "15 Mozambique 3\n", + "2 Belize 2\n", + "7 Indonesia 2\n", + "8 Guatemala 2\n", + "0 Brazil 1\n", + "13 Laos 1\n", + "9 Malta 1" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_country_vote.sort_values(0, ascending=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.51000000000000001" + ] + }, + "execution_count": 102, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def precision_at_k(array, gr_truth, k):\n", + " return len(set(array[:k]) & set(gr_truth[:k])) / float(k)\n", + " \n", + "k = 10\n", + "ground_truth = df_country_vote['index'].get_values()\n", + "p_ = []\n", + "for j in range(ranked_countries_arr.shape[1]):\n", + " p_.append(precision_at_k(ranked_countries_arr[:, j], ground_truth, k))\n", + "p_ = np.array(p_)\n", + "np.mean(p_)" ] }, {
--- a/notebooks/test_hubness.ipynb Fri Sep 22 18:02:59 2017 +0100 +++ b/notebooks/test_hubness.ipynb Tue Sep 26 12:40:07 2017 +0100 @@ -2,11 +2,20 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 16, "metadata": { - "collapsed": true + "collapsed": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "import numpy as np\n", "import pickle\n", @@ -28,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 17, "metadata": { "collapsed": true }, @@ -46,7 +55,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 18, "metadata": { "collapsed": false }, @@ -57,7 +66,7 @@ "(8200, 380)" ] }, - "execution_count": 4, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -81,12 +90,13 @@ }, "outputs": [], "source": [ - "D = pairwise_distances(X, metric='mahalanobis')" + "D = pairwise_distances(X, metric='mahalanobis')\n", + "np.savetxt('../data/D_mahal.csv', D)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 19, "metadata": { "collapsed": false }, @@ -97,13 +107,12 @@ "(8200, 8200)" ] }, - "execution_count": 5, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "np.savetxt('../data/D_mahal.csv', D)\n", "D = np.loadtxt('../data/D_mahal.csv')\n", "D.shape" ] @@ -238,8 +247,6 @@ ] }, { - "collapsed": false - }, "cell_type": "code", "execution_count": 16, "metadata": { @@ -653,7 +660,9 @@ { "cell_type": "code", "execution_count": 7, - "metadata": {}, + "metadata": { + "collapsed": false + }, "outputs": [ { "name": "stdout", @@ -693,7 +702,9 @@ { "cell_type": "code", "execution_count": 15, - "metadata": {}, + "metadata": { + "collapsed": false + }, "outputs": [ { "name": "stdout", @@ -724,7 +735,9 @@ { "cell_type": "code", "execution_count": 6, - "metadata": {}, + "metadata": { + "collapsed": false + }, "outputs": [ { "data": { @@ -746,7 +759,9 @@ { "cell_type": "code", "execution_count": 7, - "metadata": {}, + "metadata": { + "collapsed": false + }, "outputs": [ { "data": {