github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/utils/unittest/unittest.go (about)

     1  package unittest
     2  
     3  import (
     4  	"encoding/json"
     5  	"math"
     6  	"math/rand"
     7  	"os"
     8  	"os/exec"
     9  	"path"
    10  	"regexp"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/cockroachdb/pebble"
    18  	"github.com/dgraph-io/badger/v2"
    19  	"github.com/libp2p/go-libp2p/core/peer"
    20  	"github.com/onflow/crypto"
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/require"
    23  
    24  	"github.com/onflow/flow-go/model/flow"
    25  	"github.com/onflow/flow-go/module"
    26  	"github.com/onflow/flow-go/module/util"
    27  	"github.com/onflow/flow-go/network"
    28  	cborcodec "github.com/onflow/flow-go/network/codec/cbor"
    29  	"github.com/onflow/flow-go/network/p2p/keyutils"
    30  	"github.com/onflow/flow-go/network/topology"
    31  )
    32  
    33  type SkipReason int
    34  
    35  const (
    36  	TEST_FLAKY               SkipReason = iota + 1 // flaky
    37  	TEST_TODO                                      // not fully implemented or broken and needs to be fixed
    38  	TEST_REQUIRES_GCP_ACCESS                       // requires the environment to be configured with GCP credentials
    39  	TEST_DEPRECATED                                // uses code that has been deprecated / disabled
    40  	TEST_LONG_RUNNING                              // long running
    41  	TEST_RESOURCE_INTENSIVE                        // resource intensive test
    42  )
    43  
    44  func (s SkipReason) String() string {
    45  	switch s {
    46  	case TEST_FLAKY:
    47  		return "TEST_FLAKY"
    48  	case TEST_TODO:
    49  		return "TEST_TODO"
    50  	case TEST_REQUIRES_GCP_ACCESS:
    51  		return "TEST_REQUIRES_GCP_ACCESS"
    52  	case TEST_DEPRECATED:
    53  		return "TEST_DEPRECATED"
    54  	case TEST_LONG_RUNNING:
    55  		return "TEST_LONG_RUNNING"
    56  	case TEST_RESOURCE_INTENSIVE:
    57  		return "TEST_RESOURCE_INTENSIVE"
    58  	}
    59  	return "UNKNOWN"
    60  }
    61  
    62  func (s SkipReason) MarshalJSON() ([]byte, error) {
    63  	return json.Marshal(s.String())
    64  }
    65  
    66  func parseSkipReason(reason string) SkipReason {
    67  	switch reason {
    68  	case "TEST_FLAKY":
    69  		return TEST_FLAKY
    70  	case "TEST_TODO":
    71  		return TEST_TODO
    72  	case "TEST_REQUIRES_GCP_ACCESS":
    73  		return TEST_REQUIRES_GCP_ACCESS
    74  	case "TEST_DEPRECATED":
    75  		return TEST_DEPRECATED
    76  	case "TEST_LONG_RUNNING":
    77  		return TEST_LONG_RUNNING
    78  	case "TEST_RESOURCE_INTENSIVE":
    79  		return TEST_RESOURCE_INTENSIVE
    80  	default:
    81  		return 0
    82  	}
    83  }
    84  
    85  func ParseSkipReason(output string) (SkipReason, bool) {
    86  	// match output like:
    87  	// "    test_file.go:123: SKIP [TEST_REASON]: message\n"
    88  	r := regexp.MustCompile(`(?s)^\s+[a-zA-Z0-9_\-]+\.go:[0-9]+: SKIP \[([A-Z_]+)]: .*$`)
    89  	matches := r.FindStringSubmatch(output)
    90  
    91  	if len(matches) == 2 {
    92  		skipReason := parseSkipReason(matches[1])
    93  		if skipReason != 0 {
    94  			return skipReason, true
    95  		}
    96  	}
    97  
    98  	return 0, false
    99  }
   100  
   101  func SkipUnless(t *testing.T, reason SkipReason, message string) {
   102  	t.Helper()
   103  	if os.Getenv(reason.String()) == "" {
   104  		t.Skipf("SKIP [%s]: %s", reason.String(), message)
   105  	}
   106  }
   107  
   108  type SkipBenchmarkReason int
   109  
   110  const (
   111  	BENCHMARK_EXPERIMENT SkipBenchmarkReason = iota + 1
   112  )
   113  
   114  func (s SkipBenchmarkReason) String() string {
   115  	switch s {
   116  	case BENCHMARK_EXPERIMENT:
   117  		return "BENCHMARK_EXPERIMENT"
   118  	}
   119  	return "UNKNOWN"
   120  }
   121  
   122  func SkipBenchmarkUnless(b *testing.B, reason SkipBenchmarkReason, message string) {
   123  	b.Helper()
   124  	if os.Getenv(reason.String()) == "" {
   125  		b.Skip(message)
   126  	}
   127  }
   128  
   129  // AssertReturnsBefore asserts that the given function returns before the
   130  // duration expires.
   131  func AssertReturnsBefore(t *testing.T, f func(), duration time.Duration, msgAndArgs ...interface{}) bool {
   132  	done := make(chan struct{})
   133  
   134  	go func() {
   135  		f()
   136  		close(done)
   137  	}()
   138  
   139  	select {
   140  	case <-time.After(duration):
   141  		t.Log("function did not return in time")
   142  		assert.Fail(t, "function did not close in time", msgAndArgs...)
   143  	case <-done:
   144  		return true
   145  	}
   146  	return false
   147  }
   148  
   149  // ClosedChannel returns a closed channel.
   150  func ClosedChannel() <-chan struct{} {
   151  	ch := make(chan struct{})
   152  	close(ch)
   153  	return ch
   154  }
   155  
   156  // AssertClosesBefore asserts that the given channel closes before the
   157  // duration expires.
   158  func AssertClosesBefore(t assert.TestingT, done <-chan struct{}, duration time.Duration, msgAndArgs ...interface{}) {
   159  	select {
   160  	case <-time.After(duration):
   161  		assert.Fail(t, "channel did not return in time", msgAndArgs...)
   162  	case <-done:
   163  		return
   164  	}
   165  }
   166  
   167  func AssertFloatEqual(t *testing.T, expected, actual float64, message string) {
   168  	tolerance := .00001
   169  	if !(math.Abs(expected-actual) < tolerance) {
   170  		assert.Equal(t, expected, actual, message)
   171  	}
   172  }
   173  
   174  // AssertNotClosesBefore asserts that the given channel does not close before the duration expires.
   175  func AssertNotClosesBefore(t assert.TestingT, done <-chan struct{}, duration time.Duration, msgAndArgs ...interface{}) {
   176  	select {
   177  	case <-time.After(duration):
   178  		return
   179  	case <-done:
   180  		assert.Fail(t, "channel closed before timeout", msgAndArgs...)
   181  	}
   182  }
   183  
   184  // RequireReturnsBefore requires that the given function returns before the
   185  // duration expires.
   186  func RequireReturnsBefore(t testing.TB, f func(), duration time.Duration, message string) {
   187  	done := make(chan struct{})
   188  
   189  	go func() {
   190  		f()
   191  		close(done)
   192  	}()
   193  
   194  	RequireCloseBefore(t, done, duration, message+": function did not return on time")
   195  }
   196  
   197  // RequireComponentsDoneBefore invokes the done method of each of the input components concurrently, and
   198  // fails the test if any components shutdown takes longer than the specified duration.
   199  func RequireComponentsDoneBefore(t testing.TB, duration time.Duration, components ...module.ReadyDoneAware) {
   200  	done := util.AllDone(components...)
   201  	RequireCloseBefore(t, done, duration, "failed to shutdown all components on time")
   202  }
   203  
   204  // RequireComponentsReadyBefore invokes the ready method of each of the input components concurrently, and
   205  // fails the test if any components startup takes longer than the specified duration.
   206  func RequireComponentsReadyBefore(t testing.TB, duration time.Duration, components ...module.ReadyDoneAware) {
   207  	ready := util.AllReady(components...)
   208  	RequireCloseBefore(t, ready, duration, "failed to start all components on time")
   209  }
   210  
   211  // RequireCloseBefore requires that the given channel returns before the
   212  // duration expires.
   213  func RequireCloseBefore(t testing.TB, c <-chan struct{}, duration time.Duration, message string) {
   214  	select {
   215  	case <-time.After(duration):
   216  		require.Fail(t, "could not close done channel on time: "+message)
   217  	case <-c:
   218  		return
   219  	}
   220  }
   221  
   222  // RequireClosed is a test helper function that fails the test if channel `ch` is not closed.
   223  func RequireClosed(t *testing.T, ch <-chan struct{}, message string) {
   224  	select {
   225  	case <-ch:
   226  	default:
   227  		require.Fail(t, "channel is not closed: "+message)
   228  	}
   229  }
   230  
   231  // RequireConcurrentCallsReturnBefore is a test helper that runs function `f` count-many times concurrently,
   232  // and requires all invocations to return within duration.
   233  func RequireConcurrentCallsReturnBefore(t *testing.T, f func(), count int, duration time.Duration, message string) {
   234  	wg := &sync.WaitGroup{}
   235  	for i := 0; i < count; i++ {
   236  		wg.Add(1)
   237  		go func() {
   238  			f()
   239  			wg.Done()
   240  		}()
   241  	}
   242  
   243  	RequireReturnsBefore(t, wg.Wait, duration, message)
   244  }
   245  
   246  // RequireNeverReturnBefore is a test helper that tries invoking function `f` and fails the test if either:
   247  // - function `f` is not invoked within 1 second.
   248  // - function `f` returns before specified `duration`.
   249  //
   250  // It also returns a channel that is closed once the function `f` returns and hence its openness can evaluate
   251  // return status of function `f` for intervals longer than duration.
   252  func RequireNeverReturnBefore(t *testing.T, f func(), duration time.Duration, message string) <-chan struct{} {
   253  	ch := make(chan struct{})
   254  	wg := sync.WaitGroup{}
   255  	wg.Add(1)
   256  
   257  	go func() {
   258  		wg.Done()
   259  		f()
   260  		close(ch)
   261  	}()
   262  
   263  	// requires function invoked within next 1 second
   264  	RequireReturnsBefore(t, wg.Wait, 1*time.Second, "could not invoke the function: "+message)
   265  
   266  	// requires function never returns within duration
   267  	RequireNeverClosedWithin(t, ch, duration, "unexpected return: "+message)
   268  
   269  	return ch
   270  }
   271  
   272  // RequireNeverClosedWithin is a test helper function that fails the test if channel `ch` is closed before the
   273  // determined duration.
   274  func RequireNeverClosedWithin(t *testing.T, ch <-chan struct{}, duration time.Duration, message string) {
   275  	select {
   276  	case <-time.After(duration):
   277  	case <-ch:
   278  		require.Fail(t, "channel closed before timeout: "+message)
   279  	}
   280  }
   281  
   282  // RequireNotClosed is a test helper function that fails the test if channel `ch` is closed.
   283  func RequireNotClosed(t *testing.T, ch <-chan struct{}, message string) {
   284  	select {
   285  	case <-ch:
   286  		require.Fail(t, "channel is closed: "+message)
   287  	default:
   288  	}
   289  }
   290  
   291  // AssertErrSubstringMatch asserts that two errors match with substring
   292  // checking on the Error method (`expected` must be a substring of `actual`, to
   293  // account for the actual error being wrapped). Fails the test if either error
   294  // is nil.
   295  //
   296  // NOTE: This should only be used in cases where `errors.Is` cannot be, like
   297  // when errors are transmitted over the network without type information.
   298  func AssertErrSubstringMatch(t testing.TB, expected, actual error) {
   299  	require.NotNil(t, expected)
   300  	require.NotNil(t, actual)
   301  	assert.True(
   302  		t,
   303  		strings.Contains(actual.Error(), expected.Error()) || strings.Contains(expected.Error(), actual.Error()),
   304  		"expected error: '%s', got: '%s'", expected.Error(), actual.Error(),
   305  	)
   306  }
   307  
   308  func TempDir(t testing.TB) string {
   309  	dir, err := os.MkdirTemp("", "flow-testing-temp-")
   310  	require.NoError(t, err)
   311  	return dir
   312  }
   313  
   314  func RunWithTempDir(t testing.TB, f func(string)) {
   315  	dbDir := TempDir(t)
   316  	defer func() {
   317  		require.NoError(t, os.RemoveAll(dbDir))
   318  	}()
   319  	f(dbDir)
   320  }
   321  
   322  func badgerDB(t testing.TB, dir string, create func(badger.Options) (*badger.DB, error)) *badger.DB {
   323  	opts := badger.
   324  		DefaultOptions(dir).
   325  		WithKeepL0InMemory(true).
   326  		WithLogger(nil)
   327  	db, err := create(opts)
   328  	require.NoError(t, err)
   329  	return db
   330  }
   331  
   332  func BadgerDB(t testing.TB, dir string) *badger.DB {
   333  	return badgerDB(t, dir, badger.Open)
   334  }
   335  
   336  func TypedBadgerDB(t testing.TB, dir string, create func(badger.Options) (*badger.DB, error)) *badger.DB {
   337  	return badgerDB(t, dir, create)
   338  }
   339  
   340  func RunWithBadgerDB(t testing.TB, f func(*badger.DB)) {
   341  	RunWithTempDir(t, func(dir string) {
   342  		db := BadgerDB(t, dir)
   343  		defer func() {
   344  			assert.NoError(t, db.Close())
   345  		}()
   346  		f(db)
   347  	})
   348  }
   349  
   350  // RunWithTypedBadgerDB creates a Badger DB that is passed to f and closed
   351  // after f returns. The extra create parameter allows passing in a database
   352  // constructor function which instantiates a database with a particular type
   353  // marker, for testing storage modules which require a backed with a particular
   354  // type.
   355  func RunWithTypedBadgerDB(t testing.TB, create func(badger.Options) (*badger.DB, error), f func(*badger.DB)) {
   356  	RunWithTempDir(t, func(dir string) {
   357  		db := badgerDB(t, dir, create)
   358  		defer func() {
   359  			assert.NoError(t, db.Close())
   360  		}()
   361  		f(db)
   362  	})
   363  }
   364  
   365  func TempBadgerDB(t testing.TB) (*badger.DB, string) {
   366  	dir := TempDir(t)
   367  	db := BadgerDB(t, dir)
   368  	return db, dir
   369  }
   370  
   371  func TempPebblePath(t *testing.T) string {
   372  	return path.Join(TempDir(t), "pebble"+strconv.Itoa(rand.Int())+".db")
   373  }
   374  
   375  func TempPebbleDBWithOpts(t testing.TB, opts *pebble.Options) (*pebble.DB, string) {
   376  	// create random path string for parallelization
   377  	dbpath := path.Join(TempDir(t), "pebble"+strconv.Itoa(rand.Int())+".db")
   378  	db, err := pebble.Open(dbpath, opts)
   379  	require.NoError(t, err)
   380  	return db, dbpath
   381  }
   382  
   383  func Concurrently(n int, f func(int)) {
   384  	var wg sync.WaitGroup
   385  	for i := 0; i < n; i++ {
   386  		wg.Add(1)
   387  		go func(i int) {
   388  			f(i)
   389  			wg.Done()
   390  		}(i)
   391  	}
   392  	wg.Wait()
   393  }
   394  
   395  // AssertEqualBlocksLenAndOrder asserts that both a segment of blocks have the same len and blocks are in the same order
   396  func AssertEqualBlocksLenAndOrder(t *testing.T, expectedBlocks, actualSegmentBlocks []*flow.Block) {
   397  	assert.Equal(t, flow.GetIDs(expectedBlocks), flow.GetIDs(actualSegmentBlocks))
   398  }
   399  
   400  // NetworkCodec returns cbor codec.
   401  func NetworkCodec() network.Codec {
   402  	return cborcodec.NewCodec()
   403  }
   404  
   405  // NetworkTopology returns the default topology for testing purposes.
   406  func NetworkTopology() network.Topology {
   407  	return topology.NewFullyConnectedTopology()
   408  }
   409  
   410  // CrashTest safely tests functions that crash (as the expected behavior) by checking that running the function creates an error and
   411  // an expected error message.
   412  func CrashTest(t *testing.T, scenario func(*testing.T), expectedErrorMsg string) {
   413  	CrashTestWithExpectedStatus(t, scenario, expectedErrorMsg, 1)
   414  }
   415  
   416  // CrashTestWithExpectedStatus checks for the test crashing with a specific exit code.
   417  func CrashTestWithExpectedStatus(
   418  	t *testing.T,
   419  	scenario func(*testing.T),
   420  	expectedErrorMsg string,
   421  	expectedStatus ...int,
   422  ) {
   423  	require.NotNil(t, scenario)
   424  	require.NotEmpty(t, expectedStatus)
   425  
   426  	if os.Getenv("CRASH_TEST") == "1" {
   427  		scenario(t)
   428  		return
   429  	}
   430  
   431  	cmd := exec.Command(os.Args[0], "-test.run="+t.Name())
   432  	cmd.Env = append(os.Environ(), "CRASH_TEST=1")
   433  
   434  	outBytes, err := cmd.Output()
   435  	// expect error from run
   436  	require.Error(t, err)
   437  
   438  	// expect specific status codes
   439  	require.Contains(t, expectedStatus, cmd.ProcessState.ExitCode())
   440  
   441  	// expect logger.Fatal() message to be pushed to stdout
   442  	outStr := string(outBytes)
   443  	require.Contains(t, outStr, expectedErrorMsg)
   444  }
   445  
   446  // GenerateRandomStringWithLen returns a string of random alpha characters of the provided length
   447  func GenerateRandomStringWithLen(commentLen uint) string {
   448  	const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
   449  	bytes := make([]byte, commentLen)
   450  	for i := range bytes {
   451  		bytes[i] = letterBytes[rand.Intn(len(letterBytes))]
   452  	}
   453  	return string(bytes)
   454  }
   455  
   456  // PeerIdFixture creates a random and unique peer ID (libp2p node ID).
   457  func PeerIdFixture(tb testing.TB) peer.ID {
   458  	peerID, err := peerIDFixture()
   459  	require.NoError(tb, err)
   460  	return peerID
   461  }
   462  
   463  func peerIDFixture() (peer.ID, error) {
   464  	key, err := generateNetworkingKey(IdentifierFixture())
   465  	if err != nil {
   466  		return "", err
   467  	}
   468  	pubKey, err := keyutils.LibP2PPublicKeyFromFlow(key.PublicKey())
   469  	if err != nil {
   470  		return "", err
   471  	}
   472  
   473  	peerID, err := peer.IDFromPublicKey(pubKey)
   474  	if err != nil {
   475  		return "", err
   476  	}
   477  
   478  	return peerID, nil
   479  }
   480  
   481  // generateNetworkingKey generates a Flow ECDSA key using the given seed
   482  func generateNetworkingKey(s flow.Identifier) (crypto.PrivateKey, error) {
   483  	seed := make([]byte, crypto.KeyGenSeedMinLen)
   484  	copy(seed, s[:])
   485  	return crypto.GeneratePrivateKey(crypto.ECDSASecp256k1, seed)
   486  }
   487  
   488  // PeerIdFixtures creates random and unique peer IDs (libp2p node IDs).
   489  func PeerIdFixtures(t *testing.T, n int) []peer.ID {
   490  	peerIDs := make([]peer.ID, n)
   491  	for i := 0; i < n; i++ {
   492  		peerIDs[i] = PeerIdFixture(t)
   493  	}
   494  	return peerIDs
   495  }