diff sampling/simulations.R @ 0:205974c9568c tip

Initial commit. Predictions not included for lack of space.
author franrodalg <f.rodriguezalgarra@qmul.ac.uk>
date Sat, 29 Jun 2019 18:45:50 +0100
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sampling/simulations.R	Sat Jun 29 18:45:50 2019 +0100
@@ -0,0 +1,220 @@
+filter_set <- function(set, ref_set){
+    
+  excerpts <- set %>%
+    inner_join(get_excerpts_artists(), by = c("ex_id"))
+  
+  artists <- get_artists(ref_set$ex_id, unique_artists = T)
+  
+  ex_ids <- excerpts %>% 
+    group_by(ex_id) %>% 
+    filter(any(artist_id %in% artists)) %>% 
+    select(ex_id) %>%
+    unique() %>%
+    unlist()
+    
+  excerpts %>% filter(! ex_id %in% ex_ids) %>% select(-artist_id)
+  
+}
+
+check_filt <- function(set, minimum = 10){
+  
+  if(plyr::empty(set) | is.null(set)){
+    classes <- get_class_names()
+    res <- logical(nrow(classes))
+    names(res) <- classes$class
+    return(res)
+  }
+    
+  table(
+    get_classes(set$ex_id, unique_classes = F, use_names = T)) >= minimum
+  
+}
+
+select_samples <- function(min_per_class = 10, selection_type = 'rand_surv',
+                           max_draws = 1e6){
+  
+  if (selection_type == 'rand_surv' | 
+      selection_type == 'class_surv' |
+      selection_type == 'art_class_surv'){
+    survival_selection(min_per_class, selection_type, max_draws)
+  }
+  
+}
+
+#
+.update_set <- function(original_set, class_set){
+  
+  class <- get_classes(class_set$ex_id,
+                       unique_classes = T)[1]
+  classes <- get_classes(original_set$ex_id,
+                         unique_classes = F)
+  
+  new_set <- original_set[classes != class, ]
+  
+  rbind(new_set, class_set) %>% arrange(ex_id)
+  
+}
+
+#
+.update_sets <- function(original_sets, class_sets){
+  
+  original_sets$train <- .update_set(original_sets$train, class_sets$train)
+  original_sets$test <- .update_set(original_sets$test, class_sets$test)
+  original_sets$filt <- .update_set(original_sets$filt, class_sets$filt)
+  original_sets
+
+}
+
+# 
+
+.get_sets <- function(class = NULL, hold_out_ex_ids = NULL){
+  
+  if(is.null(class)) N <- 1000
+  else N <- 100
+  
+  sets <- get_bs_samples(N, 
+                         num_iter = 1, stratified = T,
+                         keep_prop = F, classes = class,
+                         hold_out_ex_ids = hold_out_ex_ids)
+  sets$filt <- filter_set(sets$test, sets$train)
+  
+  return(sets)
+  
+}
+
+.get_class_sets <- function(class = NULL, 
+                            selection_type = 'class_surv',
+                            min_per_class = 0){
+  
+  if(selection_type == 'class_surv') .get_sets(class = class)
+  else if(selection_type == 'art_class_surv'){
+    
+    ex_art <- 
+      get_excerpts_classes(class = class) %>% 
+      inner_join(get_excerpts_artists(), by = c('ex_id'))
+    
+    artists_dist <- table(ex_art$artist_id)
+    
+    total_excerpts <- length(unique(ex_art$ex_id))
+    num_excerpts <- 0
+    hold_out_artists <- numeric(0)
+    
+    while (num_excerpts < min_per_class) {
+      new_artist <- sample(unique(ex_art$artist_id), 1)
+      if (!any(new_artist %in% hold_out_artists)) {
+        hold_out_artists <- c(hold_out_artists, new_artist)
+        num_excerpts <- num_excerpts + 
+          as.numeric(artists_dist[which(names(artists_dist) == new_artist)])
+      }
+      
+    }
+    
+    hold_out_ex_ids <- unique(ex_art$ex_id)[
+      which(ex_art$artist_id %in% hold_out_artists)]
+    
+    return(.get_sets(class, hold_out_ex_ids))
+    
+  }
+  
+}
+
+.fail_selection <- function(num_draws){
+  
+  list(sets = NULL, success = F, draws = num_draws)
+  
+}
+
+survival_selection <- function(min_per_class = 10,
+                               selection_type = 'rand_surv',
+                               max_draws = 1e6){
+  
+  # 1.- Draw initial sets
+  sets <- .get_sets()
+  num_draws <- nrow(sets$train)
+  
+  # 2.- Check if initial test set meets requirements
+  checked <- check_filt(sets$filt, minimum = min_per_class)
+  
+  if(!all(checked)){
+    
+    # 3.1.- Random Survival: Complete redraw
+    if(selection_type == 'rand_surv')
+      while(!all(checked)){
+        if(num_draws > max_draws) return(.fail_selection(num_draws))
+        sets <- .get_sets()
+        num_draws <- num_draws + 1000
+        checked <- check_filt(sets$filt, minimum = min_per_class)
+      }
+    # 3.2.- Class Survival: Redraw only problematic classes
+    # 3.3.- Artist-informed Class Survival
+    else if(selection_type == 'class_surv' | 
+            selection_type == 'art_class_surv'){
+      
+      for(fail_class in names(checked[!checked])){
+        print(paste0("Redrawing class: ", fail_class))
+        if(num_draws > max_draws) return(.fail_selection(num_draws))
+        class_checked <- F
+        while (!class_checked) {
+          print(paste0("  Attempting: ", fail_class))
+          class_sets <- .get_class_sets(class = fail_class,
+                                        selection_type = selection_type,
+                                        min_per_class = min_per_class)
+          num_draws <- num_draws + 100
+          class_checked <- nrow(class_sets$filt) >= min_per_class
+        }
+        sets <- .update_sets(sets, class_sets)
+      }
+    }
+  }
+    
+  return(list(sets = sets, success = T, draws = num_draws))
+
+}
+
+
+# 
+draw_samples <- function(num_iters = 40,
+                         min_per_class = 10,
+                         selection_type = "art_class_surv"){
+  
+  sets <- list(train = data.frame(),
+               test = data.frame(),
+               filt = data.frame())
+  sim_results <- data.frame(
+    success = logical(0), num_draws = numeric(0))
+  i <- 1
+  while(i <= num_iters){
+    
+    print("Starting new attempt")
+    
+    res <- survival_selection(
+      min_per_class = min_per_class,
+      selection_type = selection_type
+    )
+    
+    print(
+      paste0(
+        "Finishing new attempt: ",
+        ifelse(res$success, "Success ", "Failure "),
+        "after ", res$draws, " draws"))
+    print("")
+    
+    sim_results <- rbind(sim_results,
+                         list(success = res$success,
+                              num_draws = res$draws))
+    
+    if(res$success){
+      sets$train <- rbind(sets$train,
+                          res$sets$train %>% mutate(iter = i))
+      sets$test <- rbind(sets$test,
+                          res$sets$test %>% mutate(iter = i))
+      sets$filt <- rbind(sets$filt,
+                          res$sets$filt %>% mutate(iter = i))
+      i <- i + 1
+    }
+    
+  }  
+    
+  return(list(sets = sets, sim_results = sim_results))
+    
+}