Mercurial > hg > plosone_underreview
comparison notebooks/explain_components.ipynb @ 9:c4841876a8ff branch-tests
adding notebooks and trying to explain classifier coefficients
author | Maria Panteli <m.x.panteli@gmail.com> |
---|---|
date | Mon, 11 Sep 2017 19:06:40 +0100 |
parents | |
children | a1a9b472bcdb |
comparison
equal
deleted
inserted
replaced
8:0f3eba42b425 | 9:c4841876a8ff |
---|---|
1 { | |
2 "cells": [ | |
3 { | |
4 "cell_type": "code", | |
5 "execution_count": 7, | |
6 "metadata": { | |
7 "collapsed": false | |
8 }, | |
9 "outputs": [], | |
10 "source": [ | |
11 "import numpy as np\n", | |
12 "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA\n", | |
13 "\n", | |
14 "import sys\n", | |
15 "sys.path.append('../')\n", | |
16 "import scripts.map_and_average as mapper\n", | |
17 "import scripts.util_feature_learning as util_feature_learning" | |
18 ] | |
19 }, | |
20 { | |
21 "cell_type": "markdown", | |
22 "metadata": {}, | |
23 "source": [ | |
24 "## Load data" | |
25 ] | |
26 }, | |
27 { | |
28 "cell_type": "code", | |
29 "execution_count": 8, | |
30 "metadata": { | |
31 "collapsed": false | |
32 }, | |
33 "outputs": [ | |
34 { | |
35 "name": "stdout", | |
36 "output_type": "stream", | |
37 "text": [ | |
38 "/import/c4dm-04/mariap/train_data_melodia_8.pickle\n" | |
39 ] | |
40 }, | |
41 { | |
42 "ename": "IOError", | |
43 "evalue": "[Errno 2] No such file or directory: '/import/c4dm-04/mariap/train_data_melodia_8.pickle'", | |
44 "output_type": "error", | |
45 "traceback": [ | |
46 "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
47 "\u001b[0;31mIOError\u001b[0m Traceback (most recent call last)", | |
48 "\u001b[0;32m<ipython-input-8-aa3c9e978b25>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrainset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtestset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmapper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_train_val_test_sets\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
49 "\u001b[0;32m/Users/mariapanteli/Documents/QMUL/Code/MyPythonCode/plosone_underreview/scripts/map_and_average.pyc\u001b[0m in \u001b[0;36mload_train_val_test_sets\u001b[0;34m()\u001b[0m\n\u001b[1;32m 69\u001b[0m '''\n\u001b[1;32m 70\u001b[0m \u001b[0;32mprint\u001b[0m \u001b[0mINPUT_FILES\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 71\u001b[0;31m \u001b[0mtrainset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_data_from_pickle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mINPUT_FILES\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 72\u001b[0m \u001b[0mvalset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_data_from_pickle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mINPUT_FILES\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0mtestset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_data_from_pickle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mINPUT_FILES\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
50 "\u001b[0;32m/Users/mariapanteli/Documents/QMUL/Code/MyPythonCode/plosone_underreview/scripts/map_and_average.pyc\u001b[0m in \u001b[0;36mload_data_from_pickle\u001b[0;34m(pickle_file)\u001b[0m\n\u001b[1;32m 56\u001b[0m '''load frame based features and labels from pickle file\n\u001b[1;32m 57\u001b[0m '''\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpickle_file\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maudiolabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;31m# remove 'unknown' and 'unidentified' country\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
51 "\u001b[0;31mIOError\u001b[0m: [Errno 2] No such file or directory: '/import/c4dm-04/mariap/train_data_melodia_8.pickle'" | |
52 ] | |
53 } | |
54 ], | |
55 "source": [ | |
56 "trainset, valset, testset = mapper.load_train_val_test_sets()\n", | |
57 "traindata, trainlabels, trainaudiolabels = trainset\n", | |
58 "valdata, vallabels, valaudiolabels = valset\n", | |
59 "testdata, testlabels, testaudiolabels = testset\n", | |
60 "labels = np.concatenate((trainlabels, vallabels, testlabels)).ravel()\n", | |
61 "audiolabels = np.concatenate((trainaudiolabels, valaudiolabels, testaudiolabels)).ravel()\n", | |
62 "print traindata.shape, valdata.shape, testdata.shape" | |
63 ] | |
64 }, | |
65 { | |
66 "cell_type": "markdown", | |
67 "metadata": {}, | |
68 "source": [ | |
69 "## explain LDA" | |
70 ] | |
71 }, | |
72 { | |
73 "cell_type": "code", | |
74 "execution_count": null, | |
75 "metadata": { | |
76 "collapsed": true | |
77 }, | |
78 "outputs": [], | |
79 "source": [ | |
80 "min_variance = 0.99\n", | |
81 "feat_labels, feat_inds = mapper.get_feat_inds(n_dim=traindata.shape[1])\n", | |
82 "for i in range(len(feat_inds)):\n", | |
83 " print \"mapping \" + feat_labels[i]\n", | |
84 " inds = feat_inds[i]\n", | |
85 " ssm_feat = util_feature_learning.Transformer()\n", | |
86 " if min_variance is not None:\n", | |
87 " ssm_feat.fit_data(traindata[:, inds], trainlabels, n_components=len(inds), pca_only=True)\n", | |
88 " n_components = np.where(ssm_feat.pca_transformer.explained_variance_ratio_.cumsum()>min_variance)[0][0]+1\n", | |
89 " print n_components, len(inds)\n", | |
90 " ssm_feat.fit_lda_data(traindata[:, inds], trainlabels, n_components=n_components)\n", | |
91 "\n", | |
92 " WW = ssm_feat.lda_transformer.scalings_\n", | |
93 " plt.figure()\n", | |
94 " plt.imshow(WW[:, :n_components].T, aspect='auto')\n", | |
95 " plt.colorbar()" | |
96 ] | |
97 }, | |
98 { | |
99 "cell_type": "markdown", | |
100 "metadata": {}, | |
101 "source": [ | |
102 "## explain classifier" | |
103 ] | |
104 }, | |
105 { | |
106 "cell_type": "code", | |
107 "execution_count": null, | |
108 "metadata": { | |
109 "collapsed": true | |
110 }, | |
111 "outputs": [], | |
112 "source": [ | |
113 "X_list, Y, Yaudio = pickle.load(open('../data/lda_data_melodia_8.pickle','rb'))\n", | |
114 "Xrhy, Xmel, Xmfc, Xchr = X_list\n", | |
115 "X = np.concatenate((Xrhy, Xmel, Xmfc, Xchr), axis=1)" | |
116 ] | |
117 }, | |
118 { | |
119 "cell_type": "code", | |
120 "execution_count": null, | |
121 "metadata": { | |
122 "collapsed": true | |
123 }, | |
124 "outputs": [], | |
125 "source": [ | |
126 "ssm_feat.classify_and_save(X_train, Y_train, X_test, Y_test, transform_label=\" \")" | |
127 ] | |
128 }, | |
129 { | |
130 "cell_type": "code", | |
131 "execution_count": null, | |
132 "metadata": { | |
133 "collapsed": true | |
134 }, | |
135 "outputs": [], | |
136 "source": [ | |
137 "def components_plot(lda_transformer, XX, n_comp=42, figurename=None):\n", | |
138 " WW=lda_transformer.scalings_\n", | |
139 " Xlda=lda_transformer.transform(XX)\n", | |
140 " Xww=numpy.dot(XX, WW[:, :n_comp])\n", | |
141 " plt.figure()\n", | |
142 " plt.imshow(Xlda - Xww, aspect='auto')\n", | |
143 " plt.figure()\n", | |
144 " plt.imshow(Xlda, aspect='auto')\n", | |
145 " plt.figure()\n", | |
146 " plt.imshow(Xww, aspect='auto')\n", | |
147 " plt.figure()\n", | |
148 " plt.imshow(WW[:, :n_comp], aspect='auto') # this explains the weights up to n_components=64\n", | |
149 " if figurename is not None:\n", | |
150 " plt.savefig(figurename)\n", | |
151 "\n", | |
152 "XX = traindata[:, inds]\n", | |
153 "components_plot(ssm_feat.lda_transformer, XX, n_comp=n_components)" | |
154 ] | |
155 } | |
156 ], | |
157 "metadata": { | |
158 "kernelspec": { | |
159 "display_name": "Python 2", | |
160 "language": "python", | |
161 "name": "python2" | |
162 }, | |
163 "language_info": { | |
164 "codemirror_mode": { | |
165 "name": "ipython", | |
166 "version": 2 | |
167 }, | |
168 "file_extension": ".py", | |
169 "mimetype": "text/x-python", | |
170 "name": "python", | |
171 "nbconvert_exporter": "python", | |
172 "pygments_lexer": "ipython2", | |
173 "version": "2.7.12" | |
174 } | |
175 }, | |
176 "nbformat": 4, | |
177 "nbformat_minor": 0 | |
178 } |