23

I an ETL process I'm retrieving a lot of entities from a Spring Data Repository. I'm then using a parallel stream to map the entities to different ones. I can either use a consumer to store those new entities in another repository one by one or collect them into a List and store that in a single bulk operation. The first is costly while the later might exceed the available memory.

Is there a good way to collect a certain amount of elements in the stream (like limit does), consume that chunk, and keep on going in parallel until all elements are processed?

Stuart Marks
  • 127,867
  • 37
  • 205
  • 259
Christoph Grimmer
  • 4,210
  • 4
  • 40
  • 64
  • Possible duplicate of [Java 8 Stream with batch processing](https://stackoverflow.com/questions/30641383/java-8-stream-with-batch-processing) – ryenus Nov 09 '17 at 10:05

5 Answers5

21

My approach to bulk operations with chunking is to use a partitioning spliterator wrapper, and another wrapper which overrides the default splitting policy (arithmetic progression of batch sizes in increments of 1024) to simple fixed-batch splitting. Use it like this:

Stream<OriginalType> existingStream = ...;
Stream<List<OriginalType>> partitioned = partition(existingStream, 100, 1);
partitioned.forEach(chunk -> ... process the chunk ...);

Here is the full code:

import java.util.ArrayList;
import java.util.List;
import java.util.Spliterator;
import java.util.Spliterators.AbstractSpliterator;
import java.util.function.Consumer;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public class PartitioningSpliterator<E> extends AbstractSpliterator<List<E>>
{
  private final Spliterator<E> spliterator;
  private final int partitionSize;

  public PartitioningSpliterator(Spliterator<E> toWrap, int partitionSize) {
    super(toWrap.estimateSize(), toWrap.characteristics() | Spliterator.NONNULL);
    if (partitionSize <= 0) throw new IllegalArgumentException(
        "Partition size must be positive, but was " + partitionSize);
    this.spliterator = toWrap;
    this.partitionSize = partitionSize;
  }

  public static <E> Stream<List<E>> partition(Stream<E> in, int size) {
    return StreamSupport.stream(new PartitioningSpliterator(in.spliterator(), size), false);
  }

  public static <E> Stream<List<E>> partition(Stream<E> in, int size, int batchSize) {
    return StreamSupport.stream(
        new FixedBatchSpliterator<>(new PartitioningSpliterator<>(in.spliterator(), size), batchSize), false);
  }

  @Override public boolean tryAdvance(Consumer<? super List<E>> action) {
    final ArrayList<E> partition = new ArrayList<>(partitionSize);
    while (spliterator.tryAdvance(partition::add) 
           && partition.size() < partitionSize);
    if (partition.isEmpty()) return false;
    action.accept(partition);
    return true;
  }

  @Override public long estimateSize() {
    final long est = spliterator.estimateSize();
    return est == Long.MAX_VALUE? est
         : est / partitionSize + (est % partitionSize > 0? 1 : 0);
  }
}

import static java.util.Spliterators.spliterator;

import java.util.Comparator;
import java.util.Spliterator;
import java.util.function.Consumer;

public abstract class FixedBatchSpliteratorBase<T> implements Spliterator<T> {
  private final int batchSize;
  private final int characteristics;
  private long est;

  public FixedBatchSpliteratorBase(int characteristics, int batchSize, long est) {
    characteristics |= ORDERED;
    if ((characteristics & SIZED) != 0) characteristics |= SUBSIZED;
    this.characteristics = characteristics;
    this.batchSize = batchSize;
    this.est = est;
  }
  public FixedBatchSpliteratorBase(int characteristics, int batchSize) {
    this(characteristics, batchSize, Long.MAX_VALUE);
  }
  public FixedBatchSpliteratorBase(int characteristics) {
    this(characteristics, 64, Long.MAX_VALUE);
  }

  @Override public Spliterator<T> trySplit() {
    final HoldingConsumer<T> holder = new HoldingConsumer<>();
    if (!tryAdvance(holder)) return null;
    final Object[] a = new Object[batchSize];
    int j = 0;
    do a[j] = holder.value; while (++j < batchSize && tryAdvance(holder));
    if (est != Long.MAX_VALUE) est -= j;
    return spliterator(a, 0, j, characteristics());
  }
  @Override public Comparator<? super T> getComparator() {
    if (hasCharacteristics(SORTED)) return null;
    throw new IllegalStateException();
  }
  @Override public long estimateSize() { return est; }
  @Override public int characteristics() { return characteristics; }

  static final class HoldingConsumer<T> implements Consumer<T> {
    Object value;
    @Override public void accept(T value) { this.value = value; }
  }
}

import static java.util.stream.StreamSupport.stream;

import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.stream.Stream;

public class FixedBatchSpliterator<T> extends FixedBatchSpliteratorBase<T> {
  private final Spliterator<T> spliterator;

  public FixedBatchSpliterator(Spliterator<T> toWrap, int batchSize, long est) {
    super(toWrap.characteristics(), batchSize, est);
    this.spliterator = toWrap;
  }
  public FixedBatchSpliterator(Spliterator<T> toWrap, int batchSize) {
    this(toWrap, batchSize, toWrap.estimateSize());
  }
  public FixedBatchSpliterator(Spliterator<T> toWrap) {
    this(toWrap, 64, toWrap.estimateSize());
  }

  public static <T> Stream<T> withBatchSize(Stream<T> in, int batchSize) {
    return stream(new FixedBatchSpliterator<>(in.spliterator(), batchSize), true);
  }

  public static <T> FixedBatchSpliterator<T> batchedSpliterator(Spliterator<T> toWrap, int batchSize) {
    return new FixedBatchSpliterator<>(toWrap, batchSize);
  }

  @Override public boolean tryAdvance(Consumer<? super T> action) {
    return spliterator.tryAdvance(action);
  }
  @Override public void forEachRemaining(Consumer<? super T> action) {
    spliterator.forEachRemaining(action);
  }
}
Marko Topolnik
  • 195,646
  • 29
  • 319
  • 436
  • Shouldn't `getComparator()` be defined in `FixedBatchSpliterator` and delegate, i.e: `return spliterator.getComparator();` ? – David Soroko Jan 20 '16 at 15:51
  • Possibly. The `getComparator()` API remains somewhat of a mystery to me. – Marko Topolnik Jan 21 '16 at 12:36
  • what is the difference between `FixedBatchSpliterator` and `PartitioningSpliterator` ? I don't get the `partition(existingStream, 100, 1)` – Maelig Dec 13 '16 at 16:39
  • 2
    @Titmael `FixedBatchSpliterator` is about non-semantic batching of data delivered to ecah worker thread. The effects of this batching are not visible to user-supplied lambdas. `PartitioningSpliterator`, on the other hand, produces chunks of data explicitly passed to your lambda. – Marko Topolnik Dec 14 '16 at 10:43
  • In my case I need to split my stream in chunks of 500 items, I can use only `PartitioningSpliterator#partition(Stream in, int size)` ? – Maelig Dec 14 '16 at 11:13
4

You might be able to write your own Collector that accumulates entities and then performs bulk updates.

The Collector.accumulator() method can add the entities to an internal temp cache until the cache grows too large. When the cache is large enough you can do a bulk store into your other repository.

Collector.merge() needs to combine 2 thread's Collector's caches into a single cache (and possibly merge)

Finally, the Collector.finisher() method is called when the Stream is done so store anything left in the cache here too.

Since you are already using a parallel stream and seem OK with doing multiple loads at the same time, I assume you have thread safety already handled.

UPDATE

My comment regarding thread safety and parallel streams was referring to the actual saving/storing into the repository, not concurrency in your temp collection.

Each Collector should (I think) be run in its own thread. A parallel stream should create multiple collector instances by calling supplier() multiple times. So you can treat a collector instance as single threaded and it should work fine.

For example in the Javadoc for java.util.IntSummaryStatistics it says:

This implementation is not thread safe. However, it is safe to use Collectors.toIntStatistics() on a parallel stream, because the parallel implementation of Stream.collect() provides the necessary partitioning, isolation, and merging of results for safe and efficient parallel execution.
dkatzel
  • 31,188
  • 3
  • 63
  • 67
  • When filling a concurrency save collection by either using peek() or Collector.accumulator() I can not safely determine when my cache has reached, say, 1000 entries. I would have to lock the collection, take the count, retrieve all entries if filled to my desired level, and release the collection again. Which will kill parallelism dead. I had hope that there is a _good_ way to do it somewhere hidden in the stream API... – Christoph Grimmer Aug 21 '14 at 06:37
  • 1
    @ChristophGrimmer-Dietrich I'm not sure you need to worry about that. Each Collector *should* (I think) be run in its own thread. A parrallel stream should create multiple collector instances by calling `supplier()` multiple times. I will update my answer – dkatzel Aug 21 '14 at 18:09
1

You could use a custom collector to do this elegantly.

Please see my answer to a similar question here:

Custom batch processing collector

Then, you can simply batch process the stream in parallel using the above collector to store the records back in your repository, example usage:

List<Integer> input = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);

