gerard@1
|
1 % Copyright 2013 MUSIC TECHNOLOGY GROUP, UNIVERSITAT POMPEU FABRA
|
gerard@1
|
2 %
|
gerard@1
|
3 % Written by Gerard Roma <gerard.roma@upf.edu>
|
gerard@1
|
4 %
|
gerard@1
|
5 % This program is free software: you can redistribute it and/or modify
|
gerard@1
|
6 % it under the terms of the GNU Affero General Public License as published by
|
gerard@1
|
7 % the Free Software Foundation, either version 3 of the License, or
|
gerard@1
|
8 % (at your option) any later version.
|
gerard@1
|
9 %
|
gerard@1
|
10 % This program is distributed in the hope that it will be useful,
|
gerard@1
|
11 % but WITHOUT ANY WARRANTY; without even the implied warranty of
|
gerard@1
|
12 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
gerard@1
|
13 % GNU Affero General Public License for more details.
|
gerard@1
|
14 %
|
gerard@1
|
15 % You should have received a copy of the GNU Affero General Public License
|
gerard@1
|
16 % along with this program. If not, see <http://www.gnu.org/licenses/>.
|
gerard@1
|
17
|
gerard@1
|
18 function [ predict_label ] = classify_scenes(tmp_path, train_file,test_file, output_file, grid_search)
|
gerard@1
|
19 addpath('rastamat/');
|
gerard@1
|
20 addpath('libsvm/')
|
gerard@1
|
21
|
gerard@1
|
22 classes = {'bus' 'busystreet' 'office' 'openairmarket' 'park' 'quietstreet' 'restaurant' 'supermarket' 'tube' 'tubestation'};
|
gerard@1
|
23
|
gerard@1
|
24 [tr_names,tr_labels] = loadClassificationOutput(train_file);
|
gerard@1
|
25 [te_names,te_labels] = loadClassificationOutput(test_file);
|
gerard@1
|
26
|
gerard@1
|
27
|
gerard@1
|
28 %% extract features
|
gerard@1
|
29 disp('analyzing training files');
|
gerard@1
|
30 train_features = analyze_files(tr_names,tmp_path);
|
gerard@1
|
31 disp('analyzing test files');
|
gerard@1
|
32 test_features = analyze_files(te_names, tmp_path);
|
gerard@1
|
33 train_labels = [];
|
gerard@1
|
34 test_labels = ones(length(te_names),1);
|
gerard@1
|
35
|
gerard@1
|
36 train_labels = get_class_indices(tr_labels);
|
gerard@1
|
37
|
gerard@1
|
38 [train_z, mu, sigma] = zscore(train_features);
|
gerard@1
|
39 MU = repmat(mu,size(test_features,1),1);
|
gerard@1
|
40 SIGMA = repmat(sigma,size(test_features,1),1);
|
gerard@1
|
41 test_z = (test_features -MU)./SIGMA;
|
gerard@1
|
42
|
gerard@1
|
43
|
gerard@1
|
44 %% SVM grid Search inspired in http://labrosa.ee.columbia.edu/projects/consumervideo/
|
gerard@1
|
45
|
gerard@1
|
46 if grid_search
|
gerard@1
|
47 disp('performing grid search');
|
gerard@1
|
48 tuning_data_fraction = 0.9;
|
gerard@1
|
49 train_size = round(tuning_data_fraction*size(train_z,1));
|
gerard@1
|
50 p = randperm(size(train_z,1));
|
gerard@1
|
51 trainX = train_z(p(1:train_size),:);
|
gerard@1
|
52 validateX = train_z(p(train_size+1:end),:);
|
gerard@1
|
53 trainY = get_class_indices(tr_labels(p(1:train_size)));
|
gerard@1
|
54 validateY = get_class_indices(tr_labels(p(train_size+1:end)));
|
gerard@1
|
55
|
gerard@1
|
56
|
gerard@1
|
57 gamma = 2.^[-10:1:10];
|
gerard@1
|
58 C = 2.^[0:13];
|
gerard@1
|
59
|
gerard@1
|
60 params = zeros(length(gamma),length(C));
|
gerard@1
|
61 best_a = 0;
|
gerard@1
|
62 best_g = 0;
|
gerard@1
|
63 best_C = 0;
|
gerard@1
|
64
|
gerard@1
|
65 for gi= 1:length(gamma)
|
gerard@1
|
66 for Ci = 1:length(C)
|
gerard@1
|
67 m = svmtrain(trainY', trainX, sprintf('-c %d -g %2.5f -q',C(Ci),gamma(gi)));
|
gerard@1
|
68 [p,a] = svmpredict(validateY', validateX, m, '-q');
|
gerard@1
|
69 if a(1) > best_a
|
gerard@1
|
70 best_C = C(Ci);
|
gerard@1
|
71 best_g = gamma(gi);
|
gerard@1
|
72 best_a = a(1);
|
gerard@1
|
73 end
|
gerard@1
|
74 end
|
gerard@1
|
75 end
|
gerard@1
|
76 disp('grid search done');
|
gerard@1
|
77 else
|
gerard@1
|
78 best_C=70;
|
gerard@1
|
79 best_g = 0.003;
|
gerard@1
|
80 end
|
gerard@1
|
81
|
gerard@1
|
82
|
gerard@1
|
83 model = svmtrain(train_labels', train_z, sprintf('-c %d -g %2.5f -q',best_C,best_g));
|
gerard@1
|
84 [predict_indices, accuracy_obj, prob_values] = svmpredict(test_labels, test_z, model,'-q');
|
gerard@1
|
85 predict_label = classes(predict_indices);
|
gerard@1
|
86
|
gerard@1
|
87
|
gerard@1
|
88 outfd = fopen(output_file,'w+');
|
gerard@1
|
89
|
gerard@1
|
90 for i = 1:length(te_names)
|
gerard@1
|
91 fprintf(outfd,'%s\t',[char(te_names(i))]);
|
gerard@1
|
92 fprintf(outfd,'%s\n',[char(predict_label(i))]);
|
gerard@1
|
93 end
|
gerard@1
|
94 fclose(outfd);
|
gerard@1
|
95
|
gerard@1
|
96 function idx = get_class_indices(labels)
|
gerard@1
|
97 idx = [];
|
gerard@1
|
98 for i=1:length(labels)
|
gerard@1
|
99 class = find(strcmp(classes,labels(i)));
|
gerard@1
|
100 idx(i) = class;
|
gerard@1
|
101 end
|
gerard@1
|
102
|
gerard@1
|
103 end
|
gerard@1
|
104
|
gerard@1
|
105
|
gerard@1
|
106 end
|
gerard@1
|
107
|