github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/breaker/breaker_test.go (about) 1 // Copyright 2022 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package breaker 16 17 import ( 18 "errors" 19 "net/http" 20 "testing" 21 "time" 22 23 "github.com/gravitational/trace" 24 "github.com/jonboulle/clockwork" 25 "github.com/stretchr/testify/require" 26 "google.golang.org/grpc/codes" 27 "google.golang.org/grpc/status" 28 ) 29 30 func TestCircuitBreaker_generation(t *testing.T) { 31 t.Parallel() 32 clock := clockwork.NewFakeClock() 33 34 cb, err := New(Config{ 35 Clock: clock, 36 Interval: time.Second, 37 Trip: StaticTripper(false), 38 Recover: StaticTripper(false), 39 }) 40 require.NoError(t, err) 41 42 generation, state := cb.currentState(clock.Now()) 43 require.Equal(t, uint64(1), generation) 44 require.Equal(t, StateStandby, state) 45 require.Equal(t, clock.Now().Add(time.Second), cb.expiry) 46 47 clock.Advance(500 * time.Millisecond) 48 generation, state = cb.currentState(clock.Now()) 49 require.Equal(t, uint64(1), generation) 50 require.Equal(t, StateStandby, state) 51 clock.Advance(501 * time.Millisecond) 52 generation, state = cb.currentState(clock.Now()) 53 require.Equal(t, uint64(2), generation) 54 require.Equal(t, StateStandby, state) 55 require.Equal(t, clock.Now().Add(time.Second), cb.expiry) 56 57 for i := 0; i < 1000; i++ { 58 prevGeneration, prevState := cb.currentState(clock.Now()) 59 cb.nextGeneration(clock.Now()) 60 generation, state := cb.currentState(clock.Now()) 61 require.NotEqual(t, prevGeneration, generation) 62 require.Equal(t, prevState, state) 63 } 64 65 generation, state = cb.currentState(clock.Now()) 66 require.Equal(t, uint64(1002), generation) 67 require.Equal(t, StateStandby, state) 68 } 69 70 func TestCircuitBreaker_beforeRequest(t *testing.T) { 71 t.Parallel() 72 cases := []struct { 73 desc string 74 generation uint64 75 executions uint32 76 advance time.Duration 77 state State 78 errorCheck require.ErrorAssertionFunc 79 }{ 80 { 81 desc: "standby allows execution", 82 generation: 1, 83 executions: 1, 84 state: StateStandby, 85 errorCheck: require.NoError, 86 }, 87 { 88 desc: "tripped prevents executions", 89 generation: 1, 90 executions: 0, 91 state: StateTripped, 92 errorCheck: func(t require.TestingT, err error, i ...interface{}) { 93 require.Error(t, err) 94 require.ErrorIs(t, ErrStateTripped, err) 95 }, 96 }, 97 { 98 desc: "recovering after allows executions", 99 generation: 1, 100 executions: 1, 101 state: StateRecovering, 102 advance: 3 * time.Second, 103 errorCheck: require.NoError, 104 }, 105 } 106 107 for _, tt := range cases { 108 t.Run(tt.desc, func(t *testing.T) { 109 clock := clockwork.NewFakeClock() 110 111 cb, err := New(Config{ 112 Clock: clock, 113 Interval: time.Second, 114 Trip: StaticTripper(false), 115 Recover: StaticTripper(false), 116 RecoveryLimit: 1, 117 }) 118 require.NoError(t, err) 119 cb.state = tt.state 120 121 clock.Advance(tt.advance) 122 123 generation, err := cb.beforeExecution() 124 tt.errorCheck(t, err) 125 require.Equal(t, tt.generation, generation) 126 require.Equal(t, tt.executions, cb.metrics.Executions) 127 128 }) 129 } 130 } 131 132 func TestCircuitBreaker_afterExecution(t *testing.T) { 133 t.Parallel() 134 cases := []struct { 135 desc string 136 err error 137 priorGeneration uint64 138 checkMetrics require.ValueAssertionFunc 139 trip TripFn 140 recover TripFn 141 expectedState State 142 }{ 143 { 144 desc: "successful execution", 145 priorGeneration: 1, 146 checkMetrics: func(t require.TestingT, i interface{}, i2 ...interface{}) { 147 m, ok := i.(Metrics) 148 require.True(t, ok) 149 require.Equal(t, uint32(1), m.Successes) 150 require.Equal(t, uint32(0), m.Failures) 151 }, 152 trip: StaticTripper(false), 153 recover: StaticTripper(false), 154 expectedState: StateStandby, 155 }, 156 { 157 desc: "generation change", 158 priorGeneration: 0, 159 trip: StaticTripper(false), 160 recover: StaticTripper(false), 161 checkMetrics: func(t require.TestingT, i interface{}, i2 ...interface{}) { 162 m, ok := i.(Metrics) 163 require.True(t, ok) 164 require.Equal(t, uint32(0), m.Successes) 165 require.Equal(t, uint32(0), m.Failures) 166 }, 167 expectedState: StateStandby, 168 }, 169 { 170 desc: "failed execution with out tripping", 171 priorGeneration: 1, 172 err: errors.New("failure"), 173 trip: StaticTripper(false), 174 recover: StaticTripper(false), 175 checkMetrics: func(t require.TestingT, i interface{}, i2 ...interface{}) { 176 m, ok := i.(Metrics) 177 require.True(t, ok) 178 require.Equal(t, uint32(0), m.Successes) 179 require.Equal(t, uint32(1), m.Failures) 180 }, 181 expectedState: StateStandby, 182 }, 183 { 184 desc: "failed execution causing a trip", 185 priorGeneration: 1, 186 err: errors.New("failure"), 187 trip: StaticTripper(true), 188 recover: StaticTripper(false), 189 checkMetrics: func(t require.TestingT, i interface{}, i2 ...interface{}) { 190 m, ok := i.(Metrics) 191 require.True(t, ok) 192 require.Equal(t, uint32(0), m.Successes) 193 require.Equal(t, uint32(0), m.Failures) 194 }, 195 expectedState: StateTripped, 196 }, 197 } 198 199 for _, tt := range cases { 200 t.Run(tt.desc, func(t *testing.T) { 201 clock := clockwork.NewFakeClock() 202 cb, err := New(Config{ 203 Clock: clock, 204 Interval: time.Second, 205 Trip: tt.trip, 206 Recover: tt.recover, 207 }) 208 require.NoError(t, err) 209 210 cb.afterExecution(tt.priorGeneration, nil, tt.err) 211 tt.checkMetrics(t, cb.metrics) 212 require.Equal(t, tt.expectedState, cb.state) 213 }) 214 } 215 } 216 217 func TestCircuitBreaker_success(t *testing.T) { 218 t.Parallel() 219 cases := []struct { 220 desc string 221 initialState State 222 successState State 223 expectedState State 224 recoveryLimit uint32 225 }{ 226 { 227 desc: "success in standby", 228 initialState: StateStandby, 229 successState: StateStandby, 230 expectedState: StateStandby, 231 }, 232 { 233 desc: "success in recovery below limit", 234 initialState: StateRecovering, 235 successState: StateRecovering, 236 expectedState: StateRecovering, 237 recoveryLimit: 2, 238 }, 239 { 240 desc: "success in recovery above limit", 241 initialState: StateRecovering, 242 successState: StateRecovering, 243 expectedState: StateStandby, 244 recoveryLimit: 1, 245 }, 246 } 247 248 for _, tt := range cases { 249 t.Run(tt.desc, func(t *testing.T) { 250 clock := clockwork.NewFakeClock() 251 cb, err := New(Config{ 252 Clock: clock, 253 Interval: time.Second, 254 RecoveryLimit: tt.recoveryLimit, 255 Trip: StaticTripper(false), 256 Recover: StaticTripper(false), 257 }) 258 require.NoError(t, err) 259 cb.state = tt.initialState 260 261 generation, state := cb.currentState(clock.Now()) 262 cb.successLocked(tt.successState, clock.Now()) 263 require.Equal(t, tt.expectedState, cb.state) 264 if tt.expectedState != state { 265 require.NotEqual(t, generation, cb.generation) 266 } 267 }) 268 } 269 } 270 271 func TestCircuitBreaker_failure(t *testing.T) { 272 t.Parallel() 273 cases := []struct { 274 desc string 275 initialState State 276 failureState State 277 expectedState State 278 tripFn TripFn 279 recover TripFn 280 onTrip func(ch chan bool) func() 281 tripped bool 282 requireTripped require.BoolAssertionFunc 283 }{ 284 { 285 desc: "failure in recovering transitions to tripped", 286 initialState: StateRecovering, 287 failureState: StateRecovering, 288 expectedState: StateTripped, 289 tripFn: StaticTripper(false), 290 recover: StaticTripper(true), 291 requireTripped: require.False, 292 }, 293 { 294 desc: "failure in standby without tripping", 295 initialState: StateStandby, 296 failureState: StateStandby, 297 expectedState: StateStandby, 298 tripFn: StaticTripper(false), 299 recover: StaticTripper(false), 300 requireTripped: require.False, 301 }, 302 { 303 desc: "failure in standby causes tripping", 304 initialState: StateStandby, 305 failureState: StateStandby, 306 expectedState: StateTripped, 307 tripFn: StaticTripper(true), 308 recover: StaticTripper(false), 309 requireTripped: require.True, 310 onTrip: func(ch chan bool) func() { 311 return func() { 312 ch <- true 313 } 314 }, 315 }, 316 } 317 318 for _, tt := range cases { 319 tt := tt 320 t.Run(tt.desc, func(t *testing.T) { 321 t.Parallel() 322 clock := clockwork.NewFakeClock() 323 324 if tt.onTrip == nil { 325 tt.onTrip = func(ch chan bool) func() { 326 ch <- false 327 return func() {} 328 } 329 } 330 331 trippedCh := make(chan bool, 1) 332 333 cb, err := New(Config{ 334 Clock: clock, 335 Interval: time.Second, 336 Trip: tt.tripFn, 337 OnTripped: tt.onTrip(trippedCh), 338 Recover: tt.recover, 339 }) 340 require.NoError(t, err) 341 cb.state = tt.initialState 342 343 generation, state := cb.currentState(clock.Now()) 344 cb.failureLocked(tt.failureState, clock.Now()) 345 require.Equal(t, tt.expectedState, cb.state) 346 if tt.expectedState != state { 347 require.NotEqual(t, generation, cb.generation) 348 } 349 350 tripped := <-trippedCh 351 352 tt.requireTripped(t, tripped) 353 }) 354 } 355 } 356 357 func TestCircuitBreaker_Execute(t *testing.T) { 358 t.Parallel() 359 360 clock := clockwork.NewFakeClock() 361 362 trippedCh := make(chan struct{}) 363 onTripped := func(ch chan struct{}) func() { 364 return func() { 365 ch <- struct{}{} 366 } 367 } 368 369 cb, err := New(Config{ 370 Clock: clock, 371 Interval: time.Second, 372 Trip: ConsecutiveFailureTripper(3), 373 Recover: ConsecutiveFailureTripper(1), 374 OnTripped: onTripped(trippedCh), 375 TrippedPeriod: 2 * time.Second, 376 RecoveryLimit: 2, 377 }) 378 require.NoError(t, err) 379 380 testErr := errors.New("failure") 381 errorFn := func() (interface{}, error) { return nil, testErr } 382 noErrorFn := func() (interface{}, error) { return nil, nil } 383 cases := []struct { 384 desc string 385 exec func() (interface{}, error) 386 advance time.Duration 387 errorAssertion require.ErrorAssertionFunc 388 expectedState State 389 expectedGeneration uint64 390 }{ 391 { 392 desc: "no errors remain in standby", 393 exec: noErrorFn, 394 errorAssertion: require.NoError, 395 expectedState: StateStandby, 396 expectedGeneration: 1, 397 }, 398 { 399 desc: "error below limit remain in standby", 400 exec: errorFn, 401 errorAssertion: require.Error, 402 expectedState: StateStandby, 403 expectedGeneration: 1, 404 }, 405 { 406 desc: "another error below limit remain in standby", 407 exec: errorFn, 408 errorAssertion: require.Error, 409 expectedState: StateStandby, 410 expectedGeneration: 1, 411 }, 412 { 413 desc: "last error below limit remain in standby", 414 exec: errorFn, 415 errorAssertion: require.Error, 416 expectedState: StateStandby, 417 expectedGeneration: 1, 418 }, 419 { 420 desc: "transition from standby to tripped", 421 exec: errorFn, 422 errorAssertion: require.Error, 423 expectedState: StateTripped, 424 expectedGeneration: 2, 425 }, 426 { 427 desc: "error remain tripped", 428 exec: errorFn, 429 errorAssertion: require.Error, 430 expectedState: StateTripped, 431 expectedGeneration: 2, 432 }, 433 { 434 desc: "no error remain tripped", 435 exec: noErrorFn, 436 errorAssertion: require.Error, 437 expectedState: StateTripped, 438 expectedGeneration: 2, 439 }, 440 { 441 desc: "transition from tripped to recovering", 442 exec: noErrorFn, 443 errorAssertion: require.NoError, 444 expectedState: StateRecovering, 445 expectedGeneration: 3, 446 advance: 3 * time.Second, 447 }, 448 { 449 desc: "first failed execution recovering remains in recovering", 450 exec: errorFn, 451 errorAssertion: require.Error, 452 expectedState: StateRecovering, 453 expectedGeneration: 3, 454 advance: 250 * time.Millisecond, 455 }, 456 { 457 desc: "second failed execution recovering transitions to tripped", 458 exec: errorFn, 459 errorAssertion: require.Error, 460 expectedState: StateTripped, 461 expectedGeneration: 4, 462 advance: 450 * time.Millisecond, 463 }, 464 { 465 desc: "transition from tripped to recovering", 466 exec: noErrorFn, 467 errorAssertion: require.NoError, 468 expectedState: StateRecovering, 469 expectedGeneration: 5, 470 advance: 3 * time.Second, 471 }, 472 { 473 desc: "transition from recovering to standby", 474 exec: noErrorFn, 475 errorAssertion: require.NoError, 476 expectedState: StateStandby, 477 expectedGeneration: 6, 478 advance: 450 * time.Millisecond, 479 }, 480 { 481 desc: "remain in standby while in new generation", 482 exec: noErrorFn, 483 errorAssertion: require.NoError, 484 expectedState: StateStandby, 485 expectedGeneration: 7, 486 advance: time.Minute, 487 }, 488 } 489 490 for i, tt := range cases { 491 t.Run(tt.desc, func(t *testing.T) { 492 clock.Advance(tt.advance) 493 _, err := cb.Execute(tt.exec) 494 tt.errorAssertion(t, err) 495 generation, state := cb.currentState(clock.Now()) 496 require.Equal(t, tt.expectedGeneration, generation, "incorrect generation") 497 require.Equal(t, tt.expectedState, state, "incorrect state") 498 499 if state != StateTripped && tt.expectedState == StateTripped { 500 select { 501 case <-trippedCh: 502 default: 503 t.Fatalf("step %d expected to get tripped, but wasn't", i) 504 } 505 } 506 }) 507 } 508 509 } 510 511 func TestMetrics(t *testing.T) { 512 m := Metrics{} 513 514 zero := uint32(0) 515 one := uint32(1) 516 require.Equal(t, zero, m.Executions) 517 require.Equal(t, zero, m.Successes) 518 require.Equal(t, zero, m.Failures) 519 require.Equal(t, zero, m.ConsecutiveSuccesses) 520 require.Equal(t, zero, m.ConsecutiveFailures) 521 522 m.success() 523 524 require.Equal(t, zero, m.Executions) 525 require.Equal(t, one, m.Successes) 526 require.Equal(t, zero, m.Failures) 527 require.Equal(t, one, m.ConsecutiveSuccesses) 528 require.Equal(t, zero, m.ConsecutiveFailures) 529 530 m.execute() 531 532 require.Equal(t, one, m.Executions) 533 require.Equal(t, one, m.Successes) 534 require.Equal(t, zero, m.Failures) 535 require.Equal(t, one, m.ConsecutiveSuccesses) 536 require.Equal(t, zero, m.ConsecutiveFailures) 537 538 m.failure() 539 540 require.Equal(t, one, m.Executions) 541 require.Equal(t, one, m.Successes) 542 require.Equal(t, one, m.Failures) 543 require.Equal(t, zero, m.ConsecutiveSuccesses) 544 require.Equal(t, one, m.ConsecutiveFailures) 545 546 m.reset() 547 548 require.Equal(t, zero, m.Executions) 549 require.Equal(t, zero, m.Successes) 550 require.Equal(t, zero, m.Failures) 551 require.Equal(t, zero, m.ConsecutiveSuccesses) 552 require.Equal(t, zero, m.ConsecutiveFailures) 553 } 554 555 func TestIsResponseSuccessful(t *testing.T) { 556 cases := []struct { 557 name string 558 err error 559 response *http.Response 560 assertion require.BoolAssertionFunc 561 }{ 562 { 563 name: "nil error", 564 assertion: require.True, 565 }, 566 { 567 name: "codes.Canceled error", 568 err: status.Error(codes.Canceled, ""), 569 assertion: require.False, 570 }, 571 { 572 name: "codes.Unknown error", 573 err: status.Error(codes.Unknown, ""), 574 assertion: require.False, 575 }, 576 { 577 name: "codes.Unavailable error", 578 err: status.Error(codes.Unavailable, ""), 579 assertion: require.False, 580 }, 581 { 582 name: "codes.Unavailable error", 583 err: status.Error(codes.DeadlineExceeded, ""), 584 assertion: require.False, 585 }, 586 { 587 name: "other error", 588 err: trace.NotFound("not found"), 589 assertion: require.False, 590 }, 591 { 592 name: "error", 593 err: trace.NotFound(""), 594 assertion: require.False, 595 }, 596 { 597 name: "200", 598 response: &http.Response{StatusCode: http.StatusOK}, 599 assertion: require.True, 600 }, 601 { 602 name: "500", 603 response: &http.Response{StatusCode: http.StatusBadGateway}, 604 assertion: require.False, 605 }, 606 { 607 name: "404", 608 response: &http.Response{StatusCode: http.StatusNotFound}, 609 assertion: require.True, 610 }, 611 } 612 613 for _, tt := range cases { 614 t.Run(tt.name, func(t *testing.T) { 615 tt.assertion(t, IsResponseSuccessful(tt.response, tt.err)) 616 }) 617 } 618 }