From e27df00fa8016550c454bd25a634a5b4b8b68ba5 Mon Sep 17 00:00:00 2001 From: Tuuuuuuuuu Date: Tue, 14 Jan 2025 16:19:45 +0800 Subject: [PATCH] make Bridge not block in the first stream that not closed (#288) * not block in the first channel * make Bridge not block in the first stream that not closed * Bridge with test --- concurrency/channel.go | 19 +++++++++++-------- concurrency/channel_example_test.go | 19 ++++++++++++------- concurrency/channel_test.go | 12 ++++++++---- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/concurrency/channel.go b/concurrency/channel.go index c610971..876d14e 100644 --- a/concurrency/channel.go +++ b/concurrency/channel.go @@ -157,10 +157,10 @@ func (c *Channel[T]) Tee(ctx context.Context, in <-chan T) (<-chan T, <-chan T) // Play: https://go.dev/play/p/qmWSy1NVF-Y func (c *Channel[T]) Bridge(ctx context.Context, chanStream <-chan <-chan T) <-chan T { valStream := make(chan T) - go func() { defer close(valStream) - + wg := sync.WaitGroup{} + defer wg.Wait() for { var stream <-chan T select { @@ -169,19 +169,22 @@ func (c *Channel[T]) Bridge(ctx context.Context, chanStream <-chan <-chan T) <-c return } stream = maybeStream + wg.Add(1) case <-ctx.Done(): return } - for val := range c.OrDone(ctx, stream) { - select { - case valStream <- val: - case <-ctx.Done(): + go func() { + defer wg.Done() + for val := range c.OrDone(ctx, stream) { + select { + case valStream <- val: + case <-ctx.Done(): + } } - } + }() } }() - return valStream } diff --git a/concurrency/channel_example_test.go b/concurrency/channel_example_test.go index 9d4686a..dc72281 100644 --- a/concurrency/channel_example_test.go +++ b/concurrency/channel_example_test.go @@ -168,7 +168,8 @@ func ExampleChannel_Tee() { func ExampleChannel_Bridge() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - + m1 := make(map[int]int) + m2 := make(map[int]int) c := NewChannel[int]() genVals := func() <-chan <-chan int { out := make(chan (<-chan int)) @@ -177,6 +178,7 @@ func ExampleChannel_Bridge() { for i := 1; i <= 5; i++ { stream := make(chan int, 1) stream <- i + m1[i]++ close(stream) out <- stream } @@ -185,12 +187,15 @@ func ExampleChannel_Bridge() { } for v := range c.Bridge(ctx, genVals()) { - fmt.Println(v) + m2[v]++ + } + for k, v := range m1 { + fmt.Println(m2[k] == v) } // Output: - // 1 - // 2 - // 3 - // 4 - // 5 + // true + // true + // true + // true + // true } diff --git a/concurrency/channel_test.go b/concurrency/channel_test.go index e977bcf..4c9da23 100644 --- a/concurrency/channel_test.go +++ b/concurrency/channel_test.go @@ -169,7 +169,8 @@ func TestTee(t *testing.T) { func TestBridge(t *testing.T) { t.Parallel() assert := internal.NewAssert(t, "TestBridge") - + m1 := make(map[int]int) + m2 := make(map[int]int) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -181,6 +182,7 @@ func TestBridge(t *testing.T) { for i := 0; i < 10; i++ { stream := make(chan int, 1) stream <- i + m1[i]++ close(stream) chanStream <- stream } @@ -188,9 +190,11 @@ func TestBridge(t *testing.T) { return chanStream } - index := 0 for val := range c.Bridge(ctx, genVals()) { - assert.Equal(index, val) - index++ + m2[val]++ + } + + for k, v := range m1 { + assert.Equal(m2[k], v) } }