💠 ReentrantReadWriteLock
2022年10月10日
- Java
💠 ReentrantReadWriteLock
1. 类注释
- 将 AQS 的 state 分为 读锁[高 16 位] & 写锁[低 16 位] 两部分
- 读锁
- 在 无写锁 || 当前线程持有写锁 的情况下,可尝试获锁
- 共享锁
- 写锁
- 只有 没有其他线程获得读锁 | 写锁 的情况下,才可尝试获锁
- 独占 & 可重入 锁
- 读锁
- 简单使用
class RWDictionary {
private final Map<String, Data> m = new TreeMap<String, Data>();
private final ReentrantReadWriteLock rwl = new ReentrantReadWriteLock();
private final Lock r = rwl.readLock(); // 读写锁的 读锁
private final Lock w = rwl.writeLock(); // 读写锁的 写锁
public Data get(String key) { // 读锁读
r.lock();
try { return m.get(key); }
finally { r.unlock(); }
}
public String[] allKeys() { // 读锁 读取全部
r.lock();
try { return m.keySet().toArray(); }
finally { r.unlock(); }
}
public Data put(String key, Data value) { // 写锁 写入
w.lock();
try { return m.put(key, value); }
finally { w.unlock(); }
}
public void clear() { // 写锁 清空
w.lock();
try { m.clear(); }
finally { w.unlock(); }
}
}
2. 类图
public class ReentrantReadWriteLock implements ReadWriteLock, java.io.Serializable {
// ...
}
3. 属性
private static final long serialVersionUID = -6992448646407690164L;
private final ReentrantReadWriteLock.ReadLock readerLock; // 读锁
private final ReentrantReadWriteLock.WriteLock writerLock; // 写锁
final Sync sync;
public ReentrantReadWriteLock.WriteLock writeLock() { return writerLock; }
public ReentrantReadWriteLock.ReadLock readLock() { return readerLock; }
4. 构造函数
/**
* 默认创建 非公平读写锁
*/
public ReentrantReadWriteLock() {
this(false);
}
/**
* 创建 指定策略 的读写锁
*/
public ReentrantReadWriteLock(boolean fair) {
sync = fair ? new FairSync() : new NonfairSync();
readerLock = new ReadLock(this);
writerLock = new WriteLock(this);
}
5. 内部类
Sync
/**
* Synchronization implementation for ReentrantReadWriteLock.
* Subclassed into fair and nonfair versions.
*/
abstract static class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 6317671515068378041L;
/*
* 将 AQS 的 state int 属性分为两部分,高 16 位作为读锁的使用,低 16 位为写锁 使用
*/
static final int SHARED_SHIFT = 16;
static final int SHARED_UNIT = (1 << SHARED_SHIFT);
static final int MAX_COUNT = (1 << SHARED_SHIFT) - 1;
static final int EXCLUSIVE_MASK = (1 << SHARED_SHIFT) - 1;
/** 读锁,读取高 16 位,即右移 16 位 */
static int sharedCount(int c) { return c >>> SHARED_SHIFT; }
/** 写锁,读取低 16 位,即将高位全部抹消 */
static int exclusiveCount(int c) { return c & EXCLUSIVE_MASK; }
/**
* 静态内部类
*/
static final class HoldCounter {
int count = 0;
// Use id, not reference, to avoid garbage retention
final long tid = getThreadId(Thread.currentThread());
}
/**
* 静态内部类继承了 ThreadLocal
*/
static final class ThreadLocalHoldCounter extends ThreadLocal<HoldCounter> {
public HoldCounter initialValue() {
return new HoldCounter();
}
}
/**
* 当前线程持有的 读锁 个数,只在 构造函数 & readObject 时被初始化,当 == 0 时被移除
*/
private transient ThreadLocalHoldCounter readHolds;
/**
* 最后一个线程持有的 读锁 数量
*/
private transient HoldCounter cachedHoldCounter;
/**
* 首个获取该 读锁 的线程
*/
private transient Thread firstReader = null;
private transient int firstReaderHoldCount; // firstReader 线程持有的 读锁 个数
// -----------------------------------------------------------------------
// 构造函数
Sync() {
readHolds = new ThreadLocalHoldCounter();
setState(getState()); // ensures visibility of readHolds
}
// ---------------------------- method ----------------------------------
/**
* 抽象方法,判断当前线程获取读锁时,是否应该被阻塞
*/
abstract boolean readerShouldBlock();
/**
* 抽象方法,判断当前线程尝试获取 写锁 时,是否应该被阻塞
*/
abstract boolean writerShouldBlock();
// 重写父类方法,尝试释放 写锁
protected final boolean tryRelease(int releases) {
if (!isHeldExclusively())
throw new IllegalMonitorStateException();
int nextc = getState() - releases;
boolean free = exclusiveCount(nextc) == 0;
if (free)
setExclusiveOwnerThread(null);
setState(nextc);
return free;
}
// 重写父类方法,尝试获取 写锁[独占锁]
protected final boolean tryAcquire(int acquires) {
Thread current = Thread.currentThread();
int c = getState();
int w = exclusiveCount(c); // 根据 c 获取独占锁 的值
if (c != 0) {
// c != 0 说明有 锁的存在,可能是 读锁,也可能是 写锁
// w == 0 说明有了 读锁,此时无法获取写锁
// current != getExclusiveOwnerThread() 说明获取读锁的不是当前线程,此时也无法获取写锁
if (w == 0 || current != getExclusiveOwnerThread())
return false;
if (w + exclusiveCount(acquires) > MAX_COUNT)
throw new Error("Maximum lock count exceeded");
// 进入当前,只可能是 获取读锁的是当前线程,修改 c 的状态,写锁为低 16 位,因此直接相加
setState(c + acquires);
return true;
}
// c == 0 说明没有任何锁
if (writerShouldBlock() ||
!compareAndSetState(c, c + acquires)) // CAS 修改 c 状态失败
return false;
// 获写锁成功
setExclusiveOwnerThread(current);
return true;
}
// 重写父类方法,尝试释放 读锁,若释放后 c == 0 返回 true;否则返回 false
protected final boolean tryReleaseShared(int unused) {
Thread current = Thread.currentThread();
if (firstReader == current) { // 当前线程就是 首个获取 读锁的线程,修改 firstReaderHoldCount 状态
if (firstReaderHoldCount == 1)
firstReader = null;
else
firstReaderHoldCount--;
} else { // 当前线程不是 首个获取 读锁的线程,
HoldCounter rh = cachedHoldCounter;
if (rh == null || rh.tid != getThreadId(current))
rh = readHolds.get();
int count = rh.count;
if (count <= 1) {
readHolds.remove();
if (count <= 0)
throw unmatchedUnlockException();
}
--rh.count;
}
for (;;) {
int c = getState();
int nextc = c - SHARED_UNIT;
if (compareAndSetState(c, nextc)) // 修改 c 的 读锁个数
return nextc == 0;
}
}
private IllegalMonitorStateException unmatchedUnlockException() {
return new IllegalMonitorStateException(
"attempt to unlock read lock, not locked by current thread");
}
// 重写父类方法,尝试获取 读锁
protected final int tryAcquireShared(int unused) {
Thread current = Thread.currentThread();
int c = getState();
// 有线程获取了 写锁 & 非当前线程获取的 写锁,无法获取读锁,直接返回 -1
if (exclusiveCount(c) != 0 &&
getExclusiveOwnerThread() != current)
return -1;
int r = sharedCount(c); // 读锁个数
if (!readerShouldBlock() &&
r < MAX_COUNT &&
compareAndSetState(c, c + SHARED_UNIT)) { // 尝试修改 c 状态成功,返回 1
// firstReader 之类的相关
if (r == 0) {
firstReader = current;
firstReaderHoldCount = 1;
} else if (firstReader == current) {
firstReaderHoldCount++;
} else {
HoldCounter rh = cachedHoldCounter;
if (rh == null || rh.tid != getThreadId(current))
cachedHoldCounter = rh = readHolds.get();
else if (rh.count == 0)
readHolds.set(rh);
rh.count++;
}
return 1;
}
// 修改 c 状态失败,自旋尝试获取 读锁
return fullTryAcquireShared(current);
}
/**
* 自旋尝试获取 读锁
*/
final int fullTryAcquireShared(Thread current) {
HoldCounter rh = null;
for (;;) {
int c = getState();
if (exclusiveCount(c) != 0) {
if (getExclusiveOwnerThread() != current)
return -1;
} else if (readerShouldBlock()) { // 认为当前线程获取 读锁 应该被阻塞,返回 -1
if (firstReader == current) {
// assert firstReaderHoldCount > 0;
} else {
if (rh == null) {
rh = cachedHoldCounter;
if (rh == null || rh.tid != getThreadId(current)) {
rh = readHolds.get();
if (rh.count == 0)
readHolds.remove();
}
}
if (rh.count == 0)
return -1;
}
}
if (sharedCount(c) == MAX_COUNT)
throw new Error("Maximum lock count exceeded");
if (compareAndSetState(c, c + SHARED_UNIT)) { // CAS 修改 c 的状态
if (sharedCount(c) == 0) {
firstReader = current;
firstReaderHoldCount = 1;
} else if (firstReader == current) {
firstReaderHoldCount++;
} else {
if (rh == null)
rh = cachedHoldCounter;
if (rh == null || rh.tid != getThreadId(current))
rh = readHolds.get();
else if (rh.count == 0)
readHolds.set(rh);
rh.count++;
cachedHoldCounter = rh; // cache for release
}
return 1;
}
}
}
/**
* 尝试获取 写锁,与 tryAcquire 区别在于没有 writerShouldBlock 方法的调用判断
*/
final boolean tryWriteLock() {
Thread current = Thread.currentThread();
int c = getState();
if (c != 0) {
int w = exclusiveCount(c);
if (w == 0 || current != getExclusiveOwnerThread())
return false;
if (w == MAX_COUNT)
throw new Error("Maximum lock count exceeded");
}
if (!compareAndSetState(c, c + 1))
return false;
setExclusiveOwnerThread(current);
return true;
}
/**
* 尝试获取 读锁,与 tryAcquireShared 区别在于没有 readerShouldBlock 方法的调用判断
*/
final boolean tryReadLock() {
Thread current = Thread.currentThread();
for (;;) {
int c = getState();
if (exclusiveCount(c) != 0 &&
getExclusiveOwnerThread() != current)
return false;
int r = sharedCount(c);
if (r == MAX_COUNT)
throw new Error("Maximum lock count exceeded");
if (compareAndSetState(c, c + SHARED_UNIT)) {
if (r == 0) {
firstReader = current;
firstReaderHoldCount = 1;
} else if (firstReader == current) {
firstReaderHoldCount++;
} else {
HoldCounter rh = cachedHoldCounter;
if (rh == null || rh.tid != getThreadId(current))
cachedHoldCounter = rh = readHolds.get();
else if (rh.count == 0)
readHolds.set(rh);
rh.count++;
}
return true;
}
}
}
protected final boolean isHeldExclusively() {
return getExclusiveOwnerThread() == Thread.currentThread();
}
// Methods relayed to outer class
final ConditionObject newCondition() {
return new ConditionObject();
}
final Thread getOwner() {
return ((exclusiveCount(getState()) == 0) ? null : getExclusiveOwnerThread());
}
final int getReadLockCount() {
return sharedCount(getState());
}
final boolean isWriteLocked() {
return exclusiveCount(getState()) != 0;
}
final int getWriteHoldCount() {
return isHeldExclusively() ? exclusiveCount(getState()) : 0;
}
final int getReadHoldCount() {
if (getReadLockCount() == 0)
return 0;
Thread current = Thread.currentThread();
if (firstReader == current)
return firstReaderHoldCount;
HoldCounter rh = cachedHoldCounter;
if (rh != null && rh.tid == getThreadId(current))
return rh.count;
int count = readHolds.get().count;
if (count == 0) readHolds.remove();
return count;
}
/**
* 序列化
*/
private void readObject(java.io.ObjectInputStream s)
throws java.io.IOException, ClassNotFoundException {
s.defaultReadObject();
readHolds = new ThreadLocalHoldCounter();
setState(0); // reset to unlocked state
}
final int getCount() { return getState(); }
}
NonfairSync
/**
* 非公平锁 实现类
*/
static final class NonfairSync extends Sync {
private static final long serialVersionUID = -8159625535654395037L;
final boolean writerShouldBlock() {
return false; // writers can always barge
}
final boolean readerShouldBlock() {
/* As a heuristic to avoid indefinite writer starvation,
* block if the thread that momentarily appears to be head
* of queue, if one exists, is a waiting writer. This is
* only a probabilistic effect since a new reader will not
* block if there is a waiting writer behind other enabled
* readers that have not yet drained from the queue.
*/
return apparentlyFirstQueuedIsExclusive();
}
}
FairSync
/**
* 公平锁 实现类
*/
static final class FairSync extends Sync {
private static final long serialVersionUID = -2274990926593161451L;
final boolean writerShouldBlock() {
return hasQueuedPredecessors();
}
final boolean readerShouldBlock() {
return hasQueuedPredecessors();
}
}
ReadLock
/**
* 读锁
*/
public static class ReadLock implements Lock, java.io.Serializable {
private static final long serialVersionUID = -5992448646407690164L;
private final Sync sync;
/**
* 构造函数
*/
protected ReadLock(ReentrantReadWriteLock lock) {
sync = lock.sync;
}
// --------------------- 实现了 Lock 类的 6 个方法-----------------
public void lock() {
sync.acquireShared(1); // 共享获锁
}
public void lockInterruptibly() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public boolean tryLock() {
return sync.tryReadLock();
}
public boolean tryLock(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public void unlock() {
sync.releaseShared(1);
}
public Condition newCondition() {
throw new UnsupportedOperationException();
}
public String toString() {
int r = sync.getReadLockCount();
return super.toString() + "[Read locks = " + r + "]";
}
}
WriteLock
public static class WriteLock implements Lock, java.io.Serializable {
private static final long serialVersionUID = -4992448646407690164L;
private final Sync sync;
/**
* 构造函数
*/
protected WriteLock(ReentrantReadWriteLock lock) {
sync = lock.sync;
}
public void lock() {
sync.acquire(1); // 独占写
}
public void lockInterruptibly() throws InterruptedException {
sync.acquireInterruptibly(1);
}
public boolean tryLock( ) {
return sync.tryWriteLock();
}
public boolean tryLock(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireNanos(1, unit.toNanos(timeout));
}
public void unlock() {
sync.release(1);
}
public Condition newCondition() {
return sync.newCondition();
}
public String toString() {
Thread o = sync.getOwner();
return super.toString() + ((o == null) ?
"[Unlocked]" :
"[Locked by thread " + o.getName() + "]");
}
public boolean isHeldByCurrentThread() {
return sync.isHeldExclusively();
}
public int getHoldCount() {
return sync.getWriteHoldCount();
}
}