view sampling/simulations.R @ 0:205974c9568c tip

Initial commit. Predictions not included for lack of space.
author franrodalg <f.rodriguezalgarra@qmul.ac.uk>
date Sat, 29 Jun 2019 18:45:50 +0100
parents
children
line wrap: on
line source
filter_set <- function(set, ref_set){
    
  excerpts <- set %>%
    inner_join(get_excerpts_artists(), by = c("ex_id"))
  
  artists <- get_artists(ref_set$ex_id, unique_artists = T)
  
  ex_ids <- excerpts %>% 
    group_by(ex_id) %>% 
    filter(any(artist_id %in% artists)) %>% 
    select(ex_id) %>%
    unique() %>%
    unlist()
    
  excerpts %>% filter(! ex_id %in% ex_ids) %>% select(-artist_id)
  
}

check_filt <- function(set, minimum = 10){
  
  if(plyr::empty(set) | is.null(set)){
    classes <- get_class_names()
    res <- logical(nrow(classes))
    names(res) <- classes$class
    return(res)
  }
    
  table(
    get_classes(set$ex_id, unique_classes = F, use_names = T)) >= minimum
  
}

select_samples <- function(min_per_class = 10, selection_type = 'rand_surv',
                           max_draws = 1e6){
  
  if (selection_type == 'rand_surv' | 
      selection_type == 'class_surv' |
      selection_type == 'art_class_surv'){
    survival_selection(min_per_class, selection_type, max_draws)
  }
  
}

#
.update_set <- function(original_set, class_set){
  
  class <- get_classes(class_set$ex_id,
                       unique_classes = T)[1]
  classes <- get_classes(original_set$ex_id,
                         unique_classes = F)
  
  new_set <- original_set[classes != class, ]
  
  rbind(new_set, class_set) %>% arrange(ex_id)
  
}

#
.update_sets <- function(original_sets, class_sets){
  
  original_sets$train <- .update_set(original_sets$train, class_sets$train)
  original_sets$test <- .update_set(original_sets$test, class_sets$test)
  original_sets$filt <- .update_set(original_sets$filt, class_sets$filt)
  original_sets

}

# 

.get_sets <- function(class = NULL, hold_out_ex_ids = NULL){
  
  if(is.null(class)) N <- 1000
  else N <- 100
  
  sets <- get_bs_samples(N, 
                         num_iter = 1, stratified = T,
                         keep_prop = F, classes = class,
                         hold_out_ex_ids = hold_out_ex_ids)
  sets$filt <- filter_set(sets$test, sets$train)
  
  return(sets)
  
}

.get_class_sets <- function(class = NULL, 
                            selection_type = 'class_surv',
                            min_per_class = 0){
  
  if(selection_type == 'class_surv') .get_sets(class = class)
  else if(selection_type == 'art_class_surv'){
    
    ex_art <- 
      get_excerpts_classes(class = class) %>% 
      inner_join(get_excerpts_artists(), by = c('ex_id'))
    
    artists_dist <- table(ex_art$artist_id)
    
    total_excerpts <- length(unique(ex_art$ex_id))
    num_excerpts <- 0
    hold_out_artists <- numeric(0)
    
    while (num_excerpts < min_per_class) {
      new_artist <- sample(unique(ex_art$artist_id), 1)
      if (!any(new_artist %in% hold_out_artists)) {
        hold_out_artists <- c(hold_out_artists, new_artist)
        num_excerpts <- num_excerpts + 
          as.numeric(artists_dist[which(names(artists_dist) == new_artist)])
      }
      
    }
    
    hold_out_ex_ids <- unique(ex_art$ex_id)[
      which(ex_art$artist_id %in% hold_out_artists)]
    
    return(.get_sets(class, hold_out_ex_ids))
    
  }
  
}

.fail_selection <- function(num_draws){
  
  list(sets = NULL, success = F, draws = num_draws)
  
}

survival_selection <- function(min_per_class = 10,
                               selection_type = 'rand_surv',
                               max_draws = 1e6){
  
  # 1.- Draw initial sets
  sets <- .get_sets()
  num_draws <- nrow(sets$train)
  
  # 2.- Check if initial test set meets requirements
  checked <- check_filt(sets$filt, minimum = min_per_class)
  
  if(!all(checked)){
    
    # 3.1.- Random Survival: Complete redraw
    if(selection_type == 'rand_surv')
      while(!all(checked)){
        if(num_draws > max_draws) return(.fail_selection(num_draws))
        sets <- .get_sets()
        num_draws <- num_draws + 1000
        checked <- check_filt(sets$filt, minimum = min_per_class)
      }
    # 3.2.- Class Survival: Redraw only problematic classes
    # 3.3.- Artist-informed Class Survival
    else if(selection_type == 'class_surv' | 
            selection_type == 'art_class_surv'){
      
      for(fail_class in names(checked[!checked])){
        print(paste0("Redrawing class: ", fail_class))
        if(num_draws > max_draws) return(.fail_selection(num_draws))
        class_checked <- F
        while (!class_checked) {
          print(paste0("  Attempting: ", fail_class))
          class_sets <- .get_class_sets(class = fail_class,
                                        selection_type = selection_type,
                                        min_per_class = min_per_class)
          num_draws <- num_draws + 100
          class_checked <- nrow(class_sets$filt) >= min_per_class
        }
        sets <- .update_sets(sets, class_sets)
      }
    }
  }
    
  return(list(sets = sets, success = T, draws = num_draws))

}


# 
draw_samples <- function(num_iters = 40,
                         min_per_class = 10,
                         selection_type = "art_class_surv"){
  
  sets <- list(train = data.frame(),
               test = data.frame(),
               filt = data.frame())
  sim_results <- data.frame(
    success = logical(0), num_draws = numeric(0))
  i <- 1
  while(i <= num_iters){
    
    print("Starting new attempt")
    
    res <- survival_selection(
      min_per_class = min_per_class,
      selection_type = selection_type
    )
    
    print(
      paste0(
        "Finishing new attempt: ",
        ifelse(res$success, "Success ", "Failure "),
        "after ", res$draws, " draws"))
    print("")
    
    sim_results <- rbind(sim_results,
                         list(success = res$success,
                              num_draws = res$draws))
    
    if(res$success){
      sets$train <- rbind(sets$train,
                          res$sets$train %>% mutate(iter = i))
      sets$test <- rbind(sets$test,
                          res$sets$test %>% mutate(iter = i))
      sets$filt <- rbind(sets$filt,
                          res$sets$filt %>% mutate(iter = i))
      i <- i + 1
    }
    
  }  
    
  return(list(sets = sets, sim_results = sim_results))
    
}