github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cli/sql_util.go (about)

     1  // Copyright 2015 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package cli
    12  
    13  import (
    14  	"context"
    15  	"database/sql/driver"
    16  	"fmt"
    17  	"io"
    18  	"net/url"
    19  	"reflect"
    20  	"strconv"
    21  	"strings"
    22  	"time"
    23  	"unicode"
    24  	"unicode/utf8"
    25  
    26  	"github.com/cockroachdb/cockroach-go/crdb"
    27  	"github.com/cockroachdb/cockroach/pkg/build"
    28  	"github.com/cockroachdb/cockroach/pkg/roachpb"
    29  	"github.com/cockroachdb/cockroach/pkg/security"
    30  	"github.com/cockroachdb/cockroach/pkg/sql/lex"
    31  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    32  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    33  	"github.com/cockroachdb/cockroach/pkg/util/envutil"
    34  	"github.com/cockroachdb/cockroach/pkg/util/log"
    35  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    36  	"github.com/cockroachdb/cockroach/pkg/util/version"
    37  	"github.com/cockroachdb/errors"
    38  	"github.com/lib/pq"
    39  )
    40  
    41  type sqlConnI interface {
    42  	driver.Conn
    43  	//lint:ignore SA1019 TODO(mjibson): clean this up to use go1.8 APIs
    44  	driver.Execer
    45  	//lint:ignore SA1019 TODO(mjibson): clean this up to use go1.8 APIs
    46  	driver.Queryer
    47  }
    48  
    49  type sqlConn struct {
    50  	url          string
    51  	conn         sqlConnI
    52  	reconnecting bool
    53  
    54  	pendingNotices []*pq.Error
    55  
    56  	// delayNotices, if set, makes notices accumulate for printing
    57  	// when the SQL execution completes. The default (false)
    58  	// indicates that notices must be printed as soon as they are received.
    59  	// This is used by the Query() interface to avoid interleaving
    60  	// notices with result rows.
    61  	delayNotices bool
    62  
    63  	// dbName is the last known current database, to be reconfigured in
    64  	// case of automatic reconnects.
    65  	dbName string
    66  
    67  	serverVersion string // build.Info.Tag (short version, like 1.0.3)
    68  	serverBuild   string // build.Info.Short (version, platform, etc summary)
    69  
    70  	// clusterID and serverBuildInfo are the last known corresponding
    71  	// values from the server, used to report any changes upon
    72  	// (re)connects.
    73  	clusterID           string
    74  	clusterOrganization string
    75  }
    76  
    77  // initialSQLConnectionError signals to the error decorator in
    78  // error.go that we're failing during the initial connection set-up.
    79  type initialSQLConnectionError struct {
    80  	err error
    81  }
    82  
    83  // Error implements the error interface.
    84  func (i *initialSQLConnectionError) Error() string { return i.err.Error() }
    85  
    86  // Cause implements causer.
    87  func (i *initialSQLConnectionError) Cause() error { return i.err }
    88  
    89  // Format implements fmt.Formatter.
    90  func (i *initialSQLConnectionError) Format(s fmt.State, verb rune) { errors.FormatError(i, s, verb) }
    91  
    92  // FormatError implements errors.Formatter.
    93  func (i *initialSQLConnectionError) FormatError(p errors.Printer) error {
    94  	if p.Detail() {
    95  		p.Print("error while establishing the SQL session")
    96  	}
    97  	return i.err
    98  }
    99  
   100  // wrapConnError detects TCP EOF errors during the initial SQL handshake.
   101  // These are translated to a message "perhaps this is not a CockroachDB node"
   102  // at the top level.
   103  // EOF errors later in the SQL session should not be wrapped in that way,
   104  // because by that time we've established that the server is indeed a SQL
   105  // server.
   106  func wrapConnError(err error) error {
   107  	errMsg := err.Error()
   108  	if errMsg == "EOF" || errMsg == "unexpected EOF" {
   109  		return &initialSQLConnectionError{err}
   110  	}
   111  	return err
   112  }
   113  
   114  func (c *sqlConn) flushNotices() {
   115  	for _, notice := range c.pendingNotices {
   116  		cliOutputError(stderr, notice, true /*showSeverity*/, false /*verbose*/)
   117  	}
   118  	c.pendingNotices = nil
   119  	c.delayNotices = false
   120  }
   121  
   122  func (c *sqlConn) handleNotice(notice *pq.Error) {
   123  	c.pendingNotices = append(c.pendingNotices, notice)
   124  	if !c.delayNotices {
   125  		c.flushNotices()
   126  	}
   127  }
   128  
   129  func (c *sqlConn) ensureConn() error {
   130  	if c.conn == nil {
   131  		if c.reconnecting && cliCtx.isInteractive {
   132  			fmt.Fprintf(stderr, "warning: connection lost!\n"+
   133  				"opening new connection: all session settings will be lost\n")
   134  		}
   135  		base, err := pq.NewConnector(c.url)
   136  		if err != nil {
   137  			return wrapConnError(err)
   138  		}
   139  		// Add a notice handler - re-use the cliOutputError function in this case.
   140  		connector := pq.ConnectorWithNoticeHandler(base, func(notice *pq.Error) {
   141  			c.handleNotice(notice)
   142  		})
   143  		// TODO(cli): we can't thread ctx through ensureConn usages, as it needs
   144  		// to follow the gosql.DB interface. We should probably look at initializing
   145  		// connections only once instead. The context is only used for dialing.
   146  		conn, err := connector.Connect(context.TODO())
   147  		if err != nil {
   148  			return wrapConnError(err)
   149  		}
   150  		if c.reconnecting && c.dbName != "" {
   151  			// Attempt to reset the current database.
   152  			if _, err := conn.(sqlConnI).Exec(
   153  				`SET DATABASE = `+tree.NameStringP(&c.dbName), nil,
   154  			); err != nil {
   155  				fmt.Fprintf(stderr, "warning: unable to restore current database: %v\n", err)
   156  			}
   157  		}
   158  		c.conn = conn.(sqlConnI)
   159  		if err := c.checkServerMetadata(); err != nil {
   160  			c.Close()
   161  			return wrapConnError(err)
   162  		}
   163  		c.reconnecting = false
   164  	}
   165  	return nil
   166  }
   167  
   168  func (c *sqlConn) getServerMetadata() (
   169  	nodeID roachpb.NodeID,
   170  	version, clusterID string,
   171  	err error,
   172  ) {
   173  	// Retrieve the node ID and server build info.
   174  	rows, err := c.Query("SELECT * FROM crdb_internal.node_build_info", nil)
   175  	if errors.Is(err, driver.ErrBadConn) {
   176  		return 0, "", "", err
   177  	}
   178  	if err != nil {
   179  		return 0, "", "", err
   180  	}
   181  	defer func() { _ = rows.Close() }()
   182  
   183  	// Read the node_build_info table as an array of strings.
   184  	rowVals, err := getAllRowStrings(rows, true /* showMoreChars */)
   185  	if err != nil || len(rowVals) == 0 || len(rowVals[0]) != 3 {
   186  		return 0, "", "", errors.New("incorrect data while retrieving the server version")
   187  	}
   188  
   189  	// Extract the version fields from the query results.
   190  	var v10fields [5]string
   191  	for _, row := range rowVals {
   192  		switch row[1] {
   193  		case "ClusterID":
   194  			clusterID = row[2]
   195  		case "Version":
   196  			version = row[2]
   197  		case "Build":
   198  			c.serverBuild = row[2]
   199  		case "Organization":
   200  			c.clusterOrganization = row[2]
   201  			id, err := strconv.Atoi(row[0])
   202  			if err != nil {
   203  				return 0, "", "", errors.New("incorrect data while retrieving node id")
   204  			}
   205  			nodeID = roachpb.NodeID(id)
   206  
   207  			// Fields for v1.0 compatibility.
   208  		case "Distribution":
   209  			v10fields[0] = row[2]
   210  		case "Tag":
   211  			v10fields[1] = row[2]
   212  		case "Platform":
   213  			v10fields[2] = row[2]
   214  		case "Time":
   215  			v10fields[3] = row[2]
   216  		case "GoVersion":
   217  			v10fields[4] = row[2]
   218  		}
   219  	}
   220  
   221  	if version == "" {
   222  		// The "Version" field was not present, this indicates a v1.0
   223  		// CockroachDB. Use that below.
   224  		version = "v1.0-" + v10fields[1]
   225  		c.serverBuild = fmt.Sprintf("CockroachDB %s %s (%s, built %s, %s)",
   226  			v10fields[0], version, v10fields[2], v10fields[3], v10fields[4])
   227  	}
   228  	return nodeID, version, clusterID, nil
   229  }
   230  
   231  // checkServerMetadata reports the server version and cluster ID
   232  // upon the initial connection or if either has changed since
   233  // the last connection, based on the last known values in the sqlConn
   234  // struct.
   235  func (c *sqlConn) checkServerMetadata() error {
   236  	if !cliCtx.isInteractive {
   237  		// Version reporting is just noise if the user is not present to
   238  		// change their mind upon seeing the information.
   239  		return nil
   240  	}
   241  
   242  	_, newServerVersion, newClusterID, err := c.getServerMetadata()
   243  	if errors.Is(err, driver.ErrBadConn) {
   244  		return err
   245  	}
   246  	if err != nil {
   247  		// It is not an error that the server version cannot be retrieved.
   248  		fmt.Fprintf(stderr, "warning: unable to retrieve the server's version: %s\n", err)
   249  	}
   250  
   251  	// Report the server version only if it the revision has been
   252  	// fetched successfully, and the revision has changed since the last
   253  	// connection.
   254  	if newServerVersion != c.serverVersion {
   255  		c.serverVersion = newServerVersion
   256  
   257  		isSame := ""
   258  		// We compare just the version (`build.Info.Tag`), whereas we *display* the
   259  		// the full build summary (version, platform, etc) string
   260  		// (`build.Info.Short()`). This is because we don't care if they're
   261  		// different platforms/build tools/timestamps. The important bit exposed by
   262  		// a version mismatch is the wire protocol and SQL dialect.
   263  		client := build.GetInfo()
   264  		if c.serverVersion != client.Tag {
   265  			fmt.Println("# Client version:", client.Short())
   266  		} else {
   267  			isSame = " (same version as client)"
   268  		}
   269  		fmt.Printf("# Server version: %s%s\n", c.serverBuild, isSame)
   270  
   271  		sv, err := version.Parse(c.serverVersion)
   272  		if err == nil {
   273  			cv, err := version.Parse(client.Tag)
   274  			if err == nil {
   275  				if sv.Compare(cv) == -1 { // server ver < client ver
   276  					fmt.Fprintln(stderr, "\nwarning: server version older than client! "+
   277  						"proceed with caution; some features may not be available.\n")
   278  				}
   279  			}
   280  		}
   281  	}
   282  
   283  	// Report the cluster ID only if it it could be fetched
   284  	// successfully, and it has changed since the last connection.
   285  	if old := c.clusterID; newClusterID != c.clusterID {
   286  		c.clusterID = newClusterID
   287  		if old != "" {
   288  			return errors.Errorf("the cluster ID has changed!\nPrevious ID: %s\nNew ID: %s",
   289  				old, newClusterID)
   290  		}
   291  		c.clusterID = newClusterID
   292  		fmt.Println("# Cluster ID:", c.clusterID)
   293  		if c.clusterOrganization != "" {
   294  			fmt.Println("# Organization:", c.clusterOrganization)
   295  		}
   296  	}
   297  
   298  	return nil
   299  }
   300  
   301  // requireServerVersion returns an error if the version of the connected server
   302  // is not at least the given version.
   303  func (c *sqlConn) requireServerVersion(required *version.Version) error {
   304  	_, versionString, _, err := c.getServerMetadata()
   305  	if err != nil {
   306  		return err
   307  	}
   308  	vers, err := version.Parse(versionString)
   309  	if err != nil {
   310  		return fmt.Errorf("unable to parse server version %q", versionString)
   311  	}
   312  	if !vers.AtLeast(required) {
   313  		return fmt.Errorf("incompatible client and server versions (detected server version: %s, required: %s)",
   314  			vers, required)
   315  	}
   316  	return nil
   317  }
   318  
   319  // getServerValue retrieves the first driverValue returned by the
   320  // given sql query. If the query fails or does not return a single
   321  // column, `false` is returned in the second result.
   322  func (c *sqlConn) getServerValue(what, sql string) (driver.Value, bool) {
   323  	var dbVals [1]driver.Value
   324  
   325  	rows, err := c.Query(sql, nil)
   326  	if err != nil {
   327  		fmt.Fprintf(stderr, "warning: error retrieving the %s: %v\n", what, err)
   328  		return nil, false
   329  	}
   330  	defer func() { _ = rows.Close() }()
   331  
   332  	if len(rows.Columns()) == 0 {
   333  		fmt.Fprintf(stderr, "warning: cannot get the %s\n", what)
   334  		return nil, false
   335  	}
   336  
   337  	err = rows.Next(dbVals[:])
   338  	if err != nil {
   339  		fmt.Fprintf(stderr, "warning: invalid %s: %v\n", what, err)
   340  		return nil, false
   341  	}
   342  
   343  	return dbVals[0], true
   344  }
   345  
   346  // sqlTxnShim implements the crdb.Tx interface.
   347  //
   348  // It exists to support crdb.ExecuteInTxn. Normally, we'd hand crdb.ExecuteInTxn
   349  // a sql.Txn, but sqlConn predates go1.8's support for multiple result sets and
   350  // so deals directly with the lib/pq driver. See #14964.
   351  type sqlTxnShim struct {
   352  	conn *sqlConn
   353  }
   354  
   355  var _ crdb.Tx = sqlTxnShim{}
   356  
   357  func (t sqlTxnShim) Commit(context.Context) error {
   358  	return t.conn.Exec(`COMMIT`, nil)
   359  }
   360  
   361  func (t sqlTxnShim) Rollback(context.Context) error {
   362  	return t.conn.Exec(`ROLLBACK`, nil)
   363  }
   364  
   365  func (t sqlTxnShim) Exec(_ context.Context, query string, values ...interface{}) error {
   366  	if len(values) != 0 {
   367  		panic(fmt.Sprintf("sqlTxnShim.ExecContext must not be called with values"))
   368  	}
   369  	return t.conn.Exec(query, nil)
   370  }
   371  
   372  // ExecTxn runs fn inside a transaction and retries it as needed.
   373  // On non-retryable failures, the transaction is aborted and rolled
   374  // back; on success, the transaction is committed.
   375  //
   376  // NOTE: the supplied closure should not have external side
   377  // effects beyond changes to the database.
   378  func (c *sqlConn) ExecTxn(fn func(*sqlConn) error) (err error) {
   379  	if err := c.Exec(`BEGIN`, nil); err != nil {
   380  		return err
   381  	}
   382  	return crdb.ExecuteInTx(context.TODO(), sqlTxnShim{c}, func() error {
   383  		return fn(c)
   384  	})
   385  }
   386  
   387  func (c *sqlConn) Exec(query string, args []driver.Value) error {
   388  	if err := c.ensureConn(); err != nil {
   389  		return err
   390  	}
   391  	if sqlCtx.echo {
   392  		fmt.Fprintln(stderr, ">", query)
   393  	}
   394  	_, err := c.conn.Exec(query, args)
   395  	c.flushNotices()
   396  	if errors.Is(err, driver.ErrBadConn) {
   397  		c.reconnecting = true
   398  		c.Close()
   399  	}
   400  	return err
   401  }
   402  
   403  func (c *sqlConn) Query(query string, args []driver.Value) (*sqlRows, error) {
   404  	if err := c.ensureConn(); err != nil {
   405  		return nil, err
   406  	}
   407  	if sqlCtx.echo {
   408  		fmt.Fprintln(stderr, ">", query)
   409  	}
   410  	rows, err := c.conn.Query(query, args)
   411  	if errors.Is(err, driver.ErrBadConn) {
   412  		c.reconnecting = true
   413  		c.Close()
   414  	}
   415  	if err != nil {
   416  		return nil, err
   417  	}
   418  	return &sqlRows{rows: rows.(sqlRowsI), conn: c}, nil
   419  }
   420  
   421  func (c *sqlConn) QueryRow(query string, args []driver.Value) ([]driver.Value, error) {
   422  	rows, err := makeQuery(query, args...)(c)
   423  	if err != nil {
   424  		return nil, err
   425  	}
   426  	defer func() { _ = rows.Close() }()
   427  	vals := make([]driver.Value, len(rows.Columns()))
   428  	err = rows.Next(vals)
   429  
   430  	// Assert that there is just one row.
   431  	if err == nil {
   432  		nextVals := make([]driver.Value, len(rows.Columns()))
   433  		nextErr := rows.Next(nextVals)
   434  		if nextErr != io.EOF {
   435  			if nextErr != nil {
   436  				return nil, err
   437  			}
   438  			return nil, fmt.Errorf("programming error: %q: expected just 1 row of result, got more", query)
   439  		}
   440  	}
   441  
   442  	return vals, err
   443  }
   444  
   445  func (c *sqlConn) Close() {
   446  	c.flushNotices()
   447  	if c.conn != nil {
   448  		err := c.conn.Close()
   449  		if err != nil && !errors.Is(err, driver.ErrBadConn) {
   450  			log.Infof(context.TODO(), "%v", err)
   451  		}
   452  		c.conn = nil
   453  	}
   454  }
   455  
   456  type sqlRowsI interface {
   457  	driver.RowsColumnTypeScanType
   458  	Result() driver.Result
   459  	Tag() string
   460  
   461  	// Go 1.8 multiple result set interfaces.
   462  	// TODO(mjibson): clean this up after 1.8 is released.
   463  	HasNextResultSet() bool
   464  	NextResultSet() error
   465  }
   466  
   467  type sqlRows struct {
   468  	rows sqlRowsI
   469  	conn *sqlConn
   470  }
   471  
   472  func (r *sqlRows) Columns() []string {
   473  	return r.rows.Columns()
   474  }
   475  
   476  func (r *sqlRows) Result() driver.Result {
   477  	return r.rows.Result()
   478  }
   479  
   480  func (r *sqlRows) Tag() string {
   481  	return r.rows.Tag()
   482  }
   483  
   484  func (r *sqlRows) Close() error {
   485  	r.conn.flushNotices()
   486  	err := r.rows.Close()
   487  	if errors.Is(err, driver.ErrBadConn) {
   488  		r.conn.reconnecting = true
   489  		r.conn.Close()
   490  	}
   491  	return err
   492  }
   493  
   494  // Next populates values with the next row of results. []byte values are copied
   495  // so that subsequent calls to Next and Close do not mutate values. This
   496  // makes it slower than theoretically possible but the safety concerns
   497  // (since this is unobvious and unexpected behavior) outweigh.
   498  func (r *sqlRows) Next(values []driver.Value) error {
   499  	err := r.rows.Next(values)
   500  	if errors.Is(err, driver.ErrBadConn) {
   501  		r.conn.reconnecting = true
   502  		r.conn.Close()
   503  	}
   504  	for i, v := range values {
   505  		if b, ok := v.([]byte); ok {
   506  			values[i] = append([]byte{}, b...)
   507  		}
   508  	}
   509  	// After the first row was received, we want to delay all
   510  	// further notices until the end of execution.
   511  	r.conn.delayNotices = true
   512  	return err
   513  }
   514  
   515  // NextResultSet prepares the next result set for reading.
   516  func (r *sqlRows) NextResultSet() (bool, error) {
   517  	if !r.rows.HasNextResultSet() {
   518  		return false, nil
   519  	}
   520  	return true, r.rows.NextResultSet()
   521  }
   522  
   523  func (r *sqlRows) ColumnTypeScanType(index int) reflect.Type {
   524  	return r.rows.ColumnTypeScanType(index)
   525  }
   526  
   527  func makeSQLConn(url string) *sqlConn {
   528  	return &sqlConn{
   529  		url: url,
   530  	}
   531  }
   532  
   533  // sqlConnTimeout is the default SQL connect timeout. This can also be
   534  // set using `connect_timeout` in the connection URL. The default of
   535  // 15 seconds is chosen to exceed the default password retrieval
   536  // timeout (system.user_login.timeout).
   537  var sqlConnTimeout = envutil.EnvOrDefaultString("COCKROACH_CONNECT_TIMEOUT", "15")
   538  
   539  // defaultSQLDb describes how a missing database part in the SQL
   540  // connection string is processed when creating a client connection.
   541  type defaultSQLDb int
   542  
   543  const (
   544  	// useSystemDb means that a missing database will be overridden with
   545  	// "system".
   546  	useSystemDb defaultSQLDb = iota
   547  	// useDefaultDb means that a missing database will be left as-is so
   548  	// that the server can default to "defaultdb".
   549  	useDefaultDb
   550  )
   551  
   552  // makeSQLClient connects to the database using the connection
   553  // settings set by the command-line flags.
   554  // If a password is needed, it also prompts for the password.
   555  //
   556  // If forceSystemDB is set, it also connects it to the `system`
   557  // database. The --database flag or database part in the URL is then
   558  // ignored.
   559  //
   560  // The appName given as argument is added to the URL even if --url is
   561  // specified, but only if the URL didn't already specify
   562  // application_name. It is prefixed with '$ ' to mark it as internal.
   563  func makeSQLClient(appName string, defaultMode defaultSQLDb) (*sqlConn, error) {
   564  	baseURL, err := cliCtx.makeClientConnURL()
   565  	if err != nil {
   566  		return nil, err
   567  	}
   568  
   569  	if defaultMode == useSystemDb && baseURL.Path == "" {
   570  		// Override the target database. This is because the current
   571  		// database can influence the output of CLI commands, and in the
   572  		// case where the database is missing it will default server-wise to
   573  		// `defaultdb` which may not exist.
   574  		baseURL.Path = "system"
   575  	}
   576  
   577  	// If there is no user in the URL already, fill in the default user.
   578  	if baseURL.User.Username() == "" {
   579  		baseURL.User = url.User(security.RootUser)
   580  	}
   581  
   582  	options, err := url.ParseQuery(baseURL.RawQuery)
   583  	if err != nil {
   584  		return nil, err
   585  	}
   586  
   587  	// Insecure connections are insecure and should never see a password. Reject
   588  	// one that may be present in the URL already.
   589  	if options.Get("sslmode") == "disable" {
   590  		if _, pwdSet := baseURL.User.Password(); pwdSet {
   591  			return nil, errors.Errorf("cannot specify a password in URL with an insecure connection")
   592  		}
   593  	} else {
   594  		if options.Get("sslcert") == "" || options.Get("sslkey") == "" {
   595  			// If there's no password in the URL yet and we don't have a client
   596  			// certificate, ask for it and populate it in the URL.
   597  			if _, pwdSet := baseURL.User.Password(); !pwdSet {
   598  				pwd, err := security.PromptForPassword()
   599  				if err != nil {
   600  					return nil, err
   601  				}
   602  				baseURL.User = url.UserPassword(baseURL.User.Username(), pwd)
   603  			}
   604  		}
   605  	}
   606  
   607  	// Load the application name. It's not a command-line flag, so
   608  	// anything already in the URL should take priority.
   609  	if options.Get("application_name") == "" && appName != "" {
   610  		options.Set("application_name", sqlbase.ReportableAppNamePrefix+appName)
   611  	}
   612  
   613  	// Set a connection timeout if none is provided already. This
   614  	// ensures that if the server was not initialized or there is some
   615  	// network issue, the client will not be left to hang forever.
   616  	//
   617  	// This is a lib/pq feature.
   618  	if options.Get("connect_timeout") == "" {
   619  		options.Set("connect_timeout", sqlConnTimeout)
   620  	}
   621  
   622  	baseURL.RawQuery = options.Encode()
   623  	sqlURL := baseURL.String()
   624  
   625  	if log.V(2) {
   626  		log.Infof(context.Background(), "connecting with URL: %s", sqlURL)
   627  	}
   628  
   629  	return makeSQLConn(sqlURL), nil
   630  }
   631  
   632  type queryFunc func(conn *sqlConn) (*sqlRows, error)
   633  
   634  func makeQuery(query string, parameters ...driver.Value) queryFunc {
   635  	return func(conn *sqlConn) (*sqlRows, error) {
   636  		// driver.Value is an alias for interface{}, but must adhere to a restricted
   637  		// set of types when being passed to driver.Queryer.Query (see
   638  		// driver.IsValue). We use driver.DefaultParameterConverter to perform the
   639  		// necessary conversion. This is usually taken care of by the sql package,
   640  		// but we have to do so manually because we're talking directly to the
   641  		// driver.
   642  		for i := range parameters {
   643  			var err error
   644  			parameters[i], err = driver.DefaultParameterConverter.ConvertValue(parameters[i])
   645  			if err != nil {
   646  				return nil, err
   647  			}
   648  		}
   649  		return conn.Query(query, parameters)
   650  	}
   651  }
   652  
   653  // runQuery takes a 'query' with optional 'parameters'.
   654  // It runs the sql query and returns a list of columns names and a list of rows.
   655  func runQuery(conn *sqlConn, fn queryFunc, showMoreChars bool) ([]string, [][]string, error) {
   656  	rows, err := fn(conn)
   657  	if err != nil {
   658  		return nil, nil, err
   659  	}
   660  
   661  	defer func() { _ = rows.Close() }()
   662  	return sqlRowsToStrings(rows, showMoreChars)
   663  }
   664  
   665  // handleCopyError ensures the user is properly informed when they issue
   666  // a COPY statement somewhere in their input.
   667  func handleCopyError(conn *sqlConn, err error) error {
   668  	if !strings.HasPrefix(err.Error(), "pq: unknown response for simple query: 'G'") {
   669  		return err
   670  	}
   671  
   672  	// The COPY statement has hosed the connection by putting the
   673  	// protocol in a state that lib/pq cannot understand any more. Reset
   674  	// it.
   675  	conn.Close()
   676  	conn.reconnecting = true
   677  	return errors.New("woops! COPY has confused this client! Suggestion: use 'psql' for COPY")
   678  }
   679  
   680  // All tags where the RowsAffected value should be reported to
   681  // the user.
   682  var tagsWithRowsAffected = map[string]struct{}{
   683  	"INSERT":    {},
   684  	"UPDATE":    {},
   685  	"DELETE":    {},
   686  	"DROP USER": {},
   687  	// This one is used with e.g. CREATE TABLE AS (other SELECT
   688  	// statements have type Rows, not RowsAffected).
   689  	"SELECT": {},
   690  }
   691  
   692  // runQueryAndFormatResults takes a 'query' with optional 'parameters'.
   693  // It runs the sql query and writes output to 'w'.
   694  func runQueryAndFormatResults(conn *sqlConn, w io.Writer, fn queryFunc) error {
   695  	startTime := timeutil.Now()
   696  	rows, err := fn(conn)
   697  	if err != nil {
   698  		return handleCopyError(conn, err)
   699  	}
   700  	defer func() {
   701  		_ = rows.Close()
   702  	}()
   703  	for {
   704  		// lib/pq is not able to tell us before the first call to Next()
   705  		// whether a statement returns either
   706  		// - a rows result set with zero rows (e.g. SELECT on an empty table), or
   707  		// - no rows result set, but a valid value for RowsAffected (e.g. INSERT), or
   708  		// - doesn't return any rows whatsoever (e.g. SET).
   709  		//
   710  		// To distinguish them we must go through Next() somehow, which is what the
   711  		// render() function does. So we ask render() to call this noRowsHook
   712  		// when Next() has completed its work and no rows where observed, to decide
   713  		// what to do.
   714  		noRowsHook := func() (bool, error) {
   715  			res := rows.Result()
   716  			if ra, ok := res.(driver.RowsAffected); ok {
   717  				nRows, err := ra.RowsAffected()
   718  				if err != nil {
   719  					return false, err
   720  				}
   721  
   722  				// This may be either something like INSERT with a valid
   723  				// RowsAffected value, or a statement like SET. The pq driver
   724  				// uses both driver.RowsAffected for both.  So we need to be a
   725  				// little more manual.
   726  				tag := rows.Tag()
   727  				if tag == "SELECT" && nRows == 0 {
   728  					// As explained above, the pq driver unhelpfully does not
   729  					// distinguish between a statement returning zero rows and a
   730  					// statement returning an affected row count of zero.
   731  					// noRowsHook is called non-discriminatingly for both
   732  					// situations.
   733  					//
   734  					// TODO(knz): meanwhile, there are rare, non-SELECT
   735  					// statements that have tag "SELECT" but are legitimately of
   736  					// type RowsAffected. CREATE TABLE AS is one. pq's inability
   737  					// to distinguish those two cases means that any non-SELECT
   738  					// statement that legitimately returns 0 rows affected, and
   739  					// for which the user would expect to see "SELECT 0", will
   740  					// be incorrectly displayed as an empty row result set
   741  					// instead. This needs to be addressed by ensuring pq can
   742  					// distinguish the two cases, or switching to an entirely
   743  					// different driver altogether.
   744  					//
   745  					return false, nil
   746  				} else if _, ok := tagsWithRowsAffected[tag]; ok {
   747  					// INSERT, DELETE, etc.: print the row count.
   748  					nRows, err := ra.RowsAffected()
   749  					if err != nil {
   750  						return false, err
   751  					}
   752  					fmt.Fprintf(w, "%s %d\n", tag, nRows)
   753  				} else {
   754  					// SET, etc.: just print the tag, or OK if there's no tag.
   755  					if tag == "" {
   756  						tag = "OK"
   757  					}
   758  					fmt.Fprintln(w, tag)
   759  				}
   760  				return true, nil
   761  			}
   762  			// Other cases: this is a statement with a rows result set, but
   763  			// zero rows (e.g. SELECT on empty table). Let the reporter
   764  			// handle it.
   765  			return false, nil
   766  		}
   767  
   768  		cols := getColumnStrings(rows, true)
   769  		reporter, cleanup, err := makeReporter(w)
   770  		if err != nil {
   771  			return err
   772  		}
   773  
   774  		var queryCompleteTime time.Time
   775  		completedHook := func() { queryCompleteTime = timeutil.Now() }
   776  
   777  		if err := func() error {
   778  			if cleanup != nil {
   779  				defer cleanup()
   780  			}
   781  			return render(reporter, w, cols, newRowIter(rows, true), completedHook, noRowsHook)
   782  		}(); err != nil {
   783  			return err
   784  		}
   785  
   786  		if sqlCtx.showTimes {
   787  			// Present the time since the last result, or since the
   788  			// beginning of execution. Currently the execution engine makes
   789  			// all the work upfront so most of the time is accounted for by
   790  			// the 1st result; this is subject to change once CockroachDB
   791  			// evolves to stream results as statements are executed.
   792  			fmt.Fprintf(w, "\nTime: %s\n", queryCompleteTime.Sub(startTime))
   793  			// Make users better understand any discrepancy they observe.
   794  			renderDelay := timeutil.Now().Sub(queryCompleteTime)
   795  			if renderDelay >= 1*time.Second {
   796  				fmt.Fprintf(w,
   797  					"Note: an additional delay of %s was spent formatting the results.\n"+
   798  						"You can use \\set display_format to change the formatting.\n",
   799  					renderDelay)
   800  			}
   801  			fmt.Fprintln(w)
   802  			// Reset the clock. We ignore the rendering time.
   803  			startTime = timeutil.Now()
   804  		}
   805  
   806  		if more, err := rows.NextResultSet(); err != nil {
   807  			return err
   808  		} else if !more {
   809  			return nil
   810  		}
   811  	}
   812  }
   813  
   814  // sqlRowsToStrings turns 'rows' into a list of rows, each of which
   815  // is a list of column values.
   816  // 'rows' should be closed by the caller.
   817  // It returns the header row followed by all data rows.
   818  // If both the header row and list of rows are empty, it means no row
   819  // information was returned (eg: statement was not a query).
   820  // If showMoreChars is true, then more characters are not escaped.
   821  func sqlRowsToStrings(rows *sqlRows, showMoreChars bool) ([]string, [][]string, error) {
   822  	cols := getColumnStrings(rows, showMoreChars)
   823  	allRows, err := getAllRowStrings(rows, showMoreChars)
   824  	if err != nil {
   825  		return nil, nil, err
   826  	}
   827  	return cols, allRows, nil
   828  }
   829  
   830  func getColumnStrings(rows *sqlRows, showMoreChars bool) []string {
   831  	srcCols := rows.Columns()
   832  	cols := make([]string, len(srcCols))
   833  	for i, c := range srcCols {
   834  		cols[i] = formatVal(c, showMoreChars, showMoreChars)
   835  	}
   836  	return cols
   837  }
   838  
   839  func getAllRowStrings(rows *sqlRows, showMoreChars bool) ([][]string, error) {
   840  	var allRows [][]string
   841  
   842  	for {
   843  		rowStrings, err := getNextRowStrings(rows, showMoreChars)
   844  		if err != nil {
   845  			return nil, err
   846  		}
   847  		if rowStrings == nil {
   848  			break
   849  		}
   850  		allRows = append(allRows, rowStrings)
   851  	}
   852  
   853  	return allRows, nil
   854  }
   855  
   856  func getNextRowStrings(rows *sqlRows, showMoreChars bool) ([]string, error) {
   857  	cols := rows.Columns()
   858  	var vals []driver.Value
   859  	if len(cols) > 0 {
   860  		vals = make([]driver.Value, len(cols))
   861  	}
   862  
   863  	err := rows.Next(vals)
   864  	if err == io.EOF {
   865  		return nil, nil
   866  	}
   867  	if err != nil {
   868  		return nil, err
   869  	}
   870  
   871  	rowStrings := make([]string, len(cols))
   872  	for i, v := range vals {
   873  		rowStrings[i] = formatVal(v, showMoreChars, showMoreChars)
   874  	}
   875  	return rowStrings, nil
   876  }
   877  
   878  func isNotPrintableASCII(r rune) bool { return r < 0x20 || r > 0x7e || r == '"' || r == '\\' }
   879  func isNotGraphicUnicode(r rune) bool { return !unicode.IsGraphic(r) }
   880  func isNotGraphicUnicodeOrTabOrNewline(r rune) bool {
   881  	return r != '\t' && r != '\n' && !unicode.IsGraphic(r)
   882  }
   883  
   884  func formatVal(val driver.Value, showPrintableUnicode bool, showNewLinesAndTabs bool) string {
   885  	switch t := val.(type) {
   886  	case nil:
   887  		return "NULL"
   888  	case string:
   889  		if showPrintableUnicode {
   890  			pred := isNotGraphicUnicode
   891  			if showNewLinesAndTabs {
   892  				pred = isNotGraphicUnicodeOrTabOrNewline
   893  			}
   894  			if utf8.ValidString(t) && strings.IndexFunc(t, pred) == -1 {
   895  				return t
   896  			}
   897  		} else {
   898  			if strings.IndexFunc(t, isNotPrintableASCII) == -1 {
   899  				return t
   900  			}
   901  		}
   902  		s := fmt.Sprintf("%+q", t)
   903  		// Strip the start and final quotes. The surrounding display
   904  		// format (e.g. CSV/TSV) will add its own quotes.
   905  		return s[1 : len(s)-1]
   906  
   907  	case []byte:
   908  		// Format the bytes as per bytea_output = escape.
   909  		//
   910  		// We use the "escape" format here because it enables printing
   911  		// readable strings as-is -- the default hex format would always
   912  		// render as hexadecimal digits. The escape format is also more
   913  		// compact.
   914  		//
   915  		// TODO(knz): this formatting is unfortunate/incorrect, and exists
   916  		// only because lib/pq incorrectly interprets the bytes received
   917  		// from the server. The proper behavior would be for the driver to
   918  		// not interpret the bytes and for us here to print that as-is, so
   919  		// that we can let the user see and control the result using
   920  		// `bytea_output`.
   921  		return lex.EncodeByteArrayToRawBytes(string(t),
   922  			lex.BytesEncodeEscape, false /* skipHexPrefix */)
   923  
   924  	case time.Time:
   925  		return t.Format(tree.TimestampOutputFormat)
   926  	}
   927  
   928  	return fmt.Sprint(val)
   929  }