1

I developed a piece of code which is multi-threaded. This code is called within a web application, so potentially by several Threads (requests) in parallel. To keep under control the number of Threads this code is going to create (by being called by several parallel requests), I use a static shared ThreadPoolExecutor (Executors.newFixedThreadPool(nbOfThreads)). So I am sure that this code is never going to create more than nbOfThreads threads. To follow the tasks involved in a given request and wait that they are finished, I use a CompletionService for each request.

Now, I would like to have a bit of “fairness” (not sure it’s the good word) in the way threads of the pool are given to requests. With the default fixed ThreadPoolExecutor, the waiting queue is a LinkedBlockingQueue. It gives tasks to the Executor according to their arriving order (FIFO). Imagine that the pool core size is 100 threads. A 1st request is big and involves the creation of 150 tasks. So it is going to make the pool full and put 50 tasks in the waiting queue. If a 2nd tiny request arrives 1 second later, even if it needs only 2 threads from the pool, it will have to wait that all the 150 tasks created by the first big request are finished before being processed.

How to make the pool fairly and evenly giving threads to each request ? How to make the 2 tasks of the 2nd request not waiting after all the 50 waiting tasks of the 1st query ?

My idea was to develop a personal implementation of BlockingQueue to give to the ThreadPoolExecutor. This BlockingQueue would store the waiting tasks classified by the requests which created them (in a backing Map with the id of the request in key and a LinkedBlockingQueue storing the tasks of the request in value). Then when the ThreadPoolExecutor take or poll a new task from the queue, the queue will give a task from a different request each time … Is it the correct approach ? The use-case seems quite common to me. I am surprised that I have to implement such custom and tedious stuff myself. That’s why I am thinking that I may be wrong and there exists a well-know best practice to do that.

Here is the code I did. It works but still wondering if this is the right approach.

public class TestThreadPoolExecutorWithTurningQueue {

    private final static Logger logger = LogManager.getLogger();

    private static ThreadPoolExecutor executorService;

    int nbRequest = 4;

    int nbThreadPerRequest = 8;

    int threadPoolSize = 5;

    private void init() {
        executorService = new ThreadPoolExecutor(threadPoolSize, threadPoolSize, 0L, TimeUnit.MILLISECONDS,
                new CategoryBlockingQueue<Runnable>()// my custom blocking queue storing waiting tasks per request
                //new LinkedBlockingQueue<Runnable>()
        );
    }

    @Test
    public void test() throws Exception {
        init();
        // Parallel requests arriving
        ExecutorService tomcat = Executors.newFixedThreadPool(nbRequest);
        for (int i = 0; i < nbRequest; i++) {
            Thread.sleep(10);
            final int finalI = i;
            tomcat.execute(new Runnable() {
                @Override
                public void run() {
                    request(finalI);
                }
            });
        }
        tomcat.shutdown();
        tomcat.awaitTermination(1, TimeUnit.DAYS);
    }

