github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/pgwire/conn_test.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package pgwire 12 13 import ( 14 "bytes" 15 "context" 16 gosql "database/sql" 17 "fmt" 18 "io" 19 "io/ioutil" 20 "net" 21 "net/url" 22 "strconv" 23 "strings" 24 "sync" 25 "testing" 26 "time" 27 28 "github.com/cockroachdb/cockroach/pkg/base" 29 "github.com/cockroachdb/cockroach/pkg/security" 30 "github.com/cockroachdb/cockroach/pkg/sql" 31 "github.com/cockroachdb/cockroach/pkg/sql/lex" 32 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" 33 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" 34 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" 35 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 36 "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" 37 "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" 38 "github.com/cockroachdb/cockroach/pkg/sql/sqlutil" 39 "github.com/cockroachdb/cockroach/pkg/testutils" 40 "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" 41 "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" 42 "github.com/cockroachdb/cockroach/pkg/util" 43 "github.com/cockroachdb/cockroach/pkg/util/leaktest" 44 "github.com/cockroachdb/cockroach/pkg/util/log" 45 "github.com/cockroachdb/cockroach/pkg/util/metric" 46 "github.com/cockroachdb/cockroach/pkg/util/mon" 47 "github.com/cockroachdb/cockroach/pkg/util/stop" 48 "github.com/cockroachdb/errors" 49 "github.com/jackc/pgx" 50 "github.com/jackc/pgx/pgproto3" 51 "github.com/stretchr/testify/require" 52 "golang.org/x/sync/errgroup" 53 ) 54 55 // Test the conn struct: check that it marshalls the correct commands to the 56 // stmtBuf. 57 // 58 // This test is weird because it aims to be a "unit test" for the conn with 59 // minimal dependencies, but it needs a producer speaking the pgwire protocol 60 // on the other end of the connection. We use the pgx Postgres driver for this. 61 // We're going to simulate a client sending various commands to the server. We 62 // don't have proper execution of those commands in this test, so we synthesize 63 // responses. 64 // 65 // This test depends on recognizing the queries sent by pgx when it opens a 66 // connection. If that set of queries changes, this test will probably fail 67 // complaining that the stmtBuf has an unexpected entry in it. 68 func TestConn(t *testing.T) { 69 defer leaktest.AfterTest(t)() 70 71 // The test server is used only incidentally by this test: this is not the 72 // server that the client will connect to; we just use it on the side to 73 // execute some metadata queries that pgx sends whenever it opens a 74 // connection. 75 s, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true, UseDatabase: "system"}) 76 defer s.Stopper().Stop(context.Background()) 77 78 // Start a pgwire "server". 79 addr := util.TestAddr 80 ln, err := net.Listen(addr.Network(), addr.String()) 81 if err != nil { 82 t.Fatal(err) 83 } 84 serverAddr := ln.Addr() 85 log.Infof(context.Background(), "started listener on %s", serverAddr) 86 87 var g errgroup.Group 88 ctx := context.Background() 89 90 var clientWG sync.WaitGroup 91 clientWG.Add(1) 92 93 g.Go(func() error { 94 return client(ctx, serverAddr, &clientWG) 95 }) 96 97 // Wait for the client to connect and perform the handshake. 98 conn, err := waitForClientConn(ln) 99 if err != nil { 100 t.Fatal(err) 101 } 102 103 // Run the conn's loop in the background - it will push commands to the 104 // buffer. 105 serveCtx, stopServe := context.WithCancel(ctx) 106 g.Go(func() error { 107 conn.serveImpl( 108 serveCtx, 109 func() bool { return false }, /* draining */ 110 // sqlServer - nil means don't create a command processor and a write side of the conn 111 nil, 112 mon.BoundAccount{}, /* reserved */ 113 authOptions{testingSkipAuth: true}, 114 s.Stopper()) 115 return nil 116 }) 117 defer stopServe() 118 119 if err := processPgxStartup(ctx, s, conn); err != nil { 120 t.Fatal(err) 121 } 122 123 // Now we'll expect to receive the commands corresponding to the operations in 124 // client(). 125 rd := sql.MakeStmtBufReader(&conn.stmtBuf) 126 expectExecStmt(ctx, t, "SELECT 1", &rd, conn, queryStringComplete) 127 expectSync(ctx, t, &rd) 128 expectExecStmt(ctx, t, "SELECT 2", &rd, conn, queryStringComplete) 129 expectSync(ctx, t, &rd) 130 expectPrepareStmt(ctx, t, "p1", "SELECT 'p1'", &rd, conn) 131 expectDescribeStmt(ctx, t, "p1", pgwirebase.PrepareStatement, &rd, conn) 132 expectSync(ctx, t, &rd) 133 expectBindStmt(ctx, t, "p1", &rd, conn) 134 expectExecPortal(ctx, t, "", &rd, conn) 135 // Check that a query string with multiple queries sent using the simple 136 // protocol is broken up. 137 expectSync(ctx, t, &rd) 138 expectExecStmt(ctx, t, "SELECT 4", &rd, conn, queryStringIncomplete) 139 expectExecStmt(ctx, t, "SELECT 5", &rd, conn, queryStringIncomplete) 140 expectExecStmt(ctx, t, "SELECT 6", &rd, conn, queryStringComplete) 141 expectSync(ctx, t, &rd) 142 143 // Check that the batching works like the client intended. 144 145 // pgx wraps batchs in transactions. 146 expectExecStmt(ctx, t, "BEGIN TRANSACTION", &rd, conn, queryStringComplete) 147 expectSync(ctx, t, &rd) 148 expectPrepareStmt(ctx, t, "", "SELECT 7", &rd, conn) 149 expectBindStmt(ctx, t, "", &rd, conn) 150 expectDescribeStmt(ctx, t, "", pgwirebase.PreparePortal, &rd, conn) 151 expectExecPortal(ctx, t, "", &rd, conn) 152 expectPrepareStmt(ctx, t, "", "SELECT 8", &rd, conn) 153 // Now we'll send an error, in the middle of this batch. pgx will stop waiting 154 // for results for commands in the batch. We'll then test that seeking to the 155 // next batch advances us to the correct statement. 156 if err := finishQuery(generateError, conn); err != nil { 157 t.Fatal(err) 158 } 159 // We're about to seek to the next batch but, as per seek's contract, seeking 160 // can only be called when there is something in the buffer. Since the buffer 161 // is filled concurrently with this code, we call CurCmd to ensure that 162 // there's something in there. 163 if _, err := rd.CurCmd(); err != nil { 164 t.Fatal(err) 165 } 166 // Skip all the remaining messages in the batch. 167 if err := rd.SeekToNextBatch(); err != nil { 168 t.Fatal(err) 169 } 170 // We got to the COMMIT that pgx pushed to match the BEGIN it generated for 171 // the batch. 172 expectSync(ctx, t, &rd) 173 expectExecStmt(ctx, t, "COMMIT TRANSACTION", &rd, conn, queryStringComplete) 174 expectSync(ctx, t, &rd) 175 expectExecStmt(ctx, t, "SELECT 9", &rd, conn, queryStringComplete) 176 expectSync(ctx, t, &rd) 177 178 // Test that parse error turns into SendError. 179 expectSendError(ctx, t, pgcode.Syntax, &rd, conn) 180 181 clientWG.Done() 182 183 if err := g.Wait(); err != nil { 184 t.Fatal(err) 185 } 186 } 187 188 // processPgxStartup processes the first few queries that the pgx driver 189 // automatically sends on a new connection that has been established. 190 func processPgxStartup(ctx context.Context, s serverutils.TestServerInterface, c *conn) error { 191 rd := sql.MakeStmtBufReader(&c.stmtBuf) 192 193 for { 194 cmd, err := rd.CurCmd() 195 if err != nil { 196 log.Errorf(ctx, "CurCmd error: %v", err) 197 return err 198 } 199 200 if _, ok := cmd.(sql.Sync); ok { 201 log.Infof(ctx, "advancing Sync") 202 rd.AdvanceOne() 203 continue 204 } 205 206 exec, ok := cmd.(sql.ExecStmt) 207 if !ok { 208 log.Infof(ctx, "stop wait at: %v", cmd) 209 return nil 210 } 211 query := exec.AST.String() 212 if !strings.HasPrefix(query, "SELECT t.oid") { 213 log.Infof(ctx, "stop wait at query: %s", query) 214 return nil 215 } 216 if err := execQuery(ctx, query, s, c); err != nil { 217 log.Errorf(ctx, "execQuery %s error: %v", query, err) 218 return err 219 } 220 log.Infof(ctx, "executed query: %s", query) 221 rd.AdvanceOne() 222 } 223 } 224 225 // execQuery executes a query on the passed-in server and send the results on c. 226 func execQuery( 227 ctx context.Context, query string, s serverutils.TestServerInterface, c *conn, 228 ) error { 229 rows, cols, err := s.InternalExecutor().(sqlutil.InternalExecutor).QueryWithCols( 230 ctx, "test", nil, /* txn */ 231 sqlbase.InternalExecutorSessionDataOverride{User: security.RootUser, Database: "system"}, 232 query, 233 ) 234 if err != nil { 235 return err 236 } 237 return sendResult(ctx, c, cols, rows) 238 } 239 240 func client(ctx context.Context, serverAddr net.Addr, wg *sync.WaitGroup) error { 241 host, ports, err := net.SplitHostPort(serverAddr.String()) 242 if err != nil { 243 return err 244 } 245 port, err := strconv.Atoi(ports) 246 if err != nil { 247 return err 248 } 249 conn, err := pgx.Connect( 250 pgx.ConnConfig{ 251 Logger: pgxTestLogger{}, 252 Host: host, 253 Port: uint16(port), 254 User: "root", 255 // Setting this so that the queries sent by pgx to initialize the 256 // connection are not using prepared statements. That simplifies the 257 // scaffolding of the test. 258 PreferSimpleProtocol: true, 259 Database: "system", 260 }) 261 if err != nil { 262 return err 263 } 264 265 if _, err := conn.Exec("select 1"); err != nil { 266 return err 267 } 268 if _, err := conn.Exec("select 2"); err != nil { 269 return err 270 } 271 if _, err := conn.Prepare("p1", "select 'p1'"); err != nil { 272 return err 273 } 274 if _, err := conn.ExecEx( 275 ctx, "p1", 276 // We set these options because apparently that's how I tell pgx that it 277 // should check whether "p1" is a prepared statement. 278 &pgx.QueryExOptions{SimpleProtocol: false}); err != nil { 279 return err 280 } 281 282 // Send a group of statements as one query string using the simple protocol. 283 // We'll check that we receive them one by one, but marked as a batch. 284 if _, err := conn.Exec("select 4; select 5; select 6;"); err != nil { 285 return err 286 } 287 288 batch := conn.BeginBatch() 289 batch.Queue("select 7", nil, nil, nil) 290 batch.Queue("select 8", nil, nil, nil) 291 if err := batch.Send(context.Background(), &pgx.TxOptions{}); err != nil { 292 return err 293 } 294 if err := batch.Close(); err != nil { 295 // Swallow the error that we injected. 296 if !strings.Contains(err.Error(), "injected") { 297 return err 298 } 299 } 300 301 if _, err := conn.Exec("select 9"); err != nil { 302 return err 303 } 304 if _, err := conn.Exec("bogus statement failing to parse"); err != nil { 305 return err 306 } 307 308 wg.Wait() 309 310 return conn.Close() 311 } 312 313 // waitForClientConn blocks until a client connects and performs the pgwire 314 // handshake. This emulates what pgwire.Server does. 315 func waitForClientConn(ln net.Listener) (*conn, error) { 316 conn, err := ln.Accept() 317 if err != nil { 318 return nil, err 319 } 320 321 var buf pgwirebase.ReadBuffer 322 _, err = buf.ReadUntypedMsg(conn) 323 if err != nil { 324 return nil, err 325 } 326 version, err := buf.GetUint32() 327 if err != nil { 328 return nil, err 329 } 330 if version != version30 { 331 return nil, errors.Errorf("unexpected protocol version: %d", version) 332 } 333 334 // Consume the connection options. 335 if _, err := parseClientProvidedSessionParameters(context.Background(), nil, &buf); err != nil { 336 return nil, err 337 } 338 339 metrics := makeServerMetrics(sql.MemoryMetrics{} /* sqlMemMetrics */, metric.TestSampleInterval) 340 pgwireConn := newConn(conn, sql.SessionArgs{ConnResultsBufferSize: 16 << 10}, &metrics, nil) 341 return pgwireConn, nil 342 } 343 344 func makeTestingConvCfg() sessiondata.DataConversionConfig { 345 return sessiondata.DataConversionConfig{ 346 Location: time.UTC, 347 BytesEncodeFormat: lex.BytesEncodeHex, 348 } 349 } 350 351 // sendResult serializes a set of rows in pgwire format and sends them on a 352 // connection. 353 // 354 // TODO(andrei): Tests using this should probably switch to using the similar 355 // routines in the connection once conn learns how to write rows. 356 func sendResult( 357 ctx context.Context, c *conn, cols sqlbase.ResultColumns, rows []tree.Datums, 358 ) error { 359 if err := c.writeRowDescription(ctx, cols, nil /* formatCodes */, c.conn); err != nil { 360 return err 361 } 362 363 defaultConv := makeTestingConvCfg() 364 for _, row := range rows { 365 c.msgBuilder.initMsg(pgwirebase.ServerMsgDataRow) 366 c.msgBuilder.putInt16(int16(len(row))) 367 for _, col := range row { 368 c.msgBuilder.writeTextDatum(ctx, col, defaultConv) 369 } 370 371 if err := c.msgBuilder.finishMsg(c.conn); err != nil { 372 return err 373 } 374 } 375 376 return finishQuery(execute, c) 377 } 378 379 type executeType int 380 381 const ( 382 queryStringComplete executeType = iota 383 queryStringIncomplete 384 ) 385 386 func expectExecStmt( 387 ctx context.Context, t *testing.T, expSQL string, rd *sql.StmtBufReader, c *conn, typ executeType, 388 ) { 389 t.Helper() 390 cmd, err := rd.CurCmd() 391 if err != nil { 392 t.Fatal(err) 393 } 394 rd.AdvanceOne() 395 396 es, ok := cmd.(sql.ExecStmt) 397 if !ok { 398 t.Fatalf("expected command ExecStmt, got: %T (%+v)", cmd, cmd) 399 } 400 401 if es.AST.String() != expSQL { 402 t.Fatalf("expected %s, got %s", expSQL, es.AST.String()) 403 } 404 405 if es.ParseStart == (time.Time{}) { 406 t.Fatalf("ParseStart not filled in") 407 } 408 if es.ParseEnd == (time.Time{}) { 409 t.Fatalf("ParseEnd not filled in") 410 } 411 if typ == queryStringComplete { 412 if err := finishQuery(execute, c); err != nil { 413 t.Fatal(err) 414 } 415 } else { 416 if err := finishQuery(cmdComplete, c); err != nil { 417 t.Fatal(err) 418 } 419 } 420 } 421 422 func expectPrepareStmt( 423 ctx context.Context, t *testing.T, expName string, expSQL string, rd *sql.StmtBufReader, c *conn, 424 ) { 425 t.Helper() 426 cmd, err := rd.CurCmd() 427 if err != nil { 428 t.Fatal(err) 429 } 430 rd.AdvanceOne() 431 432 pr, ok := cmd.(sql.PrepareStmt) 433 if !ok { 434 t.Fatalf("expected command PrepareStmt, got: %T (%+v)", cmd, cmd) 435 } 436 437 if pr.Name != expName { 438 t.Fatalf("expected name %s, got %s", expName, pr.Name) 439 } 440 441 if pr.AST.String() != expSQL { 442 t.Fatalf("expected %s, got %s", expSQL, pr.AST.String()) 443 } 444 445 if err := finishQuery(prepare, c); err != nil { 446 t.Fatal(err) 447 } 448 } 449 450 func expectDescribeStmt( 451 ctx context.Context, 452 t *testing.T, 453 expName string, 454 expType pgwirebase.PrepareType, 455 rd *sql.StmtBufReader, 456 c *conn, 457 ) { 458 t.Helper() 459 cmd, err := rd.CurCmd() 460 if err != nil { 461 t.Fatal(err) 462 } 463 rd.AdvanceOne() 464 465 desc, ok := cmd.(sql.DescribeStmt) 466 if !ok { 467 t.Fatalf("expected command DescribeStmt, got: %T (%+v)", cmd, cmd) 468 } 469 470 if desc.Name != expName { 471 t.Fatalf("expected name %s, got %s", expName, desc.Name) 472 } 473 474 if desc.Type != expType { 475 t.Fatalf("expected type %s, got %s", expType, desc.Type) 476 } 477 478 if err := finishQuery(describe, c); err != nil { 479 t.Fatal(err) 480 } 481 } 482 483 func expectBindStmt( 484 ctx context.Context, t *testing.T, expName string, rd *sql.StmtBufReader, c *conn, 485 ) { 486 t.Helper() 487 cmd, err := rd.CurCmd() 488 if err != nil { 489 t.Fatal(err) 490 } 491 rd.AdvanceOne() 492 493 bd, ok := cmd.(sql.BindStmt) 494 if !ok { 495 t.Fatalf("expected command BindStmt, got: %T (%+v)", cmd, cmd) 496 } 497 498 if bd.PreparedStatementName != expName { 499 t.Fatalf("expected name %s, got %s", expName, bd.PreparedStatementName) 500 } 501 502 if err := finishQuery(bind, c); err != nil { 503 t.Fatal(err) 504 } 505 } 506 507 func expectSync(ctx context.Context, t *testing.T, rd *sql.StmtBufReader) { 508 t.Helper() 509 cmd, err := rd.CurCmd() 510 if err != nil { 511 t.Fatal(err) 512 } 513 rd.AdvanceOne() 514 515 _, ok := cmd.(sql.Sync) 516 if !ok { 517 t.Fatalf("expected command Sync, got: %T (%+v)", cmd, cmd) 518 } 519 } 520 521 func expectExecPortal( 522 ctx context.Context, t *testing.T, expName string, rd *sql.StmtBufReader, c *conn, 523 ) { 524 t.Helper() 525 cmd, err := rd.CurCmd() 526 if err != nil { 527 t.Fatal(err) 528 } 529 rd.AdvanceOne() 530 531 ep, ok := cmd.(sql.ExecPortal) 532 if !ok { 533 t.Fatalf("expected command ExecPortal, got: %T (%+v)", cmd, cmd) 534 } 535 536 if ep.Name != expName { 537 t.Fatalf("expected name %s, got %s", expName, ep.Name) 538 } 539 540 if err := finishQuery(execPortal, c); err != nil { 541 t.Fatal(err) 542 } 543 } 544 545 func expectSendError( 546 ctx context.Context, t *testing.T, pgErrCode string, rd *sql.StmtBufReader, c *conn, 547 ) { 548 t.Helper() 549 cmd, err := rd.CurCmd() 550 if err != nil { 551 t.Fatal(err) 552 } 553 rd.AdvanceOne() 554 555 se, ok := cmd.(sql.SendError) 556 if !ok { 557 t.Fatalf("expected command SendError, got: %T (%+v)", cmd, cmd) 558 } 559 560 if code := pgerror.GetPGCode(se.Err); code != pgErrCode { 561 t.Fatalf("expected code %s, got: %s", pgErrCode, code) 562 } 563 564 if err := finishQuery(execPortal, c); err != nil { 565 t.Fatal(err) 566 } 567 } 568 569 type finishType int 570 571 const ( 572 execute finishType = iota 573 // cmdComplete is like execute, except that it marks the completion of a query 574 // in a larger query string and so no ReadyForQuery message should be sent. 575 cmdComplete 576 prepare 577 bind 578 describe 579 execPortal 580 generateError 581 ) 582 583 // Send a CommandComplete/ReadyForQuery to signal that the rows are done. 584 func finishQuery(t finishType, c *conn) error { 585 var skipFinish bool 586 587 switch t { 588 case execPortal: 589 fallthrough 590 case cmdComplete: 591 fallthrough 592 case execute: 593 c.msgBuilder.initMsg(pgwirebase.ServerMsgCommandComplete) 594 // HACK: This message is supposed to contains a command tag but this test is 595 // not sure about how to produce one and it works without it. 596 c.msgBuilder.nullTerminate() 597 case prepare: 598 // pgx doesn't send a Sync in between prepare (Parse protocol message) and 599 // the subsequent Describe, so we're not going to send a ReadyForQuery 600 // below. 601 c.msgBuilder.initMsg(pgwirebase.ServerMsgParseComplete) 602 case describe: 603 skipFinish = true 604 if err := c.writeRowDescription( 605 context.Background(), nil /* columns */, nil /* formatCodes */, c.conn, 606 ); err != nil { 607 return err 608 } 609 case bind: 610 // pgx doesn't send a Sync mesage in between Bind and Execute, so we're not 611 // going to send a ReadyForQuery below. 612 c.msgBuilder.initMsg(pgwirebase.ServerMsgBindComplete) 613 case generateError: 614 c.msgBuilder.initMsg(pgwirebase.ServerMsgErrorResponse) 615 c.msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFieldSeverity) 616 c.msgBuilder.writeTerminatedString("ERROR") 617 c.msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFieldMsgPrimary) 618 c.msgBuilder.writeTerminatedString("injected") 619 c.msgBuilder.nullTerminate() 620 if err := c.msgBuilder.finishMsg(c.conn); err != nil { 621 return err 622 } 623 } 624 625 if !skipFinish { 626 if err := c.msgBuilder.finishMsg(c.conn); err != nil { 627 return err 628 } 629 } 630 631 if t != cmdComplete && t != bind && t != prepare { 632 c.msgBuilder.initMsg(pgwirebase.ServerMsgReady) 633 c.msgBuilder.writeByte('I') // transaction status: no txn 634 if err := c.msgBuilder.finishMsg(c.conn); err != nil { 635 return err 636 } 637 } 638 return nil 639 } 640 641 type pgxTestLogger struct{} 642 643 func (l pgxTestLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) { 644 log.Infof(context.Background(), "pgx log [%s] %s - %s", level, msg, data) 645 } 646 647 // pgxTestLogger implements pgx.Logger. 648 var _ pgx.Logger = pgxTestLogger{} 649 650 // Test that closing a pgwire connection causes transactions to be rolled back 651 // and release their locks. 652 func TestConnCloseReleasesLocks(t *testing.T) { 653 defer leaktest.AfterTest(t)() 654 // We're going to test closing the connection in both the Open and Aborted 655 // state. 656 testutils.RunTrueAndFalse(t, "open state", func(t *testing.T, open bool) { 657 s, _, _ := serverutils.StartServer(t, base.TestServerArgs{}) 658 ctx := context.Background() 659 defer s.Stopper().Stop(ctx) 660 661 pgURL, cleanupFunc := sqlutils.PGUrl( 662 t, s.ServingSQLAddr(), "testConnClose" /* prefix */, url.User(security.RootUser), 663 ) 664 defer cleanupFunc() 665 db, err := gosql.Open("postgres", pgURL.String()) 666 require.NoError(t, err) 667 defer db.Close() 668 669 r := sqlutils.MakeSQLRunner(db) 670 r.Exec(t, "CREATE DATABASE test") 671 r.Exec(t, "CREATE TABLE test.t (x int primary key)") 672 673 pgxConfig, err := pgx.ParseConnectionString(pgURL.String()) 674 if err != nil { 675 t.Fatal(err) 676 } 677 678 conn, err := pgx.Connect(pgxConfig) 679 require.NoError(t, err) 680 tx, err := conn.Begin() 681 require.NoError(t, err) 682 _, err = tx.Exec("INSERT INTO test.t(x) values (1)") 683 require.NoError(t, err) 684 readCh := make(chan error) 685 go func() { 686 conn2, err := pgx.Connect(pgxConfig) 687 require.NoError(t, err) 688 _, err = conn2.Exec("SELECT * FROM test.t") 689 readCh <- err 690 }() 691 692 select { 693 case err := <-readCh: 694 t.Fatalf("unexpected read unblocked: %v", err) 695 case <-time.After(10 * time.Millisecond): 696 } 697 698 if !open { 699 _, err = tx.Exec("bogus") 700 require.NotNil(t, err) 701 } 702 err = conn.Close() 703 require.NoError(t, err) 704 select { 705 case readErr := <-readCh: 706 require.NoError(t, readErr) 707 case <-time.After(10 * time.Second): 708 t.Fatal("read not unblocked in a timely manner") 709 } 710 }) 711 } 712 713 // Test that closing a client connection such that producing results rows 714 // encounters network errors doesn't crash the server (#23694). 715 // 716 // We'll run a query that produces a bunch of rows and close the connection as 717 // soon as the client received anything, this way ensuring that: 718 // a) the query started executing when the connection is closed, and so it's 719 // likely to observe a network error and not a context cancelation, and 720 // b) the connection's server-side results buffer has overflowed, and so 721 // attempting to produce results (through CommandResult.AddRow()) observes 722 // network errors. 723 func TestConnCloseWhileProducingRows(t *testing.T) { 724 defer leaktest.AfterTest(t)() 725 s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) 726 ctx := context.Background() 727 defer s.Stopper().Stop(ctx) 728 729 // Disable results buffering. 730 if _, err := db.Exec( 731 `SET CLUSTER SETTING sql.defaults.results_buffer.size = '0'`, 732 ); err != nil { 733 t.Fatal(err) 734 } 735 pgURL, cleanupFunc := sqlutils.PGUrl( 736 t, s.ServingSQLAddr(), "testConnClose" /* prefix */, url.User(security.RootUser), 737 ) 738 defer cleanupFunc() 739 noBufferDB, err := gosql.Open("postgres", pgURL.String()) 740 if err != nil { 741 t.Fatal(err) 742 } 743 defer noBufferDB.Close() 744 745 r := sqlutils.MakeSQLRunner(noBufferDB) 746 r.Exec(t, "CREATE DATABASE test") 747 r.Exec(t, "CREATE TABLE test.test AS SELECT * FROM generate_series(1,100)") 748 749 pgxConfig, err := pgx.ParseConnectionString(pgURL.String()) 750 if err != nil { 751 t.Fatal(err) 752 } 753 // We test both with and without DistSQL, as the way that network errors are 754 // observed depends on the engine. 755 testutils.RunTrueAndFalse(t, "useDistSQL", func(t *testing.T, useDistSQL bool) { 756 conn, err := pgx.Connect(pgxConfig) 757 if err != nil { 758 t.Fatal(err) 759 } 760 var query string 761 if useDistSQL { 762 query = `SET DISTSQL = 'always'` 763 } else { 764 query = `SET DISTSQL = 'off'` 765 } 766 if _, err := conn.Exec(query); err != nil { 767 t.Fatal(err) 768 } 769 rows, err := conn.Query("SELECT * FROM test.test") 770 if err != nil { 771 t.Fatal(err) 772 } 773 if hasResults := rows.Next(); !hasResults { 774 t.Fatal("expected results") 775 } 776 if err := conn.Close(); err != nil { 777 t.Fatal(err) 778 } 779 }) 780 } 781 782 // TestMaliciousInputs verifies that known malicious inputs sent to 783 // a v3Conn don't crash the server. 784 func TestMaliciousInputs(t *testing.T) { 785 defer leaktest.AfterTest(t)() 786 787 ctx := context.Background() 788 789 for _, tc := range [][]byte{ 790 // This byte string sends a pgwirebase.ClientMsgClose message type. When 791 // ReadBuffer.readUntypedMsg is called, the 4 bytes is subtracted 792 // from the size, leaving a 0-length ReadBuffer. Following this, 793 // handleClose is called with the empty buffer, which calls 794 // getPrepareType. Previously, getPrepareType would crash on an 795 // empty buffer. This is now fixed. 796 {byte(pgwirebase.ClientMsgClose), 0x00, 0x00, 0x00, 0x04}, 797 // This byte string exploited the same bug using a pgwirebase.ClientMsgDescribe 798 // message type. 799 {byte(pgwirebase.ClientMsgDescribe), 0x00, 0x00, 0x00, 0x04}, 800 // This would cause ReadBuffer.getInt16 to overflow, resulting in a 801 // negative value being used for an allocation size. 802 {byte(pgwirebase.ClientMsgParse), 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0xff, 0xff}, 803 } { 804 t.Run("", func(t *testing.T) { 805 w, r := net.Pipe() 806 defer w.Close() 807 defer r.Close() 808 809 go func() { 810 // This io.Copy will discard all bytes from w until w is closed. 811 // This is needed because sends on the net.Pipe are synchronous, so 812 // the conn will block if we don't read whatever it tries to send. 813 // The reason this works is that ioutil.devNull implements ReadFrom 814 // as an infinite loop, so it will Read continuously until it hits an 815 // error (on w.Close()). 816 _, _ = io.Copy(ioutil.Discard, w) 817 }() 818 819 errChan := make(chan error, 1) 820 go func() { 821 // Write the malicious data. 822 if _, err := w.Write(tc); err != nil { 823 errChan <- err 824 return 825 } 826 827 // Sync and terminate if a panic did not occur to stop the server. 828 // We append a 4-byte trailer to each to signify a zero length message. See 829 // lib/pq.conn.sendSimpleMessage for a similar approach to simple messages. 830 _, _ = w.Write([]byte{byte(pgwirebase.ClientMsgSync), 0x00, 0x00, 0x00, 0x04}) 831 _, _ = w.Write([]byte{byte(pgwirebase.ClientMsgTerminate), 0x00, 0x00, 0x00, 0x04}) 832 close(errChan) 833 }() 834 835 stopper := stop.NewStopper() 836 defer stopper.Stop(ctx) 837 838 sqlMetrics := sql.MakeMemMetrics("test" /* endpoint */, time.Second /* histogramWindow */) 839 metrics := makeServerMetrics(sqlMetrics, time.Second /* histogramWindow */) 840 841 conn := newConn( 842 // ConnResultsBufferBytes - really small so that it overflows 843 // when we produce a few results. 844 r, sql.SessionArgs{ConnResultsBufferSize: 10}, &metrics, 845 nil, 846 ) 847 // Ignore the error from serveImpl. There might be one when the client 848 // sends malformed input. 849 conn.serveImpl( 850 ctx, 851 func() bool { return false }, /* draining */ 852 nil, /* sqlServer */ 853 mon.BoundAccount{}, /* reserved */ 854 authOptions{testingSkipAuth: true}, 855 stopper, 856 ) 857 if err := <-errChan; err != nil { 858 t.Fatal(err) 859 } 860 }) 861 } 862 } 863 864 // TestReadTimeoutConn asserts that a readTimeoutConn performs reads normally 865 // and exits with an appropriate error when exit conditions are satisfied. 866 func TestReadTimeoutConnExits(t *testing.T) { 867 defer leaktest.AfterTest(t)() 868 // Cannot use net.Pipe because deadlines are not supported. 869 ln, err := net.Listen(util.TestAddr.Network(), util.TestAddr.String()) 870 if err != nil { 871 t.Fatal(err) 872 } 873 log.Infof(context.Background(), "started listener on %s", ln.Addr()) 874 defer func() { 875 if err := ln.Close(); err != nil { 876 t.Fatal(err) 877 } 878 }() 879 880 ctx, cancel := context.WithCancel(context.Background()) 881 expectedRead := []byte("expectedRead") 882 883 // Start a goroutine that performs reads using a readTimeoutConn. 884 errChan := make(chan error) 885 go func() { 886 defer close(errChan) 887 errChan <- func() error { 888 c, err := ln.Accept() 889 if err != nil { 890 return err 891 } 892 defer c.Close() 893 894 readTimeoutConn := newReadTimeoutConn(c, func() error { return ctx.Err() }) 895 // Assert that reads are performed normally. 896 readBytes := make([]byte, len(expectedRead)) 897 if _, err := readTimeoutConn.Read(readBytes); err != nil { 898 return err 899 } 900 if !bytes.Equal(readBytes, expectedRead) { 901 return errors.Errorf("expected %v got %v", expectedRead, readBytes) 902 } 903 904 // The main goroutine will cancel the context, which should abort 905 // this read with an appropriate error. 906 _, err = readTimeoutConn.Read(make([]byte, 1)) 907 return err 908 }() 909 }() 910 911 c, err := net.Dial(ln.Addr().Network(), ln.Addr().String()) 912 if err != nil { 913 t.Fatal(err) 914 } 915 defer c.Close() 916 917 if _, err := c.Write(expectedRead); err != nil { 918 t.Fatal(err) 919 } 920 921 select { 922 case err := <-errChan: 923 t.Fatalf("goroutine unexpectedly returned: %v", err) 924 default: 925 } 926 cancel() 927 if err := <-errChan; !errors.Is(err, context.Canceled) { 928 t.Fatalf("unexpected error: %v", err) 929 } 930 } 931 932 func TestConnResultsBufferSize(t *testing.T) { 933 defer leaktest.AfterTest(t)() 934 s, db, _ := serverutils.StartServer(t, base.TestServerArgs{}) 935 defer s.Stopper().Stop(context.Background()) 936 937 // Check that SHOW results_buffer_size correctly exposes the value when it 938 // inherits the default. 939 { 940 var size string 941 require.NoError(t, db.QueryRow(`SHOW results_buffer_size`).Scan(&size)) 942 require.Equal(t, `16384`, size) 943 } 944 945 pgURL, cleanup := sqlutils.PGUrl(t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser)) 946 defer cleanup() 947 q := pgURL.Query() 948 949 q.Add(`results_buffer_size`, `foo`) 950 pgURL.RawQuery = q.Encode() 951 { 952 errDB, err := gosql.Open("postgres", pgURL.String()) 953 require.NoError(t, err) 954 defer errDB.Close() 955 _, err = errDB.Exec(`SELECT 1`) 956 require.EqualError(t, err, 957 `pq: error parsing results_buffer_size option value 'foo' as bytes`) 958 } 959 960 q.Del(`results_buffer_size`) 961 q.Add(`results_buffer_size`, `-1`) 962 pgURL.RawQuery = q.Encode() 963 { 964 errDB, err := gosql.Open("postgres", pgURL.String()) 965 require.NoError(t, err) 966 defer errDB.Close() 967 _, err = errDB.Exec(`SELECT 1`) 968 require.EqualError(t, err, `pq: results_buffer_size option value '-1' cannot be negative`) 969 } 970 971 // Set the results_buffer_size to a very small value, eliminating buffering. 972 q.Del(`results_buffer_size`) 973 q.Add(`results_buffer_size`, `2`) 974 pgURL.RawQuery = q.Encode() 975 976 noBufferDB, err := gosql.Open("postgres", pgURL.String()) 977 require.NoError(t, err) 978 defer noBufferDB.Close() 979 980 var size string 981 require.NoError(t, noBufferDB.QueryRow(`SHOW results_buffer_size`).Scan(&size)) 982 require.Equal(t, `2`, size) 983 984 // Run a query that immediately returns one result and then pauses for a 985 // long time while computing the second. 986 rows, err := noBufferDB.Query( 987 `SELECT a, if(a = 1, pg_sleep(99999), false) from (VALUES (0), (1)) AS foo (a)`) 988 require.NoError(t, err) 989 990 // Verify that the first result has been flushed. 991 require.True(t, rows.Next()) 992 var a int 993 var b bool 994 require.NoError(t, rows.Scan(&a, &b)) 995 require.Equal(t, 0, a) 996 require.False(t, b) 997 } 998 999 // Test that closing a connection while authentication was ongoing cancels the 1000 // auhentication process. In other words, this checks that the server is reading 1001 // from the connection while authentication is ongoing and so it reacts to the 1002 // connection closing. 1003 func TestConnCloseCancelsAuth(t *testing.T) { 1004 defer leaktest.AfterTest(t)() 1005 authBlocked := make(chan struct{}) 1006 s, _, _ := serverutils.StartServer(t, 1007 base.TestServerArgs{ 1008 Insecure: true, 1009 Knobs: base.TestingKnobs{ 1010 PGWireTestingKnobs: &sql.PGWireTestingKnobs{ 1011 AuthHook: func(ctx context.Context) error { 1012 // Notify the test. 1013 authBlocked <- struct{}{} 1014 // Wait for context cancelation. 1015 <-ctx.Done() 1016 // Notify the test. 1017 close(authBlocked) 1018 return fmt.Errorf("test auth canceled") 1019 }, 1020 }, 1021 }, 1022 }) 1023 ctx := context.Background() 1024 defer s.Stopper().Stop(ctx) 1025 1026 // We're going to open a client connection and do the minimum so that the 1027 // server gets to the authentication phase, where it will block. 1028 conn, err := net.Dial("tcp", s.ServingSQLAddr()) 1029 if err != nil { 1030 t.Fatal(err) 1031 } 1032 fe, err := pgproto3.NewFrontend(conn, conn) 1033 if err != nil { 1034 t.Fatal(err) 1035 } 1036 if err := fe.Send(&pgproto3.StartupMessage{ProtocolVersion: version30}); err != nil { 1037 t.Fatal(err) 1038 } 1039 1040 // Wait for server to block the auth. 1041 <-authBlocked 1042 // Close the connection. This is supposed to unblock the auth by canceling its 1043 // ctx. 1044 if err := conn.Close(); err != nil { 1045 t.Fatal(err) 1046 } 1047 // Check that the auth process indeed noticed the cancelation. 1048 <-authBlocked 1049 }