I want to sample n rows from each of the classes of my Spark DataFrame in sparklyr
.
I understand that the dplyr::sample_n
function can't be used for this (Is sample_n really a random sample when used with sparklyr?) so I have used the sparklyr::sdf_sample()
function. The problem with this is that I can't sample by group i.e get 10 observations from each class, I can only specify the fraction of the entire dataset to sample.
I have a workaround to use sdf_sample()
on each group individually in a loop, but since the function does not return an exact sample size, this is still not ideal.
R code for workaround:
library(sparklyr)
library(dplyr)
sc <- spark_connect(master = "local", version = "2.3")
# copy iris to our spark cluster
iris_tbl <- copy_to(sc, iris, overwrite = TRUE)
# get class counts
class_counts <- iris_tbl %>% count(Species) %>%
collect()
# Species n
# <chr> <dbl>
#1 versicolor 50
#2 virginica 50
#3 setosa 50
# we want to sample n = 10 points from each class
n <- 10
sampled_iris <- data.frame(stringsAsFactors = F)
for( i in seq_along(class_counts$Species)){
my_frac <- n / class_counts[[i, 'n']]
my_class <- class_counts[[i, 'Species']]
tmp <- iris_tbl %>%
filter(Species == my_class) %>%
sdf_sample(fraction = my_frac) %>%
collect()
sampled_iris <- bind_rows(sampled_iris, tmp)
}
We don't get exactly 10 samples from each class:
# new counts
sampled_iris %>% count(Species)
#Species n
# <chr> <int>
#1 setosa 7
#2 versicolor 9
#3 virginica 6
I'm wondering if there's a better way to get a balanced sample across groups using sparklyr? Or even using a sql query which I can pass directly to the cluster using DBI::dbGetQuery()
?