m@94
|
1 {
|
m@94
|
2 "cells": [
|
m@94
|
3 {
|
m@94
|
4 "cell_type": "code",
|
m@94
|
5 "execution_count": 6,
|
m@94
|
6 "metadata": {},
|
m@94
|
7 "outputs": [
|
m@94
|
8 {
|
m@94
|
9 "name": "stdout",
|
m@94
|
10 "output_type": "stream",
|
m@94
|
11 "text": [
|
m@94
|
12 "The autoreload extension is already loaded. To reload it, use:\n",
|
m@94
|
13 " %reload_ext autoreload\n"
|
m@94
|
14 ]
|
m@94
|
15 }
|
m@94
|
16 ],
|
m@94
|
17 "source": [
|
m@94
|
18 "import numpy as np\n",
|
m@94
|
19 "import pandas as pd\n",
|
m@94
|
20 "from sklearn.model_selection import train_test_split\n",
|
m@94
|
21 "from collections import Counter\n",
|
m@94
|
22 "\n",
|
m@94
|
23 "%matplotlib inline\n",
|
m@94
|
24 "import matplotlib.pyplot as plt\n",
|
m@94
|
25 "\n",
|
m@94
|
26 "%load_ext autoreload\n",
|
m@94
|
27 "%autoreload 2\n",
|
m@94
|
28 "\n",
|
m@94
|
29 "import sys\n",
|
m@94
|
30 "sys.path.append('../')\n",
|
m@94
|
31 "import scripts.classification as classification\n",
|
m@94
|
32 "import scripts.outliers as outliers"
|
m@94
|
33 ]
|
m@94
|
34 },
|
m@94
|
35 {
|
m@94
|
36 "cell_type": "markdown",
|
m@94
|
37 "metadata": {},
|
m@94
|
38 "source": [
|
m@94
|
39 "## Sample 80% of the dataset, for 10 times"
|
m@94
|
40 ]
|
m@94
|
41 },
|
m@94
|
42 {
|
m@94
|
43 "cell_type": "markdown",
|
m@94
|
44 "metadata": {},
|
m@94
|
45 "source": [
|
m@94
|
46 "Let's sample only 80% of the recordings each time (in a stratified manner) so that the set of recordings considered for each country is changed every time."
|
m@94
|
47 ]
|
m@94
|
48 },
|
m@94
|
49 {
|
m@94
|
50 "cell_type": "code",
|
m@94
|
51 "execution_count": null,
|
m@94
|
52 "metadata": {
|
m@94
|
53 "collapsed": true
|
m@94
|
54 },
|
m@94
|
55 "outputs": [],
|
m@94
|
56 "source": [
|
m@94
|
57 "results_file = '../data/lda_data_8.pickle'\n",
|
m@94
|
58 "n_iters = 10\n",
|
m@94
|
59 "for n in range(n_iters):\n",
|
m@94
|
60 " print \"iteration %d\" % n\n",
|
m@94
|
61 " print results_file\n",
|
m@94
|
62 " X, Y, Yaudio = classification.load_data_from_pickle(results_file)\n",
|
m@94
|
63 " # get only 80% of the dataset.. to vary the choice of outliers\n",
|
m@94
|
64 " X, _, Y, _ = train_test_split(X, Y, train_size=0.8, stratify=Y)\n",
|
m@94
|
65 " print X.shape, Y.shape\n",
|
m@94
|
66 " # outliers\n",
|
m@94
|
67 " print \"detecting outliers...\"\n",
|
m@94
|
68 " df_global, threshold, MD = outliers.get_outliers_df(X, Y, chi2thr=0.999)\n",
|
m@94
|
69 " outliers.print_most_least_outliers_topN(df_global, N=10)\n",
|
m@94
|
70 " \n",
|
m@94
|
71 " # write output\n",
|
m@94
|
72 " print \"writing file\"\n",
|
m@94
|
73 " df_global.to_csv('../data/outliers_'+str(n)+'.csv', index=False)"
|
m@94
|
74 ]
|
m@94
|
75 },
|
m@94
|
76 {
|
m@94
|
77 "cell_type": "code",
|
m@94
|
78 "execution_count": 3,
|
m@94
|
79 "metadata": {},
|
m@94
|
80 "outputs": [],
|
m@94
|
81 "source": [
|
m@94
|
82 "n_iters = 10\n",
|
m@94
|
83 "ranked_countries = pd.DataFrame()\n",
|
m@94
|
84 "ranked_outliers = pd.DataFrame()\n",
|
m@94
|
85 "for n in range(n_iters):\n",
|
m@94
|
86 " df_global = pd.read_csv('../data/outliers_'+str(n)+'.csv')\n",
|
m@94
|
87 " df_global = df_global.sort_values('Outliers', axis=0, ascending=False).reset_index()\n",
|
m@94
|
88 " ranked_countries = pd.concat([ranked_countries, df_global['Country']], axis=1)\n",
|
m@94
|
89 " ranked_outliers = pd.concat([ranked_outliers, df_global['Outliers']], axis=1)\n",
|
m@94
|
90 "ranked_countries_arr = ranked_countries.get_values()"
|
m@94
|
91 ]
|
m@94
|
92 },
|
m@94
|
93 {
|
m@94
|
94 "cell_type": "markdown",
|
m@94
|
95 "metadata": {},
|
m@94
|
96 "source": [
|
m@94
|
97 "## Estimate precision at K"
|
m@94
|
98 ]
|
m@94
|
99 },
|
m@94
|
100 {
|
m@94
|
101 "cell_type": "markdown",
|
m@94
|
102 "metadata": {},
|
m@94
|
103 "source": [
|
m@94
|
104 "First get the ground truth from a majority vote on the top K=10 positions."
|
m@94
|
105 ]
|
m@94
|
106 },
|
m@94
|
107 {
|
m@94
|
108 "cell_type": "code",
|
m@94
|
109 "execution_count": 5,
|
m@94
|
110 "metadata": {
|
m@94
|
111 "collapsed": true
|
m@94
|
112 },
|
m@94
|
113 "outputs": [],
|
m@94
|
114 "source": [
|
m@94
|
115 "# majority voting + precision at K\n",
|
m@94
|
116 "K_vote = 10\n",
|
m@94
|
117 "country_vote = Counter(ranked_countries_arr[:K_vote, :].ravel())"
|
m@94
|
118 ]
|
m@94
|
119 },
|
m@94
|
120 {
|
m@94
|
121 "cell_type": "code",
|
m@94
|
122 "execution_count": 8,
|
m@94
|
123 "metadata": {},
|
m@94
|
124 "outputs": [
|
m@94
|
125 {
|
m@94
|
126 "data": {
|
m@94
|
127 "text/html": [
|
m@94
|
128 "<div>\n",
|
m@94
|
129 "<table border=\"1\" class=\"dataframe\">\n",
|
m@94
|
130 " <thead>\n",
|
m@94
|
131 " <tr style=\"text-align: right;\">\n",
|
m@94
|
132 " <th></th>\n",
|
m@94
|
133 " <th>index</th>\n",
|
m@94
|
134 " <th>0</th>\n",
|
m@94
|
135 " </tr>\n",
|
m@94
|
136 " </thead>\n",
|
m@94
|
137 " <tbody>\n",
|
m@94
|
138 " <tr>\n",
|
m@94
|
139 " <th>0</th>\n",
|
m@94
|
140 " <td>Pakistan</td>\n",
|
m@94
|
141 " <td>10</td>\n",
|
m@94
|
142 " </tr>\n",
|
m@94
|
143 " <tr>\n",
|
m@94
|
144 " <th>2</th>\n",
|
m@94
|
145 " <td>Chad</td>\n",
|
m@94
|
146 " <td>10</td>\n",
|
m@94
|
147 " </tr>\n",
|
m@94
|
148 " <tr>\n",
|
m@94
|
149 " <th>5</th>\n",
|
m@94
|
150 " <td>Gambia</td>\n",
|
m@94
|
151 " <td>10</td>\n",
|
m@94
|
152 " </tr>\n",
|
m@94
|
153 " <tr>\n",
|
m@94
|
154 " <th>10</th>\n",
|
m@94
|
155 " <td>Ivory Coast</td>\n",
|
m@94
|
156 " <td>10</td>\n",
|
m@94
|
157 " </tr>\n",
|
m@94
|
158 " <tr>\n",
|
m@94
|
159 " <th>12</th>\n",
|
m@94
|
160 " <td>Botswana</td>\n",
|
m@94
|
161 " <td>10</td>\n",
|
m@94
|
162 " </tr>\n",
|
m@94
|
163 " <tr>\n",
|
m@94
|
164 " <th>6</th>\n",
|
m@94
|
165 " <td>Nepal</td>\n",
|
m@94
|
166 " <td>9</td>\n",
|
m@94
|
167 " </tr>\n",
|
m@94
|
168 " <tr>\n",
|
m@94
|
169 " <th>13</th>\n",
|
m@94
|
170 " <td>Benin</td>\n",
|
m@94
|
171 " <td>8</td>\n",
|
m@94
|
172 " </tr>\n",
|
m@94
|
173 " <tr>\n",
|
m@94
|
174 " <th>8</th>\n",
|
m@94
|
175 " <td>Senegal</td>\n",
|
m@94
|
176 " <td>7</td>\n",
|
m@94
|
177 " </tr>\n",
|
m@94
|
178 " <tr>\n",
|
m@94
|
179 " <th>9</th>\n",
|
m@94
|
180 " <td>French Guiana</td>\n",
|
m@94
|
181 " <td>7</td>\n",
|
m@94
|
182 " </tr>\n",
|
m@94
|
183 " <tr>\n",
|
m@94
|
184 " <th>4</th>\n",
|
m@94
|
185 " <td>El Salvador</td>\n",
|
m@94
|
186 " <td>5</td>\n",
|
m@94
|
187 " </tr>\n",
|
m@94
|
188 " <tr>\n",
|
m@94
|
189 " <th>11</th>\n",
|
m@94
|
190 " <td>Mozambique</td>\n",
|
m@94
|
191 " <td>5</td>\n",
|
m@94
|
192 " </tr>\n",
|
m@94
|
193 " <tr>\n",
|
m@94
|
194 " <th>7</th>\n",
|
m@94
|
195 " <td>Uganda</td>\n",
|
m@94
|
196 " <td>4</td>\n",
|
m@94
|
197 " </tr>\n",
|
m@94
|
198 " <tr>\n",
|
m@94
|
199 " <th>1</th>\n",
|
m@94
|
200 " <td>Bhutan</td>\n",
|
m@94
|
201 " <td>3</td>\n",
|
m@94
|
202 " </tr>\n",
|
m@94
|
203 " <tr>\n",
|
m@94
|
204 " <th>3</th>\n",
|
m@94
|
205 " <td>Liberia</td>\n",
|
m@94
|
206 " <td>2</td>\n",
|
m@94
|
207 " </tr>\n",
|
m@94
|
208 " </tbody>\n",
|
m@94
|
209 "</table>\n",
|
m@94
|
210 "</div>"
|
m@94
|
211 ],
|
m@94
|
212 "text/plain": [
|
m@94
|
213 " index 0\n",
|
m@94
|
214 "0 Pakistan 10\n",
|
m@94
|
215 "2 Chad 10\n",
|
m@94
|
216 "5 Gambia 10\n",
|
m@94
|
217 "10 Ivory Coast 10\n",
|
m@94
|
218 "12 Botswana 10\n",
|
m@94
|
219 "6 Nepal 9\n",
|
m@94
|
220 "13 Benin 8\n",
|
m@94
|
221 "8 Senegal 7\n",
|
m@94
|
222 "9 French Guiana 7\n",
|
m@94
|
223 "4 El Salvador 5\n",
|
m@94
|
224 "11 Mozambique 5\n",
|
m@94
|
225 "7 Uganda 4\n",
|
m@94
|
226 "1 Bhutan 3\n",
|
m@94
|
227 "3 Liberia 2"
|
m@94
|
228 ]
|
m@94
|
229 },
|
m@94
|
230 "execution_count": 8,
|
m@94
|
231 "metadata": {},
|
m@94
|
232 "output_type": "execute_result"
|
m@94
|
233 }
|
m@94
|
234 ],
|
m@94
|
235 "source": [
|
m@94
|
236 "df_country_vote = pd.DataFrame.from_dict(country_vote, orient='index').reset_index()\n",
|
m@94
|
237 "df_country_vote.sort_values(0, ascending=False)"
|
m@94
|
238 ]
|
m@94
|
239 },
|
m@94
|
240 {
|
m@94
|
241 "cell_type": "code",
|
m@94
|
242 "execution_count": 9,
|
m@94
|
243 "metadata": {
|
m@94
|
244 "collapsed": true
|
m@94
|
245 },
|
m@94
|
246 "outputs": [],
|
m@94
|
247 "source": [
|
m@94
|
248 "def precision_at_k(array, gr_truth, k):\n",
|
m@94
|
249 " return len(set(array[:k]) & set(gr_truth[:k])) / float(k)\n",
|
m@94
|
250 " \n",
|
m@94
|
251 "k = 10\n",
|
m@94
|
252 "ground_truth = df_country_vote['index'].get_values()\n",
|
m@94
|
253 "p_ = []\n",
|
m@94
|
254 "for j in range(ranked_countries_arr.shape[1]):\n",
|
m@94
|
255 " p_.append(precision_at_k(ranked_countries_arr[:, j], ground_truth, k))\n",
|
m@94
|
256 "p_ = np.array(p_)"
|
m@94
|
257 ]
|
m@94
|
258 },
|
m@94
|
259 {
|
m@94
|
260 "cell_type": "code",
|
m@94
|
261 "execution_count": 10,
|
m@94
|
262 "metadata": {},
|
m@94
|
263 "outputs": [
|
m@94
|
264 {
|
m@94
|
265 "name": "stdout",
|
m@94
|
266 "output_type": "stream",
|
m@94
|
267 "text": [
|
m@94
|
268 "mean 0.67\n",
|
m@94
|
269 "std 0.0640312423743\n"
|
m@94
|
270 ]
|
m@94
|
271 }
|
m@94
|
272 ],
|
m@94
|
273 "source": [
|
m@94
|
274 "print 'mean', np.mean(p_) \n",
|
m@94
|
275 "print 'std', np.std(p_)"
|
m@94
|
276 ]
|
m@94
|
277 },
|
m@94
|
278 {
|
m@94
|
279 "cell_type": "code",
|
m@94
|
280 "execution_count": 11,
|
m@94
|
281 "metadata": {},
|
m@94
|
282 "outputs": [
|
m@94
|
283 {
|
m@94
|
284 "name": "stdout",
|
m@94
|
285 "output_type": "stream",
|
m@94
|
286 "text": [
|
m@94
|
287 "[ 0.6 0.7 0.7 0.6 0.6 0.7 0.8 0.6 0.7 0.7]\n"
|
m@94
|
288 ]
|
m@94
|
289 }
|
m@94
|
290 ],
|
m@94
|
291 "source": [
|
m@94
|
292 "print p_"
|
m@94
|
293 ]
|
m@94
|
294 },
|
m@94
|
295 {
|
m@94
|
296 "cell_type": "code",
|
m@94
|
297 "execution_count": null,
|
m@94
|
298 "metadata": {
|
m@94
|
299 "collapsed": true
|
m@94
|
300 },
|
m@94
|
301 "outputs": [],
|
m@94
|
302 "source": []
|
m@94
|
303 }
|
m@94
|
304 ],
|
m@94
|
305 "metadata": {
|
m@94
|
306 "kernelspec": {
|
m@94
|
307 "display_name": "Python 2",
|
m@94
|
308 "language": "python",
|
m@94
|
309 "name": "python2"
|
m@94
|
310 },
|
m@94
|
311 "language_info": {
|
m@94
|
312 "codemirror_mode": {
|
m@94
|
313 "name": "ipython",
|
m@94
|
314 "version": 2
|
m@94
|
315 },
|
m@94
|
316 "file_extension": ".py",
|
m@94
|
317 "mimetype": "text/x-python",
|
m@94
|
318 "name": "python",
|
m@94
|
319 "nbconvert_exporter": "python",
|
m@94
|
320 "pygments_lexer": "ipython2",
|
m@94
|
321 "version": "2.7.12"
|
m@94
|
322 }
|
m@94
|
323 },
|
m@94
|
324 "nbformat": 4,
|
m@94
|
325 "nbformat_minor": 2
|
m@94
|
326 }
|