Mercurial > hg > confint
view sampling/simulations.R @ 0:205974c9568c tip
Initial commit. Predictions not included for lack of space.
author | franrodalg <f.rodriguezalgarra@qmul.ac.uk> |
---|---|
date | Sat, 29 Jun 2019 18:45:50 +0100 |
parents | |
children |
line wrap: on
line source
filter_set <- function(set, ref_set){ excerpts <- set %>% inner_join(get_excerpts_artists(), by = c("ex_id")) artists <- get_artists(ref_set$ex_id, unique_artists = T) ex_ids <- excerpts %>% group_by(ex_id) %>% filter(any(artist_id %in% artists)) %>% select(ex_id) %>% unique() %>% unlist() excerpts %>% filter(! ex_id %in% ex_ids) %>% select(-artist_id) } check_filt <- function(set, minimum = 10){ if(plyr::empty(set) | is.null(set)){ classes <- get_class_names() res <- logical(nrow(classes)) names(res) <- classes$class return(res) } table( get_classes(set$ex_id, unique_classes = F, use_names = T)) >= minimum } select_samples <- function(min_per_class = 10, selection_type = 'rand_surv', max_draws = 1e6){ if (selection_type == 'rand_surv' | selection_type == 'class_surv' | selection_type == 'art_class_surv'){ survival_selection(min_per_class, selection_type, max_draws) } } # .update_set <- function(original_set, class_set){ class <- get_classes(class_set$ex_id, unique_classes = T)[1] classes <- get_classes(original_set$ex_id, unique_classes = F) new_set <- original_set[classes != class, ] rbind(new_set, class_set) %>% arrange(ex_id) } # .update_sets <- function(original_sets, class_sets){ original_sets$train <- .update_set(original_sets$train, class_sets$train) original_sets$test <- .update_set(original_sets$test, class_sets$test) original_sets$filt <- .update_set(original_sets$filt, class_sets$filt) original_sets } # .get_sets <- function(class = NULL, hold_out_ex_ids = NULL){ if(is.null(class)) N <- 1000 else N <- 100 sets <- get_bs_samples(N, num_iter = 1, stratified = T, keep_prop = F, classes = class, hold_out_ex_ids = hold_out_ex_ids) sets$filt <- filter_set(sets$test, sets$train) return(sets) } .get_class_sets <- function(class = NULL, selection_type = 'class_surv', min_per_class = 0){ if(selection_type == 'class_surv') .get_sets(class = class) else if(selection_type == 'art_class_surv'){ ex_art <- get_excerpts_classes(class = class) %>% inner_join(get_excerpts_artists(), by = c('ex_id')) artists_dist <- table(ex_art$artist_id) total_excerpts <- length(unique(ex_art$ex_id)) num_excerpts <- 0 hold_out_artists <- numeric(0) while (num_excerpts < min_per_class) { new_artist <- sample(unique(ex_art$artist_id), 1) if (!any(new_artist %in% hold_out_artists)) { hold_out_artists <- c(hold_out_artists, new_artist) num_excerpts <- num_excerpts + as.numeric(artists_dist[which(names(artists_dist) == new_artist)]) } } hold_out_ex_ids <- unique(ex_art$ex_id)[ which(ex_art$artist_id %in% hold_out_artists)] return(.get_sets(class, hold_out_ex_ids)) } } .fail_selection <- function(num_draws){ list(sets = NULL, success = F, draws = num_draws) } survival_selection <- function(min_per_class = 10, selection_type = 'rand_surv', max_draws = 1e6){ # 1.- Draw initial sets sets <- .get_sets() num_draws <- nrow(sets$train) # 2.- Check if initial test set meets requirements checked <- check_filt(sets$filt, minimum = min_per_class) if(!all(checked)){ # 3.1.- Random Survival: Complete redraw if(selection_type == 'rand_surv') while(!all(checked)){ if(num_draws > max_draws) return(.fail_selection(num_draws)) sets <- .get_sets() num_draws <- num_draws + 1000 checked <- check_filt(sets$filt, minimum = min_per_class) } # 3.2.- Class Survival: Redraw only problematic classes # 3.3.- Artist-informed Class Survival else if(selection_type == 'class_surv' | selection_type == 'art_class_surv'){ for(fail_class in names(checked[!checked])){ print(paste0("Redrawing class: ", fail_class)) if(num_draws > max_draws) return(.fail_selection(num_draws)) class_checked <- F while (!class_checked) { print(paste0(" Attempting: ", fail_class)) class_sets <- .get_class_sets(class = fail_class, selection_type = selection_type, min_per_class = min_per_class) num_draws <- num_draws + 100 class_checked <- nrow(class_sets$filt) >= min_per_class } sets <- .update_sets(sets, class_sets) } } } return(list(sets = sets, success = T, draws = num_draws)) } # draw_samples <- function(num_iters = 40, min_per_class = 10, selection_type = "art_class_surv"){ sets <- list(train = data.frame(), test = data.frame(), filt = data.frame()) sim_results <- data.frame( success = logical(0), num_draws = numeric(0)) i <- 1 while(i <= num_iters){ print("Starting new attempt") res <- survival_selection( min_per_class = min_per_class, selection_type = selection_type ) print( paste0( "Finishing new attempt: ", ifelse(res$success, "Success ", "Failure "), "after ", res$draws, " draws")) print("") sim_results <- rbind(sim_results, list(success = res$success, num_draws = res$draws)) if(res$success){ sets$train <- rbind(sets$train, res$sets$train %>% mutate(iter = i)) sets$test <- rbind(sets$test, res$sets$test %>% mutate(iter = i)) sets$filt <- rbind(sets$filt, res$sets$filt %>% mutate(iter = i)) i <- i + 1 } } return(list(sets = sets, sim_results = sim_results)) }