github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/pgwire/server.go (about)

     1  // Copyright 2015 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package pgwire
    12  
    13  import (
    14  	"context"
    15  	"crypto/tls"
    16  	"io"
    17  	"net"
    18  	"strings"
    19  	"sync/atomic"
    20  	"time"
    21  
    22  	"github.com/cockroachdb/cockroach/pkg/base"
    23  	"github.com/cockroachdb/cockroach/pkg/server/telemetry"
    24  	"github.com/cockroachdb/cockroach/pkg/settings"
    25  	"github.com/cockroachdb/cockroach/pkg/settings/cluster"
    26  	"github.com/cockroachdb/cockroach/pkg/sql"
    27  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/hba"
    28  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    29  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    30  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
    31  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    32  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    33  	"github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry"
    34  	"github.com/cockroachdb/cockroach/pkg/util/contextutil"
    35  	"github.com/cockroachdb/cockroach/pkg/util/envutil"
    36  	"github.com/cockroachdb/cockroach/pkg/util/humanizeutil"
    37  	"github.com/cockroachdb/cockroach/pkg/util/log"
    38  	"github.com/cockroachdb/cockroach/pkg/util/metric"
    39  	"github.com/cockroachdb/cockroach/pkg/util/mon"
    40  	"github.com/cockroachdb/cockroach/pkg/util/stop"
    41  	"github.com/cockroachdb/cockroach/pkg/util/syncutil"
    42  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    43  	"github.com/cockroachdb/errors"
    44  	"github.com/cockroachdb/logtags"
    45  )
    46  
    47  // ATTENTION: After changing this value in a unit test, you probably want to
    48  // open a new connection pool since the connections in the existing one are not
    49  // affected.
    50  //
    51  // The "results_buffer_size" connection parameter can be used to override this
    52  // default for an individual connection.
    53  var connResultsBufferSize = settings.RegisterPublicByteSizeSetting(
    54  	"sql.defaults.results_buffer.size",
    55  	"default size of the buffer that accumulates results for a statement or a batch "+
    56  		"of statements before they are sent to the client. This can be overridden on "+
    57  		"an individual connection with the 'results_buffer_size' parameter. Note that auto-retries "+
    58  		"generally only happen while no results have been delivered to the client, so "+
    59  		"reducing this size can increase the number of retriable errors a client "+
    60  		"receives. On the other hand, increasing the buffer size can increase the "+
    61  		"delay until the client receives the first result row. "+
    62  		"Updating the setting only affects new connections. "+
    63  		"Setting to 0 disables any buffering.",
    64  	16<<10, // 16 KiB
    65  )
    66  
    67  var logConnAuth = settings.RegisterPublicBoolSetting(
    68  	sql.ConnAuditingClusterSettingName,
    69  	"if set, log SQL client connect and disconnect events (note: may hinder performance on loaded nodes)",
    70  	false)
    71  
    72  var logSessionAuth = settings.RegisterPublicBoolSetting(
    73  	sql.AuthAuditingClusterSettingName,
    74  	"if set, log SQL session login/disconnection events (note: may hinder performance on loaded nodes)",
    75  	false)
    76  
    77  const (
    78  	// ErrSSLRequired is returned when a client attempts to connect to a
    79  	// secure server in cleartext.
    80  	ErrSSLRequired = "node is running secure mode, SSL connection required"
    81  
    82  	// ErrDrainingNewConn is returned when a client attempts to connect to a server
    83  	// which is not accepting client connections.
    84  	ErrDrainingNewConn = "server is not accepting clients"
    85  	// ErrDrainingExistingConn is returned when a connection is shut down because
    86  	// the server is draining.
    87  	ErrDrainingExistingConn = "server is shutting down"
    88  )
    89  
    90  // Fully-qualified names for metrics.
    91  var (
    92  	MetaConns = metric.Metadata{
    93  		Name:        "sql.conns",
    94  		Help:        "Number of active sql connections",
    95  		Measurement: "Connections",
    96  		Unit:        metric.Unit_COUNT,
    97  	}
    98  	MetaNewConns = metric.Metadata{
    99  		Name:        "sql.new_conns",
   100  		Help:        "Counter of the number of sql connections created",
   101  		Measurement: "Connections",
   102  		Unit:        metric.Unit_COUNT,
   103  	}
   104  	MetaBytesIn = metric.Metadata{
   105  		Name:        "sql.bytesin",
   106  		Help:        "Number of sql bytes received",
   107  		Measurement: "SQL Bytes",
   108  		Unit:        metric.Unit_BYTES,
   109  	}
   110  	MetaBytesOut = metric.Metadata{
   111  		Name:        "sql.bytesout",
   112  		Help:        "Number of sql bytes sent",
   113  		Measurement: "SQL Bytes",
   114  		Unit:        metric.Unit_BYTES,
   115  	}
   116  )
   117  
   118  const (
   119  	// The below constants can occur during the first message a client
   120  	// sends to the server. There are two categories: protocol version and
   121  	// request code. The protocol version is (major version number << 16)
   122  	// + minor version number. Request codes are (1234 << 16) + 5678 + N,
   123  	// where N started at 0 and is increased by 1 for every new request
   124  	// code added, which happens rarely during major or minor Postgres
   125  	// releases.
   126  	//
   127  	// See: https://www.postgresql.org/docs/current/protocol-message-formats.html
   128  
   129  	version30     = 196608   // (3 << 16) + 0
   130  	versionCancel = 80877102 // (1234 << 16) + 5678
   131  	versionSSL    = 80877103 // (1234 << 16) + 5679
   132  	versionGSSENC = 80877104 // (1234 << 16) + 5680
   133  )
   134  
   135  // cancelMaxWait is the amount of time a draining server gives to sessions to
   136  // react to cancellation and return before a forceful shutdown.
   137  const cancelMaxWait = 1 * time.Second
   138  
   139  // baseSQLMemoryBudget is the amount of memory pre-allocated in each connection.
   140  var baseSQLMemoryBudget = envutil.EnvOrDefaultInt64("COCKROACH_BASE_SQL_MEMORY_BUDGET",
   141  	int64(2.1*float64(mon.DefaultPoolAllocationSize)))
   142  
   143  // connReservationBatchSize determines for how many connections memory
   144  // is pre-reserved at once.
   145  var connReservationBatchSize = 5
   146  
   147  var (
   148  	sslSupported   = []byte{'S'}
   149  	sslUnsupported = []byte{'N'}
   150  )
   151  
   152  // cancelChanMap keeps track of channels that are closed after the associated
   153  // cancellation function has been called and the cancellation has taken place.
   154  type cancelChanMap map[chan struct{}]context.CancelFunc
   155  
   156  // Server implements the server side of the PostgreSQL wire protocol.
   157  type Server struct {
   158  	AmbientCtx log.AmbientContext
   159  	cfg        *base.Config
   160  	SQLServer  *sql.Server
   161  	execCfg    *sql.ExecutorConfig
   162  
   163  	metrics ServerMetrics
   164  
   165  	mu struct {
   166  		syncutil.Mutex
   167  		// connCancelMap entries represent connections started when the server
   168  		// was not draining. Each value is a function that can be called to
   169  		// cancel the associated connection. The corresponding key is a channel
   170  		// that is closed when the connection is done.
   171  		connCancelMap cancelChanMap
   172  		draining      bool
   173  	}
   174  
   175  	auth struct {
   176  		syncutil.RWMutex
   177  		conf *hba.Conf
   178  	}
   179  
   180  	sqlMemoryPool mon.BytesMonitor
   181  	connMonitor   mon.BytesMonitor
   182  
   183  	stopper *stop.Stopper
   184  
   185  	// testingLogEnabled is used in unit tests in this package to
   186  	// force-enable conn/auth logging without dancing around the
   187  	// asynchronicity of cluster settings.
   188  	testingLogEnabled int32
   189  }
   190  
   191  // ServerMetrics is the set of metrics for the pgwire server.
   192  type ServerMetrics struct {
   193  	BytesInCount   *metric.Counter
   194  	BytesOutCount  *metric.Counter
   195  	Conns          *metric.Gauge
   196  	NewConns       *metric.Counter
   197  	ConnMemMetrics sql.MemoryMetrics
   198  	SQLMemMetrics  sql.MemoryMetrics
   199  }
   200  
   201  func makeServerMetrics(
   202  	sqlMemMetrics sql.MemoryMetrics, histogramWindow time.Duration,
   203  ) ServerMetrics {
   204  	return ServerMetrics{
   205  		BytesInCount:   metric.NewCounter(MetaBytesIn),
   206  		BytesOutCount:  metric.NewCounter(MetaBytesOut),
   207  		Conns:          metric.NewGauge(MetaConns),
   208  		NewConns:       metric.NewCounter(MetaNewConns),
   209  		ConnMemMetrics: sql.MakeMemMetrics("conns", histogramWindow),
   210  		SQLMemMetrics:  sqlMemMetrics,
   211  	}
   212  }
   213  
   214  // noteworthySQLMemoryUsageBytes is the minimum size tracked by the
   215  // client SQL pool before the pool start explicitly logging overall
   216  // usage growth in the log.
   217  var noteworthySQLMemoryUsageBytes = envutil.EnvOrDefaultInt64("COCKROACH_NOTEWORTHY_SQL_MEMORY_USAGE", 100*1024*1024)
   218  
   219  // noteworthyConnMemoryUsageBytes is the minimum size tracked by the
   220  // connection monitor before the monitor start explicitly logging overall
   221  // usage growth in the log.
   222  var noteworthyConnMemoryUsageBytes = envutil.EnvOrDefaultInt64("COCKROACH_NOTEWORTHY_CONN_MEMORY_USAGE", 2*1024*1024)
   223  
   224  // MakeServer creates a Server.
   225  //
   226  // Start() needs to be called on the Server so it begins processing.
   227  func MakeServer(
   228  	ambientCtx log.AmbientContext,
   229  	cfg *base.Config,
   230  	st *cluster.Settings,
   231  	sqlMemMetrics sql.MemoryMetrics,
   232  	parentMemoryMonitor *mon.BytesMonitor,
   233  	histogramWindow time.Duration,
   234  	executorConfig *sql.ExecutorConfig,
   235  ) *Server {
   236  	server := &Server{
   237  		AmbientCtx: ambientCtx,
   238  		cfg:        cfg,
   239  		execCfg:    executorConfig,
   240  		metrics:    makeServerMetrics(sqlMemMetrics, histogramWindow),
   241  	}
   242  	server.sqlMemoryPool = mon.MakeMonitor("sql",
   243  		mon.MemoryResource,
   244  		server.metrics.SQLMemMetrics.CurBytesCount,
   245  		server.metrics.SQLMemMetrics.MaxBytesHist,
   246  		0, noteworthySQLMemoryUsageBytes, st)
   247  	server.sqlMemoryPool.Start(context.Background(), parentMemoryMonitor, mon.BoundAccount{})
   248  	server.SQLServer = sql.NewServer(executorConfig, &server.sqlMemoryPool)
   249  
   250  	server.connMonitor = mon.MakeMonitor("conn",
   251  		mon.MemoryResource,
   252  		server.metrics.ConnMemMetrics.CurBytesCount,
   253  		server.metrics.ConnMemMetrics.MaxBytesHist,
   254  		int64(connReservationBatchSize)*baseSQLMemoryBudget, noteworthyConnMemoryUsageBytes, st)
   255  	server.connMonitor.Start(context.Background(), &server.sqlMemoryPool, mon.BoundAccount{})
   256  
   257  	server.mu.Lock()
   258  	server.mu.connCancelMap = make(cancelChanMap)
   259  	server.mu.Unlock()
   260  
   261  	connAuthConf.SetOnChange(&st.SV,
   262  		func() {
   263  			loadLocalAuthConfigUponRemoteSettingChange(
   264  				ambientCtx.AnnotateCtx(context.Background()), server, st)
   265  		})
   266  
   267  	return server
   268  }
   269  
   270  // Match returns true if rd appears to be a Postgres connection.
   271  func Match(rd io.Reader) bool {
   272  	var buf pgwirebase.ReadBuffer
   273  	_, err := buf.ReadUntypedMsg(rd)
   274  	if err != nil {
   275  		return false
   276  	}
   277  	version, err := buf.GetUint32()
   278  	if err != nil {
   279  		return false
   280  	}
   281  	return version == version30 || version == versionSSL || version == versionCancel || version == versionGSSENC
   282  }
   283  
   284  // Start makes the Server ready for serving connections.
   285  func (s *Server) Start(ctx context.Context, stopper *stop.Stopper) {
   286  	s.stopper = stopper
   287  	s.SQLServer.Start(ctx, stopper)
   288  }
   289  
   290  // IsDraining returns true if the server is not currently accepting
   291  // connections.
   292  func (s *Server) IsDraining() bool {
   293  	s.mu.Lock()
   294  	defer s.mu.Unlock()
   295  	return s.mu.draining
   296  }
   297  
   298  // Metrics returns the set of metrics structs.
   299  func (s *Server) Metrics() (res []interface{}) {
   300  	return []interface{}{
   301  		&s.metrics,
   302  		&s.SQLServer.Metrics.StartedStatementCounters,
   303  		&s.SQLServer.Metrics.ExecutedStatementCounters,
   304  		&s.SQLServer.Metrics.EngineMetrics,
   305  		&s.SQLServer.InternalMetrics.StartedStatementCounters,
   306  		&s.SQLServer.InternalMetrics.ExecutedStatementCounters,
   307  		&s.SQLServer.InternalMetrics.EngineMetrics,
   308  	}
   309  }
   310  
   311  // Drain prevents new connections from being served and waits for drainWait for
   312  // open connections to terminate before canceling them.
   313  // An error will be returned when connections that have been canceled have not
   314  // responded to this cancellation and closed themselves in time. The server
   315  // will remain in draining state, though open connections may continue to
   316  // exist.
   317  // The RFC on drain modes has more information regarding the specifics of
   318  // what will happen to connections in different states:
   319  // https://github.com/cockroachdb/cockroach/blob/master/docs/RFCS/20160425_drain_modes.md
   320  //
   321  // The reporter callback, if non-nil, is called on a best effort basis
   322  // to report work that needed to be done and which may or may not have
   323  // been done by the time this call returns. See the explanation in
   324  // pkg/server/drain.go for details.
   325  func (s *Server) Drain(drainWait time.Duration, reporter func(int, string)) error {
   326  	return s.drainImpl(drainWait, cancelMaxWait, reporter)
   327  }
   328  
   329  // Undrain switches the server back to the normal mode of operation in which
   330  // connections are accepted.
   331  func (s *Server) Undrain() {
   332  	s.mu.Lock()
   333  	s.setDrainingLocked(false)
   334  	s.mu.Unlock()
   335  }
   336  
   337  // setDrainingLocked sets the server's draining state and returns whether the
   338  // state changed (i.e. drain != s.mu.draining). s.mu must be locked.
   339  func (s *Server) setDrainingLocked(drain bool) bool {
   340  	if s.mu.draining == drain {
   341  		return false
   342  	}
   343  	s.mu.draining = drain
   344  	return true
   345  }
   346  
   347  // drainImpl drains the SQL clients.
   348  //
   349  // The drainWait duration is used to wait on clients to
   350  // self-disconnect after their session has been canceled. The
   351  // cancelWait is used to wait after the drainWait timer has expired
   352  // and there are still clients connected, and their context.Context is
   353  // canceled.
   354  //
   355  // The reporter callback, if non-nil, is called on a best effort basis
   356  // to report work that needed to be done and which may or may not have
   357  // been done by the time this call returns. See the explanation in
   358  // pkg/server/drain.go for details.
   359  func (s *Server) drainImpl(
   360  	drainWait time.Duration, cancelWait time.Duration, reporter func(int, string),
   361  ) error {
   362  	// This anonymous function returns a copy of s.mu.connCancelMap if there are
   363  	// any active connections to cancel. We will only attempt to cancel
   364  	// connections that were active at the moment the draining switch happened.
   365  	// It is enough to do this because:
   366  	// 1) If no new connections are added to the original map all connections
   367  	// will be canceled.
   368  	// 2) If new connections are added to the original map, it follows that they
   369  	// were added when s.mu.draining = false, thus not requiring cancellation.
   370  	// These connections are not our responsibility and will be handled when the
   371  	// server starts draining again.
   372  	connCancelMap := func() cancelChanMap {
   373  		s.mu.Lock()
   374  		defer s.mu.Unlock()
   375  		if !s.setDrainingLocked(true) {
   376  			// We are already draining.
   377  			return nil
   378  		}
   379  		connCancelMap := make(cancelChanMap)
   380  		for done, cancel := range s.mu.connCancelMap {
   381  			connCancelMap[done] = cancel
   382  		}
   383  		return connCancelMap
   384  	}()
   385  	if len(connCancelMap) == 0 {
   386  		return nil
   387  	}
   388  	if reporter != nil {
   389  		// Report progress to the Drain RPC.
   390  		reporter(len(connCancelMap), "SQL clients")
   391  	}
   392  
   393  	// Spin off a goroutine that waits for all connections to signal that they
   394  	// are done and reports it on allConnsDone. The main goroutine signals this
   395  	// goroutine to stop work through quitWaitingForConns.
   396  	allConnsDone := make(chan struct{})
   397  	quitWaitingForConns := make(chan struct{})
   398  	defer close(quitWaitingForConns)
   399  	go func() {
   400  		defer close(allConnsDone)
   401  		for done := range connCancelMap {
   402  			select {
   403  			case <-done:
   404  			case <-quitWaitingForConns:
   405  				return
   406  			}
   407  		}
   408  	}()
   409  
   410  	// Wait for all connections to finish up to drainWait.
   411  	select {
   412  	case <-time.After(drainWait):
   413  	case <-allConnsDone:
   414  	}
   415  
   416  	// Cancel the contexts of all sessions if the server is still in draining
   417  	// mode.
   418  	if stop := func() bool {
   419  		s.mu.Lock()
   420  		defer s.mu.Unlock()
   421  		if !s.mu.draining {
   422  			return true
   423  		}
   424  		for _, cancel := range connCancelMap {
   425  			// There is a possibility that different calls to SetDraining have
   426  			// overlapping connCancelMaps, but context.CancelFunc calls are
   427  			// idempotent.
   428  			cancel()
   429  		}
   430  		return false
   431  	}(); stop {
   432  		return nil
   433  	}
   434  
   435  	select {
   436  	case <-time.After(cancelWait):
   437  		return errors.Errorf("some sessions did not respond to cancellation within %s", cancelWait)
   438  	case <-allConnsDone:
   439  	}
   440  	return nil
   441  }
   442  
   443  // SocketType indicates the connection type. This is an optimization to
   444  // prevent a comparison against conn.LocalAddr().Network().
   445  type SocketType bool
   446  
   447  const (
   448  	// SocketTCP is used for TCP sockets. The standard.
   449  	SocketTCP SocketType = true
   450  	// SocketUnix is used for unix datagram sockets.
   451  	SocketUnix SocketType = false
   452  )
   453  
   454  func (s SocketType) asConnType() (hba.ConnType, error) {
   455  	switch s {
   456  	case SocketTCP:
   457  		return hba.ConnHostNoSSL, nil
   458  	case SocketUnix:
   459  		return hba.ConnLocal, nil
   460  	default:
   461  		return 0, errors.AssertionFailedf("unimplemented socket type: %v", errors.Safe(s))
   462  	}
   463  }
   464  
   465  func (s *Server) connLogEnabled() bool {
   466  	return atomic.LoadInt32(&s.testingLogEnabled) != 0 || logConnAuth.Get(&s.execCfg.Settings.SV)
   467  }
   468  
   469  // TestingEnableConnAuthLogging is exported for use in tests.
   470  func (s *Server) TestingEnableConnAuthLogging() {
   471  	atomic.StoreInt32(&s.testingLogEnabled, 1)
   472  }
   473  
   474  // ServeConn serves a single connection, driving the handshake process and
   475  // delegating to the appropriate connection type.
   476  //
   477  // The socketType argument is an optimization to avoid a string
   478  // compare on conn.LocalAddr().Network(). When the socket type is
   479  // unix datagram (local filesystem), SSL negotiation is disabled
   480  // even when the server is running securely with certificates.
   481  // This has the effect of forcing password auth, also in a way
   482  // compatible with postgres.
   483  //
   484  // An error is returned if the initial handshake of the connection fails.
   485  func (s *Server) ServeConn(ctx context.Context, conn net.Conn, socketType SocketType) error {
   486  	ctx, draining, onCloseFn := s.registerConn(ctx)
   487  	defer onCloseFn()
   488  
   489  	// Some bookkeeping, for security-minded administrators.
   490  	// This registers the connection to the authentication log.
   491  	connStart := timeutil.Now()
   492  	if s.connLogEnabled() {
   493  		s.execCfg.AuthLogger.Logf(ctx, "received connection")
   494  	}
   495  	defer func() {
   496  		// The duration of the session is logged at the end so that the
   497  		// reader of the log file can know how much to look back in time
   498  		// to find when the connection was opened. This is important
   499  		// because the log files may have been rotated since.
   500  		if s.connLogEnabled() {
   501  			s.execCfg.AuthLogger.Logf(ctx, "disconnected; duration: %s", timeutil.Now().Sub(connStart))
   502  		}
   503  	}()
   504  
   505  	// In any case, first check the command in the start-up message.
   506  	//
   507  	// We're assuming that a client is not willing/able to receive error
   508  	// packets before we drain that message.
   509  	version, buf, err := s.readVersion(conn)
   510  	if err != nil {
   511  		return err
   512  	}
   513  
   514  	if version == versionCancel {
   515  		// The cancel message is rather peculiar: it is sent without
   516  		// authentication, always over an unencrypted channel.
   517  		//
   518  		// Since we don't support this, close the door in the client's
   519  		// face. Make a note of that use in telemetry.
   520  		telemetry.Inc(sqltelemetry.CancelRequestCounter)
   521  		_ = conn.Close()
   522  		return nil
   523  	}
   524  
   525  	// If the server is shutting down, terminate the connection early.
   526  	if draining {
   527  		return s.sendErr(ctx, conn, newAdminShutdownErr(ErrDrainingNewConn))
   528  	}
   529  
   530  	// Compute the initial connType.
   531  	connType, err := socketType.asConnType()
   532  	if err != nil {
   533  		return err
   534  	}
   535  
   536  	// If the client requests SSL, upgrade the connection to use TLS.
   537  	var clientErr error
   538  	conn, connType, version, clientErr, err = s.maybeUpgradeToSecureConn(ctx, conn, connType, version, &buf)
   539  	if err != nil {
   540  		return err
   541  	}
   542  	if clientErr != nil {
   543  		return s.sendErr(ctx, conn, clientErr)
   544  	}
   545  	ctx = logtags.AddTag(ctx, connType.String(), nil)
   546  
   547  	// What does the client want to do?
   548  	switch version {
   549  	case version30:
   550  		// Normal SQL connection. Proceed normally below.
   551  
   552  	default:
   553  		// We don't know this protocol.
   554  		return s.sendErr(ctx, conn,
   555  			pgerror.Newf(pgcode.ProtocolViolation, "unknown protocol version %d", version))
   556  	}
   557  
   558  	// Reserve some memory for this connection using the server's monitor. This
   559  	// reduces pressure on the shared pool because the server monitor allocates in
   560  	// chunks from the shared pool and these chunks should be larger than
   561  	// baseSQLMemoryBudget.
   562  	reserved := s.connMonitor.MakeBoundAccount()
   563  	if err := reserved.Grow(ctx, baseSQLMemoryBudget); err != nil {
   564  		return errors.Errorf("unable to pre-allocate %d bytes for this connection: %v",
   565  			baseSQLMemoryBudget, err)
   566  	}
   567  
   568  	// Load the client-provided session parameters.
   569  	var sArgs sql.SessionArgs
   570  	if sArgs, err = parseClientProvidedSessionParameters(ctx, &s.execCfg.Settings.SV, &buf); err != nil {
   571  		return s.sendErr(ctx, conn, err)
   572  	}
   573  
   574  	// If a test is hooking in some authentication option, load it.
   575  	var testingAuthHook func(context.Context) error
   576  	if k := s.execCfg.PGWireTestingKnobs; k != nil {
   577  		testingAuthHook = k.AuthHook
   578  	}
   579  
   580  	// Defer the rest of the processing to the connection handler.
   581  	// This includes authentication.
   582  	s.serveConn(
   583  		ctx, conn, sArgs,
   584  		reserved,
   585  		authOptions{
   586  			connType:        connType,
   587  			insecure:        s.cfg.Insecure,
   588  			ie:              s.execCfg.InternalExecutor,
   589  			auth:            s.GetAuthenticationConfiguration(),
   590  			testingAuthHook: testingAuthHook,
   591  		})
   592  	return nil
   593  }
   594  
   595  // parseClientProvidedSessionParameters reads the incoming k/v pairs
   596  // in the startup message into a sql.SessionArgs struct.
   597  func parseClientProvidedSessionParameters(
   598  	ctx context.Context, sv *settings.Values, buf *pgwirebase.ReadBuffer,
   599  ) (sql.SessionArgs, error) {
   600  	args := sql.SessionArgs{
   601  		SessionDefaults: make(map[string]string),
   602  	}
   603  	foundBufferSize := false
   604  
   605  	for {
   606  		// Read a key-value pair from the client.
   607  		key, err := buf.GetString()
   608  		if err != nil {
   609  			return sql.SessionArgs{}, pgerror.Newf(pgcode.ProtocolViolation,
   610  				"error reading option key: %s", err)
   611  		}
   612  		if len(key) == 0 {
   613  			// End of parameter list.
   614  			break
   615  		}
   616  		value, err := buf.GetString()
   617  		if err != nil {
   618  			return sql.SessionArgs{}, pgerror.Newf(pgcode.ProtocolViolation,
   619  				"error reading option value: %s", err)
   620  		}
   621  
   622  		// Case-fold for the key for easier comparison.
   623  		key = strings.ToLower(key)
   624  
   625  		// Load the parameter.
   626  		switch key {
   627  		case "user":
   628  			// Unicode-normalize and case-fold the username.
   629  			args.User = tree.Name(value).Normalize()
   630  
   631  		case "results_buffer_size":
   632  			if args.ConnResultsBufferSize, err = humanizeutil.ParseBytes(value); err != nil {
   633  				return sql.SessionArgs{}, errors.WithSecondaryError(
   634  					pgerror.Newf(pgcode.ProtocolViolation,
   635  						"error parsing results_buffer_size option value '%s' as bytes", value), err)
   636  			}
   637  			if args.ConnResultsBufferSize < 0 {
   638  				return sql.SessionArgs{}, pgerror.Newf(pgcode.ProtocolViolation,
   639  					"results_buffer_size option value '%s' cannot be negative", value)
   640  			}
   641  			foundBufferSize = true
   642  
   643  		default:
   644  			exists, configurable := sql.IsSessionVariableConfigurable(key)
   645  
   646  			switch {
   647  			case exists && configurable:
   648  				args.SessionDefaults[key] = value
   649  
   650  			case !exists:
   651  				if _, ok := sql.UnsupportedVars[key]; ok {
   652  					counter := sqltelemetry.UnimplementedClientStatusParameterCounter(key)
   653  					telemetry.Inc(counter)
   654  				}
   655  				log.Warningf(ctx, "unknown configuration parameter: %q", key)
   656  
   657  			case !configurable:
   658  				return sql.SessionArgs{}, pgerror.Newf(pgcode.CantChangeRuntimeParam,
   659  					"parameter %q cannot be changed", key)
   660  			}
   661  		}
   662  	}
   663  
   664  	if !foundBufferSize && sv != nil {
   665  		// The client did not provide buffer_size; use the cluster setting as default.
   666  		args.ConnResultsBufferSize = connResultsBufferSize.Get(sv)
   667  	}
   668  
   669  	if _, ok := args.SessionDefaults["database"]; !ok {
   670  		// CockroachDB-specific behavior: if no database is specified,
   671  		// default to "defaultdb". In PostgreSQL this would be "postgres".
   672  		args.SessionDefaults["database"] = sqlbase.DefaultDatabaseName
   673  	}
   674  
   675  	return args, nil
   676  }
   677  
   678  // maybeUpgradeToSecureConn upgrades the connection to TLS/SSL if
   679  // requested by the client, and available in the server configuration.
   680  func (s *Server) maybeUpgradeToSecureConn(
   681  	ctx context.Context,
   682  	conn net.Conn,
   683  	connType hba.ConnType,
   684  	version uint32,
   685  	buf *pgwirebase.ReadBuffer,
   686  ) (newConn net.Conn, newConnType hba.ConnType, newVersion uint32, clientErr, serverErr error) {
   687  	// By default, this is a no-op.
   688  	newConn = conn
   689  	newConnType = connType
   690  	newVersion = version
   691  	var n int // byte counts
   692  
   693  	if version != versionSSL {
   694  		// The client did not require a SSL connection.
   695  
   696  		if !s.cfg.Insecure && connType != hba.ConnLocal {
   697  			// Currently non-SSL connections are not allowed in secure
   698  			// mode. Ideally, we want to allow this and subject it to HBA
   699  			// rules ('hostssl' vs 'hostnossl').
   700  			//
   701  			// TODO(knz): revisit this when needed.
   702  			clientErr = pgerror.New(pgcode.ProtocolViolation, ErrSSLRequired)
   703  			return
   704  		}
   705  
   706  		// Non-SSL in non-secure mode, all is well: no-op.
   707  		return
   708  	}
   709  
   710  	if connType == hba.ConnLocal {
   711  		clientErr = pgerror.New(pgcode.ProtocolViolation,
   712  			"cannot use SSL/TLS over local connections")
   713  	}
   714  
   715  	// Protocol sanity check.
   716  	if len(buf.Msg) > 0 {
   717  		serverErr = errors.Errorf("unexpected data after SSLRequest: %q", buf.Msg)
   718  		return
   719  	}
   720  
   721  	// The client has requested SSL. We're going to try and upgrade the
   722  	// connection to use TLS/SSL.
   723  
   724  	// Do we have a TLS configuration?
   725  	tlsConfig, serverErr := s.cfg.GetServerTLSConfig()
   726  	if serverErr != nil {
   727  		return
   728  	}
   729  
   730  	if tlsConfig == nil {
   731  		// We don't have a TLS configuration available, so we can't honor
   732  		// the client's request.
   733  		n, serverErr = conn.Write(sslUnsupported)
   734  		if serverErr != nil {
   735  			return
   736  		}
   737  	} else {
   738  		// We have a TLS configuration. Upgrade the connection.
   739  		n, serverErr = conn.Write(sslSupported)
   740  		if serverErr != nil {
   741  			return
   742  		}
   743  		newConn = tls.Server(conn, tlsConfig)
   744  		newConnType = hba.ConnHostSSL
   745  	}
   746  	s.metrics.BytesOutCount.Inc(int64(n))
   747  
   748  	// Finally, re-read the version/command from the client.
   749  	newVersion, *buf, serverErr = s.readVersion(newConn)
   750  	return
   751  }
   752  
   753  // registerConn registers the incoming connection to the map of active connections,
   754  // which can be canceled by a concurrent server drain. It also returns
   755  // the current draining status of the server.
   756  //
   757  // The onCloseFn() callback must be called at the end of the
   758  // connection by the caller.
   759  func (s *Server) registerConn(
   760  	ctx context.Context,
   761  ) (newCtx context.Context, draining bool, onCloseFn func()) {
   762  	onCloseFn = func() {}
   763  	newCtx = ctx
   764  	s.mu.Lock()
   765  	draining = s.mu.draining
   766  	if !draining {
   767  		var cancel context.CancelFunc
   768  		newCtx, cancel = contextutil.WithCancel(ctx)
   769  		done := make(chan struct{})
   770  		s.mu.connCancelMap[done] = cancel
   771  		onCloseFn = func() {
   772  			cancel()
   773  			close(done)
   774  			s.mu.Lock()
   775  			delete(s.mu.connCancelMap, done)
   776  			s.mu.Unlock()
   777  		}
   778  	}
   779  	s.mu.Unlock()
   780  
   781  	// If the Server is draining, we will use the connection only to send an
   782  	// error, so we don't count it in the stats. This makes sense since
   783  	// DrainClient() waits for that number to drop to zero,
   784  	// so we don't want it to oscillate unnecessarily.
   785  	if !draining {
   786  		s.metrics.NewConns.Inc(1)
   787  		s.metrics.Conns.Inc(1)
   788  		prevOnCloseFn := onCloseFn
   789  		onCloseFn = func() { prevOnCloseFn(); s.metrics.Conns.Dec(1) }
   790  	}
   791  	return
   792  }
   793  
   794  // readVersion reads the start-up message, then returns the version
   795  // code (first uint32 in message) and the buffer containing the rest
   796  // of the payload.
   797  func (s *Server) readVersion(
   798  	conn io.Reader,
   799  ) (version uint32, buf pgwirebase.ReadBuffer, err error) {
   800  	var n int
   801  	n, err = buf.ReadUntypedMsg(conn)
   802  	if err != nil {
   803  		return
   804  	}
   805  	version, err = buf.GetUint32()
   806  	if err != nil {
   807  		return
   808  	}
   809  	s.metrics.BytesInCount.Inc(int64(n))
   810  	return
   811  }
   812  
   813  // sendErr sends errors to the client during the connection startup
   814  // sequence. Later error sends during/after authentication are handled
   815  // in conn.go.
   816  func (s *Server) sendErr(ctx context.Context, conn net.Conn, err error) error {
   817  	msgBuilder := newWriteBuffer(s.metrics.BytesOutCount)
   818  	// We could, but do not, report server-side network errors while
   819  	// trying to send the client error. This is because clients that
   820  	// receive error payload are highly correlated with clients
   821  	// disconnecting abruptly.
   822  	_ /* err */ = writeErr(ctx, &s.execCfg.Settings.SV, err, msgBuilder, conn)
   823  	_ = conn.Close()
   824  	return err
   825  }
   826  
   827  func newAdminShutdownErr(msg string) error {
   828  	return pgerror.New(pgcode.AdminShutdown, msg)
   829  }