diff --git a/compare/compare.go b/compare/compare.go new file mode 100644 index 0000000..8011f56 --- /dev/null +++ b/compare/compare.go @@ -0,0 +1,57 @@ +// Copyright 2023 dudaodong@gmail.com. All rights resulterved. +// Use of this source code is governed by MIT license + +// Package compare provides a lightweight comparison function on any type. +// reference: https://github.com/stretchr/testify +package compare + +import ( + "reflect" + "time" + + "github.com/duke-git/lancet/v2/convertor" +) + +// operator type +const ( + equal = "eq" + lessThan = "lt" + greaterThan = "gt" + lessOrEqual = "le" + greaterOrEqual = "ge" +) + +var ( + timeType = reflect.TypeOf(time.Time{}) + bytesType = reflect.TypeOf([]byte{}) +) + +// Equal checks if two values are equal or not +func Equal(left, right any) bool { + return compareValue(equal, left, right) +} + +func EqualValue(left, right any) bool { + ls, rs := convertor.ToString(left), convertor.ToString(right) + return ls == rs +} + +// LessThan checks if value `left` less than value `right`. +func LessThan(left, right any) bool { + return compareValue(lessThan, left, right) +} + +// GreaterThan checks if value `left` greater than value `right` +func GreaterThan(left, right any) bool { + return compareValue(greaterThan, left, right) +} + +// LessOrEqual checks if value `left` less than or equal to value `right` +func LessOrEqual(left, right any) bool { + return compareValue(lessOrEqual, left, right) +} + +// GreaterOrEqual checks if value `left` greater than or equal to value `right` +func GreaterOrEqual(left, right any) bool { + return compareValue(greaterOrEqual, left, right) +} diff --git a/compare/compare_internal.go b/compare/compare_internal.go new file mode 100644 index 0000000..52c00ae --- /dev/null +++ b/compare/compare_internal.go @@ -0,0 +1,323 @@ +package compare + +import ( + "bytes" + "encoding/json" + "reflect" + "time" + + "github.com/duke-git/lancet/v2/convertor" +) + +func compareValue(operator string, left, right any) bool { + leftType, rightType := reflect.TypeOf(left), reflect.TypeOf(right) + + if leftType.Kind() != rightType.Kind() { + return false + } + + switch leftType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.Bool, reflect.String: + return compareBasicValue(operator, left, right) + + case reflect.Struct, reflect.Slice, reflect.Map: + return compareRefValue(operator, left, right, leftType.Kind()) + } + + return false +} + +func compareRefValue(operator string, leftObj, rightObj any, kind reflect.Kind) bool { + leftVal, rightVal := reflect.ValueOf(leftObj), reflect.ValueOf(rightObj) + + switch kind { + case reflect.Struct: + + // compare time + if leftVal.CanConvert(timeType) { + timeObj1, ok := leftObj.(time.Time) + if !ok { + timeObj1 = leftVal.Convert(timeType).Interface().(time.Time) + } + + timeObj2, ok := rightObj.(time.Time) + if !ok { + timeObj2 = rightVal.Convert(timeType).Interface().(time.Time) + } + + return compareBasicValue(operator, timeObj1.UnixNano(), timeObj2.UnixNano()) + } + + // for other struct type, only process equal operator + switch operator { + case equal: + return objectsAreEqualValues(leftObj, rightObj) + } + + case reflect.Slice: + // compare []byte + if leftVal.CanConvert(bytesType) { + bytesObj1, ok := leftObj.([]byte) + if !ok { + bytesObj1 = leftVal.Convert(bytesType).Interface().([]byte) + } + bytesObj2, ok := rightObj.([]byte) + if !ok { + bytesObj2 = rightVal.Convert(bytesType).Interface().([]byte) + } + + switch operator { + case equal: + if bytes.Compare(bytesObj1, bytesObj2) == 0 { + return true + } + case lessThan: + if bytes.Compare(bytesObj1, bytesObj2) == -1 { + return true + } + case greaterThan: + if bytes.Compare(bytesObj1, bytesObj2) == 1 { + return true + } + case lessOrEqual: + if bytes.Compare(bytesObj1, bytesObj2) <= 0 { + return true + } + case greaterOrEqual: + if bytes.Compare(bytesObj1, bytesObj2) >= 0 { + return true + } + } + + } + + // for other type slice, only process equal operator + switch operator { + case equal: + return reflect.DeepEqual(leftObj, rightObj) + } + + case reflect.Map: + // only process equal operator + switch operator { + case equal: + return reflect.DeepEqual(leftObj, rightObj) + } + } + + return false +} + +func objectsAreEqualValues(expected, actual interface{}) bool { + if objectsAreEqual(expected, actual) { + return true + } + + actualType := reflect.TypeOf(actual) + if actualType == nil { + return false + } + expectedValue := reflect.ValueOf(expected) + if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { + // Attempt comparison after type conversion + return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) + } + + return false +} + +func objectsAreEqual(expected, actual interface{}) bool { + if expected == nil || actual == nil { + return expected == actual + } + + exp, ok := expected.([]byte) + if !ok { + return reflect.DeepEqual(expected, actual) + } + + act, ok := actual.([]byte) + if !ok { + return false + } + if exp == nil || act == nil { + return exp == nil && act == nil + } + return bytes.Equal(exp, act) +} + +// compareBasic compare basic value: integer, float, string, bool +func compareBasicValue(operator string, leftValue, rightValue any) bool { + if leftValue == nil && rightValue == nil && operator == equal { + return true + } + + switch leftVal := leftValue.(type) { + case json.Number: + if left, err := leftVal.Float64(); err == nil { + switch rightVal := rightValue.(type) { + case json.Number: + if right, err := rightVal.Float64(); err == nil { + switch operator { + case equal: + if left == right { + return true + } + case lessThan: + if left < right { + return true + } + case greaterThan: + if left > right { + return true + } + case lessOrEqual: + if left <= right { + return true + } + case greaterOrEqual: + if left >= right { + return true + } + } + + } + + case float32, float64, int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64: + right, err := convertor.ToFloat(rightValue) + if err != nil { + return false + } + switch operator { + case equal: + if left == right { + return true + } + case lessThan: + if left < right { + return true + } + case greaterThan: + if left > right { + return true + } + case lessOrEqual: + if left <= right { + return true + } + case greaterOrEqual: + if left >= right { + return true + } + } + } + + } + + case float32, float64, int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64: + left, err := convertor.ToFloat(leftValue) + if err != nil { + return false + } + switch rightVal := rightValue.(type) { + case json.Number: + if right, err := rightVal.Float64(); err == nil { + switch operator { + case equal: + if left == right { + return true + } + case lessThan: + if left < right { + return true + } + case greaterThan: + if left > right { + return true + } + case lessOrEqual: + if left <= right { + return true + } + case greaterOrEqual: + if left >= right { + return true + } + } + } + case float32, float64, int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64: + right, err := convertor.ToFloat(rightValue) + if err != nil { + return false + } + + switch operator { + case equal: + if left == right { + return true + } + case lessThan: + if left < right { + return true + } + case greaterThan: + if left > right { + return true + } + case lessOrEqual: + if left <= right { + return true + } + case greaterOrEqual: + if left >= right { + return true + } + } + } + + case string: + left := leftVal + switch right := rightValue.(type) { + case string: + switch operator { + case equal: + if left == right { + return true + } + case lessThan: + if left < right { + return true + } + case greaterThan: + if left > right { + return true + } + case lessOrEqual: + if left <= right { + return true + } + case greaterOrEqual: + if left >= right { + return true + } + } + } + + case bool: + left := leftVal + switch right := rightValue.(type) { + case bool: + switch operator { + case equal: + if left == right { + return true + } + } + } + + } + + return false +} diff --git a/compare/compare_test.go b/compare/compare_test.go new file mode 100644 index 0000000..0628acf --- /dev/null +++ b/compare/compare_test.go @@ -0,0 +1,70 @@ +package compare + +import ( + "testing" + "time" + + "github.com/duke-git/lancet/v2/internal" +) + +func TestEqual(t *testing.T) { + assert := internal.NewAssert(t, "TestEqual") + + assert.Equal(true, Equal(1, 1)) + assert.Equal(true, Equal(int64(1), int64(1))) + assert.Equal(true, Equal("a", "a")) + assert.Equal(true, Equal(true, true)) + assert.Equal(true, Equal([]int{1, 2, 3}, []int{1, 2, 3})) + assert.Equal(true, Equal(map[int]string{1: "a", 2: "b"}, map[int]string{1: "a", 2: "b"})) + + assert.Equal(false, Equal(1, 2)) + assert.Equal(false, Equal(1, int64(1))) + assert.Equal(false, Equal("a", "b")) + assert.Equal(false, Equal(true, false)) + assert.Equal(false, Equal([]int{1, 2}, []int{1, 2, 3})) + assert.Equal(false, Equal(map[int]string{1: "a", 2: "b"}, map[int]string{1: "a"})) + + time1 := time.Now() + time2 := time1.Add(time.Second) + time3 := time1.Add(time.Second) + + assert.Equal(false, Equal(time1, time2)) + assert.Equal(true, Equal(time2, time3)) + + st1 := struct { + A string + B string + }{ + A: "a", + B: "b", + } + + st2 := struct { + A string + B string + }{ + A: "a", + B: "b", + } + + st3 := struct { + A string + B string + }{ + A: "a1", + B: "b", + } + + assert.Equal(true, Equal(st1, st2)) + assert.Equal(false, Equal(st1, st3)) +} + +func TestEqualValue(t *testing.T) { + assert := internal.NewAssert(t, "TestEqualValue") + + assert.Equal(true, EqualValue(1, 1)) + assert.Equal(true, EqualValue(int(1), int64(1))) + assert.Equal(true, EqualValue(1, "1")) + + assert.Equal(false, EqualValue(1, "2")) +}