💠 ReentrantReadWriteLock

吞佛童子2022年10月10日
  • Java
  • concurrency
大约 6 分钟

💠 ReentrantReadWriteLock

1. 类注释

  • 将 AQS 的 state 分为 读锁[高 16 位] & 写锁[低 16 位] 两部分
    • 读锁
      • 无写锁 || 当前线程持有写锁 的情况下,可尝试获锁
      • 共享锁
    • 写锁
      • 只有 没有其他线程获得读锁 | 写锁 的情况下,才可尝试获锁
      • 独占 & 可重入
  • 简单使用
 class RWDictionary {
   private final Map<String, Data> m = new TreeMap<String, Data>();
   
   private final ReentrantReadWriteLock rwl = new ReentrantReadWriteLock();
   private final Lock r = rwl.readLock(); // 读写锁的 读锁
   private final Lock w = rwl.writeLock(); // 读写锁的 写锁

   public Data get(String key) { // 读锁读
     r.lock();
     try { return m.get(key); }
     finally { r.unlock(); }
   }
   public String[] allKeys() { // 读锁 读取全部
     r.lock();
     try { return m.keySet().toArray(); }
     finally { r.unlock(); }
   }
   public Data put(String key, Data value) { // 写锁 写入
     w.lock();
     try { return m.put(key, value); }
     finally { w.unlock(); }
   }
   public void clear() { // 写锁 清空
     w.lock();
     try { m.clear(); }
     finally { w.unlock(); }
   }
 }

2. 类图

public class ReentrantReadWriteLock implements ReadWriteLock, java.io.Serializable {
    // ...
    }

img_6.png


3. 属性

    private static final long serialVersionUID = -6992448646407690164L;

    private final ReentrantReadWriteLock.ReadLock readerLock; // 读锁
    private final ReentrantReadWriteLock.WriteLock writerLock; // 写锁

    final Sync sync;
    
    public ReentrantReadWriteLock.WriteLock writeLock() { return writerLock; }
    public ReentrantReadWriteLock.ReadLock  readLock()  { return readerLock; }

4. 构造函数

    /**
     * 默认创建 非公平读写锁
     */
    public ReentrantReadWriteLock() {
        this(false);
    }

    /**
     * 创建 指定策略 的读写锁
     */
    public ReentrantReadWriteLock(boolean fair) {
        sync = fair ? new FairSync() : new NonfairSync();
        readerLock = new ReadLock(this);
        writerLock = new WriteLock(this);
    }

5. 内部类