int batchSize = 3;
Consumer<List<Integer>> batchProcessor = xs -> repository.save(xs);

input.parallelStream()
     .map(i -> i + 1)
     .collect(StreamUtils.batchCollector(batchSize, batchProcessor));
Community
  • 1
  • 1
rohitvats
  • 1,811
  • 13
  • 11
0
  @Test
public void streamTest(){

    Stream<Integer> data = Stream.generate(() -> {
        //Block on IO
        return blockOnIO();
    });


    AtomicInteger countDown = new AtomicInteger(1000);
    final ArrayList[] buffer = new ArrayList[]{new ArrayList<Integer>()};
    Object syncO = new Object();
    data.parallel().unordered().map(i -> i * 1000).forEach(i->{
        System.out.println(String.format("FE %s %d",Thread.currentThread().getName(), buffer[0].size()));
        int c;
        ArrayList<Integer> export=null;
        synchronized (syncO) {
            c = countDown.addAndGet(-1);
            buffer[0].add(i);
            if (c == 0) {
                export=buffer[0];
                buffer[0] = new ArrayList<Integer>();
                countDown.set(1000);
            }
        }
        if(export !=null){
            sendBatch(export);
        }

    });
    //export any remaining
    sendBatch(buffer[0]);
}

Integer blockOnIO(){
    try {
        Thread.sleep(50);
        return Integer.valueOf((int)Math.random()*1000);
    } catch (InterruptedException e) {
        throw new RuntimeException(e);
    }
}

void sendBatch(ArrayList al){
    assert al.size() == 1000;
    System.out.println(String.format("LOAD %s %d",Thread.currentThread().getName(), al.size()));
}

This is maybe somewhat old-fashion but should achieve batching with a minimum of locking.

It will produce output as

FE ForkJoinPool.commonPool-worker-2 996
FE ForkJoinPool.commonPool-worker-5 996
FE ForkJoinPool.commonPool-worker-4 998
FE ForkJoinPool.commonPool-worker-3 999
LOAD ForkJoinPool.commonPool-worker-3 1000
FE ForkJoinPool.commonPool-worker-6 0
FE ForkJoinPool.commonPool-worker-1 2
FE ForkJoinPool.commonPool-worker-7 2
FE ForkJoinPool.commonPool-worker-2 4
David Lilljegren
  • 1,799
  • 16
  • 19
0

Here is solution by My Library: abacus-common:

stream.split(batchSize).parallel(threadNum).map(yourBatchProcessFunction);
user_3380739
  • 1
  • 14
  • 14