🟩 ForkJoinPool

吞佛童子2022年10月10日
  • Java
  • Concurrency
大约 5 分钟

🟩 ForkJoinPool

1. 类注释

  1. Fork/Join框架的核心是两个类:ForkJoinPool & ForkJoinTask
  2. ForkJoinPool 负责实现 任务的提交工作窃取算法管理工作线程提供关于任务的状态以及执行信息
    • 工作窃取算法
      • 将任务分成多个线程,每个线程对应一条工作队列,当一条线程中的任务全部执行完成时,会从其他工作队列队尾获取任务,进行执行
      • 工作队列为 数组构成的双端队列
  3. ForkJoinTask 主要提供 fork() & join() 方法
    • ForkJoinTask 又有两个抽象子类: RecursiveTask & RecursiveAction
    • 我们需要继承抽象类,然后重写里面的 compute() 方法,实现自定义逻辑,满足业务需求

使用

public class ForkJoinTest {

    public static void main(String[] args) throws Exception {
        List<Integer> list = new LinkedList<>();
        Integer sum = 0;
        for (int i = 0; i < 1000; i++) { // 填充 list 链表
            list.add(i);
            sum += i;
        }

        CalculateTask task = new CalculateTask(0, list.size(), list); // 生成 任务
        Future<Integer> future = ForkJoinPool.commonPool().submit(task); // 使用 commonPool 执行任务
        System.out.println("sum=" + sum + ",Fork/Join result=" + future.get()); // 打印结果
    }

    // 创建 任务类,继承 RecursiveTask,重写 compute() 方法
    @Data
    static class CalculateTask extends RecursiveTask<Integer> {
        private Integer start;
        private Integer end;
        private List<Integer> list; // 对 list 链表里的值进行操作

        public CalculateTask(Integer start, Integer end, List<Integer> list) {
            this.start = start;
            this.end = end;
            this.list = list;
        }

        @Override
        protected Integer compute() {
            Integer sum = 0;
            if (end - start < 200) {
                for (int i = start; i < end; i++) {
                    sum += list.get(i); // 求和
                }
            } else { // 分治
                int middle = (start + end) / 2;
                CalculateTask task1 = new CalculateTask(start, middle, list);
                CalculateTask task2 = new CalculateTask(middle, end, list);
                // 执行子任务
                task1.fork();
                task2.fork();
                // 合并子任务结果
                sum = task1.join() + task2.join();
            }
            return sum;
        }
    }
}

2. 类图

@sun.misc.Contended
public class ForkJoinPool extends AbstractExecutorService {
}

img.png


3. 构造函数

    private ForkJoinPool(int parallelism, // 并发数
                         ForkJoinWorkerThreadFactory factory, // 线程工厂
                         UncaughtExceptionHandler handler, // 拒绝策略
                         int mode, // 同步 | 异步
                         String workerNamePrefix) { // 工作线程名称前缀
        this.workerNamePrefix = workerNamePrefix;
        this.factory = factory;
        this.ueh = handler;
        this.config = (parallelism & SMASK) | mode;
        long np = (long)(-parallelism); // offset ctl counts
        this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
    }
    
    public ForkJoinPool(int parallelism,
                        ForkJoinWorkerThreadFactory factory,
                        UncaughtExceptionHandler handler,
                        boolean asyncMode) {
        this(checkParallelism(parallelism), // 校验并发数 是否合法
             checkFactory(factory), // 若 factory == null 则抛出异常
             handler,
             asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
             "ForkJoinPool-" + nextPoolId() + "-worker-");
        checkPermission(); // 权限认证
    }
    
    public ForkJoinPool(int parallelism) {
        this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
    }
    
    public ForkJoinPool() {
        this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()), 
             defaultForkJoinWorkerThreadFactory, 
             null, 
             false);
    }
    
    // 返回公共池实例。
    // 这个池是静态构建的;它的运行状态不受 {@link shutdown} 或 {@link shutdownNow} 尝试的影响。
    // 然而,这个池和任何正在进行的处理都会在程序 {@link Systemexit} 时自动终止。
    // 任何依赖异步任务处理在程序终止之前完成的程序都应该在退出之前调用 {@code commonPool().}{@link awaitQuiescence awaitQuiescence}。
    public static ForkJoinPool commonPool() {
        return common;
    }

