4

I have a dataset with some columns and a grouping variable. I want to reduce the dataset per grouping variable, max_n rows per grouping level. At the same time I want to keep the distribution of the other columns. What I mean by that is that I want to keep the lowest and highest values of a and b after the data has been filtered. That is why I use the function setorderv below.

library(data.table)

set.seed(22)
n=20
max_n = 6
dt <- data.table("grp"=sample(c("a", "b", "c"), n, replace=T),
                 "a"=sample(1:10, n, replace=T),
                 "b"=sample(1:20, n, replace=T),
                 "id"=1:n)
setorderv(dt, c("grp", "a", "b"))
dt

My temporary solution, which is not very elegant or data.table-ish, goes like this:

dt_new <- data.table()
for (gr in unique(dt[["grp"]])) {
  tmp <- dt[grp == gr, ]
  n_tmp <- nrow(tmp)
  if (n_tmp > max_n) {
    tmp <- tmp[as.integer(seq(1, n_tmp, length.out=max_n)),]
  }
  dt_new <- rbindlist(list(dt_new, tmp))
}

Is there a more elegant way of doing this? EDIT: I want a data.table solution.

Code right now is too bulky

EnFiFa
  • 43
  • 5
  • What do you mean by keeping the distribution of other columns? Can you elaborate a bit? – Shibaprasadb Jul 12 '23 at 10:35
  • Yes, you see that I'm sorting the data on the variables! So when I choose the first and the last index of tmp, I'm basically choosing the lowest and highest value of combination c("a", "b"). Makes sense? – EnFiFa Jul 12 '23 at 10:40
  • Okay. So basically you want to reduce the dataset but you want the range to stay the same? How do you select the "highest" and "lowest" across two groups though? a=5,b=7 and a=6,b=5. What will be the highest in this case? @EnFiFa – Shibaprasadb Jul 12 '23 at 10:48
  • Hi EnFiFa! Welcome to StackOverflow! – Mark Jul 12 '23 at 10:54
  • I was wondering if you could clarify for me, when you say you want to get the lowest and highest values of a and b, you have (quite unfortunately) given the same name to the group column variables and the columns, so it's hard to know which one you are referring to – Mark Jul 12 '23 at 10:55
  • @Shibaprasadb Well, right now, since I'm sorting on c("a", "b"), (5, 7) will be lower than (6, 5). – EnFiFa Jul 12 '23 at 10:56
  • When you say you want to keep the distribution of the other columns, I assume you refer to this part in your code `as.integer(seq(1, n_tmp, length.out=max_n))` correct? You don't want to take the first 6 rows as in `seq_len(max_n)`, right? – TimTeaFan Jul 12 '23 at 10:57
  • I don't know if I follow EnFiFa. Maybe you could rename the groups or the columns so it's easier to distinguish – Mark Jul 12 '23 at 10:58
  • 1
    @TimTeaFan exactly! – EnFiFa Jul 12 '23 at 10:58
  • @Mark exactly what is it you mean "given the same name to the group column variables and the columns"? – EnFiFa Jul 12 '23 at 10:59
  • what I mean is, when you say a and b, it could refer to the groups, or the columns. – Mark Jul 12 '23 at 11:00

2 Answers2

5

To keep the min (of a and b), max (ditto), and total max_n rows randomly from a data.table:

dt[, minmax := a %in% range(a) | b %in% range(b), by = grp]
set.seed(42)
dt[, .SD[minmax | 1:.N %in% head(sample(which(!minmax)), max_n - sum(minmax)),], grp]
#        grp     a        id minmax
#     <char> <int> <int> <int> <lgcl>
#  1:      a     1    11    14   TRUE
#  2:      a     2     9    13  FALSE
#  3:      a     2    19    17   TRUE
#  4:      a     5     7     6   TRUE
#  5:      a     8    12    19  FALSE
#  6:      a     9    11     7   TRUE
#  7:      b     1    20     1   TRUE
#  8:      b     2     1    16   TRUE
#  9:      b     3    19     3  FALSE
# 10:      b     4     3    11  FALSE
# 11:      b     7    10    18  FALSE
# 12:      b     9    17    10   TRUE
# 13:      c     1    16    12   TRUE
# 14:      c     3    14    20  FALSE
# 15:      c     5    18     9  FALSE
# 16:      c     6    20     5   TRUE
# 17:      c     7    13     8   TRUE
dt[, minmax := NULL] # cleanup

Walk-through:

  • minmax is true where either a or b is the min/max per group (min/max by-variable by-group)
  • which(!minmax) returns the row indices of remaining rows (where a and b are not min/max)
  • sample(.) randomizes the list of remaining row indices, and head(., max_n - sum(minmax)) returns no more than the number of rows needed to end up with max_n rows
  • minmax | 1:.N %in% .. reduces to the rows; in the special case where the number of rows not including min/max of a/b is fewer than max_n, this guarantees a return of all rows

Data

dt <- data.table::as.data.table(structure(list(grp = c("a", "a", "a", "a", "a", "a", "a", "b", "b", "b", "b", "b", "b", "b", "b", "c", "c", "c", "c", "c"), a = c(1L, 2L, 2L, 5L, 8L, 8L, 9L, 1L, 2L, 3L, 4L, 6L, 7L, 8L, 9L, 1L, 3L, 5L, 6L, 7L), b = c(11L, 9L, 19L, 7L, 11L, 12L, 11L, 20L, 1L, 19L, 3L, 3L, 10L, 10L, 17L, 16L, 14L, 18L, 20L, 13L), id = c(14L, 13L, 17L, 6L, 2L, 19L, 7L, 1L, 16L, 3L, 11L, 15L, 18L, 4L, 10L, 12L, 20L, 9L, 5L, 8L)), row.names = c(NA, -20L), class = c("data.table", "data.frame")))
r2evans
  • 141,215
  • 6
  • 77
  • 149
2

You can do something like this

First, let's identify the rows to keep

library(tidyverse)


dt %>%
  group_by(grp) %>%
  arrange(grp, a, b) %>%
  mutate(
    new_id = 1:n(),
    id_desc = case_when(
      new_id == min(new_id) ~ 'Head',
      new_id == max(new_id) ~ 'Tail',
      .default = 'Other'
    )
  ) %>%
  ungroup() -> dt_modified

Then sample rows from "Others". You can use slice_head() if you need the top observations instead of randomly sampling

dt_modified %>%
  filter(id_desc == 'Other') %>%
  group_by(grp) %>%
  slice_sample(n = max_n - 2) %>% #Because head and tail will be there in the dataframe
  ungroup() %>%
  bind_rows(dt_modified %>%
              filter(id_desc != 'Other')) %>%
  arrange(grp, a, b) %>%
  select(-new_id, -id_desc)-> dt_new

The benefit of using this method is it will be much faster than the loops for large datasets.

Shibaprasadb
  • 1,307
  • 1
  • 7
  • 22