annotate notebooks/sensitivity_experiment_outliers.ipynb @ 105:edd82eb89b4b branch-tests tip

Merge
author Maria Panteli
date Sun, 15 Oct 2017 13:36:59 +0100
parents 69521f86d931
children
rev   line source
m@94 1 {
m@94 2 "cells": [
m@94 3 {
m@94 4 "cell_type": "code",
m@94 5 "execution_count": 6,
m@94 6 "metadata": {},
m@94 7 "outputs": [
m@94 8 {
m@94 9 "name": "stdout",
m@94 10 "output_type": "stream",
m@94 11 "text": [
m@94 12 "The autoreload extension is already loaded. To reload it, use:\n",
m@94 13 " %reload_ext autoreload\n"
m@94 14 ]
m@94 15 }
m@94 16 ],
m@94 17 "source": [
m@94 18 "import numpy as np\n",
m@94 19 "import pandas as pd\n",
m@94 20 "from sklearn.model_selection import train_test_split\n",
m@94 21 "from collections import Counter\n",
m@94 22 "\n",
m@94 23 "%matplotlib inline\n",
m@94 24 "import matplotlib.pyplot as plt\n",
m@94 25 "\n",
m@94 26 "%load_ext autoreload\n",
m@94 27 "%autoreload 2\n",
m@94 28 "\n",
m@94 29 "import sys\n",
m@94 30 "sys.path.append('../')\n",
m@94 31 "import scripts.classification as classification\n",
m@94 32 "import scripts.outliers as outliers"
m@94 33 ]
m@94 34 },
m@94 35 {
m@94 36 "cell_type": "markdown",
m@94 37 "metadata": {},
m@94 38 "source": [
m@94 39 "## Sample 80% of the dataset, for 10 times"
m@94 40 ]
m@94 41 },
m@94 42 {
m@94 43 "cell_type": "markdown",
m@94 44 "metadata": {},
m@94 45 "source": [
m@94 46 "Let's sample only 80% of the recordings each time (in a stratified manner) so that the set of recordings considered for each country is changed every time."
m@94 47 ]
m@94 48 },
m@94 49 {
m@94 50 "cell_type": "code",
m@94 51 "execution_count": null,
m@94 52 "metadata": {
m@94 53 "collapsed": true
m@94 54 },
m@94 55 "outputs": [],
m@94 56 "source": [
m@94 57 "results_file = '../data/lda_data_8.pickle'\n",
m@94 58 "n_iters = 10\n",
m@94 59 "for n in range(n_iters):\n",
m@94 60 " print \"iteration %d\" % n\n",
m@94 61 " print results_file\n",
m@94 62 " X, Y, Yaudio = classification.load_data_from_pickle(results_file)\n",
m@94 63 " # get only 80% of the dataset.. to vary the choice of outliers\n",
m@94 64 " X, _, Y, _ = train_test_split(X, Y, train_size=0.8, stratify=Y)\n",
m@94 65 " print X.shape, Y.shape\n",
m@94 66 " # outliers\n",
m@94 67 " print \"detecting outliers...\"\n",
m@94 68 " df_global, threshold, MD = outliers.get_outliers_df(X, Y, chi2thr=0.999)\n",
m@94 69 " outliers.print_most_least_outliers_topN(df_global, N=10)\n",
m@94 70 " \n",
m@94 71 " # write output\n",
m@94 72 " print \"writing file\"\n",
m@94 73 " df_global.to_csv('../data/outliers_'+str(n)+'.csv', index=False)"
m@94 74 ]
m@94 75 },
m@94 76 {
m@94 77 "cell_type": "code",
m@94 78 "execution_count": 3,
m@94 79 "metadata": {},
m@94 80 "outputs": [],
m@94 81 "source": [
m@94 82 "n_iters = 10\n",
m@94 83 "ranked_countries = pd.DataFrame()\n",
m@94 84 "ranked_outliers = pd.DataFrame()\n",
m@94 85 "for n in range(n_iters):\n",
m@94 86 " df_global = pd.read_csv('../data/outliers_'+str(n)+'.csv')\n",
m@94 87 " df_global = df_global.sort_values('Outliers', axis=0, ascending=False).reset_index()\n",
m@94 88 " ranked_countries = pd.concat([ranked_countries, df_global['Country']], axis=1)\n",
m@94 89 " ranked_outliers = pd.concat([ranked_outliers, df_global['Outliers']], axis=1)\n",
m@94 90 "ranked_countries_arr = ranked_countries.get_values()"
m@94 91 ]
m@94 92 },
m@94 93 {
m@94 94 "cell_type": "markdown",
m@94 95 "metadata": {},
m@94 96 "source": [
m@94 97 "## Estimate precision at K"
m@94 98 ]
m@94 99 },
m@94 100 {
m@94 101 "cell_type": "markdown",
m@94 102 "metadata": {},
m@94 103 "source": [
m@94 104 "First get the ground truth from a majority vote on the top K=10 positions."
m@94 105 ]
m@94 106 },
m@94 107 {
m@94 108 "cell_type": "code",
m@94 109 "execution_count": 5,
m@94 110 "metadata": {
m@94 111 "collapsed": true
m@94 112 },
m@94 113 "outputs": [],
m@94 114 "source": [
m@94 115 "# majority voting + precision at K\n",
m@94 116 "K_vote = 10\n",
m@94 117 "country_vote = Counter(ranked_countries_arr[:K_vote, :].ravel())"
m@94 118 ]
m@94 119 },
m@94 120 {
m@94 121 "cell_type": "code",
m@94 122 "execution_count": 8,
m@94 123 "metadata": {},
m@94 124 "outputs": [
m@94 125 {
m@94 126 "data": {
m@94 127 "text/html": [
m@94 128 "<div>\n",
m@94 129 "<table border=\"1\" class=\"dataframe\">\n",
m@94 130 " <thead>\n",
m@94 131 " <tr style=\"text-align: right;\">\n",
m@94 132 " <th></th>\n",
m@94 133 " <th>index</th>\n",
m@94 134 " <th>0</th>\n",
m@94 135 " </tr>\n",
m@94 136 " </thead>\n",
m@94 137 " <tbody>\n",
m@94 138 " <tr>\n",
m@94 139 " <th>0</th>\n",
m@94 140 " <td>Pakistan</td>\n",
m@94 141 " <td>10</td>\n",
m@94 142 " </tr>\n",
m@94 143 " <tr>\n",
m@94 144 " <th>2</th>\n",
m@94 145 " <td>Chad</td>\n",
m@94 146 " <td>10</td>\n",
m@94 147 " </tr>\n",
m@94 148 " <tr>\n",
m@94 149 " <th>5</th>\n",
m@94 150 " <td>Gambia</td>\n",
m@94 151 " <td>10</td>\n",
m@94 152 " </tr>\n",
m@94 153 " <tr>\n",
m@94 154 " <th>10</th>\n",
m@94 155 " <td>Ivory Coast</td>\n",
m@94 156 " <td>10</td>\n",
m@94 157 " </tr>\n",
m@94 158 " <tr>\n",
m@94 159 " <th>12</th>\n",
m@94 160 " <td>Botswana</td>\n",
m@94 161 " <td>10</td>\n",
m@94 162 " </tr>\n",
m@94 163 " <tr>\n",
m@94 164 " <th>6</th>\n",
m@94 165 " <td>Nepal</td>\n",
m@94 166 " <td>9</td>\n",
m@94 167 " </tr>\n",
m@94 168 " <tr>\n",
m@94 169 " <th>13</th>\n",
m@94 170 " <td>Benin</td>\n",
m@94 171 " <td>8</td>\n",
m@94 172 " </tr>\n",
m@94 173 " <tr>\n",
m@94 174 " <th>8</th>\n",
m@94 175 " <td>Senegal</td>\n",
m@94 176 " <td>7</td>\n",
m@94 177 " </tr>\n",
m@94 178 " <tr>\n",
m@94 179 " <th>9</th>\n",
m@94 180 " <td>French Guiana</td>\n",
m@94 181 " <td>7</td>\n",
m@94 182 " </tr>\n",
m@94 183 " <tr>\n",
m@94 184 " <th>4</th>\n",
m@94 185 " <td>El Salvador</td>\n",
m@94 186 " <td>5</td>\n",
m@94 187 " </tr>\n",
m@94 188 " <tr>\n",
m@94 189 " <th>11</th>\n",
m@94 190 " <td>Mozambique</td>\n",
m@94 191 " <td>5</td>\n",
m@94 192 " </tr>\n",
m@94 193 " <tr>\n",
m@94 194 " <th>7</th>\n",
m@94 195 " <td>Uganda</td>\n",
m@94 196 " <td>4</td>\n",
m@94 197 " </tr>\n",
m@94 198 " <tr>\n",
m@94 199 " <th>1</th>\n",
m@94 200 " <td>Bhutan</td>\n",
m@94 201 " <td>3</td>\n",
m@94 202 " </tr>\n",
m@94 203 " <tr>\n",
m@94 204 " <th>3</th>\n",
m@94 205 " <td>Liberia</td>\n",
m@94 206 " <td>2</td>\n",
m@94 207 " </tr>\n",
m@94 208 " </tbody>\n",
m@94 209 "</table>\n",
m@94 210 "</div>"
m@94 211 ],
m@94 212 "text/plain": [
m@94 213 " index 0\n",
m@94 214 "0 Pakistan 10\n",
m@94 215 "2 Chad 10\n",
m@94 216 "5 Gambia 10\n",
m@94 217 "10 Ivory Coast 10\n",
m@94 218 "12 Botswana 10\n",
m@94 219 "6 Nepal 9\n",
m@94 220 "13 Benin 8\n",
m@94 221 "8 Senegal 7\n",
m@94 222 "9 French Guiana 7\n",
m@94 223 "4 El Salvador 5\n",
m@94 224 "11 Mozambique 5\n",
m@94 225 "7 Uganda 4\n",
m@94 226 "1 Bhutan 3\n",
m@94 227 "3 Liberia 2"
m@94 228 ]
m@94 229 },
m@94 230 "execution_count": 8,
m@94 231 "metadata": {},
m@94 232 "output_type": "execute_result"
m@94 233 }
m@94 234 ],
m@94 235 "source": [
m@94 236 "df_country_vote = pd.DataFrame.from_dict(country_vote, orient='index').reset_index()\n",
m@94 237 "df_country_vote.sort_values(0, ascending=False)"
m@94 238 ]
m@94 239 },
m@94 240 {
m@94 241 "cell_type": "code",
m@94 242 "execution_count": 9,
m@94 243 "metadata": {
m@94 244 "collapsed": true
m@94 245 },
m@94 246 "outputs": [],
m@94 247 "source": [
m@94 248 "def precision_at_k(array, gr_truth, k):\n",
m@94 249 " return len(set(array[:k]) & set(gr_truth[:k])) / float(k)\n",
m@94 250 " \n",
m@94 251 "k = 10\n",
m@94 252 "ground_truth = df_country_vote['index'].get_values()\n",
m@94 253 "p_ = []\n",
m@94 254 "for j in range(ranked_countries_arr.shape[1]):\n",
m@94 255 " p_.append(precision_at_k(ranked_countries_arr[:, j], ground_truth, k))\n",
m@94 256 "p_ = np.array(p_)"
m@94 257 ]
m@94 258 },
m@94 259 {
m@94 260 "cell_type": "code",
m@94 261 "execution_count": 10,
m@94 262 "metadata": {},
m@94 263 "outputs": [
m@94 264 {
m@94 265 "name": "stdout",
m@94 266 "output_type": "stream",
m@94 267 "text": [
m@94 268 "mean 0.67\n",
m@94 269 "std 0.0640312423743\n"
m@94 270 ]
m@94 271 }
m@94 272 ],
m@94 273 "source": [
m@94 274 "print 'mean', np.mean(p_) \n",
m@94 275 "print 'std', np.std(p_)"
m@94 276 ]
m@94 277 },
m@94 278 {
m@94 279 "cell_type": "code",
m@94 280 "execution_count": 11,
m@94 281 "metadata": {},
m@94 282 "outputs": [
m@94 283 {
m@94 284 "name": "stdout",
m@94 285 "output_type": "stream",
m@94 286 "text": [
m@94 287 "[ 0.6 0.7 0.7 0.6 0.6 0.7 0.8 0.6 0.7 0.7]\n"
m@94 288 ]
m@94 289 }
m@94 290 ],
m@94 291 "source": [
m@94 292 "print p_"
m@94 293 ]
m@94 294 },
m@94 295 {
m@94 296 "cell_type": "code",
m@94 297 "execution_count": null,
m@94 298 "metadata": {
m@94 299 "collapsed": true
m@94 300 },
m@94 301 "outputs": [],
m@94 302 "source": []
m@94 303 }
m@94 304 ],
m@94 305 "metadata": {
m@94 306 "kernelspec": {
m@94 307 "display_name": "Python 2",
m@94 308 "language": "python",
m@94 309 "name": "python2"
m@94 310 },
m@94 311 "language_info": {
m@94 312 "codemirror_mode": {
m@94 313 "name": "ipython",
m@94 314 "version": 2
m@94 315 },
m@94 316 "file_extension": ".py",
m@94 317 "mimetype": "text/x-python",
m@94 318 "name": "python",
m@94 319 "nbconvert_exporter": "python",
m@94 320 "pygments_lexer": "ipython2",
m@94 321 "version": "2.7.12"
m@94 322 }
m@94 323 },
m@94 324 "nbformat": 4,
m@94 325 "nbformat_minor": 2
m@94 326 }