Sync

    /**
     * Synchronization implementation for ReentrantReadWriteLock.
     * Subclassed into fair and nonfair versions.
     */
    abstract static class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 6317671515068378041L;

        /*
         * 将 AQS 的 state int 属性分为两部分,高 16 位作为读锁的使用,低 16 位为写锁 使用
         */
        static final int SHARED_SHIFT   = 16;
        static final int SHARED_UNIT    = (1 << SHARED_SHIFT);
        static final int MAX_COUNT      = (1 << SHARED_SHIFT) - 1;
        static final int EXCLUSIVE_MASK = (1 << SHARED_SHIFT) - 1;

        /** 读锁,读取高 16 位,即右移 16 位  */
        static int sharedCount(int c)    { return c >>> SHARED_SHIFT; }
        /** 写锁,读取低 16 位,即将高位全部抹消  */
        static int exclusiveCount(int c) { return c & EXCLUSIVE_MASK; }

        /**
         * 静态内部类
         */
        static final class HoldCounter {
            int count = 0;
            // Use id, not reference, to avoid garbage retention
            final long tid = getThreadId(Thread.currentThread());
        }
        /**
         * 静态内部类继承了 ThreadLocal
         */
        static final class ThreadLocalHoldCounter extends ThreadLocal<HoldCounter> {
            public HoldCounter initialValue() {
                return new HoldCounter();
            }
        }
        /**
         * 当前线程持有的 读锁 个数,只在 构造函数 & readObject 时被初始化,当 == 0 时被移除
         */
        private transient ThreadLocalHoldCounter readHolds;
        /**
         * 最后一个线程持有的 读锁 数量
         */
        private transient HoldCounter cachedHoldCounter;

        /**
         * 首个获取该 读锁 的线程
         */
        private transient Thread firstReader = null;
        private transient int firstReaderHoldCount; // firstReader 线程持有的 读锁 个数

        // -----------------------------------------------------------------------
        
        // 构造函数
        Sync() {
            readHolds = new ThreadLocalHoldCounter();
            setState(getState()); // ensures visibility of readHolds
        }

        // ---------------------------- method ----------------------------------
        
        /**
         * 抽象方法,判断当前线程获取读锁时,是否应该被阻塞
         */
        abstract boolean readerShouldBlock();

        /**
         * 抽象方法,判断当前线程尝试获取 写锁 时,是否应该被阻塞
         */
        abstract boolean writerShouldBlock();

        // 重写父类方法,尝试释放 写锁
        protected final boolean tryRelease(int releases) {
            if (!isHeldExclusively())
                throw new IllegalMonitorStateException();
            int nextc = getState() - releases;
            boolean free = exclusiveCount(nextc) == 0;
            if (free)
                setExclusiveOwnerThread(null);
            setState(nextc);
            return free;
        }

        // 重写父类方法,尝试获取 写锁[独占锁]
        protected final boolean tryAcquire(int acquires) {
            Thread current = Thread.currentThread();
            int c = getState();
            int w = exclusiveCount(c); // 根据 c 获取独占锁 的值
            if (c != 0) {
                // c != 0 说明有 锁的存在,可能是 读锁,也可能是 写锁
                // w == 0 说明有了 读锁,此时无法获取写锁
                // current != getExclusiveOwnerThread() 说明获取读锁的不是当前线程,此时也无法获取写锁
                if (w == 0 || current != getExclusiveOwnerThread())
                    return false;
                if (w + exclusiveCount(acquires) > MAX_COUNT)
                    throw new Error("Maximum lock count exceeded");
                // 进入当前,只可能是 获取读锁的是当前线程,修改 c 的状态,写锁为低 16 位,因此直接相加
                setState(c + acquires);
                return true;
            }
            // c == 0 说明没有任何锁
            if (writerShouldBlock() ||
                !compareAndSetState(c, c + acquires)) // CAS 修改 c 状态失败
                return false;
            // 获写锁成功
            setExclusiveOwnerThread(current);
            return true;
        }

        // 重写父类方法,尝试释放 读锁,若释放后 c == 0 返回 true;否则返回 false
        protected final boolean tryReleaseShared(int unused) {
            Thread current = Thread.currentThread();
            if (firstReader == current) { // 当前线程就是 首个获取 读锁的线程,修改 firstReaderHoldCount 状态
                if (firstReaderHoldCount == 1)
                    firstReader = null;
                else
                    firstReaderHoldCount--;
            } else { // 当前线程不是 首个获取 读锁的线程,
                HoldCounter rh = cachedHoldCounter;
                if (rh == null || rh.tid != getThreadId(current))
                    rh = readHolds.get();
                int count = rh.count;
                if (count <= 1) {
                    readHolds.remove();
                    if (count <= 0)
                        throw unmatchedUnlockException();
                }
                --rh.count;
            }
            for (;;) {
                int c = getState();
                int nextc = c - SHARED_UNIT;
                if (compareAndSetState(c, nextc)) // 修改 c 的 读锁个数
                    return nextc == 0;
            }
        }

        private IllegalMonitorStateException unmatchedUnlockException() {
            return new IllegalMonitorStateException(
                "attempt to unlock read lock, not locked by current thread");
        }

        // 重写父类方法,尝试获取 读锁
        protected final int tryAcquireShared(int unused) {
            Thread current = Thread.currentThread();
            int c = getState();
            // 有线程获取了 写锁 & 非当前线程获取的 写锁,无法获取读锁,直接返回 -1
            if (exclusiveCount(c) != 0 &&
                getExclusiveOwnerThread() != current)
                return -1;
                
            int r = sharedCount(c); // 读锁个数
            if (!readerShouldBlock() &&
                r < MAX_COUNT &&
                compareAndSetState(c, c + SHARED_UNIT)) { // 尝试修改 c 状态成功,返回 1
                // firstReader 之类的相关
                if (r == 0) {
                    firstReader = current;
                    firstReaderHoldCount = 1;
                } else if (firstReader == current) {
                    firstReaderHoldCount++;
                } else {
                    HoldCounter rh = cachedHoldCounter;
                    if (rh == null || rh.tid != getThreadId(current))
                        cachedHoldCounter = rh = readHolds.get();
                    else if (rh.count == 0)
                        readHolds.set(rh);
                    rh.count++;
                }
                return 1;
            }
            // 修改 c 状态失败,自旋尝试获取 读锁
            return fullTryAcquireShared(current);
        }

        /**
         * 自旋尝试获取 读锁
         */
        final int fullTryAcquireShared(Thread current) {
            HoldCounter rh = null;
            for (;;) {
                int c = getState();
                if (exclusiveCount(c) != 0) {
                    if (getExclusiveOwnerThread() != current)
                        return -1;
                } else if (readerShouldBlock()) { // 认为当前线程获取 读锁 应该被阻塞,返回 -1
                    if (firstReader == current) {
                        // assert firstReaderHoldCount > 0;
                    } else {
                        if (rh == null) {
                            rh = cachedHoldCounter;
                            if (rh == null || rh.tid != getThreadId(current)) {
                                rh = readHolds.get();
                                if (rh.count == 0)
                                    readHolds.remove();
                            }
                        }
                        if (rh.count == 0)
                            return -1;
                    }
                }
                if (sharedCount(c) == MAX_COUNT)
                    throw new Error("Maximum lock count exceeded");
                if (compareAndSetState(c, c + SHARED_UNIT)) { // CAS 修改 c 的状态
                    if (sharedCount(c) == 0) {
                        firstReader = current;
                        firstReaderHoldCount = 1;
                    } else if (firstReader == current) {
                        firstReaderHoldCount++;
                    } else {
                        if (rh == null)
                            rh = cachedHoldCounter;
                        if (rh == null || rh.tid != getThreadId(current))
                            rh = readHolds.get();
                        else if (rh.count == 0)
                            readHolds.set(rh);
                        rh.count++;
                        cachedHoldCounter = rh; // cache for release
                    }
                    return 1;
                }
            }
        }

        /**
         * 尝试获取 写锁,与 tryAcquire 区别在于没有 writerShouldBlock 方法的调用判断
         */
        final boolean tryWriteLock() {
            Thread current = Thread.currentThread();
            int c = getState();
            if (c != 0) {
                int w = exclusiveCount(c);
                if (w == 0 || current != getExclusiveOwnerThread())
                    return false;
                if (w == MAX_COUNT)
                    throw new Error("Maximum lock count exceeded");
            }
            if (!compareAndSetState(c, c + 1))
                return false;
            setExclusiveOwnerThread(current);
            return true;
        }

        /**
         * 尝试获取 读锁,与 tryAcquireShared 区别在于没有 readerShouldBlock 方法的调用判断
         */
        final boolean tryReadLock() {
            Thread current = Thread.currentThread();
            for (;;) {
                int c = getState();
                if (exclusiveCount(c) != 0 &&
                    getExclusiveOwnerThread() != current)
                    return false;
                int r = sharedCount(c);
                if (r == MAX_COUNT)
                    throw new Error("Maximum lock count exceeded");
                if (compareAndSetState(c, c + SHARED_UNIT)) {
                    if (r == 0) {
                        firstReader = current;
                        firstReaderHoldCount = 1;
                    } else if (firstReader == current) {
                        firstReaderHoldCount++;
                    } else {
                        HoldCounter rh = cachedHoldCounter;
                        if (rh == null || rh.tid != getThreadId(current))
                            cachedHoldCounter = rh = readHolds.get();
                        else if (rh.count == 0)
                            readHolds.set(rh);
                        rh.count++;
                    }
                    return true;
                }
            }
        }

        protected final boolean isHeldExclusively() {
            return getExclusiveOwnerThread() == Thread.currentThread();
        }

        // Methods relayed to outer class

        final ConditionObject newCondition() {
            return new ConditionObject();
        }

        final Thread getOwner() {
            return ((exclusiveCount(getState()) == 0) ? null : getExclusiveOwnerThread());
        }

        final int getReadLockCount() {
            return sharedCount(getState());
        }

        final boolean isWriteLocked() {
            return exclusiveCount(getState()) != 0;
        }

        final int getWriteHoldCount() {
            return isHeldExclusively() ? exclusiveCount(getState()) : 0;
        }

        final int getReadHoldCount() {
            if (getReadLockCount() == 0)
                return 0;

            Thread current = Thread.currentThread();
            if (firstReader == current)
                return firstReaderHoldCount;

            HoldCounter rh = cachedHoldCounter;
            if (rh != null && rh.tid == getThreadId(current))
                return rh.count;

            int count = readHolds.get().count;
            if (count == 0) readHolds.remove();
            return count;
        }

        /**
         * 序列化
         */
        private void readObject(java.io.ObjectInputStream s)
            throws java.io.IOException, ClassNotFoundException {
            s.defaultReadObject();
            readHolds = new ThreadLocalHoldCounter();
            setState(0); // reset to unlocked state
        }

        final int getCount() { return getState(); }
    }

