comparison sampling/simulations.R @ 0:205974c9568c tip

Initial commit. Predictions not included for lack of space.
author franrodalg <f.rodriguezalgarra@qmul.ac.uk>
date Sat, 29 Jun 2019 18:45:50 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:205974c9568c
1 filter_set <- function(set, ref_set){
2
3 excerpts <- set %>%
4 inner_join(get_excerpts_artists(), by = c("ex_id"))
5
6 artists <- get_artists(ref_set$ex_id, unique_artists = T)
7
8 ex_ids <- excerpts %>%
9 group_by(ex_id) %>%
10 filter(any(artist_id %in% artists)) %>%
11 select(ex_id) %>%
12 unique() %>%
13 unlist()
14
15 excerpts %>% filter(! ex_id %in% ex_ids) %>% select(-artist_id)
16
17 }
18
19 check_filt <- function(set, minimum = 10){
20
21 if(plyr::empty(set) | is.null(set)){
22 classes <- get_class_names()
23 res <- logical(nrow(classes))
24 names(res) <- classes$class
25 return(res)
26 }
27
28 table(
29 get_classes(set$ex_id, unique_classes = F, use_names = T)) >= minimum
30
31 }
32
33 select_samples <- function(min_per_class = 10, selection_type = 'rand_surv',
34 max_draws = 1e6){
35
36 if (selection_type == 'rand_surv' |
37 selection_type == 'class_surv' |
38 selection_type == 'art_class_surv'){
39 survival_selection(min_per_class, selection_type, max_draws)
40 }
41
42 }
43
44 #
45 .update_set <- function(original_set, class_set){
46
47 class <- get_classes(class_set$ex_id,
48 unique_classes = T)[1]
49 classes <- get_classes(original_set$ex_id,
50 unique_classes = F)
51
52 new_set <- original_set[classes != class, ]
53
54 rbind(new_set, class_set) %>% arrange(ex_id)
55
56 }
57
58 #
59 .update_sets <- function(original_sets, class_sets){
60
61 original_sets$train <- .update_set(original_sets$train, class_sets$train)
62 original_sets$test <- .update_set(original_sets$test, class_sets$test)
63 original_sets$filt <- .update_set(original_sets$filt, class_sets$filt)
64 original_sets
65
66 }
67
68 #
69
70 .get_sets <- function(class = NULL, hold_out_ex_ids = NULL){
71
72 if(is.null(class)) N <- 1000
73 else N <- 100
74
75 sets <- get_bs_samples(N,
76 num_iter = 1, stratified = T,
77 keep_prop = F, classes = class,
78 hold_out_ex_ids = hold_out_ex_ids)
79 sets$filt <- filter_set(sets$test, sets$train)
80
81 return(sets)
82
83 }
84
85 .get_class_sets <- function(class = NULL,
86 selection_type = 'class_surv',
87 min_per_class = 0){
88
89 if(selection_type == 'class_surv') .get_sets(class = class)
90 else if(selection_type == 'art_class_surv'){
91
92 ex_art <-
93 get_excerpts_classes(class = class) %>%
94 inner_join(get_excerpts_artists(), by = c('ex_id'))
95
96 artists_dist <- table(ex_art$artist_id)
97
98 total_excerpts <- length(unique(ex_art$ex_id))
99 num_excerpts <- 0
100 hold_out_artists <- numeric(0)
101
102 while (num_excerpts < min_per_class) {
103 new_artist <- sample(unique(ex_art$artist_id), 1)
104 if (!any(new_artist %in% hold_out_artists)) {
105 hold_out_artists <- c(hold_out_artists, new_artist)
106 num_excerpts <- num_excerpts +
107 as.numeric(artists_dist[which(names(artists_dist) == new_artist)])
108 }
109
110 }
111
112 hold_out_ex_ids <- unique(ex_art$ex_id)[
113 which(ex_art$artist_id %in% hold_out_artists)]
114
115 return(.get_sets(class, hold_out_ex_ids))
116
117 }
118
119 }
120
121 .fail_selection <- function(num_draws){
122
123 list(sets = NULL, success = F, draws = num_draws)
124
125 }
126
127 survival_selection <- function(min_per_class = 10,
128 selection_type = 'rand_surv',
129 max_draws = 1e6){
130
131 # 1.- Draw initial sets
132 sets <- .get_sets()
133 num_draws <- nrow(sets$train)
134
135 # 2.- Check if initial test set meets requirements
136 checked <- check_filt(sets$filt, minimum = min_per_class)
137
138 if(!all(checked)){
139
140 # 3.1.- Random Survival: Complete redraw
141 if(selection_type == 'rand_surv')
142 while(!all(checked)){
143 if(num_draws > max_draws) return(.fail_selection(num_draws))
144 sets <- .get_sets()
145 num_draws <- num_draws + 1000
146 checked <- check_filt(sets$filt, minimum = min_per_class)
147 }
148 # 3.2.- Class Survival: Redraw only problematic classes
149 # 3.3.- Artist-informed Class Survival
150 else if(selection_type == 'class_surv' |
151 selection_type == 'art_class_surv'){
152
153 for(fail_class in names(checked[!checked])){
154 print(paste0("Redrawing class: ", fail_class))
155 if(num_draws > max_draws) return(.fail_selection(num_draws))
156 class_checked <- F
157 while (!class_checked) {
158 print(paste0(" Attempting: ", fail_class))
159 class_sets <- .get_class_sets(class = fail_class,
160 selection_type = selection_type,
161 min_per_class = min_per_class)
162 num_draws <- num_draws + 100
163 class_checked <- nrow(class_sets$filt) >= min_per_class
164 }
165 sets <- .update_sets(sets, class_sets)
166 }
167 }
168 }
169
170 return(list(sets = sets, success = T, draws = num_draws))
171
172 }
173
174
175 #
176 draw_samples <- function(num_iters = 40,
177 min_per_class = 10,
178 selection_type = "art_class_surv"){
179
180 sets <- list(train = data.frame(),
181 test = data.frame(),
182 filt = data.frame())
183 sim_results <- data.frame(
184 success = logical(0), num_draws = numeric(0))
185 i <- 1
186 while(i <= num_iters){
187
188 print("Starting new attempt")
189
190 res <- survival_selection(
191 min_per_class = min_per_class,
192 selection_type = selection_type
193 )
194
195 print(
196 paste0(
197 "Finishing new attempt: ",
198 ifelse(res$success, "Success ", "Failure "),
199 "after ", res$draws, " draws"))
200 print("")
201
202 sim_results <- rbind(sim_results,
203 list(success = res$success,
204 num_draws = res$draws))
205
206 if(res$success){
207 sets$train <- rbind(sets$train,
208 res$sets$train %>% mutate(iter = i))
209 sets$test <- rbind(sets$test,
210 res$sets$test %>% mutate(iter = i))
211 sets$filt <- rbind(sets$filt,
212 res$sets$filt %>% mutate(iter = i))
213 i <- i + 1
214 }
215
216 }
217
218 return(list(sets = sets, sim_results = sim_results))
219
220 }