mirror of
https://github.com/duke-git/lancet.git
synced 2026-02-04 12:52:28 +08:00
feat: add KeyedLocker
This commit is contained in:
93
concurrency/keyed_locker.go
Normal file
93
concurrency/keyed_locker.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
// Copyright 2025 dudaodong@gmail.com. All rights reserved.
|
||||||
|
// Use of this source code is governed by MIT license
|
||||||
|
|
||||||
|
// Package concurrency contain some functions to support concurrent programming. eg, goroutine, channel, locker.
|
||||||
|
|
||||||
|
package concurrency
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KeyedLocker is a simple implementation of a keyed locker that allows for non-blocking lock acquisition.
|
||||||
|
type KeyedLocker[K comparable] struct {
|
||||||
|
locks sync.Map
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type lockEntry struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
ref int32
|
||||||
|
timer atomic.Pointer[time.Timer]
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKeyedLocker creates a new KeyedLocker with the specified TTL for lock expiration.
|
||||||
|
// The TTL is used to automatically release locks that are no longer held.
|
||||||
|
func NewKeyedLocker[K comparable](ttl time.Duration) *KeyedLocker[K] {
|
||||||
|
return &KeyedLocker[K]{ttl: ttl}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do acquires a lock for the specified key and executes the provided function.
|
||||||
|
// It returns an error if the context is canceled before the function completes.
|
||||||
|
func (l *KeyedLocker[K]) Do(ctx context.Context, key K, fn func()) error {
|
||||||
|
entry := l.acquire(key)
|
||||||
|
defer l.release(key, entry, key)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
entry.mu.Lock()
|
||||||
|
defer entry.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
default:
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *KeyedLocker[K]) acquire(key K) *lockEntry {
|
||||||
|
lock, _ := l.locks.LoadOrStore(key, &lockEntry{})
|
||||||
|
entry := lock.(*lockEntry)
|
||||||
|
|
||||||
|
atomic.AddInt32(&entry.ref, 1)
|
||||||
|
if t := entry.timer.Swap(nil); t != nil {
|
||||||
|
t.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *KeyedLocker[K]) release(key K, entry *lockEntry, rawKey K) {
|
||||||
|
if atomic.AddInt32(&entry.ref, -1) == 0 {
|
||||||
|
entry.mu.Lock()
|
||||||
|
defer entry.mu.Unlock()
|
||||||
|
|
||||||
|
if entry.ref == 0 {
|
||||||
|
if t := entry.timer.Swap(nil); t != nil {
|
||||||
|
t.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
l.locks.Delete(rawKey)
|
||||||
|
} else {
|
||||||
|
if entry.timer.Load() == nil {
|
||||||
|
t := time.AfterFunc(l.ttl, func() {
|
||||||
|
l.release(key, entry, rawKey)
|
||||||
|
})
|
||||||
|
entry.timer.Store(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
106
concurrency/keyed_locker_test.go
Normal file
106
concurrency/keyed_locker_test.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package concurrency
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/duke-git/lancet/v2/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKeyedLocker_SerialExecutionSameKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert := internal.NewAssert(t, "TestKeyedLocker_SerialExecutionSameKey")
|
||||||
|
|
||||||
|
locker := NewKeyedLocker[string](100 * time.Millisecond)
|
||||||
|
|
||||||
|
var result []int
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
err := locker.Do(context.Background(), "key1", func() {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
result = append(result, i)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.IsNil(err)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
assert.Equal(5, len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyedLocker_ParallelExecutionDifferentKeys(t *testing.T) {
|
||||||
|
locker := NewKeyedLocker[string](100 * time.Millisecond)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
key := "key" + strconv.Itoa(i)
|
||||||
|
locker.Do(context.Background(), key, func() {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
})
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if elapsed > 100*time.Millisecond {
|
||||||
|
t.Errorf("parallel execution took too long: %s", elapsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyedLocker_ContextTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert := internal.NewAssert(t, "TestKeyedLocker_ContextTimeout")
|
||||||
|
|
||||||
|
locker := NewKeyedLocker[string](100 * time.Millisecond)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Lock key before calling
|
||||||
|
go func() {
|
||||||
|
_ = locker.Do(context.Background(), "key-timeout", func() {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Millisecond) // ensure lock is acquired first
|
||||||
|
|
||||||
|
err := locker.Do(ctx, "key-timeout", func() {
|
||||||
|
t.Error("should not execute")
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.IsNotNil(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyedLocker_LockReleaseAfterTTL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
assert := internal.NewAssert(t, "TestKeyedLocker_LockReleaseAfterTTL")
|
||||||
|
|
||||||
|
locker := NewKeyedLocker[string](50 * time.Millisecond)
|
||||||
|
|
||||||
|
err := locker.Do(context.Background(), "ttl-key", func() {})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for TTL to pass
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
err = locker.Do(context.Background(), "ttl-key", func() {})
|
||||||
|
assert.IsNil(err)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user