NonfairSync

    /**
     * 非公平锁 实现类
     */
    static final class NonfairSync extends Sync {
        private static final long serialVersionUID = -8159625535654395037L;
        
        final boolean writerShouldBlock() {
            return false; // writers can always barge
        }
        
        final boolean readerShouldBlock() {
            /* As a heuristic to avoid indefinite writer starvation,
             * block if the thread that momentarily appears to be head
             * of queue, if one exists, is a waiting writer.  This is
             * only a probabilistic effect since a new reader will not
             * block if there is a waiting writer behind other enabled
             * readers that have not yet drained from the queue.
             */
            return apparentlyFirstQueuedIsExclusive();
        }
    }

FairSync

    /**
     * 公平锁 实现类
     */
    static final class FairSync extends Sync {
        private static final long serialVersionUID = -2274990926593161451L;
        
        final boolean writerShouldBlock() {
            return hasQueuedPredecessors();
        }
        final boolean readerShouldBlock() {
            return hasQueuedPredecessors();
        }
    }

ReadLock

    /**
     * 读锁
     */
    public static class ReadLock implements Lock, java.io.Serializable {
        private static final long serialVersionUID = -5992448646407690164L;
        private final Sync sync;

        /**
         * 构造函数 
         */
        protected ReadLock(ReentrantReadWriteLock lock) {
            sync = lock.sync;
        }

        // --------------------- 实现了 Lock 类的 6 个方法-----------------
        
        public void lock() {
            sync.acquireShared(1); // 共享获锁
        }

        public void lockInterruptibly() throws InterruptedException {
            sync.acquireSharedInterruptibly(1);
        }

        public boolean tryLock() {
            return sync.tryReadLock();
        }

        public boolean tryLock(long timeout, TimeUnit unit)
                throws InterruptedException {
            return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
        }

        public void unlock() {
            sync.releaseShared(1);
        }

        public Condition newCondition() {
            throw new UnsupportedOperationException();
        }
        public String toString() {
            int r = sync.getReadLockCount();
            return super.toString() + "[Read locks = " + r + "]";
        }
    }

