1
0
mirror of https://github.com/duke-git/lancet.git synced 2026-02-04 12:52:28 +08:00

feat: add ConcurrentMap

This commit is contained in:
dudaodong
2023-07-24 17:10:45 +08:00
parent fe0264f628
commit fe0cb04137
2 changed files with 272 additions and 0 deletions

154
maputil/concurrentmap.go Normal file
View File

@@ -0,0 +1,154 @@
// Copyright 2021 dudaodong@gmail.com. All rights reserved.
// Use of this source code is governed by MIT license
// Package maputil includes some functions to manipulate map.
package maputil
import (
"fmt"
"sync"
)
const defaultShardCount = 32
// ConcurrentMap is like map, but is safe for concurrent use by multiple goroutines.
type ConcurrentMap[K comparable, V any] struct {
shardCount uint64
locks []sync.RWMutex
maps []map[K]V
}
// NewConcurrentMap create a ConcurrentMap with specific shard count.
func NewConcurrentMap[K comparable, V any](shardCount int) *ConcurrentMap[K, V] {
if shardCount <= 0 {
shardCount = defaultShardCount
}
cm := &ConcurrentMap[K, V]{
shardCount: uint64(shardCount),
locks: make([]sync.RWMutex, shardCount),
maps: make([]map[K]V, shardCount),
}
for i := range cm.maps {
cm.maps[i] = make(map[K]V)
}
return cm
}
// Set the value for a key.
// Play: todo
func (cm *ConcurrentMap[K, V]) Set(key K, value V) {
shard := cm.getShard(key)
cm.locks[shard].Lock()
cm.maps[shard][key] = value
cm.locks[shard].Unlock()
}
// Get the value stored in the map for a key, or nil if no.
// Play: todo
func (cm *ConcurrentMap[K, V]) Get(key K) (V, bool) {
shard := cm.getShard(key)
cm.locks[shard].RLock()
value, ok := cm.maps[shard][key]
cm.locks[shard].RUnlock()
return value, ok
}
// GetOrSet returns the existing value for the key if present.
// Otherwise, it sets and returns the given value.
// Play: todo
func (cm *ConcurrentMap[K, V]) GetOrSet(key K, value V) (actual V, ok bool) {
shard := cm.getShard(key)
cm.locks[shard].RLock()
if actual, ok := cm.maps[shard][key]; ok {
cm.locks[shard].RUnlock()
return actual, ok
}
cm.locks[shard].RUnlock()
// lock again
cm.locks[shard].Lock()
if actual, ok = cm.maps[shard][key]; ok {
cm.locks[shard].Unlock()
return
}
cm.maps[shard][key] = value
cm.locks[shard].Unlock()
return value, ok
}
// Delete the value for a key.
// Play: todo
func (cm *ConcurrentMap[K, V]) Delete(key K) {
shard := cm.getShard(key)
cm.locks[shard].Lock()
delete(cm.maps[shard], key)
cm.locks[shard].Unlock()
}
// GetAndDelete returns the existing value for the key if present and then delete the value for the key.
// Otherwise, do nothing, just return false
// Play: todo
func (cm *ConcurrentMap[K, V]) GetAndDelete(key K) (actual V, ok bool) {
shard := cm.getShard(key)
cm.locks[shard].RLock()
if actual, ok = cm.maps[shard][key]; ok {
cm.locks[shard].RUnlock()
cm.Delete(key)
return
}
cm.locks[shard].RUnlock()
return actual, false
}
// Has checks if map has the value for a key.
// Play: todo
func (cm *ConcurrentMap[K, V]) Has(key K) bool {
_, ok := cm.Get(key)
return ok
}
// Range calls iterator sequentially for each key and value present in each of the shards in the map.
// If iterator returns false, range stops the iteration.
func (cm *ConcurrentMap[K, V]) Range(iterator func(key K, value V) bool) {
for shard := range cm.locks {
cm.locks[shard].RLock()
for k, v := range cm.maps[shard] {
if !iterator(k, v) {
cm.locks[shard].RUnlock()
return
}
}
cm.locks[shard].RUnlock()
}
}
// getShard get shard by a key.
func (cm *ConcurrentMap[K, V]) getShard(key K) uint64 {
hash := fnv32(fmt.Sprintf("%v", key))
return uint64(hash) % cm.shardCount
}
func fnv32(key string) uint32 {
hash := uint32(2166136261)
const prime32 = uint32(16777619)
keyLength := len(key)
for i := 0; i < keyLength; i++ {
hash *= prime32
hash ^= uint32(key[i])
}
return hash
}

View File

@@ -0,0 +1,118 @@
package maputil
import (
"fmt"
"sync"
"testing"
"github.com/duke-git/lancet/v2/internal"
)
func TestConcurrentMap_Set_Get(t *testing.T) {
assert := internal.NewAssert(t, "TestConcurrentMap_Set_Get")
cm := NewConcurrentMap[string, int](100)
var wg1 sync.WaitGroup
wg1.Add(10)
for i := 0; i < 10; i++ {
go func(n int) {
cm.Set(fmt.Sprintf("%d", n), n)
wg1.Done()
}(i)
}
wg1.Wait()
var wg2 sync.WaitGroup
wg2.Add(10)
for j := 0; j < 10; j++ {
go func(n int) {
val, ok := cm.Get(fmt.Sprintf("%d", n))
assert.Equal(n, val)
assert.Equal(true, ok)
wg2.Done()
}(j)
}
wg2.Wait()
}
func TestConcurrentMap_GetOrSet(t *testing.T) {
assert := internal.NewAssert(t, "TestConcurrentMap_GetOrSet")
cm := NewConcurrentMap[string, int](100)
for i := 0; i < 5; i++ {
go func(n int) {
val, ok := cm.GetOrSet(fmt.Sprintf("%d", n), n)
assert.Equal(n, val)
assert.Equal(false, ok)
}(i)
}
for j := 0; j < 5; j++ {
go func(n int) {
val, ok := cm.Get(fmt.Sprintf("%d", n))
assert.Equal(n, val)
assert.Equal(true, ok)
}(j)
}
}
func TestConcurrentMap_Delete(t *testing.T) {
assert := internal.NewAssert(t, "TestConcurrentMap_Delete")
cm := NewConcurrentMap[string, int](100)
var wg1 sync.WaitGroup
wg1.Add(10)
for i := 0; i < 10; i++ {
go func(n int) {
cm.Set(fmt.Sprintf("%d", n), n)
wg1.Done()
}(i)
}
wg1.Wait()
var wg2 sync.WaitGroup
wg2.Add(10)
for i := 0; i < 10; i++ {
go func(n int) {
cm.Delete(fmt.Sprintf("%d", n))
wg2.Done()
}(i)
}
wg2.Wait()
for j := 0; j < 10; j++ {
go func(n int) {
_, ok := cm.Get(fmt.Sprintf("%d", n))
assert.Equal(false, ok)
}(j)
}
}
func TestConcurrentMap_GetAndDelete(t *testing.T) {
assert := internal.NewAssert(t, "TestConcurrentMap_GetAndDelete")
cm := NewConcurrentMap[string, int](100)
for i := 0; i < 10; i++ {
go func(n int) {
cm.Set(fmt.Sprintf("%d", n), n)
}(i)
}
for j := 0; j < 10; j++ {
go func(n int) {
val, ok := cm.GetAndDelete(fmt.Sprintf("%d", n))
assert.Equal(n, val)
assert.Equal(true, ok)
_, ok = cm.Get(fmt.Sprintf("%d", n))
assert.Equal(false, ok)
}(j)
}
}