github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/pgwire/command_result.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package pgwire
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  
    17  	"github.com/cockroachdb/cockroach/pkg/server/telemetry"
    18  	"github.com/cockroachdb/cockroach/pkg/sql"
    19  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
    20  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    21  	"github.com/cockroachdb/cockroach/pkg/sql/sessiondata"
    22  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    23  	"github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry"
    24  	"github.com/cockroachdb/errors"
    25  	"github.com/lib/pq/oid"
    26  )
    27  
    28  type completionMsgType int
    29  
    30  const (
    31  	_ completionMsgType = iota
    32  	commandComplete
    33  	bindComplete
    34  	closeComplete
    35  	parseComplete
    36  	emptyQueryResponse
    37  	readyForQuery
    38  	flush
    39  	// Some commands, like Describe, don't need a completion message.
    40  	noCompletionMsg
    41  )
    42  
    43  // commandResult is an implementation of sql.CommandResult that streams a
    44  // commands results over a pgwire network connection.
    45  type commandResult struct {
    46  	// conn is the parent connection of this commandResult.
    47  	conn *conn
    48  	// conv indicates the conversion settings for SQL values.
    49  	conv sessiondata.DataConversionConfig
    50  	// pos identifies the position of the command within the connection.
    51  	pos sql.CmdPos
    52  	// flushBeforeClose contains a list of functions to flush
    53  	// before a command is closed.
    54  	flushBeforeCloseFuncs []func(ctx context.Context) error
    55  
    56  	err error
    57  	// errExpected, if set, enforces that an error had been set when Close is
    58  	// called.
    59  	errExpected bool
    60  
    61  	typ completionMsgType
    62  	// If typ == commandComplete, this is the tag to be written in the
    63  	// CommandComplete message.
    64  	cmdCompleteTag string
    65  
    66  	stmtType     tree.StatementType
    67  	descOpt      sql.RowDescOpt
    68  	rowsAffected int
    69  
    70  	// formatCodes describes the encoding of each column of result rows. It is nil
    71  	// for statements not returning rows (or for results for commands other than
    72  	// executing statements). It can also be nil for queries returning rows,
    73  	// meaning that all columns will be encoded in the text format (this is the
    74  	// case for queries executed through the simple protocol). Otherwise, it needs
    75  	// to have an entry for every column.
    76  	formatCodes []pgwirebase.FormatCode
    77  
    78  	// oids is a map from result column index to its Oid, similar to formatCodes
    79  	// (except oids must always be set).
    80  	oids []oid.Oid
    81  
    82  	// bufferingDisabled is conditionally set during planning of certain
    83  	// statements.
    84  	bufferingDisabled bool
    85  
    86  	// released is set when the command result has been released so that its
    87  	// memory can be reused. It is also used to assert against use-after-free
    88  	// errors.
    89  	released bool
    90  }
    91  
    92  var _ sql.CommandResult = &commandResult{}
    93  
    94  // Close is part of the CommandResult interface.
    95  func (r *commandResult) Close(ctx context.Context, t sql.TransactionStatusIndicator) {
    96  	r.assertNotReleased()
    97  	defer r.release()
    98  	if r.errExpected && r.err == nil {
    99  		panic("expected err to be set on result by Close, but wasn't")
   100  	}
   101  
   102  	r.conn.writerState.fi.registerCmd(r.pos)
   103  	if r.err != nil {
   104  		r.conn.bufferErr(ctx, r.err)
   105  		return
   106  	}
   107  
   108  	for _, f := range r.flushBeforeCloseFuncs {
   109  		if err := f(ctx); err != nil {
   110  			panic(fmt.Sprintf("unexpected err when closing: %s", err))
   111  		}
   112  	}
   113  	r.flushBeforeCloseFuncs = nil
   114  
   115  	// Send a completion message, specific to the type of result.
   116  	switch r.typ {
   117  	case commandComplete:
   118  		tag := cookTag(
   119  			r.cmdCompleteTag, r.conn.writerState.tagBuf[:0], r.stmtType, r.rowsAffected,
   120  		)
   121  		r.conn.bufferCommandComplete(tag)
   122  	case parseComplete:
   123  		r.conn.bufferParseComplete()
   124  	case bindComplete:
   125  		r.conn.bufferBindComplete()
   126  	case closeComplete:
   127  		r.conn.bufferCloseComplete()
   128  	case readyForQuery:
   129  		r.conn.bufferReadyForQuery(byte(t))
   130  		// The error is saved on conn.err.
   131  		_ /* err */ = r.conn.Flush(r.pos)
   132  	case emptyQueryResponse:
   133  		r.conn.bufferEmptyQueryResponse()
   134  	case flush:
   135  		// The error is saved on conn.err.
   136  		_ /* err */ = r.conn.Flush(r.pos)
   137  	case noCompletionMsg:
   138  		// nothing to do
   139  	default:
   140  		panic(fmt.Sprintf("unknown type: %v", r.typ))
   141  	}
   142  }
   143  
   144  // Discard is part of the CommandResult interface.
   145  func (r *commandResult) Discard() {
   146  	r.assertNotReleased()
   147  	defer r.release()
   148  }
   149  
   150  // Err is part of the CommandResult interface.
   151  func (r *commandResult) Err() error {
   152  	r.assertNotReleased()
   153  	return r.err
   154  }
   155  
   156  // SetError is part of the CommandResult interface.
   157  //
   158  // We're not going to write any bytes to the buffer in order to support future
   159  // SetError() calls. The error will only be serialized at Close() time.
   160  func (r *commandResult) SetError(err error) {
   161  	r.assertNotReleased()
   162  	r.err = err
   163  }
   164  
   165  // AddRow is part of the CommandResult interface.
   166  func (r *commandResult) AddRow(ctx context.Context, row tree.Datums) error {
   167  	r.assertNotReleased()
   168  	if r.err != nil {
   169  		panic(fmt.Sprintf("can't call AddRow after having set error: %s",
   170  			r.err))
   171  	}
   172  	r.conn.writerState.fi.registerCmd(r.pos)
   173  	if err := r.conn.GetErr(); err != nil {
   174  		return err
   175  	}
   176  	if r.err != nil {
   177  		panic("can't send row after error")
   178  	}
   179  	r.rowsAffected++
   180  
   181  	r.conn.bufferRow(ctx, row, r.formatCodes, r.conv, r.oids)
   182  	var err error
   183  	if r.bufferingDisabled {
   184  		err = r.conn.Flush(r.pos)
   185  	} else {
   186  		_ /* flushed */, err = r.conn.maybeFlush(r.pos)
   187  	}
   188  	return err
   189  }
   190  
   191  // DisableBuffering is part of the CommandResult interface.
   192  func (r *commandResult) DisableBuffering() {
   193  	r.assertNotReleased()
   194  	r.bufferingDisabled = true
   195  }
   196  
   197  // AppendParamStatusUpdate is part of the CommandResult interface.
   198  func (r *commandResult) AppendParamStatusUpdate(param string, val string) {
   199  	r.flushBeforeCloseFuncs = append(
   200  		r.flushBeforeCloseFuncs,
   201  		func(ctx context.Context) error { return r.conn.bufferParamStatus(param, val) },
   202  	)
   203  }
   204  
   205  // AppendNotice is part of the CommandResult interface.
   206  func (r *commandResult) AppendNotice(noticeErr error) {
   207  	r.flushBeforeCloseFuncs = append(
   208  		r.flushBeforeCloseFuncs,
   209  		func(ctx context.Context) error {
   210  			return r.conn.bufferNotice(ctx, noticeErr)
   211  		},
   212  	)
   213  }
   214  
   215  // SetColumns is part of the CommandResult interface.
   216  func (r *commandResult) SetColumns(ctx context.Context, cols sqlbase.ResultColumns) {
   217  	r.assertNotReleased()
   218  	r.conn.writerState.fi.registerCmd(r.pos)
   219  	if r.descOpt == sql.NeedRowDesc {
   220  		_ /* err */ = r.conn.writeRowDescription(ctx, cols, r.formatCodes, &r.conn.writerState.buf)
   221  	}
   222  	r.oids = make([]oid.Oid, len(cols))
   223  	for i, col := range cols {
   224  		r.oids[i] = col.Typ.Oid()
   225  	}
   226  }
   227  
   228  // SetInferredTypes is part of the DescribeResult interface.
   229  func (r *commandResult) SetInferredTypes(types []oid.Oid) {
   230  	r.assertNotReleased()
   231  	r.conn.writerState.fi.registerCmd(r.pos)
   232  	r.conn.bufferParamDesc(types)
   233  }
   234  
   235  // SetNoDataRowDescription is part of the DescribeResult interface.
   236  func (r *commandResult) SetNoDataRowDescription() {
   237  	r.assertNotReleased()
   238  	r.conn.writerState.fi.registerCmd(r.pos)
   239  	r.conn.bufferNoDataMsg()
   240  }
   241  
   242  // SetPrepStmtOutput is part of the DescribeResult interface.
   243  func (r *commandResult) SetPrepStmtOutput(ctx context.Context, cols sqlbase.ResultColumns) {
   244  	r.assertNotReleased()
   245  	r.conn.writerState.fi.registerCmd(r.pos)
   246  	_ /* err */ = r.conn.writeRowDescription(ctx, cols, nil /* formatCodes */, &r.conn.writerState.buf)
   247  }
   248  
   249  // SetPortalOutput is part of the DescribeResult interface.
   250  func (r *commandResult) SetPortalOutput(
   251  	ctx context.Context, cols sqlbase.ResultColumns, formatCodes []pgwirebase.FormatCode,
   252  ) {
   253  	r.assertNotReleased()
   254  	r.conn.writerState.fi.registerCmd(r.pos)
   255  	_ /* err */ = r.conn.writeRowDescription(ctx, cols, formatCodes, &r.conn.writerState.buf)
   256  }
   257  
   258  // IncrementRowsAffected is part of the CommandResult interface.
   259  func (r *commandResult) IncrementRowsAffected(n int) {
   260  	r.assertNotReleased()
   261  	r.rowsAffected += n
   262  }
   263  
   264  // RowsAffected is part of the CommandResult interface.
   265  func (r *commandResult) RowsAffected() int {
   266  	r.assertNotReleased()
   267  	return r.rowsAffected
   268  }
   269  
   270  // ResetStmtType is part of the CommandResult interface.
   271  func (r *commandResult) ResetStmtType(stmt tree.Statement) {
   272  	r.assertNotReleased()
   273  	r.stmtType = stmt.StatementType()
   274  	r.cmdCompleteTag = stmt.StatementTag()
   275  }
   276  
   277  // release frees the commandResult and allows its memory to be reused.
   278  func (r *commandResult) release() {
   279  	*r = commandResult{released: true}
   280  }
   281  
   282  // assertNotReleased asserts that the commandResult is not being used after
   283  // being freed by one of the methods in the CommandResultClose interface. The
   284  // assertion can have false negatives, where it fails to detect a use-after-free
   285  // condition, but will never result in a false positive.
   286  func (r *commandResult) assertNotReleased() {
   287  	if r.released {
   288  		panic("commandResult used after being released")
   289  	}
   290  }
   291  
   292  func (c *conn) allocCommandResult() *commandResult {
   293  	r := &c.res
   294  	if r.released {
   295  		r.released = false
   296  	} else {
   297  		// In practice, each conn only ever uses a single commandResult at a
   298  		// time, so we could make this panic. However, doing so would entail
   299  		// complicating the ClientComm interface, which doesn't seem worth it.
   300  		r = new(commandResult)
   301  	}
   302  	return r
   303  }
   304  
   305  func (c *conn) newCommandResult(
   306  	descOpt sql.RowDescOpt,
   307  	pos sql.CmdPos,
   308  	stmt tree.Statement,
   309  	formatCodes []pgwirebase.FormatCode,
   310  	conv sessiondata.DataConversionConfig,
   311  	limit int,
   312  	portalName string,
   313  	implicitTxn bool,
   314  ) sql.CommandResult {
   315  	r := c.allocCommandResult()
   316  	*r = commandResult{
   317  		conn:           c,
   318  		conv:           conv,
   319  		pos:            pos,
   320  		typ:            commandComplete,
   321  		cmdCompleteTag: stmt.StatementTag(),
   322  		stmtType:       stmt.StatementType(),
   323  		descOpt:        descOpt,
   324  		formatCodes:    formatCodes,
   325  	}
   326  	if limit == 0 {
   327  		return r
   328  	}
   329  	telemetry.Inc(sqltelemetry.PortalWithLimitRequestCounter)
   330  	return &limitedCommandResult{
   331  		limit:         limit,
   332  		portalName:    portalName,
   333  		implicitTxn:   implicitTxn,
   334  		commandResult: r,
   335  	}
   336  }
   337  
   338  func (c *conn) newMiscResult(pos sql.CmdPos, typ completionMsgType) *commandResult {
   339  	r := c.allocCommandResult()
   340  	*r = commandResult{
   341  		conn: c,
   342  		pos:  pos,
   343  		typ:  typ,
   344  	}
   345  	return r
   346  }
   347  
   348  // limitedCommandResult is a commandResult that has a limit, after which calls
   349  // to AddRow will block until the associated client connection asks for more
   350  // rows. It essentially implements the "execute portal with limit" part of the
   351  // Postgres protocol.
   352  //
   353  // This design is known to be flawed. It only supports a specific subset of the
   354  // protocol. We only allow a portal suspension in an explicit transaction where
   355  // the suspended portal is completely exhausted before any other pgwire command
   356  // is executed, otherwise an error is produced. You cannot, for example,
   357  // interleave portal executions (a portal must be executed to completion before
   358  // another can be executed). It also breaks the software layering by adding an
   359  // additional state machine here, instead of teaching the state machine in the
   360  // sql package about portals.
   361  //
   362  // This has been done because refactoring the executor to be able to correctly
   363  // suspend a portal will require a lot of work, and we wanted to move
   364  // forward. The work included is things like auditing all of the defers and
   365  // post-execution stuff (like stats collection) to have it only execute once
   366  // per statement instead of once per portal.
   367  type limitedCommandResult struct {
   368  	*commandResult
   369  	portalName  string
   370  	implicitTxn bool
   371  
   372  	seenTuples int
   373  	// If set, an error will be sent to the client if more rows are produced than
   374  	// this limit.
   375  	limit int
   376  }
   377  
   378  // AddRow is part of the CommandResult interface.
   379  func (r *limitedCommandResult) AddRow(ctx context.Context, row tree.Datums) error {
   380  	if err := r.commandResult.AddRow(ctx, row); err != nil {
   381  		return err
   382  	}
   383  	r.seenTuples++
   384  
   385  	if r.seenTuples == r.limit {
   386  		// If we've seen up to the limit of rows, send a "portal suspended" message
   387  		// and wait for another exec portal message.
   388  		r.conn.bufferPortalSuspended()
   389  		if err := r.conn.Flush(r.pos); err != nil {
   390  			return err
   391  		}
   392  		r.seenTuples = 0
   393  
   394  		return r.moreResultsNeeded(ctx)
   395  	}
   396  	if _ /* flushed */, err := r.conn.maybeFlush(r.pos); err != nil {
   397  		return err
   398  	}
   399  	return nil
   400  }
   401  
   402  // moreResultsNeeded is a restricted connection handler that waits for more
   403  // requests for rows from the active portal, during the "execute portal" flow
   404  // when a limit has been specified.
   405  func (r *limitedCommandResult) moreResultsNeeded(ctx context.Context) error {
   406  	// In an implicit transaction, a portal suspension is immediately
   407  	// followed by closing the portal.
   408  	if r.implicitTxn {
   409  		r.typ = noCompletionMsg
   410  		return sql.ErrLimitedResultClosed
   411  	}
   412  
   413  	// Keep track of the previous CmdPos so we can rewind if needed.
   414  	prevPos := r.conn.stmtBuf.AdvanceOne()
   415  	for {
   416  		cmd, curPos, err := r.conn.stmtBuf.CurCmd()
   417  		if err != nil {
   418  			return err
   419  		}
   420  		switch c := cmd.(type) {
   421  		case sql.DeletePreparedStmt:
   422  			// The client wants to close a portal or statement. We
   423  			// support the case where it is exactly this
   424  			// portal. This is done by closing the portal in
   425  			// the same way implicit transactions do, but also
   426  			// rewinding the stmtBuf to still point to the portal
   427  			// close so that the state machine can do its part of
   428  			// the cleanup. We are in effect peeking to see if the
   429  			// next message is a delete portal.
   430  			if c.Type != pgwirebase.PreparePortal || c.Name != r.portalName {
   431  				telemetry.Inc(sqltelemetry.InterleavedPortalRequestCounter)
   432  				return errors.WithDetail(sql.ErrLimitedResultNotSupported,
   433  					"cannot close a portal while a different one is open")
   434  			}
   435  			r.typ = noCompletionMsg
   436  			// Rewind to before the delete so the AdvanceOne in
   437  			// connExecutor.execCmd ends up back on it.
   438  			r.conn.stmtBuf.Rewind(ctx, prevPos)
   439  			return sql.ErrLimitedResultClosed
   440  		case sql.ExecPortal:
   441  			// The happy case: the client wants more rows from the portal.
   442  			if c.Name != r.portalName {
   443  				telemetry.Inc(sqltelemetry.InterleavedPortalRequestCounter)
   444  				return errors.WithDetail(sql.ErrLimitedResultNotSupported,
   445  					"cannot execute a portal while a different one is open")
   446  			}
   447  			r.limit = c.Limit
   448  			// In order to get the correct command tag, we need to reset the seen rows.
   449  			r.rowsAffected = 0
   450  			return nil
   451  		case sql.Sync:
   452  			// The client wants to see a ready for query message
   453  			// back. Send it then run the for loop again.
   454  			r.conn.stmtBuf.AdvanceOne()
   455  			// We can hard code InTxnBlock here because we don't
   456  			// support implicit transactions, so we know we're in
   457  			// a transaction.
   458  			r.conn.bufferReadyForQuery(byte(sql.InTxnBlock))
   459  			if err := r.conn.Flush(r.pos); err != nil {
   460  				return err
   461  			}
   462  		default:
   463  			// We got some other message, but we only support executing to completion.
   464  			telemetry.Inc(sqltelemetry.InterleavedPortalRequestCounter)
   465  			return errors.WithSafeDetails(sql.ErrLimitedResultNotSupported,
   466  				"cannot perform operation %T while a different portal is open",
   467  				errors.Safe(c))
   468  		}
   469  		prevPos = curPos
   470  	}
   471  }