1
0
mirror of https://github.com/duke-git/lancet.git synced 2026-02-04 12:52:28 +08:00

feat: a simple enum implementation

This commit is contained in:
dudaodong
2025-09-30 14:56:25 +08:00
parent 7d4b9510a2
commit 309b07ae8a
2 changed files with 445 additions and 0 deletions

282
enum/enum.go Normal file
View File

@@ -0,0 +1,282 @@
// Copyright 2025 dudaodong@gmail.com. All rights resulterved.
// Use of this source code is governed by MIT license
// Package enum provides a simple enum implementation.
package enum
import (
"encoding/json"
"fmt"
"reflect"
"sync"
)
// Enum defines a common enum interface.
type Enum[T comparable] interface {
Value() T
String() string
Valid() bool
}
// Item defines a common enum item struct implement Enum interface.
type Item[T comparable] struct {
value T
name string
}
func NewItem[T comparable](value T, name string) *Item[T] {
return &Item[T]{value: value, name: name}
}
// NewWithItems creates enum items from value-name pairs.
// It requires an even number of arguments, where each pair consists of a value of type T and a string name.
// If the number of arguments is odd or if any argument does not match the expected type, it will panic.
func NewItems[T comparable](args ...any) []*Item[T] {
if len(args) == 0 {
return []*Item[T]{}
}
if len(args)%2 != 0 {
panic("New requires even number of arguments (value, name pairs)")
}
items := make([]*Item[T], 0, len(args)/2)
for i := 0; i < len(args); i += 2 {
value, ok := args[i].(T)
if !ok {
panic(fmt.Sprintf("argument %d is not of type %T", i, value))
}
name, ok := args[i+1].(string)
if !ok {
panic(fmt.Sprintf("argument %d is not of type string", i+1))
}
items = append(items, &Item[T]{value: value, name: name})
}
return items
}
func (e *Item[T]) Value() T {
return e.value
}
func (e *Item[T]) Name() string {
return e.name
}
// Valid checks if the enum item is valid. If a custom check function is provided, it will be used to validate the value.
func (e *Item[T]) Valid(check ...func(T) bool) bool {
if len(check) > 0 {
return check[0](e.value) && e.name != ""
}
var zero T
return e.value != zero && e.name != ""
}
// MarshalJSON implements the json.Marshaler interface.
func (e *Item[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"value": e.value,
"name": e.name,
})
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (e *Item[T]) UnmarshalJSON(data []byte) error {
type alias struct {
Value any `json:"value"`
Name string `json:"name"`
}
var temp alias
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
var v T
rv := reflect.TypeOf(v)
switch rv.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
val, ok := temp.Value.(float64)
if !ok {
return fmt.Errorf("invalid type for value, want int family")
}
converted := reflect.ValueOf(int64(val)).Convert(rv)
e.value = converted.Interface().(T)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
val, ok := temp.Value.(float64)
if !ok {
return fmt.Errorf("invalid type for value, want uint family")
}
converted := reflect.ValueOf(uint64(val)).Convert(rv)
e.value = converted.Interface().(T)
case reflect.Float32, reflect.Float64:
val, ok := temp.Value.(float64)
if !ok {
return fmt.Errorf("invalid type for value, want float family")
}
converted := reflect.ValueOf(val).Convert(rv)
e.value = converted.Interface().(T)
case reflect.String:
val, ok := temp.Value.(string)
if !ok {
return fmt.Errorf("invalid type for value, want string")
}
e.value = any(val).(T)
case reflect.Bool:
val, ok := temp.Value.(bool)
if !ok {
return fmt.Errorf("invalid type for value, want bool")
}
e.value = any(val).(T)
default:
// 枚举类型底层通常是 int可以尝试 float64->int64->底层类型
val, ok := temp.Value.(float64)
if ok {
converted := reflect.ValueOf(int64(val)).Convert(rv)
e.value = converted.Interface().(T)
} else {
val2, ok2 := temp.Value.(T)
if !ok2 {
return fmt.Errorf("invalid type for value")
}
e.value = val2
}
}
e.name = temp.Name
return nil
}
// Registry defines a common enum registry struct.
type Registry[T comparable] struct {
mu sync.RWMutex
values map[T]*Item[T]
names map[string]*Item[T]
items []*Item[T]
}
func NewRegistry[T comparable](items ...*Item[T]) *Registry[T] {
r := &Registry[T]{
values: make(map[T]*Item[T]),
names: make(map[string]*Item[T]),
items: make([]*Item[T], 0, len(items)),
}
r.Add(items...)
return r
}
// Add adds enum items to the registry.
func (r *Registry[T]) Add(items ...*Item[T]) {
r.mu.Lock()
defer r.mu.Unlock()
for _, item := range items {
if _, exists := r.values[item.value]; exists {
continue
}
if _, exists := r.names[item.name]; exists {
continue
}
r.values[item.value] = item
r.names[item.name] = item
r.items = append(r.items, item)
}
}
// Remove removes an enum item from the registry by its value.
func (r *Registry[T]) Remove(value T) bool {
r.mu.Lock()
defer r.mu.Unlock()
item, ok := r.values[value]
if !ok {
return false
}
delete(r.values, value)
delete(r.names, item.name)
for i, it := range r.items {
if it.value == value {
r.items = append(r.items[:i], r.items[i+1:]...)
break
}
}
return true
}
// Update updates the name of an enum item in the registry by its value.
func (r *Registry[T]) Update(value T, newName string) bool {
r.mu.Lock()
defer r.mu.Unlock()
item, ok := r.values[value]
if !ok {
return false
}
delete(r.names, item.name)
item.name = newName
r.names[newName] = item
return true
}
// GetByValue retrieves an enum item by its value.
func (r *Registry[T]) GetByValue(value T) (*Item[T], bool) {
r.mu.RLock()
defer r.mu.RUnlock()
item, ok := r.values[value]
return item, ok
}
// GetByName retrieves an enum item by its name.
func (r *Registry[T]) GetByName(name string) (*Item[T], bool) {
r.mu.RLock()
defer r.mu.RUnlock()
item, ok := r.names[name]
return item, ok
}
// Items returns a slice of all enum items in the registry.
func (r *Registry[T]) Items() []*Item[T] {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]*Item[T], len(r.items))
copy(result, r.items)
return result
}
// Contains checks if an enum item with the given value exists in the registry.
func (r *Registry[T]) Contains(value T) bool {
_, ok := r.GetByValue(value)
return ok
}
// Validate checks if the given value is a valid enum item in the registry.
func (r *Registry[T]) Validate(value T) error {
if !r.Contains(value) {
return fmt.Errorf("invalid enum value: %v", value)
}
return nil
}
// ValidateAll checks if all given values are valid enum items in the registry.
func (r *Registry[T]) ValidateAll(values ...T) error {
for _, value := range values {
if err := r.Validate(value); err != nil {
return err
}
}
return nil
}
// Size returns the number of enum items in the registry.
func (r *Registry[T]) Size() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.items)
}