WriteLock

    public static class WriteLock implements Lock, java.io.Serializable {
        private static final long serialVersionUID = -4992448646407690164L;
        private final Sync sync;

        /**
         * 构造函数
         */
        protected WriteLock(ReentrantReadWriteLock lock) {
            sync = lock.sync;
        }

        public void lock() {
            sync.acquire(1); // 独占写
        }

        public void lockInterruptibly() throws InterruptedException {
            sync.acquireInterruptibly(1);
        }

        public boolean tryLock( ) {
            return sync.tryWriteLock();
        }

        public boolean tryLock(long timeout, TimeUnit unit)
                throws InterruptedException {
            return sync.tryAcquireNanos(1, unit.toNanos(timeout));
        }

        public void unlock() {
            sync.release(1);
        }

        public Condition newCondition() {
            return sync.newCondition();
        }

        public String toString() {
            Thread o = sync.getOwner();
            return super.toString() + ((o == null) ?
                                       "[Unlocked]" :
                                       "[Locked by thread " + o.getName() + "]");
        }

        public boolean isHeldByCurrentThread() {
            return sync.isHeldExclusively();
        }

        public int getHoldCount() {
            return sync.getWriteHoldCount();
        }
    }
上次编辑于: 2022/10/10 下午8:43:48
贡献者: liuxianzhishou