github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/flowinfra/flow_registry_test.go (about) 1 // Copyright 2016 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package flowinfra 12 13 import ( 14 "context" 15 "math" 16 "sync" 17 "testing" 18 "time" 19 20 "github.com/cockroachdb/cockroach/pkg/sql/execinfra" 21 "github.com/cockroachdb/cockroach/pkg/sql/execinfrapb" 22 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" 23 "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" 24 "github.com/cockroachdb/cockroach/pkg/testutils" 25 "github.com/cockroachdb/cockroach/pkg/testutils/distsqlutils" 26 "github.com/cockroachdb/cockroach/pkg/util/leaktest" 27 "github.com/cockroachdb/cockroach/pkg/util/timeutil" 28 "github.com/cockroachdb/cockroach/pkg/util/uuid" 29 "github.com/cockroachdb/errors" 30 ) 31 32 // lookupFlow returns the registered flow with the given ID. If no such flow is 33 // registered, waits until it gets registered - up to the given timeout. If the 34 // timeout elapses and the flow is not registered, the bool return value will be 35 // false. 36 func lookupFlow(fr *FlowRegistry, fid execinfrapb.FlowID, timeout time.Duration) Flow { 37 fr.Lock() 38 defer fr.Unlock() 39 entry := fr.getEntryLocked(fid) 40 if entry.flow != nil { 41 return entry.flow 42 } 43 entry = fr.waitForFlowLocked(context.Background(), fid, timeout) 44 if entry == nil { 45 return nil 46 } 47 return entry.flow 48 } 49 50 // lookupStreamInfo returns a stream entry from a FlowRegistry. If either the 51 // flow or the streams are missing, an error is returned. 52 // 53 // A copy of the registry's InboundStreamInfo is returned so it can be accessed 54 // without locking. 55 func lookupStreamInfo( 56 fr *FlowRegistry, fid execinfrapb.FlowID, sid execinfrapb.StreamID, 57 ) (InboundStreamInfo, error) { 58 fr.Lock() 59 defer fr.Unlock() 60 entry := fr.getEntryLocked(fid) 61 if entry.flow == nil { 62 return InboundStreamInfo{}, errors.Errorf("missing flow entry: %s", fid) 63 } 64 si, ok := entry.inboundStreams[sid] 65 if !ok { 66 return InboundStreamInfo{}, errors.Errorf("missing stream entry: %d", sid) 67 } 68 return *si, nil 69 } 70 71 func TestFlowRegistry(t *testing.T) { 72 defer leaktest.AfterTest(t)() 73 reg := NewFlowRegistry(0) 74 75 id1 := execinfrapb.FlowID{UUID: uuid.MakeV4()} 76 f1 := &FlowBase{} 77 78 id2 := execinfrapb.FlowID{UUID: uuid.MakeV4()} 79 f2 := &FlowBase{} 80 81 id3 := execinfrapb.FlowID{UUID: uuid.MakeV4()} 82 f3 := &FlowBase{} 83 84 id4 := execinfrapb.FlowID{UUID: uuid.MakeV4()} 85 f4 := &FlowBase{} 86 87 // A basic duration; needs to be significantly larger than possible delays 88 // in scheduling goroutines. 89 jiffy := 10 * time.Millisecond 90 91 // -- Lookup, register, lookup, unregister, lookup. -- 92 93 if f := lookupFlow(reg, id1, 0); f != nil { 94 t.Error("looked up unregistered flow") 95 } 96 97 const flowStreamTimeout = 10 * time.Second 98 99 ctx := context.Background() 100 if err := reg.RegisterFlow( 101 ctx, id1, f1, nil /* inboundStreams */, flowStreamTimeout, 102 ); err != nil { 103 t.Fatal(err) 104 } 105 106 if f := lookupFlow(reg, id1, 0); f != f1 { 107 t.Error("couldn't lookup previously registered flow") 108 } 109 110 reg.UnregisterFlow(id1) 111 112 if f := lookupFlow(reg, id1, 0); f != nil { 113 t.Error("looked up unregistered flow") 114 } 115 116 // -- Lookup with timeout, register in the meantime. -- 117 118 go func() { 119 time.Sleep(jiffy) 120 if err := reg.RegisterFlow( 121 ctx, id1, f1, nil /* inboundStreams */, flowStreamTimeout, 122 ); err != nil { 123 t.Error(err) 124 } 125 }() 126 127 if f := lookupFlow(reg, id1, 10*jiffy); f != f1 { 128 t.Error("couldn't lookup registered flow (with wait)") 129 } 130 131 if f := lookupFlow(reg, id1, 0); f != f1 { 132 t.Error("couldn't lookup registered flow") 133 } 134 135 // -- Multiple lookups before register. -- 136 137 var wg sync.WaitGroup 138 wg.Add(2) 139 140 go func() { 141 if f := lookupFlow(reg, id2, 10*jiffy); f != f2 { 142 t.Error("couldn't lookup registered flow (with wait)") 143 } 144 wg.Done() 145 }() 146 147 go func() { 148 if f := lookupFlow(reg, id2, 10*jiffy); f != f2 { 149 t.Error("couldn't lookup registered flow (with wait)") 150 } 151 wg.Done() 152 }() 153 154 time.Sleep(jiffy) 155 if err := reg.RegisterFlow( 156 ctx, id2, f2, nil /* inboundStreams */, flowStreamTimeout, 157 ); err != nil { 158 t.Fatal(err) 159 } 160 wg.Wait() 161 162 // -- Multiple lookups, with the first one failing. -- 163 164 var wg1 sync.WaitGroup 165 var wg2 sync.WaitGroup 166 167 wg1.Add(1) 168 wg2.Add(1) 169 go func() { 170 if f := lookupFlow(reg, id3, jiffy); f != nil { 171 t.Error("expected lookup to fail") 172 } 173 wg1.Done() 174 }() 175 176 go func() { 177 if f := lookupFlow(reg, id3, 10*jiffy); f != f3 { 178 t.Error("couldn't lookup registered flow (with wait)") 179 } 180 wg2.Done() 181 }() 182 183 wg1.Wait() 184 if err := reg.RegisterFlow( 185 ctx, id3, f3, nil /* inboundStreams */, flowStreamTimeout, 186 ); err != nil { 187 t.Fatal(err) 188 } 189 wg2.Wait() 190 191 // -- Lookup with huge timeout, register in the meantime. -- 192 193 go func() { 194 time.Sleep(jiffy) 195 if err := reg.RegisterFlow( 196 ctx, id4, f4, nil /* inboundStreams */, flowStreamTimeout, 197 ); err != nil { 198 t.Error(err) 199 } 200 }() 201 202 // This should return in a jiffy. 203 if f := lookupFlow(reg, id4, time.Hour); f != f4 { 204 t.Error("couldn't lookup registered flow (with wait)") 205 } 206 } 207 208 // Test that, if inbound streams are not connected within the timeout, errors 209 // are propagated to their consumers and future attempts to connect them fail. 210 func TestStreamConnectionTimeout(t *testing.T) { 211 defer leaktest.AfterTest(t)() 212 reg := NewFlowRegistry(0) 213 214 jiffy := time.Nanosecond 215 216 // Register a flow with a very low timeout. After it times out, we'll attempt 217 // to connect a stream, but it'll be too late. 218 id1 := execinfrapb.FlowID{UUID: uuid.MakeV4()} 219 f1 := &FlowBase{} 220 streamID1 := execinfrapb.StreamID(1) 221 consumer := &distsqlutils.RowBuffer{} 222 wg := &sync.WaitGroup{} 223 wg.Add(1) 224 inboundStreams := map[execinfrapb.StreamID]*InboundStreamInfo{ 225 streamID1: {receiver: RowInboundStreamHandler{consumer}, waitGroup: wg}, 226 } 227 if err := reg.RegisterFlow( 228 context.Background(), id1, f1, inboundStreams, jiffy, 229 ); err != nil { 230 t.Fatal(err) 231 } 232 233 testutils.SucceedsSoon(t, func() error { 234 si, err := lookupStreamInfo(reg, id1, streamID1) 235 if err != nil { 236 t.Fatal(err) 237 } 238 if !si.canceled { 239 return errors.Errorf("not timed out yet") 240 } 241 return nil 242 }) 243 244 testutils.SucceedsSoon(t, func() error { 245 if !consumer.ProducerClosed() { 246 return errors.New("expected consumer to have been closed when the flow timed out") 247 } 248 return nil 249 }) 250 251 // Create a dummy server stream to pass to ConnectInboundStream. 252 serverStream, _ /* clientStream */, cleanup, err := createDummyStream() 253 if err != nil { 254 t.Fatal(err) 255 } 256 defer cleanup() 257 258 _, _, _, err = reg.ConnectInboundStream(context.Background(), id1, streamID1, serverStream, jiffy) 259 if !testutils.IsError(err, "came too late") { 260 t.Fatalf("expected %q, got: %v", "came too late", err) 261 } 262 263 // Unregister the flow. Subsequent attempts to connect a stream should result 264 // in a different error than before. 265 reg.UnregisterFlow(id1) 266 _, _, _, err = reg.ConnectInboundStream(context.Background(), id1, streamID1, serverStream, jiffy) 267 if !testutils.IsError(err, "not found") { 268 t.Fatalf("expected %q, got: %v", "not found", err) 269 } 270 } 271 272 // Test that the FlowRegistry send the correct handshake messages: 273 // - if an inbound stream arrives to the registry before the consumer is 274 // scheduled, then a Handshake message informing that the consumer is not yet 275 // connected is sent; 276 // - once the consumer connects, another Handshake message is sent. 277 func TestHandshake(t *testing.T) { 278 defer leaktest.AfterTest(t)() 279 280 reg := NewFlowRegistry(0) 281 282 tests := []struct { 283 name string 284 consumerConnectedEarly bool 285 }{ 286 { 287 name: "consumer early", 288 consumerConnectedEarly: true, 289 }, 290 { 291 name: "consumer late", 292 consumerConnectedEarly: false, 293 }, 294 } 295 for _, tc := range tests { 296 t.Run(tc.name, func(t *testing.T) { 297 flowID := execinfrapb.FlowID{UUID: uuid.MakeV4()} 298 streamID := execinfrapb.StreamID(1) 299 300 serverStream, clientStream, cleanup, err := createDummyStream() 301 if err != nil { 302 t.Fatal(err) 303 } 304 defer cleanup() 305 306 connectProducer := func() { 307 // Simulate a producer connecting to the server. This should be called 308 // async because the consumer is not yet there and ConnectInboundStream 309 // is blocking. 310 if _, _, _, err := reg.ConnectInboundStream( 311 context.Background(), flowID, streamID, serverStream, time.Hour, 312 ); err != nil { 313 t.Error(err) 314 } 315 } 316 connectConsumer := func() { 317 f1 := &FlowBase{} 318 consumer := &distsqlutils.RowBuffer{} 319 wg := &sync.WaitGroup{} 320 wg.Add(1) 321 inboundStreams := map[execinfrapb.StreamID]*InboundStreamInfo{ 322 streamID: {receiver: RowInboundStreamHandler{consumer}, waitGroup: wg}, 323 } 324 if err := reg.RegisterFlow( 325 context.Background(), flowID, f1, inboundStreams, time.Hour, /* timeout */ 326 ); err != nil { 327 t.Fatal(err) 328 } 329 } 330 331 // If the consumer is supposed to be connected early, then we connect the 332 // consumer and then we connect the producer. Otherwise, we connect the 333 // producer and expect a first handshake and only then we connect the 334 // consumer. 335 if tc.consumerConnectedEarly { 336 connectConsumer() 337 go connectProducer() 338 } else { 339 go connectProducer() 340 341 // Expect the client (the producer) to receive a Handshake saying that the 342 // consumer is not connected yet. 343 consumerSignal, err := clientStream.Recv() 344 if err != nil { 345 t.Fatal(err) 346 } 347 if consumerSignal.Handshake == nil { 348 t.Fatalf("expected handshake, got: %+v", consumerSignal) 349 } 350 if consumerSignal.Handshake.ConsumerScheduled { 351 t.Fatal("expected !ConsumerScheduled") 352 } 353 354 connectConsumer() 355 } 356 357 // Now expect another Handshake message telling the producer that the consumer 358 // has connected. 359 consumerSignal, err := clientStream.Recv() 360 if err != nil { 361 t.Fatal(err) 362 } 363 if consumerSignal.Handshake == nil { 364 t.Fatalf("expected handshake, got: %+v", consumerSignal) 365 } 366 if !consumerSignal.Handshake.ConsumerScheduled { 367 t.Fatal("expected ConsumerScheduled") 368 } 369 }) 370 } 371 } 372 373 // TestFlowRegistryDrain verifies a FlowRegistry's draining behavior. See 374 // subtests for more details. 375 func TestFlowRegistryDrain(t *testing.T) { 376 defer leaktest.AfterTest(t)() 377 378 ctx := context.Background() 379 reg := NewFlowRegistry(0) 380 381 flow := &FlowBase{} 382 id := execinfrapb.FlowID{UUID: uuid.MakeV4()} 383 registerFlow := func(t *testing.T, id execinfrapb.FlowID) { 384 t.Helper() 385 if err := reg.RegisterFlow( 386 ctx, id, flow, nil /* inboundStreams */, 0, /* timeout */ 387 ); err != nil { 388 t.Fatal(err) 389 } 390 } 391 392 // WaitForFlow verifies that Drain waits for a flow to finish within the 393 // timeout. 394 t.Run("WaitForFlow", func(t *testing.T) { 395 registerFlow(t, id) 396 drainDone := make(chan struct{}) 397 go func() { 398 reg.Drain(math.MaxInt64 /* flowDrainWait */, 0 /* minFlowDrainWait */, nil /* reporter */) 399 drainDone <- struct{}{} 400 }() 401 // Be relatively sure that the FlowRegistry is draining. 402 time.Sleep(time.Microsecond) 403 reg.UnregisterFlow(id) 404 <-drainDone 405 reg.Undrain() 406 }) 407 408 // DrainTimeout verifies that Drain returns once the timeout expires. 409 t.Run("DrainTimeout", func(t *testing.T) { 410 registerFlow(t, id) 411 reg.Drain(0 /* flowDrainWait */, 0 /* minFlowDrainWait */, nil /* reporter */) 412 reg.UnregisterFlow(id) 413 reg.Undrain() 414 }) 415 416 // AcceptNewFlow verifies that a FlowRegistry continues accepting flows 417 // while draining. 418 t.Run("AcceptNewFlow", func(t *testing.T) { 419 registerFlow(t, id) 420 drainDone := make(chan struct{}) 421 go func() { 422 reg.Drain(math.MaxInt64 /* flowDrainWait */, 0 /* minFlowDrainWait */, nil /* reporter */) 423 drainDone <- struct{}{} 424 }() 425 // Be relatively sure that the FlowRegistry is draining. 426 time.Sleep(time.Microsecond) 427 newFlowID := execinfrapb.FlowID{UUID: uuid.MakeV4()} 428 registerFlow(t, newFlowID) 429 reg.UnregisterFlow(id) 430 select { 431 case <-drainDone: 432 t.Fatal("finished draining before unregistering new flow") 433 default: 434 } 435 reg.UnregisterFlow(newFlowID) 436 <-drainDone 437 // The registry should not accept new flows once it has finished draining. 438 if err := reg.RegisterFlow( 439 ctx, id, flow, nil /* inboundStreams */, 0, /* timeout */ 440 ); !testutils.IsError(err, "draining") { 441 t.Fatalf("unexpected error: %v", err) 442 } 443 reg.Undrain() 444 }) 445 446 // MinFlowWait verifies that the FlowRegistry waits a minimum amount of time 447 // for incoming flows to be registered. 448 t.Run("MinFlowWait", func(t *testing.T) { 449 // Case in which draining is initiated with zero running flows. 450 drainDone := make(chan struct{}) 451 // Register a flow right before the FlowRegistry waits for 452 // minFlowDrainWait. Use an errChan because draining is performed from 453 // another goroutine and cannot call t.Fatal. 454 errChan := make(chan error) 455 reg.testingRunBeforeDrainSleep = func() { 456 if err := reg.RegisterFlow( 457 ctx, id, flow, nil /* inboundStreams */, 0, /* timeout */ 458 ); err != nil { 459 errChan <- err 460 } 461 errChan <- nil 462 } 463 defer func() { reg.testingRunBeforeDrainSleep = nil }() 464 go func() { 465 reg.Drain(math.MaxInt64 /* flowDrainWait */, 0 /* minFlowDrainWait */, nil /* reporter */) 466 drainDone <- struct{}{} 467 }() 468 if err := <-errChan; err != nil { 469 t.Fatal(err) 470 } 471 reg.UnregisterFlow(id) 472 <-drainDone 473 reg.Undrain() 474 475 // Case in which a running flow finishes before the minimum wait time. We 476 // attempt to register another flow after the completion of the first flow 477 // to simulate an incoming flow that is registered during the minimum wait 478 // time. However, it is possible to unregister the first flow after the 479 // minimum wait time has passed, in which case we simply verify that the 480 // FlowRegistry drain process has lasted at least the required wait time. 481 registerFlow(t, id) 482 reg.testingRunBeforeDrainSleep = func() { 483 if err := reg.RegisterFlow( 484 ctx, id, flow, nil /* inboundStreams */, 0, /* timeout */ 485 ); err != nil { 486 errChan <- err 487 } 488 errChan <- nil 489 } 490 minFlowDrainWait := 10 * time.Millisecond 491 start := timeutil.Now() 492 go func() { 493 reg.Drain(math.MaxInt64 /* flowDrainWait */, minFlowDrainWait, nil /* reporter */) 494 drainDone <- struct{}{} 495 }() 496 // Be relatively sure that the FlowRegistry is draining. 497 time.Sleep(time.Microsecond) 498 reg.UnregisterFlow(id) 499 select { 500 case <-drainDone: 501 if timeutil.Since(start) < minFlowDrainWait { 502 t.Fatal("flow registry did not wait at least minFlowDrainWait") 503 } 504 return 505 case err := <-errChan: 506 if err != nil { 507 t.Fatal(err) 508 } 509 } 510 reg.UnregisterFlow(id) 511 <-drainDone 512 513 reg.Undrain() 514 }) 515 } 516 517 // TestInboundStreamTimeoutIsRetryable verifies that a failure from an inbound 518 // stream to connect in a timeout is considered retryable by 519 // pgerror.IsSQLRetryableError. 520 // TODO(asubiotto): This error should also be considered retryable by clients. 521 func TestInboundStreamTimeoutIsRetryable(t *testing.T) { 522 defer leaktest.AfterTest(t)() 523 524 fr := NewFlowRegistry(0) 525 wg := sync.WaitGroup{} 526 rc := &execinfra.RowChannel{} 527 rc.InitWithBufSizeAndNumSenders(sqlbase.OneIntCol, 1 /* chanBufSize */, 1 /* numSenders */) 528 inboundStreams := map[execinfrapb.StreamID]*InboundStreamInfo{ 529 0: { 530 receiver: RowInboundStreamHandler{rc}, 531 waitGroup: &wg, 532 }, 533 } 534 wg.Add(1) 535 if err := fr.RegisterFlow( 536 context.Background(), execinfrapb.FlowID{}, &FlowBase{}, inboundStreams, 0, /* timeout */ 537 ); err != nil { 538 t.Fatal(err) 539 } 540 wg.Wait() 541 if _, meta := rc.Next(); meta == nil { 542 t.Fatal("expected error but got no meta") 543 } else if !pgerror.IsSQLRetryableError(meta.Err) { 544 t.Fatalf("unexpected error: %v", meta.Err) 545 } 546 } 547 548 // TestTimeoutPushDoesntBlockRegister verifies that in the case of a timeout 549 // error, we are still able to register flows while Pushing the error (#34041). 550 func TestTimeoutPushDoesntBlockRegister(t *testing.T) { 551 defer leaktest.AfterTest(t)() 552 553 ctx := context.Background() 554 fr := NewFlowRegistry(0) 555 // pushChan is used to be able to tell when a Push on the RowBuffer has 556 // occurred. 557 pushChan := make(chan *execinfrapb.ProducerMetadata) 558 rc := distsqlutils.NewRowBuffer( 559 sqlbase.OneIntCol, 560 nil, /* rows */ 561 distsqlutils.RowBufferArgs{ 562 OnPush: func(_ sqlbase.EncDatumRow, meta *execinfrapb.ProducerMetadata) { 563 pushChan <- meta 564 <-pushChan 565 }, 566 }, 567 ) 568 569 wg := sync.WaitGroup{} 570 wg.Add(1) 571 inboundStreams := map[execinfrapb.StreamID]*InboundStreamInfo{ 572 0: { 573 receiver: RowInboundStreamHandler{rc}, 574 waitGroup: &wg, 575 }, 576 } 577 578 // RegisterFlow with an immediate timeout. 579 if err := fr.RegisterFlow( 580 ctx, execinfrapb.FlowID{}, &FlowBase{}, inboundStreams, 0, /* timeout */ 581 ); err != nil { 582 t.Fatal(err) 583 } 584 585 // Ensure RegisterFlow performs a Push. 586 meta := <-pushChan 587 if !testutils.IsError(meta.Err, errNoInboundStreamConnection.Error()) { 588 t.Fatalf("unexpected err %v, expected %s", meta.Err, errNoInboundStreamConnection) 589 } 590 591 // Attempt to register a flow. Note that this flow has no inbound streams, so 592 // Pushing to the RowBuffer is unexpected. 593 if err := fr.RegisterFlow( 594 ctx, execinfrapb.FlowID{UUID: uuid.MakeV4()}, &FlowBase{}, nil /* inboundStreams */, time.Hour, /* timeout */ 595 ); err != nil { 596 t.Fatal(err) 597 } 598 599 // Unblock the first RegisterFlow. 600 close(pushChan) 601 } 602 603 // TestFlowCancelPartiallyBlocked tests that cancellation messages can propagate 604 // into a flow even if one of the inbound streams are blocked (#35859). 605 func TestFlowCancelPartiallyBlocked(t *testing.T) { 606 defer leaktest.AfterTest(t)() 607 608 ctx := context.Background() 609 fr := NewFlowRegistry(0) 610 left := &execinfra.RowChannel{} 611 left.InitWithBufSizeAndNumSenders(nil /* types */, 1, 1) 612 right := &execinfra.RowChannel{} 613 right.InitWithBufSizeAndNumSenders(nil /* types */, 1, 1) 614 615 wgLeft := sync.WaitGroup{} 616 wgLeft.Add(1) 617 wgRight := sync.WaitGroup{} 618 wgRight.Add(1) 619 inboundStreams := map[execinfrapb.StreamID]*InboundStreamInfo{ 620 0: { 621 receiver: RowInboundStreamHandler{left}, 622 waitGroup: &wgLeft, 623 }, 624 1: { 625 receiver: RowInboundStreamHandler{right}, 626 waitGroup: &wgRight, 627 }, 628 } 629 630 // Fill up the left, so pushes to it block. 631 left.Push(nil, &execinfrapb.ProducerMetadata{}) 632 633 // RegisterFlow with an immediate timeout. 634 flow := &FlowBase{ 635 FlowCtx: execinfra.FlowCtx{ 636 ID: execinfrapb.FlowID{UUID: uuid.FastMakeV4()}, 637 }, 638 inboundStreams: inboundStreams, 639 flowRegistry: fr, 640 } 641 if err := fr.RegisterFlow( 642 ctx, flow.ID, flow, inboundStreams, 10*time.Second, /* timeout */ 643 ); err != nil { 644 t.Fatal(err) 645 } 646 647 flow.cancel() 648 649 // Reading from the right shouldn't block and should immediately return a 650 // flow canceled error. 651 652 _, meta := right.Next() 653 if !errors.Is(meta.Err, sqlbase.QueryCanceledError) { 654 t.Fatal("expected query canceled, found", meta.Err) 655 } 656 657 // Read from the left to unblock the canceler, assert that the next 658 // message is the query canceled message as well. 659 660 _, _ = left.Next() 661 _, meta = left.Next() 662 if !errors.Is(meta.Err, sqlbase.QueryCanceledError) { 663 t.Fatal("expected query canceled, found", meta.Err) 664 } 665 }