github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/util/syncutil/singleflight/singleflight_test.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  // Copyright 2013 The Go Authors. All rights reserved.
    12  // Use of this source code is governed by a BSD-style
    13  // license that can be found in licenses/BSD-golang.txt.
    14  
    15  // This code originated in Go's internal/singleflight package.
    16  
    17  package singleflight
    18  
    19  import (
    20  	"fmt"
    21  	"sync"
    22  	"sync/atomic"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/cockroachdb/errors"
    27  )
    28  
    29  func TestDo(t *testing.T) {
    30  	var g Group
    31  	v, _, err := g.Do("key", func() (interface{}, error) {
    32  		return "bar", nil
    33  	})
    34  	res := Result{Val: v, Err: err}
    35  	assertRes(t, res, false)
    36  }
    37  
    38  func TestDoChan(t *testing.T) {
    39  	var g Group
    40  	resC, leader := g.DoChan("key", func() (interface{}, error) {
    41  		return "bar", nil
    42  	})
    43  	if !leader {
    44  		t.Errorf("DoChan returned not leader, expected leader")
    45  	}
    46  	res := <-resC
    47  	assertRes(t, res, false)
    48  }
    49  
    50  func TestDoErr(t *testing.T) {
    51  	var g Group
    52  	someErr := errors.New("Some error")
    53  	v, _, err := g.Do("key", func() (interface{}, error) {
    54  		return nil, someErr
    55  	})
    56  	if !errors.Is(err, someErr) {
    57  		t.Errorf("Do error = %v; want someErr %v", err, someErr)
    58  	}
    59  	if v != nil {
    60  		t.Errorf("unexpected non-nil value %#v", v)
    61  	}
    62  }
    63  
    64  func TestDoDupSuppress(t *testing.T) {
    65  	var g Group
    66  	var wg1, wg2 sync.WaitGroup
    67  	c := make(chan string, 1)
    68  	var calls int32
    69  	fn := func() (interface{}, error) {
    70  		if atomic.AddInt32(&calls, 1) == 1 {
    71  			// First invocation.
    72  			wg1.Done()
    73  		}
    74  		v := <-c
    75  		c <- v // pump; make available for any future calls
    76  
    77  		time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
    78  
    79  		return v, nil
    80  	}
    81  
    82  	const n = 10
    83  	wg1.Add(1)
    84  	for i := 0; i < n; i++ {
    85  		wg1.Add(1)
    86  		wg2.Add(1)
    87  		go func() {
    88  			defer wg2.Done()
    89  			wg1.Done()
    90  			v, _, err := g.Do("key", fn)
    91  			if err != nil {
    92  				t.Errorf("Do error: %v", err)
    93  				return
    94  			}
    95  			if s, _ := v.(string); s != "bar" {
    96  				t.Errorf("Do = %T %v; want %q", v, v, "bar")
    97  			}
    98  		}()
    99  	}
   100  	wg1.Wait()
   101  	// At least one goroutine is in fn now and all of them have at
   102  	// least reached the line before the Do.
   103  	c <- "bar"
   104  	wg2.Wait()
   105  	if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
   106  		t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
   107  	}
   108  }
   109  
   110  func TestDoChanDupSuppress(t *testing.T) {
   111  	c := make(chan struct{})
   112  	fn := func() (interface{}, error) {
   113  		<-c
   114  		return "bar", nil
   115  	}
   116  
   117  	var g Group
   118  	resC1, leader1 := g.DoChan("key", fn)
   119  	if !leader1 {
   120  		t.Errorf("DoChan returned not leader, expected leader")
   121  	}
   122  
   123  	resC2, leader2 := g.DoChan("key", fn)
   124  	if leader2 {
   125  		t.Errorf("DoChan returned leader, expected not leader")
   126  	}
   127  
   128  	close(c)
   129  	for _, res := range []Result{<-resC1, <-resC2} {
   130  		assertRes(t, res, true)
   131  	}
   132  }
   133  
   134  func TestNumCalls(t *testing.T) {
   135  	c := make(chan struct{})
   136  	fn := func() (interface{}, error) {
   137  		<-c
   138  		return "bar", nil
   139  	}
   140  	var g Group
   141  	assertNumCalls(t, g.NumCalls("key"), 0)
   142  	resC1, _ := g.DoChan("key", fn)
   143  	assertNumCalls(t, g.NumCalls("key"), 1)
   144  	resC2, _ := g.DoChan("key", fn)
   145  	assertNumCalls(t, g.NumCalls("key"), 2)
   146  	close(c)
   147  	<-resC1
   148  	<-resC2
   149  	assertNumCalls(t, g.NumCalls("key"), 0)
   150  }
   151  
   152  func assertRes(t *testing.T, res Result, expectShared bool) {
   153  	if got, want := fmt.Sprintf("%v (%T)", res.Val, res.Val), "bar (string)"; got != want {
   154  		t.Errorf("Res.Val = %v; want %v", got, want)
   155  	}
   156  	if res.Err != nil {
   157  		t.Errorf("Res.Err = %v", res.Err)
   158  	}
   159  	if res.Shared != expectShared {
   160  		t.Errorf("Res.Shared = %t; want %t", res.Shared, expectShared)
   161  	}
   162  }
   163  
   164  func assertNumCalls(t *testing.T, actualCalls int, expectedCalls int) {
   165  	if actualCalls != expectedCalls {
   166  		t.Errorf("NumCalls = %d; want %d", actualCalls, expectedCalls)
   167  	}
   168  }