github.com/reusee/pr2@v0.0.0-20230630035947-72a20ff5e864/wait_group_test.go (about)

     1  package pr2
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"sync/atomic"
     7  	"testing"
     8  
     9  	"github.com/reusee/e5"
    10  )
    11  
    12  func TestWaitGroup(t *testing.T) {
    13  
    14  	t.Run("single", func(t *testing.T) {
    15  		wg := NewWaitGroup(context.Background())
    16  		n := 128
    17  		var c int64
    18  		for i := 0; i < n; i++ {
    19  			wg.Go(func() {
    20  				<-wg.Done()
    21  				atomic.AddInt64(&c, 1)
    22  			})
    23  		}
    24  		wg.Cancel()
    25  		wg.Wait()
    26  		if c != int64(n) {
    27  			t.Fatal()
    28  		}
    29  	})
    30  
    31  	t.Run("tree", func(t *testing.T) {
    32  		wg := NewWaitGroup(context.Background())
    33  		var c int64
    34  		n := 128
    35  		m := 8
    36  		for i := 0; i < m; i++ {
    37  			subWg := NewWaitGroup(wg)
    38  			go func() {
    39  				for i := 0; i < n; i++ {
    40  					subWg.Go(func() {
    41  						<-subWg.Done()
    42  						atomic.AddInt64(&c, 1)
    43  					})
    44  				}
    45  				subWg.Cancel()
    46  				subWg.Wait()
    47  			}()
    48  		}
    49  		wg.Wait()
    50  		if c != int64(n*m) {
    51  			t.Fatal()
    52  		}
    53  	})
    54  
    55  	t.Run("cancel", func(t *testing.T) {
    56  		var num int
    57  		wg := NewWaitGroup(context.Background())
    58  		wg.Go(func() {
    59  			<-wg.Done()
    60  			num++
    61  		})
    62  		wg.Cancel()
    63  		err := wg.Err()
    64  		if !errors.Is(err, context.Canceled) {
    65  			t.Fatal()
    66  		}
    67  		func() {
    68  			var err error
    69  			defer func() {
    70  				if err == nil {
    71  					t.Fatal("shoule throw error")
    72  				}
    73  				if !errors.Is(err, context.Canceled) {
    74  					t.Fatal()
    75  				}
    76  				wg.Wait()
    77  				if num != 1 {
    78  					t.Fatal()
    79  				}
    80  			}()
    81  			defer e5.Handle(&err)
    82  			wg.Add()
    83  		}()
    84  	})
    85  
    86  	t.Run("get", func(t *testing.T) {
    87  		wg := NewWaitGroup(context.Background())
    88  		wg2 := GetWaitGroup(wg)
    89  		if wg2 != wg {
    90  			t.Fatal()
    91  		}
    92  
    93  		wg2 = GetWaitGroup(context.Background())
    94  		if wg2 != nil {
    95  			t.Fatal()
    96  		}
    97  	})
    98  
    99  }