4. 内部类

    public static interface ForkJoinWorkerThreadFactory {
        /**
         * Returns a new worker thread operating in the given pool.
         */
        public ForkJoinWorkerThread newThread(ForkJoinPool pool);
    }
    
    static final class DefaultForkJoinWorkerThreadFactory implements ForkJoinWorkerThreadFactory {
        public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
            return new ForkJoinWorkerThread(pool);
        }
    }
    
    // 空任务
    static final class EmptyTask extends ForkJoinTask<Void> {
        private static final long serialVersionUID = -7721805057305804111L;
        EmptyTask() { status = ForkJoinTask.NORMAL; } // force done
        public final Void getRawResult() { return null; }
        public final void setRawResult(Void x) {}
        public final boolean exec() { return true; }
    }
    
    // 支持 工作窃取 和 外部任务提交 的队列 - 数组构成的双端队列,便于 工作窃取
    @sun.misc.Contended // 防止 伪共享 问题
    static final class WorkQueue {
        volatile int scanState;    // versioned, <0: inactive; odd:scanning
        int stackPred;             // pool stack (ctl) predecessor
        int nsteals;               // number of steals
        int hint;                  // randomization and stealer index hint
        int config;                // pool index and mode
        volatile int qlock;        // 1: locked, < 0: terminate; else 0
        volatile int base;         // poll 操作的索引
        int top;                   // push 操作的索引
        ForkJoinTask<?>[] array;   // 工作队列中的 任务数组
        final ForkJoinPool pool;   // the containing pool (may be null)
        final ForkJoinWorkerThread owner; // owning thread or null if shared
        volatile Thread parker;    // == owner during call to park; else null
        volatile ForkJoinTask<?> currentJoin;  // 当前 join 的任务
        volatile ForkJoinTask<?> currentSteal; // 当前 窃取 的任务
    }

5. 常用方法

    // 执行给定的任务,完成后返回其结果。
    public <T> T invoke(ForkJoinTask<T> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task); // 将任务加入 工作队列 中
        return task.join();
    }
    
    // 安排给定任务的(异步)执行。
    public void execute(ForkJoinTask<?> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
    }
    
    // 执行给定的任务
    public void execute(Runnable task) {
        if (task == null)
            throw new NullPointerException();
        ForkJoinTask<?> job;
        if (task instanceof ForkJoinTask<?>) // avoid re-wrap
            job = (ForkJoinTask<?>) task;
        else
            job = new ForkJoinTask.RunnableExecuteAction(task);
        externalPush(job);
    }
    
    // 提交给定的任务,并执行
    public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
        return task;
    }
    public <T> ForkJoinTask<T> submit(Callable<T> task) {
        ForkJoinTask<T> job = new ForkJoinTask.AdaptedCallable<T>(task);
        externalPush(job);
        return job;
    }
    public <T> ForkJoinTask<T> submit(Runnable task, T result) {
        ForkJoinTask<T> job = new ForkJoinTask.AdaptedRunnable<T>(task, result);
        externalPush(job);
        return job;
    }
    public ForkJoinTask<?> submit(Runnable task) {
        if (task == null)
            throw new NullPointerException();
        ForkJoinTask<?> job;
        if (task instanceof ForkJoinTask<?>) // avoid re-wrap
            job = (ForkJoinTask<?>) task;
        else
            job = new ForkJoinTask.AdaptedRunnableAction(task);
        externalPush(job);
        return job;
    }
    
    // 唤醒所有任务,并执行
    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) {
        // In previous versions of this class, this method constructed
        // a task to run ForkJoinTask.invokeAll, but now external
        // invocation of multiple tasks is at least as efficient.
        ArrayList<Future<T>> futures = new ArrayList<>(tasks.size());

        boolean done = false;
        try {
            for (Callable<T> t : tasks) {
                ForkJoinTask<T> f = new ForkJoinTask.AdaptedCallable<T>(t);
                futures.add(f);
                externalPush(f);
            }
            for (int i = 0, size = futures.size(); i < size; i++)
                ((ForkJoinTask<?>)futures.get(i)).quietlyJoin();
            done = true;
            return futures;
        } finally {
            if (!done)
                for (int i = 0, size = futures.size(); i < size; i++)
                    futures.get(i).cancel(false);
        }
    }

