f@0: filter_set <- function(set, ref_set){ f@0: f@0: excerpts <- set %>% f@0: inner_join(get_excerpts_artists(), by = c("ex_id")) f@0: f@0: artists <- get_artists(ref_set$ex_id, unique_artists = T) f@0: f@0: ex_ids <- excerpts %>% f@0: group_by(ex_id) %>% f@0: filter(any(artist_id %in% artists)) %>% f@0: select(ex_id) %>% f@0: unique() %>% f@0: unlist() f@0: f@0: excerpts %>% filter(! ex_id %in% ex_ids) %>% select(-artist_id) f@0: f@0: } f@0: f@0: check_filt <- function(set, minimum = 10){ f@0: f@0: if(plyr::empty(set) | is.null(set)){ f@0: classes <- get_class_names() f@0: res <- logical(nrow(classes)) f@0: names(res) <- classes$class f@0: return(res) f@0: } f@0: f@0: table( f@0: get_classes(set$ex_id, unique_classes = F, use_names = T)) >= minimum f@0: f@0: } f@0: f@0: select_samples <- function(min_per_class = 10, selection_type = 'rand_surv', f@0: max_draws = 1e6){ f@0: f@0: if (selection_type == 'rand_surv' | f@0: selection_type == 'class_surv' | f@0: selection_type == 'art_class_surv'){ f@0: survival_selection(min_per_class, selection_type, max_draws) f@0: } f@0: f@0: } f@0: f@0: # f@0: .update_set <- function(original_set, class_set){ f@0: f@0: class <- get_classes(class_set$ex_id, f@0: unique_classes = T)[1] f@0: classes <- get_classes(original_set$ex_id, f@0: unique_classes = F) f@0: f@0: new_set <- original_set[classes != class, ] f@0: f@0: rbind(new_set, class_set) %>% arrange(ex_id) f@0: f@0: } f@0: f@0: # f@0: .update_sets <- function(original_sets, class_sets){ f@0: f@0: original_sets$train <- .update_set(original_sets$train, class_sets$train) f@0: original_sets$test <- .update_set(original_sets$test, class_sets$test) f@0: original_sets$filt <- .update_set(original_sets$filt, class_sets$filt) f@0: original_sets f@0: f@0: } f@0: f@0: # f@0: f@0: .get_sets <- function(class = NULL, hold_out_ex_ids = NULL){ f@0: f@0: if(is.null(class)) N <- 1000 f@0: else N <- 100 f@0: f@0: sets <- get_bs_samples(N, f@0: num_iter = 1, stratified = T, f@0: keep_prop = F, classes = class, f@0: hold_out_ex_ids = hold_out_ex_ids) f@0: sets$filt <- filter_set(sets$test, sets$train) f@0: f@0: return(sets) f@0: f@0: } f@0: f@0: .get_class_sets <- function(class = NULL, f@0: selection_type = 'class_surv', f@0: min_per_class = 0){ f@0: f@0: if(selection_type == 'class_surv') .get_sets(class = class) f@0: else if(selection_type == 'art_class_surv'){ f@0: f@0: ex_art <- f@0: get_excerpts_classes(class = class) %>% f@0: inner_join(get_excerpts_artists(), by = c('ex_id')) f@0: f@0: artists_dist <- table(ex_art$artist_id) f@0: f@0: total_excerpts <- length(unique(ex_art$ex_id)) f@0: num_excerpts <- 0 f@0: hold_out_artists <- numeric(0) f@0: f@0: while (num_excerpts < min_per_class) { f@0: new_artist <- sample(unique(ex_art$artist_id), 1) f@0: if (!any(new_artist %in% hold_out_artists)) { f@0: hold_out_artists <- c(hold_out_artists, new_artist) f@0: num_excerpts <- num_excerpts + f@0: as.numeric(artists_dist[which(names(artists_dist) == new_artist)]) f@0: } f@0: f@0: } f@0: f@0: hold_out_ex_ids <- unique(ex_art$ex_id)[ f@0: which(ex_art$artist_id %in% hold_out_artists)] f@0: f@0: return(.get_sets(class, hold_out_ex_ids)) f@0: f@0: } f@0: f@0: } f@0: f@0: .fail_selection <- function(num_draws){ f@0: f@0: list(sets = NULL, success = F, draws = num_draws) f@0: f@0: } f@0: f@0: survival_selection <- function(min_per_class = 10, f@0: selection_type = 'rand_surv', f@0: max_draws = 1e6){ f@0: f@0: # 1.- Draw initial sets f@0: sets <- .get_sets() f@0: num_draws <- nrow(sets$train) f@0: f@0: # 2.- Check if initial test set meets requirements f@0: checked <- check_filt(sets$filt, minimum = min_per_class) f@0: f@0: if(!all(checked)){ f@0: f@0: # 3.1.- Random Survival: Complete redraw f@0: if(selection_type == 'rand_surv') f@0: while(!all(checked)){ f@0: if(num_draws > max_draws) return(.fail_selection(num_draws)) f@0: sets <- .get_sets() f@0: num_draws <- num_draws + 1000 f@0: checked <- check_filt(sets$filt, minimum = min_per_class) f@0: } f@0: # 3.2.- Class Survival: Redraw only problematic classes f@0: # 3.3.- Artist-informed Class Survival f@0: else if(selection_type == 'class_surv' | f@0: selection_type == 'art_class_surv'){ f@0: f@0: for(fail_class in names(checked[!checked])){ f@0: print(paste0("Redrawing class: ", fail_class)) f@0: if(num_draws > max_draws) return(.fail_selection(num_draws)) f@0: class_checked <- F f@0: while (!class_checked) { f@0: print(paste0(" Attempting: ", fail_class)) f@0: class_sets <- .get_class_sets(class = fail_class, f@0: selection_type = selection_type, f@0: min_per_class = min_per_class) f@0: num_draws <- num_draws + 100 f@0: class_checked <- nrow(class_sets$filt) >= min_per_class f@0: } f@0: sets <- .update_sets(sets, class_sets) f@0: } f@0: } f@0: } f@0: f@0: return(list(sets = sets, success = T, draws = num_draws)) f@0: f@0: } f@0: f@0: f@0: # f@0: draw_samples <- function(num_iters = 40, f@0: min_per_class = 10, f@0: selection_type = "art_class_surv"){ f@0: f@0: sets <- list(train = data.frame(), f@0: test = data.frame(), f@0: filt = data.frame()) f@0: sim_results <- data.frame( f@0: success = logical(0), num_draws = numeric(0)) f@0: i <- 1 f@0: while(i <= num_iters){ f@0: f@0: print("Starting new attempt") f@0: f@0: res <- survival_selection( f@0: min_per_class = min_per_class, f@0: selection_type = selection_type f@0: ) f@0: f@0: print( f@0: paste0( f@0: "Finishing new attempt: ", f@0: ifelse(res$success, "Success ", "Failure "), f@0: "after ", res$draws, " draws")) f@0: print("") f@0: f@0: sim_results <- rbind(sim_results, f@0: list(success = res$success, f@0: num_draws = res$draws)) f@0: f@0: if(res$success){ f@0: sets$train <- rbind(sets$train, f@0: res$sets$train %>% mutate(iter = i)) f@0: sets$test <- rbind(sets$test, f@0: res$sets$test %>% mutate(iter = i)) f@0: sets$filt <- rbind(sets$filt, f@0: res$sets$filt %>% mutate(iter = i)) f@0: i <- i + 1 f@0: } f@0: f@0: } f@0: f@0: return(list(sets = sets, sim_results = sim_results)) f@0: f@0: }