Mercurial > hg > plosone_underreview
comparison notebooks/sensitivity_experiment_outliers.ipynb @ 94:69521f86d931 branch-tests
notebooks for publication
author | mpanteli <m.x.panteli@gmail.com> |
---|---|
date | Mon, 02 Oct 2017 18:59:29 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
93:f9513664fe42 | 94:69521f86d931 |
---|---|
1 { | |
2 "cells": [ | |
3 { | |
4 "cell_type": "code", | |
5 "execution_count": 6, | |
6 "metadata": {}, | |
7 "outputs": [ | |
8 { | |
9 "name": "stdout", | |
10 "output_type": "stream", | |
11 "text": [ | |
12 "The autoreload extension is already loaded. To reload it, use:\n", | |
13 " %reload_ext autoreload\n" | |
14 ] | |
15 } | |
16 ], | |
17 "source": [ | |
18 "import numpy as np\n", | |
19 "import pandas as pd\n", | |
20 "from sklearn.model_selection import train_test_split\n", | |
21 "from collections import Counter\n", | |
22 "\n", | |
23 "%matplotlib inline\n", | |
24 "import matplotlib.pyplot as plt\n", | |
25 "\n", | |
26 "%load_ext autoreload\n", | |
27 "%autoreload 2\n", | |
28 "\n", | |
29 "import sys\n", | |
30 "sys.path.append('../')\n", | |
31 "import scripts.classification as classification\n", | |
32 "import scripts.outliers as outliers" | |
33 ] | |
34 }, | |
35 { | |
36 "cell_type": "markdown", | |
37 "metadata": {}, | |
38 "source": [ | |
39 "## Sample 80% of the dataset, for 10 times" | |
40 ] | |
41 }, | |
42 { | |
43 "cell_type": "markdown", | |
44 "metadata": {}, | |
45 "source": [ | |
46 "Let's sample only 80% of the recordings each time (in a stratified manner) so that the set of recordings considered for each country is changed every time." | |
47 ] | |
48 }, | |
49 { | |
50 "cell_type": "code", | |
51 "execution_count": null, | |
52 "metadata": { | |
53 "collapsed": true | |
54 }, | |
55 "outputs": [], | |
56 "source": [ | |
57 "results_file = '../data/lda_data_8.pickle'\n", | |
58 "n_iters = 10\n", | |
59 "for n in range(n_iters):\n", | |
60 " print \"iteration %d\" % n\n", | |
61 " print results_file\n", | |
62 " X, Y, Yaudio = classification.load_data_from_pickle(results_file)\n", | |
63 " # get only 80% of the dataset.. to vary the choice of outliers\n", | |
64 " X, _, Y, _ = train_test_split(X, Y, train_size=0.8, stratify=Y)\n", | |
65 " print X.shape, Y.shape\n", | |
66 " # outliers\n", | |
67 " print \"detecting outliers...\"\n", | |
68 " df_global, threshold, MD = outliers.get_outliers_df(X, Y, chi2thr=0.999)\n", | |
69 " outliers.print_most_least_outliers_topN(df_global, N=10)\n", | |
70 " \n", | |
71 " # write output\n", | |
72 " print \"writing file\"\n", | |
73 " df_global.to_csv('../data/outliers_'+str(n)+'.csv', index=False)" | |
74 ] | |
75 }, | |
76 { | |
77 "cell_type": "code", | |
78 "execution_count": 3, | |
79 "metadata": {}, | |
80 "outputs": [], | |
81 "source": [ | |
82 "n_iters = 10\n", | |
83 "ranked_countries = pd.DataFrame()\n", | |
84 "ranked_outliers = pd.DataFrame()\n", | |
85 "for n in range(n_iters):\n", | |
86 " df_global = pd.read_csv('../data/outliers_'+str(n)+'.csv')\n", | |
87 " df_global = df_global.sort_values('Outliers', axis=0, ascending=False).reset_index()\n", | |
88 " ranked_countries = pd.concat([ranked_countries, df_global['Country']], axis=1)\n", | |
89 " ranked_outliers = pd.concat([ranked_outliers, df_global['Outliers']], axis=1)\n", | |
90 "ranked_countries_arr = ranked_countries.get_values()" | |
91 ] | |
92 }, | |
93 { | |
94 "cell_type": "markdown", | |
95 "metadata": {}, | |
96 "source": [ | |
97 "## Estimate precision at K" | |
98 ] | |
99 }, | |
100 { | |
101 "cell_type": "markdown", | |
102 "metadata": {}, | |
103 "source": [ | |
104 "First get the ground truth from a majority vote on the top K=10 positions." | |
105 ] | |
106 }, | |
107 { | |
108 "cell_type": "code", | |
109 "execution_count": 5, | |
110 "metadata": { | |
111 "collapsed": true | |
112 }, | |
113 "outputs": [], | |
114 "source": [ | |
115 "# majority voting + precision at K\n", | |
116 "K_vote = 10\n", | |
117 "country_vote = Counter(ranked_countries_arr[:K_vote, :].ravel())" | |
118 ] | |
119 }, | |
120 { | |
121 "cell_type": "code", | |
122 "execution_count": 8, | |
123 "metadata": {}, | |
124 "outputs": [ | |
125 { | |
126 "data": { | |
127 "text/html": [ | |
128 "<div>\n", | |
129 "<table border=\"1\" class=\"dataframe\">\n", | |
130 " <thead>\n", | |
131 " <tr style=\"text-align: right;\">\n", | |
132 " <th></th>\n", | |
133 " <th>index</th>\n", | |
134 " <th>0</th>\n", | |
135 " </tr>\n", | |
136 " </thead>\n", | |
137 " <tbody>\n", | |
138 " <tr>\n", | |
139 " <th>0</th>\n", | |
140 " <td>Pakistan</td>\n", | |
141 " <td>10</td>\n", | |
142 " </tr>\n", | |
143 " <tr>\n", | |
144 " <th>2</th>\n", | |
145 " <td>Chad</td>\n", | |
146 " <td>10</td>\n", | |
147 " </tr>\n", | |
148 " <tr>\n", | |
149 " <th>5</th>\n", | |
150 " <td>Gambia</td>\n", | |
151 " <td>10</td>\n", | |
152 " </tr>\n", | |
153 " <tr>\n", | |
154 " <th>10</th>\n", | |
155 " <td>Ivory Coast</td>\n", | |
156 " <td>10</td>\n", | |
157 " </tr>\n", | |
158 " <tr>\n", | |
159 " <th>12</th>\n", | |
160 " <td>Botswana</td>\n", | |
161 " <td>10</td>\n", | |
162 " </tr>\n", | |
163 " <tr>\n", | |
164 " <th>6</th>\n", | |
165 " <td>Nepal</td>\n", | |
166 " <td>9</td>\n", | |
167 " </tr>\n", | |
168 " <tr>\n", | |
169 " <th>13</th>\n", | |
170 " <td>Benin</td>\n", | |
171 " <td>8</td>\n", | |
172 " </tr>\n", | |
173 " <tr>\n", | |
174 " <th>8</th>\n", | |
175 " <td>Senegal</td>\n", | |
176 " <td>7</td>\n", | |
177 " </tr>\n", | |
178 " <tr>\n", | |
179 " <th>9</th>\n", | |
180 " <td>French Guiana</td>\n", | |
181 " <td>7</td>\n", | |
182 " </tr>\n", | |
183 " <tr>\n", | |
184 " <th>4</th>\n", | |
185 " <td>El Salvador</td>\n", | |
186 " <td>5</td>\n", | |
187 " </tr>\n", | |
188 " <tr>\n", | |
189 " <th>11</th>\n", | |
190 " <td>Mozambique</td>\n", | |
191 " <td>5</td>\n", | |
192 " </tr>\n", | |
193 " <tr>\n", | |
194 " <th>7</th>\n", | |
195 " <td>Uganda</td>\n", | |
196 " <td>4</td>\n", | |
197 " </tr>\n", | |
198 " <tr>\n", | |
199 " <th>1</th>\n", | |
200 " <td>Bhutan</td>\n", | |
201 " <td>3</td>\n", | |
202 " </tr>\n", | |
203 " <tr>\n", | |
204 " <th>3</th>\n", | |
205 " <td>Liberia</td>\n", | |
206 " <td>2</td>\n", | |
207 " </tr>\n", | |
208 " </tbody>\n", | |
209 "</table>\n", | |
210 "</div>" | |
211 ], | |
212 "text/plain": [ | |
213 " index 0\n", | |
214 "0 Pakistan 10\n", | |
215 "2 Chad 10\n", | |
216 "5 Gambia 10\n", | |
217 "10 Ivory Coast 10\n", | |
218 "12 Botswana 10\n", | |
219 "6 Nepal 9\n", | |
220 "13 Benin 8\n", | |
221 "8 Senegal 7\n", | |
222 "9 French Guiana 7\n", | |
223 "4 El Salvador 5\n", | |
224 "11 Mozambique 5\n", | |
225 "7 Uganda 4\n", | |
226 "1 Bhutan 3\n", | |
227 "3 Liberia 2" | |
228 ] | |
229 }, | |
230 "execution_count": 8, | |
231 "metadata": {}, | |
232 "output_type": "execute_result" | |
233 } | |
234 ], | |
235 "source": [ | |
236 "df_country_vote = pd.DataFrame.from_dict(country_vote, orient='index').reset_index()\n", | |
237 "df_country_vote.sort_values(0, ascending=False)" | |
238 ] | |
239 }, | |
240 { | |
241 "cell_type": "code", | |
242 "execution_count": 9, | |
243 "metadata": { | |
244 "collapsed": true | |
245 }, | |
246 "outputs": [], | |
247 "source": [ | |
248 "def precision_at_k(array, gr_truth, k):\n", | |
249 " return len(set(array[:k]) & set(gr_truth[:k])) / float(k)\n", | |
250 " \n", | |
251 "k = 10\n", | |
252 "ground_truth = df_country_vote['index'].get_values()\n", | |
253 "p_ = []\n", | |
254 "for j in range(ranked_countries_arr.shape[1]):\n", | |
255 " p_.append(precision_at_k(ranked_countries_arr[:, j], ground_truth, k))\n", | |
256 "p_ = np.array(p_)" | |
257 ] | |
258 }, | |
259 { | |
260 "cell_type": "code", | |
261 "execution_count": 10, | |
262 "metadata": {}, | |
263 "outputs": [ | |
264 { | |
265 "name": "stdout", | |
266 "output_type": "stream", | |
267 "text": [ | |
268 "mean 0.67\n", | |
269 "std 0.0640312423743\n" | |
270 ] | |
271 } | |
272 ], | |
273 "source": [ | |
274 "print 'mean', np.mean(p_) \n", | |
275 "print 'std', np.std(p_)" | |
276 ] | |
277 }, | |
278 { | |
279 "cell_type": "code", | |
280 "execution_count": 11, | |
281 "metadata": {}, | |
282 "outputs": [ | |
283 { | |
284 "name": "stdout", | |
285 "output_type": "stream", | |
286 "text": [ | |
287 "[ 0.6 0.7 0.7 0.6 0.6 0.7 0.8 0.6 0.7 0.7]\n" | |
288 ] | |
289 } | |
290 ], | |
291 "source": [ | |
292 "print p_" | |
293 ] | |
294 }, | |
295 { | |
296 "cell_type": "code", | |
297 "execution_count": null, | |
298 "metadata": { | |
299 "collapsed": true | |
300 }, | |
301 "outputs": [], | |
302 "source": [] | |
303 } | |
304 ], | |
305 "metadata": { | |
306 "kernelspec": { | |
307 "display_name": "Python 2", | |
308 "language": "python", | |
309 "name": "python2" | |
310 }, | |
311 "language_info": { | |
312 "codemirror_mode": { | |
313 "name": "ipython", | |
314 "version": 2 | |
315 }, | |
316 "file_extension": ".py", | |
317 "mimetype": "text/x-python", | |
318 "name": "python", | |
319 "nbconvert_exporter": "python", | |
320 "pygments_lexer": "ipython2", | |
321 "version": "2.7.12" | |
322 } | |
323 }, | |
324 "nbformat": 4, | |
325 "nbformat_minor": 2 | |
326 } |