diff --git a/iterator/iterator.go b/iterator/iterator.go index 6b62ea1..1acda7e 100644 --- a/iterator/iterator.go +++ b/iterator/iterator.go @@ -11,6 +11,8 @@ package iterator import ( + "context" + "golang.org/x/exp/constraints" ) @@ -102,21 +104,15 @@ func (iter *sliceIterator[T]) HasNext() bool { func (iter *sliceIterator[T]) Next() (T, bool) { iter.index++ + ok := iter.index >= 0 && iter.index < len(iter.slice) + var item T if ok { item = iter.slice[iter.index] } - return item, ok - // if len(iter.slice) == 0 { - // var zero T - // return zero, false - // } - // iter.index++ - // item := iter.slice[0] - // iter.slice = iter.slice[1:] - // return item, true + return item, ok } // Prev implements PrevIterator. @@ -171,3 +167,41 @@ func (iter *rangeIterator[T]) Next() (T, bool) { iter.start += iter.step return num, true } + +// FromRange creates a iterator which returns the numeric range between start inclusive and end +// exclusive by the step size. start should be less than end, step shoud be positive. +func FromChannel[T any](channel <-chan T) Iterator[T] { + return &channelIterator[T]{channel: channel} +} + +type channelIterator[T any] struct { + channel <-chan T +} + +func (iter *channelIterator[T]) Next() (T, bool) { + item, ok := <-iter.channel + return item, ok +} + +func (iter *channelIterator[T]) HasNext() bool { + return len(iter.channel) == 0 +} + +// ToChannel create a new goroutine to pull items from the channel iterator to the returned channel. +func ToChannel[T any](ctx context.Context, iter Iterator[T], buffer int) <-chan T { + result := make(chan T, buffer) + + go func() { + defer close(result) + + for item, ok := iter.Next(); ok; item, ok = iter.Next() { + select { + case result <- item: + case <-ctx.Done(): + return + } + } + }() + + return result +} diff --git a/iterator/iterator_test.go b/iterator/iterator_test.go index b762b5f..b29cae9 100644 --- a/iterator/iterator_test.go +++ b/iterator/iterator_test.go @@ -5,6 +5,7 @@ package iterator import ( + "context" "testing" "github.com/duke-git/lancet/v2/internal" @@ -47,6 +48,15 @@ func TestSliceIterator(t *testing.T) { assert.Equal(false, ok) }) + t.Run("slice iterator ToSlice: ", func(t *testing.T) { + iter := FromSlice([]int{1, 2, 3, 4}) + item, _ := iter.Next() + assert.Equal(1, item) + + data := ToSlice(iter) + assert.Equal([]int{2, 3, 4}, data) + }) + } func TestRangeIterator(t *testing.T) { @@ -73,3 +83,21 @@ func TestRangeIterator(t *testing.T) { }) } + +func TestChannelIterator(t *testing.T) { + assert := internal.NewAssert(t, "TestRangeIterator") + + iter := FromSlice([]int{1, 2, 3, 4}) + + ctx, cancel := context.WithCancel(context.Background()) + iter = FromChannel(ToChannel(ctx, iter, 0)) + item, ok := iter.Next() + assert.Equal(1, item) + assert.Equal(true, ok) + assert.Equal(true, iter.HasNext()) + + cancel() + + _, ok = iter.Next() + assert.Equal(false, ok) +}