annotate sampling/simulations.R @ 0:205974c9568c tip

Initial commit. Predictions not included for lack of space.
author franrodalg <f.rodriguezalgarra@qmul.ac.uk>
date Sat, 29 Jun 2019 18:45:50 +0100
parents
children
rev   line source
f@0 1 filter_set <- function(set, ref_set){
f@0 2
f@0 3 excerpts <- set %>%
f@0 4 inner_join(get_excerpts_artists(), by = c("ex_id"))
f@0 5
f@0 6 artists <- get_artists(ref_set$ex_id, unique_artists = T)
f@0 7
f@0 8 ex_ids <- excerpts %>%
f@0 9 group_by(ex_id) %>%
f@0 10 filter(any(artist_id %in% artists)) %>%
f@0 11 select(ex_id) %>%
f@0 12 unique() %>%
f@0 13 unlist()
f@0 14
f@0 15 excerpts %>% filter(! ex_id %in% ex_ids) %>% select(-artist_id)
f@0 16
f@0 17 }
f@0 18
f@0 19 check_filt <- function(set, minimum = 10){
f@0 20
f@0 21 if(plyr::empty(set) | is.null(set)){
f@0 22 classes <- get_class_names()
f@0 23 res <- logical(nrow(classes))
f@0 24 names(res) <- classes$class
f@0 25 return(res)
f@0 26 }
f@0 27
f@0 28 table(
f@0 29 get_classes(set$ex_id, unique_classes = F, use_names = T)) >= minimum
f@0 30
f@0 31 }
f@0 32
f@0 33 select_samples <- function(min_per_class = 10, selection_type = 'rand_surv',
f@0 34 max_draws = 1e6){
f@0 35
f@0 36 if (selection_type == 'rand_surv' |
f@0 37 selection_type == 'class_surv' |
f@0 38 selection_type == 'art_class_surv'){
f@0 39 survival_selection(min_per_class, selection_type, max_draws)
f@0 40 }
f@0 41
f@0 42 }
f@0 43
f@0 44 #
f@0 45 .update_set <- function(original_set, class_set){
f@0 46
f@0 47 class <- get_classes(class_set$ex_id,
f@0 48 unique_classes = T)[1]
f@0 49 classes <- get_classes(original_set$ex_id,
f@0 50 unique_classes = F)
f@0 51
f@0 52 new_set <- original_set[classes != class, ]
f@0 53
f@0 54 rbind(new_set, class_set) %>% arrange(ex_id)
f@0 55
f@0 56 }
f@0 57
f@0 58 #
f@0 59 .update_sets <- function(original_sets, class_sets){
f@0 60
f@0 61 original_sets$train <- .update_set(original_sets$train, class_sets$train)
f@0 62 original_sets$test <- .update_set(original_sets$test, class_sets$test)
f@0 63 original_sets$filt <- .update_set(original_sets$filt, class_sets$filt)
f@0 64 original_sets
f@0 65
f@0 66 }
f@0 67
f@0 68 #
f@0 69
f@0 70 .get_sets <- function(class = NULL, hold_out_ex_ids = NULL){
f@0 71
f@0 72 if(is.null(class)) N <- 1000
f@0 73 else N <- 100
f@0 74
f@0 75 sets <- get_bs_samples(N,
f@0 76 num_iter = 1, stratified = T,
f@0 77 keep_prop = F, classes = class,
f@0 78 hold_out_ex_ids = hold_out_ex_ids)
f@0 79 sets$filt <- filter_set(sets$test, sets$train)
f@0 80
f@0 81 return(sets)
f@0 82
f@0 83 }
f@0 84
f@0 85 .get_class_sets <- function(class = NULL,
f@0 86 selection_type = 'class_surv',
f@0 87 min_per_class = 0){
f@0 88
f@0 89 if(selection_type == 'class_surv') .get_sets(class = class)
f@0 90 else if(selection_type == 'art_class_surv'){
f@0 91
f@0 92 ex_art <-
f@0 93 get_excerpts_classes(class = class) %>%
f@0 94 inner_join(get_excerpts_artists(), by = c('ex_id'))
f@0 95
f@0 96 artists_dist <- table(ex_art$artist_id)
f@0 97
f@0 98 total_excerpts <- length(unique(ex_art$ex_id))
f@0 99 num_excerpts <- 0
f@0 100 hold_out_artists <- numeric(0)
f@0 101
f@0 102 while (num_excerpts < min_per_class) {
f@0 103 new_artist <- sample(unique(ex_art$artist_id), 1)
f@0 104 if (!any(new_artist %in% hold_out_artists)) {
f@0 105 hold_out_artists <- c(hold_out_artists, new_artist)
f@0 106 num_excerpts <- num_excerpts +
f@0 107 as.numeric(artists_dist[which(names(artists_dist) == new_artist)])
f@0 108 }
f@0 109
f@0 110 }
f@0 111
f@0 112 hold_out_ex_ids <- unique(ex_art$ex_id)[
f@0 113 which(ex_art$artist_id %in% hold_out_artists)]
f@0 114
f@0 115 return(.get_sets(class, hold_out_ex_ids))
f@0 116
f@0 117 }
f@0 118
f@0 119 }
f@0 120
f@0 121 .fail_selection <- function(num_draws){
f@0 122
f@0 123 list(sets = NULL, success = F, draws = num_draws)
f@0 124
f@0 125 }
f@0 126
f@0 127 survival_selection <- function(min_per_class = 10,
f@0 128 selection_type = 'rand_surv',
f@0 129 max_draws = 1e6){
f@0 130
f@0 131 # 1.- Draw initial sets
f@0 132 sets <- .get_sets()
f@0 133 num_draws <- nrow(sets$train)
f@0 134
f@0 135 # 2.- Check if initial test set meets requirements
f@0 136 checked <- check_filt(sets$filt, minimum = min_per_class)
f@0 137
f@0 138 if(!all(checked)){
f@0 139
f@0 140 # 3.1.- Random Survival: Complete redraw
f@0 141 if(selection_type == 'rand_surv')
f@0 142 while(!all(checked)){
f@0 143 if(num_draws > max_draws) return(.fail_selection(num_draws))
f@0 144 sets <- .get_sets()
f@0 145 num_draws <- num_draws + 1000
f@0 146 checked <- check_filt(sets$filt, minimum = min_per_class)
f@0 147 }
f@0 148 # 3.2.- Class Survival: Redraw only problematic classes
f@0 149 # 3.3.- Artist-informed Class Survival
f@0 150 else if(selection_type == 'class_surv' |
f@0 151 selection_type == 'art_class_surv'){
f@0 152
f@0 153 for(fail_class in names(checked[!checked])){
f@0 154 print(paste0("Redrawing class: ", fail_class))
f@0 155 if(num_draws > max_draws) return(.fail_selection(num_draws))
f@0 156 class_checked <- F
f@0 157 while (!class_checked) {
f@0 158 print(paste0(" Attempting: ", fail_class))
f@0 159 class_sets <- .get_class_sets(class = fail_class,
f@0 160 selection_type = selection_type,
f@0 161 min_per_class = min_per_class)
f@0 162 num_draws <- num_draws + 100
f@0 163 class_checked <- nrow(class_sets$filt) >= min_per_class
f@0 164 }
f@0 165 sets <- .update_sets(sets, class_sets)
f@0 166 }
f@0 167 }
f@0 168 }
f@0 169
f@0 170 return(list(sets = sets, success = T, draws = num_draws))
f@0 171
f@0 172 }
f@0 173
f@0 174
f@0 175 #
f@0 176 draw_samples <- function(num_iters = 40,
f@0 177 min_per_class = 10,
f@0 178 selection_type = "art_class_surv"){
f@0 179
f@0 180 sets <- list(train = data.frame(),
f@0 181 test = data.frame(),
f@0 182 filt = data.frame())
f@0 183 sim_results <- data.frame(
f@0 184 success = logical(0), num_draws = numeric(0))
f@0 185 i <- 1
f@0 186 while(i <= num_iters){
f@0 187
f@0 188 print("Starting new attempt")
f@0 189
f@0 190 res <- survival_selection(
f@0 191 min_per_class = min_per_class,
f@0 192 selection_type = selection_type
f@0 193 )
f@0 194
f@0 195 print(
f@0 196 paste0(
f@0 197 "Finishing new attempt: ",
f@0 198 ifelse(res$success, "Success ", "Failure "),
f@0 199 "after ", res$draws, " draws"))
f@0 200 print("")
f@0 201
f@0 202 sim_results <- rbind(sim_results,
f@0 203 list(success = res$success,
f@0 204 num_draws = res$draws))
f@0 205
f@0 206 if(res$success){
f@0 207 sets$train <- rbind(sets$train,
f@0 208 res$sets$train %>% mutate(iter = i))
f@0 209 sets$test <- rbind(sets$test,
f@0 210 res$sets$test %>% mutate(iter = i))
f@0 211 sets$filt <- rbind(sets$filt,
f@0 212 res$sets$filt %>% mutate(iter = i))
f@0 213 i <- i + 1
f@0 214 }
f@0 215
f@0 216 }
f@0 217
f@0 218 return(list(sets = sets, sim_results = sim_results))
f@0 219
f@0 220 }