163
enum/enum_test.go Normal file
View File

@@ -0,0 +1,163 @@
// Copyright 2025 dudaodong@gmail.com. All rights resulterved.
// Use of this source code is governed by MIT license
package enum
import (
"testing"
"github.com/duke-git/lancet/v2/internal"
)
type Status int
const (
Unknown Status = iota
Active
Inactive
)
func TestNewItem(t *testing.T) {
t.Parallel()
assert := internal.NewAssert(t, "TestNewItem")
items := NewItems[Status](
Active, "Active",
Inactive, "Inactive",
)
assert.Equal(2, len(items))
assert.Equal(Active, items[0].Value())
assert.Equal("Active", items[0].Name())
assert.Equal(Inactive, items[1].Value())
assert.Equal("Inactive", items[1].Name())
}
func TestItem_Valid(t *testing.T) {
t.Parallel()
assert := internal.NewAssert(t, "TestItem_Valid")
item := NewItem(Active, "Active")
assert.Equal(true, item.Valid())
invalidItem := NewItem(Unknown, "")
assert.Equal(false, invalidItem.Valid())
}
func TestItem_MarshalJSON(t *testing.T) {
t.Parallel()
assert := internal.NewAssert(t, "TestItem_MarshalJSON")
item := NewItem(Active, "Active")
data, err := item.MarshalJSON()
assert.IsNil(err)
assert.Equal("{\"name\":\"Active\",\"value\":1}", string(data))
var unmarshaledItem Item[Status]
err = unmarshaledItem.UnmarshalJSON(data)
assert.IsNil(err)
assert.Equal(item.Value(), unmarshaledItem.Value())
assert.Equal(item.Name(), unmarshaledItem.Name())
}
func TestRegistry_AddAndGet(t *testing.T) {
t.Parallel()
assert := internal.NewAssert(t, "TestRegistry_AddAndGet")
registry := NewRegistry[Status]()
item1 := NewItem(Active, "Active")
item2 := NewItem(Inactive, "Inactive")
registry.Add(item1, item2)
assert.Equal(2, registry.Size())
item, ok := registry.GetByValue(Active)
assert.Equal(true, ok)
assert.Equal("Active", item.Name())
item, ok = registry.GetByName("Inactive")
assert.Equal(true, ok)
assert.Equal(Inactive, item.Value())
}
func TestRegistry_Remove(t *testing.T) {
t.Parallel()
assert := internal.NewAssert(t, "TestRegistry_Remove")
registry := NewRegistry[Status]()
item1 := NewItem(Active, "Active")
item2 := NewItem(Inactive, "Inactive")
registry.Add(item1, item2)
assert.Equal(2, registry.Size())
removed := registry.Remove(Active)
assert.Equal(true, removed)
assert.Equal(1, registry.Size())
_, ok := registry.GetByValue(Active)
assert.Equal(false, ok)
}
func TestRegistry_Update(t *testing.T) {
t.Parallel()
assert := internal.NewAssert(t, "TestRegistry_Update")
registry := NewRegistry[Status]()
item1 := NewItem(Active, "Active")
registry.Add(item1)
updated := registry.Update(Active, "Activated")
assert.Equal(true, updated)
item, ok := registry.GetByValue(Active)
assert.Equal(true, ok)
assert.Equal("Activated", item.Name())
}
func TestRegistry_Contains(t *testing.T) {
t.Parallel()
assert := internal.NewAssert(t, "TestRegistry_Contains")
registry := NewRegistry[Status]()
item1 := NewItem(Active, "Active")
registry.Add(item1)
assert.Equal(true, registry.Contains(Active))
assert.Equal(false, registry.Contains(Inactive))
}
func TestRegistry_Validate(t *testing.T) {
t.Parallel()
assert := internal.NewAssert(t, "TestRegistry_Validate")
registry := NewRegistry[Status]()
item1 := NewItem(Active, "Active")
item2 := NewItem(Inactive, "Inactive")
registry.Add(item1, item2)
err := registry.Validate(Active)
assert.IsNil(err)
err = registry.Validate(Inactive)
assert.IsNil(err)
err = registry.Validate(Unknown)
assert.IsNotNil(err)
}
func TestRegistry_ValidateAll(t *testing.T) {
t.Parallel()
assert := internal.NewAssert(t, "TestRegistry_ValidateAll")
registry := NewRegistry[Status]()
item1 := NewItem(Active, "Active")
item2 := NewItem(Inactive, "Inactive")
registry.Add(item1, item2)
err := registry.ValidateAll(Active, Inactive)
assert.IsNil(err)
err = registry.ValidateAll(Active, Unknown)
assert.IsNotNil(err)
}