From d88bba07dde6998b12b294ae95347defc4d5ecf8 Mon Sep 17 00:00:00 2001 From: dudaodong Date: Mon, 21 Apr 2025 10:49:48 +0800 Subject: [PATCH] feat: add TryKeyedLocker --- concurrency/keyed_locker.go | 57 ++++++++++++++++++++++++++++++++ concurrency/keyed_locker_test.go | 53 +++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+) diff --git a/concurrency/keyed_locker.go b/concurrency/keyed_locker.go index 8e7c555..a921228 100644 --- a/concurrency/keyed_locker.go +++ b/concurrency/keyed_locker.go @@ -191,3 +191,60 @@ func (l *RWKeyedLocker[K]) release(entry *rwLockEntry, rawKey K) { entry.timer.Store(timer) } } + +// TryKeyedLocker is a non-blocking version of KeyedLocker. +// It allows for trying to acquire a lock without blocking if the lock is already held. +type TryKeyedLocker[K comparable] struct { + mu sync.Mutex + locks map[K]*casMutex +} + +// NewTryKeyedLocker creates a new TryKeyedLocker. +func NewTryKeyedLocker[K comparable]() *TryKeyedLocker[K] { + return &TryKeyedLocker[K]{locks: make(map[K]*casMutex)} +} + +// TryLock tries to acquire a lock for the specified key. +// It returns true if the lock was acquired, false otherwise. +func (l *TryKeyedLocker[K]) TryLock(key K) bool { + l.mu.Lock() + + lock, ok := l.locks[key] + if !ok { + lock = &casMutex{} + l.locks[key] = lock + } + l.mu.Unlock() + + return lock.TryLock() +} + +// Unlock releases the lock for the specified key. +func (l *TryKeyedLocker[K]) Unlock(key K) { + l.mu.Lock() + defer l.mu.Unlock() + + lock, ok := l.locks[key] + if ok { + lock.Unlock() + if lock.lock == 0 { + delete(l.locks, key) + } + } +} + +// casMutex is a simple mutex that uses atomic operations to provide a non-blocking lock. +type casMutex struct { + lock int32 +} + +// TryLock tries to acquire the lock without blocking. +// It returns true if the lock was acquired, false otherwise. +func (m *casMutex) TryLock() bool { + return atomic.CompareAndSwapInt32(&m.lock, 0, 1) +} + +// Unlock releases the lock. +func (m *casMutex) Unlock() { + atomic.StoreInt32(&m.lock, 0) +} diff --git a/concurrency/keyed_locker_test.go b/concurrency/keyed_locker_test.go index 3e669a9..5c07752 100644 --- a/concurrency/keyed_locker_test.go +++ b/concurrency/keyed_locker_test.go @@ -175,3 +175,56 @@ func TestRWKeyedLocker_LockTimeout(t *testing.T) { assert.IsNotNil(err) } + +func TestTryKeyedLocker_SimpleLockUnlock(t *testing.T) { + t.Parallel() + assert := internal.NewAssert(t, "TestTryKeyedLocker_SimpleLockUnlock") + + locker := NewTryKeyedLocker[string]() + + ok := locker.TryLock("key1") + assert.Equal(true, ok) + + ok = locker.TryLock("key1") + assert.Equal(false, ok) + + locker.Unlock("key1") + + ok = locker.TryLock("key1") + assert.Equal(true, ok) + + locker.Unlock("key1") +} + +func TestTryKeyedLocker_ParallelTry(t *testing.T) { + t.Parallel() + assert := internal.NewAssert(t, "TestTryKeyedLocker_ParallelTry") + + locker := NewTryKeyedLocker[string]() + + var wg sync.WaitGroup + var mu sync.Mutex + var count int + + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + ok := locker.TryLock("key" + strconv.Itoa(i)) + mu.Lock() + if ok { + count++ + } + mu.Unlock() + time.Sleep(10 * time.Millisecond) + if ok { + locker.Unlock("key" + strconv.Itoa(i)) + } + }(i) + } + + wg.Wait() + + assert.Equal(5, count) + assert.Equal(0, len(locker.locks)) +}