diff --git a/concurrency/channel.go b/concurrency/channel.go index 13617a0..418e7d3 100644 --- a/concurrency/channel.go +++ b/concurrency/channel.go @@ -4,7 +4,10 @@ // Package concurrency contain some functions to support concurrent programming. eg, goroutine, channel, async. package concurrency -import "context" +import ( + "context" + "sync" +) // Channel is a logic object which can generate or manipulate go channel // all methods of Channel are in the book tilted《Concurrency in Go》 @@ -91,3 +94,33 @@ func (c *Channel) Take(ctx context.Context, valueStream <-chan any, number int) return takeStream } + +// FanIn merge multiple channels into one channel +func (c *Channel) FanIn(ctx context.Context, channels ...<-chan any) <-chan any { + var wg sync.WaitGroup + multiplexedStream := make(chan any) + + multiplex := func(c <-chan any) { + defer wg.Done() + + for i := range c { + select { + case <-ctx.Done(): + return + case multiplexedStream <- i: + } + } + } + + wg.Add(len(channels)) + for _, c := range channels { + go multiplex(c) + } + + go func() { + wg.Wait() + close(multiplexedStream) + }() + + return multiplexedStream +} diff --git a/concurrency/channel_test.go b/concurrency/channel_test.go index a07e4b2..e8cc0cc 100644 --- a/concurrency/channel_test.go +++ b/concurrency/channel_test.go @@ -66,7 +66,7 @@ func TestRepeatFn(t *testing.T) { } func TestTake(t *testing.T) { - assert := internal.NewAssert(t, "TestRepeat") + assert := internal.NewAssert(t, "TestTake") ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -82,11 +82,29 @@ func TestTake(t *testing.T) { c := NewChannel() intStream := c.Take(ctx, numbers, 3) - // for v := range intStream { - // t.Log(v) //1, 2, 3 - // } - assert.Equal(1, <-intStream) assert.Equal(2, <-intStream) assert.Equal(3, <-intStream) } + +func TestFanIn(t *testing.T) { + assert := internal.NewAssert(t, "TestFanIn") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := NewChannel() + channels := make([]<-chan any, 3) + + for i := 0; i < 3; i++ { + channels[i] = c.Take(ctx, c.Repeat(ctx, i), 3) + } + + mergedChannel := c.FanIn(ctx, channels...) + + for val := range mergedChannel { + t.Logf("\t%d\n", val) + } + + assert.Equal(1, 1) +}