Mercurial > hg > confint
diff sampling/strategies.R @ 0:205974c9568c tip
Initial commit. Predictions not included for lack of space.
author | franrodalg <f.rodriguezalgarra@qmul.ac.uk> |
---|---|
date | Sat, 29 Jun 2019 18:45:50 +0100 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sampling/strategies.R Sat Jun 29 18:45:50 2019 +0100 @@ -0,0 +1,204 @@ +## Generic sampling function +get_samples <- function(num_iter = 1, mode = 'cv', + N = 0, num_folds = 4, + stratified = F, keep_prop = F, + db = "../db/gtzan.db", classes = NULL){ + + if (mode == 'cv') + get_cv_samples( + num_iter = num_iter, num_folds = num_folds, + stratified = stratified, db = db, classes = classes) + else + get_bs_samples( + N = N, num_iter = num_iter, stratified = stratified, + keep_prop = keep_prop, db = db, classes = classes) + +} + +# CROSS-VALIDATION + +## Cross-Validation sampling +get_cv_samples <- function(num_iter = 1, num_folds = 4, stratified = F, + db = "../db/gtzan.db", classes = NULL){ + + train <- data.frame() + test <- data.frame() + + for (i in 1:num_iter){ + folds <- + get_folds(num_folds = num_folds, stratified = stratified, + db = db, classes = classes) + + for (j in 1:num_folds){ + + train <- rbind(train, + data.frame(ex_id = folds[[j]]$train, + iter = i, fold = j, + mode = 'cv', + stratified = stratified)) + test <- rbind(test, + data.frame(ex_id = folds[[j]]$test, + iter = i, fold = j, + mode = 'cv', + stratified = stratified)) + } + + } + + return(list(train = train, test = test)) + +} + + +# Random assignment into folds +.get_assignment <- function(N, num_folds){ + + sample( + as.numeric( + sapply(1:num_folds, + function(x) rep(x, N/num_folds)))) + +} + +# Obtain folds for Cross-Validation +get_folds <- function(num_folds = 4, stratified = F, + db = "../db/gtzan.db", classes = NULL){ + + res <- get_excerpts_classes(db = db, classes) %>% + arrange(class_id, ex_id_class) + + if(stratified){ + assignment <- numeric(0) + for(i in unique(res$class_id)){ + N_class <- + filter(res, class_id == i) %>% + nrow() + assignment <- + c(assignment, .get_assignment(N_class, num_folds)) + } + } + else + assignment <- .get_assignment(nrow(res), num_folds) + + folds <- vector("list", num_folds) + aux <- res$ex_id + + for (i in 1:num_folds){ + folds[[i]][['test']] <- + aux[which(assignment == i)] + folds[[i]][['train']] <- + aux[which(assignment != i)] + } + + folds + +} + + +# BOOTSTRAP + +## Boostrap sampling +get_bs_samples <- function(N, num_iter, stratified = F, keep_prop = F, + db = "../db/gtzan.db", classes = NULL, + hold_out_ex_ids = NULL){ + + res <- get_excerpts_classes(db, classes) + + if (N == 0 | is.null(N)) + N <- nrow(res) + + train <- data.frame() + test <- data.frame() + + for (i in 1:num_iter){ + train_test <- + get_bs_sample(N, stratified, keep_prop, db, classes, res, + hold_out_ex_ids) + train <- rbind(train, + data.frame(ex_id = train_test$train, + iter = i, + mode = 'bs', + stratified = stratified, + keep_prop = keep_prop)) + test <- rbind(test, + data.frame(ex_id = train_test$test, + iter = i, + mode = 'bs', + stratified = stratified, + keep_prop = keep_prop)) + } + + return(list(train = train, test = test)) + +} + +## Correct proportions for test set stratification +.correct_prop <- function(samples, reference_samples, + db = "../db/gtzan.db"){ + + ref <- as.factor( + get_classes(reference_samples, unique_classes = F, db = db)) + classes <- factor( + get_classes(samples, unique_classes = F, db = db), + levels = levels(ref)) + + ref_counts <- table(ref) + ref_prop <- ref_counts / sum(ref_counts) + + prior_counts <- table(classes) + post_counts <- round(ref_prop * min(table(classes) / ref_prop)) + + post_samples <- numeric(0) + for (class in levels(classes)){ + post_samples <- + c(post_samples, + sample(samples[which(classes == class)], + size = post_counts[[class]])) + } + + post_samples + +} + +## Bootstrap single train/test pair +get_bs_sample <- function(N, stratified = F, keep_prop = F, + db = "../db/gtzan.db", classes = NULL, + df = NULL, + hold_out_ex_ids = NULL + ){ + + if(is.null(df)) + df <- get_excerpts_classes(db, classes) + + n_ex_class <- table(df$class_id) + print(n_ex_class) + + if(!is.null(hold_out_ex_ids)) + df <- df %>% filter(! (ex_id %in% hold_out_ex_ids)) + + train <- numeric(0) + + if(stratified){ + for(j in unique(df$class_id)){ + ex_class <- filter(df, class_id == j) %>% + select(ex_id) %>% + unlist() %>% + unname() + train <- c(train, + sample(ex_class, + size = n_ex_class[which(names(n_ex_class) == j)], + replace = T)) + } + } + else{ + train <- sample(df$ex_id, size = N, replace = T) + } + + test <- sort(c(df$ex_id[!(df$ex_id %in% train)], hold_out_ex_ids)) + if(keep_prop){ + test <- .correct_prop(test, reference = train) + } + + return(list(train = train, test = test)) + +}