    // Code executed by each request
    // Code multi-threaded using a single shared ThreadPoolExecutor to keep the 
    // number of threads under control
    public void request(final int requestId) {
        final List<Future<Object>> futures = new ArrayList<>();
        CustomCompletionService<Object> completionService = new CustomCompletionService<>(executorService);
        for (int j = 0; j < nbThreadPerRequest; j++) {
            final int finalJ = j;
            futures.add(completionService.submit(new CategoryRunnable(requestId) {
                @Override
                public void run() {
                    logger.debug("thread " + finalJ + " of request " + requestId);
                    try {
                        // here should come the useful things to be done
                        Thread.sleep(2000);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }, null));
        }
        // Wait fot completion of all the tasks of the request
        // If a task threw an exception, cancel the other tasks of the request
        for (int j = 0; j < nbThreadPerRequest; j++) {
            try {
                completionService.take().get();
            } catch (Exception e) {
                // Cancel the remaining tasks
                for (Future<Object> future : futures) {
                    future.cancel(true);
                }

                // Get the underlying exception
                Exception toThrow = e;
                if (e instanceof ExecutionException) {
                    ExecutionException ex = (ExecutionException) e;
                    toThrow = (Exception) ex.getCause();
                }
                throw new RuntimeException(toThrow);
            }
        }
    }

    public class CustomCompletionService<V> implements CompletionService<V> {

        private final Executor executor;

        private final BlockingQueue<Future<V>> completionQueue;

        public CustomCompletionService(Executor executor) {
            if (executor == null)
                throw new NullPointerException();
            this.executor = executor;
            this.completionQueue = new LinkedBlockingQueue<Future<V>>();
        }

        private RunnableFuture<V> newTaskFor(Callable<V> task) {
            return new FutureTask<V>(task);
        }

        private RunnableFuture<V> newTaskFor(Runnable task, V result) {
            return new FutureTask<V>(task, result);
        }

        public Future<V> submit(CategoryCallable<V> task) {
            if (task == null) throw new NullPointerException();
            RunnableFuture<V> f = newTaskFor(task);
            executor.execute(new CategorizedQueueingFuture(f, task.getCategory()));
            return f;
        }

        public Future<V> submit(CategoryRunnable task, V result) {
            if (task == null) throw new NullPointerException();
            RunnableFuture<V> f = newTaskFor(task, result);
            executor.execute(new CategorizedQueueingFuture(f, task.getCategory()));
            return f;
        }

        public Future<V> submit(CategoryRunnable task) {
            return submit(task, null);
        }

        @Override
        public Future<V> submit(Callable<V> task) {
            throw new IllegalArgumentException("Must use a 'CategoryCallable'");
        }

        @Override
        public Future<V> submit(Runnable task, V result) {
            throw new IllegalArgumentException("Must use a 'CategoryRunnable'");
        }

        public Future<V> take() throws InterruptedException {
            return completionQueue.take();
        }

        public Future<V> poll() {
            return completionQueue.poll();
        }

        public Future<V> poll(long timeout, TimeUnit unit)
                throws InterruptedException {
            return completionQueue.poll(timeout, unit);
        }

        /**
         * FutureTask extension to enqueue upon completion + Category
         */
        public class CategorizedQueueingFuture extends FutureTask<Void> {

            private final Future<V> task;

            private int category;

            CategorizedQueueingFuture(RunnableFuture<V> task, int category) {
                super(task, null);
                this.task = task;
                this.category = category;
            }

            protected void done() {
                completionQueue.add(task);
            }

            public int getCategory() {
                return category;
            }
        }
    }

    public abstract class CategoryRunnable implements Runnable {

        private int category;

        public CategoryRunnable(int category) {
            this.category = category;
        }

        public int getCategory() {
            return category;
        }
    }

    public abstract class CategoryCallable<V> implements Callable<V> {

        private int category;

        public CategoryCallable(int category) {
            this.category = category;
        }

        public int getCategory() {
            return category;
        }
    }

    public class CategoryBlockingQueue<E> extends AbstractQueue<E> implements BlockingQueue<E> {

        private Map<Integer, LinkedBlockingQueue<E>> map = new HashMap<>();

        private AtomicInteger count = new AtomicInteger(0);

        private ReentrantLock lock = new ReentrantLock();

        private LinkedBlockingQueue<Integer> nextCategories = new LinkedBlockingQueue<>();

        @Override
        public boolean offer(E e) {
            CustomCompletionService.CategorizedQueueingFuture item = (CustomCompletionService.CategorizedQueueingFuture) e;
            lock.lock();
            try {
                int category = item.getCategory();
                if (!map.containsKey(category)) {
                    map.put(category, new LinkedBlockingQueue<E>());
                    nextCategories.offer(category);
                }
                boolean b = map.get(category).offer(e);
                if (b) {
                    count.incrementAndGet();
                }
                return b;
            } finally {
                lock.unlock();
            }
        }

        @Override
        public E poll() {
            return null;
        }

        @Override
        public E peek() {
            return null;
        }

        @Override
        public void put(E e) throws InterruptedException {

        }

        @Override
        public boolean offer(E e, long timeout, TimeUnit unit) throws InterruptedException {
            return false;
        }

        @Override
        public E take() throws InterruptedException {
            lock.lockInterruptibly();
            try {
                Integer nextCategory = nextCategories.take();
                LinkedBlockingQueue<E> categoryElements = map.get(nextCategory);
                E e = categoryElements.take();
                count.decrementAndGet();
                if (categoryElements.isEmpty()) {
                    map.remove(nextCategory);
                } else {
                    nextCategories.offer(nextCategory);
                }
                return e;
            } finally {
                lock.unlock();
            }
        }

        @Override
        public boolean remove(Object o) {
            CustomCompletionService.CategorizedQueueingFuture item = (CustomCompletionService.CategorizedQueueingFuture) o;
            lock.lock();
            try {
                int category = item.getCategory();
                LinkedBlockingQueue<E> categoryElements = map.get(category);
                boolean b = categoryElements.remove(item);
                if (categoryElements.isEmpty()) {
                    map.remove(category);
                }
                if (b) {
                    count.decrementAndGet();
                }
                return b;
            } finally {
                lock.unlock();
            }
        }

        @Override
        public int drainTo(Collection<? super E> c) {
            return 0;
        }

        @Override
        public int drainTo(Collection<? super E> c, int maxElements) {
            return 0;
        }

        @Override
        public Iterator<E> iterator() {
            return null;
        }

        @Override
        public int size() {
            return count.get();
        }

        @Override
        public E poll(long timeout, TimeUnit unit) throws InterruptedException {
            // TODO
            return null;
        }

        @Override
        public int remainingCapacity() {
            return 0;
        }

    }
}

Output with the traditional LinkedBlockingQueue

2017-01-09 14:56:13,061 [pool-2-thread-1] DEBUG - thread 0 of request 0
2017-01-09 14:56:13,061 [pool-2-thread-4] DEBUG - thread 3 of request 0
2017-01-09 14:56:13,061 [pool-2-thread-2] DEBUG - thread 1 of request 0
2017-01-09 14:56:13,061 [pool-2-thread-3] DEBUG - thread 2 of request 0
2017-01-09 14:56:13,061 [pool-2-thread-5] DEBUG - thread 4 of request 0
2017-01-09 14:56:15,063 [pool-2-thread-2] DEBUG - thread 5 of request 0
2017-01-09 14:56:15,063 [pool-2-thread-1] DEBUG - thread 6 of request 0
2017-01-09 14:56:15,063 [pool-2-thread-4] DEBUG - thread 7 of request 0
2017-01-09 14:56:15,063 [pool-2-thread-3] DEBUG - thread 0 of request 1
2017-01-09 14:56:15,063 [pool-2-thread-5] DEBUG - thread 1 of request 1
2017-01-09 14:56:17,064 [pool-2-thread-2] DEBUG - thread 2 of request 1
2017-01-09 14:56:17,064 [pool-2-thread-4] DEBUG - thread 3 of request 1
2017-01-09 14:56:17,064 [pool-2-thread-1] DEBUG - thread 5 of request 1
2017-01-09 14:56:17,064 [pool-2-thread-3] DEBUG - thread 4 of request 1
2017-01-09 14:56:17,064 [pool-2-thread-5] DEBUG - thread 6 of request 1
2017-01-09 14:56:19,064 [pool-2-thread-4] DEBUG - thread 7 of request 1
2017-01-09 14:56:19,064 [pool-2-thread-1] DEBUG - thread 0 of request 2
2017-01-09 14:56:19,064 [pool-2-thread-3] DEBUG - thread 1 of request 2
2017-01-09 14:56:19,064 [pool-2-thread-5] DEBUG - thread 2 of request 2
2017-01-09 14:56:19,064 [pool-2-thread-2] DEBUG - thread 3 of request 2
2017-01-09 14:56:21,064 [pool-2-thread-4] DEBUG - thread 4 of request 2
2017-01-09 14:56:21,064 [pool-2-thread-3] DEBUG - thread 5 of request 2
2017-01-09 14:56:21,064 [pool-2-thread-5] DEBUG - thread 6 of request 2
2017-01-09 14:56:21,064 [pool-2-thread-2] DEBUG - thread 7 of request 2
2017-01-09 14:56:21,064 [pool-2-thread-1] DEBUG - thread 0 of request 3
2017-01-09 14:56:23,064 [pool-2-thread-4] DEBUG - thread 2 of request 3
2017-01-09 14:56:23,064 [pool-2-thread-3] DEBUG - thread 1 of request 3
2017-01-09 14:56:23,064 [pool-2-thread-2] DEBUG - thread 3 of request 3
2017-01-09 14:56:23,064 [pool-2-thread-1] DEBUG - thread 4 of request 3
2017-01-09 14:56:23,064 [pool-2-thread-5] DEBUG - thread 5 of request 3
2017-01-09 14:56:25,064 [pool-2-thread-2] DEBUG - thread 7 of request 3
2017-01-09 14:56:25,064 [pool-2-thread-1] DEBUG - thread 6 of request 3

Output with my custom CategoryBlockingQueue

2017-01-09 14:54:54,765 [pool-2-thread-3] DEBUG - thread 2 of request 0
2017-01-09 14:54:54,765 [pool-2-thread-2] DEBUG - thread 1 of request 0
2017-01-09 14:54:54,765 [pool-2-thread-5] DEBUG - thread 4 of request 0
2017-01-09 14:54:54,765 [pool-2-thread-1] DEBUG - thread 0 of request 0
2017-01-09 14:54:54,765 [pool-2-thread-4] DEBUG - thread 3 of request 0
2017-01-09 14:54:56,767 [pool-2-thread-1] DEBUG - thread 0 of request 1
2017-01-09 14:54:56,767 [pool-2-thread-4] DEBUG - thread 0 of request 3
2017-01-09 14:54:56,767 [pool-2-thread-3] DEBUG - thread 5 of request 0
2017-01-09 14:54:56,767 [pool-2-thread-5] DEBUG - thread 0 of request 2
2017-01-09 14:54:56,767 [pool-2-thread-2] DEBUG - thread 6 of request 0
2017-01-09 14:54:58,767 [pool-2-thread-1] DEBUG - thread 1 of request 1
2017-01-09 14:54:58,767 [pool-2-thread-5] DEBUG - thread 1 of request 2
2017-01-09 14:54:58,767 [pool-2-thread-2] DEBUG - thread 7 of request 0
2017-01-09 14:54:58,767 [pool-2-thread-4] DEBUG - thread 1 of request 3
2017-01-09 14:54:58,767 [pool-2-thread-3] DEBUG - thread 2 of request 1
2017-01-09 14:55:00,767 [pool-2-thread-1] DEBUG - thread 2 of request 2
2017-01-09 14:55:00,767 [pool-2-thread-5] DEBUG - thread 2 of request 3
2017-01-09 14:55:00,767 [pool-2-thread-2] DEBUG - thread 3 of request 1
2017-01-09 14:55:00,767 [pool-2-thread-4] DEBUG - thread 3 of request 2
2017-01-09 14:55:00,767 [pool-2-thread-3] DEBUG - thread 3 of request 3
2017-01-09 14:55:02,767 [pool-2-thread-5] DEBUG - thread 4 of request 1
2017-01-09 14:55:02,767 [pool-2-thread-3] DEBUG - thread 4 of request 2
2017-01-09 14:55:02,767 [pool-2-thread-2] DEBUG - thread 4 of request 3
2017-01-09 14:55:02,767 [pool-2-thread-1] DEBUG - thread 5 of request 1
2017-01-09 14:55:02,767 [pool-2-thread-4] DEBUG - thread 5 of request 2
2017-01-09 14:55:04,767 [pool-2-thread-2] DEBUG - thread 5 of request 3
2017-01-09 14:55:04,767 [pool-2-thread-1] DEBUG - thread 6 of request 1
2017-01-09 14:55:04,767 [pool-2-thread-5] DEBUG - thread 6 of request 2
2017-01-09 14:55:04,767 [pool-2-thread-3] DEBUG - thread 6 of request 3
2017-01-09 14:55:04,768 [pool-2-thread-4] DEBUG - thread 7 of request 1
2017-01-09 14:55:06,768 [pool-2-thread-2] DEBUG - thread 7 of request 3
2017-01-09 14:55:06,768 [pool-2-thread-1] DEBUG - thread 7 of request 2
Comencau
  • 1,084
  • 15
  • 35
  • Is it necessary that all "Tasks" of one request are executed in parallel? With 100 Threads (which is much)? Or would it suffice to give each Request a smaller but dedicated pool? – Fildor Jan 09 '17 at 11:09
  • @Fildor Indeed, my first implementation used smaller dedicated pool per request (core size = 20). But, when I tried to "stress test" by launching several big requests in parallel, it launched too many threads. For example 10 concurrent requests let the application create 200 threads (10 x thread pool of core size 20). That's why, I thought having a single shared ThreadPoolExecutor would really allow to "control" the maximum threads that could exist, even under high workload. – Comencau Jan 09 '17 at 12:27
  • You can use a PriorityQueue: https://docs.oracle.com/javase/8/docs/api/java/util/PriorityQueue.html See http://stackoverflow.com/questions/3198660/java-executors-how-can-i-set-task-priority – Wim Deblauwe Jan 09 '17 at 14:12
  • @Wim Deblauwe Thank you. Yes I thought to use PriorityQueue. But I did not find which logic to implement to make those task "comparable" and achieve the behavior I am searching. Here I want that each request has one of its task regularly taken by a thread becoming available in the pool. I don't want for example that each task of the same request have the same priority, potentially higher than task of another request. For example I could say that task from "tiny" request have higher priority than task from"big" request. But here this is not the behavior I am searching. Request each in its turn. – Comencau Jan 09 '17 at 14:22
  • @Comencau, did you use fixed thread pool? if it ix fixed thread pool, it does not create 200 threads. If you have two dedicated thread pools for short & long duration tasks, your problem is solved. – Ravindra babu Jan 09 '17 at 14:31
  • @Ravindra babu. Thank you for your comment. I was saying it would use 200 threads by answering Fildor's suggestion to have a dedicated smaller thread pool per request. But this is not what I am doing. I do use a single fixed ThreadPoolExecutor. So it limits the number of threads, no problem on this side. The problem now is to avoid that a big request arriving first creates a traffic jam preventing tiny request (for which end-user expects a fast response) from being processed before. – Comencau Jan 09 '17 at 14:45
  • @Wim Deblauwe. Sorry to bother you. Just to make it clearer concerning the use of a PriorityQueue. If I use a PriorityQueue, I don't see how to implement a relevant logic. If I have a continuous stream of small requests (needing 5 threads), I still want that big requests move forward. If I use priority which is proportional to the size of the request (number of needed threads), the continuous stream of small requests will prevent bigger request from being processed. I don't see how to make use of PriorityQueue to overcome that. – Comencau Jan 09 '17 at 14:49
  • You could also do a 2-stage approach: Requests get queued in one Executor, for example a Work-Stealing one. In this Request-Task, you create "local" Executors of (small) fixed size for each Requests' Tasks. So you'll end up with for example 4 x 10 Threads ( for 4 Processors and 10 Threads for each Request that is handled in parallel ) ... – Fildor Jan 10 '17 at 08:37

3 Answers3

0

I have gone throw below link and it might be useful for your own fairness locking implementation.

http://tutorials.jenkov.com/java-concurrency/starvation-and-fairness.html

0

Keep the things simple.

  1. Have two dedicated thread pools for smaller and longer tasks.
  2. Preferably use Executors.html#newFixedThreadPool or Executors.html#newWorkStealingPool

work-stealing thread pool effectively uses available CPU cores.

Have a look at below related SE question for more details:

Java: How to scale threads according to cpu cores?

Community
  • 1
  • 1
Ravindra babu
  • 37,698
  • 11
  • 250
  • 211
  • Thank you for your help. Your suggestion could be okay as well. However, I see 2 small drawbacks. In the case the app is not too busy, the big request will always leverage less threads than the number of threads that would have been available with a single pool. Secondly, using the 2 pools approach creates the new question : how to classify the request is big or small ? What is the limit ? That's why I preferred to stay on my initial approach to have threads given to parallel request each in its turn. – Comencau Jan 12 '17 at 09:48
  • It's specific to Application and developer. If he is using global queue for all types of messages, he can add into two different executors depending on class name. In this case, the developer knows that a particular class takes longer time for execution. – Ravindra babu Jan 12 '17 at 10:01
0

Finally here is what I did to give Threads of the pool to each parallel requests in its turn in a "fair", "balanced" manner. This works. Please, if something is wrong or there is a better way of doing it, let me know.

To summarize, I created a BlockingQueue to be used by the pool. This queue stores requests' tasks in a Map which classifies them according to the request they relate to. Then the take or offer method called by the pool to get a new task to be executed gives a task from a new request each time.

I needed to tune CompletionService to deals with Runnable and Callable having an additional field as the id of the request.

public class TestThreadPoolExecutorWithTurningQueue {

    private final static Logger logger = LogManager.getLogger();

    private static ThreadPoolExecutor executorService;

    int nbRequest = 4;

    int nbThreadPerRequest = 8;

    int threadPoolSize = 5;

    private void init() {
        executorService = new ThreadPoolExecutor(threadPoolSize, threadPoolSize, 10L, TimeUnit.SECONDS,
                new CategoryBlockingQueue<Runnable>()// my custom blocking queue storing waiting tasks per request
                //new LinkedBlockingQueue<Runnable>()
        );
        executorService.allowCoreThreadTimeOut(true);
    }

    @Test
    public void test() throws Exception {
        init();
        ExecutorService tomcat = Executors.newFixedThreadPool(nbRequest);
        for (int i = 0; i < nbRequest; i++) {
            Thread.sleep(10);
            final int finalI = i;
            tomcat.execute(new Runnable() {
                @Override
                public void run() {
                    request(finalI);
                }
            });
        }

        for (int i = 0; i < 100; i++) {
            Thread.sleep(1000);
            logger.debug("TPE = " + executorService);
        }

        tomcat.shutdown();
        tomcat.awaitTermination(1, TimeUnit.DAYS);
    }

    public void request(final int requestId) {
        CustomCompletionService<Object> completionService = new CustomCompletionService<>(executorService);
        for (int j = 0; j < nbThreadPerRequest; j++) {
            final int finalJ = j;
            completionService.submit(new CategoryRunnable(requestId) {
                @Override
                public void go() throws Exception {
                    logger.debug("thread " + finalJ + " of request " + requestId + "   " + executorService);
                    Thread.sleep(2000);// here should come the useful things to be done
                }
            });
        }
        completionService.awaitCompletion();
    }

    public class CustomCompletionService<V> implements CompletionService<V> {

        private final Executor executor;

        private final BlockingQueue<Future<V>> completionQueue;

        private List<RunnableFuture<V>> submittedTasks = new ArrayList<>();

        public CustomCompletionService(Executor executor) {
            if (executor == null)
                throw new NullPointerException();
            this.executor = executor;
            this.completionQueue = new LinkedBlockingQueue<>();
        }

        public void awaitCompletion() {
            for (int i = 0; i < submittedTasks.size(); i++) {
                try {
                    take().get();
                } catch (Exception e) {
                    // Cancel the remaining tasks
                    for (RunnableFuture<V> f : submittedTasks) {
                        f.cancel(true);
                    }

                    // Get the underlying exception
                    Exception toThrow = e;
                    if (e instanceof ExecutionException) {
                        ExecutionException ex = (ExecutionException) e;
                        toThrow = (Exception) ex.getCause();
                    }
                    throw new RuntimeException(toThrow);
                }
            }
        }

        private RunnableFuture<V> newTaskFor(Callable<V> task) {
            return new FutureTask<V>(task);
        }

        private RunnableFuture<V> newTaskFor(Runnable task, V result) {
            return new FutureTask<V>(task, result);
        }

        public Future<V> submit(CategoryCallable<V> task) {
            if (task == null) throw new NullPointerException();
            RunnableFuture<V> f = newTaskFor(task);
            executor.execute(new CategorizedQueueingFuture(f, task.getCategory()));
            submittedTasks.add(f);
            return f;
        }

        public Future<V> submit(CategoryRunnable task, V result) {
            if (task == null) throw new NullPointerException();
            RunnableFuture<V> f = newTaskFor(task, result);
            executor.execute(new CategorizedQueueingFuture(f, task.getCategory()));
            submittedTasks.add(f);
            return f;
        }

        public Future<V> submit(CategoryRunnable task) {
            return submit(task, null);
        }

        @Override
        public Future<V> submit(Callable<V> task) {
            throw new IllegalArgumentException("Must use a 'CategoryCallable'");
        }

        @Override
        public Future<V> submit(Runnable task, V result) {
            throw new IllegalArgumentException("Must use a 'CategoryRunnable'");
        }

        public Future<V> take() throws InterruptedException {
            return completionQueue.take();
        }

        public Future<V> poll() {
            return completionQueue.poll();
        }

        public Future<V> poll(long timeout, TimeUnit unit)
                throws InterruptedException {
            return completionQueue.poll(timeout, unit);
        }

        /**
         * FutureTask extension to enqueue upon completion + Category
         */
        public class CategorizedQueueingFuture extends FutureTask<Void> {

            private final Future<V> task;

            private int category;

            CategorizedQueueingFuture(RunnableFuture<V> task, int category) {
                super(task, null);
                this.task = task;
                this.category = category;
            }

            protected void done() {
                completionQueue.add(task);
            }

            public int getCategory() {
                return category;
            }
        }
    }

    public abstract class CategoryRunnable implements Runnable {

        private int category;

        public CategoryRunnable(int category) {
            this.category = category;
        }

        public int getCategory() {
            return category;
        }

        void go() throws Exception {
            // To be implemented. Do nothing by default.
            logger.warn("Implement go method !");
        }

        @Override
        public void run() {
            try {
                go();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    public abstract class CategoryCallable<V> implements Callable<V> {

        private int category;

        public CategoryCallable(int category) {
            this.category = category;
        }

        public int getCategory() {
            return category;
        }
    }

    public class CategoryBlockingQueue<E> extends AbstractQueue<E> implements BlockingQueue<E> {

        /**
         * Lock held by take, poll, etc
         */
        private final ReentrantLock takeLock = new ReentrantLock();

        /**
         * Wait queue for waiting takes
         */
        private final Condition notEmpty = takeLock.newCondition();

        /**
         * Lock held by put, offer, etc
         */
        private final ReentrantLock putLock = new ReentrantLock();

        private Map<Integer, LinkedBlockingQueue<E>> map = new ConcurrentHashMap<>();

        private AtomicInteger count = new AtomicInteger(0);

        private LinkedBlockingQueue<Integer> nextCategories = new LinkedBlockingQueue<>();

        @Override
        public boolean offer(E e) {
            CustomCompletionService.CategorizedQueueingFuture item = (CustomCompletionService.CategorizedQueueingFuture) e;
            putLock.lock();
            try {
                int category = item.getCategory();
                if (!map.containsKey(category)) {
                    map.put(category, new LinkedBlockingQueue<E>());
                    if (!nextCategories.offer(category)) return false;
                }
                if (!map.get(category).offer(e)) return false;
                int c = count.getAndIncrement();
                if (c == 0) signalNotEmpty();// if we passed from 0 element (empty queue) to 1 element, signal potentially waiting threads on take
                return true;
            } finally {
                putLock.unlock();
            }
        }

        private void signalNotEmpty() {
            takeLock.lock();
            try {
                notEmpty.signal();
            } finally {
                takeLock.unlock();
            }
        }

        @Override
        public E take() throws InterruptedException {
            takeLock.lockInterruptibly();
            try {
                while (count.get() == 0) {
                    notEmpty.await();
                }
                E e = dequeue();
                int c = count.decrementAndGet();
                if (c > 0) notEmpty.signal();
                return e;
            } finally {
                takeLock.unlock();
            }
        }

        private E dequeue() throws InterruptedException {
            Integer nextCategory = nextCategories.take();
            LinkedBlockingQueue<E> categoryElements = map.get(nextCategory);
            E e = categoryElements.take();
            if (categoryElements.isEmpty()) {
                map.remove(nextCategory);
            } else {
                nextCategories.offer(nextCategory);
            }
            return e;
        }

        @Override
        public E poll(long timeout, TimeUnit unit) throws InterruptedException {
            E x = null;
            long nanos = unit.toNanos(timeout);
            takeLock.lockInterruptibly();
            try {
                while (count.get() == 0) {
                    if (nanos <= 0) return null;
                    nanos = notEmpty.awaitNanos(nanos);
                }
                x = dequeue();
                int c = count.decrementAndGet();
                if (c > 0) notEmpty.signal();
            } finally {
                takeLock.unlock();
            }
            return x;
        }

        @Override
        public boolean remove(Object o) {
            if (o == null) return false;
            CustomCompletionService.CategorizedQueueingFuture item = (CustomCompletionService.CategorizedQueueingFuture) o;
            putLock.lock();
            takeLock.lock();
            try {
                int category = item.getCategory();
                LinkedBlockingQueue<E> categoryElements = map.get(category);
                boolean b = categoryElements.remove(item);
                if (categoryElements.isEmpty()) {
                    map.remove(category);
                }
                if (b) {
                    count.decrementAndGet();
                }
                return b;
            } finally {
                takeLock.unlock();
                putLock.unlock();
            }
        }

        @Override
        public E poll() {
            return null;
        }

        @Override
        public E peek() {
            return null;
        }

        @Override
        public void put(E e) throws InterruptedException {

        }

        @Override
        public boolean offer(E e, long timeout, TimeUnit unit) throws InterruptedException {
            return false;
        }

        @Override
        public int drainTo(Collection<? super E> c) {
            return 0;
        }

        @Override
        public int drainTo(Collection<? super E> c, int maxElements) {
            return 0;
        }

        @Override
        public Iterator<E> iterator() {
            return null;
        }

        @Override
        public int size() {
            return count.get();
        }

        @Override
        public int remainingCapacity() {
            return 0;
        }

    }

}
Comencau
  • 1,084
  • 15
  • 35