diff --git a/concurrency/channel.go b/concurrency/channel.go index a7e3b68..13617a0 100644 --- a/concurrency/channel.go +++ b/concurrency/channel.go @@ -4,6 +4,8 @@ // Package concurrency contain some functions to support concurrent programming. eg, goroutine, channel, async. package concurrency +import "context" + // Channel is a logic object which can generate or manipulate go channel // all methods of Channel are in the book tilted《Concurrency in Go》 type Channel struct { @@ -15,7 +17,7 @@ func NewChannel() *Channel { } // Generate a data of type any chan, put param `values` into the chan -func (c *Channel) Generate(done <-chan any, values ...any) <-chan any { +func (c *Channel) Generate(ctx context.Context, values ...any) <-chan any { dataStream := make(chan any) go func() { @@ -23,7 +25,7 @@ func (c *Channel) Generate(done <-chan any, values ...any) <-chan any { for _, v := range values { select { - case <-done: + case <-ctx.Done(): return case dataStream <- v: } @@ -35,7 +37,7 @@ func (c *Channel) Generate(done <-chan any, values ...any) <-chan any { // Repeat return a data of type any chan, put param `values` into the chan repeatly, // until close the `done` chan -func (c *Channel) Repeat(done <-chan any, values ...any) <-chan any { +func (c *Channel) Repeat(ctx context.Context, values ...any) <-chan any { dataStream := make(chan any) go func() { @@ -43,7 +45,7 @@ func (c *Channel) Repeat(done <-chan any, values ...any) <-chan any { for { for _, v := range values { select { - case <-done: + case <-ctx.Done(): return case dataStream <- v: } @@ -55,14 +57,14 @@ func (c *Channel) Repeat(done <-chan any, values ...any) <-chan any { // RepeatFn return a chan, excutes fn repeatly, and put the result into retruned chan // until close the `done` channel -func (c *Channel) RepeatFn(done <-chan any, fn func() any) <-chan any { +func (c *Channel) RepeatFn(ctx context.Context, fn func() any) <-chan any { dataStream := make(chan any) go func() { defer close(dataStream) for { select { - case <-done: + case <-ctx.Done(): return case dataStream <- fn(): } @@ -72,7 +74,7 @@ func (c *Channel) RepeatFn(done <-chan any, fn func() any) <-chan any { } // Take return a chan whose values are tahken from another chan -func (c *Channel) Take(done <-chan any, valueStream <-chan any, number int) <-chan any { +func (c *Channel) Take(ctx context.Context, valueStream <-chan any, number int) <-chan any { takeStream := make(chan any) go func() { @@ -80,7 +82,7 @@ func (c *Channel) Take(done <-chan any, valueStream <-chan any, number int) <-ch for i := 0; i < number; i++ { select { - case <-done: + case <-ctx.Done(): return case takeStream <- <-valueStream: } diff --git a/concurrency/channel_test.go b/concurrency/channel_test.go index 289e07d..a07e4b2 100644 --- a/concurrency/channel_test.go +++ b/concurrency/channel_test.go @@ -1,6 +1,7 @@ package concurrency import ( + "context" "testing" "github.com/duke-git/lancet/v2/internal" @@ -9,11 +10,11 @@ import ( func TestGenerate(t *testing.T) { assert := internal.NewAssert(t, "TestGenerate") - done := make(chan any) - defer close(done) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() c := NewChannel() - intStream := c.Generate(done, 1, 2, 3) + intStream := c.Generate(ctx, 1, 2, 3) // for v := range intStream { // t.Log(v) //1, 2, 3 @@ -26,11 +27,11 @@ func TestGenerate(t *testing.T) { func TestRepeat(t *testing.T) { assert := internal.NewAssert(t, "TestRepeat") - done := make(chan any) - defer close(done) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() c := NewChannel() - intStream := c.Take(done, c.Repeat(done, 1, 2), 5) + intStream := c.Take(ctx, c.Repeat(ctx, 1, 2), 5) // for v := range intStream { // t.Log(v) //1, 2, 1, 2, 1 @@ -45,15 +46,15 @@ func TestRepeat(t *testing.T) { func TestRepeatFn(t *testing.T) { assert := internal.NewAssert(t, "TestRepeatFn") - done := make(chan any) - defer close(done) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() fn := func() any { s := "a" return s } c := NewChannel() - dataStream := c.Take(done, c.RepeatFn(done, fn), 3) + dataStream := c.Take(ctx, c.RepeatFn(ctx, fn), 3) // for v := range dataStream { // t.Log(v) //a, a, a @@ -67,8 +68,8 @@ func TestRepeatFn(t *testing.T) { func TestTake(t *testing.T) { assert := internal.NewAssert(t, "TestRepeat") - done := make(chan any) - defer close(done) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() numbers := make(chan any, 5) numbers <- 1 @@ -79,7 +80,7 @@ func TestTake(t *testing.T) { defer close(numbers) c := NewChannel() - intStream := c.Take(done, numbers, 3) + intStream := c.Take(ctx, numbers, 3) // for v := range intStream { // t.Log(v) //1, 2, 3