f@0
|
1 filter_set <- function(set, ref_set){
|
f@0
|
2
|
f@0
|
3 excerpts <- set %>%
|
f@0
|
4 inner_join(get_excerpts_artists(), by = c("ex_id"))
|
f@0
|
5
|
f@0
|
6 artists <- get_artists(ref_set$ex_id, unique_artists = T)
|
f@0
|
7
|
f@0
|
8 ex_ids <- excerpts %>%
|
f@0
|
9 group_by(ex_id) %>%
|
f@0
|
10 filter(any(artist_id %in% artists)) %>%
|
f@0
|
11 select(ex_id) %>%
|
f@0
|
12 unique() %>%
|
f@0
|
13 unlist()
|
f@0
|
14
|
f@0
|
15 excerpts %>% filter(! ex_id %in% ex_ids) %>% select(-artist_id)
|
f@0
|
16
|
f@0
|
17 }
|
f@0
|
18
|
f@0
|
19 check_filt <- function(set, minimum = 10){
|
f@0
|
20
|
f@0
|
21 if(plyr::empty(set) | is.null(set)){
|
f@0
|
22 classes <- get_class_names()
|
f@0
|
23 res <- logical(nrow(classes))
|
f@0
|
24 names(res) <- classes$class
|
f@0
|
25 return(res)
|
f@0
|
26 }
|
f@0
|
27
|
f@0
|
28 table(
|
f@0
|
29 get_classes(set$ex_id, unique_classes = F, use_names = T)) >= minimum
|
f@0
|
30
|
f@0
|
31 }
|
f@0
|
32
|
f@0
|
33 select_samples <- function(min_per_class = 10, selection_type = 'rand_surv',
|
f@0
|
34 max_draws = 1e6){
|
f@0
|
35
|
f@0
|
36 if (selection_type == 'rand_surv' |
|
f@0
|
37 selection_type == 'class_surv' |
|
f@0
|
38 selection_type == 'art_class_surv'){
|
f@0
|
39 survival_selection(min_per_class, selection_type, max_draws)
|
f@0
|
40 }
|
f@0
|
41
|
f@0
|
42 }
|
f@0
|
43
|
f@0
|
44 #
|
f@0
|
45 .update_set <- function(original_set, class_set){
|
f@0
|
46
|
f@0
|
47 class <- get_classes(class_set$ex_id,
|
f@0
|
48 unique_classes = T)[1]
|
f@0
|
49 classes <- get_classes(original_set$ex_id,
|
f@0
|
50 unique_classes = F)
|
f@0
|
51
|
f@0
|
52 new_set <- original_set[classes != class, ]
|
f@0
|
53
|
f@0
|
54 rbind(new_set, class_set) %>% arrange(ex_id)
|
f@0
|
55
|
f@0
|
56 }
|
f@0
|
57
|
f@0
|
58 #
|
f@0
|
59 .update_sets <- function(original_sets, class_sets){
|
f@0
|
60
|
f@0
|
61 original_sets$train <- .update_set(original_sets$train, class_sets$train)
|
f@0
|
62 original_sets$test <- .update_set(original_sets$test, class_sets$test)
|
f@0
|
63 original_sets$filt <- .update_set(original_sets$filt, class_sets$filt)
|
f@0
|
64 original_sets
|
f@0
|
65
|
f@0
|
66 }
|
f@0
|
67
|
f@0
|
68 #
|
f@0
|
69
|
f@0
|
70 .get_sets <- function(class = NULL, hold_out_ex_ids = NULL){
|
f@0
|
71
|
f@0
|
72 if(is.null(class)) N <- 1000
|
f@0
|
73 else N <- 100
|
f@0
|
74
|
f@0
|
75 sets <- get_bs_samples(N,
|
f@0
|
76 num_iter = 1, stratified = T,
|
f@0
|
77 keep_prop = F, classes = class,
|
f@0
|
78 hold_out_ex_ids = hold_out_ex_ids)
|
f@0
|
79 sets$filt <- filter_set(sets$test, sets$train)
|
f@0
|
80
|
f@0
|
81 return(sets)
|
f@0
|
82
|
f@0
|
83 }
|
f@0
|
84
|
f@0
|
85 .get_class_sets <- function(class = NULL,
|
f@0
|
86 selection_type = 'class_surv',
|
f@0
|
87 min_per_class = 0){
|
f@0
|
88
|
f@0
|
89 if(selection_type == 'class_surv') .get_sets(class = class)
|
f@0
|
90 else if(selection_type == 'art_class_surv'){
|
f@0
|
91
|
f@0
|
92 ex_art <-
|
f@0
|
93 get_excerpts_classes(class = class) %>%
|
f@0
|
94 inner_join(get_excerpts_artists(), by = c('ex_id'))
|
f@0
|
95
|
f@0
|
96 artists_dist <- table(ex_art$artist_id)
|
f@0
|
97
|
f@0
|
98 total_excerpts <- length(unique(ex_art$ex_id))
|
f@0
|
99 num_excerpts <- 0
|
f@0
|
100 hold_out_artists <- numeric(0)
|
f@0
|
101
|
f@0
|
102 while (num_excerpts < min_per_class) {
|
f@0
|
103 new_artist <- sample(unique(ex_art$artist_id), 1)
|
f@0
|
104 if (!any(new_artist %in% hold_out_artists)) {
|
f@0
|
105 hold_out_artists <- c(hold_out_artists, new_artist)
|
f@0
|
106 num_excerpts <- num_excerpts +
|
f@0
|
107 as.numeric(artists_dist[which(names(artists_dist) == new_artist)])
|
f@0
|
108 }
|
f@0
|
109
|
f@0
|
110 }
|
f@0
|
111
|
f@0
|
112 hold_out_ex_ids <- unique(ex_art$ex_id)[
|
f@0
|
113 which(ex_art$artist_id %in% hold_out_artists)]
|
f@0
|
114
|
f@0
|
115 return(.get_sets(class, hold_out_ex_ids))
|
f@0
|
116
|
f@0
|
117 }
|
f@0
|
118
|
f@0
|
119 }
|
f@0
|
120
|
f@0
|
121 .fail_selection <- function(num_draws){
|
f@0
|
122
|
f@0
|
123 list(sets = NULL, success = F, draws = num_draws)
|
f@0
|
124
|
f@0
|
125 }
|
f@0
|
126
|
f@0
|
127 survival_selection <- function(min_per_class = 10,
|
f@0
|
128 selection_type = 'rand_surv',
|
f@0
|
129 max_draws = 1e6){
|
f@0
|
130
|
f@0
|
131 # 1.- Draw initial sets
|
f@0
|
132 sets <- .get_sets()
|
f@0
|
133 num_draws <- nrow(sets$train)
|
f@0
|
134
|
f@0
|
135 # 2.- Check if initial test set meets requirements
|
f@0
|
136 checked <- check_filt(sets$filt, minimum = min_per_class)
|
f@0
|
137
|
f@0
|
138 if(!all(checked)){
|
f@0
|
139
|
f@0
|
140 # 3.1.- Random Survival: Complete redraw
|
f@0
|
141 if(selection_type == 'rand_surv')
|
f@0
|
142 while(!all(checked)){
|
f@0
|
143 if(num_draws > max_draws) return(.fail_selection(num_draws))
|
f@0
|
144 sets <- .get_sets()
|
f@0
|
145 num_draws <- num_draws + 1000
|
f@0
|
146 checked <- check_filt(sets$filt, minimum = min_per_class)
|
f@0
|
147 }
|
f@0
|
148 # 3.2.- Class Survival: Redraw only problematic classes
|
f@0
|
149 # 3.3.- Artist-informed Class Survival
|
f@0
|
150 else if(selection_type == 'class_surv' |
|
f@0
|
151 selection_type == 'art_class_surv'){
|
f@0
|
152
|
f@0
|
153 for(fail_class in names(checked[!checked])){
|
f@0
|
154 print(paste0("Redrawing class: ", fail_class))
|
f@0
|
155 if(num_draws > max_draws) return(.fail_selection(num_draws))
|
f@0
|
156 class_checked <- F
|
f@0
|
157 while (!class_checked) {
|
f@0
|
158 print(paste0(" Attempting: ", fail_class))
|
f@0
|
159 class_sets <- .get_class_sets(class = fail_class,
|
f@0
|
160 selection_type = selection_type,
|
f@0
|
161 min_per_class = min_per_class)
|
f@0
|
162 num_draws <- num_draws + 100
|
f@0
|
163 class_checked <- nrow(class_sets$filt) >= min_per_class
|
f@0
|
164 }
|
f@0
|
165 sets <- .update_sets(sets, class_sets)
|
f@0
|
166 }
|
f@0
|
167 }
|
f@0
|
168 }
|
f@0
|
169
|
f@0
|
170 return(list(sets = sets, success = T, draws = num_draws))
|
f@0
|
171
|
f@0
|
172 }
|
f@0
|
173
|
f@0
|
174
|
f@0
|
175 #
|
f@0
|
176 draw_samples <- function(num_iters = 40,
|
f@0
|
177 min_per_class = 10,
|
f@0
|
178 selection_type = "art_class_surv"){
|
f@0
|
179
|
f@0
|
180 sets <- list(train = data.frame(),
|
f@0
|
181 test = data.frame(),
|
f@0
|
182 filt = data.frame())
|
f@0
|
183 sim_results <- data.frame(
|
f@0
|
184 success = logical(0), num_draws = numeric(0))
|
f@0
|
185 i <- 1
|
f@0
|
186 while(i <= num_iters){
|
f@0
|
187
|
f@0
|
188 print("Starting new attempt")
|
f@0
|
189
|
f@0
|
190 res <- survival_selection(
|
f@0
|
191 min_per_class = min_per_class,
|
f@0
|
192 selection_type = selection_type
|
f@0
|
193 )
|
f@0
|
194
|
f@0
|
195 print(
|
f@0
|
196 paste0(
|
f@0
|
197 "Finishing new attempt: ",
|
f@0
|
198 ifelse(res$success, "Success ", "Failure "),
|
f@0
|
199 "after ", res$draws, " draws"))
|
f@0
|
200 print("")
|
f@0
|
201
|
f@0
|
202 sim_results <- rbind(sim_results,
|
f@0
|
203 list(success = res$success,
|
f@0
|
204 num_draws = res$draws))
|
f@0
|
205
|
f@0
|
206 if(res$success){
|
f@0
|
207 sets$train <- rbind(sets$train,
|
f@0
|
208 res$sets$train %>% mutate(iter = i))
|
f@0
|
209 sets$test <- rbind(sets$test,
|
f@0
|
210 res$sets$test %>% mutate(iter = i))
|
f@0
|
211 sets$filt <- rbind(sets$filt,
|
f@0
|
212 res$sets$filt %>% mutate(iter = i))
|
f@0
|
213 i <- i + 1
|
f@0
|
214 }
|
f@0
|
215
|
f@0
|
216 }
|
f@0
|
217
|
f@0
|
218 return(list(sets = sets, sim_results = sim_results))
|
f@0
|
219
|
f@0
|
220 }
|