github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/pgwire/conn_test.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package pgwire
    12  
    13  import (
    14  	"bytes"
    15  	"context"
    16  	gosql "database/sql"
    17  	"fmt"
    18  	"io"
    19  	"io/ioutil"
    20  	"net"
    21  	"net/url"
    22  	"strconv"
    23  	"strings"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/cockroachdb/cockroach/pkg/base"
    29  	"github.com/cockroachdb/cockroach/pkg/security"
    30  	"github.com/cockroachdb/cockroach/pkg/sql"
    31  	"github.com/cockroachdb/cockroach/pkg/sql/lex"
    32  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    33  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    34  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
    35  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    36  	"github.com/cockroachdb/cockroach/pkg/sql/sessiondata"
    37  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    38  	"github.com/cockroachdb/cockroach/pkg/sql/sqlutil"
    39  	"github.com/cockroachdb/cockroach/pkg/testutils"
    40  	"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
    41  	"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
    42  	"github.com/cockroachdb/cockroach/pkg/util"
    43  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    44  	"github.com/cockroachdb/cockroach/pkg/util/log"
    45  	"github.com/cockroachdb/cockroach/pkg/util/metric"
    46  	"github.com/cockroachdb/cockroach/pkg/util/mon"
    47  	"github.com/cockroachdb/cockroach/pkg/util/stop"
    48  	"github.com/cockroachdb/errors"
    49  	"github.com/jackc/pgx"
    50  	"github.com/jackc/pgx/pgproto3"
    51  	"github.com/stretchr/testify/require"
    52  	"golang.org/x/sync/errgroup"
    53  )
    54  
    55  // Test the conn struct: check that it marshalls the correct commands to the
    56  // stmtBuf.
    57  //
    58  // This test is weird because it aims to be a "unit test" for the conn with
    59  // minimal dependencies, but it needs a producer speaking the pgwire protocol
    60  // on the other end of the connection. We use the pgx Postgres driver for this.
    61  // We're going to simulate a client sending various commands to the server. We
    62  // don't have proper execution of those commands in this test, so we synthesize
    63  // responses.
    64  //
    65  // This test depends on recognizing the queries sent by pgx when it opens a
    66  // connection. If that set of queries changes, this test will probably fail
    67  // complaining that the stmtBuf has an unexpected entry in it.
    68  func TestConn(t *testing.T) {
    69  	defer leaktest.AfterTest(t)()
    70  
    71  	// The test server is used only incidentally by this test: this is not the
    72  	// server that the client will connect to; we just use it on the side to
    73  	// execute some metadata queries that pgx sends whenever it opens a
    74  	// connection.
    75  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{Insecure: true, UseDatabase: "system"})
    76  	defer s.Stopper().Stop(context.Background())
    77  
    78  	// Start a pgwire "server".
    79  	addr := util.TestAddr
    80  	ln, err := net.Listen(addr.Network(), addr.String())
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	serverAddr := ln.Addr()
    85  	log.Infof(context.Background(), "started listener on %s", serverAddr)
    86  
    87  	var g errgroup.Group
    88  	ctx := context.Background()
    89  
    90  	var clientWG sync.WaitGroup
    91  	clientWG.Add(1)
    92  
    93  	g.Go(func() error {
    94  		return client(ctx, serverAddr, &clientWG)
    95  	})
    96  
    97  	// Wait for the client to connect and perform the handshake.
    98  	conn, err := waitForClientConn(ln)
    99  	if err != nil {
   100  		t.Fatal(err)
   101  	}
   102  
   103  	// Run the conn's loop in the background - it will push commands to the
   104  	// buffer.
   105  	serveCtx, stopServe := context.WithCancel(ctx)
   106  	g.Go(func() error {
   107  		conn.serveImpl(
   108  			serveCtx,
   109  			func() bool { return false }, /* draining */
   110  			// sqlServer - nil means don't create a command processor and a write side of the conn
   111  			nil,
   112  			mon.BoundAccount{}, /* reserved */
   113  			authOptions{testingSkipAuth: true},
   114  			s.Stopper())
   115  		return nil
   116  	})
   117  	defer stopServe()
   118  
   119  	if err := processPgxStartup(ctx, s, conn); err != nil {
   120  		t.Fatal(err)
   121  	}
   122  
   123  	// Now we'll expect to receive the commands corresponding to the operations in
   124  	// client().
   125  	rd := sql.MakeStmtBufReader(&conn.stmtBuf)
   126  	expectExecStmt(ctx, t, "SELECT 1", &rd, conn, queryStringComplete)
   127  	expectSync(ctx, t, &rd)
   128  	expectExecStmt(ctx, t, "SELECT 2", &rd, conn, queryStringComplete)
   129  	expectSync(ctx, t, &rd)
   130  	expectPrepareStmt(ctx, t, "p1", "SELECT 'p1'", &rd, conn)
   131  	expectDescribeStmt(ctx, t, "p1", pgwirebase.PrepareStatement, &rd, conn)
   132  	expectSync(ctx, t, &rd)
   133  	expectBindStmt(ctx, t, "p1", &rd, conn)
   134  	expectExecPortal(ctx, t, "", &rd, conn)
   135  	// Check that a query string with multiple queries sent using the simple
   136  	// protocol is broken up.
   137  	expectSync(ctx, t, &rd)
   138  	expectExecStmt(ctx, t, "SELECT 4", &rd, conn, queryStringIncomplete)
   139  	expectExecStmt(ctx, t, "SELECT 5", &rd, conn, queryStringIncomplete)
   140  	expectExecStmt(ctx, t, "SELECT 6", &rd, conn, queryStringComplete)
   141  	expectSync(ctx, t, &rd)
   142  
   143  	// Check that the batching works like the client intended.
   144  
   145  	// pgx wraps batchs in transactions.
   146  	expectExecStmt(ctx, t, "BEGIN TRANSACTION", &rd, conn, queryStringComplete)
   147  	expectSync(ctx, t, &rd)
   148  	expectPrepareStmt(ctx, t, "", "SELECT 7", &rd, conn)
   149  	expectBindStmt(ctx, t, "", &rd, conn)
   150  	expectDescribeStmt(ctx, t, "", pgwirebase.PreparePortal, &rd, conn)
   151  	expectExecPortal(ctx, t, "", &rd, conn)
   152  	expectPrepareStmt(ctx, t, "", "SELECT 8", &rd, conn)
   153  	// Now we'll send an error, in the middle of this batch. pgx will stop waiting
   154  	// for results for commands in the batch. We'll then test that seeking to the
   155  	// next batch advances us to the correct statement.
   156  	if err := finishQuery(generateError, conn); err != nil {
   157  		t.Fatal(err)
   158  	}
   159  	// We're about to seek to the next batch but, as per seek's contract, seeking
   160  	// can only be called when there is something in the buffer. Since the buffer
   161  	// is filled concurrently with this code, we call CurCmd to ensure that
   162  	// there's something in there.
   163  	if _, err := rd.CurCmd(); err != nil {
   164  		t.Fatal(err)
   165  	}
   166  	// Skip all the remaining messages in the batch.
   167  	if err := rd.SeekToNextBatch(); err != nil {
   168  		t.Fatal(err)
   169  	}
   170  	// We got to the COMMIT that pgx pushed to match the BEGIN it generated for
   171  	// the batch.
   172  	expectSync(ctx, t, &rd)
   173  	expectExecStmt(ctx, t, "COMMIT TRANSACTION", &rd, conn, queryStringComplete)
   174  	expectSync(ctx, t, &rd)
   175  	expectExecStmt(ctx, t, "SELECT 9", &rd, conn, queryStringComplete)
   176  	expectSync(ctx, t, &rd)
   177  
   178  	// Test that parse error turns into SendError.
   179  	expectSendError(ctx, t, pgcode.Syntax, &rd, conn)
   180  
   181  	clientWG.Done()
   182  
   183  	if err := g.Wait(); err != nil {
   184  		t.Fatal(err)
   185  	}
   186  }
   187  
   188  // processPgxStartup processes the first few queries that the pgx driver
   189  // automatically sends on a new connection that has been established.
   190  func processPgxStartup(ctx context.Context, s serverutils.TestServerInterface, c *conn) error {
   191  	rd := sql.MakeStmtBufReader(&c.stmtBuf)
   192  
   193  	for {
   194  		cmd, err := rd.CurCmd()
   195  		if err != nil {
   196  			log.Errorf(ctx, "CurCmd error: %v", err)
   197  			return err
   198  		}
   199  
   200  		if _, ok := cmd.(sql.Sync); ok {
   201  			log.Infof(ctx, "advancing Sync")
   202  			rd.AdvanceOne()
   203  			continue
   204  		}
   205  
   206  		exec, ok := cmd.(sql.ExecStmt)
   207  		if !ok {
   208  			log.Infof(ctx, "stop wait at: %v", cmd)
   209  			return nil
   210  		}
   211  		query := exec.AST.String()
   212  		if !strings.HasPrefix(query, "SELECT t.oid") {
   213  			log.Infof(ctx, "stop wait at query: %s", query)
   214  			return nil
   215  		}
   216  		if err := execQuery(ctx, query, s, c); err != nil {
   217  			log.Errorf(ctx, "execQuery %s error: %v", query, err)
   218  			return err
   219  		}
   220  		log.Infof(ctx, "executed query: %s", query)
   221  		rd.AdvanceOne()
   222  	}
   223  }
   224  
   225  // execQuery executes a query on the passed-in server and send the results on c.
   226  func execQuery(
   227  	ctx context.Context, query string, s serverutils.TestServerInterface, c *conn,
   228  ) error {
   229  	rows, cols, err := s.InternalExecutor().(sqlutil.InternalExecutor).QueryWithCols(
   230  		ctx, "test", nil, /* txn */
   231  		sqlbase.InternalExecutorSessionDataOverride{User: security.RootUser, Database: "system"},
   232  		query,
   233  	)
   234  	if err != nil {
   235  		return err
   236  	}
   237  	return sendResult(ctx, c, cols, rows)
   238  }
   239  
   240  func client(ctx context.Context, serverAddr net.Addr, wg *sync.WaitGroup) error {
   241  	host, ports, err := net.SplitHostPort(serverAddr.String())
   242  	if err != nil {
   243  		return err
   244  	}
   245  	port, err := strconv.Atoi(ports)
   246  	if err != nil {
   247  		return err
   248  	}
   249  	conn, err := pgx.Connect(
   250  		pgx.ConnConfig{
   251  			Logger: pgxTestLogger{},
   252  			Host:   host,
   253  			Port:   uint16(port),
   254  			User:   "root",
   255  			// Setting this so that the queries sent by pgx to initialize the
   256  			// connection are not using prepared statements. That simplifies the
   257  			// scaffolding of the test.
   258  			PreferSimpleProtocol: true,
   259  			Database:             "system",
   260  		})
   261  	if err != nil {
   262  		return err
   263  	}
   264  
   265  	if _, err := conn.Exec("select 1"); err != nil {
   266  		return err
   267  	}
   268  	if _, err := conn.Exec("select 2"); err != nil {
   269  		return err
   270  	}
   271  	if _, err := conn.Prepare("p1", "select 'p1'"); err != nil {
   272  		return err
   273  	}
   274  	if _, err := conn.ExecEx(
   275  		ctx, "p1",
   276  		// We set these options because apparently that's how I tell pgx that it
   277  		// should check whether "p1" is a prepared statement.
   278  		&pgx.QueryExOptions{SimpleProtocol: false}); err != nil {
   279  		return err
   280  	}
   281  
   282  	// Send a group of statements as one query string using the simple protocol.
   283  	// We'll check that we receive them one by one, but marked as a batch.
   284  	if _, err := conn.Exec("select 4; select 5; select 6;"); err != nil {
   285  		return err
   286  	}
   287  
   288  	batch := conn.BeginBatch()
   289  	batch.Queue("select 7", nil, nil, nil)
   290  	batch.Queue("select 8", nil, nil, nil)
   291  	if err := batch.Send(context.Background(), &pgx.TxOptions{}); err != nil {
   292  		return err
   293  	}
   294  	if err := batch.Close(); err != nil {
   295  		// Swallow the error that we injected.
   296  		if !strings.Contains(err.Error(), "injected") {
   297  			return err
   298  		}
   299  	}
   300  
   301  	if _, err := conn.Exec("select 9"); err != nil {
   302  		return err
   303  	}
   304  	if _, err := conn.Exec("bogus statement failing to parse"); err != nil {
   305  		return err
   306  	}
   307  
   308  	wg.Wait()
   309  
   310  	return conn.Close()
   311  }
   312  
   313  // waitForClientConn blocks until a client connects and performs the pgwire
   314  // handshake. This emulates what pgwire.Server does.
   315  func waitForClientConn(ln net.Listener) (*conn, error) {
   316  	conn, err := ln.Accept()
   317  	if err != nil {
   318  		return nil, err
   319  	}
   320  
   321  	var buf pgwirebase.ReadBuffer
   322  	_, err = buf.ReadUntypedMsg(conn)
   323  	if err != nil {
   324  		return nil, err
   325  	}
   326  	version, err := buf.GetUint32()
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  	if version != version30 {
   331  		return nil, errors.Errorf("unexpected protocol version: %d", version)
   332  	}
   333  
   334  	// Consume the connection options.
   335  	if _, err := parseClientProvidedSessionParameters(context.Background(), nil, &buf); err != nil {
   336  		return nil, err
   337  	}
   338  
   339  	metrics := makeServerMetrics(sql.MemoryMetrics{} /* sqlMemMetrics */, metric.TestSampleInterval)
   340  	pgwireConn := newConn(conn, sql.SessionArgs{ConnResultsBufferSize: 16 << 10}, &metrics, nil)
   341  	return pgwireConn, nil
   342  }
   343  
   344  func makeTestingConvCfg() sessiondata.DataConversionConfig {
   345  	return sessiondata.DataConversionConfig{
   346  		Location:          time.UTC,
   347  		BytesEncodeFormat: lex.BytesEncodeHex,
   348  	}
   349  }
   350  
   351  // sendResult serializes a set of rows in pgwire format and sends them on a
   352  // connection.
   353  //
   354  // TODO(andrei): Tests using this should probably switch to using the similar
   355  // routines in the connection once conn learns how to write rows.
   356  func sendResult(
   357  	ctx context.Context, c *conn, cols sqlbase.ResultColumns, rows []tree.Datums,
   358  ) error {
   359  	if err := c.writeRowDescription(ctx, cols, nil /* formatCodes */, c.conn); err != nil {
   360  		return err
   361  	}
   362  
   363  	defaultConv := makeTestingConvCfg()
   364  	for _, row := range rows {
   365  		c.msgBuilder.initMsg(pgwirebase.ServerMsgDataRow)
   366  		c.msgBuilder.putInt16(int16(len(row)))
   367  		for _, col := range row {
   368  			c.msgBuilder.writeTextDatum(ctx, col, defaultConv)
   369  		}
   370  
   371  		if err := c.msgBuilder.finishMsg(c.conn); err != nil {
   372  			return err
   373  		}
   374  	}
   375  
   376  	return finishQuery(execute, c)
   377  }
   378  
   379  type executeType int
   380  
   381  const (
   382  	queryStringComplete executeType = iota
   383  	queryStringIncomplete
   384  )
   385  
   386  func expectExecStmt(
   387  	ctx context.Context, t *testing.T, expSQL string, rd *sql.StmtBufReader, c *conn, typ executeType,
   388  ) {
   389  	t.Helper()
   390  	cmd, err := rd.CurCmd()
   391  	if err != nil {
   392  		t.Fatal(err)
   393  	}
   394  	rd.AdvanceOne()
   395  
   396  	es, ok := cmd.(sql.ExecStmt)
   397  	if !ok {
   398  		t.Fatalf("expected command ExecStmt, got: %T (%+v)", cmd, cmd)
   399  	}
   400  
   401  	if es.AST.String() != expSQL {
   402  		t.Fatalf("expected %s, got %s", expSQL, es.AST.String())
   403  	}
   404  
   405  	if es.ParseStart == (time.Time{}) {
   406  		t.Fatalf("ParseStart not filled in")
   407  	}
   408  	if es.ParseEnd == (time.Time{}) {
   409  		t.Fatalf("ParseEnd not filled in")
   410  	}
   411  	if typ == queryStringComplete {
   412  		if err := finishQuery(execute, c); err != nil {
   413  			t.Fatal(err)
   414  		}
   415  	} else {
   416  		if err := finishQuery(cmdComplete, c); err != nil {
   417  			t.Fatal(err)
   418  		}
   419  	}
   420  }
   421  
   422  func expectPrepareStmt(
   423  	ctx context.Context, t *testing.T, expName string, expSQL string, rd *sql.StmtBufReader, c *conn,
   424  ) {
   425  	t.Helper()
   426  	cmd, err := rd.CurCmd()
   427  	if err != nil {
   428  		t.Fatal(err)
   429  	}
   430  	rd.AdvanceOne()
   431  
   432  	pr, ok := cmd.(sql.PrepareStmt)
   433  	if !ok {
   434  		t.Fatalf("expected command PrepareStmt, got: %T (%+v)", cmd, cmd)
   435  	}
   436  
   437  	if pr.Name != expName {
   438  		t.Fatalf("expected name %s, got %s", expName, pr.Name)
   439  	}
   440  
   441  	if pr.AST.String() != expSQL {
   442  		t.Fatalf("expected %s, got %s", expSQL, pr.AST.String())
   443  	}
   444  
   445  	if err := finishQuery(prepare, c); err != nil {
   446  		t.Fatal(err)
   447  	}
   448  }
   449  
   450  func expectDescribeStmt(
   451  	ctx context.Context,
   452  	t *testing.T,
   453  	expName string,
   454  	expType pgwirebase.PrepareType,
   455  	rd *sql.StmtBufReader,
   456  	c *conn,
   457  ) {
   458  	t.Helper()
   459  	cmd, err := rd.CurCmd()
   460  	if err != nil {
   461  		t.Fatal(err)
   462  	}
   463  	rd.AdvanceOne()
   464  
   465  	desc, ok := cmd.(sql.DescribeStmt)
   466  	if !ok {
   467  		t.Fatalf("expected command DescribeStmt, got: %T (%+v)", cmd, cmd)
   468  	}
   469  
   470  	if desc.Name != expName {
   471  		t.Fatalf("expected name %s, got %s", expName, desc.Name)
   472  	}
   473  
   474  	if desc.Type != expType {
   475  		t.Fatalf("expected type %s, got %s", expType, desc.Type)
   476  	}
   477  
   478  	if err := finishQuery(describe, c); err != nil {
   479  		t.Fatal(err)
   480  	}
   481  }
   482  
   483  func expectBindStmt(
   484  	ctx context.Context, t *testing.T, expName string, rd *sql.StmtBufReader, c *conn,
   485  ) {
   486  	t.Helper()
   487  	cmd, err := rd.CurCmd()
   488  	if err != nil {
   489  		t.Fatal(err)
   490  	}
   491  	rd.AdvanceOne()
   492  
   493  	bd, ok := cmd.(sql.BindStmt)
   494  	if !ok {
   495  		t.Fatalf("expected command BindStmt, got: %T (%+v)", cmd, cmd)
   496  	}
   497  
   498  	if bd.PreparedStatementName != expName {
   499  		t.Fatalf("expected name %s, got %s", expName, bd.PreparedStatementName)
   500  	}
   501  
   502  	if err := finishQuery(bind, c); err != nil {
   503  		t.Fatal(err)
   504  	}
   505  }
   506  
   507  func expectSync(ctx context.Context, t *testing.T, rd *sql.StmtBufReader) {
   508  	t.Helper()
   509  	cmd, err := rd.CurCmd()
   510  	if err != nil {
   511  		t.Fatal(err)
   512  	}
   513  	rd.AdvanceOne()
   514  
   515  	_, ok := cmd.(sql.Sync)
   516  	if !ok {
   517  		t.Fatalf("expected command Sync, got: %T (%+v)", cmd, cmd)
   518  	}
   519  }
   520  
   521  func expectExecPortal(
   522  	ctx context.Context, t *testing.T, expName string, rd *sql.StmtBufReader, c *conn,
   523  ) {
   524  	t.Helper()
   525  	cmd, err := rd.CurCmd()
   526  	if err != nil {
   527  		t.Fatal(err)
   528  	}
   529  	rd.AdvanceOne()
   530  
   531  	ep, ok := cmd.(sql.ExecPortal)
   532  	if !ok {
   533  		t.Fatalf("expected command ExecPortal, got: %T (%+v)", cmd, cmd)
   534  	}
   535  
   536  	if ep.Name != expName {
   537  		t.Fatalf("expected name %s, got %s", expName, ep.Name)
   538  	}
   539  
   540  	if err := finishQuery(execPortal, c); err != nil {
   541  		t.Fatal(err)
   542  	}
   543  }
   544  
   545  func expectSendError(
   546  	ctx context.Context, t *testing.T, pgErrCode string, rd *sql.StmtBufReader, c *conn,
   547  ) {
   548  	t.Helper()
   549  	cmd, err := rd.CurCmd()
   550  	if err != nil {
   551  		t.Fatal(err)
   552  	}
   553  	rd.AdvanceOne()
   554  
   555  	se, ok := cmd.(sql.SendError)
   556  	if !ok {
   557  		t.Fatalf("expected command SendError, got: %T (%+v)", cmd, cmd)
   558  	}
   559  
   560  	if code := pgerror.GetPGCode(se.Err); code != pgErrCode {
   561  		t.Fatalf("expected code %s, got: %s", pgErrCode, code)
   562  	}
   563  
   564  	if err := finishQuery(execPortal, c); err != nil {
   565  		t.Fatal(err)
   566  	}
   567  }
   568  
   569  type finishType int
   570  
   571  const (
   572  	execute finishType = iota
   573  	// cmdComplete is like execute, except that it marks the completion of a query
   574  	// in a larger query string and so no ReadyForQuery message should be sent.
   575  	cmdComplete
   576  	prepare
   577  	bind
   578  	describe
   579  	execPortal
   580  	generateError
   581  )
   582  
   583  // Send a CommandComplete/ReadyForQuery to signal that the rows are done.
   584  func finishQuery(t finishType, c *conn) error {
   585  	var skipFinish bool
   586  
   587  	switch t {
   588  	case execPortal:
   589  		fallthrough
   590  	case cmdComplete:
   591  		fallthrough
   592  	case execute:
   593  		c.msgBuilder.initMsg(pgwirebase.ServerMsgCommandComplete)
   594  		// HACK: This message is supposed to contains a command tag but this test is
   595  		// not sure about how to produce one and it works without it.
   596  		c.msgBuilder.nullTerminate()
   597  	case prepare:
   598  		// pgx doesn't send a Sync in between prepare (Parse protocol message) and
   599  		// the subsequent Describe, so we're not going to send a ReadyForQuery
   600  		// below.
   601  		c.msgBuilder.initMsg(pgwirebase.ServerMsgParseComplete)
   602  	case describe:
   603  		skipFinish = true
   604  		if err := c.writeRowDescription(
   605  			context.Background(), nil /* columns */, nil /* formatCodes */, c.conn,
   606  		); err != nil {
   607  			return err
   608  		}
   609  	case bind:
   610  		// pgx doesn't send a Sync mesage in between Bind and Execute, so we're not
   611  		// going to send a ReadyForQuery below.
   612  		c.msgBuilder.initMsg(pgwirebase.ServerMsgBindComplete)
   613  	case generateError:
   614  		c.msgBuilder.initMsg(pgwirebase.ServerMsgErrorResponse)
   615  		c.msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFieldSeverity)
   616  		c.msgBuilder.writeTerminatedString("ERROR")
   617  		c.msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFieldMsgPrimary)
   618  		c.msgBuilder.writeTerminatedString("injected")
   619  		c.msgBuilder.nullTerminate()
   620  		if err := c.msgBuilder.finishMsg(c.conn); err != nil {
   621  			return err
   622  		}
   623  	}
   624  
   625  	if !skipFinish {
   626  		if err := c.msgBuilder.finishMsg(c.conn); err != nil {
   627  			return err
   628  		}
   629  	}
   630  
   631  	if t != cmdComplete && t != bind && t != prepare {
   632  		c.msgBuilder.initMsg(pgwirebase.ServerMsgReady)
   633  		c.msgBuilder.writeByte('I') // transaction status: no txn
   634  		if err := c.msgBuilder.finishMsg(c.conn); err != nil {
   635  			return err
   636  		}
   637  	}
   638  	return nil
   639  }
   640  
   641  type pgxTestLogger struct{}
   642  
   643  func (l pgxTestLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
   644  	log.Infof(context.Background(), "pgx log [%s] %s - %s", level, msg, data)
   645  }
   646  
   647  // pgxTestLogger implements pgx.Logger.
   648  var _ pgx.Logger = pgxTestLogger{}
   649  
   650  // Test that closing a pgwire connection causes transactions to be rolled back
   651  // and release their locks.
   652  func TestConnCloseReleasesLocks(t *testing.T) {
   653  	defer leaktest.AfterTest(t)()
   654  	// We're going to test closing the connection in both the Open and Aborted
   655  	// state.
   656  	testutils.RunTrueAndFalse(t, "open state", func(t *testing.T, open bool) {
   657  		s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
   658  		ctx := context.Background()
   659  		defer s.Stopper().Stop(ctx)
   660  
   661  		pgURL, cleanupFunc := sqlutils.PGUrl(
   662  			t, s.ServingSQLAddr(), "testConnClose" /* prefix */, url.User(security.RootUser),
   663  		)
   664  		defer cleanupFunc()
   665  		db, err := gosql.Open("postgres", pgURL.String())
   666  		require.NoError(t, err)
   667  		defer db.Close()
   668  
   669  		r := sqlutils.MakeSQLRunner(db)
   670  		r.Exec(t, "CREATE DATABASE test")
   671  		r.Exec(t, "CREATE TABLE test.t (x int primary key)")
   672  
   673  		pgxConfig, err := pgx.ParseConnectionString(pgURL.String())
   674  		if err != nil {
   675  			t.Fatal(err)
   676  		}
   677  
   678  		conn, err := pgx.Connect(pgxConfig)
   679  		require.NoError(t, err)
   680  		tx, err := conn.Begin()
   681  		require.NoError(t, err)
   682  		_, err = tx.Exec("INSERT INTO test.t(x) values (1)")
   683  		require.NoError(t, err)
   684  		readCh := make(chan error)
   685  		go func() {
   686  			conn2, err := pgx.Connect(pgxConfig)
   687  			require.NoError(t, err)
   688  			_, err = conn2.Exec("SELECT * FROM test.t")
   689  			readCh <- err
   690  		}()
   691  
   692  		select {
   693  		case err := <-readCh:
   694  			t.Fatalf("unexpected read unblocked: %v", err)
   695  		case <-time.After(10 * time.Millisecond):
   696  		}
   697  
   698  		if !open {
   699  			_, err = tx.Exec("bogus")
   700  			require.NotNil(t, err)
   701  		}
   702  		err = conn.Close()
   703  		require.NoError(t, err)
   704  		select {
   705  		case readErr := <-readCh:
   706  			require.NoError(t, readErr)
   707  		case <-time.After(10 * time.Second):
   708  			t.Fatal("read not unblocked in a timely manner")
   709  		}
   710  	})
   711  }
   712  
   713  // Test that closing a client connection such that producing results rows
   714  // encounters network errors doesn't crash the server (#23694).
   715  //
   716  // We'll run a query that produces a bunch of rows and close the connection as
   717  // soon as the client received anything, this way ensuring that:
   718  // a) the query started executing when the connection is closed, and so it's
   719  // likely to observe a network error and not a context cancelation, and
   720  // b) the connection's server-side results buffer has overflowed, and so
   721  // attempting to produce results (through CommandResult.AddRow()) observes
   722  // network errors.
   723  func TestConnCloseWhileProducingRows(t *testing.T) {
   724  	defer leaktest.AfterTest(t)()
   725  	s, db, _ := serverutils.StartServer(t, base.TestServerArgs{})
   726  	ctx := context.Background()
   727  	defer s.Stopper().Stop(ctx)
   728  
   729  	// Disable results buffering.
   730  	if _, err := db.Exec(
   731  		`SET CLUSTER SETTING sql.defaults.results_buffer.size = '0'`,
   732  	); err != nil {
   733  		t.Fatal(err)
   734  	}
   735  	pgURL, cleanupFunc := sqlutils.PGUrl(
   736  		t, s.ServingSQLAddr(), "testConnClose" /* prefix */, url.User(security.RootUser),
   737  	)
   738  	defer cleanupFunc()
   739  	noBufferDB, err := gosql.Open("postgres", pgURL.String())
   740  	if err != nil {
   741  		t.Fatal(err)
   742  	}
   743  	defer noBufferDB.Close()
   744  
   745  	r := sqlutils.MakeSQLRunner(noBufferDB)
   746  	r.Exec(t, "CREATE DATABASE test")
   747  	r.Exec(t, "CREATE TABLE test.test AS SELECT * FROM generate_series(1,100)")
   748  
   749  	pgxConfig, err := pgx.ParseConnectionString(pgURL.String())
   750  	if err != nil {
   751  		t.Fatal(err)
   752  	}
   753  	// We test both with and without DistSQL, as the way that network errors are
   754  	// observed depends on the engine.
   755  	testutils.RunTrueAndFalse(t, "useDistSQL", func(t *testing.T, useDistSQL bool) {
   756  		conn, err := pgx.Connect(pgxConfig)
   757  		if err != nil {
   758  			t.Fatal(err)
   759  		}
   760  		var query string
   761  		if useDistSQL {
   762  			query = `SET DISTSQL = 'always'`
   763  		} else {
   764  			query = `SET DISTSQL = 'off'`
   765  		}
   766  		if _, err := conn.Exec(query); err != nil {
   767  			t.Fatal(err)
   768  		}
   769  		rows, err := conn.Query("SELECT * FROM test.test")
   770  		if err != nil {
   771  			t.Fatal(err)
   772  		}
   773  		if hasResults := rows.Next(); !hasResults {
   774  			t.Fatal("expected results")
   775  		}
   776  		if err := conn.Close(); err != nil {
   777  			t.Fatal(err)
   778  		}
   779  	})
   780  }
   781  
   782  // TestMaliciousInputs verifies that known malicious inputs sent to
   783  // a v3Conn don't crash the server.
   784  func TestMaliciousInputs(t *testing.T) {
   785  	defer leaktest.AfterTest(t)()
   786  
   787  	ctx := context.Background()
   788  
   789  	for _, tc := range [][]byte{
   790  		// This byte string sends a pgwirebase.ClientMsgClose message type. When
   791  		// ReadBuffer.readUntypedMsg is called, the 4 bytes is subtracted
   792  		// from the size, leaving a 0-length ReadBuffer. Following this,
   793  		// handleClose is called with the empty buffer, which calls
   794  		// getPrepareType. Previously, getPrepareType would crash on an
   795  		// empty buffer. This is now fixed.
   796  		{byte(pgwirebase.ClientMsgClose), 0x00, 0x00, 0x00, 0x04},
   797  		// This byte string exploited the same bug using a pgwirebase.ClientMsgDescribe
   798  		// message type.
   799  		{byte(pgwirebase.ClientMsgDescribe), 0x00, 0x00, 0x00, 0x04},
   800  		// This would cause ReadBuffer.getInt16 to overflow, resulting in a
   801  		// negative value being used for an allocation size.
   802  		{byte(pgwirebase.ClientMsgParse), 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0xff, 0xff},
   803  	} {
   804  		t.Run("", func(t *testing.T) {
   805  			w, r := net.Pipe()
   806  			defer w.Close()
   807  			defer r.Close()
   808  
   809  			go func() {
   810  				// This io.Copy will discard all bytes from w until w is closed.
   811  				// This is needed because sends on the net.Pipe are synchronous, so
   812  				// the conn will block if we don't read whatever it tries to send.
   813  				// The reason this works is that ioutil.devNull implements ReadFrom
   814  				// as an infinite loop, so it will Read continuously until it hits an
   815  				// error (on w.Close()).
   816  				_, _ = io.Copy(ioutil.Discard, w)
   817  			}()
   818  
   819  			errChan := make(chan error, 1)
   820  			go func() {
   821  				// Write the malicious data.
   822  				if _, err := w.Write(tc); err != nil {
   823  					errChan <- err
   824  					return
   825  				}
   826  
   827  				// Sync and terminate if a panic did not occur to stop the server.
   828  				// We append a 4-byte trailer to each to signify a zero length message. See
   829  				// lib/pq.conn.sendSimpleMessage for a similar approach to simple messages.
   830  				_, _ = w.Write([]byte{byte(pgwirebase.ClientMsgSync), 0x00, 0x00, 0x00, 0x04})
   831  				_, _ = w.Write([]byte{byte(pgwirebase.ClientMsgTerminate), 0x00, 0x00, 0x00, 0x04})
   832  				close(errChan)
   833  			}()
   834  
   835  			stopper := stop.NewStopper()
   836  			defer stopper.Stop(ctx)
   837  
   838  			sqlMetrics := sql.MakeMemMetrics("test" /* endpoint */, time.Second /* histogramWindow */)
   839  			metrics := makeServerMetrics(sqlMetrics, time.Second /* histogramWindow */)
   840  
   841  			conn := newConn(
   842  				// ConnResultsBufferBytes - really small so that it overflows
   843  				// when we produce a few results.
   844  				r, sql.SessionArgs{ConnResultsBufferSize: 10}, &metrics,
   845  				nil,
   846  			)
   847  			// Ignore the error from serveImpl. There might be one when the client
   848  			// sends malformed input.
   849  			conn.serveImpl(
   850  				ctx,
   851  				func() bool { return false }, /* draining */
   852  				nil,                          /* sqlServer */
   853  				mon.BoundAccount{},           /* reserved */
   854  				authOptions{testingSkipAuth: true},
   855  				stopper,
   856  			)
   857  			if err := <-errChan; err != nil {
   858  				t.Fatal(err)
   859  			}
   860  		})
   861  	}
   862  }
   863  
   864  // TestReadTimeoutConn asserts that a readTimeoutConn performs reads normally
   865  // and exits with an appropriate error when exit conditions are satisfied.
   866  func TestReadTimeoutConnExits(t *testing.T) {
   867  	defer leaktest.AfterTest(t)()
   868  	// Cannot use net.Pipe because deadlines are not supported.
   869  	ln, err := net.Listen(util.TestAddr.Network(), util.TestAddr.String())
   870  	if err != nil {
   871  		t.Fatal(err)
   872  	}
   873  	log.Infof(context.Background(), "started listener on %s", ln.Addr())
   874  	defer func() {
   875  		if err := ln.Close(); err != nil {
   876  			t.Fatal(err)
   877  		}
   878  	}()
   879  
   880  	ctx, cancel := context.WithCancel(context.Background())
   881  	expectedRead := []byte("expectedRead")
   882  
   883  	// Start a goroutine that performs reads using a readTimeoutConn.
   884  	errChan := make(chan error)
   885  	go func() {
   886  		defer close(errChan)
   887  		errChan <- func() error {
   888  			c, err := ln.Accept()
   889  			if err != nil {
   890  				return err
   891  			}
   892  			defer c.Close()
   893  
   894  			readTimeoutConn := newReadTimeoutConn(c, func() error { return ctx.Err() })
   895  			// Assert that reads are performed normally.
   896  			readBytes := make([]byte, len(expectedRead))
   897  			if _, err := readTimeoutConn.Read(readBytes); err != nil {
   898  				return err
   899  			}
   900  			if !bytes.Equal(readBytes, expectedRead) {
   901  				return errors.Errorf("expected %v got %v", expectedRead, readBytes)
   902  			}
   903  
   904  			// The main goroutine will cancel the context, which should abort
   905  			// this read with an appropriate error.
   906  			_, err = readTimeoutConn.Read(make([]byte, 1))
   907  			return err
   908  		}()
   909  	}()
   910  
   911  	c, err := net.Dial(ln.Addr().Network(), ln.Addr().String())
   912  	if err != nil {
   913  		t.Fatal(err)
   914  	}
   915  	defer c.Close()
   916  
   917  	if _, err := c.Write(expectedRead); err != nil {
   918  		t.Fatal(err)
   919  	}
   920  
   921  	select {
   922  	case err := <-errChan:
   923  		t.Fatalf("goroutine unexpectedly returned: %v", err)
   924  	default:
   925  	}
   926  	cancel()
   927  	if err := <-errChan; !errors.Is(err, context.Canceled) {
   928  		t.Fatalf("unexpected error: %v", err)
   929  	}
   930  }
   931  
   932  func TestConnResultsBufferSize(t *testing.T) {
   933  	defer leaktest.AfterTest(t)()
   934  	s, db, _ := serverutils.StartServer(t, base.TestServerArgs{})
   935  	defer s.Stopper().Stop(context.Background())
   936  
   937  	// Check that SHOW results_buffer_size correctly exposes the value when it
   938  	// inherits the default.
   939  	{
   940  		var size string
   941  		require.NoError(t, db.QueryRow(`SHOW results_buffer_size`).Scan(&size))
   942  		require.Equal(t, `16384`, size)
   943  	}
   944  
   945  	pgURL, cleanup := sqlutils.PGUrl(t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
   946  	defer cleanup()
   947  	q := pgURL.Query()
   948  
   949  	q.Add(`results_buffer_size`, `foo`)
   950  	pgURL.RawQuery = q.Encode()
   951  	{
   952  		errDB, err := gosql.Open("postgres", pgURL.String())
   953  		require.NoError(t, err)
   954  		defer errDB.Close()
   955  		_, err = errDB.Exec(`SELECT 1`)
   956  		require.EqualError(t, err,
   957  			`pq: error parsing results_buffer_size option value 'foo' as bytes`)
   958  	}
   959  
   960  	q.Del(`results_buffer_size`)
   961  	q.Add(`results_buffer_size`, `-1`)
   962  	pgURL.RawQuery = q.Encode()
   963  	{
   964  		errDB, err := gosql.Open("postgres", pgURL.String())
   965  		require.NoError(t, err)
   966  		defer errDB.Close()
   967  		_, err = errDB.Exec(`SELECT 1`)
   968  		require.EqualError(t, err, `pq: results_buffer_size option value '-1' cannot be negative`)
   969  	}
   970  
   971  	// Set the results_buffer_size to a very small value, eliminating buffering.
   972  	q.Del(`results_buffer_size`)
   973  	q.Add(`results_buffer_size`, `2`)
   974  	pgURL.RawQuery = q.Encode()
   975  
   976  	noBufferDB, err := gosql.Open("postgres", pgURL.String())
   977  	require.NoError(t, err)
   978  	defer noBufferDB.Close()
   979  
   980  	var size string
   981  	require.NoError(t, noBufferDB.QueryRow(`SHOW results_buffer_size`).Scan(&size))
   982  	require.Equal(t, `2`, size)
   983  
   984  	// Run a query that immediately returns one result and then pauses for a
   985  	// long time while computing the second.
   986  	rows, err := noBufferDB.Query(
   987  		`SELECT a, if(a = 1, pg_sleep(99999), false) from (VALUES (0), (1)) AS foo (a)`)
   988  	require.NoError(t, err)
   989  
   990  	// Verify that the first result has been flushed.
   991  	require.True(t, rows.Next())
   992  	var a int
   993  	var b bool
   994  	require.NoError(t, rows.Scan(&a, &b))
   995  	require.Equal(t, 0, a)
   996  	require.False(t, b)
   997  }
   998  
   999  // Test that closing a connection while authentication was ongoing cancels the
  1000  // auhentication process. In other words, this checks that the server is reading
  1001  // from the connection while authentication is ongoing and so it reacts to the
  1002  // connection closing.
  1003  func TestConnCloseCancelsAuth(t *testing.T) {
  1004  	defer leaktest.AfterTest(t)()
  1005  	authBlocked := make(chan struct{})
  1006  	s, _, _ := serverutils.StartServer(t,
  1007  		base.TestServerArgs{
  1008  			Insecure: true,
  1009  			Knobs: base.TestingKnobs{
  1010  				PGWireTestingKnobs: &sql.PGWireTestingKnobs{
  1011  					AuthHook: func(ctx context.Context) error {
  1012  						// Notify the test.
  1013  						authBlocked <- struct{}{}
  1014  						// Wait for context cancelation.
  1015  						<-ctx.Done()
  1016  						// Notify the test.
  1017  						close(authBlocked)
  1018  						return fmt.Errorf("test auth canceled")
  1019  					},
  1020  				},
  1021  			},
  1022  		})
  1023  	ctx := context.Background()
  1024  	defer s.Stopper().Stop(ctx)
  1025  
  1026  	// We're going to open a client connection and do the minimum so that the
  1027  	// server gets to the authentication phase, where it will block.
  1028  	conn, err := net.Dial("tcp", s.ServingSQLAddr())
  1029  	if err != nil {
  1030  		t.Fatal(err)
  1031  	}
  1032  	fe, err := pgproto3.NewFrontend(conn, conn)
  1033  	if err != nil {
  1034  		t.Fatal(err)
  1035  	}
  1036  	if err := fe.Send(&pgproto3.StartupMessage{ProtocolVersion: version30}); err != nil {
  1037  		t.Fatal(err)
  1038  	}
  1039  
  1040  	// Wait for server to block the auth.
  1041  	<-authBlocked
  1042  	// Close the connection. This is supposed to unblock the auth by canceling its
  1043  	// ctx.
  1044  	if err := conn.Close(); err != nil {
  1045  		t.Fatal(err)
  1046  	}
  1047  	// Check that the auth process indeed noticed the cancelation.
  1048  	<-authBlocked
  1049  }