m@94: { m@94: "cells": [ m@94: { m@94: "cell_type": "code", m@94: "execution_count": 6, m@94: "metadata": {}, m@94: "outputs": [ m@94: { m@94: "name": "stdout", m@94: "output_type": "stream", m@94: "text": [ m@94: "The autoreload extension is already loaded. To reload it, use:\n", m@94: " %reload_ext autoreload\n" m@94: ] m@94: } m@94: ], m@94: "source": [ m@94: "import numpy as np\n", m@94: "import pandas as pd\n", m@94: "from sklearn.model_selection import train_test_split\n", m@94: "from collections import Counter\n", m@94: "\n", m@94: "%matplotlib inline\n", m@94: "import matplotlib.pyplot as plt\n", m@94: "\n", m@94: "%load_ext autoreload\n", m@94: "%autoreload 2\n", m@94: "\n", m@94: "import sys\n", m@94: "sys.path.append('../')\n", m@94: "import scripts.classification as classification\n", m@94: "import scripts.outliers as outliers" m@94: ] m@94: }, m@94: { m@94: "cell_type": "markdown", m@94: "metadata": {}, m@94: "source": [ m@94: "## Sample 80% of the dataset, for 10 times" m@94: ] m@94: }, m@94: { m@94: "cell_type": "markdown", m@94: "metadata": {}, m@94: "source": [ m@94: "Let's sample only 80% of the recordings each time (in a stratified manner) so that the set of recordings considered for each country is changed every time." m@94: ] m@94: }, m@94: { m@94: "cell_type": "code", m@94: "execution_count": null, m@94: "metadata": { m@94: "collapsed": true m@94: }, m@94: "outputs": [], m@94: "source": [ m@94: "results_file = '../data/lda_data_8.pickle'\n", m@94: "n_iters = 10\n", m@94: "for n in range(n_iters):\n", m@94: " print \"iteration %d\" % n\n", m@94: " print results_file\n", m@94: " X, Y, Yaudio = classification.load_data_from_pickle(results_file)\n", m@94: " # get only 80% of the dataset.. to vary the choice of outliers\n", m@94: " X, _, Y, _ = train_test_split(X, Y, train_size=0.8, stratify=Y)\n", m@94: " print X.shape, Y.shape\n", m@94: " # outliers\n", m@94: " print \"detecting outliers...\"\n", m@94: " df_global, threshold, MD = outliers.get_outliers_df(X, Y, chi2thr=0.999)\n", m@94: " outliers.print_most_least_outliers_topN(df_global, N=10)\n", m@94: " \n", m@94: " # write output\n", m@94: " print \"writing file\"\n", m@94: " df_global.to_csv('../data/outliers_'+str(n)+'.csv', index=False)" m@94: ] m@94: }, m@94: { m@94: "cell_type": "code", m@94: "execution_count": 3, m@94: "metadata": {}, m@94: "outputs": [], m@94: "source": [ m@94: "n_iters = 10\n", m@94: "ranked_countries = pd.DataFrame()\n", m@94: "ranked_outliers = pd.DataFrame()\n", m@94: "for n in range(n_iters):\n", m@94: " df_global = pd.read_csv('../data/outliers_'+str(n)+'.csv')\n", m@94: " df_global = df_global.sort_values('Outliers', axis=0, ascending=False).reset_index()\n", m@94: " ranked_countries = pd.concat([ranked_countries, df_global['Country']], axis=1)\n", m@94: " ranked_outliers = pd.concat([ranked_outliers, df_global['Outliers']], axis=1)\n", m@94: "ranked_countries_arr = ranked_countries.get_values()" m@94: ] m@94: }, m@94: { m@94: "cell_type": "markdown", m@94: "metadata": {}, m@94: "source": [ m@94: "## Estimate precision at K" m@94: ] m@94: }, m@94: { m@94: "cell_type": "markdown", m@94: "metadata": {}, m@94: "source": [ m@94: "First get the ground truth from a majority vote on the top K=10 positions." m@94: ] m@94: }, m@94: { m@94: "cell_type": "code", m@94: "execution_count": 5, m@94: "metadata": { m@94: "collapsed": true m@94: }, m@94: "outputs": [], m@94: "source": [ m@94: "# majority voting + precision at K\n", m@94: "K_vote = 10\n", m@94: "country_vote = Counter(ranked_countries_arr[:K_vote, :].ravel())" m@94: ] m@94: }, m@94: { m@94: "cell_type": "code", m@94: "execution_count": 8, m@94: "metadata": {}, m@94: "outputs": [ m@94: { m@94: "data": { m@94: "text/html": [ m@94: "
\n", m@94: " | index | \n", m@94: "0 | \n", m@94: "
---|---|---|
0 | \n", m@94: "Pakistan | \n", m@94: "10 | \n", m@94: "
2 | \n", m@94: "Chad | \n", m@94: "10 | \n", m@94: "
5 | \n", m@94: "Gambia | \n", m@94: "10 | \n", m@94: "
10 | \n", m@94: "Ivory Coast | \n", m@94: "10 | \n", m@94: "
12 | \n", m@94: "Botswana | \n", m@94: "10 | \n", m@94: "
6 | \n", m@94: "Nepal | \n", m@94: "9 | \n", m@94: "
13 | \n", m@94: "Benin | \n", m@94: "8 | \n", m@94: "
8 | \n", m@94: "Senegal | \n", m@94: "7 | \n", m@94: "
9 | \n", m@94: "French Guiana | \n", m@94: "7 | \n", m@94: "
4 | \n", m@94: "El Salvador | \n", m@94: "5 | \n", m@94: "
11 | \n", m@94: "Mozambique | \n", m@94: "5 | \n", m@94: "
7 | \n", m@94: "Uganda | \n", m@94: "4 | \n", m@94: "
1 | \n", m@94: "Bhutan | \n", m@94: "3 | \n", m@94: "
3 | \n", m@94: "Liberia | \n", m@94: "2 | \n", m@94: "