6. 任务类

类图

img_1.png

  • ForkJoinTask 为抽象基类
  • RecursiveTask & RecursiveAction 均为 ForkJoinTask 的抽象子类
  • 我们要使用时,需要继承 RecursiveTask | RecursiveAction 作为 实际任务类
  • RecursiveTask & RecursiveAction 的区别在于:
    • RecursiveTask 用于计算有返回结果的任务
    • RecursiveAction 用于计算没有返回结果的任务

1) ForkJoinTask

public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
    // 安排在当前任务正在运行的池中异步执行此任务
    public final ForkJoinTask<V> fork() {
        Thread t;
        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);
        else
            ForkJoinPool.common.externalPush(this);
        return this;
    }
    
    // 完成后返回计算结果
    public final V join() {
        int s;
        if ((s = doJoin() & DONE_MASK) != NORMAL)
            reportException(s);
        return getRawResult();
    }
    
    // 开始执行此任务,在必要时等待其完成,并返回其结果
    public final V invoke() {
        int s;
        if ((s = doInvoke() & DONE_MASK) != NORMAL)
            reportException(s);
        return getRawResult();
    }
    
    //
    public static <T extends ForkJoinTask<?>> Collection<T> invokeAll(Collection<T> tasks) {
        if (!(tasks instanceof RandomAccess) || !(tasks instanceof List<?>)) {
            invokeAll(tasks.toArray(new ForkJoinTask<?>[tasks.size()]));
            return tasks;
        }
        @SuppressWarnings("unchecked")
        List<? extends ForkJoinTask<?>> ts = (List<? extends ForkJoinTask<?>>) tasks;
        Throwable ex = null;
        int last = ts.size() - 1;
        for (int i = last; i >= 0; --i) {
            ForkJoinTask<?> t = ts.get(i);
            if (t == null) {
                if (ex == null)
                    ex = new NullPointerException();
            }
            else if (i != 0)
                t.fork();
            else if (t.doInvoke() < NORMAL && ex == null)
                ex = t.getException();
        }
        for (int i = 1; i <= last; ++i) {
            ForkJoinTask<?> t = ts.get(i);
            if (t != null) {
                if (ex != null)
                    t.cancel(false);
                else if (t.doJoin() < NORMAL)
                    ex = t.getException();
            }
        }
        if (ex != null)
            rethrow(ex);
        return tasks;
    }
}

2) RecursiveTask

public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
    private static final long serialVersionUID = 5232453952276485270L;

    /**
     * The result of the computation.
     */
    V result;

    protected abstract V compute(); // 要重写的方法

    public final V getRawResult() {
        return result;
    }

    protected final void setRawResult(V value) {
        result = value;
    }

    /**
     * Implements execution conventions for RecursiveTask.
     */
    protected final boolean exec() {
        result = compute();
        return true;
    }
}

3) RecursiveAction

public abstract class RecursiveAction extends ForkJoinTask<Void> {
    private static final long serialVersionUID = 5232453952276485070L;

    protected abstract void compute(); // 要重写的方法

    /**
     * 永远为 Null
     */
    public final Void getRawResult() { return null; }

    /**
     * 输入参数只能为 null
     */
    protected final void setRawResult(Void mustBeNull) { }

    /**
     * Implements execution conventions for RecursiveActions.
     */
    protected final boolean exec() {
        compute();
        return true;
    }
}
上次编辑于: 2022/10/10 下午8:43:48
贡献者: liuxianzhishou