33

How do you group first and then apply filtering using Java streams?

Example: Consider this Employee class: I want to group by Department with a list of an employee having a salary greater than 2000.

public class Employee {
    private String department;
    private Integer salary;
    private String name;

    //getter and setter

    public Employee(String department, Integer salary, String name) {
        this.department = department;
        this.salary = salary;
        this.name = name;
    }
}   

This is how I can do this

List<Employee> list   = new ArrayList<>();
list.add(new Employee("A", 5000, "A1"));
list.add(new Employee("B", 1000, "B1"));
list.add(new Employee("C", 6000, "C1"));
list.add(new Employee("C", 7000, "C2"));

Map<String, List<Employee>> collect = list.stream()
    .filter(e -> e.getSalary() > 2000)
    .collect(Collectors.groupingBy(Employee::getDepartment));  

Output

{A=[Employee [department=A, salary=5000, name=A1]],
 C=[Employee [department=C, salary=6000, name=C1], Employee [department=C, salary=7000, name=C2]]}

As there are no employees in Department B with a salary greater than 2000. So there is no key for Department B: But actually, I want to have that key with empty list –

Expected output

{A=[Employee [department=A, salary=5000, name=A1]],
 B=[],
 C=[Employee [department=C, salary=6000, name=C1], Employee [department=C, salary=7000, name=C2]]}

How can we do this?

Naman
  • 27,789
  • 26
  • 218
  • 353
Niraj Sonawane
  • 10,225
  • 10
  • 75
  • 104
  • 4
    Version tags should be used for questions specific to that version. If this is about streams across multiple versions, it shouldn't have either tag IMO. – shmosel Jan 16 '18 at 02:20

5 Answers5

36

You can make use of the Collectors.filtering API introduced since Java-9 for this:

Map<String, List<Employee>> output = list.stream()
            .collect(Collectors.groupingBy(Employee::getDepartment,
                    Collectors.filtering(e -> e.getSalary() > 2000, Collectors.toList())));

Important from the API note :

  • The filtering() collectors are most useful when used in a multi-level reduction, such as downstream of a groupingBy or partitioningBy.

  • A filtering collector differs from a stream's filter() operation.

Naman
  • 27,789
  • 26
  • 218
  • 353
  • 1
    Interesting, I wouldn't have thought this behaves differently from filtering the stream directly. – shmosel Jan 16 '18 at 02:12
  • 5
    @shmosel it does the same when you pass `filtering(…)` directly to the `collect` method, e.g. `filtering(…, groupingBy(…))`. But when you pass it to `groupingBy` as downstream collector, i.e. `groupingBy(…, filtering(…))`, it will receive the elements after the group has been created. It’s as simple as that and similar to how `mapping` or `flatMapping` work. – Holger Jan 16 '18 at 08:07
22

nullpointer’s answer shows the straight-forward way to go. If you can’t update to Java 9, no problem, this filtering collector is no magic. Here is a Java 8 compatible version:

public static <T, A, R> Collector<T, ?, R> filtering(
    Predicate<? super T> predicate, Collector<? super T, A, R> downstream) {

    BiConsumer<A, ? super T> accumulator = downstream.accumulator();
    return Collector.of(downstream.supplier(),
        (r, t) -> { if(predicate.test(t)) accumulator.accept(r, t); },
        downstream.combiner(), downstream.finisher(),
        downstream.characteristics().toArray(new Collector.Characteristics[0]));
}

You can add it to your codebase and use it the same way as Java 9’s counterpart, so you don’t have to change the code in any way if you’re using import static.

Naman
  • 27,789
  • 26
  • 218
  • 353
Holger
  • 285,553
  • 42
  • 434
  • 765
6

Use Map#putIfAbsent(K,V) to fill in the gaps after filtering

Map<String, List<Employee>> map = list.stream()
              .filter(e->e.getSalary() > 2000)
              .collect(Collectors.groupingBy(Employee::getDepartment, HashMap::new, toList()));
list.forEach(e->map.putIfAbsent(e.getDepartment(), Collections.emptyList()));

Note: Since the map returned by groupingBy is not guaranteed to be mutable, you need to specify a Map Supplier to be sure (thanks to shmosel for pointing that out).


