3

I want to sample n rows from each of the classes of my Spark DataFrame in sparklyr.

I understand that the dplyr::sample_n function can't be used for this (Is sample_n really a random sample when used with sparklyr?) so I have used the sparklyr::sdf_sample() function. The problem with this is that I can't sample by group i.e get 10 observations from each class, I can only specify the fraction of the entire dataset to sample.

I have a workaround to use sdf_sample() on each group individually in a loop, but since the function does not return an exact sample size, this is still not ideal.

R code for workaround:

library(sparklyr)
library(dplyr)

sc <- spark_connect(master = "local", version = "2.3")

# copy iris to our spark cluster
iris_tbl <- copy_to(sc, iris, overwrite = TRUE)


# get class counts
class_counts <- iris_tbl %>% count(Species) %>%
  collect()
#  Species        n
#  <chr>      <dbl>
#1 versicolor    50
#2 virginica     50
#3 setosa        50

# we want to sample n = 10 points from each class
n <- 10 

sampled_iris <- data.frame(stringsAsFactors = F)
for( i in seq_along(class_counts$Species)){

  my_frac <- n / class_counts[[i, 'n']]
  my_class <- class_counts[[i, 'Species']]

  tmp <- iris_tbl %>%
    filter(Species == my_class) %>%
    sdf_sample(fraction = my_frac) %>%
    collect()

  sampled_iris <- bind_rows(sampled_iris, tmp)
}

We don't get exactly 10 samples from each class:

# new counts
sampled_iris %>% count(Species)


#Species        n
#  <chr>      <int>
#1 setosa         7
#2 versicolor     9
#3 virginica      6

I'm wondering if there's a better way to get a balanced sample across groups using sparklyr? Or even using a sql query which I can pass directly to the cluster using DBI::dbGetQuery()?

Chris
  • 3,836
  • 1
  • 16
  • 34

1 Answers1

4

I can't sample by group

As long as the grouping column is string (that's limitation of sparklyr type mapping) that part can easily handled using DataFrameStatFunctions.sampleBy:

spark_dataframe(iris_tbl) %>%
  sparklyr::invoke("stat") %>%
  sparklyr::invoke(
    "sampleBy",
    "Species",
    fractions=as.environment(list(
      "setosa"=0.2,
      "versicolor"=0.2,
      "virginica"=0.2
    )),
    seed=1L
  ) %>% sparklyr::sdf_register()

However no distributed and scalable method will give you "exact sample size". It is possible to use hacks such as:

iris_tbl %>% 
  group_by(Species) %>% 
  mutate(rand = rand()) %>%
  arrange(rand, .by_group=TRUE) %>%
  filter(row_number() <= 10) %>%
  select(-rand)

but such methods, which depend on window functions, are highly sensitive to skewed data distributions, and in general don't scale well.

If samples are small you can push this a bit further, but oversampling first (using the first method) and then taking exact samples (using the second method), but if your data is large enough to be processed with Spark, small fluctuations shouldn't really matter.

10465355
  • 4,481
  • 2
  • 20
  • 44
  • Thanks, that's just what I was looking for! I agree we don't really need exact samples with large amounts of data, it was more a bonus if we could, but I do like the idea of oversampling by a factor and then cutting down as you suggest. – Chris Jan 13 '20 at 23:53