github.com/koko1123/flow-go-1@v0.29.6/utils/unittest/unittest.go (about)

     1  package unittest
     2  
     3  import (
     4  	"encoding/json"
     5  	"math"
     6  	"math/rand"
     7  	"os"
     8  	"os/exec"
     9  	"regexp"
    10  	"strings"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/dgraph-io/badger/v3"
    16  	"github.com/rs/zerolog"
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  
    20  	"github.com/koko1123/flow-go-1/model/flow"
    21  	"github.com/koko1123/flow-go-1/module"
    22  	"github.com/koko1123/flow-go-1/module/util"
    23  	"github.com/koko1123/flow-go-1/network"
    24  	cborcodec "github.com/koko1123/flow-go-1/network/codec/cbor"
    25  	"github.com/koko1123/flow-go-1/network/slashing"
    26  	"github.com/koko1123/flow-go-1/network/topology"
    27  )
    28  
    29  type SkipReason int
    30  
    31  const (
    32  	TEST_FLAKY               SkipReason = iota + 1 // flaky
    33  	TEST_TODO                                      // not fully implemented or broken and needs to be fixed
    34  	TEST_REQUIRES_GCP_ACCESS                       // requires the environment to be configured with GCP credentials
    35  	TEST_DEPRECATED                                // uses code that has been deprecated / disabled
    36  	TEST_LONG_RUNNING                              // long running
    37  	TEST_RESOURCE_INTENSIVE                        // resource intensive test
    38  )
    39  
    40  func (s SkipReason) String() string {
    41  	switch s {
    42  	case TEST_FLAKY:
    43  		return "TEST_FLAKY"
    44  	case TEST_TODO:
    45  		return "TEST_TODO"
    46  	case TEST_REQUIRES_GCP_ACCESS:
    47  		return "TEST_REQUIRES_GCP_ACCESS"
    48  	case TEST_DEPRECATED:
    49  		return "TEST_DEPRECATED"
    50  	case TEST_LONG_RUNNING:
    51  		return "TEST_LONG_RUNNING"
    52  	case TEST_RESOURCE_INTENSIVE:
    53  		return "TEST_RESOURCE_INTENSIVE"
    54  	}
    55  	return "UNKNOWN"
    56  }
    57  
    58  func (s SkipReason) MarshalJSON() ([]byte, error) {
    59  	return json.Marshal(s.String())
    60  }
    61  
    62  func parseSkipReason(reason string) SkipReason {
    63  	switch reason {
    64  	case "TEST_FLAKY":
    65  		return TEST_FLAKY
    66  	case "TEST_TODO":
    67  		return TEST_TODO
    68  	case "TEST_REQUIRES_GCP_ACCESS":
    69  		return TEST_REQUIRES_GCP_ACCESS
    70  	case "TEST_DEPRECATED":
    71  		return TEST_DEPRECATED
    72  	case "TEST_LONG_RUNNING":
    73  		return TEST_LONG_RUNNING
    74  	case "TEST_RESOURCE_INTENSIVE":
    75  		return TEST_RESOURCE_INTENSIVE
    76  	default:
    77  		return 0
    78  	}
    79  }
    80  
    81  func ParseSkipReason(output string) (SkipReason, bool) {
    82  	// match output like:
    83  	// "    test_file.go:123: SKIP [TEST_REASON]: message\n"
    84  	r := regexp.MustCompile(`(?s)^\s+[a-zA-Z0-9_\-]+\.go:[0-9]+: SKIP \[([A-Z_]+)]: .*$`)
    85  	matches := r.FindStringSubmatch(output)
    86  
    87  	if len(matches) == 2 {
    88  		skipReason := parseSkipReason(matches[1])
    89  		if skipReason != 0 {
    90  			return skipReason, true
    91  		}
    92  	}
    93  
    94  	return 0, false
    95  }
    96  
    97  func SkipUnless(t *testing.T, reason SkipReason, message string) {
    98  	t.Helper()
    99  	if os.Getenv(reason.String()) == "" {
   100  		t.Skipf("SKIP [%s]: %s", reason.String(), message)
   101  	}
   102  }
   103  
   104  type SkipBenchmarkReason int
   105  
   106  const (
   107  	BENCHMARK_EXPERIMENT SkipBenchmarkReason = iota + 1
   108  )
   109  
   110  func (s SkipBenchmarkReason) String() string {
   111  	switch s {
   112  	case BENCHMARK_EXPERIMENT:
   113  		return "BENCHMARK_EXPERIMENT"
   114  	}
   115  	return "UNKNOWN"
   116  }
   117  
   118  func SkipBenchmarkUnless(b *testing.B, reason SkipBenchmarkReason, message string) {
   119  	b.Helper()
   120  	if os.Getenv(reason.String()) == "" {
   121  		b.Skip(message)
   122  	}
   123  }
   124  
   125  func ExpectPanic(expectedMsg string, t *testing.T) {
   126  	if r := recover(); r != nil {
   127  		err := r.(error)
   128  		if err.Error() != expectedMsg {
   129  			t.Errorf("expected %v to be %v", err, expectedMsg)
   130  		}
   131  		return
   132  	}
   133  	t.Errorf("Expected to panic with `%s`, but did not panic", expectedMsg)
   134  }
   135  
   136  // AssertReturnsBefore asserts that the given function returns before the
   137  // duration expires.
   138  func AssertReturnsBefore(t *testing.T, f func(), duration time.Duration, msgAndArgs ...interface{}) {
   139  	done := make(chan struct{})
   140  
   141  	go func() {
   142  		f()
   143  		close(done)
   144  	}()
   145  
   146  	select {
   147  	case <-time.After(duration):
   148  		t.Log("function did not return in time")
   149  		assert.Fail(t, "function did not close in time", msgAndArgs...)
   150  	case <-done:
   151  		return
   152  	}
   153  }
   154  
   155  // AssertClosesBefore asserts that the given channel closes before the
   156  // duration expires.
   157  func AssertClosesBefore(t assert.TestingT, done <-chan struct{}, duration time.Duration, msgAndArgs ...interface{}) {
   158  	select {
   159  	case <-time.After(duration):
   160  		assert.Fail(t, "channel did not return in time", msgAndArgs...)
   161  	case <-done:
   162  		return
   163  	}
   164  }
   165  
   166  func AssertFloatEqual(t *testing.T, expected, actual float64, message string) {
   167  	tolerance := .00001
   168  	if !(math.Abs(expected-actual) < tolerance) {
   169  		assert.Equal(t, expected, actual, message)
   170  	}
   171  }
   172  
   173  // AssertNotClosesBefore asserts that the given channel does not close before the duration expires.
   174  func AssertNotClosesBefore(t assert.TestingT, done <-chan struct{}, duration time.Duration, msgAndArgs ...interface{}) {
   175  	select {
   176  	case <-time.After(duration):
   177  		return
   178  	case <-done:
   179  		assert.Fail(t, "channel closed before timeout", msgAndArgs...)
   180  	}
   181  }
   182  
   183  // RequireReturnsBefore requires that the given function returns before the
   184  // duration expires.
   185  func RequireReturnsBefore(t testing.TB, f func(), duration time.Duration, message string) {
   186  	done := make(chan struct{})
   187  
   188  	go func() {
   189  		f()
   190  		close(done)
   191  	}()
   192  
   193  	RequireCloseBefore(t, done, duration, message+": function did not return on time")
   194  }
   195  
   196  // RequireComponentsDoneBefore invokes the done method of each of the input components concurrently, and
   197  // fails the test if any components shutdown takes longer than the specified duration.
   198  func RequireComponentsDoneBefore(t testing.TB, duration time.Duration, components ...module.ReadyDoneAware) {
   199  	done := util.AllDone(components...)
   200  	RequireCloseBefore(t, done, duration, "failed to shutdown all components on time")
   201  }
   202  
   203  // RequireComponentsReadyBefore invokes the ready method of each of the input components concurrently, and
   204  // fails the test if any components startup takes longer than the specified duration.
   205  func RequireComponentsReadyBefore(t testing.TB, duration time.Duration, components ...module.ReadyDoneAware) {
   206  	ready := util.AllReady(components...)
   207  	RequireCloseBefore(t, ready, duration, "failed to start all components on time")
   208  }
   209  
   210  // RequireCloseBefore requires that the given channel returns before the
   211  // duration expires.
   212  func RequireCloseBefore(t testing.TB, c <-chan struct{}, duration time.Duration, message string) {
   213  	select {
   214  	case <-time.After(duration):
   215  		require.Fail(t, "could not close done channel on time: "+message)
   216  	case <-c:
   217  		return
   218  	}
   219  }
   220  
   221  // RequireClosed is a test helper function that fails the test if channel `ch` is not closed.
   222  func RequireClosed(t *testing.T, ch <-chan struct{}, message string) {
   223  	select {
   224  	case <-ch:
   225  	default:
   226  		require.Fail(t, "channel is not closed: "+message)
   227  	}
   228  }
   229  
   230  // RequireConcurrentCallsReturnBefore is a test helper that runs function `f` count-many times concurrently,
   231  // and requires all invocations to return within duration.
   232  func RequireConcurrentCallsReturnBefore(t *testing.T, f func(), count int, duration time.Duration, message string) {
   233  	wg := &sync.WaitGroup{}
   234  	for i := 0; i < count; i++ {
   235  		wg.Add(1)
   236  		go func() {
   237  			f()
   238  			wg.Done()
   239  		}()
   240  	}
   241  
   242  	RequireReturnsBefore(t, wg.Wait, duration, message)
   243  }
   244  
   245  // RequireNeverReturnBefore is a test helper that tries invoking function `f` and fails the test if either:
   246  // - function `f` is not invoked within 1 second.
   247  // - function `f` returns before specified `duration`.
   248  //
   249  // It also returns a channel that is closed once the function `f` returns and hence its openness can evaluate
   250  // return status of function `f` for intervals longer than duration.
   251  func RequireNeverReturnBefore(t *testing.T, f func(), duration time.Duration, message string) <-chan struct{} {
   252  	ch := make(chan struct{})
   253  	wg := sync.WaitGroup{}
   254  	wg.Add(1)
   255  
   256  	go func() {
   257  		wg.Done()
   258  		f()
   259  		close(ch)
   260  	}()
   261  
   262  	// requires function invoked within next 1 second
   263  	RequireReturnsBefore(t, wg.Wait, 1*time.Second, "could not invoke the function: "+message)
   264  
   265  	// requires function never returns within duration
   266  	RequireNeverClosedWithin(t, ch, duration, "unexpected return: "+message)
   267  
   268  	return ch
   269  }
   270  
   271  // RequireNeverClosedWithin is a test helper function that fails the test if channel `ch` is closed before the
   272  // determined duration.
   273  func RequireNeverClosedWithin(t *testing.T, ch <-chan struct{}, duration time.Duration, message string) {
   274  	select {
   275  	case <-time.After(duration):
   276  	case <-ch:
   277  		require.Fail(t, "channel closed before timeout: "+message)
   278  	}
   279  }
   280  
   281  // RequireNotClosed is a test helper function that fails the test if channel `ch` is closed.
   282  func RequireNotClosed(t *testing.T, ch <-chan struct{}, message string) {
   283  	select {
   284  	case <-ch:
   285  		require.Fail(t, "channel is closed: "+message)
   286  	default:
   287  	}
   288  }
   289  
   290  // AssertErrSubstringMatch asserts that two errors match with substring
   291  // checking on the Error method (`expected` must be a substring of `actual`, to
   292  // account for the actual error being wrapped). Fails the test if either error
   293  // is nil.
   294  //
   295  // NOTE: This should only be used in cases where `errors.Is` cannot be, like
   296  // when errors are transmitted over the network without type information.
   297  func AssertErrSubstringMatch(t testing.TB, expected, actual error) {
   298  	require.NotNil(t, expected)
   299  	require.NotNil(t, actual)
   300  	assert.True(
   301  		t,
   302  		strings.Contains(actual.Error(), expected.Error()) || strings.Contains(expected.Error(), actual.Error()),
   303  		"expected error: '%s', got: '%s'", expected.Error(), actual.Error(),
   304  	)
   305  }
   306  
   307  func TempDir(t testing.TB) string {
   308  	dir, err := os.MkdirTemp("", "flow-testing-temp-")
   309  	require.NoError(t, err)
   310  	return dir
   311  }
   312  
   313  func RunWithTempDir(t testing.TB, f func(string)) {
   314  	dbDir := TempDir(t)
   315  	defer func() {
   316  		require.NoError(t, os.RemoveAll(dbDir))
   317  	}()
   318  	f(dbDir)
   319  }
   320  
   321  func badgerDB(t testing.TB, dir string, create func(badger.Options) (*badger.DB, error)) *badger.DB {
   322  	opts := badger.
   323  		DefaultOptions(dir)
   324  	db, err := create(opts)
   325  	require.NoError(t, err)
   326  	return db
   327  }
   328  
   329  func BadgerDB(t testing.TB, dir string) *badger.DB {
   330  	return badgerDB(t, dir, badger.Open)
   331  }
   332  
   333  func TypedBadgerDB(t testing.TB, dir string, create func(badger.Options) (*badger.DB, error)) *badger.DB {
   334  	return badgerDB(t, dir, create)
   335  }
   336  
   337  func RunWithBadgerDB(t testing.TB, f func(*badger.DB)) {
   338  	RunWithTempDir(t, func(dir string) {
   339  		db := BadgerDB(t, dir)
   340  		defer func() {
   341  			assert.NoError(t, db.Close())
   342  		}()
   343  		f(db)
   344  	})
   345  }
   346  
   347  // RunWithTypedBadgerDB creates a Badger DB that is passed to f and closed
   348  // after f returns. The extra create parameter allows passing in a database
   349  // constructor function which instantiates a database with a particular type
   350  // marker, for testing storage modules which require a backed with a particular
   351  // type.
   352  func RunWithTypedBadgerDB(t testing.TB, create func(badger.Options) (*badger.DB, error), f func(*badger.DB)) {
   353  	RunWithTempDir(t, func(dir string) {
   354  		db := badgerDB(t, dir, create)
   355  		defer func() {
   356  			assert.NoError(t, db.Close())
   357  		}()
   358  		f(db)
   359  	})
   360  }
   361  
   362  func TempBadgerDB(t testing.TB) (*badger.DB, string) {
   363  	dir := TempDir(t)
   364  	db := BadgerDB(t, dir)
   365  	return db, dir
   366  }
   367  
   368  func Concurrently(n int, f func(int)) {
   369  	var wg sync.WaitGroup
   370  	for i := 0; i < n; i++ {
   371  		wg.Add(1)
   372  		go func(i int) {
   373  			f(i)
   374  			wg.Done()
   375  		}(i)
   376  	}
   377  	wg.Wait()
   378  }
   379  
   380  // AssertEqualBlocksLenAndOrder asserts that both a segment of blocks have the same len and blocks are in the same order
   381  func AssertEqualBlocksLenAndOrder(t *testing.T, expectedBlocks, actualSegmentBlocks []*flow.Block) {
   382  	assert.Equal(t, flow.GetIDs(expectedBlocks), flow.GetIDs(actualSegmentBlocks))
   383  }
   384  
   385  // NetworkCodec returns cbor codec.
   386  func NetworkCodec() network.Codec {
   387  	return cborcodec.NewCodec()
   388  }
   389  
   390  // NetworkTopology returns the default topology for testing purposes.
   391  func NetworkTopology() network.Topology {
   392  	return topology.NewFullyConnectedTopology()
   393  }
   394  
   395  // CrashTest safely tests functions that crash (as the expected behavior) by checking that running the function creates an error and
   396  // an expected error message.
   397  func CrashTest(t *testing.T, scenario func(*testing.T), expectedErrorMsg string) {
   398  	CrashTestWithExpectedStatus(t, scenario, expectedErrorMsg, 1)
   399  }
   400  
   401  // CrashTestWithExpectedStatus checks for the test crashing with a specific exit code.
   402  func CrashTestWithExpectedStatus(
   403  	t *testing.T,
   404  	scenario func(*testing.T),
   405  	expectedErrorMsg string,
   406  	expectedStatus ...int,
   407  ) {
   408  	require.NotNil(t, scenario)
   409  	require.NotEmpty(t, expectedStatus)
   410  
   411  	if os.Getenv("CRASH_TEST") == "1" {
   412  		scenario(t)
   413  		return
   414  	}
   415  
   416  	cmd := exec.Command(os.Args[0], "-test.run="+t.Name())
   417  	cmd.Env = append(os.Environ(), "CRASH_TEST=1")
   418  
   419  	outBytes, err := cmd.Output()
   420  	// expect error from run
   421  	require.Error(t, err)
   422  
   423  	// expect specific status codes
   424  	// require.Contains(t, expectedStatus, cmd.ProcessState.ExitCode())
   425  
   426  	// expect logger.Fatal() message to be pushed to stdout
   427  	outStr := string(outBytes)
   428  	require.Contains(t, outStr, expectedErrorMsg)
   429  }
   430  
   431  // GenerateRandomStringWithLen returns a string of random alpha characters of the provided length
   432  func GenerateRandomStringWithLen(commentLen uint) string {
   433  	const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
   434  	bytes := make([]byte, commentLen)
   435  	for i := range bytes {
   436  		bytes[i] = letterBytes[rand.Intn(len(letterBytes))]
   437  	}
   438  	return string(bytes)
   439  }
   440  
   441  // NetworkSlashingViolationsConsumer returns a slashing violations consumer for network middleware
   442  func NetworkSlashingViolationsConsumer(logger zerolog.Logger, metrics module.NetworkSecurityMetrics) slashing.ViolationsConsumer {
   443  	return slashing.NewSlashingViolationsConsumer(logger, metrics)
   444  }