5

I want to build an adaptative mesh refinement in 3D.

The basic principle is the following:

I have a set of cells with unique cell IDs. I test each cell to see if it needs to be refined.

  • If refinement is required, a create 8 new child cells and add them to the list of cells to check for refinement.
  • Otherwise, this is a leaf node and I add it to my list of leaf nodes.

I want to implement it using the ForkJoin framework and Java 8 streams. I read this article, but I don't know how to apply it to my case.

For now, what I came up with is this:

public class ForkJoinAttempt {
    private final double[] cellIds;

    public ForkJoinAttempt(double[] cellIds) {
        this.cellIds = cellIds;
    }

    public void refineGrid() {
        ForkJoinPool pool = ForkJoinPool.commonPool();
        double[] result = pool.invoke(new RefineTask(100));
    }

    private class RefineTask extends RecursiveTask<double[]> {
        final double cellId;

        private RefineTask(double cellId) {
            this.cellId = cellId;
        }

        @Override
        protected double[] compute() {
            return ForkJoinTask.invokeAll(createSubtasks())
                    .stream()
                    .map(ForkJoinTask::join)
                    .reduce(new double[0], new Concat());
        }
    }

    private double[] refineCell(double cellId) {
        double[] result;
        if (checkCell()) {
            result = new double[8];

            for (int i = 0; i < 8; i++) {
                result[i] = Math.random();
            }

        } else {
            result = new double[1];
            result[0] = cellId;
        }

        return result;
    }

    private Collection<RefineTask> createSubtasks() {
        List<RefineTask> dividedTasks = new ArrayList<>();

        for (int i = 0; i < cellIds.length; i++) {
            dividedTasks.add(new RefineTask(cellIds[i]));
        }
        
        return dividedTasks;
    }

    private class Concat implements BinaryOperator<double[]>  {

        @Override
        public double[] apply(double[] a, double[] b) {
            int aLen = a.length;
            int bLen = b.length;

            @SuppressWarnings("unchecked")
            double[] c = (double[]) Array.newInstance(a.getClass().getComponentType(), aLen + bLen);
            System.arraycopy(a, 0, c, 0, aLen);
            System.arraycopy(b, 0, c, aLen, bLen);

            return c;
        }
    }

    public boolean checkCell() {
        return Math.random() < 0.5;
    }
}

... and I'm stuck here.

This doesn't do much for now, because I never call the refineCell function.

I also might have a performance issue with all those double[] I create. And merging them in this way might not be the most efficient way to do it too.

But first things first, can anyone help me on implementing the fork join in that case?

The expected result of the algorithm is an array of leaf cell IDs (double[])

Edit 1:

Thanks to the comments, I came up with something that works a little better.

Some changes:

  • I went from arrays to lists. This is not good for the memory footprint, because I'm not able to use Java primitives. But it made the implantation simpler.
  • The cell IDs are now Long instead of Double.
  • Ids are not randomly chosen any more:
    • Root level cells have IDs 1, 2, 3 etc.;
    • Children of 1 have IDs 10, 11, 12, etc.;
    • Children of 2 have IDs 20, 21, 22, etc.;
    • You get the idea...
  • I refine all cells whose ID is lower than 100

This allows me for the sake of this example to better check the results.

Here is the new implementation:

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.*;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class ForkJoinAttempt {
    private static final int THRESHOLD = 2;
    private List<Long> leafCellIds;

    public void refineGrid(List<Long> cellsToProcess) {
        leafCellIds = ForkJoinPool.commonPool().invoke(new RefineTask(cellsToProcess));
    }

    public List<Long> getLeafCellIds() {
        return leafCellIds;
    }

    private class RefineTask extends RecursiveTask<List<Long>> {

        private final CopyOnWriteArrayList<Long> cellsToProcess = new CopyOnWriteArrayList<>();

        private RefineTask(List<Long> cellsToProcess) {
            this.cellsToProcess.addAll(cellsToProcess);
        }

        @Override
        protected List<Long> compute() {
            if (cellsToProcess.size() > THRESHOLD) {
                System.out.println("Fork/Join");
                return ForkJoinTask.invokeAll(createSubTasks())
                        .stream()
                        .map(ForkJoinTask::join)
                        .reduce(new ArrayList<>(), new Concat());
            } else {
                System.out.println("Direct computation");
                
                List<Long> leafCells = new ArrayList<>();

                for (Long cell : cellsToProcess) {
                    Long result = refineCell(cell);
                    if (result != null) {
                        leafCells.add(result);
                    }
                }

                return leafCells;
            }
        }

        private Collection<RefineTask> createSubTasks() {
            List<RefineTask> dividedTasks = new ArrayList<>();

            for (List<Long> list : split(cellsToProcess)) {
                dividedTasks.add(new RefineTask(list));
            }

            return dividedTasks;
        }

        private Long refineCell(Long cellId) {
            if (checkCell(cellId)) {
                for (int i = 0; i < 8; i++) {
                    Long newCell = cellId * 10 + i;
                    cellsToProcess.add(newCell);
                    System.out.println("Adding child " + newCell + " to cell " + cellId);
                }
                return null;
            } else {
                System.out.println("Leaf node " + cellId);
                return cellId;
            }
        }

        private List<List<Long>> split(List<Long> list)
        {
            int[] index = {0, (list.size() + 1)/2, list.size()};

            List<List<Long>> lists = IntStream.rangeClosed(0, 1)
                    .mapToObj(i -> list.subList(index[i], index[i + 1]))
                    .collect(Collectors.toList());

            return lists;
        }


    }



    private class Concat implements BinaryOperator<List<Long>> {
        @Override
        public List<Long> apply(List<Long> listOne, List<Long> listTwo) {
            return Stream.concat(listOne.stream(), listTwo.stream())
                    .collect(Collectors.toList());
        }
    }

    public boolean checkCell(Long cellId) {
        return cellId < 100;
    }
}

