diff --git a/concurrency/keyed_locker.go b/concurrency/keyed_locker.go new file mode 100644 index 0000000..8d3d7f7 --- /dev/null +++ b/concurrency/keyed_locker.go @@ -0,0 +1,93 @@ +// Copyright 2025 dudaodong@gmail.com. All rights reserved. +// Use of this source code is governed by MIT license + +// Package concurrency contain some functions to support concurrent programming. eg, goroutine, channel, locker. + +package concurrency + +import ( + "context" + "sync" + "sync/atomic" + "time" +) + +// KeyedLocker is a simple implementation of a keyed locker that allows for non-blocking lock acquisition. +type KeyedLocker[K comparable] struct { + locks sync.Map + ttl time.Duration +} + +type lockEntry struct { + mu sync.Mutex + ref int32 + timer atomic.Pointer[time.Timer] +} + +// NewKeyedLocker creates a new KeyedLocker with the specified TTL for lock expiration. +// The TTL is used to automatically release locks that are no longer held. +func NewKeyedLocker[K comparable](ttl time.Duration) *KeyedLocker[K] { + return &KeyedLocker[K]{ttl: ttl} +} + +// Do acquires a lock for the specified key and executes the provided function. +// It returns an error if the context is canceled before the function completes. +func (l *KeyedLocker[K]) Do(ctx context.Context, key K, fn func()) error { + entry := l.acquire(key) + defer l.release(key, entry, key) + + done := make(chan struct{}) + + go func() { + entry.mu.Lock() + defer entry.mu.Unlock() + + select { + case <-ctx.Done(): + default: + fn() + } + close(done) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + return nil + } +} + +func (l *KeyedLocker[K]) acquire(key K) *lockEntry { + lock, _ := l.locks.LoadOrStore(key, &lockEntry{}) + entry := lock.(*lockEntry) + + atomic.AddInt32(&entry.ref, 1) + if t := entry.timer.Swap(nil); t != nil { + t.Stop() + } + + return entry +} + +func (l *KeyedLocker[K]) release(key K, entry *lockEntry, rawKey K) { + if atomic.AddInt32(&entry.ref, -1) == 0 { + entry.mu.Lock() + defer entry.mu.Unlock() + + if entry.ref == 0 { + if t := entry.timer.Swap(nil); t != nil { + t.Stop() + } + + l.locks.Delete(rawKey) + } else { + if entry.timer.Load() == nil { + t := time.AfterFunc(l.ttl, func() { + l.release(key, entry, rawKey) + }) + entry.timer.Store(t) + } + } + } +} diff --git a/concurrency/keyed_locker_test.go b/concurrency/keyed_locker_test.go new file mode 100644 index 0000000..799d6f8 --- /dev/null +++ b/concurrency/keyed_locker_test.go @@ -0,0 +1,106 @@ +package concurrency + +import ( + "context" + "strconv" + "sync" + "testing" + "time" + + "github.com/duke-git/lancet/v2/internal" +) + +func TestKeyedLocker_SerialExecutionSameKey(t *testing.T) { + t.Parallel() + assert := internal.NewAssert(t, "TestKeyedLocker_SerialExecutionSameKey") + + locker := NewKeyedLocker[string](100 * time.Millisecond) + + var result []int + var mu sync.Mutex + + wg := sync.WaitGroup{} + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + err := locker.Do(context.Background(), "key1", func() { + time.Sleep(10 * time.Millisecond) + mu.Lock() + defer mu.Unlock() + result = append(result, i) + }) + + assert.IsNil(err) + }(i) + } + wg.Wait() + + assert.Equal(5, len(result)) +} + +func TestKeyedLocker_ParallelExecutionDifferentKeys(t *testing.T) { + locker := NewKeyedLocker[string](100 * time.Millisecond) + + start := time.Now() + wg := sync.WaitGroup{} + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + key := "key" + strconv.Itoa(i) + locker.Do(context.Background(), key, func() { + time.Sleep(50 * time.Millisecond) + }) + }(i) + } + wg.Wait() + elapsed := time.Since(start) + + if elapsed > 100*time.Millisecond { + t.Errorf("parallel execution took too long: %s", elapsed) + } +} + +func TestKeyedLocker_ContextTimeout(t *testing.T) { + t.Parallel() + assert := internal.NewAssert(t, "TestKeyedLocker_ContextTimeout") + + locker := NewKeyedLocker[string](100 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + // Lock key before calling + go func() { + _ = locker.Do(context.Background(), "key-timeout", func() { + time.Sleep(50 * time.Millisecond) + }) + }() + + time.Sleep(1 * time.Millisecond) // ensure lock is acquired first + + err := locker.Do(ctx, "key-timeout", func() { + t.Error("should not execute") + }) + + assert.IsNotNil(err) +} + +func TestKeyedLocker_LockReleaseAfterTTL(t *testing.T) { + t.Parallel() + assert := internal.NewAssert(t, "TestKeyedLocker_LockReleaseAfterTTL") + + locker := NewKeyedLocker[string](50 * time.Millisecond) + + err := locker.Do(context.Background(), "ttl-key", func() {}) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + // Wait for TTL to pass + time.Sleep(100 * time.Millisecond) + + err = locker.Do(context.Background(), "ttl-key", func() {}) + assert.IsNil(err) +}