Mercurial > hg > confint
comparison sampling/strategies.R @ 0:205974c9568c tip
Initial commit. Predictions not included for lack of space.
author | franrodalg <f.rodriguezalgarra@qmul.ac.uk> |
---|---|
date | Sat, 29 Jun 2019 18:45:50 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:205974c9568c |
---|---|
1 ## Generic sampling function | |
2 get_samples <- function(num_iter = 1, mode = 'cv', | |
3 N = 0, num_folds = 4, | |
4 stratified = F, keep_prop = F, | |
5 db = "../db/gtzan.db", classes = NULL){ | |
6 | |
7 if (mode == 'cv') | |
8 get_cv_samples( | |
9 num_iter = num_iter, num_folds = num_folds, | |
10 stratified = stratified, db = db, classes = classes) | |
11 else | |
12 get_bs_samples( | |
13 N = N, num_iter = num_iter, stratified = stratified, | |
14 keep_prop = keep_prop, db = db, classes = classes) | |
15 | |
16 } | |
17 | |
18 # CROSS-VALIDATION | |
19 | |
20 ## Cross-Validation sampling | |
21 get_cv_samples <- function(num_iter = 1, num_folds = 4, stratified = F, | |
22 db = "../db/gtzan.db", classes = NULL){ | |
23 | |
24 train <- data.frame() | |
25 test <- data.frame() | |
26 | |
27 for (i in 1:num_iter){ | |
28 folds <- | |
29 get_folds(num_folds = num_folds, stratified = stratified, | |
30 db = db, classes = classes) | |
31 | |
32 for (j in 1:num_folds){ | |
33 | |
34 train <- rbind(train, | |
35 data.frame(ex_id = folds[[j]]$train, | |
36 iter = i, fold = j, | |
37 mode = 'cv', | |
38 stratified = stratified)) | |
39 test <- rbind(test, | |
40 data.frame(ex_id = folds[[j]]$test, | |
41 iter = i, fold = j, | |
42 mode = 'cv', | |
43 stratified = stratified)) | |
44 } | |
45 | |
46 } | |
47 | |
48 return(list(train = train, test = test)) | |
49 | |
50 } | |
51 | |
52 | |
53 # Random assignment into folds | |
54 .get_assignment <- function(N, num_folds){ | |
55 | |
56 sample( | |
57 as.numeric( | |
58 sapply(1:num_folds, | |
59 function(x) rep(x, N/num_folds)))) | |
60 | |
61 } | |
62 | |
63 # Obtain folds for Cross-Validation | |
64 get_folds <- function(num_folds = 4, stratified = F, | |
65 db = "../db/gtzan.db", classes = NULL){ | |
66 | |
67 res <- get_excerpts_classes(db = db, classes) %>% | |
68 arrange(class_id, ex_id_class) | |
69 | |
70 if(stratified){ | |
71 assignment <- numeric(0) | |
72 for(i in unique(res$class_id)){ | |
73 N_class <- | |
74 filter(res, class_id == i) %>% | |
75 nrow() | |
76 assignment <- | |
77 c(assignment, .get_assignment(N_class, num_folds)) | |
78 } | |
79 } | |
80 else | |
81 assignment <- .get_assignment(nrow(res), num_folds) | |
82 | |
83 folds <- vector("list", num_folds) | |
84 aux <- res$ex_id | |
85 | |
86 for (i in 1:num_folds){ | |
87 folds[[i]][['test']] <- | |
88 aux[which(assignment == i)] | |
89 folds[[i]][['train']] <- | |
90 aux[which(assignment != i)] | |
91 } | |
92 | |
93 folds | |
94 | |
95 } | |
96 | |
97 | |
98 # BOOTSTRAP | |
99 | |
100 ## Boostrap sampling | |
101 get_bs_samples <- function(N, num_iter, stratified = F, keep_prop = F, | |
102 db = "../db/gtzan.db", classes = NULL, | |
103 hold_out_ex_ids = NULL){ | |
104 | |
105 res <- get_excerpts_classes(db, classes) | |
106 | |
107 if (N == 0 | is.null(N)) | |
108 N <- nrow(res) | |
109 | |
110 train <- data.frame() | |
111 test <- data.frame() | |
112 | |
113 for (i in 1:num_iter){ | |
114 train_test <- | |
115 get_bs_sample(N, stratified, keep_prop, db, classes, res, | |
116 hold_out_ex_ids) | |
117 train <- rbind(train, | |
118 data.frame(ex_id = train_test$train, | |
119 iter = i, | |
120 mode = 'bs', | |
121 stratified = stratified, | |
122 keep_prop = keep_prop)) | |
123 test <- rbind(test, | |
124 data.frame(ex_id = train_test$test, | |
125 iter = i, | |
126 mode = 'bs', | |
127 stratified = stratified, | |
128 keep_prop = keep_prop)) | |
129 } | |
130 | |
131 return(list(train = train, test = test)) | |
132 | |
133 } | |
134 | |
135 ## Correct proportions for test set stratification | |
136 .correct_prop <- function(samples, reference_samples, | |
137 db = "../db/gtzan.db"){ | |
138 | |
139 ref <- as.factor( | |
140 get_classes(reference_samples, unique_classes = F, db = db)) | |
141 classes <- factor( | |
142 get_classes(samples, unique_classes = F, db = db), | |
143 levels = levels(ref)) | |
144 | |
145 ref_counts <- table(ref) | |
146 ref_prop <- ref_counts / sum(ref_counts) | |
147 | |
148 prior_counts <- table(classes) | |
149 post_counts <- round(ref_prop * min(table(classes) / ref_prop)) | |
150 | |
151 post_samples <- numeric(0) | |
152 for (class in levels(classes)){ | |
153 post_samples <- | |
154 c(post_samples, | |
155 sample(samples[which(classes == class)], | |
156 size = post_counts[[class]])) | |
157 } | |
158 | |
159 post_samples | |
160 | |
161 } | |
162 | |
163 ## Bootstrap single train/test pair | |
164 get_bs_sample <- function(N, stratified = F, keep_prop = F, | |
165 db = "../db/gtzan.db", classes = NULL, | |
166 df = NULL, | |
167 hold_out_ex_ids = NULL | |
168 ){ | |
169 | |
170 if(is.null(df)) | |
171 df <- get_excerpts_classes(db, classes) | |
172 | |
173 n_ex_class <- table(df$class_id) | |
174 print(n_ex_class) | |
175 | |
176 if(!is.null(hold_out_ex_ids)) | |
177 df <- df %>% filter(! (ex_id %in% hold_out_ex_ids)) | |
178 | |
179 train <- numeric(0) | |
180 | |
181 if(stratified){ | |
182 for(j in unique(df$class_id)){ | |
183 ex_class <- filter(df, class_id == j) %>% | |
184 select(ex_id) %>% | |
185 unlist() %>% | |
186 unname() | |
187 train <- c(train, | |
188 sample(ex_class, | |
189 size = n_ex_class[which(names(n_ex_class) == j)], | |
190 replace = T)) | |
191 } | |
192 } | |
193 else{ | |
194 train <- sample(df$ex_id, size = N, replace = T) | |
195 } | |
196 | |
197 test <- sort(c(df$ex_id[!(df$ex_id %in% train)], hold_out_ex_ids)) | |
198 if(keep_prop){ | |
199 test <- .correct_prop(test, reference = train) | |
200 } | |
201 | |
202 return(list(train = train, test = test)) | |
203 | |
204 } |