github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/pgwire/pgwire_test.go (about)

     1  // Copyright 2015 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package pgwire_test
    12  
    13  import (
    14  	"context"
    15  	gosql "database/sql"
    16  	"database/sql/driver"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"io/ioutil"
    21  	"net"
    22  	"net/url"
    23  	"os"
    24  	"path/filepath"
    25  	"reflect"
    26  	"runtime"
    27  	"strconv"
    28  	"strings"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/cockroachdb/cockroach/pkg/base"
    33  	"github.com/cockroachdb/cockroach/pkg/security"
    34  	"github.com/cockroachdb/cockroach/pkg/server"
    35  	"github.com/cockroachdb/cockroach/pkg/server/telemetry"
    36  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire"
    37  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    38  	"github.com/cockroachdb/cockroach/pkg/testutils"
    39  	"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
    40  	"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
    41  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    42  	"github.com/cockroachdb/cockroach/pkg/util/log"
    43  	"github.com/cockroachdb/errors"
    44  	"github.com/jackc/pgx"
    45  	"github.com/jackc/pgx/pgproto3"
    46  	"github.com/lib/pq"
    47  )
    48  
    49  func wrongArgCountString(want, got int) string {
    50  	return fmt.Sprintf("sql: expected %d arguments, got %d", want, got)
    51  }
    52  
    53  func trivialQuery(pgURL url.URL) error {
    54  	db, err := gosql.Open("postgres", pgURL.String())
    55  	if err != nil {
    56  		return err
    57  	}
    58  	defer db.Close()
    59  	{
    60  		_, err := db.Exec("SELECT 1")
    61  		return err
    62  	}
    63  }
    64  
    65  // TestPGWireDrainClient makes sure that in draining mode, the server refuses
    66  // new connections and allows sessions with ongoing transactions to finish.
    67  func TestPGWireDrainClient(t *testing.T) {
    68  	defer leaktest.AfterTest(t)()
    69  	params := base.TestServerArgs{Insecure: true}
    70  	s, _, _ := serverutils.StartServer(t, params)
    71  
    72  	ctx := context.Background()
    73  	defer s.Stopper().Stop(ctx)
    74  
    75  	host, port, err := net.SplitHostPort(s.ServingSQLAddr())
    76  	if err != nil {
    77  		t.Fatal(err)
    78  	}
    79  
    80  	pgBaseURL := url.URL{
    81  		Scheme:   "postgres",
    82  		Host:     net.JoinHostPort(host, port),
    83  		User:     url.User(security.RootUser),
    84  		RawQuery: "sslmode=disable",
    85  	}
    86  
    87  	db, err := gosql.Open("postgres", pgBaseURL.String())
    88  	if err != nil {
    89  		t.Fatal(err)
    90  	}
    91  	defer db.Close()
    92  
    93  	txn, err := db.Begin()
    94  	if err != nil {
    95  		t.Fatal(err)
    96  	}
    97  
    98  	// Draining runs in a separate goroutine since it won't return until the
    99  	// connection with an ongoing transaction finishes.
   100  	errChan := make(chan error)
   101  	go func() {
   102  		defer close(errChan)
   103  		errChan <- func() error {
   104  			return s.(*server.TestServer).DrainClients(ctx)
   105  		}()
   106  	}()
   107  
   108  	// Ensure server is in draining mode and rejects new connections.
   109  	testutils.SucceedsSoon(t, func() error {
   110  		if err := trivialQuery(pgBaseURL); !testutils.IsError(err, pgwire.ErrDrainingNewConn) {
   111  			return errors.Errorf("unexpected error: %v", err)
   112  		}
   113  		return nil
   114  	})
   115  
   116  	if _, err := txn.Exec("SELECT 1"); err != nil {
   117  		t.Fatal(err)
   118  	}
   119  	if err := txn.Commit(); err != nil {
   120  		t.Fatal(err)
   121  	}
   122  
   123  	for err := range errChan {
   124  		if err != nil {
   125  			t.Fatal(err)
   126  		}
   127  	}
   128  
   129  	if !s.(*server.TestServer).PGServer().IsDraining() {
   130  		t.Fatal("server should be draining, but is not")
   131  	}
   132  }
   133  
   134  // TestPGWireDrainOngoingTxns tests that connections with open transactions are
   135  // canceled when they go on for too long.
   136  func TestPGWireDrainOngoingTxns(t *testing.T) {
   137  	defer leaktest.AfterTest(t)()
   138  	params := base.TestServerArgs{Insecure: true}
   139  	s, _, _ := serverutils.StartServer(t, params)
   140  	defer s.Stopper().Stop(context.Background())
   141  
   142  	host, port, err := net.SplitHostPort(s.ServingSQLAddr())
   143  	if err != nil {
   144  		t.Fatal(err)
   145  	}
   146  
   147  	pgBaseURL := url.URL{
   148  		Scheme:   "postgres",
   149  		Host:     net.JoinHostPort(host, port),
   150  		User:     url.User(security.RootUser),
   151  		RawQuery: "sslmode=disable",
   152  	}
   153  
   154  	db, err := gosql.Open("postgres", pgBaseURL.String())
   155  	if err != nil {
   156  		t.Fatal(err)
   157  	}
   158  	defer db.Close()
   159  
   160  	pgServer := s.(*server.TestServer).PGServer()
   161  
   162  	// Make sure that the server reports correctly the case in which a
   163  	// connection did not respond to cancellation in time.
   164  	t.Run("CancelResponseFailure", func(t *testing.T) {
   165  		txn, err := db.Begin()
   166  		if err != nil {
   167  			t.Fatal(err)
   168  		}
   169  
   170  		// Overwrite the pgServer's cancel map to avoid race conditions in
   171  		// which the connection is canceled and closes itself before the
   172  		// pgServer stops waiting for connections to respond to cancellation.
   173  		realCancels := pgServer.OverwriteCancelMap()
   174  
   175  		// Set draining with no drainWait or cancelWait timeout. The expected
   176  		// behavior is that the ongoing session is immediately canceled but
   177  		// since we overwrote the context.CancelFunc, this cancellation will
   178  		// not have any effect. The pgServer will not bother to wait for the
   179  		// connection to close properly and should notify the caller that a
   180  		// session did not respond to cancellation.
   181  		if err := pgServer.DrainImpl(
   182  			0 /* drainWait */, 0, /* cancelWait */
   183  		); !testutils.IsError(err, "some sessions did not respond to cancellation") {
   184  			t.Fatalf("unexpected error: %v", err)
   185  		}
   186  
   187  		// Actually cancel the connection.
   188  		for _, cancel := range realCancels {
   189  			cancel()
   190  		}
   191  
   192  		// Make sure that the connection was disrupted. A retry loop is needed
   193  		// because we must wait (since we told the pgServer not to) until the
   194  		// connection registers the cancellation and closes itself.
   195  		testutils.SucceedsSoon(t, func() error {
   196  			if _, err := txn.Exec("SELECT 1"); !errors.Is(err, driver.ErrBadConn) {
   197  				return errors.Errorf("unexpected error: %v", err)
   198  			}
   199  			return nil
   200  		})
   201  
   202  		if err := txn.Commit(); !errors.Is(err, driver.ErrBadConn) {
   203  			t.Fatalf("unexpected error: %v", err)
   204  		}
   205  
   206  		pgServer.Undrain()
   207  	})
   208  
   209  	// Make sure that a connection gets canceled and correctly responds to this
   210  	// cancellation by closing itself.
   211  	t.Run("CancelResponseSuccess", func(t *testing.T) {
   212  		txn, err := db.Begin()
   213  		if err != nil {
   214  			t.Fatal(err)
   215  		}
   216  
   217  		// Set draining with no drainWait timeout and a 2s cancelWait timeout.
   218  		// The expected behavior is for the pgServer to immediately cancel any
   219  		// ongoing sessions and wait for 2s for the cancellation to take effect.
   220  		if err := pgServer.DrainImpl(
   221  			0 /* drainWait */, 2*time.Second, /* cancelWait */
   222  		); err != nil {
   223  			t.Fatal(err)
   224  		}
   225  
   226  		if err := txn.Commit(); err == nil ||
   227  			(!errors.Is(err, driver.ErrBadConn) &&
   228  				!strings.Contains(err.Error(), "connection reset by peer")) {
   229  			t.Fatalf("unexpected error: %v", err)
   230  		}
   231  
   232  		pgServer.Undrain()
   233  	})
   234  }
   235  
   236  // We want to ensure that despite use of errors.{Wrap,Wrapf}, we are surfacing a
   237  // pq.Error.
   238  func TestPGUnwrapError(t *testing.T) {
   239  	defer leaktest.AfterTest(t)()
   240  
   241  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
   242  	defer s.Stopper().Stop(context.Background())
   243  
   244  	pgURL, cleanupFn := sqlutils.PGUrl(t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
   245  	defer cleanupFn()
   246  
   247  	db, err := gosql.Open("postgres", pgURL.String())
   248  	if err != nil {
   249  		t.Fatal(err)
   250  	}
   251  	defer db.Close()
   252  
   253  	// This is just a statement that is known to utilize errors.Wrap.
   254  	stmt := "SELECT COALESCE(2, 'foo')"
   255  
   256  	if _, err := db.Exec(stmt); err == nil {
   257  		t.Fatalf("expected %s to error", stmt)
   258  	} else {
   259  		if !errors.HasType(err, (*pq.Error)(nil)) {
   260  			t.Fatalf("pgwire should be surfacing a pq.Error")
   261  		}
   262  	}
   263  }
   264  
   265  func TestPGPrepareFail(t *testing.T) {
   266  	defer leaktest.AfterTest(t)()
   267  
   268  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
   269  	defer s.Stopper().Stop(context.Background())
   270  
   271  	pgURL, cleanupFn := sqlutils.PGUrl(t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
   272  	defer cleanupFn()
   273  
   274  	db, err := gosql.Open("postgres", pgURL.String())
   275  	if err != nil {
   276  		t.Fatal(err)
   277  	}
   278  	defer db.Close()
   279  
   280  	testFailures := map[string]string{
   281  		"SELECT $1 = $1":                            "pq: could not determine data type of placeholder $1",
   282  		"SELECT $1":                                 "pq: could not determine data type of placeholder $1",
   283  		"SELECT $1 + $1":                            "pq: could not determine data type of placeholder $1",
   284  		"SELECT CASE WHEN TRUE THEN $1 END":         "pq: could not determine data type of placeholder $1",
   285  		"SELECT CASE WHEN TRUE THEN $1 ELSE $2 END": "pq: could not determine data type of placeholder $1",
   286  		"SELECT $1 > 0 AND NOT $1":                  "pq: placeholder $1 already has type int, cannot assign bool",
   287  		"CREATE TABLE $1 (id INT)":                  "pq: at or near \"1\": syntax error",
   288  		"UPDATE d.t SET s = i + $1":                 "pq: unsupported binary operator: <int> + <anyelement> (desired <string>)",
   289  		"SELECT $0 > 0":                             "pq: lexical error: placeholder index must be between 1 and 65536",
   290  		"SELECT $2 > 0":                             "pq: could not determine data type of placeholder $1",
   291  		"SELECT 3 + CASE (4) WHEN 4 THEN $1 END":    "pq: could not determine data type of placeholder $1",
   292  		"SELECT ($1 + $1) + current_date()":         "pq: could not determine data type of placeholder $1",
   293  		"SELECT $1 + $2, $2::FLOAT":                 "pq: could not determine data type of placeholder $1",
   294  		"SELECT $1[2]":                              "pq: could not determine data type of placeholder $1",
   295  		"SELECT ($1 + 2) + ($1 + 2.5::FLOAT)":       "pq: unsupported binary operator: <int> + <float>",
   296  	}
   297  
   298  	if _, err := db.Exec(`CREATE DATABASE d; CREATE TABLE d.t (i INT, s STRING, d INT)`); err != nil {
   299  		t.Fatal(err)
   300  	}
   301  
   302  	for query, reason := range testFailures {
   303  		if stmt, err := db.Prepare(query); err == nil {
   304  			t.Errorf("expected error: %s", query)
   305  			if err := stmt.Close(); err != nil {
   306  				t.Fatal(err)
   307  			}
   308  		} else if err.Error() != reason {
   309  			t.Errorf(`%s: got: %q, expected: %q`, query, err, reason)
   310  		}
   311  	}
   312  }
   313  
   314  // Run a Prepare referencing a table created or dropped in the same
   315  // transaction.
   316  func TestPGPrepareWithCreateDropInTxn(t *testing.T) {
   317  	defer leaktest.AfterTest(t)()
   318  
   319  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
   320  	defer s.Stopper().Stop(context.Background())
   321  
   322  	pgURL, cleanupFn := sqlutils.PGUrl(t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
   323  	defer cleanupFn()
   324  
   325  	db, err := gosql.Open("postgres", pgURL.String())
   326  	if err != nil {
   327  		t.Fatal(err)
   328  	}
   329  	defer db.Close()
   330  
   331  	{
   332  		tx, err := db.Begin()
   333  		if err != nil {
   334  			t.Fatal(err)
   335  		}
   336  
   337  		if _, err := tx.Exec(`
   338  	CREATE DATABASE d;
   339  	CREATE TABLE d.kv (k VARCHAR PRIMARY KEY, v VARCHAR);
   340  `); err != nil {
   341  			t.Fatal(err)
   342  		}
   343  
   344  		stmt, err := tx.Prepare(`INSERT INTO d.kv (k,v) VALUES ($1, $2);`)
   345  		if err != nil {
   346  			t.Fatal(err)
   347  		}
   348  
   349  		res, err := stmt.Exec('a', 'b')
   350  		if err != nil {
   351  			t.Fatal(err)
   352  		}
   353  		stmt.Close()
   354  		affected, err := res.RowsAffected()
   355  		if err != nil {
   356  			t.Fatal(err)
   357  		}
   358  		if affected != 1 {
   359  			t.Fatalf("unexpected number of rows affected: %d", affected)
   360  		}
   361  
   362  		if err := tx.Commit(); err != nil {
   363  			t.Fatal(err)
   364  		}
   365  	}
   366  
   367  	{
   368  		tx, err := db.Begin()
   369  		if err != nil {
   370  			t.Fatal(err)
   371  		}
   372  
   373  		if _, err := tx.Exec(`
   374  	DROP TABLE d.kv;
   375  `); err != nil {
   376  			t.Fatal(err)
   377  		}
   378  
   379  		if _, err := tx.Prepare(`
   380  INSERT INTO d.kv (k,v) VALUES ($1, $2);
   381  `); !testutils.IsError(err, "relation \"d.kv\" does not exist") {
   382  			t.Fatalf("err = %v", err)
   383  		}
   384  
   385  		if err := tx.Rollback(); err != nil {
   386  			t.Fatal(err)
   387  		}
   388  	}
   389  }
   390  
   391  type preparedQueryTest struct {
   392  	qargs   []interface{}
   393  	results [][]interface{}
   394  	others  int
   395  	error   string
   396  	// preparedError determines the error to expect upon stmt.Query()
   397  	// (executing a prepared statement), as opposed to db.Query()
   398  	// (direct query without prepare). If left empty, error above is
   399  	// used for both.
   400  	preparedError string
   401  }
   402  
   403  func (p preparedQueryTest) SetArgs(v ...interface{}) preparedQueryTest {
   404  	p.qargs = v
   405  	return p
   406  }
   407  
   408  func (p preparedQueryTest) Results(v ...interface{}) preparedQueryTest {
   409  	p.results = append(p.results, v)
   410  	return p
   411  }
   412  
   413  func (p preparedQueryTest) Others(o int) preparedQueryTest {
   414  	p.others = o
   415  	return p
   416  }
   417  
   418  func (p preparedQueryTest) Error(err string) preparedQueryTest {
   419  	p.error = err
   420  	return p
   421  }
   422  
   423  func (p preparedQueryTest) PreparedError(err string) preparedQueryTest {
   424  	p.preparedError = err
   425  	return p
   426  }
   427  
   428  func TestPGPreparedQuery(t *testing.T) {
   429  	defer leaktest.AfterTest(t)()
   430  	var baseTest preparedQueryTest
   431  
   432  	queryTests := []struct {
   433  		sql   string
   434  		ptest []preparedQueryTest
   435  	}{
   436  		{"SELECT $1 > 0", []preparedQueryTest{
   437  			baseTest.SetArgs(1).Results(true),
   438  			baseTest.SetArgs("1").Results(true),
   439  			baseTest.SetArgs(1.1).Error(`pq: error in argument for $1: strconv.ParseInt: parsing "1.1": invalid syntax`).Results(true),
   440  			baseTest.SetArgs("1.0").Error(`pq: error in argument for $1: strconv.ParseInt: parsing "1.0": invalid syntax`),
   441  			baseTest.SetArgs(true).Error(`pq: error in argument for $1: strconv.ParseInt: parsing "true": invalid syntax`),
   442  		}},
   443  		{"SELECT ($1) > 0", []preparedQueryTest{
   444  			baseTest.SetArgs(1).Results(true),
   445  			baseTest.SetArgs(-1).Results(false),
   446  		}},
   447  		{"SELECT ((($1))) > 0", []preparedQueryTest{
   448  			baseTest.SetArgs(1).Results(true),
   449  			baseTest.SetArgs(-1).Results(false),
   450  		}},
   451  		{"SELECT TRUE AND $1", []preparedQueryTest{
   452  			baseTest.SetArgs(true).Results(true),
   453  			baseTest.SetArgs(false).Results(false),
   454  			baseTest.SetArgs(1).Results(true),
   455  			baseTest.SetArgs("").Error(`pq: error in argument for $1: strconv.ParseBool: parsing "": invalid syntax`),
   456  			// Make sure we can run another after a failure.
   457  			baseTest.SetArgs(true).Results(true),
   458  		}},
   459  		{"SELECT $1::bool", []preparedQueryTest{
   460  			baseTest.SetArgs(true).Results(true),
   461  			baseTest.SetArgs("true").Results(true),
   462  			baseTest.SetArgs("false").Results(false),
   463  			baseTest.SetArgs("1").Results(true),
   464  			baseTest.SetArgs(2).Error(`pq: error in argument for $1: strconv.ParseBool: parsing "2": invalid syntax`),
   465  			baseTest.SetArgs(3.1).Error(`pq: error in argument for $1: strconv.ParseBool: parsing "3.1": invalid syntax`),
   466  			baseTest.SetArgs("").Error(`pq: error in argument for $1: strconv.ParseBool: parsing "": invalid syntax`),
   467  		}},
   468  		{"SELECT CASE 40+2 WHEN 42 THEN 51 ELSE $1::INT END", []preparedQueryTest{
   469  			baseTest.Error(
   470  				"pq: no value provided for placeholder: $1",
   471  			).PreparedError(
   472  				wrongArgCountString(1, 0),
   473  			),
   474  		}},
   475  		{"SELECT $1::int > $2::float", []preparedQueryTest{
   476  			baseTest.SetArgs(2, 1).Results(true),
   477  			baseTest.SetArgs("2", 1).Results(true),
   478  			baseTest.SetArgs(1, "2").Results(false),
   479  			baseTest.SetArgs("2", "1.0").Results(true),
   480  			baseTest.SetArgs("2.0", "1").Error(`pq: error in argument for $1: strconv.ParseInt: parsing "2.0": invalid syntax`),
   481  			baseTest.SetArgs(2.1, 1).Error(`pq: error in argument for $1: strconv.ParseInt: parsing "2.1": invalid syntax`),
   482  		}},
   483  		{"SELECT greatest($1, 0, $2), $2", []preparedQueryTest{
   484  			baseTest.SetArgs(1, -1).Results(1, -1),
   485  			baseTest.SetArgs(-1, 10).Results(10, 10),
   486  			baseTest.SetArgs("-2", "-1").Results(0, -1),
   487  			baseTest.SetArgs(1, 2.1).Error(`pq: error in argument for $2: strconv.ParseInt: parsing "2.1": invalid syntax`),
   488  		}},
   489  		{"SELECT $1::int, $1::float", []preparedQueryTest{
   490  			baseTest.SetArgs(1).Results(1, 1.0),
   491  			baseTest.SetArgs("1").Results(1, 1.0),
   492  		}},
   493  		{"SELECT 3 + $1, $1 + $2", []preparedQueryTest{
   494  			baseTest.SetArgs("1", "2").Results(4, 3),
   495  			baseTest.SetArgs(3, "4").Results(6, 7),
   496  			baseTest.SetArgs(0, "a").Error(`pq: error in argument for $2: strconv.ParseInt: parsing "a": invalid syntax`),
   497  		}},
   498  		// Check for name resolution.
   499  		{"SELECT count(*)", []preparedQueryTest{
   500  			baseTest.Results(1),
   501  		}},
   502  		{"SELECT CASE WHEN $1 THEN 1-$3 WHEN $2 THEN 1+$3 END", []preparedQueryTest{
   503  			baseTest.SetArgs(true, false, 2).Results(-1),
   504  			baseTest.SetArgs(false, true, 3).Results(4),
   505  			baseTest.SetArgs(false, false, 2).Results(gosql.NullBool{}),
   506  		}},
   507  		{"SELECT CASE 1 WHEN $1 THEN $2 ELSE 2 END", []preparedQueryTest{
   508  			baseTest.SetArgs(1, 3).Results(3),
   509  			baseTest.SetArgs(2, 3).Results(2),
   510  			baseTest.SetArgs(true, 0).Error(`pq: error in argument for $1: strconv.ParseInt: parsing "true": invalid syntax`),
   511  		}},
   512  		{"SELECT $1[2] LIKE 'b'", []preparedQueryTest{
   513  			baseTest.SetArgs(pq.Array([]string{"a", "b", "c"})).Results(true),
   514  			baseTest.SetArgs(pq.Array([]gosql.NullString{{String: "a", Valid: true}, {Valid: false}, {String: "c", Valid: true}})).Results(gosql.NullBool{Valid: false}),
   515  		}},
   516  		{"SET application_name = $1", []preparedQueryTest{
   517  			baseTest.SetArgs("hello world"),
   518  		}},
   519  		{"SET CLUSTER SETTING cluster.organization = $1", []preparedQueryTest{
   520  			baseTest.SetArgs("hello world"),
   521  		}},
   522  		{"SHOW DATABASE", []preparedQueryTest{
   523  			baseTest.Results("defaultdb"),
   524  		}},
   525  		{"SELECT descriptor FROM system.descriptor WHERE descriptor != $1 LIMIT 1", []preparedQueryTest{
   526  			baseTest.SetArgs([]byte("abc")).Results([]byte("\x12!\n\x06system\x10\x01\x1a\x15\n\t\n\x05admin\x100\n\b\n\x04root\x100")),
   527  		}},
   528  		{"SHOW COLUMNS FROM system.users", []preparedQueryTest{
   529  			baseTest.
   530  				Results("username", "STRING", false, gosql.NullBool{}, "", "{primary}", false).
   531  				Results("hashedPassword", "BYTES", true, gosql.NullBool{}, "", "{}", false).
   532  				Results("isRole", "BOOL", false, false, "", "{}", false),
   533  		}},
   534  		{"SHOW DATABASES", []preparedQueryTest{
   535  			baseTest.Results("d").Results("defaultdb").Results("postgres").Results("system"),
   536  		}},
   537  		{"SHOW GRANTS ON system.users", []preparedQueryTest{
   538  			baseTest.Results("system", "public", "users", sqlbase.AdminRole, "DELETE").
   539  				Results("system", "public", "users", sqlbase.AdminRole, "GRANT").
   540  				Results("system", "public", "users", sqlbase.AdminRole, "INSERT").
   541  				Results("system", "public", "users", sqlbase.AdminRole, "SELECT").
   542  				Results("system", "public", "users", sqlbase.AdminRole, "UPDATE").
   543  				Results("system", "public", "users", security.RootUser, "DELETE").
   544  				Results("system", "public", "users", security.RootUser, "GRANT").
   545  				Results("system", "public", "users", security.RootUser, "INSERT").
   546  				Results("system", "public", "users", security.RootUser, "SELECT").
   547  				Results("system", "public", "users", security.RootUser, "UPDATE"),
   548  		}},
   549  		{"SHOW INDEXES FROM system.users", []preparedQueryTest{
   550  			baseTest.Results("users", "primary", false, 1, "username", "ASC", false, false),
   551  		}},
   552  		{"SHOW TABLES FROM system", []preparedQueryTest{
   553  			baseTest.Results("public", "comments", "table").Others(26),
   554  		}},
   555  		{"SHOW SCHEMAS FROM system", []preparedQueryTest{
   556  			baseTest.Results("crdb_internal").Others(4),
   557  		}},
   558  		{"SHOW CONSTRAINTS FROM system.users", []preparedQueryTest{
   559  			baseTest.Results("users", "primary", "PRIMARY KEY", "PRIMARY KEY (username ASC)", true),
   560  		}},
   561  		{"SHOW TIME ZONE", []preparedQueryTest{
   562  			baseTest.Results("UTC"),
   563  		}},
   564  		{"CREATE USER IF NOT EXISTS $1 WITH PASSWORD $2", []preparedQueryTest{
   565  			baseTest.SetArgs("abc", "def"),
   566  			baseTest.SetArgs("woo", "waa"),
   567  		}},
   568  		{"ALTER USER IF EXISTS $1 WITH PASSWORD $2", []preparedQueryTest{
   569  			baseTest.SetArgs("abc", "def"),
   570  			baseTest.SetArgs("woo", "waa"),
   571  		}},
   572  		{"SHOW USERS", []preparedQueryTest{
   573  			baseTest.Results("abc", "", "{}").Results("admin", "CREATEROLE", "{}").
   574  				Results("root", "CREATEROLE", "{admin}").Results("woo", "", "{}"),
   575  		}},
   576  		{"DROP USER $1", []preparedQueryTest{
   577  			baseTest.SetArgs("abc"),
   578  			baseTest.SetArgs("woo"),
   579  		}},
   580  		{"SELECT (SELECT 1+$1)", []preparedQueryTest{
   581  			baseTest.SetArgs(1).Results(2),
   582  		}},
   583  		{"SELECT CASE WHEN $1 THEN $2 ELSE 3 END", []preparedQueryTest{
   584  			baseTest.SetArgs(true, 2).Results(2),
   585  			baseTest.SetArgs(false, 2).Results(3),
   586  		}},
   587  		{"SELECT CASE WHEN TRUE THEN 1 ELSE $1 END", []preparedQueryTest{
   588  			baseTest.SetArgs(2).Results(1),
   589  		}},
   590  		{"SELECT CASE $1 WHEN 1 THEN 1 END", []preparedQueryTest{
   591  			baseTest.SetArgs(1).Results(1),
   592  			baseTest.SetArgs(2).Results(gosql.NullInt64{}),
   593  		}},
   594  		{"SELECT $1::timestamp, $2::date", []preparedQueryTest{
   595  			baseTest.SetArgs("2001-01-02 03:04:05", "2006-07-08").Results(
   596  				time.Date(2001, 1, 2, 3, 4, 5, 0, time.FixedZone("", 0)),
   597  				time.Date(2006, 7, 8, 0, 0, 0, 0, time.FixedZone("", 0)),
   598  			),
   599  		}},
   600  		{"SELECT $1::date, $2::timestamp", []preparedQueryTest{
   601  			baseTest.SetArgs(
   602  				time.Date(2006, 7, 8, 0, 0, 0, 9, time.FixedZone("", 0)),
   603  				time.Date(2001, 1, 2, 3, 4, 5, 6000, time.FixedZone("", 0)),
   604  			).Results(
   605  				time.Date(2006, 7, 8, 0, 0, 0, 0, time.FixedZone("", 0)),
   606  				time.Date(2001, 1, 2, 3, 4, 5, 6000, time.FixedZone("", 0)),
   607  			),
   608  		}},
   609  		{"INSERT INTO d.ts VALUES($1, $2) RETURNING *", []preparedQueryTest{
   610  			baseTest.SetArgs("2001-01-02 03:04:05", "2006-07-08").Results(
   611  				time.Date(2001, 1, 2, 3, 4, 5, 0, time.FixedZone("", 0)),
   612  				time.Date(2006, 7, 8, 0, 0, 0, 0, time.FixedZone("", 0)),
   613  			),
   614  		}},
   615  		{"INSERT INTO d.ts VALUES(current_timestamp(), $1) RETURNING b", []preparedQueryTest{
   616  			baseTest.SetArgs("2006-07-08").Results(
   617  				time.Date(2006, 7, 8, 0, 0, 0, 0, time.FixedZone("", 0)),
   618  			),
   619  		}},
   620  		{"INSERT INTO d.ts VALUES(statement_timestamp(), $1) RETURNING b", []preparedQueryTest{
   621  			baseTest.SetArgs("2006-07-08").Results(
   622  				time.Date(2006, 7, 8, 0, 0, 0, 0, time.FixedZone("", 0)),
   623  			),
   624  		}},
   625  		{"INSERT INTO d.ts (a) VALUES ($1) RETURNING a", []preparedQueryTest{
   626  			baseTest.SetArgs(
   627  				time.Date(2006, 7, 8, 0, 0, 0, 123000, time.FixedZone("", 0)),
   628  			).Results(
   629  				time.Date(2006, 7, 8, 0, 0, 0, 123000, time.FixedZone("", 0)),
   630  			),
   631  		}},
   632  		{"INSERT INTO d.T VALUES ($1) RETURNING 1", []preparedQueryTest{
   633  			baseTest.SetArgs(1).Results(1),
   634  			baseTest.SetArgs(nil).Results(1),
   635  		}},
   636  		{"INSERT INTO d.T VALUES ($1::INT) RETURNING 1", []preparedQueryTest{
   637  			baseTest.SetArgs(1).Results(1),
   638  		}},
   639  		{"INSERT INTO d.T VALUES ($1) RETURNING $1", []preparedQueryTest{
   640  			baseTest.SetArgs(1).Results(1),
   641  			baseTest.SetArgs(3).Results(3),
   642  		}},
   643  		{"INSERT INTO d.T VALUES ($1) RETURNING $1, 1 + $1", []preparedQueryTest{
   644  			baseTest.SetArgs(1).Results(1, 2),
   645  			baseTest.SetArgs(3).Results(3, 4),
   646  		}},
   647  		{"INSERT INTO d.T VALUES (greatest(42, $1)) RETURNING a", []preparedQueryTest{
   648  			baseTest.SetArgs(40).Results(42),
   649  			baseTest.SetArgs(45).Results(45),
   650  		}},
   651  		// TODO(justin): match this with the optimizer. Currently we only report
   652  		// one placeholder not being filled in, since we only detect so at eval
   653  		// time, #26901.
   654  		// {"SELECT a FROM d.T WHERE a = $1 AND (SELECT a >= $2 FROM d.T WHERE a = $1)",  []preparedQueryTest{
   655  		// 	baseTest.SetArgs(10, 5).Results(10),
   656  		// 	baseTest.Error(
   657  		// 		"pq: no value provided for placeholders: $1, $2",
   658  		// 	).PreparedError(
   659  		// 		wrongArgCountString(2, 0),
   660  		// 	),
   661  		// }},
   662  		{"SELECT * FROM (VALUES (1), (2), (3), (4)) AS foo (a) LIMIT $1 OFFSET $2", []preparedQueryTest{
   663  			baseTest.SetArgs(1, 0).Results(1),
   664  			baseTest.SetArgs(1, 1).Results(2),
   665  			baseTest.SetArgs(1, 2).Results(3),
   666  		}},
   667  		{"SELECT * FROM (VALUES (1), (2), (3), (4)) AS foo (a) FETCH FIRST $1 ROWS ONLY OFFSET $2 ROWS", []preparedQueryTest{
   668  			baseTest.SetArgs(1, 0).Results(1),
   669  			baseTest.SetArgs(1, 1).Results(2),
   670  			baseTest.SetArgs(1, 2).Results(3),
   671  		}},
   672  		{"SELECT 3 + CASE (4) WHEN 4 THEN $1 ELSE 42 END", []preparedQueryTest{
   673  			baseTest.SetArgs(12).Results(15),
   674  			baseTest.SetArgs(-12).Results(-9),
   675  		}},
   676  		{"SELECT DATE '2001-01-02' + ($1 + $1:::int)", []preparedQueryTest{
   677  			baseTest.SetArgs(12).Results("2001-01-26T00:00:00Z"),
   678  		}},
   679  		// Hint for INT type to distinguish from ~INET functionality.
   680  		{"SELECT to_hex(~(~$1:::INT))", []preparedQueryTest{
   681  			baseTest.SetArgs(12).Results("c"),
   682  		}},
   683  		{"SELECT $1::INT", []preparedQueryTest{
   684  			baseTest.SetArgs(12).Results(12),
   685  		}},
   686  		{"SELECT ANNOTATE_TYPE($1, int)", []preparedQueryTest{
   687  			baseTest.SetArgs(12).Results(12),
   688  		}},
   689  		{"SELECT $1 + $2, ANNOTATE_TYPE($2, float)", []preparedQueryTest{
   690  			baseTest.SetArgs(12, 23).Results(35, 23),
   691  		}},
   692  		{"INSERT INTO d.T VALUES ($1 + 1) RETURNING a", []preparedQueryTest{
   693  			baseTest.SetArgs(1).Results(2),
   694  			baseTest.SetArgs(11).Results(12),
   695  		}},
   696  		{"INSERT INTO d.T VALUES (-$1) RETURNING a", []preparedQueryTest{
   697  			baseTest.SetArgs(1).Results(-1),
   698  			baseTest.SetArgs(-999).Results(999),
   699  		}},
   700  		{"INSERT INTO d.two (a, b) VALUES (~$1, $1 + $2) RETURNING a, b", []preparedQueryTest{
   701  			baseTest.SetArgs(5, 6).Results(-6, 11),
   702  		}},
   703  		{"INSERT INTO d.str (s) VALUES (left($1, 3)) RETURNING s", []preparedQueryTest{
   704  			baseTest.SetArgs("abcdef").Results("abc"),
   705  			baseTest.SetArgs("123456").Results("123"),
   706  		}},
   707  		{"INSERT INTO d.str (b) VALUES (COALESCE($1, 'strLit')) RETURNING b", []preparedQueryTest{
   708  			baseTest.SetArgs(nil).Results("strLit"),
   709  			baseTest.SetArgs("123456").Results("123456"),
   710  		}},
   711  		{"INSERT INTO d.intStr VALUES ($1, 'hello ' || $1::TEXT) RETURNING *", []preparedQueryTest{
   712  			baseTest.SetArgs(123).Results(123, "hello 123"),
   713  		}},
   714  		{"SELECT * from d.T WHERE a = ANY($1)", []preparedQueryTest{
   715  			baseTest.SetArgs(pq.Array([]int{10})).Results(10),
   716  		}},
   717  		{"SELECT s from (VALUES ('foo'), ('bar')) as t(s) WHERE s = ANY($1)", []preparedQueryTest{
   718  			baseTest.SetArgs(pq.StringArray([]string{"foo"})).Results("foo"),
   719  		}},
   720  		// #13725
   721  		{"SELECT * FROM d.emptynorows", []preparedQueryTest{
   722  			baseTest.SetArgs(),
   723  		}},
   724  		{"SELECT * FROM d.emptyrows", []preparedQueryTest{
   725  			baseTest.SetArgs().Results().Results().Results(),
   726  		}},
   727  		// #14238
   728  		{"EXPLAIN SELECT 1", []preparedQueryTest{
   729  			baseTest.SetArgs().
   730  				Results("", "distributed", "false").
   731  				Results("", "vectorized", "false").
   732  				Results("values", "", "").
   733  				Results("", "size", "1 column, 1 row"),
   734  		}},
   735  		// #14245
   736  		{"SELECT 1::oid = $1", []preparedQueryTest{
   737  			baseTest.SetArgs(1).Results(true),
   738  			baseTest.SetArgs(2).Results(false),
   739  			baseTest.SetArgs("1").Results(true),
   740  			baseTest.SetArgs("2").Results(false),
   741  		}},
   742  		{"SELECT * FROM d.pg_catalog.pg_class WHERE relnamespace = $1", []preparedQueryTest{
   743  			baseTest.SetArgs(1),
   744  		}},
   745  		{"SELECT $1::UUID", []preparedQueryTest{
   746  			baseTest.SetArgs("63616665-6630-3064-6465-616462656562").Results("63616665-6630-3064-6465-616462656562"),
   747  		}},
   748  		{"SELECT $1::INET", []preparedQueryTest{
   749  			baseTest.SetArgs("192.168.0.1/32").Results("192.168.0.1"),
   750  		}},
   751  		{"SELECT $1::TIME", []preparedQueryTest{
   752  			baseTest.SetArgs("12:00:00").Results("0000-01-01T12:00:00Z"),
   753  		}},
   754  		{"SELECT $1::TIMETZ", []preparedQueryTest{
   755  			baseTest.SetArgs("12:00:00+0330").Results("0000-01-01T12:00:00+03:30"),
   756  		}},
   757  		{"SELECT $1::GEOGRAPHY", []preparedQueryTest{
   758  			baseTest.SetArgs("POINT(1.0 1.0)").Results("0101000020E6100000000000000000F03F000000000000F03F"),
   759  		}},
   760  		{"SELECT $1::GEOMETRY", []preparedQueryTest{
   761  			baseTest.SetArgs("POINT(1.0 1.0)").Results("0101000000000000000000F03F000000000000F03F"),
   762  		}},
   763  		{"SELECT $1:::FLOAT[]", []preparedQueryTest{
   764  			baseTest.SetArgs("{}").Results("{}"),
   765  			baseTest.SetArgs("{1.0,2.0,3.0}").Results("{1.0,2.0,3.0}"),
   766  		}},
   767  		{"SELECT $1:::DECIMAL[]", []preparedQueryTest{
   768  			baseTest.SetArgs("{1.000}").Results("{1.000}"),
   769  		}},
   770  		{"SELECT $1:::STRING[]", []preparedQueryTest{
   771  			baseTest.SetArgs(`{aaa}`).Results(`{aaa}`),
   772  			baseTest.SetArgs(`{"aaa"}`).Results(`{aaa}`),
   773  			baseTest.SetArgs(`{aaa,bbb,ccc}`).Results(`{aaa,bbb,ccc}`),
   774  		}},
   775  		{"SELECT $1:::JSON", []preparedQueryTest{
   776  			baseTest.SetArgs(`true`).Results(`true`),
   777  			baseTest.SetArgs(`"hello"`).Results(`"hello"`),
   778  		}},
   779  		{"SELECT $1:::BIT(4)", []preparedQueryTest{
   780  			baseTest.SetArgs(`1101`).Results(`1101`),
   781  		}},
   782  		{"SELECT $1:::VARBIT", []preparedQueryTest{
   783  			baseTest.SetArgs(`1101`).Results(`1101`),
   784  			baseTest.SetArgs(`1101001`).Results(`1101001`),
   785  		}},
   786  		{"SELECT $1::INT[]", []preparedQueryTest{
   787  			baseTest.SetArgs(pq.Array([]int64{10})).Results(pq.Array([]int64{10})),
   788  		}},
   789  		{"INSERT INTO d.arr VALUES($1, $2)", []preparedQueryTest{
   790  			baseTest.SetArgs(pq.Array([]int64{}), pq.Array([]string{})),
   791  		}},
   792  		{"EXPERIMENTAL SCRUB TABLE system.locations", []preparedQueryTest{
   793  			baseTest.SetArgs(),
   794  		}},
   795  		{"ALTER RANGE liveness CONFIGURE ZONE = $1", []preparedQueryTest{
   796  			baseTest.SetArgs("num_replicas: 1"),
   797  		}},
   798  		{"ALTER RANGE liveness CONFIGURE ZONE USING num_replicas = $1", []preparedQueryTest{
   799  			baseTest.SetArgs(1),
   800  		}},
   801  		{"ALTER RANGE liveness CONFIGURE ZONE = $1", []preparedQueryTest{
   802  			baseTest.SetArgs(gosql.NullString{}),
   803  		}},
   804  		{"TRUNCATE TABLE d.str", []preparedQueryTest{
   805  			baseTest.SetArgs(),
   806  		}},
   807  
   808  		// TODO(nvanbenschoten): Same class of limitation as that in logic_test/typing:
   809  		//   Nested constants are not exposed to the same constant type resolution rules
   810  		//   as top-level constants, and instead are simply resolved to their natural type.
   811  		//{"SELECT (CASE a WHEN 10 THEN 'one' WHEN 11 THEN (CASE 'en' WHEN 'en' THEN $1 END) END) AS ret FROM d.T ORDER BY ret DESC LIMIT 2",  []preparedQueryTest{
   812  		// 	baseTest.SetArgs("hello").Results("one").Results("hello"),
   813  		//}},
   814  	}
   815  
   816  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
   817  	defer s.Stopper().Stop(context.Background())
   818  
   819  	pgURL, cleanupFn := sqlutils.PGUrl(t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
   820  	defer cleanupFn()
   821  
   822  	db, err := gosql.Open("postgres", pgURL.String())
   823  	if err != nil {
   824  		t.Fatal(err)
   825  	}
   826  	defer db.Close()
   827  
   828  	runTests := func(
   829  		t *testing.T,
   830  		query string,
   831  		prepared bool,
   832  		tests []preparedQueryTest,
   833  		queryFunc func(...interface{}) (*gosql.Rows, error),
   834  	) {
   835  		for idx, test := range tests {
   836  			t.Run(fmt.Sprintf("%d", idx), func(t *testing.T) {
   837  				if testing.Verbose() || log.V(1) {
   838  					log.Infof(context.Background(), "query: %s", query)
   839  				}
   840  				rows, err := queryFunc(test.qargs...)
   841  				if err != nil {
   842  					if test.error == "" {
   843  						t.Errorf("%s: %v: unexpected error: %s", query, test.qargs, err)
   844  					} else {
   845  						expectedErr := test.error
   846  						if prepared && test.preparedError != "" {
   847  							expectedErr = test.preparedError
   848  						}
   849  						if err.Error() != expectedErr {
   850  							t.Errorf("%s: %v: expected error: %s, got %s", query, test.qargs, expectedErr, err)
   851  						}
   852  					}
   853  					return
   854  				}
   855  				defer rows.Close()
   856  
   857  				if test.error != "" {
   858  					t.Fatalf("expected error: %s: %v", query, test.qargs)
   859  				}
   860  
   861  				for _, expected := range test.results {
   862  					if !rows.Next() {
   863  						t.Fatalf("expected row: %s: %v", query, test.qargs)
   864  					}
   865  					dst := make([]interface{}, len(expected))
   866  					for i, d := range expected {
   867  						dst[i] = reflect.New(reflect.TypeOf(d)).Interface()
   868  					}
   869  					if err := rows.Scan(dst...); err != nil {
   870  						t.Error(err)
   871  					}
   872  					for i, d := range dst {
   873  						dst[i] = reflect.Indirect(reflect.ValueOf(d)).Interface()
   874  					}
   875  					if len(dst) > 0 && len(expected) > 0 && !reflect.DeepEqual(dst, expected) {
   876  						t.Errorf("%s: %v: expected %v, got %v", query, test.qargs, expected, dst)
   877  					}
   878  				}
   879  				for rows.Next() {
   880  					if test.others > 0 {
   881  						test.others--
   882  						continue
   883  					}
   884  					cols, err := rows.Columns()
   885  					if err != nil {
   886  						t.Errorf("%s: %s", query, err)
   887  						continue
   888  					}
   889  					// Unexpected line. Get and print out the details.
   890  					dst := make([]interface{}, len(cols))
   891  					for i := range dst {
   892  						dst[i] = new(interface{})
   893  					}
   894  					if err := rows.Scan(dst...); err != nil {
   895  						t.Errorf("%s: %s", query, err)
   896  						continue
   897  					}
   898  					b, err := json.Marshal(dst)
   899  					if err != nil {
   900  						t.Errorf("%s: %s", query, err)
   901  						continue
   902  					}
   903  					t.Errorf("%s: unexpected row: %s", query, b)
   904  				}
   905  				if test.others > 0 {
   906  					t.Fatalf("%s: expected %d more rows", query, test.others)
   907  				}
   908  			})
   909  		}
   910  	}
   911  
   912  	initStmt := `
   913  CREATE DATABASE d;
   914  CREATE TABLE d.t (a INT);
   915  INSERT INTO d.t VALUES (10),(11);
   916  CREATE TABLE d.ts (a TIMESTAMP, b DATE);
   917  CREATE TABLE d.two (a INT, b INT);
   918  CREATE TABLE d.intStr (a INT, s STRING);
   919  CREATE TABLE d.str (s STRING, b BYTES);
   920  CREATE TABLE d.arr (a INT[], b TEXT[]);
   921  CREATE TABLE d.emptynorows (); -- zero columns, zero rows
   922  CREATE TABLE d.emptyrows (x INT);
   923  INSERT INTO d.emptyrows VALUES (1),(2),(3);
   924  ALTER TABLE d.emptyrows DROP COLUMN x; -- zero columns, 3 rows
   925  `
   926  	if _, err := db.Exec(initStmt); err != nil {
   927  		t.Fatal(err)
   928  	}
   929  
   930  	t.Run("exec", func(t *testing.T) {
   931  		for _, test := range queryTests {
   932  			query := test.sql
   933  			tests := test.ptest
   934  			t.Run(query, func(t *testing.T) {
   935  				runTests(t, query, false, tests, func(args ...interface{}) (*gosql.Rows, error) {
   936  					return db.Query(query, args...)
   937  				})
   938  			})
   939  		}
   940  	})
   941  
   942  	t.Run("prepare", func(t *testing.T) {
   943  		for _, test := range queryTests {
   944  			query := test.sql
   945  			tests := test.ptest
   946  			t.Run(query, func(t *testing.T) {
   947  				if stmt, err := db.Prepare(query); err != nil {
   948  					t.Errorf("%s: prepare error: %s", query, err)
   949  				} else {
   950  					defer stmt.Close()
   951  
   952  					runTests(t, query, true, tests, stmt.Query)
   953  				}
   954  			})
   955  		}
   956  	})
   957  }
   958  
   959  type preparedExecTest struct {
   960  	qargs           []interface{}
   961  	rowsAffected    int64
   962  	error           string
   963  	rowsAffectedErr string
   964  }
   965  
   966  func (p preparedExecTest) SetArgs(v ...interface{}) preparedExecTest {
   967  	p.qargs = v
   968  	return p
   969  }
   970  
   971  func (p preparedExecTest) RowsAffected(rowsAffected int64) preparedExecTest {
   972  	p.rowsAffected = rowsAffected
   973  	return p
   974  }
   975  
   976  func (p preparedExecTest) Error(err string) preparedExecTest {
   977  	p.error = err
   978  	return p
   979  }
   980  
   981  func (p preparedExecTest) RowsAffectedErr(err string) preparedExecTest {
   982  	p.rowsAffectedErr = err
   983  	return p
   984  }
   985  
   986  // Verify that bound dates are evaluated using session timezone.
   987  func TestPGPrepareDate(t *testing.T) {
   988  	defer leaktest.AfterTest(t)()
   989  	s, db, _ := serverutils.StartServer(t, base.TestServerArgs{})
   990  	defer s.Stopper().Stop(context.Background())
   991  
   992  	if _, err := db.Exec("CREATE TABLE test (t TIMESTAMPTZ)"); err != nil {
   993  		t.Fatal(err)
   994  	}
   995  
   996  	if _, err := db.Exec("SET TIME ZONE +08"); err != nil {
   997  		t.Fatal(err)
   998  	}
   999  
  1000  	stmt, err := db.Prepare("INSERT INTO test VALUES ($1)")
  1001  	if err != nil {
  1002  		t.Fatal(err)
  1003  	}
  1004  
  1005  	if _, err := stmt.Exec("2018-01-01 12:34:56"); err != nil {
  1006  		t.Fatal(err)
  1007  	}
  1008  
  1009  	// Reset to UTC for the query.
  1010  	if _, err := db.Exec("SET TIME ZONE UTC"); err != nil {
  1011  		t.Fatal(err)
  1012  	}
  1013  
  1014  	var ts time.Time
  1015  	if err := db.QueryRow("SELECT t FROM test").Scan(&ts); err != nil {
  1016  		t.Fatal(err)
  1017  	}
  1018  
  1019  	exp := time.Date(2018, 1, 1, 4, 34, 56, 0, time.UTC)
  1020  	if !exp.Equal(ts) {
  1021  		t.Fatalf("expected %s, got %s", exp, ts)
  1022  	}
  1023  }
  1024  
  1025  func TestPGPreparedExec(t *testing.T) {
  1026  	defer leaktest.AfterTest(t)()
  1027  	var baseTest preparedExecTest
  1028  	execTests := []struct {
  1029  		query string
  1030  		tests []preparedExecTest
  1031  	}{
  1032  		{
  1033  			"CREATE DATABASE d",
  1034  			[]preparedExecTest{
  1035  				baseTest,
  1036  			},
  1037  		},
  1038  		{
  1039  			"CREATE TABLE d.public.t (i INT, s STRING, d INT)",
  1040  			[]preparedExecTest{
  1041  				baseTest,
  1042  				baseTest.Error(`pq: relation "t" already exists`),
  1043  			},
  1044  		},
  1045  		{
  1046  			"INSERT INTO d.public.t VALUES ($1, $2, $3)",
  1047  			[]preparedExecTest{
  1048  				baseTest.SetArgs(1, "one", 2).RowsAffected(1),
  1049  				baseTest.SetArgs("two", 2, 2).Error(`pq: error in argument for $1: strconv.ParseInt: parsing "two": invalid syntax`),
  1050  			},
  1051  		},
  1052  		{
  1053  			"UPDATE d.public.t SET s = $1, i = i + $2, d = 1 + $3 WHERE i = $4",
  1054  			[]preparedExecTest{
  1055  				baseTest.SetArgs(4, 3, 2, 1).RowsAffected(1),
  1056  			},
  1057  		},
  1058  		{
  1059  			"UPDATE d.public.t SET i = $1 WHERE (i, s) = ($2, $3)",
  1060  			[]preparedExecTest{
  1061  				baseTest.SetArgs(8, 4, "4").RowsAffected(1),
  1062  			},
  1063  		},
  1064  		{
  1065  			"DELETE FROM d.public.t WHERE s = $1 and i = $2 and d = 2 + $3",
  1066  			[]preparedExecTest{
  1067  				baseTest.SetArgs(1, 2, 3).RowsAffected(0),
  1068  			},
  1069  		},
  1070  		{
  1071  			"INSERT INTO d.public.t VALUES ($1), ($2)",
  1072  			[]preparedExecTest{
  1073  				baseTest.SetArgs(1, 2).RowsAffected(2),
  1074  			},
  1075  		},
  1076  		{
  1077  			"INSERT INTO d.public.t VALUES ($1), ($2) RETURNING $3 + 1",
  1078  			[]preparedExecTest{
  1079  				baseTest.SetArgs(3, 4, 5).RowsAffected(2),
  1080  			},
  1081  		},
  1082  		{
  1083  			"UPDATE d.public.t SET i = CASE WHEN $1 THEN i-$3 WHEN $2 THEN i+$3 END",
  1084  			[]preparedExecTest{
  1085  				baseTest.SetArgs(true, true, 3).RowsAffected(5),
  1086  			},
  1087  		},
  1088  		{
  1089  			"UPDATE d.public.t SET i = CASE i WHEN $1 THEN i-$3 WHEN $2 THEN i+$3 END",
  1090  			[]preparedExecTest{
  1091  				baseTest.SetArgs(1, 2, 3).RowsAffected(5),
  1092  			},
  1093  		},
  1094  		{
  1095  			"UPDATE d.public.t SET d = CASE WHEN TRUE THEN $1 END",
  1096  			[]preparedExecTest{
  1097  				baseTest.SetArgs(2).RowsAffected(5),
  1098  			},
  1099  		},
  1100  		{
  1101  			"DELETE FROM d.public.t RETURNING $1+1",
  1102  			[]preparedExecTest{
  1103  				baseTest.SetArgs(1).RowsAffected(5),
  1104  			},
  1105  		},
  1106  		{
  1107  			"DROP TABLE d.public.t",
  1108  			[]preparedExecTest{
  1109  				baseTest,
  1110  				baseTest.Error(`pq: relation "d.public.t" does not exist`),
  1111  			},
  1112  		},
  1113  		{
  1114  			"CREATE TABLE d.public.t AS SELECT $1+1 AS x",
  1115  			[]preparedExecTest{
  1116  				baseTest.SetArgs(1),
  1117  			},
  1118  		},
  1119  		{
  1120  			"CREATE TABLE d.public.types (i int, f float, s string, b bytes, d date, m timestamp, z timestamp with time zone, n interval, o bool, e decimal)",
  1121  			[]preparedExecTest{
  1122  				baseTest,
  1123  			},
  1124  		},
  1125  		{
  1126  			"INSERT INTO d.public.types VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)",
  1127  			[]preparedExecTest{
  1128  				baseTest.RowsAffected(1).SetArgs(
  1129  					int64(0),
  1130  					float64(0),
  1131  					"",
  1132  					[]byte{},
  1133  					time.Time{}, // date
  1134  					time.Time{}, // timestamp
  1135  					time.Time{}, // timestamptz
  1136  					time.Hour.String(),
  1137  					true,
  1138  					"0.0", // decimal
  1139  				),
  1140  			},
  1141  		},
  1142  		{
  1143  			"DROP DATABASE d CASCADE",
  1144  			[]preparedExecTest{
  1145  				baseTest,
  1146  			},
  1147  		},
  1148  		{
  1149  			"CANCEL JOB $1",
  1150  			[]preparedExecTest{
  1151  				baseTest.SetArgs(123).Error("pq: job with ID 123 does not exist"),
  1152  			},
  1153  		},
  1154  		{
  1155  			"CANCEL JOBS SELECT $1",
  1156  			[]preparedExecTest{
  1157  				baseTest.SetArgs(123).Error("pq: job with ID 123 does not exist"),
  1158  			},
  1159  		},
  1160  		{
  1161  			"RESUME JOB $1",
  1162  			[]preparedExecTest{
  1163  				baseTest.SetArgs(123).Error("pq: job with ID 123 does not exist"),
  1164  			},
  1165  		},
  1166  		{
  1167  			"RESUME JOBS SELECT $1",
  1168  			[]preparedExecTest{
  1169  				baseTest.SetArgs(123).Error("pq: job with ID 123 does not exist"),
  1170  			},
  1171  		},
  1172  		{
  1173  			"PAUSE JOB $1",
  1174  			[]preparedExecTest{
  1175  				baseTest.SetArgs(123).Error("pq: job with ID 123 does not exist"),
  1176  			},
  1177  		},
  1178  		{
  1179  			"PAUSE JOBS SELECT $1",
  1180  			[]preparedExecTest{
  1181  				baseTest.SetArgs(123).Error("pq: job with ID 123 does not exist"),
  1182  			},
  1183  		},
  1184  		{
  1185  			"CANCEL QUERY $1",
  1186  			[]preparedExecTest{
  1187  				baseTest.SetArgs("01").Error("pq: could not cancel query 00000000000000000000000000000001: query ID 00000000000000000000000000000001 not found"),
  1188  			},
  1189  		},
  1190  		{
  1191  			"CANCEL QUERIES SELECT $1",
  1192  			[]preparedExecTest{
  1193  				baseTest.SetArgs("01").Error("pq: could not cancel query 00000000000000000000000000000001: query ID 00000000000000000000000000000001 not found"),
  1194  			},
  1195  		},
  1196  		{
  1197  			"CANCEL SESSION $1",
  1198  			[]preparedExecTest{
  1199  				baseTest.SetArgs("01").Error("pq: could not cancel session 00000000000000000000000000000001: session ID 00000000000000000000000000000001 not found"),
  1200  			},
  1201  		},
  1202  		{
  1203  			"CANCEL SESSIONS SELECT $1",
  1204  			[]preparedExecTest{
  1205  				baseTest.SetArgs("01").Error("pq: could not cancel session 00000000000000000000000000000001: session ID 00000000000000000000000000000001 not found"),
  1206  			},
  1207  		},
  1208  		// An empty string is valid in postgres.
  1209  		{
  1210  			"",
  1211  			[]preparedExecTest{
  1212  				baseTest.RowsAffectedErr("no RowsAffected available after the empty statement"),
  1213  			},
  1214  		},
  1215  		// Empty statements are permitted.
  1216  		{
  1217  			";",
  1218  			[]preparedExecTest{
  1219  				baseTest.RowsAffectedErr("no RowsAffected available after the empty statement"),
  1220  			},
  1221  		},
  1222  		// Any number of empty statements are permitted with a single statement
  1223  		// anywhere.
  1224  		{
  1225  			"; ; SET DATABASE = system; ;",
  1226  			[]preparedExecTest{
  1227  				baseTest,
  1228  			},
  1229  		},
  1230  	}
  1231  
  1232  	s, db, _ := serverutils.StartServer(t, base.TestServerArgs{})
  1233  	defer s.Stopper().Stop(context.Background())
  1234  
  1235  	runTests := func(
  1236  		t *testing.T, query string, tests []preparedExecTest, execFunc func(...interface{},
  1237  		) (gosql.Result, error)) {
  1238  		for idx, test := range tests {
  1239  			t.Run(fmt.Sprintf("%d", idx), func(t *testing.T) {
  1240  				if testing.Verbose() || log.V(1) {
  1241  					log.Infof(context.Background(), "exec: %s", query)
  1242  				}
  1243  				if result, err := execFunc(test.qargs...); err != nil {
  1244  					if test.error == "" {
  1245  						t.Errorf("%s: %v: unexpected error: %s", query, test.qargs, err)
  1246  					} else if err.Error() != test.error {
  1247  						t.Errorf("%s: %v: expected error: %s, got %s", query, test.qargs, test.error, err)
  1248  					}
  1249  				} else {
  1250  					rowsAffected, err := result.RowsAffected()
  1251  					if !testutils.IsError(err, test.rowsAffectedErr) {
  1252  						t.Errorf("%s: %v: expected %q, got %v", query, test.qargs, test.rowsAffectedErr, err)
  1253  					} else if rowsAffected != test.rowsAffected {
  1254  						t.Errorf("%s: %v: expected %v, got %v", query, test.qargs, test.rowsAffected, rowsAffected)
  1255  					}
  1256  				}
  1257  			})
  1258  		}
  1259  	}
  1260  
  1261  	t.Run("exec", func(t *testing.T) {
  1262  		for _, execTest := range execTests {
  1263  			t.Run(execTest.query, func(t *testing.T) {
  1264  				runTests(t, execTest.query, execTest.tests, func(args ...interface{}) (gosql.Result, error) {
  1265  					return db.Exec(execTest.query, args...)
  1266  				})
  1267  			})
  1268  		}
  1269  	})
  1270  
  1271  	t.Run("prepare", func(t *testing.T) {
  1272  		for _, execTest := range execTests {
  1273  			t.Run(execTest.query, func(t *testing.T) {
  1274  				if testing.Verbose() || log.V(1) {
  1275  					log.Infof(context.Background(), "prepare: %s", execTest.query)
  1276  				}
  1277  				if stmt, err := db.Prepare(execTest.query); err != nil {
  1278  					t.Errorf("%s: prepare error: %s", execTest.query, err)
  1279  				} else {
  1280  					defer stmt.Close()
  1281  
  1282  					runTests(t, execTest.query, execTest.tests, stmt.Exec)
  1283  				}
  1284  			})
  1285  		}
  1286  	})
  1287  }
  1288  
  1289  // Names should be qualified automatically during Prepare when a database name
  1290  // was given in the connection string.
  1291  func TestPGPrepareNameQual(t *testing.T) {
  1292  	defer leaktest.AfterTest(t)()
  1293  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
  1294  	defer s.Stopper().Stop(context.Background())
  1295  
  1296  	pgURL, cleanupFn := sqlutils.PGUrl(t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
  1297  	defer cleanupFn()
  1298  
  1299  	db, err := gosql.Open("postgres", pgURL.String())
  1300  	if err != nil {
  1301  		t.Fatal(err)
  1302  	}
  1303  	defer db.Close()
  1304  
  1305  	if _, err := db.Exec(`CREATE DATABASE IF NOT EXISTS testing`); err != nil {
  1306  		t.Fatal(err)
  1307  	}
  1308  
  1309  	pgURL.Path = "/testing"
  1310  	db2, err := gosql.Open("postgres", pgURL.String())
  1311  	if err != nil {
  1312  		t.Fatal(err)
  1313  	}
  1314  	defer db2.Close()
  1315  
  1316  	statements := []string{
  1317  		`CREATE TABLE IF NOT EXISTS f (v INT)`,
  1318  		`INSERT INTO f VALUES (42)`,
  1319  		`SELECT * FROM f`,
  1320  		`DELETE FROM f WHERE v = 42`,
  1321  		`DROP TABLE IF EXISTS f`,
  1322  	}
  1323  
  1324  	for _, stmtString := range statements {
  1325  		if _, err = db2.Exec(stmtString); err != nil {
  1326  			t.Fatal(err)
  1327  		}
  1328  
  1329  		stmt, err := db2.Prepare(stmtString)
  1330  		if err != nil {
  1331  			t.Fatal(err)
  1332  		}
  1333  
  1334  		if _, err = stmt.Exec(); err != nil {
  1335  			t.Fatal(err)
  1336  		}
  1337  	}
  1338  }
  1339  
  1340  // TestPGPrepareInvalidate ensures that changing table schema triggers recompile
  1341  // of a prepared query.
  1342  func TestPGPrepareInvalidate(t *testing.T) {
  1343  	defer leaktest.AfterTest(t)()
  1344  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
  1345  	defer s.Stopper().Stop(context.Background())
  1346  
  1347  	pgURL, cleanupFn := sqlutils.PGUrl(t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
  1348  	defer cleanupFn()
  1349  
  1350  	db, err := gosql.Open("postgres", pgURL.String())
  1351  	if err != nil {
  1352  		t.Fatal(err)
  1353  	}
  1354  	defer db.Close()
  1355  
  1356  	testCases := []struct {
  1357  		stmt    string
  1358  		prep    bool
  1359  		numCols int
  1360  	}{
  1361  		{
  1362  			stmt: `CREATE DATABASE IF NOT EXISTS testing`,
  1363  		},
  1364  		{
  1365  			stmt: `CREATE TABLE IF NOT EXISTS ab (a INT PRIMARY KEY, b INT)`,
  1366  		},
  1367  		{
  1368  			stmt:    `INSERT INTO ab (a, b) VALUES (1, 10)`,
  1369  			prep:    true,
  1370  			numCols: 2,
  1371  		},
  1372  		{
  1373  			stmt:    `ALTER TABLE ab ADD COLUMN c INT`,
  1374  			numCols: 3,
  1375  		},
  1376  		{
  1377  			stmt:    `ALTER TABLE ab DROP COLUMN c`,
  1378  			numCols: 2,
  1379  		},
  1380  	}
  1381  
  1382  	var prep *gosql.Stmt
  1383  	for _, tc := range testCases {
  1384  		if _, err = db.Exec(tc.stmt); err != nil {
  1385  			t.Fatal(err)
  1386  		}
  1387  
  1388  		// Create the prepared statement.
  1389  		if tc.prep {
  1390  			if prep, err = db.Prepare(`SELECT * FROM ab WHERE b=10`); err != nil {
  1391  				t.Fatal(err)
  1392  			}
  1393  		}
  1394  
  1395  		if prep != nil {
  1396  			rows, _ := prep.Query()
  1397  			defer rows.Close()
  1398  			cols, _ := rows.Columns()
  1399  			if len(cols) != tc.numCols {
  1400  				t.Fatalf("expected %d cols, got %d cols", tc.numCols, len(cols))
  1401  			}
  1402  		}
  1403  	}
  1404  }
  1405  
  1406  // A DDL should return "CommandComplete", not "EmptyQuery" Response.
  1407  func TestCmdCompleteVsEmptyStatements(t *testing.T) {
  1408  	defer leaktest.AfterTest(t)()
  1409  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
  1410  	defer s.Stopper().Stop(context.Background())
  1411  
  1412  	pgURL, cleanupFn := sqlutils.PGUrl(
  1413  		t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
  1414  	defer cleanupFn()
  1415  
  1416  	db, err := gosql.Open("postgres", pgURL.String())
  1417  	if err != nil {
  1418  		t.Fatal(err)
  1419  	}
  1420  	defer db.Close()
  1421  
  1422  	// lib/pq handles the empty query response by returning a nil driver.Result.
  1423  	// Unfortunately gosql.Exec wraps that, nil or not, in a gosql.Result which doesn't
  1424  	// expose the underlying driver.Result.
  1425  	// gosql.Result does however have methods which attempt to dereference the underlying
  1426  	// driver.Result and can thus be used to determine if it is nil.
  1427  	// TODO(dt): This would be prettier and generate better failures with testify/assert's helpers.
  1428  
  1429  	// Result of a DDL (command complete) yields a non-nil underlying driver result.
  1430  	nonempty, err := db.Exec(`CREATE DATABASE IF NOT EXISTS testing`)
  1431  	if err != nil {
  1432  		t.Fatal(err)
  1433  	}
  1434  	_, _ = nonempty.RowsAffected() // should not panic if lib/pq returned a non-nil result.
  1435  
  1436  	empty, err := db.Exec(" ; ; ;")
  1437  	if err != nil {
  1438  		t.Fatal(err)
  1439  	}
  1440  	rows, err := empty.RowsAffected()
  1441  	if rows != 0 {
  1442  		t.Fatalf("expected 0 rows, got %d", rows)
  1443  	}
  1444  	if err == nil {
  1445  		t.Fatal("expected error")
  1446  	}
  1447  }
  1448  
  1449  // Unfortunately lib/pq doesn't expose returned command tags directly, but we can test
  1450  // the methods where it depends on their values (Begin, Commit, RowsAffected for INSERTs).
  1451  func TestPGCommandTags(t *testing.T) {
  1452  	defer leaktest.AfterTest(t)()
  1453  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
  1454  	defer s.Stopper().Stop(context.Background())
  1455  
  1456  	pgURL, cleanupFn := sqlutils.PGUrl(t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
  1457  	defer cleanupFn()
  1458  
  1459  	db, err := gosql.Open("postgres", pgURL.String())
  1460  	if err != nil {
  1461  		t.Fatal(err)
  1462  	}
  1463  	defer db.Close()
  1464  	if _, err := db.Exec(`CREATE DATABASE IF NOT EXISTS testing`); err != nil {
  1465  		t.Fatal(err)
  1466  	}
  1467  	if _, err := db.Exec(`CREATE TABLE testing.tags (k INT PRIMARY KEY, v INT)`); err != nil {
  1468  		t.Fatal(err)
  1469  	}
  1470  
  1471  	// Begin will error if the returned tag is not BEGIN.
  1472  	tx, err := db.Begin()
  1473  	if err != nil {
  1474  		t.Fatal(err)
  1475  	}
  1476  
  1477  	// Commit also checks the correct tag is returned.
  1478  	if err := tx.Commit(); err != nil {
  1479  		t.Fatal(err)
  1480  	}
  1481  
  1482  	tx, err = db.Begin()
  1483  	if err != nil {
  1484  		t.Fatal(err)
  1485  	}
  1486  	if _, err := tx.Exec("INSERT INTO testing.tags VALUES (4, 1)"); err != nil {
  1487  		t.Fatal(err)
  1488  	}
  1489  	// Rollback also checks the correct tag is returned.
  1490  	if err := tx.Rollback(); err != nil {
  1491  		t.Fatal(err)
  1492  	}
  1493  
  1494  	tx, err = db.Begin()
  1495  	if err != nil {
  1496  		t.Fatal(err)
  1497  	}
  1498  	// An error will abort the server's transaction.
  1499  	if _, err := tx.Exec("INSERT INTO testing.tags VALUES (4, 1), (4, 1)"); err == nil {
  1500  		t.Fatal("expected an error on duplicate k")
  1501  	}
  1502  	// Rollback, even of an aborted txn, should also return the correct tag.
  1503  	if err := tx.Rollback(); err != nil {
  1504  		t.Fatal(err)
  1505  	}
  1506  
  1507  	// lib/pq has a special-case for INSERT (due to oids), so test insert and update statements.
  1508  	res, err := db.Exec("INSERT INTO testing.tags VALUES (1, 1), (2, 2)")
  1509  	if err != nil {
  1510  		t.Fatal(err)
  1511  	}
  1512  	affected, err := res.RowsAffected()
  1513  	if err != nil {
  1514  		t.Fatal(err)
  1515  	}
  1516  	if affected != 2 {
  1517  		t.Fatal("unexpected number of rows affected:", affected)
  1518  	}
  1519  
  1520  	res, err = db.Exec("INSERT INTO testing.tags VALUES (3, 3)")
  1521  	if err != nil {
  1522  		t.Fatal(err)
  1523  	}
  1524  	affected, err = res.RowsAffected()
  1525  	if err != nil {
  1526  		t.Fatal(err)
  1527  	}
  1528  	if affected != 1 {
  1529  		t.Fatal("unexpected number of rows affected:", affected)
  1530  	}
  1531  
  1532  	res, err = db.Exec("UPDATE testing.tags SET v = 3")
  1533  	if err != nil {
  1534  		t.Fatal(err)
  1535  	}
  1536  	affected, err = res.RowsAffected()
  1537  	if err != nil {
  1538  		t.Fatal(err)
  1539  	}
  1540  	if affected != 3 {
  1541  		t.Fatal("unexpected number of rows affected:", affected)
  1542  	}
  1543  }
  1544  
  1545  // checkSQLNetworkMetrics returns the server's pgwire bytesIn/bytesOut and an
  1546  // error if the bytesIn/bytesOut don't satisfy the given minimums and maximums.
  1547  func checkSQLNetworkMetrics(
  1548  	s serverutils.TestServerInterface, minBytesIn, minBytesOut, maxBytesIn, maxBytesOut int64,
  1549  ) (int64, int64, error) {
  1550  	if err := s.WriteSummaries(); err != nil {
  1551  		return -1, -1, err
  1552  	}
  1553  
  1554  	bytesIn := s.MustGetSQLNetworkCounter(pgwire.MetaBytesIn.Name)
  1555  	bytesOut := s.MustGetSQLNetworkCounter(pgwire.MetaBytesOut.Name)
  1556  	if a, min := bytesIn, minBytesIn; a < min {
  1557  		return bytesIn, bytesOut, errors.Errorf("bytesin %d < expected min %d", a, min)
  1558  	}
  1559  	if a, min := bytesOut, minBytesOut; a < min {
  1560  		return bytesIn, bytesOut, errors.Errorf("bytesout %d < expected min %d", a, min)
  1561  	}
  1562  	if a, max := bytesIn, maxBytesIn; a > max {
  1563  		return bytesIn, bytesOut, errors.Errorf("bytesin %d > expected max %d", a, max)
  1564  	}
  1565  	if a, max := bytesOut, maxBytesOut; a > max {
  1566  		return bytesIn, bytesOut, errors.Errorf("bytesout %d > expected max %d", a, max)
  1567  	}
  1568  	return bytesIn, bytesOut, nil
  1569  }
  1570  
  1571  func TestSQLNetworkMetrics(t *testing.T) {
  1572  	defer leaktest.AfterTest(t)()
  1573  
  1574  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
  1575  	defer s.Stopper().Stop(context.Background())
  1576  
  1577  	// Setup pgwire client.
  1578  	pgURL, cleanupFn := sqlutils.PGUrl(
  1579  		t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
  1580  	defer cleanupFn()
  1581  
  1582  	const minbytes = 20
  1583  	const maxbytes = 2 * 1024
  1584  
  1585  	// Make sure we're starting at 0.
  1586  	if _, _, err := checkSQLNetworkMetrics(s, 0, 0, 0, 0); err != nil {
  1587  		t.Fatal(err)
  1588  	}
  1589  
  1590  	// A single query should give us some I/O.
  1591  	if err := trivialQuery(pgURL); err != nil {
  1592  		t.Fatal(err)
  1593  	}
  1594  	bytesIn, bytesOut, err := checkSQLNetworkMetrics(s, minbytes, minbytes, maxbytes, maxbytes)
  1595  	if err != nil {
  1596  		t.Fatal(err)
  1597  	}
  1598  	if err := trivialQuery(pgURL); err != nil {
  1599  		t.Fatal(err)
  1600  	}
  1601  
  1602  	// A second query should give us more I/O.
  1603  	_, _, err = checkSQLNetworkMetrics(s, bytesIn+minbytes, bytesOut+minbytes, maxbytes, maxbytes)
  1604  	if err != nil {
  1605  		t.Fatal(err)
  1606  	}
  1607  
  1608  	// Verify connection counter.
  1609  	expectConns := func(n int) {
  1610  		testutils.SucceedsSoon(t, func() error {
  1611  			if conns := s.MustGetSQLNetworkCounter(pgwire.MetaConns.Name); conns != int64(n) {
  1612  				return errors.Errorf("connections %d != expected %d", conns, n)
  1613  			}
  1614  			return nil
  1615  		})
  1616  	}
  1617  
  1618  	var conns [10]*gosql.DB
  1619  	for i := range conns {
  1620  		var err error
  1621  		if conns[i], err = gosql.Open("postgres", pgURL.String()); err != nil {
  1622  			t.Fatal(err)
  1623  		}
  1624  		defer conns[i].Close()
  1625  
  1626  		rows, err := conns[i].Query("SELECT 1")
  1627  		if err != nil {
  1628  			t.Fatal(err)
  1629  		}
  1630  		rows.Close()
  1631  		expectConns(i + 1)
  1632  	}
  1633  
  1634  	for i := len(conns) - 1; i >= 0; i-- {
  1635  		conns[i].Close()
  1636  		expectConns(i)
  1637  	}
  1638  }
  1639  
  1640  func TestPGWireOverUnixSocket(t *testing.T) {
  1641  	defer leaktest.AfterTest(t)()
  1642  
  1643  	if runtime.GOOS == "windows" {
  1644  		t.Skip("unix sockets not support on windows")
  1645  	}
  1646  
  1647  	// We need a temp directory in which we'll create the unix socket.
  1648  	//
  1649  	// On BSD, binding to a socket is limited to a path length of 104 characters
  1650  	// (including the NUL terminator). In glibc, this limit is 108 characters.
  1651  	//
  1652  	// macOS has a tendency to produce very long temporary directory names, so
  1653  	// we are careful to keep all the constants involved short.
  1654  	tempDir, err := ioutil.TempDir("", "PGSQL")
  1655  	if err != nil {
  1656  		t.Fatal(err)
  1657  	}
  1658  	defer func() { _ = os.RemoveAll(tempDir) }()
  1659  
  1660  	const port = "6"
  1661  
  1662  	socketFile := filepath.Join(tempDir, ".s.PGSQL."+port)
  1663  
  1664  	params := base.TestServerArgs{
  1665  		Insecure:   true,
  1666  		SocketFile: socketFile,
  1667  	}
  1668  	s, _, _ := serverutils.StartServer(t, params)
  1669  	defer s.Stopper().Stop(context.Background())
  1670  
  1671  	// We can't pass socket paths as url.Host to libpq, use ?host=/... instead.
  1672  	options := url.Values{
  1673  		"host": []string{tempDir},
  1674  	}
  1675  	pgURL := url.URL{
  1676  		Scheme:   "postgres",
  1677  		User:     url.User(security.RootUser),
  1678  		Host:     net.JoinHostPort("", port),
  1679  		RawQuery: options.Encode(),
  1680  	}
  1681  	if err := trivialQuery(pgURL); err != nil {
  1682  		t.Fatal(err)
  1683  	}
  1684  }
  1685  
  1686  func TestPGWireResultChange(t *testing.T) {
  1687  	defer leaktest.AfterTest(t)()
  1688  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
  1689  	defer s.Stopper().Stop(context.Background())
  1690  
  1691  	pgURL, cleanupFn := sqlutils.PGUrl(t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
  1692  	defer cleanupFn()
  1693  
  1694  	db, err := gosql.Open("postgres", pgURL.String())
  1695  	if err != nil {
  1696  		t.Fatal(err)
  1697  	}
  1698  	defer db.Close()
  1699  	if _, err := db.Exec(`CREATE DATABASE testing`); err != nil {
  1700  		t.Fatal(err)
  1701  	}
  1702  	if _, err := db.Exec(`CREATE TABLE testing.f (v INT)`); err != nil {
  1703  		t.Fatal(err)
  1704  	}
  1705  	stmt, err := db.Prepare(`SELECT * FROM testing.f`)
  1706  	if err != nil {
  1707  		t.Fatal(err)
  1708  	}
  1709  	if _, err := db.Exec(`ALTER TABLE testing.f ADD COLUMN u int`); err != nil {
  1710  		t.Fatal(err)
  1711  	}
  1712  	if _, err := db.Exec(`INSERT INTO testing.f VALUES (1, 2)`); err != nil {
  1713  		t.Fatal(err)
  1714  	}
  1715  	if _, err := stmt.Exec(); !testutils.IsError(err, "must not change result type") {
  1716  		t.Fatalf("unexpected error: %v", err)
  1717  	}
  1718  	if err := stmt.Close(); err != nil {
  1719  		t.Fatal(err)
  1720  	}
  1721  
  1722  	// Test that an INSERT RETURNING will not commit data.
  1723  	stmt, err = db.Prepare(`INSERT INTO testing.f VALUES ($1, $2) RETURNING *`)
  1724  	if err != nil {
  1725  		t.Fatal(err)
  1726  	}
  1727  	if _, err := db.Exec(`ALTER TABLE testing.f ADD COLUMN t int`); err != nil {
  1728  		t.Fatal(err)
  1729  	}
  1730  	var count int
  1731  	if err := db.QueryRow(`SELECT count(*) FROM testing.f`).Scan(&count); err != nil {
  1732  		t.Fatal(err)
  1733  	}
  1734  	if _, err := stmt.Exec(3, 4); !testutils.IsError(err, "must not change result type") {
  1735  		t.Fatalf("unexpected error: %v", err)
  1736  	}
  1737  	if err := stmt.Close(); err != nil {
  1738  		t.Fatal(err)
  1739  	}
  1740  	var countAfter int
  1741  	if err := db.QueryRow(`SELECT count(*) FROM testing.f`).Scan(&countAfter); err != nil {
  1742  		t.Fatal(err)
  1743  	}
  1744  	if count != countAfter {
  1745  		t.Fatalf("expected %d rows, got %d", count, countAfter)
  1746  	}
  1747  }
  1748  
  1749  func TestSessionParameters(t *testing.T) {
  1750  	defer leaktest.AfterTest(t)()
  1751  
  1752  	params := base.TestServerArgs{Insecure: true}
  1753  	s, _, _ := serverutils.StartServer(t, params)
  1754  
  1755  	ctx := context.Background()
  1756  	defer s.Stopper().Stop(ctx)
  1757  
  1758  	host, ports, _ := net.SplitHostPort(s.ServingSQLAddr())
  1759  	port, _ := strconv.Atoi(ports)
  1760  
  1761  	connCfg := pgx.ConnConfig{
  1762  		Host:      host,
  1763  		Port:      uint16(port),
  1764  		User:      security.RootUser,
  1765  		TLSConfig: nil, // insecure
  1766  		Logger:    pgxTestLogger{},
  1767  	}
  1768  
  1769  	testData := []struct {
  1770  		varName        string
  1771  		val            string
  1772  		expectedStatus bool
  1773  		expectedSet    bool
  1774  		expectedErr    string
  1775  	}{
  1776  		// Unknown parameters are tolerated without error (a warning will be logged).
  1777  		{"foo", "bar", false, false, ``},
  1778  		// Known parameters are checked to actually be set, even session vars which
  1779  		// are not valid server status params can be set.
  1780  		{"extra_float_digits", "3", false, true, ``},
  1781  		{"extra_float_digits", "-3", false, true, ``},
  1782  		{"distsql", "off", false, true, ``},
  1783  		{"distsql", "auto", false, true, ``},
  1784  		// Case does not matter to set, but the server will reply with special cased
  1785  		// variables.
  1786  		{"timezone", "Europe/Paris", false, true, ``},
  1787  		{"TimeZone", "Europe/Amsterdam", true, true, ``},
  1788  		{"datestyle", "ISO, MDY", false, true, ``},
  1789  		{"DateStyle", "ISO, MDY", true, true, ``},
  1790  		// Known parameters that definitely cannot be set will cause an error.
  1791  		{"server_version", "bar", false, false, `parameter "server_version" cannot be changed.*55P02`},
  1792  		// Erroneous values are also rejected.
  1793  		{"extra_float_digits", "42", false, false, `42 is outside the valid range for parameter "extra_float_digits".*22023`},
  1794  		{"datestyle", "woo", false, false, `invalid value for parameter "DateStyle".*22023`},
  1795  	}
  1796  
  1797  	for _, test := range testData {
  1798  		t.Run(test.varName+"="+test.val, func(t *testing.T) {
  1799  			cfg := connCfg
  1800  			cfg.RuntimeParams = map[string]string{test.varName: test.val}
  1801  			db, err := pgx.Connect(cfg)
  1802  			t.Logf("conn error: %v", err)
  1803  			if !testutils.IsError(err, test.expectedErr) {
  1804  				t.Fatalf("expected %q, got %v", test.expectedErr, err)
  1805  			}
  1806  			if err != nil {
  1807  				return
  1808  			}
  1809  			defer func() { _ = db.Close() }()
  1810  
  1811  			for k, v := range db.RuntimeParams {
  1812  				t.Logf("received runtime param %s = %q", k, v)
  1813  			}
  1814  
  1815  			// If the session var is also a valid status param, then check
  1816  			// the requested value was processed.
  1817  			if test.expectedStatus {
  1818  				serverVal := db.RuntimeParams[test.varName]
  1819  				if serverVal != test.val {
  1820  					t.Fatalf("initial server status %v: got %q, expected %q",
  1821  						test.varName, serverVal, test.val)
  1822  				}
  1823  			}
  1824  
  1825  			// Check the value also inside the session.
  1826  			rows, err := db.Query("SHOW " + test.varName)
  1827  			if err != nil {
  1828  				// Check that the value was not expected to be settable.
  1829  				// (The set was ignored).
  1830  				if !test.expectedSet && strings.Contains(err.Error(), "unrecognized configuration parameter") {
  1831  					return
  1832  				}
  1833  				t.Fatal(err)
  1834  			}
  1835  			// Check that the value set was the value sent by the client.
  1836  			if !rows.Next() {
  1837  				t.Fatal("too short")
  1838  			}
  1839  			if err := rows.Err(); err != nil {
  1840  				t.Fatal(err)
  1841  			}
  1842  			var gotVal string
  1843  			if err := rows.Scan(&gotVal); err != nil {
  1844  				t.Fatal(err)
  1845  			}
  1846  			if rows.Next() {
  1847  				_ = rows.Scan(&gotVal)
  1848  				t.Fatalf("expected no more rows, got %v", gotVal)
  1849  			}
  1850  			t.Logf("server says %s = %q", test.varName, gotVal)
  1851  			if gotVal != test.val {
  1852  				t.Fatalf("expected %q, got %q", test.val, gotVal)
  1853  			}
  1854  		})
  1855  	}
  1856  }
  1857  
  1858  type pgxTestLogger struct{}
  1859  
  1860  func (l pgxTestLogger) Log(level pgx.LogLevel, msg string, data map[string]interface{}) {
  1861  	log.Infof(context.Background(), "pgx log [%s] %s - %s", level, msg, data)
  1862  }
  1863  
  1864  // pgxTestLogger implements pgx.Logger.
  1865  var _ pgx.Logger = pgxTestLogger{}
  1866  
  1867  func TestCancelRequest(t *testing.T) {
  1868  	defer leaktest.AfterTest(t)()
  1869  
  1870  	testutils.RunTrueAndFalse(t, "insecure", func(t *testing.T, insecure bool) {
  1871  		params := base.TestServerArgs{Insecure: insecure}
  1872  		s, _, _ := serverutils.StartServer(t, params)
  1873  
  1874  		ctx := context.Background()
  1875  		defer s.Stopper().Stop(ctx)
  1876  
  1877  		var d net.Dialer
  1878  		conn, err := d.DialContext(ctx, "tcp", s.ServingSQLAddr())
  1879  		if err != nil {
  1880  			t.Fatal(err)
  1881  		}
  1882  		defer conn.Close()
  1883  
  1884  		// Reset telemetry so we get a deterministic count below.
  1885  		_ = telemetry.GetFeatureCounts(telemetry.Raw, telemetry.ResetCounts)
  1886  
  1887  		fe, err := pgproto3.NewFrontend(conn, conn)
  1888  		if err != nil {
  1889  			t.Fatal(err)
  1890  		}
  1891  		// versionCancel is the special code sent as header for cancel requests.
  1892  		// See: https://www.postgresql.org/docs/current/protocol-message-formats.html
  1893  		// and the explanation in server.go.
  1894  		const versionCancel = 80877102
  1895  		if err := fe.Send(&pgproto3.StartupMessage{ProtocolVersion: versionCancel}); err != nil {
  1896  			t.Fatal(err)
  1897  		}
  1898  		if _, err := fe.Receive(); err != io.EOF {
  1899  			t.Fatalf("unexpected: %v", err)
  1900  		}
  1901  		if count := telemetry.GetRawFeatureCounts()["pgwire.unimplemented.cancel_request"]; count != 1 {
  1902  			t.Fatalf("expected 1 cancel request, got %d", count)
  1903  		}
  1904  	})
  1905  }
  1906  
  1907  func TestFailPrepareFailsTxn(t *testing.T) {
  1908  	defer leaktest.AfterTest(t)()
  1909  
  1910  	s, _, _ := serverutils.StartServer(t, base.TestServerArgs{})
  1911  	defer s.Stopper().Stop(context.Background())
  1912  
  1913  	pgURL, cleanupFn := sqlutils.PGUrl(t, s.ServingSQLAddr(), t.Name(), url.User(security.RootUser))
  1914  	defer cleanupFn()
  1915  
  1916  	db, err := gosql.Open("postgres", pgURL.String())
  1917  	if err != nil {
  1918  		t.Fatal(err)
  1919  	}
  1920  	defer db.Close()
  1921  
  1922  	tx, err := db.Begin()
  1923  	if err != nil {
  1924  		t.Fatal(err)
  1925  	}
  1926  	if _, err := tx.Prepare("select fail"); err == nil {
  1927  		t.Fatal("Got no error, expected one")
  1928  	}
  1929  
  1930  	// This should also fail, since the txn should be destroyed.
  1931  	if _, err := tx.Query("select 1"); err == nil {
  1932  		t.Fatal("got no error, expected one")
  1933  	}
  1934  	if err := tx.Rollback(); err != nil {
  1935  		t.Fatal(err)
  1936  	}
  1937  }