github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/workerpool/async_pool_test.go (about)

     1  // Copyright 2020 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package workerpool
    15  
    16  import (
    17  	"context"
    18  	"math/rand"
    19  	"sync"
    20  	"sync/atomic"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/pingcap/errors"
    25  	"github.com/pingcap/log"
    26  	"github.com/stretchr/testify/require"
    27  	"golang.org/x/sync/errgroup"
    28  )
    29  
    30  func TestBasic(t *testing.T) {
    31  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
    32  	defer cancel()
    33  
    34  	errg, ctx := errgroup.WithContext(ctx)
    35  
    36  	pool := newDefaultAsyncPoolImpl(4)
    37  	errg.Go(func() error {
    38  		return pool.Run(ctx)
    39  	})
    40  
    41  	var sum int32
    42  	var wg sync.WaitGroup
    43  	for i := 0; i < 100; i++ {
    44  		wg.Add(1)
    45  		finalI := i
    46  		err := pool.Go(ctx, func() {
    47  			time.Sleep(time.Millisecond * time.Duration(rand.Int()%100))
    48  			atomic.AddInt32(&sum, int32(finalI+1))
    49  			wg.Done()
    50  		})
    51  		require.Nil(t, err)
    52  	}
    53  
    54  	wg.Wait()
    55  	require.Equal(t, sum, int32(5050))
    56  
    57  	cancel()
    58  	err := errg.Wait()
    59  	require.Regexp(t, "context canceled", err)
    60  }
    61  
    62  func TestEventuallyRun(t *testing.T) {
    63  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
    64  	defer cancel()
    65  
    66  	errg, ctx := errgroup.WithContext(ctx)
    67  	loopCtx, cancelLoop := context.WithCancel(ctx)
    68  	defer cancelLoop()
    69  
    70  	pool := newDefaultAsyncPoolImpl(4)
    71  	errg.Go(func() error {
    72  		defer cancelLoop()
    73  		for i := 0; i < 10; i++ {
    74  			log.Info("running pool")
    75  			err := runForDuration(ctx, time.Millisecond*500, func(ctx context.Context) error {
    76  				return pool.Run(ctx)
    77  			})
    78  			if err != nil {
    79  				return errors.Trace(err)
    80  			}
    81  		}
    82  		return nil
    83  	})
    84  
    85  	var sum int32
    86  	var sumExpected int32
    87  loop:
    88  	for i := 0; ; i++ {
    89  		select {
    90  		case <-loopCtx.Done():
    91  			break loop
    92  		default:
    93  		}
    94  		finalI := i
    95  		err := pool.Go(loopCtx, func() {
    96  			if rand.Int()%128 == 0 {
    97  				time.Sleep(2 * time.Millisecond)
    98  			}
    99  			atomic.AddInt32(&sum, int32(finalI+1))
   100  		})
   101  		if err != nil {
   102  			require.Regexp(t, "context canceled", err.Error())
   103  		} else {
   104  			sumExpected += int32(i + 1)
   105  		}
   106  	}
   107  
   108  	cancel()
   109  	err := errg.Wait()
   110  	require.Nil(t, err)
   111  	require.Equal(t, sum, sumExpected)
   112  }
   113  
   114  func runForDuration(ctx context.Context, duration time.Duration, f func(ctx context.Context) error) error {
   115  	timedCtx, cancel := context.WithTimeout(ctx, duration)
   116  	defer cancel()
   117  
   118  	errCh := make(chan error)
   119  	go func() {
   120  		errCh <- f(timedCtx)
   121  	}()
   122  
   123  	select {
   124  	case <-ctx.Done():
   125  		return ctx.Err()
   126  	case err := <-errCh:
   127  		if errors.Cause(err) == context.DeadlineExceeded {
   128  			return nil
   129  		}
   130  		return errors.Trace(err)
   131  	}
   132  }