diff --git a/slice/slice.go b/slice/slice.go index 741d5ce..1b9ad02 100644 --- a/slice/slice.go +++ b/slice/slice.go @@ -523,18 +523,38 @@ func SortByField(slice interface{}, field string, sortType ...string) error { } // Create a less function based on the field's kind. - var less func(a, b reflect.Value) bool + var compare func(a, b reflect.Value) bool switch sf.Type.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - less = func(a, b reflect.Value) bool { return a.Int() < b.Int() } + if len(sortType) > 0 && sortType[0] == "desc" { + compare = func(a, b reflect.Value) bool { return a.Int() > b.Int() } + } else { + compare = func(a, b reflect.Value) bool { return a.Int() < b.Int() } + } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() } + if len(sortType) > 0 && sortType[0] == "desc" { + compare = func(a, b reflect.Value) bool { return a.Uint() > b.Uint() } + } else { + compare = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() } + } case reflect.Float32, reflect.Float64: - less = func(a, b reflect.Value) bool { return a.Float() < b.Float() } + if len(sortType) > 0 && sortType[0] == "desc" { + compare = func(a, b reflect.Value) bool { return a.Float() > b.Float() } + } else { + compare = func(a, b reflect.Value) bool { return a.Float() < b.Float() } + } case reflect.String: - less = func(a, b reflect.Value) bool { return a.String() < b.String() } + if len(sortType) > 0 && sortType[0] == "desc" { + compare = func(a, b reflect.Value) bool { return a.String() > b.String() } + } else { + compare = func(a, b reflect.Value) bool { return a.String() < b.String() } + } case reflect.Bool: - less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } + if len(sortType) > 0 && sortType[0] == "desc" { + compare = func(a, b reflect.Value) bool { return a.Bool() && !b.Bool() } + } else { + compare = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } + } default: return fmt.Errorf("field type %s not supported", sf.Type) } @@ -548,24 +568,12 @@ func SortByField(slice interface{}, field string, sortType ...string) error { } a = a.FieldByIndex(sf.Index) b = b.FieldByIndex(sf.Index) - return less(a, b) + return compare(a, b) }) - if sortType[0] == "desc" { - reverseSlice(slice) - } return nil } -// todo remove after migration -func reverseSlice(slice interface{}) { - sv := sliceValue(slice) - swp := reflect.Swapper(sv.Interface()) - for i, j := 0, sv.Len()-1; i < j; i, j = i+1, j-1 { - swp(i, j) - } -} - // Without creates a slice excluding all given values func Without[T comparable](slice []T, values ...T) []T { if len(values) == 0 || len(slice) == 0 { diff --git a/slice/slice_test.go b/slice/slice_test.go index 9a680b6..d4f0641 100644 --- a/slice/slice_test.go +++ b/slice/slice_test.go @@ -381,7 +381,7 @@ func TestDifference(t *testing.T) { assert.Equal([]int{1, 2, 3}, Difference(s1, s2)) } -func TestSortByField(t *testing.T) { +func TestSortByFieldDesc(t *testing.T) { assert := internal.NewAssert(t, "TestSortByField") type student struct { @@ -406,6 +406,31 @@ func TestSortByField(t *testing.T) { assert.Equal(students, studentsOfSortByAge) } +func TestSortByFieldAsc(t *testing.T) { + assert := internal.NewAssert(t, "TestSortByField") + + type student struct { + name string + age int + } + students := []student{ + {"a", 10}, + {"b", 15}, + {"c", 5}, + {"d", 6}, + } + studentsOfSortByAge := []student{ + {"c", 5}, + {"d", 6}, + {"a", 10}, + {"b", 15}, + } + + err := SortByField(students, "age") + assert.IsNil(err) + + assert.Equal(students, studentsOfSortByAge) +} func TestWithout(t *testing.T) { assert := internal.NewAssert(t, "TestWithout")