🟩 ForkJoinPool
2022年10月10日
- Java
🟩 ForkJoinPool
1. 类注释
- Fork/Join框架的核心是两个类:
ForkJoinPool
&ForkJoinTask
ForkJoinPool
负责实现 任务的提交、工作窃取算法、管理工作线程、提供关于任务的状态以及执行信息- 工作窃取算法:
- 将任务分成多个线程,每个线程对应一条工作队列,当一条线程中的任务全部执行完成时,会从其他工作队列队尾获取任务,进行执行
- 工作队列为 数组构成的双端队列
- 工作窃取算法:
ForkJoinTask
主要提供fork()
&join()
方法ForkJoinTask
又有两个抽象子类:RecursiveTask
&RecursiveAction
- 我们需要继承抽象类,然后重写里面的
compute()
方法,实现自定义逻辑,满足业务需求
使用
public class ForkJoinTest {
public static void main(String[] args) throws Exception {
List<Integer> list = new LinkedList<>();
Integer sum = 0;
for (int i = 0; i < 1000; i++) { // 填充 list 链表
list.add(i);
sum += i;
}
CalculateTask task = new CalculateTask(0, list.size(), list); // 生成 任务
Future<Integer> future = ForkJoinPool.commonPool().submit(task); // 使用 commonPool 执行任务
System.out.println("sum=" + sum + ",Fork/Join result=" + future.get()); // 打印结果
}
// 创建 任务类,继承 RecursiveTask,重写 compute() 方法
@Data
static class CalculateTask extends RecursiveTask<Integer> {
private Integer start;
private Integer end;
private List<Integer> list; // 对 list 链表里的值进行操作
public CalculateTask(Integer start, Integer end, List<Integer> list) {
this.start = start;
this.end = end;
this.list = list;
}
@Override
protected Integer compute() {
Integer sum = 0;
if (end - start < 200) {
for (int i = start; i < end; i++) {
sum += list.get(i); // 求和
}
} else { // 分治
int middle = (start + end) / 2;
CalculateTask task1 = new CalculateTask(start, middle, list);
CalculateTask task2 = new CalculateTask(middle, end, list);
// 执行子任务
task1.fork();
task2.fork();
// 合并子任务结果
sum = task1.join() + task2.join();
}
return sum;
}
}
}
2. 类图
@sun.misc.Contended
public class ForkJoinPool extends AbstractExecutorService {
}
3. 构造函数
private ForkJoinPool(int parallelism, // 并发数
ForkJoinWorkerThreadFactory factory, // 线程工厂
UncaughtExceptionHandler handler, // 拒绝策略
int mode, // 同步 | 异步
String workerNamePrefix) { // 工作线程名称前缀
this.workerNamePrefix = workerNamePrefix;
this.factory = factory;
this.ueh = handler;
this.config = (parallelism & SMASK) | mode;
long np = (long)(-parallelism); // offset ctl counts
this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}
public ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
boolean asyncMode) {
this(checkParallelism(parallelism), // 校验并发数 是否合法
checkFactory(factory), // 若 factory == null 则抛出异常
handler,
asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
"ForkJoinPool-" + nextPoolId() + "-worker-");
checkPermission(); // 权限认证
}
public ForkJoinPool(int parallelism) {
this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
}
public ForkJoinPool() {
this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
defaultForkJoinWorkerThreadFactory,
null,
false);
}
// 返回公共池实例。
// 这个池是静态构建的;它的运行状态不受 {@link shutdown} 或 {@link shutdownNow} 尝试的影响。
// 然而,这个池和任何正在进行的处理都会在程序 {@link Systemexit} 时自动终止。
// 任何依赖异步任务处理在程序终止之前完成的程序都应该在退出之前调用 {@code commonPool().}{@link awaitQuiescence awaitQuiescence}。
public static ForkJoinPool commonPool() {
return common;
}
4. 内部类
public static interface ForkJoinWorkerThreadFactory {
/**
* Returns a new worker thread operating in the given pool.
*/
public ForkJoinWorkerThread newThread(ForkJoinPool pool);
}
static final class DefaultForkJoinWorkerThreadFactory implements ForkJoinWorkerThreadFactory {
public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
return new ForkJoinWorkerThread(pool);
}
}
// 空任务
static final class EmptyTask extends ForkJoinTask<Void> {
private static final long serialVersionUID = -7721805057305804111L;
EmptyTask() { status = ForkJoinTask.NORMAL; } // force done
public final Void getRawResult() { return null; }
public final void setRawResult(Void x) {}
public final boolean exec() { return true; }
}
// 支持 工作窃取 和 外部任务提交 的队列 - 数组构成的双端队列,便于 工作窃取
@sun.misc.Contended // 防止 伪共享 问题
static final class WorkQueue {
volatile int scanState; // versioned, <0: inactive; odd:scanning
int stackPred; // pool stack (ctl) predecessor
int nsteals; // number of steals
int hint; // randomization and stealer index hint
int config; // pool index and mode
volatile int qlock; // 1: locked, < 0: terminate; else 0
volatile int base; // poll 操作的索引
int top; // push 操作的索引
ForkJoinTask<?>[] array; // 工作队列中的 任务数组
final ForkJoinPool pool; // the containing pool (may be null)
final ForkJoinWorkerThread owner; // owning thread or null if shared
volatile Thread parker; // == owner during call to park; else null
volatile ForkJoinTask<?> currentJoin; // 当前 join 的任务
volatile ForkJoinTask<?> currentSteal; // 当前 窃取 的任务
}
5. 常用方法
// 执行给定的任务,完成后返回其结果。
public <T> T invoke(ForkJoinTask<T> task) {
if (task == null)
throw new NullPointerException();
externalPush(task); // 将任务加入 工作队列 中
return task.join();
}
// 安排给定任务的(异步)执行。
public void execute(ForkJoinTask<?> task) {
if (task == null)
throw new NullPointerException();
externalPush(task);
}
// 执行给定的任务
public void execute(Runnable task) {
if (task == null)
throw new NullPointerException();
ForkJoinTask<?> job;
if (task instanceof ForkJoinTask<?>) // avoid re-wrap
job = (ForkJoinTask<?>) task;
else
job = new ForkJoinTask.RunnableExecuteAction(task);
externalPush(job);
}
// 提交给定的任务,并执行
public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
if (task == null)
throw new NullPointerException();
externalPush(task);
return task;
}
public <T> ForkJoinTask<T> submit(Callable<T> task) {
ForkJoinTask<T> job = new ForkJoinTask.AdaptedCallable<T>(task);
externalPush(job);
return job;
}
public <T> ForkJoinTask<T> submit(Runnable task, T result) {
ForkJoinTask<T> job = new ForkJoinTask.AdaptedRunnable<T>(task, result);
externalPush(job);
return job;
}
public ForkJoinTask<?> submit(Runnable task) {
if (task == null)
throw new NullPointerException();
ForkJoinTask<?> job;
if (task instanceof ForkJoinTask<?>) // avoid re-wrap
job = (ForkJoinTask<?>) task;
else
job = new ForkJoinTask.AdaptedRunnableAction(task);
externalPush(job);
return job;
}
// 唤醒所有任务,并执行
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) {
// In previous versions of this class, this method constructed
// a task to run ForkJoinTask.invokeAll, but now external
// invocation of multiple tasks is at least as efficient.
ArrayList<Future<T>> futures = new ArrayList<>(tasks.size());
boolean done = false;
try {
for (Callable<T> t : tasks) {
ForkJoinTask<T> f = new ForkJoinTask.AdaptedCallable<T>(t);
futures.add(f);
externalPush(f);
}
for (int i = 0, size = futures.size(); i < size; i++)
((ForkJoinTask<?>)futures.get(i)).quietlyJoin();
done = true;
return futures;
} finally {
if (!done)
for (int i = 0, size = futures.size(); i < size; i++)
futures.get(i).cancel(false);
}
}
6. 任务类
类图
- ForkJoinTask 为抽象基类
- RecursiveTask & RecursiveAction 均为 ForkJoinTask 的抽象子类
- 我们要使用时,需要继承 RecursiveTask | RecursiveAction 作为 实际任务类
- RecursiveTask & RecursiveAction 的区别在于:
- RecursiveTask 用于计算有返回结果的任务
- RecursiveAction 用于计算没有返回结果的任务
1) ForkJoinTask
public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
// 安排在当前任务正在运行的池中异步执行此任务
public final ForkJoinTask<V> fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this);
return this;
}
// 完成后返回计算结果
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
// 开始执行此任务,在必要时等待其完成,并返回其结果
public final V invoke() {
int s;
if ((s = doInvoke() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
//
public static <T extends ForkJoinTask<?>> Collection<T> invokeAll(Collection<T> tasks) {
if (!(tasks instanceof RandomAccess) || !(tasks instanceof List<?>)) {
invokeAll(tasks.toArray(new ForkJoinTask<?>[tasks.size()]));
return tasks;
}
@SuppressWarnings("unchecked")
List<? extends ForkJoinTask<?>> ts = (List<? extends ForkJoinTask<?>>) tasks;
Throwable ex = null;
int last = ts.size() - 1;
for (int i = last; i >= 0; --i) {
ForkJoinTask<?> t = ts.get(i);
if (t == null) {
if (ex == null)
ex = new NullPointerException();
}
else if (i != 0)
t.fork();
else if (t.doInvoke() < NORMAL && ex == null)
ex = t.getException();
}
for (int i = 1; i <= last; ++i) {
ForkJoinTask<?> t = ts.get(i);
if (t != null) {
if (ex != null)
t.cancel(false);
else if (t.doJoin() < NORMAL)
ex = t.getException();
}
}
if (ex != null)
rethrow(ex);
return tasks;
}
}
2) RecursiveTask
public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
private static final long serialVersionUID = 5232453952276485270L;
/**
* The result of the computation.
*/
V result;
protected abstract V compute(); // 要重写的方法
public final V getRawResult() {
return result;
}
protected final void setRawResult(V value) {
result = value;
}
/**
* Implements execution conventions for RecursiveTask.
*/
protected final boolean exec() {
result = compute();
return true;
}
}
3) RecursiveAction
public abstract class RecursiveAction extends ForkJoinTask<Void> {
private static final long serialVersionUID = 5232453952276485070L;
protected abstract void compute(); // 要重写的方法
/**
* 永远为 Null
*/
public final Void getRawResult() { return null; }
/**
* 输入参数只能为 null
*/
protected final void setRawResult(Void mustBeNull) { }
/**
* Implements execution conventions for RecursiveActions.
*/
protected final boolean exec() {
compute();
return true;
}
}