Another (not recommended) solution is using toMap instead of groupingBy, which has the downside of creating a temporary list for every Employee. Also it looks a bit messy

Predicate<Employee> filter = e -> e.salary > 2000;
Map<String, List<Employee>> collect = list.stream().collect(
        Collectors.toMap(
            e-> e.department, 
            e-> new ArrayList<Employee>(filter.test(e) ? Collections.singleton(e) : Collections.<Employee>emptyList()) , 
            (l1, l2)-> {l1.addAll(l2); return l1;}
        )
);
Michael A. Schaffrath
  • 1,992
  • 1
  • 14
  • 23
  • 1
    You should change the function to `e-> new ArrayList<>(filter.test(e)? Collections.singleton(e): Collections.emptyList())`, to ensure that it will always return (mutable) `ArrayList`. Otherwise, you risk calling `addAll` on the (immutable) result of `emptyList()` in the merge function. Or you let the function always create an immutable list, `e-> filter.test(e)? Collections.singleton(e): Collections.emptyList()` and change the merge function to create a new list. – Holger Jan 16 '18 at 08:47
  • 1
    Of course, I meant, `e-> filter.test(e)? Collections.singletonList(e): Collections.emptyList()` in my last comment. – Holger Jan 16 '18 at 09:00
  • Right again. Changed it. Thanks :-) – Michael A. Schaffrath Jan 16 '18 at 09:16
  • 3
    The map returned by `groupingBy()` is not guaranteed to be mutable. You'll need to use the `Supplier` overload. – shmosel Jan 16 '18 at 09:48
3

There is no cleaner way of doing this in Java 8: Holger has shown clear approach in java8 here Accepted the Answer.

This is how I have done it in java 8:

Step: 1 Group by Department

Step: 2 loop throw each element and check if department has an employee with salary >2000

Step: 3 update the map copy values in new map based on noneMatch

Map<String, List<Employee>> employeeMap = list.stream().collect(Collectors.groupingBy(Employee::getDepartment));
Map<String, List<Employee>> newMap = new HashMap<String,List<Employee>>();
         employeeMap.forEach((k, v) -> {
            if (v.stream().noneMatch(emp -> emp.getSalary() > 2000)) {
                newMap.put(k, new ArrayList<>());
            }else{
                newMap.put(k, v);
           }

        });

Java 9 : Collectors.filtering

java 9 has added new collector Collectors.filtering this group first and then applies filtering. filtering Collector is designed to be used along with grouping.

The Collectors.Filtering takes a function for filtering the input elements and a collector to collect the filtered elements:

list.stream().collect(Collectors.groupingBy(Employee::getDepartment),
 Collectors.filtering(e->e.getSalary()>2000,toList());
Niraj Sonawane
  • 10,225
  • 10
  • 75
  • 104
  • 3
    You can't add to a map while iterating over it. You'll get a ConcurrentModificationException. – shmosel Jan 16 '18 at 02:19
  • 2
    Here's a simple demonstration in case you don't believe me: https://ideone.com/SRtWiJ – shmosel Jan 16 '18 at 02:23
  • Your logic is faulty. Why are you only clearing a list if *all* salaries are <= 2000? You should be evaluating each element individually. – shmosel Jan 16 '18 at 02:40
  • 1
    Why don't you use instead `toCollection(ArrayList::new)` and then `employeeMap.values().forEach(list -> list.removeIf(e -> e.getSalary() <= 2000));`? – shmosel Jan 16 '18 at 02:43
  • creating a list for saving the final output and adding emptyList if we don't have any emp with salary gr>2000.Otherwise, copy the same list. I am not evaluating each element individually. I think noneMatch will be better than list.removeIf – Niraj Sonawane Jan 16 '18 at 02:53
  • 2
    I don't know what you're saying. Your current solution doesn't filter out individual employees. – shmosel Jan 16 '18 at 08:10
2

Java 8 version: You can make grouping by Department and then stream the entry set and do the collect again with adding the predicate in filter:

    Map<String, List<Employee>> collect = list.stream()
        .collect(Collectors.groupingBy(Employee::getDepartment)).entrySet()
        .stream()
        .collect(Collectors.toMap(Map.Entry::getKey,
            entry -> entry.getValue()
                .stream()
                .filter(employee -> employee.getSalary() > 2000)
                .collect(toList())
            )
        );