annotate classify_scenes.m @ 2:def2b3fa1450 tip master

corrected README
author Gerard Roma <gerard.roma@upf.edu>
date Mon, 04 Nov 2013 10:46:05 +0000
parents 96b1b8697b60
children
rev   line source
gerard@1 1 % Copyright 2013 MUSIC TECHNOLOGY GROUP, UNIVERSITAT POMPEU FABRA
gerard@1 2 %
gerard@1 3 % Written by Gerard Roma <gerard.roma@upf.edu>
gerard@1 4 %
gerard@1 5 % This program is free software: you can redistribute it and/or modify
gerard@1 6 % it under the terms of the GNU Affero General Public License as published by
gerard@1 7 % the Free Software Foundation, either version 3 of the License, or
gerard@1 8 % (at your option) any later version.
gerard@1 9 %
gerard@1 10 % This program is distributed in the hope that it will be useful,
gerard@1 11 % but WITHOUT ANY WARRANTY; without even the implied warranty of
gerard@1 12 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
gerard@1 13 % GNU Affero General Public License for more details.
gerard@1 14 %
gerard@1 15 % You should have received a copy of the GNU Affero General Public License
gerard@1 16 % along with this program. If not, see <http://www.gnu.org/licenses/>.
gerard@1 17
gerard@1 18 function [ predict_label ] = classify_scenes(tmp_path, train_file,test_file, output_file, grid_search)
gerard@1 19 addpath('rastamat/');
gerard@1 20 addpath('libsvm/')
gerard@1 21
gerard@1 22 classes = {'bus' 'busystreet' 'office' 'openairmarket' 'park' 'quietstreet' 'restaurant' 'supermarket' 'tube' 'tubestation'};
gerard@1 23
gerard@1 24 [tr_names,tr_labels] = loadClassificationOutput(train_file);
gerard@1 25 [te_names,te_labels] = loadClassificationOutput(test_file);
gerard@1 26
gerard@1 27
gerard@1 28 %% extract features
gerard@1 29 disp('analyzing training files');
gerard@1 30 train_features = analyze_files(tr_names,tmp_path);
gerard@1 31 disp('analyzing test files');
gerard@1 32 test_features = analyze_files(te_names, tmp_path);
gerard@1 33 train_labels = [];
gerard@1 34 test_labels = ones(length(te_names),1);
gerard@1 35
gerard@1 36 train_labels = get_class_indices(tr_labels);
gerard@1 37
gerard@1 38 [train_z, mu, sigma] = zscore(train_features);
gerard@1 39 MU = repmat(mu,size(test_features,1),1);
gerard@1 40 SIGMA = repmat(sigma,size(test_features,1),1);
gerard@1 41 test_z = (test_features -MU)./SIGMA;
gerard@1 42
gerard@1 43
gerard@1 44 %% SVM grid Search inspired in http://labrosa.ee.columbia.edu/projects/consumervideo/
gerard@1 45
gerard@1 46 if grid_search
gerard@1 47 disp('performing grid search');
gerard@1 48 tuning_data_fraction = 0.9;
gerard@1 49 train_size = round(tuning_data_fraction*size(train_z,1));
gerard@1 50 p = randperm(size(train_z,1));
gerard@1 51 trainX = train_z(p(1:train_size),:);
gerard@1 52 validateX = train_z(p(train_size+1:end),:);
gerard@1 53 trainY = get_class_indices(tr_labels(p(1:train_size)));
gerard@1 54 validateY = get_class_indices(tr_labels(p(train_size+1:end)));
gerard@1 55
gerard@1 56
gerard@1 57 gamma = 2.^[-10:1:10];
gerard@1 58 C = 2.^[0:13];
gerard@1 59
gerard@1 60 params = zeros(length(gamma),length(C));
gerard@1 61 best_a = 0;
gerard@1 62 best_g = 0;
gerard@1 63 best_C = 0;
gerard@1 64
gerard@1 65 for gi= 1:length(gamma)
gerard@1 66 for Ci = 1:length(C)
gerard@1 67 m = svmtrain(trainY', trainX, sprintf('-c %d -g %2.5f -q',C(Ci),gamma(gi)));
gerard@1 68 [p,a] = svmpredict(validateY', validateX, m, '-q');
gerard@1 69 if a(1) > best_a
gerard@1 70 best_C = C(Ci);
gerard@1 71 best_g = gamma(gi);
gerard@1 72 best_a = a(1);
gerard@1 73 end
gerard@1 74 end
gerard@1 75 end
gerard@1 76 disp('grid search done');
gerard@1 77 else
gerard@1 78 best_C=70;
gerard@1 79 best_g = 0.003;
gerard@1 80 end
gerard@1 81
gerard@1 82
gerard@1 83 model = svmtrain(train_labels', train_z, sprintf('-c %d -g %2.5f -q',best_C,best_g));
gerard@1 84 [predict_indices, accuracy_obj, prob_values] = svmpredict(test_labels, test_z, model,'-q');
gerard@1 85 predict_label = classes(predict_indices);
gerard@1 86
gerard@1 87
gerard@1 88 outfd = fopen(output_file,'w+');
gerard@1 89
gerard@1 90 for i = 1:length(te_names)
gerard@1 91 fprintf(outfd,'%s\t',[char(te_names(i))]);
gerard@1 92 fprintf(outfd,'%s\n',[char(predict_label(i))]);
gerard@1 93 end
gerard@1 94 fclose(outfd);
gerard@1 95
gerard@1 96 function idx = get_class_indices(labels)
gerard@1 97 idx = [];
gerard@1 98 for i=1:length(labels)
gerard@1 99 class = find(strcmp(classes,labels(i)));
gerard@1 100 idx(i) = class;
gerard@1 101 end
gerard@1 102
gerard@1 103 end
gerard@1 104
gerard@1 105
gerard@1 106 end
gerard@1 107