And the method testing it:

    int initialSize = 4;
    List<Long> cellIds = new ArrayList<>(initialSize);
    for (int i = 0; i < initialSize; i++) {
        cellIds.add(Long.valueOf(i + 1));
    }

    ForkJoinAttempt test = new ForkJoinAttempt();
    test.refineGrid(cellIds);
    List<Long> leafCellIds = test.getLeafCellIds();
    System.out.println("Leaf nodes: " + leafCellIds.size());
    for (Long node : leafCellIds) {
        System.out.println(node);
    }

The output confirms that it adds 8 children to each root cell. But it does not go further.

I know why, but I don't know how to solve it: this is because even though the refineCell method add the new cells to the list of cells to process. The createSubTask method is not called again, so it cannot know I have added new cells.

Edit 2:

To state the problem differently, what I'm looking for is a mechanism where a Queue of cells IDs is processed by some RecursiveTasks while others add to the Queue in parallel.

Community
  • 1
  • 1
Ben
  • 6,321
  • 9
  • 40
  • 76
  • 1
    As you can probably see from the article you've linked, a Fork-Join pool is designed for "divide and conquer" types of action. This means that your recursive action ought to have a condition in which it actually does some work inside the `compute` method. Your implementation doesn't do that to my knowledge, and the closest thing to proper implementation of `compute` I can see in your code is the `refineCell` method, in the branch where it assigns a `Math.random` to a cell. Also, checkCell probably actually needs to know something about the cell, otherwise your description makes little sense. – M. Prokhorov Jan 09 '18 at 12:37
  • I know that "this doesn't do much for now, because I never call the refineCell function". I just don't understand how I should call it. The checkCell method does not take into account the cells, it just selects randomly half of the cells in average. In real life I have an actual function that computes the coordinates of the cell and checks whether or not it needs to be refined. This is provided as a sample reproducible example that is focused on the problem I have. – Ben Jan 09 '18 at 13:58
  • Return to your example article again and look closely: Each task operates on threshold, which means the number of elements that is "OK" (fast enough) to be processed sequentially, thus not needing any subtask. In your case, this is a branch entered when `checkCell == false`. Otherwise, you should spawn child tasks and then join with their results, like in your current `compute`, but that should be moved inside branch with `checkCell == true`. You can also look in JDK code for `Arrays.parallelSort` implementation. That is a classic one as well. – M. Prokhorov Jan 09 '18 at 14:07
  • 2
    Instead of `.map(ForkJoinTask::join) .reduce(new ArrayList<>(), new Concat());` you should use `.flatMap(task -> task.join().stream()) .collect(Collectors.toList())` and get rid of the `Concat` class. The `split` method can be implemented as simple as `int middle = (list.size() + 1)/2; return Arrays.asList(list.subList(0,middle), list.subList(middle, list.size())));` Regarding the threshold, [this answer](https://stackoverflow.com/a/48174508/2711488) might be helpful. But note that you are just reinventing parallel streams here. Currently, I don’t see anything that wouldn’t work with them. – Holger Jan 10 '18 at 09:32
  • Thanks for your useful comment. I don't want to reinvent parallel streams. So if this can be achieved with them, I would be happy to do so. Can you tell me how? – Ben Jan 10 '18 at 17:39

1 Answers1

3

First, let’s start with the Stream based solution

public class Mesh {
    public static long[] refineGrid(long[] cellsToProcess) {
        return Arrays.stream(cellsToProcess).parallel().flatMap(Mesh::expand).toArray();
    }
    static LongStream expand(long d) {
        return checkCell(d)? LongStream.of(d): generate(d).flatMap(Mesh::expand);
    }
    private static boolean checkCell(long cellId) {
        return cellId > 100;
    }
    private static LongStream generate(long cellId) {
        return LongStream.range(0, 8).map(j -> cellId * 10 + j);
    }
}

While the current flatMap implementation has known issues with parallel processing that might apply when the mesh is too unbalanced, the performance for your actual task might be reasonable, so this simple solution is always worth a try, before start to implement something more complicated.

If you really need a custom implementation, e.g. if the workload is unbalanced and the Stream implementation can’t adapt well enough, you can do it like this:

public class MeshTask extends RecursiveTask<long[]> {
    public static long[] refineGrid(long[] cellsToProcess) {
        return new MeshTask(cellsToProcess, 0, cellsToProcess.length).compute();
    }
    private final long[] source;
    private final int from, to;

    private MeshTask(long[] src, int from, int to) {
        source = src;
        this.from = from;
        this.to = to;
    }
    @Override
    protected long[] compute() {
        return compute(source, from, to);
    }
    private static long[] compute(long[] source, int from, int to) {
        long[] result = new long[to - from];
        ArrayDeque<MeshTask> next = new ArrayDeque<>();
        while(getSurplusQueuedTaskCount()<3) {
            int mid = (from+to)>>>1;
            if(mid == from) break;
            MeshTask task = new MeshTask(source, mid, to);
            next.push(task);
            task.fork();
            to = mid;
        }
        int pos = 0;
        for(; from < to; ) {
            long value = source[from++];
            if(checkCell(value)) result[pos++]=value;
            else {
                long[] array = generate(value);
                array = compute(array, 0, array.length);
                result = Arrays.copyOf(result, result.length+array.length-1);
                System.arraycopy(array, 0, result, pos, array.length);
                pos += array.length;
            }
            while(from == to && !next.isEmpty()) {
                MeshTask task = next.pop();
                if(task.tryUnfork()) {
                    to = task.to;
                }
                else {
                    long[] array = task.join();
                    int newLen = pos+to-from+array.length;
                    if(newLen != result.length)
                        result = Arrays.copyOf(result, newLen);
                    System.arraycopy(array, 0, result, pos, array.length);
                    pos += array.length;
                }
            }
        }
        return result;
    }
    static boolean checkCell(long cellId) {
        return cellId > 1000;
    }
    static long[] generate(long cellId) {
        long[] sub = new long[8];
        for(int i = 0; i < sub.length; i++) sub[i] = cellId*10+i;
        return sub;
    }
}

This implementation calls the compute method of the root task directly to incorporate the caller thread into the computation. The compute method uses getSurplusQueuedTaskCount() to decide whether to split. As its documentation says, the idea is to always have a small surplus, e.g. 3. This ensures that the evaluation can adapt to unbalanced workloads as idle threads can steal work from other task.

The splitting is not done by creating two sub-tasks and wait for both. Instead, only one task is split off, representing the second half of the pending work, and the current task’s workload is adapted to reflect the first half.

Then, the remaining workload is processed locally. Afterwards, the last pushed subtask is popped and attempted to unfork. If unforking succeeded, the current workload’s range is adapted to cover the subsequent task’s range too and the local iteration continues.

That way, any surplus task that has not been stolen by another thread is processed in the simplest and most lightweight way, as if it was never forked.

If the task has been picked up by another thread, we have to wait for its completion now and merge the result array.

Note that when waiting for a sub task via join(), the underlying implementation will also check if unforking and local evaluation is possible, to keep all worker threads busy. However, adjusting our loop variable and directly accumulating the results in our target array is still better than a nested compute invocation that still needs merging the result arrays.

If a cell is not a leaf, the resulting nodes are processed recursively by the same logic. This again allows for adaptive local and concurrent evaluation, so the execution will adapt to unbalanced workloads, e.g. if a particular cell has a larger subtree or the evaluation of a particular cell taskes much longer than others.

It must be emphasized that in all cases, a significant processing workload is needed to draw a benefit from parallel processing. If, like in the example, there is mostly data copying only, the benefit might be much smaller, non-existent or in the worst case, the parallel processing may perform worse than sequential.

Holger
  • 285,553
  • 42
  • 434
  • 765
  • Wow, that's a great answer, thank you! I tried the stream version in a real-life case (a sphere refined up to level 9) with 1,192,192 cells. 721ms. Impressive! – Ben Jan 11 '18 at 20:36
  • One additional question though about LongStream.range(0, 8).map(j -> cellId * 10 + j); Suppose instead of have only j as input I need three integers i, j and k? Can you do this with range? – Ben Jan 11 '18 at 20:53
  • I found an answer here: https://stackoverflow.com/questions/26439163/java-8-foreach-over-multiple-intstreams – Ben Jan 12 '18 at 06:11
  • 2
    For numerical variables you can use, e.g. `LongStream.range(0, 8).flatMap(i -> LongStream.range(0, 8).flatMap(j -> LongStream.range(0, 8).map(k -> cellId*1000+i*100+j*10+k)))` – Holger Jan 12 '18 at 09:46