Mercurial > hg > confint
diff sampling/simulations.R @ 0:205974c9568c tip
Initial commit. Predictions not included for lack of space.
author | franrodalg <f.rodriguezalgarra@qmul.ac.uk> |
---|---|
date | Sat, 29 Jun 2019 18:45:50 +0100 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sampling/simulations.R Sat Jun 29 18:45:50 2019 +0100 @@ -0,0 +1,220 @@ +filter_set <- function(set, ref_set){ + + excerpts <- set %>% + inner_join(get_excerpts_artists(), by = c("ex_id")) + + artists <- get_artists(ref_set$ex_id, unique_artists = T) + + ex_ids <- excerpts %>% + group_by(ex_id) %>% + filter(any(artist_id %in% artists)) %>% + select(ex_id) %>% + unique() %>% + unlist() + + excerpts %>% filter(! ex_id %in% ex_ids) %>% select(-artist_id) + +} + +check_filt <- function(set, minimum = 10){ + + if(plyr::empty(set) | is.null(set)){ + classes <- get_class_names() + res <- logical(nrow(classes)) + names(res) <- classes$class + return(res) + } + + table( + get_classes(set$ex_id, unique_classes = F, use_names = T)) >= minimum + +} + +select_samples <- function(min_per_class = 10, selection_type = 'rand_surv', + max_draws = 1e6){ + + if (selection_type == 'rand_surv' | + selection_type == 'class_surv' | + selection_type == 'art_class_surv'){ + survival_selection(min_per_class, selection_type, max_draws) + } + +} + +# +.update_set <- function(original_set, class_set){ + + class <- get_classes(class_set$ex_id, + unique_classes = T)[1] + classes <- get_classes(original_set$ex_id, + unique_classes = F) + + new_set <- original_set[classes != class, ] + + rbind(new_set, class_set) %>% arrange(ex_id) + +} + +# +.update_sets <- function(original_sets, class_sets){ + + original_sets$train <- .update_set(original_sets$train, class_sets$train) + original_sets$test <- .update_set(original_sets$test, class_sets$test) + original_sets$filt <- .update_set(original_sets$filt, class_sets$filt) + original_sets + +} + +# + +.get_sets <- function(class = NULL, hold_out_ex_ids = NULL){ + + if(is.null(class)) N <- 1000 + else N <- 100 + + sets <- get_bs_samples(N, + num_iter = 1, stratified = T, + keep_prop = F, classes = class, + hold_out_ex_ids = hold_out_ex_ids) + sets$filt <- filter_set(sets$test, sets$train) + + return(sets) + +} + +.get_class_sets <- function(class = NULL, + selection_type = 'class_surv', + min_per_class = 0){ + + if(selection_type == 'class_surv') .get_sets(class = class) + else if(selection_type == 'art_class_surv'){ + + ex_art <- + get_excerpts_classes(class = class) %>% + inner_join(get_excerpts_artists(), by = c('ex_id')) + + artists_dist <- table(ex_art$artist_id) + + total_excerpts <- length(unique(ex_art$ex_id)) + num_excerpts <- 0 + hold_out_artists <- numeric(0) + + while (num_excerpts < min_per_class) { + new_artist <- sample(unique(ex_art$artist_id), 1) + if (!any(new_artist %in% hold_out_artists)) { + hold_out_artists <- c(hold_out_artists, new_artist) + num_excerpts <- num_excerpts + + as.numeric(artists_dist[which(names(artists_dist) == new_artist)]) + } + + } + + hold_out_ex_ids <- unique(ex_art$ex_id)[ + which(ex_art$artist_id %in% hold_out_artists)] + + return(.get_sets(class, hold_out_ex_ids)) + + } + +} + +.fail_selection <- function(num_draws){ + + list(sets = NULL, success = F, draws = num_draws) + +} + +survival_selection <- function(min_per_class = 10, + selection_type = 'rand_surv', + max_draws = 1e6){ + + # 1.- Draw initial sets + sets <- .get_sets() + num_draws <- nrow(sets$train) + + # 2.- Check if initial test set meets requirements + checked <- check_filt(sets$filt, minimum = min_per_class) + + if(!all(checked)){ + + # 3.1.- Random Survival: Complete redraw + if(selection_type == 'rand_surv') + while(!all(checked)){ + if(num_draws > max_draws) return(.fail_selection(num_draws)) + sets <- .get_sets() + num_draws <- num_draws + 1000 + checked <- check_filt(sets$filt, minimum = min_per_class) + } + # 3.2.- Class Survival: Redraw only problematic classes + # 3.3.- Artist-informed Class Survival + else if(selection_type == 'class_surv' | + selection_type == 'art_class_surv'){ + + for(fail_class in names(checked[!checked])){ + print(paste0("Redrawing class: ", fail_class)) + if(num_draws > max_draws) return(.fail_selection(num_draws)) + class_checked <- F + while (!class_checked) { + print(paste0(" Attempting: ", fail_class)) + class_sets <- .get_class_sets(class = fail_class, + selection_type = selection_type, + min_per_class = min_per_class) + num_draws <- num_draws + 100 + class_checked <- nrow(class_sets$filt) >= min_per_class + } + sets <- .update_sets(sets, class_sets) + } + } + } + + return(list(sets = sets, success = T, draws = num_draws)) + +} + + +# +draw_samples <- function(num_iters = 40, + min_per_class = 10, + selection_type = "art_class_surv"){ + + sets <- list(train = data.frame(), + test = data.frame(), + filt = data.frame()) + sim_results <- data.frame( + success = logical(0), num_draws = numeric(0)) + i <- 1 + while(i <= num_iters){ + + print("Starting new attempt") + + res <- survival_selection( + min_per_class = min_per_class, + selection_type = selection_type + ) + + print( + paste0( + "Finishing new attempt: ", + ifelse(res$success, "Success ", "Failure "), + "after ", res$draws, " draws")) + print("") + + sim_results <- rbind(sim_results, + list(success = res$success, + num_draws = res$draws)) + + if(res$success){ + sets$train <- rbind(sets$train, + res$sets$train %>% mutate(iter = i)) + sets$test <- rbind(sets$test, + res$sets$test %>% mutate(iter = i)) + sets$filt <- rbind(sets$filt, + res$sets$filt %>% mutate(iter = i)) + i <- i + 1 + } + + } + + return(list(sets = sets, sim_results = sim_results)) + +}