github.com/grailbio/base@v0.0.11/sync/once/once_test.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package once
     6  
     7  import (
     8  	"errors"
     9  	"sync/atomic"
    10  	"testing"
    11  
    12  	"github.com/grailbio/base/traverse"
    13  )
    14  
    15  func TestTaskOnceConcurrency(t *testing.T) {
    16  	const (
    17  		N      = 10
    18  		resets = 2
    19  	)
    20  	var (
    21  		o     Task
    22  		count int32
    23  	)
    24  	for r := 0; r < resets; r++ {
    25  		err := traverse.Each(N, func(_ int) error {
    26  			return o.Do(func() error {
    27  				atomic.AddInt32(&count, 1)
    28  				return nil
    29  			})
    30  		})
    31  		if err != nil {
    32  			t.Fatal(err)
    33  		}
    34  		if got, want := atomic.LoadInt32(&count), int32(r+1); got != want {
    35  			t.Errorf("got %v, want %v", got, want)
    36  		}
    37  		if got, want := o.Done(), true; got != want {
    38  			t.Errorf("got %v, want %v", got, want)
    39  		}
    40  		o.Reset()
    41  		if got, want := o.Done(), false; got != want {
    42  			t.Errorf("got %v, want %v", got, want)
    43  		}
    44  	}
    45  }
    46  
    47  func TestMapOnceConcurrency(t *testing.T) {
    48  	const N = 10
    49  	var (
    50  		once  Map
    51  		count uint32
    52  	)
    53  	err := traverse.Each(N, func(jobIdx int) error {
    54  		return once.Do(123, func() error {
    55  			atomic.AddUint32(&count, 1)
    56  			return nil
    57  		})
    58  	})
    59  	if err != nil {
    60  		t.Fatal(err)
    61  	}
    62  	if got, want := count, uint32(1); got != want {
    63  		t.Errorf("got %v, want %v", got, want)
    64  	}
    65  }
    66  
    67  func TestTaskOnceError(t *testing.T) {
    68  	var (
    69  		once     Map
    70  		expected = errors.New("expected error")
    71  	)
    72  	err := once.Do(123, func() error { return expected })
    73  	if got, want := err, expected; got != want {
    74  		t.Errorf("got %v, want %v", got, want)
    75  	}
    76  	err = once.Do(123, func() error { panic("should not be called") })
    77  	if got, want := err, expected; got != want {
    78  		t.Errorf("got %v, want %v", got, want)
    79  	}
    80  	err = once.Do(124, func() error { return nil })
    81  	if err != nil {
    82  		t.Errorf("unexpected error %v", err)
    83  	}
    84  }