github.com/reusee/pr2@v0.0.0-20230630035947-72a20ff5e864/consume_test.go (about)

     1  package pr2
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"sync"
     8  	"sync/atomic"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/reusee/e5"
    13  )
    14  
    15  func TestConsume(
    16  	t *testing.T,
    17  ) {
    18  
    19  	t.Run("normal", func(t *testing.T) {
    20  		var c int64
    21  		put, wait := Consume(
    22  			context.Background(),
    23  			8,
    24  			func(_ int, _ int64) error {
    25  				atomic.AddInt64(&c, 1)
    26  				return nil
    27  			},
    28  		)
    29  		var numPut int64
    30  		for i := 0; i < 8; i++ {
    31  			if put(int64(i)) {
    32  				numPut++
    33  			}
    34  		}
    35  		if err := wait(true); err != nil {
    36  			t.Fatal(err)
    37  		}
    38  		if c != 8 {
    39  			t.Fatalf("got %d", c)
    40  		}
    41  		if c != numPut {
    42  			t.Fatalf("got %d, %d", c, numPut)
    43  		}
    44  	})
    45  
    46  	t.Run("multiple wait", func(t *testing.T) {
    47  		var c int64
    48  		put, wait := Consume(
    49  			context.Background(),
    50  			8,
    51  			func(_ int, _ int64) error {
    52  				atomic.AddInt64(&c, 1)
    53  				return nil
    54  			},
    55  		)
    56  
    57  		var numPut int64
    58  		for i := 0; i < 8; i++ {
    59  			if put(int64(i)) {
    60  				numPut++
    61  			}
    62  		}
    63  		if err := wait(false); err != nil {
    64  			t.Fatal(err)
    65  		}
    66  		if c != 8 {
    67  			t.Fatalf("got %d", c)
    68  		}
    69  		if c != numPut {
    70  			t.Fatalf("got %d, %d", c, numPut)
    71  		}
    72  
    73  		for i := 0; i < 8; i++ {
    74  			if put(int64(i)) {
    75  				numPut++
    76  			}
    77  		}
    78  		if err := wait(false); err != nil {
    79  			t.Fatal(err)
    80  		}
    81  		if c != 16 {
    82  			t.Fatalf("got %d", 16)
    83  		}
    84  		if c != numPut {
    85  			t.Fatalf("got %d, %d", c, numPut)
    86  		}
    87  	})
    88  
    89  	t.Run("cancel before put", func(t *testing.T) {
    90  		var c int64
    91  		wg := NewWaitGroup(context.Background())
    92  		put, wait := Consume(
    93  			wg,
    94  			8,
    95  			func(_ int, _ int64) error {
    96  				atomic.AddInt64(&c, 1)
    97  				return nil
    98  			},
    99  		)
   100  		wg.Cancel()
   101  		var numPut int
   102  		for i := 0; i < 8; i++ {
   103  			if put(int64(i)) {
   104  				numPut++
   105  			}
   106  		}
   107  		if err := wait(true); err != nil {
   108  			t.Fatal(err)
   109  		}
   110  		if c != 0 {
   111  			t.Fatalf("got %d", c)
   112  		}
   113  		if numPut != 0 {
   114  			t.Fatalf("got %d", numPut)
   115  		}
   116  	})
   117  
   118  	t.Run("close before put", func(t *testing.T) {
   119  		var c int64
   120  		put, wait := Consume(
   121  			context.Background(),
   122  			8,
   123  			func(_ int, _ int64) error {
   124  				atomic.AddInt64(&c, 1)
   125  				return nil
   126  			},
   127  		)
   128  		if err := wait(true); err != nil {
   129  			t.Fatal(err)
   130  		}
   131  		var numPut int64
   132  		for i := 0; i < 8; i++ {
   133  			if put(int64(i)) {
   134  				numPut++
   135  			}
   136  		}
   137  		if err := wait(true); err != nil {
   138  			t.Fatal(err)
   139  		}
   140  		if c != 0 {
   141  			t.Fatalf("got %d", c)
   142  		}
   143  		if c != numPut {
   144  			t.Fatalf("got %d", numPut)
   145  		}
   146  	})
   147  
   148  	t.Run("concurrent put and close", func(t *testing.T) {
   149  		var c int64
   150  		put, wait := Consume(
   151  			context.Background(),
   152  			8,
   153  			func(_ int, _ int64) error {
   154  				atomic.AddInt64(&c, 1)
   155  				return nil
   156  			},
   157  		)
   158  		var numPut int64
   159  		wg := new(sync.WaitGroup)
   160  		for i := 0; i < 128; i++ {
   161  			i := i
   162  			wg.Add(1)
   163  			go func() {
   164  				defer wg.Done()
   165  				if put(int64(i)) {
   166  					atomic.AddInt64(&numPut, 1)
   167  				}
   168  			}()
   169  		}
   170  		if err := wait(true); err != nil {
   171  			t.Fatal(err)
   172  		}
   173  		wg.Wait()
   174  		if a, b := atomic.LoadInt64(&numPut), atomic.LoadInt64(&c); a != b {
   175  			t.Fatalf("got %d, %d", a, b)
   176  		}
   177  	})
   178  
   179  	t.Run("concurrent put and cancel", func(t *testing.T) {
   180  		waitGroup := NewWaitGroup(context.Background())
   181  		var c int64
   182  		put, wait := Consume(
   183  			waitGroup,
   184  			8,
   185  			func(_ int, _ int64) error {
   186  				atomic.AddInt64(&c, 1)
   187  				return nil
   188  			},
   189  		)
   190  		var numPut int64
   191  		wg := new(sync.WaitGroup)
   192  		for i := 0; i < 128; i++ {
   193  			i := i
   194  			wg.Add(1)
   195  			go func() {
   196  				defer wg.Done()
   197  				if put(int64(i)) {
   198  					atomic.AddInt64(&numPut, 1)
   199  				}
   200  			}()
   201  		}
   202  		waitGroup.Cancel()
   203  		if err := wait(true); err != nil {
   204  			t.Fatal(err)
   205  		}
   206  		wg.Wait()
   207  		if a, b := atomic.LoadInt64(&numPut), atomic.LoadInt64(&c); a != b {
   208  			t.Fatalf("got %d, %d", a, b)
   209  		}
   210  	})
   211  
   212  	t.Run("fn error", func(t *testing.T) {
   213  		var c int64
   214  		put, wait := Consume(
   215  			context.Background(),
   216  			8,
   217  			func(_ int, v int64) error {
   218  				atomic.AddInt64(&c, 1)
   219  				if v == 3 {
   220  					return errors.New("foo")
   221  				}
   222  				return nil
   223  			},
   224  		)
   225  		var numPut int64
   226  		for i := 0; i < 8; i++ {
   227  			if put(int64(i)) {
   228  				numPut++
   229  			}
   230  		}
   231  		err := wait(true)
   232  		if err == nil {
   233  			t.Fatal()
   234  		}
   235  		if err.Error() != "foo" {
   236  			t.Fatalf("got %v", err)
   237  		}
   238  		if numPut != c {
   239  			t.Fatalf("got %d, %d", numPut, c)
   240  		}
   241  		time.Sleep(time.Millisecond * 10)
   242  	})
   243  
   244  	t.Run("multiple fn error", func(t *testing.T) {
   245  		var c int64
   246  		put, wait := Consume(
   247  			context.Background(),
   248  			128,
   249  			func(_ int, v int64) error {
   250  				atomic.AddInt64(&c, 1)
   251  				if v > 1 {
   252  					return errors.New("foo")
   253  				}
   254  				return nil
   255  			},
   256  		)
   257  		var numPut int64
   258  		for i := 0; i < 128; i++ {
   259  			if put(int64(i)) {
   260  				numPut++
   261  			}
   262  		}
   263  		err := wait(true)
   264  		if err == nil {
   265  			t.Fatal()
   266  		}
   267  		if err.Error() != "foo" {
   268  			t.Fatalf("got %v", err)
   269  		}
   270  		if numPut != c {
   271  			t.Fatalf("got %d, %d", numPut, c)
   272  		}
   273  	})
   274  
   275  	t.Run("put in func", func(t *testing.T) {
   276  		var c int64
   277  		var put Put[int64]
   278  		put, wait := Consume(
   279  			context.Background(),
   280  			8,
   281  			func(_ int, n int64) error {
   282  				atomic.AddInt64(&c, 1)
   283  				if n == 0 {
   284  					return nil
   285  				}
   286  				put(n - 1)
   287  				return nil
   288  			},
   289  		)
   290  		put(2048)
   291  		if err := wait(false); err != nil {
   292  			t.Fatal(err)
   293  		}
   294  		if c != 2049 {
   295  			t.Fatalf("got %d", c)
   296  		}
   297  	})
   298  
   299  }
   300  
   301  func TestConsumeE4Error(t *testing.T) {
   302  	put, wait := Consume(context.Background(), 8, func(_ int, v func() error) error {
   303  		return v()
   304  	})
   305  	put(func() error {
   306  		return func() error {
   307  			e5.Check(io.EOF)
   308  			return nil
   309  		}()
   310  	})
   311  	err := wait(true)
   312  	if !errors.Is(err, io.EOF) {
   313  		t.Fatal()
   314  	}
   315  }