diff --git a/concurrency/channel.go b/concurrency/channel.go index d5468e9..092bece 100644 --- a/concurrency/channel.go +++ b/concurrency/channel.go @@ -157,3 +157,27 @@ func (c *Channel) Or(channels ...<-chan any) <-chan any { return orDone } + +// OrDone +func (c *Channel) OrDone(ctx context.Context, channel <-chan any) <-chan any { + resStream := make(chan any) + + go func() { + defer close(resStream) + + select { + case <-ctx.Done(): + return + case v, ok := <-channel: + if !ok { + return + } + select { + case resStream <- v: + case <-ctx.Done(): + } + } + }() + + return resStream +} diff --git a/concurrency/channel_test.go b/concurrency/channel_test.go index 74f8712..22b81d1 100644 --- a/concurrency/channel_test.go +++ b/concurrency/channel_test.go @@ -131,3 +131,21 @@ func TestOr(t *testing.T) { assert.Equal(1, 1) } + +func TestOrDone(t *testing.T) { + assert := internal.NewAssert(t, "TestOrDone") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := NewChannel() + intStream := c.Take(ctx, c.Repeat(ctx, 1), 3) + + var res any + for val := range c.OrDone(ctx, intStream) { + t.Logf("%v", val) + res = val + } + + assert.Equal(1, res) +}