github.com/onsi/gomega@v1.32.0/internal/async_assertion.go (about) 1 package internal 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "reflect" 8 "runtime" 9 "sync" 10 "time" 11 12 "github.com/onsi/gomega/format" 13 "github.com/onsi/gomega/types" 14 ) 15 16 var errInterface = reflect.TypeOf((*error)(nil)).Elem() 17 var gomegaType = reflect.TypeOf((*types.Gomega)(nil)).Elem() 18 var contextType = reflect.TypeOf(new(context.Context)).Elem() 19 20 type formattedGomegaError interface { 21 FormattedGomegaError() string 22 } 23 24 type asyncPolledActualError struct { 25 message string 26 } 27 28 func (err *asyncPolledActualError) Error() string { 29 return err.message 30 } 31 32 func (err *asyncPolledActualError) FormattedGomegaError() string { 33 return err.message 34 } 35 36 type contextWithAttachProgressReporter interface { 37 AttachProgressReporter(func() string) func() 38 } 39 40 type asyncGomegaHaltExecutionError struct{} 41 42 func (a asyncGomegaHaltExecutionError) GinkgoRecoverShouldIgnoreThisPanic() {} 43 func (a asyncGomegaHaltExecutionError) Error() string { 44 return `An assertion has failed in a goroutine. You should call 45 46 defer GinkgoRecover() 47 48 at the top of the goroutine that caused this panic. This will allow Ginkgo and Gomega to correctly capture and manage this panic.` 49 } 50 51 type AsyncAssertionType uint 52 53 const ( 54 AsyncAssertionTypeEventually AsyncAssertionType = iota 55 AsyncAssertionTypeConsistently 56 ) 57 58 func (at AsyncAssertionType) String() string { 59 switch at { 60 case AsyncAssertionTypeEventually: 61 return "Eventually" 62 case AsyncAssertionTypeConsistently: 63 return "Consistently" 64 } 65 return "INVALID ASYNC ASSERTION TYPE" 66 } 67 68 type AsyncAssertion struct { 69 asyncType AsyncAssertionType 70 71 actualIsFunc bool 72 actual interface{} 73 argsToForward []interface{} 74 75 timeoutInterval time.Duration 76 pollingInterval time.Duration 77 mustPassRepeatedly int 78 ctx context.Context 79 offset int 80 g *Gomega 81 } 82 83 func NewAsyncAssertion(asyncType AsyncAssertionType, actualInput interface{}, g *Gomega, timeoutInterval time.Duration, pollingInterval time.Duration, mustPassRepeatedly int, ctx context.Context, offset int) *AsyncAssertion { 84 out := &AsyncAssertion{ 85 asyncType: asyncType, 86 timeoutInterval: timeoutInterval, 87 pollingInterval: pollingInterval, 88 mustPassRepeatedly: mustPassRepeatedly, 89 offset: offset, 90 ctx: ctx, 91 g: g, 92 } 93 94 out.actual = actualInput 95 if actualInput != nil && reflect.TypeOf(actualInput).Kind() == reflect.Func { 96 out.actualIsFunc = true 97 } 98 99 return out 100 } 101 102 func (assertion *AsyncAssertion) WithOffset(offset int) types.AsyncAssertion { 103 assertion.offset = offset 104 return assertion 105 } 106 107 func (assertion *AsyncAssertion) WithTimeout(interval time.Duration) types.AsyncAssertion { 108 assertion.timeoutInterval = interval 109 return assertion 110 } 111 112 func (assertion *AsyncAssertion) WithPolling(interval time.Duration) types.AsyncAssertion { 113 assertion.pollingInterval = interval 114 return assertion 115 } 116 117 func (assertion *AsyncAssertion) Within(timeout time.Duration) types.AsyncAssertion { 118 assertion.timeoutInterval = timeout 119 return assertion 120 } 121 122 func (assertion *AsyncAssertion) ProbeEvery(interval time.Duration) types.AsyncAssertion { 123 assertion.pollingInterval = interval 124 return assertion 125 } 126 127 func (assertion *AsyncAssertion) WithContext(ctx context.Context) types.AsyncAssertion { 128 assertion.ctx = ctx 129 return assertion 130 } 131 132 func (assertion *AsyncAssertion) WithArguments(argsToForward ...interface{}) types.AsyncAssertion { 133 assertion.argsToForward = argsToForward 134 return assertion 135 } 136 137 func (assertion *AsyncAssertion) MustPassRepeatedly(count int) types.AsyncAssertion { 138 assertion.mustPassRepeatedly = count 139 return assertion 140 } 141 142 func (assertion *AsyncAssertion) Should(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool { 143 assertion.g.THelper() 144 vetOptionalDescription("Asynchronous assertion", optionalDescription...) 145 return assertion.match(matcher, true, optionalDescription...) 146 } 147 148 func (assertion *AsyncAssertion) ShouldNot(matcher types.GomegaMatcher, optionalDescription ...interface{}) bool { 149 assertion.g.THelper() 150 vetOptionalDescription("Asynchronous assertion", optionalDescription...) 151 return assertion.match(matcher, false, optionalDescription...) 152 } 153 154 func (assertion *AsyncAssertion) buildDescription(optionalDescription ...interface{}) string { 155 switch len(optionalDescription) { 156 case 0: 157 return "" 158 case 1: 159 if describe, ok := optionalDescription[0].(func() string); ok { 160 return describe() + "\n" 161 } 162 } 163 return fmt.Sprintf(optionalDescription[0].(string), optionalDescription[1:]...) + "\n" 164 } 165 166 func (assertion *AsyncAssertion) processReturnValues(values []reflect.Value) (interface{}, error) { 167 if len(values) == 0 { 168 return nil, &asyncPolledActualError{ 169 message: fmt.Sprintf("The function passed to %s did not return any values", assertion.asyncType), 170 } 171 } 172 173 actual := values[0].Interface() 174 if _, ok := AsPollingSignalError(actual); ok { 175 return actual, actual.(error) 176 } 177 178 var err error 179 for i, extraValue := range values[1:] { 180 extra := extraValue.Interface() 181 if extra == nil { 182 continue 183 } 184 if _, ok := AsPollingSignalError(extra); ok { 185 return actual, extra.(error) 186 } 187 extraType := reflect.TypeOf(extra) 188 zero := reflect.Zero(extraType).Interface() 189 if reflect.DeepEqual(extra, zero) { 190 continue 191 } 192 if i == len(values)-2 && extraType.Implements(errInterface) { 193 err = extra.(error) 194 } 195 if err == nil { 196 err = &asyncPolledActualError{ 197 message: fmt.Sprintf("The function passed to %s had an unexpected non-nil/non-zero return value at index %d:\n%s", assertion.asyncType, i+1, format.Object(extra, 1)), 198 } 199 } 200 } 201 202 return actual, err 203 } 204 205 func (assertion *AsyncAssertion) invalidFunctionError(t reflect.Type) error { 206 return fmt.Errorf(`The function passed to %s had an invalid signature of %s. Functions passed to %s must either: 207 208 (a) have return values or 209 (b) take a Gomega interface as their first argument and use that Gomega instance to make assertions. 210 211 You can learn more at https://onsi.github.io/gomega/#eventually 212 `, assertion.asyncType, t, assertion.asyncType) 213 } 214 215 func (assertion *AsyncAssertion) noConfiguredContextForFunctionError() error { 216 return fmt.Errorf(`The function passed to %s requested a context.Context, but no context has been provided. Please pass one in using %s().WithContext(). 217 218 You can learn more at https://onsi.github.io/gomega/#eventually 219 `, assertion.asyncType, assertion.asyncType) 220 } 221 222 func (assertion *AsyncAssertion) argumentMismatchError(t reflect.Type, numProvided int) error { 223 have := "have" 224 if numProvided == 1 { 225 have = "has" 226 } 227 return fmt.Errorf(`The function passed to %s has signature %s takes %d arguments but %d %s been provided. Please use %s().WithArguments() to pass the corect set of arguments. 228 229 You can learn more at https://onsi.github.io/gomega/#eventually 230 `, assertion.asyncType, t, t.NumIn(), numProvided, have, assertion.asyncType) 231 } 232 233 func (assertion *AsyncAssertion) invalidMustPassRepeatedlyError(reason string) error { 234 return fmt.Errorf(`Invalid use of MustPassRepeatedly with %s %s 235 236 You can learn more at https://onsi.github.io/gomega/#eventually 237 `, assertion.asyncType, reason) 238 } 239 240 func (assertion *AsyncAssertion) buildActualPoller() (func() (interface{}, error), error) { 241 if !assertion.actualIsFunc { 242 return func() (interface{}, error) { return assertion.actual, nil }, nil 243 } 244 actualValue := reflect.ValueOf(assertion.actual) 245 actualType := reflect.TypeOf(assertion.actual) 246 numIn, numOut, isVariadic := actualType.NumIn(), actualType.NumOut(), actualType.IsVariadic() 247 248 if numIn == 0 && numOut == 0 { 249 return nil, assertion.invalidFunctionError(actualType) 250 } 251 takesGomega, takesContext := false, false 252 if numIn > 0 { 253 takesGomega, takesContext = actualType.In(0).Implements(gomegaType), actualType.In(0).Implements(contextType) 254 } 255 if takesGomega && numIn > 1 && actualType.In(1).Implements(contextType) { 256 takesContext = true 257 } 258 if takesContext && len(assertion.argsToForward) > 0 && reflect.TypeOf(assertion.argsToForward[0]).Implements(contextType) { 259 takesContext = false 260 } 261 if !takesGomega && numOut == 0 { 262 return nil, assertion.invalidFunctionError(actualType) 263 } 264 if takesContext && assertion.ctx == nil { 265 return nil, assertion.noConfiguredContextForFunctionError() 266 } 267 268 var assertionFailure error 269 inValues := []reflect.Value{} 270 if takesGomega { 271 inValues = append(inValues, reflect.ValueOf(NewGomega(assertion.g.DurationBundle).ConfigureWithFailHandler(func(message string, callerSkip ...int) { 272 skip := 0 273 if len(callerSkip) > 0 { 274 skip = callerSkip[0] 275 } 276 _, file, line, _ := runtime.Caller(skip + 1) 277 assertionFailure = &asyncPolledActualError{ 278 message: fmt.Sprintf("The function passed to %s failed at %s:%d with:\n%s", assertion.asyncType, file, line, message), 279 } 280 // we throw an asyncGomegaHaltExecutionError so that defer GinkgoRecover() can catch this error if the user makes an assertion in a goroutine 281 panic(asyncGomegaHaltExecutionError{}) 282 }))) 283 } 284 if takesContext { 285 inValues = append(inValues, reflect.ValueOf(assertion.ctx)) 286 } 287 for _, arg := range assertion.argsToForward { 288 inValues = append(inValues, reflect.ValueOf(arg)) 289 } 290 291 if !isVariadic && numIn != len(inValues) { 292 return nil, assertion.argumentMismatchError(actualType, len(inValues)) 293 } else if isVariadic && len(inValues) < numIn-1 { 294 return nil, assertion.argumentMismatchError(actualType, len(inValues)) 295 } 296 297 if assertion.mustPassRepeatedly != 1 && assertion.asyncType != AsyncAssertionTypeEventually { 298 return nil, assertion.invalidMustPassRepeatedlyError("it can only be used with Eventually") 299 } 300 if assertion.mustPassRepeatedly < 1 { 301 return nil, assertion.invalidMustPassRepeatedlyError("parameter can't be < 1") 302 } 303 304 return func() (actual interface{}, err error) { 305 var values []reflect.Value 306 assertionFailure = nil 307 defer func() { 308 if numOut == 0 && takesGomega { 309 actual = assertionFailure 310 } else { 311 actual, err = assertion.processReturnValues(values) 312 _, isAsyncError := AsPollingSignalError(err) 313 if assertionFailure != nil && !isAsyncError { 314 err = assertionFailure 315 } 316 } 317 if e := recover(); e != nil { 318 if _, isAsyncError := AsPollingSignalError(e); isAsyncError { 319 err = e.(error) 320 } else if assertionFailure == nil { 321 panic(e) 322 } 323 } 324 }() 325 values = actualValue.Call(inValues) 326 return 327 }, nil 328 } 329 330 func (assertion *AsyncAssertion) afterTimeout() <-chan time.Time { 331 if assertion.timeoutInterval >= 0 { 332 return time.After(assertion.timeoutInterval) 333 } 334 335 if assertion.asyncType == AsyncAssertionTypeConsistently { 336 return time.After(assertion.g.DurationBundle.ConsistentlyDuration) 337 } else { 338 if assertion.ctx == nil { 339 return time.After(assertion.g.DurationBundle.EventuallyTimeout) 340 } else { 341 return nil 342 } 343 } 344 } 345 346 func (assertion *AsyncAssertion) afterPolling() <-chan time.Time { 347 if assertion.pollingInterval >= 0 { 348 return time.After(assertion.pollingInterval) 349 } 350 if assertion.asyncType == AsyncAssertionTypeConsistently { 351 return time.After(assertion.g.DurationBundle.ConsistentlyPollingInterval) 352 } else { 353 return time.After(assertion.g.DurationBundle.EventuallyPollingInterval) 354 } 355 } 356 357 func (assertion *AsyncAssertion) matcherSaysStopTrying(matcher types.GomegaMatcher, value interface{}) bool { 358 if assertion.actualIsFunc || types.MatchMayChangeInTheFuture(matcher, value) { 359 return false 360 } 361 return true 362 } 363 364 func (assertion *AsyncAssertion) pollMatcher(matcher types.GomegaMatcher, value interface{}) (matches bool, err error) { 365 defer func() { 366 if e := recover(); e != nil { 367 if _, isAsyncError := AsPollingSignalError(e); isAsyncError { 368 err = e.(error) 369 } else { 370 panic(e) 371 } 372 } 373 }() 374 375 matches, err = matcher.Match(value) 376 377 return 378 } 379 380 func (assertion *AsyncAssertion) match(matcher types.GomegaMatcher, desiredMatch bool, optionalDescription ...interface{}) bool { 381 timer := time.Now() 382 timeout := assertion.afterTimeout() 383 lock := sync.Mutex{} 384 385 var matches, hasLastValidActual bool 386 var actual, lastValidActual interface{} 387 var actualErr, matcherErr error 388 var oracleMatcherSaysStop bool 389 390 assertion.g.THelper() 391 392 pollActual, buildActualPollerErr := assertion.buildActualPoller() 393 if buildActualPollerErr != nil { 394 assertion.g.Fail(buildActualPollerErr.Error(), 2+assertion.offset) 395 return false 396 } 397 398 actual, actualErr = pollActual() 399 if actualErr == nil { 400 lastValidActual = actual 401 hasLastValidActual = true 402 oracleMatcherSaysStop = assertion.matcherSaysStopTrying(matcher, actual) 403 matches, matcherErr = assertion.pollMatcher(matcher, actual) 404 } 405 406 renderError := func(preamble string, err error) string { 407 message := "" 408 if pollingSignalErr, ok := AsPollingSignalError(err); ok { 409 message = err.Error() 410 for _, attachment := range pollingSignalErr.Attachments { 411 message += fmt.Sprintf("\n%s:\n", attachment.Description) 412 message += format.Object(attachment.Object, 1) 413 } 414 } else { 415 message = preamble + "\n" + format.Object(err, 1) 416 } 417 return message 418 } 419 420 messageGenerator := func() string { 421 // can be called out of band by Ginkgo if the user requests a progress report 422 lock.Lock() 423 defer lock.Unlock() 424 message := "" 425 426 if actualErr == nil { 427 if matcherErr == nil { 428 if desiredMatch != matches { 429 if desiredMatch { 430 message += matcher.FailureMessage(actual) 431 } else { 432 message += matcher.NegatedFailureMessage(actual) 433 } 434 } else { 435 if assertion.asyncType == AsyncAssertionTypeConsistently { 436 message += "There is no failure as the matcher passed to Consistently has not yet failed" 437 } else { 438 message += "There is no failure as the matcher passed to Eventually succeeded on its most recent iteration" 439 } 440 } 441 } else { 442 var fgErr formattedGomegaError 443 if errors.As(actualErr, &fgErr) { 444 message += fgErr.FormattedGomegaError() + "\n" 445 } else { 446 message += renderError(fmt.Sprintf("The matcher passed to %s returned the following error:", assertion.asyncType), matcherErr) 447 } 448 } 449 } else { 450 var fgErr formattedGomegaError 451 if errors.As(actualErr, &fgErr) { 452 message += fgErr.FormattedGomegaError() + "\n" 453 } else { 454 message += renderError(fmt.Sprintf("The function passed to %s returned the following error:", assertion.asyncType), actualErr) 455 } 456 if hasLastValidActual { 457 message += fmt.Sprintf("\nAt one point, however, the function did return successfully.\nYet, %s failed because", assertion.asyncType) 458 _, e := matcher.Match(lastValidActual) 459 if e != nil { 460 message += renderError(" the matcher returned the following error:", e) 461 } else { 462 message += " the matcher was not satisfied:\n" 463 if desiredMatch { 464 message += matcher.FailureMessage(lastValidActual) 465 } else { 466 message += matcher.NegatedFailureMessage(lastValidActual) 467 } 468 } 469 } 470 } 471 472 description := assertion.buildDescription(optionalDescription...) 473 return fmt.Sprintf("%s%s", description, message) 474 } 475 476 fail := func(preamble string) { 477 assertion.g.THelper() 478 assertion.g.Fail(fmt.Sprintf("%s after %.3fs.\n%s", preamble, time.Since(timer).Seconds(), messageGenerator()), 3+assertion.offset) 479 } 480 481 var contextDone <-chan struct{} 482 if assertion.ctx != nil { 483 contextDone = assertion.ctx.Done() 484 if v, ok := assertion.ctx.Value("GINKGO_SPEC_CONTEXT").(contextWithAttachProgressReporter); ok { 485 detach := v.AttachProgressReporter(messageGenerator) 486 defer detach() 487 } 488 } 489 490 // Used to count the number of times in a row a step passed 491 passedRepeatedlyCount := 0 492 for { 493 var nextPoll <-chan time.Time = nil 494 var isTryAgainAfterError = false 495 496 for _, err := range []error{actualErr, matcherErr} { 497 if pollingSignalErr, ok := AsPollingSignalError(err); ok { 498 if pollingSignalErr.IsStopTrying() { 499 fail("Told to stop trying") 500 return false 501 } 502 if pollingSignalErr.IsTryAgainAfter() { 503 nextPoll = time.After(pollingSignalErr.TryAgainDuration()) 504 isTryAgainAfterError = true 505 } 506 } 507 } 508 509 if actualErr == nil && matcherErr == nil && matches == desiredMatch { 510 if assertion.asyncType == AsyncAssertionTypeEventually { 511 passedRepeatedlyCount += 1 512 if passedRepeatedlyCount == assertion.mustPassRepeatedly { 513 return true 514 } 515 } 516 } else if !isTryAgainAfterError { 517 if assertion.asyncType == AsyncAssertionTypeConsistently { 518 fail("Failed") 519 return false 520 } 521 // Reset the consecutive pass count 522 passedRepeatedlyCount = 0 523 } 524 525 if oracleMatcherSaysStop { 526 if assertion.asyncType == AsyncAssertionTypeEventually { 527 fail("No future change is possible. Bailing out early") 528 return false 529 } else { 530 return true 531 } 532 } 533 534 if nextPoll == nil { 535 nextPoll = assertion.afterPolling() 536 } 537 538 select { 539 case <-nextPoll: 540 a, e := pollActual() 541 lock.Lock() 542 actual, actualErr = a, e 543 lock.Unlock() 544 if actualErr == nil { 545 lock.Lock() 546 lastValidActual = actual 547 hasLastValidActual = true 548 lock.Unlock() 549 oracleMatcherSaysStop = assertion.matcherSaysStopTrying(matcher, actual) 550 m, e := assertion.pollMatcher(matcher, actual) 551 lock.Lock() 552 matches, matcherErr = m, e 553 lock.Unlock() 554 } 555 case <-contextDone: 556 err := context.Cause(assertion.ctx) 557 if err != nil && err != context.Canceled { 558 fail(fmt.Sprintf("Context was cancelled (cause: %s)", err)) 559 } else { 560 fail("Context was cancelled") 561 } 562 return false 563 case <-timeout: 564 if assertion.asyncType == AsyncAssertionTypeEventually { 565 fail("Timed out") 566 return false 567 } else { 568 if isTryAgainAfterError { 569 fail("Timed out while waiting on TryAgainAfter") 570 return false 571 } 572 return true 573 } 574 } 575 } 576 }