github.com/koko1123/flow-go-1@v0.29.6/utils/unittest/unittest.go (about) 1 package unittest 2 3 import ( 4 "encoding/json" 5 "math" 6 "math/rand" 7 "os" 8 "os/exec" 9 "regexp" 10 "strings" 11 "sync" 12 "testing" 13 "time" 14 15 "github.com/dgraph-io/badger/v3" 16 "github.com/rs/zerolog" 17 "github.com/stretchr/testify/assert" 18 "github.com/stretchr/testify/require" 19 20 "github.com/koko1123/flow-go-1/model/flow" 21 "github.com/koko1123/flow-go-1/module" 22 "github.com/koko1123/flow-go-1/module/util" 23 "github.com/koko1123/flow-go-1/network" 24 cborcodec "github.com/koko1123/flow-go-1/network/codec/cbor" 25 "github.com/koko1123/flow-go-1/network/slashing" 26 "github.com/koko1123/flow-go-1/network/topology" 27 ) 28 29 type SkipReason int 30 31 const ( 32 TEST_FLAKY SkipReason = iota + 1 // flaky 33 TEST_TODO // not fully implemented or broken and needs to be fixed 34 TEST_REQUIRES_GCP_ACCESS // requires the environment to be configured with GCP credentials 35 TEST_DEPRECATED // uses code that has been deprecated / disabled 36 TEST_LONG_RUNNING // long running 37 TEST_RESOURCE_INTENSIVE // resource intensive test 38 ) 39 40 func (s SkipReason) String() string { 41 switch s { 42 case TEST_FLAKY: 43 return "TEST_FLAKY" 44 case TEST_TODO: 45 return "TEST_TODO" 46 case TEST_REQUIRES_GCP_ACCESS: 47 return "TEST_REQUIRES_GCP_ACCESS" 48 case TEST_DEPRECATED: 49 return "TEST_DEPRECATED" 50 case TEST_LONG_RUNNING: 51 return "TEST_LONG_RUNNING" 52 case TEST_RESOURCE_INTENSIVE: 53 return "TEST_RESOURCE_INTENSIVE" 54 } 55 return "UNKNOWN" 56 } 57 58 func (s SkipReason) MarshalJSON() ([]byte, error) { 59 return json.Marshal(s.String()) 60 } 61 62 func parseSkipReason(reason string) SkipReason { 63 switch reason { 64 case "TEST_FLAKY": 65 return TEST_FLAKY 66 case "TEST_TODO": 67 return TEST_TODO 68 case "TEST_REQUIRES_GCP_ACCESS": 69 return TEST_REQUIRES_GCP_ACCESS 70 case "TEST_DEPRECATED": 71 return TEST_DEPRECATED 72 case "TEST_LONG_RUNNING": 73 return TEST_LONG_RUNNING 74 case "TEST_RESOURCE_INTENSIVE": 75 return TEST_RESOURCE_INTENSIVE 76 default: 77 return 0 78 } 79 } 80 81 func ParseSkipReason(output string) (SkipReason, bool) { 82 // match output like: 83 // " test_file.go:123: SKIP [TEST_REASON]: message\n" 84 r := regexp.MustCompile(`(?s)^\s+[a-zA-Z0-9_\-]+\.go:[0-9]+: SKIP \[([A-Z_]+)]: .*$`) 85 matches := r.FindStringSubmatch(output) 86 87 if len(matches) == 2 { 88 skipReason := parseSkipReason(matches[1]) 89 if skipReason != 0 { 90 return skipReason, true 91 } 92 } 93 94 return 0, false 95 } 96 97 func SkipUnless(t *testing.T, reason SkipReason, message string) { 98 t.Helper() 99 if os.Getenv(reason.String()) == "" { 100 t.Skipf("SKIP [%s]: %s", reason.String(), message) 101 } 102 } 103 104 type SkipBenchmarkReason int 105 106 const ( 107 BENCHMARK_EXPERIMENT SkipBenchmarkReason = iota + 1 108 ) 109 110 func (s SkipBenchmarkReason) String() string { 111 switch s { 112 case BENCHMARK_EXPERIMENT: 113 return "BENCHMARK_EXPERIMENT" 114 } 115 return "UNKNOWN" 116 } 117 118 func SkipBenchmarkUnless(b *testing.B, reason SkipBenchmarkReason, message string) { 119 b.Helper() 120 if os.Getenv(reason.String()) == "" { 121 b.Skip(message) 122 } 123 } 124 125 func ExpectPanic(expectedMsg string, t *testing.T) { 126 if r := recover(); r != nil { 127 err := r.(error) 128 if err.Error() != expectedMsg { 129 t.Errorf("expected %v to be %v", err, expectedMsg) 130 } 131 return 132 } 133 t.Errorf("Expected to panic with `%s`, but did not panic", expectedMsg) 134 } 135 136 // AssertReturnsBefore asserts that the given function returns before the 137 // duration expires. 138 func AssertReturnsBefore(t *testing.T, f func(), duration time.Duration, msgAndArgs ...interface{}) { 139 done := make(chan struct{}) 140 141 go func() { 142 f() 143 close(done) 144 }() 145 146 select { 147 case <-time.After(duration): 148 t.Log("function did not return in time") 149 assert.Fail(t, "function did not close in time", msgAndArgs...) 150 case <-done: 151 return 152 } 153 } 154 155 // AssertClosesBefore asserts that the given channel closes before the 156 // duration expires. 157 func AssertClosesBefore(t assert.TestingT, done <-chan struct{}, duration time.Duration, msgAndArgs ...interface{}) { 158 select { 159 case <-time.After(duration): 160 assert.Fail(t, "channel did not return in time", msgAndArgs...) 161 case <-done: 162 return 163 } 164 } 165 166 func AssertFloatEqual(t *testing.T, expected, actual float64, message string) { 167 tolerance := .00001 168 if !(math.Abs(expected-actual) < tolerance) { 169 assert.Equal(t, expected, actual, message) 170 } 171 } 172 173 // AssertNotClosesBefore asserts that the given channel does not close before the duration expires. 174 func AssertNotClosesBefore(t assert.TestingT, done <-chan struct{}, duration time.Duration, msgAndArgs ...interface{}) { 175 select { 176 case <-time.After(duration): 177 return 178 case <-done: 179 assert.Fail(t, "channel closed before timeout", msgAndArgs...) 180 } 181 } 182 183 // RequireReturnsBefore requires that the given function returns before the 184 // duration expires. 185 func RequireReturnsBefore(t testing.TB, f func(), duration time.Duration, message string) { 186 done := make(chan struct{}) 187 188 go func() { 189 f() 190 close(done) 191 }() 192 193 RequireCloseBefore(t, done, duration, message+": function did not return on time") 194 } 195 196 // RequireComponentsDoneBefore invokes the done method of each of the input components concurrently, and 197 // fails the test if any components shutdown takes longer than the specified duration. 198 func RequireComponentsDoneBefore(t testing.TB, duration time.Duration, components ...module.ReadyDoneAware) { 199 done := util.AllDone(components...) 200 RequireCloseBefore(t, done, duration, "failed to shutdown all components on time") 201 } 202 203 // RequireComponentsReadyBefore invokes the ready method of each of the input components concurrently, and 204 // fails the test if any components startup takes longer than the specified duration. 205 func RequireComponentsReadyBefore(t testing.TB, duration time.Duration, components ...module.ReadyDoneAware) { 206 ready := util.AllReady(components...) 207 RequireCloseBefore(t, ready, duration, "failed to start all components on time") 208 } 209 210 // RequireCloseBefore requires that the given channel returns before the 211 // duration expires. 212 func RequireCloseBefore(t testing.TB, c <-chan struct{}, duration time.Duration, message string) { 213 select { 214 case <-time.After(duration): 215 require.Fail(t, "could not close done channel on time: "+message) 216 case <-c: 217 return 218 } 219 } 220 221 // RequireClosed is a test helper function that fails the test if channel `ch` is not closed. 222 func RequireClosed(t *testing.T, ch <-chan struct{}, message string) { 223 select { 224 case <-ch: 225 default: 226 require.Fail(t, "channel is not closed: "+message) 227 } 228 } 229 230 // RequireConcurrentCallsReturnBefore is a test helper that runs function `f` count-many times concurrently, 231 // and requires all invocations to return within duration. 232 func RequireConcurrentCallsReturnBefore(t *testing.T, f func(), count int, duration time.Duration, message string) { 233 wg := &sync.WaitGroup{} 234 for i := 0; i < count; i++ { 235 wg.Add(1) 236 go func() { 237 f() 238 wg.Done() 239 }() 240 } 241 242 RequireReturnsBefore(t, wg.Wait, duration, message) 243 } 244 245 // RequireNeverReturnBefore is a test helper that tries invoking function `f` and fails the test if either: 246 // - function `f` is not invoked within 1 second. 247 // - function `f` returns before specified `duration`. 248 // 249 // It also returns a channel that is closed once the function `f` returns and hence its openness can evaluate 250 // return status of function `f` for intervals longer than duration. 251 func RequireNeverReturnBefore(t *testing.T, f func(), duration time.Duration, message string) <-chan struct{} { 252 ch := make(chan struct{}) 253 wg := sync.WaitGroup{} 254 wg.Add(1) 255 256 go func() { 257 wg.Done() 258 f() 259 close(ch) 260 }() 261 262 // requires function invoked within next 1 second 263 RequireReturnsBefore(t, wg.Wait, 1*time.Second, "could not invoke the function: "+message) 264 265 // requires function never returns within duration 266 RequireNeverClosedWithin(t, ch, duration, "unexpected return: "+message) 267 268 return ch 269 } 270 271 // RequireNeverClosedWithin is a test helper function that fails the test if channel `ch` is closed before the 272 // determined duration. 273 func RequireNeverClosedWithin(t *testing.T, ch <-chan struct{}, duration time.Duration, message string) { 274 select { 275 case <-time.After(duration): 276 case <-ch: 277 require.Fail(t, "channel closed before timeout: "+message) 278 } 279 } 280 281 // RequireNotClosed is a test helper function that fails the test if channel `ch` is closed. 282 func RequireNotClosed(t *testing.T, ch <-chan struct{}, message string) { 283 select { 284 case <-ch: 285 require.Fail(t, "channel is closed: "+message) 286 default: 287 } 288 } 289 290 // AssertErrSubstringMatch asserts that two errors match with substring 291 // checking on the Error method (`expected` must be a substring of `actual`, to 292 // account for the actual error being wrapped). Fails the test if either error 293 // is nil. 294 // 295 // NOTE: This should only be used in cases where `errors.Is` cannot be, like 296 // when errors are transmitted over the network without type information. 297 func AssertErrSubstringMatch(t testing.TB, expected, actual error) { 298 require.NotNil(t, expected) 299 require.NotNil(t, actual) 300 assert.True( 301 t, 302 strings.Contains(actual.Error(), expected.Error()) || strings.Contains(expected.Error(), actual.Error()), 303 "expected error: '%s', got: '%s'", expected.Error(), actual.Error(), 304 ) 305 } 306 307 func TempDir(t testing.TB) string { 308 dir, err := os.MkdirTemp("", "flow-testing-temp-") 309 require.NoError(t, err) 310 return dir 311 } 312 313 func RunWithTempDir(t testing.TB, f func(string)) { 314 dbDir := TempDir(t) 315 defer func() { 316 require.NoError(t, os.RemoveAll(dbDir)) 317 }() 318 f(dbDir) 319 } 320 321 func badgerDB(t testing.TB, dir string, create func(badger.Options) (*badger.DB, error)) *badger.DB { 322 opts := badger. 323 DefaultOptions(dir) 324 db, err := create(opts) 325 require.NoError(t, err) 326 return db 327 } 328 329 func BadgerDB(t testing.TB, dir string) *badger.DB { 330 return badgerDB(t, dir, badger.Open) 331 } 332 333 func TypedBadgerDB(t testing.TB, dir string, create func(badger.Options) (*badger.DB, error)) *badger.DB { 334 return badgerDB(t, dir, create) 335 } 336 337 func RunWithBadgerDB(t testing.TB, f func(*badger.DB)) { 338 RunWithTempDir(t, func(dir string) { 339 db := BadgerDB(t, dir) 340 defer func() { 341 assert.NoError(t, db.Close()) 342 }() 343 f(db) 344 }) 345 } 346 347 // RunWithTypedBadgerDB creates a Badger DB that is passed to f and closed 348 // after f returns. The extra create parameter allows passing in a database 349 // constructor function which instantiates a database with a particular type 350 // marker, for testing storage modules which require a backed with a particular 351 // type. 352 func RunWithTypedBadgerDB(t testing.TB, create func(badger.Options) (*badger.DB, error), f func(*badger.DB)) { 353 RunWithTempDir(t, func(dir string) { 354 db := badgerDB(t, dir, create) 355 defer func() { 356 assert.NoError(t, db.Close()) 357 }() 358 f(db) 359 }) 360 } 361 362 func TempBadgerDB(t testing.TB) (*badger.DB, string) { 363 dir := TempDir(t) 364 db := BadgerDB(t, dir) 365 return db, dir 366 } 367 368 func Concurrently(n int, f func(int)) { 369 var wg sync.WaitGroup 370 for i := 0; i < n; i++ { 371 wg.Add(1) 372 go func(i int) { 373 f(i) 374 wg.Done() 375 }(i) 376 } 377 wg.Wait() 378 } 379 380 // AssertEqualBlocksLenAndOrder asserts that both a segment of blocks have the same len and blocks are in the same order 381 func AssertEqualBlocksLenAndOrder(t *testing.T, expectedBlocks, actualSegmentBlocks []*flow.Block) { 382 assert.Equal(t, flow.GetIDs(expectedBlocks), flow.GetIDs(actualSegmentBlocks)) 383 } 384 385 // NetworkCodec returns cbor codec. 386 func NetworkCodec() network.Codec { 387 return cborcodec.NewCodec() 388 } 389 390 // NetworkTopology returns the default topology for testing purposes. 391 func NetworkTopology() network.Topology { 392 return topology.NewFullyConnectedTopology() 393 } 394 395 // CrashTest safely tests functions that crash (as the expected behavior) by checking that running the function creates an error and 396 // an expected error message. 397 func CrashTest(t *testing.T, scenario func(*testing.T), expectedErrorMsg string) { 398 CrashTestWithExpectedStatus(t, scenario, expectedErrorMsg, 1) 399 } 400 401 // CrashTestWithExpectedStatus checks for the test crashing with a specific exit code. 402 func CrashTestWithExpectedStatus( 403 t *testing.T, 404 scenario func(*testing.T), 405 expectedErrorMsg string, 406 expectedStatus ...int, 407 ) { 408 require.NotNil(t, scenario) 409 require.NotEmpty(t, expectedStatus) 410 411 if os.Getenv("CRASH_TEST") == "1" { 412 scenario(t) 413 return 414 } 415 416 cmd := exec.Command(os.Args[0], "-test.run="+t.Name()) 417 cmd.Env = append(os.Environ(), "CRASH_TEST=1") 418 419 outBytes, err := cmd.Output() 420 // expect error from run 421 require.Error(t, err) 422 423 // expect specific status codes 424 // require.Contains(t, expectedStatus, cmd.ProcessState.ExitCode()) 425 426 // expect logger.Fatal() message to be pushed to stdout 427 outStr := string(outBytes) 428 require.Contains(t, outStr, expectedErrorMsg) 429 } 430 431 // GenerateRandomStringWithLen returns a string of random alpha characters of the provided length 432 func GenerateRandomStringWithLen(commentLen uint) string { 433 const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 434 bytes := make([]byte, commentLen) 435 for i := range bytes { 436 bytes[i] = letterBytes[rand.Intn(len(letterBytes))] 437 } 438 return string(bytes) 439 } 440 441 // NetworkSlashingViolationsConsumer returns a slashing violations consumer for network middleware 442 func NetworkSlashingViolationsConsumer(logger zerolog.Logger, metrics module.NetworkSecurityMetrics) slashing.ViolationsConsumer { 443 return slashing.NewSlashingViolationsConsumer(logger, metrics) 444 }