Mercurial > hg > confint
comparison sampling/simulations.R @ 0:205974c9568c tip
Initial commit. Predictions not included for lack of space.
author | franrodalg <f.rodriguezalgarra@qmul.ac.uk> |
---|---|
date | Sat, 29 Jun 2019 18:45:50 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:205974c9568c |
---|---|
1 filter_set <- function(set, ref_set){ | |
2 | |
3 excerpts <- set %>% | |
4 inner_join(get_excerpts_artists(), by = c("ex_id")) | |
5 | |
6 artists <- get_artists(ref_set$ex_id, unique_artists = T) | |
7 | |
8 ex_ids <- excerpts %>% | |
9 group_by(ex_id) %>% | |
10 filter(any(artist_id %in% artists)) %>% | |
11 select(ex_id) %>% | |
12 unique() %>% | |
13 unlist() | |
14 | |
15 excerpts %>% filter(! ex_id %in% ex_ids) %>% select(-artist_id) | |
16 | |
17 } | |
18 | |
19 check_filt <- function(set, minimum = 10){ | |
20 | |
21 if(plyr::empty(set) | is.null(set)){ | |
22 classes <- get_class_names() | |
23 res <- logical(nrow(classes)) | |
24 names(res) <- classes$class | |
25 return(res) | |
26 } | |
27 | |
28 table( | |
29 get_classes(set$ex_id, unique_classes = F, use_names = T)) >= minimum | |
30 | |
31 } | |
32 | |
33 select_samples <- function(min_per_class = 10, selection_type = 'rand_surv', | |
34 max_draws = 1e6){ | |
35 | |
36 if (selection_type == 'rand_surv' | | |
37 selection_type == 'class_surv' | | |
38 selection_type == 'art_class_surv'){ | |
39 survival_selection(min_per_class, selection_type, max_draws) | |
40 } | |
41 | |
42 } | |
43 | |
44 # | |
45 .update_set <- function(original_set, class_set){ | |
46 | |
47 class <- get_classes(class_set$ex_id, | |
48 unique_classes = T)[1] | |
49 classes <- get_classes(original_set$ex_id, | |
50 unique_classes = F) | |
51 | |
52 new_set <- original_set[classes != class, ] | |
53 | |
54 rbind(new_set, class_set) %>% arrange(ex_id) | |
55 | |
56 } | |
57 | |
58 # | |
59 .update_sets <- function(original_sets, class_sets){ | |
60 | |
61 original_sets$train <- .update_set(original_sets$train, class_sets$train) | |
62 original_sets$test <- .update_set(original_sets$test, class_sets$test) | |
63 original_sets$filt <- .update_set(original_sets$filt, class_sets$filt) | |
64 original_sets | |
65 | |
66 } | |
67 | |
68 # | |
69 | |
70 .get_sets <- function(class = NULL, hold_out_ex_ids = NULL){ | |
71 | |
72 if(is.null(class)) N <- 1000 | |
73 else N <- 100 | |
74 | |
75 sets <- get_bs_samples(N, | |
76 num_iter = 1, stratified = T, | |
77 keep_prop = F, classes = class, | |
78 hold_out_ex_ids = hold_out_ex_ids) | |
79 sets$filt <- filter_set(sets$test, sets$train) | |
80 | |
81 return(sets) | |
82 | |
83 } | |
84 | |
85 .get_class_sets <- function(class = NULL, | |
86 selection_type = 'class_surv', | |
87 min_per_class = 0){ | |
88 | |
89 if(selection_type == 'class_surv') .get_sets(class = class) | |
90 else if(selection_type == 'art_class_surv'){ | |
91 | |
92 ex_art <- | |
93 get_excerpts_classes(class = class) %>% | |
94 inner_join(get_excerpts_artists(), by = c('ex_id')) | |
95 | |
96 artists_dist <- table(ex_art$artist_id) | |
97 | |
98 total_excerpts <- length(unique(ex_art$ex_id)) | |
99 num_excerpts <- 0 | |
100 hold_out_artists <- numeric(0) | |
101 | |
102 while (num_excerpts < min_per_class) { | |
103 new_artist <- sample(unique(ex_art$artist_id), 1) | |
104 if (!any(new_artist %in% hold_out_artists)) { | |
105 hold_out_artists <- c(hold_out_artists, new_artist) | |
106 num_excerpts <- num_excerpts + | |
107 as.numeric(artists_dist[which(names(artists_dist) == new_artist)]) | |
108 } | |
109 | |
110 } | |
111 | |
112 hold_out_ex_ids <- unique(ex_art$ex_id)[ | |
113 which(ex_art$artist_id %in% hold_out_artists)] | |
114 | |
115 return(.get_sets(class, hold_out_ex_ids)) | |
116 | |
117 } | |
118 | |
119 } | |
120 | |
121 .fail_selection <- function(num_draws){ | |
122 | |
123 list(sets = NULL, success = F, draws = num_draws) | |
124 | |
125 } | |
126 | |
127 survival_selection <- function(min_per_class = 10, | |
128 selection_type = 'rand_surv', | |
129 max_draws = 1e6){ | |
130 | |
131 # 1.- Draw initial sets | |
132 sets <- .get_sets() | |
133 num_draws <- nrow(sets$train) | |
134 | |
135 # 2.- Check if initial test set meets requirements | |
136 checked <- check_filt(sets$filt, minimum = min_per_class) | |
137 | |
138 if(!all(checked)){ | |
139 | |
140 # 3.1.- Random Survival: Complete redraw | |
141 if(selection_type == 'rand_surv') | |
142 while(!all(checked)){ | |
143 if(num_draws > max_draws) return(.fail_selection(num_draws)) | |
144 sets <- .get_sets() | |
145 num_draws <- num_draws + 1000 | |
146 checked <- check_filt(sets$filt, minimum = min_per_class) | |
147 } | |
148 # 3.2.- Class Survival: Redraw only problematic classes | |
149 # 3.3.- Artist-informed Class Survival | |
150 else if(selection_type == 'class_surv' | | |
151 selection_type == 'art_class_surv'){ | |
152 | |
153 for(fail_class in names(checked[!checked])){ | |
154 print(paste0("Redrawing class: ", fail_class)) | |
155 if(num_draws > max_draws) return(.fail_selection(num_draws)) | |
156 class_checked <- F | |
157 while (!class_checked) { | |
158 print(paste0(" Attempting: ", fail_class)) | |
159 class_sets <- .get_class_sets(class = fail_class, | |
160 selection_type = selection_type, | |
161 min_per_class = min_per_class) | |
162 num_draws <- num_draws + 100 | |
163 class_checked <- nrow(class_sets$filt) >= min_per_class | |
164 } | |
165 sets <- .update_sets(sets, class_sets) | |
166 } | |
167 } | |
168 } | |
169 | |
170 return(list(sets = sets, success = T, draws = num_draws)) | |
171 | |
172 } | |
173 | |
174 | |
175 # | |
176 draw_samples <- function(num_iters = 40, | |
177 min_per_class = 10, | |
178 selection_type = "art_class_surv"){ | |
179 | |
180 sets <- list(train = data.frame(), | |
181 test = data.frame(), | |
182 filt = data.frame()) | |
183 sim_results <- data.frame( | |
184 success = logical(0), num_draws = numeric(0)) | |
185 i <- 1 | |
186 while(i <= num_iters){ | |
187 | |
188 print("Starting new attempt") | |
189 | |
190 res <- survival_selection( | |
191 min_per_class = min_per_class, | |
192 selection_type = selection_type | |
193 ) | |
194 | |
195 print( | |
196 paste0( | |
197 "Finishing new attempt: ", | |
198 ifelse(res$success, "Success ", "Failure "), | |
199 "after ", res$draws, " draws")) | |
200 print("") | |
201 | |
202 sim_results <- rbind(sim_results, | |
203 list(success = res$success, | |
204 num_draws = res$draws)) | |
205 | |
206 if(res$success){ | |
207 sets$train <- rbind(sets$train, | |
208 res$sets$train %>% mutate(iter = i)) | |
209 sets$test <- rbind(sets$test, | |
210 res$sets$test %>% mutate(iter = i)) | |
211 sets$filt <- rbind(sets$filt, | |
212 res$sets$filt %>% mutate(iter = i)) | |
213 i <- i + 1 | |
214 } | |
215 | |
216 } | |
217 | |
218 return(list(sets = sets, sim_results = sim_results)) | |
219 | |
220 } |