github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/conn_executor_prepare.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package sql
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  
    17  	"github.com/cockroachdb/cockroach/pkg/kv"
    18  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    19  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    20  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
    21  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    22  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    23  	"github.com/cockroachdb/cockroach/pkg/util/fsm"
    24  	"github.com/cockroachdb/cockroach/pkg/util/log"
    25  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    26  	"github.com/cockroachdb/errors"
    27  	"github.com/lib/pq/oid"
    28  )
    29  
    30  func (ex *connExecutor) execPrepare(
    31  	ctx context.Context, parseCmd PrepareStmt,
    32  ) (fsm.Event, fsm.EventPayload) {
    33  
    34  	retErr := func(err error) (fsm.Event, fsm.EventPayload) {
    35  		return ex.makeErrEvent(err, parseCmd.AST)
    36  	}
    37  
    38  	// The anonymous statement can be overwritten.
    39  	if parseCmd.Name != "" {
    40  		if _, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[parseCmd.Name]; ok {
    41  			err := pgerror.Newf(
    42  				pgcode.DuplicatePreparedStatement,
    43  				"prepared statement %q already exists", parseCmd.Name,
    44  			)
    45  			return retErr(err)
    46  		}
    47  	} else {
    48  		// Deallocate the unnamed statement, if it exists.
    49  		ex.deletePreparedStmt(ctx, "")
    50  	}
    51  
    52  	ps, err := ex.addPreparedStmt(
    53  		ctx,
    54  		parseCmd.Name,
    55  		Statement{Statement: parseCmd.Statement},
    56  		parseCmd.TypeHints,
    57  		PreparedStatementOriginWire,
    58  	)
    59  	if err != nil {
    60  		return retErr(err)
    61  	}
    62  
    63  	// Convert the inferred SQL types back to an array of pgwire Oids.
    64  	if len(ps.TypeHints) > pgwirebase.MaxPreparedStatementArgs {
    65  		return retErr(
    66  			pgwirebase.NewProtocolViolationErrorf(
    67  				"more than %d arguments to prepared statement: %d",
    68  				pgwirebase.MaxPreparedStatementArgs, len(ps.TypeHints)))
    69  	}
    70  	inferredTypes := make([]oid.Oid, len(ps.Types))
    71  	copy(inferredTypes, parseCmd.RawTypeHints)
    72  
    73  	for i := range ps.Types {
    74  		// OID to Datum is not a 1-1 mapping (for example, int4 and int8
    75  		// both map to TypeInt), so we need to maintain the types sent by
    76  		// the client.
    77  		if inferredTypes[i] == 0 {
    78  			t, _ := ps.ValueType(tree.PlaceholderIdx(i))
    79  			inferredTypes[i] = t.Oid()
    80  		}
    81  	}
    82  	// Remember the inferred placeholder types so they can be reported on
    83  	// Describe.
    84  	ps.InferredTypes = inferredTypes
    85  	return nil, nil
    86  }
    87  
    88  // addPreparedStmt creates a new PreparedStatement with the provided name using
    89  // the given query. The new prepared statement is added to the connExecutor and
    90  // also returned. It is illegal to call this when a statement with that name
    91  // already exists (even for anonymous prepared statements).
    92  //
    93  // placeholderHints are used to assist in inferring placeholder types.
    94  func (ex *connExecutor) addPreparedStmt(
    95  	ctx context.Context,
    96  	name string,
    97  	stmt Statement,
    98  	placeholderHints tree.PlaceholderTypes,
    99  	origin PreparedStatementOrigin,
   100  ) (*PreparedStatement, error) {
   101  	if _, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[name]; ok {
   102  		panic(fmt.Sprintf("prepared statement already exists: %q", name))
   103  	}
   104  
   105  	// Prepare the query. This completes the typing of placeholders.
   106  	prepared, err := ex.prepare(ctx, stmt, placeholderHints, origin)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	if err := prepared.memAcc.Grow(ctx, int64(len(name))); err != nil {
   112  		return nil, err
   113  	}
   114  	ex.extraTxnState.prepStmtsNamespace.prepStmts[name] = prepared
   115  	return prepared, nil
   116  }
   117  
   118  // prepare prepares the given statement.
   119  //
   120  // placeholderHints may contain partial type information for placeholders.
   121  // prepare will populate the missing types. It can be nil.
   122  //
   123  // The PreparedStatement is returned (or nil if there are no results). The
   124  // returned PreparedStatement needs to be close()d once its no longer in use.
   125  func (ex *connExecutor) prepare(
   126  	ctx context.Context,
   127  	stmt Statement,
   128  	placeholderHints tree.PlaceholderTypes,
   129  	origin PreparedStatementOrigin,
   130  ) (*PreparedStatement, error) {
   131  	if placeholderHints == nil {
   132  		placeholderHints = make(tree.PlaceholderTypes, stmt.NumPlaceholders)
   133  	}
   134  
   135  	prepared := &PreparedStatement{
   136  		PrepareMetadata: sqlbase.PrepareMetadata{
   137  			PlaceholderTypesInfo: tree.PlaceholderTypesInfo{
   138  				TypeHints: placeholderHints,
   139  			},
   140  		},
   141  		memAcc:   ex.sessionMon.MakeBoundAccount(),
   142  		refCount: 1,
   143  
   144  		createdAt: timeutil.Now(),
   145  		origin:    origin,
   146  	}
   147  	// NB: if we start caching the plan, we'll want to keep around the memory
   148  	// account used for the plan, rather than clearing it.
   149  	defer prepared.memAcc.Clear(ctx)
   150  
   151  	if stmt.AST == nil {
   152  		return prepared, nil
   153  	}
   154  	prepared.Statement = stmt.Statement
   155  
   156  	// Point to the prepared state, which can be further populated during query
   157  	// preparation.
   158  	stmt.Prepared = prepared
   159  
   160  	if err := tree.ProcessPlaceholderAnnotations(&ex.planner.semaCtx, stmt.AST, placeholderHints); err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	// Preparing needs a transaction because it needs to retrieve db/table
   165  	// descriptors for type checking. If we already have an open transaction for
   166  	// this planner, use it. Using the user's transaction here is critical for
   167  	// proper deadlock detection. At the time of writing, it is the case that any
   168  	// data read on behalf of this transaction is not cached for use in other
   169  	// transactions. It's critical that this fact remain true but nothing really
   170  	// enforces it. If we create a new transaction (newTxn is true), we'll need to
   171  	// finish it before we return.
   172  
   173  	var flags planFlags
   174  	prepare := func(ctx context.Context, txn *kv.Txn) (err error) {
   175  		ex.statsCollector.reset(&ex.server.sqlStats, ex.appStats, &ex.phaseTimes)
   176  		p := &ex.planner
   177  		ex.resetPlanner(ctx, p, txn, ex.server.cfg.Clock.PhysicalTime() /* stmtTS */)
   178  		p.stmt = &stmt
   179  		p.semaCtx.Annotations = tree.MakeAnnotations(stmt.NumAnnotations)
   180  		flags, err = ex.populatePrepared(ctx, txn, placeholderHints, p)
   181  		return err
   182  	}
   183  
   184  	if txn := ex.state.mu.txn; txn != nil && txn.IsOpen() {
   185  		// Use the existing transaction.
   186  		if err := prepare(ctx, txn); err != nil {
   187  			return nil, err
   188  		}
   189  	} else {
   190  		// Use a new transaction. This will handle retriable errors here rather
   191  		// than bubbling them up to the connExecutor state machine.
   192  		if err := ex.server.cfg.DB.Txn(ctx, prepare); err != nil {
   193  			return nil, err
   194  		}
   195  	}
   196  
   197  	// Account for the memory used by this prepared statement.
   198  	if err := prepared.memAcc.Grow(ctx, prepared.MemoryEstimate()); err != nil {
   199  		return nil, err
   200  	}
   201  	ex.updateOptCounters(flags)
   202  	return prepared, nil
   203  }
   204  
   205  // populatePrepared analyzes and type-checks the query and populates
   206  // stmt.Prepared.
   207  func (ex *connExecutor) populatePrepared(
   208  	ctx context.Context, txn *kv.Txn, placeholderHints tree.PlaceholderTypes, p *planner,
   209  ) (planFlags, error) {
   210  	if before := ex.server.cfg.TestingKnobs.BeforePrepare; before != nil {
   211  		if err := before(ctx, ex.planner.stmt.String(), txn); err != nil {
   212  			return 0, err
   213  		}
   214  	}
   215  	stmt := p.stmt
   216  	if err := p.semaCtx.Placeholders.Init(stmt.NumPlaceholders, placeholderHints); err != nil {
   217  		return 0, err
   218  	}
   219  	p.extendedEvalCtx.PrepareOnly = true
   220  
   221  	protoTS, err := p.isAsOf(ctx, stmt.AST)
   222  	if err != nil {
   223  		return 0, err
   224  	}
   225  	if protoTS != nil {
   226  		p.semaCtx.AsOfTimestamp = protoTS
   227  		txn.SetFixedTimestamp(ctx, *protoTS)
   228  	}
   229  
   230  	// PREPARE has a limited subset of statements it can be run with. Postgres
   231  	// only allows SELECT, INSERT, UPDATE, DELETE and VALUES statements to be
   232  	// prepared.
   233  	// See: https://www.postgresql.org/docs/current/static/sql-prepare.html
   234  	// However, we allow a large number of additional statements.
   235  	// As of right now, the optimizer only works on SELECT statements and will
   236  	// fallback for all others, so this should be safe for the foreseeable
   237  	// future.
   238  	flags, err := p.prepareUsingOptimizer(ctx)
   239  	if err != nil {
   240  		log.VEventf(ctx, 1, "optimizer prepare failed: %v", err)
   241  		return 0, err
   242  	}
   243  	log.VEvent(ctx, 2, "optimizer prepare succeeded")
   244  	// stmt.Prepared fields have been populated.
   245  	return flags, nil
   246  }
   247  
   248  func (ex *connExecutor) execBind(
   249  	ctx context.Context, bindCmd BindStmt,
   250  ) (fsm.Event, fsm.EventPayload) {
   251  
   252  	retErr := func(err error) (fsm.Event, fsm.EventPayload) {
   253  		return eventNonRetriableErr{IsCommit: fsm.False}, eventNonRetriableErrPayload{err: err}
   254  	}
   255  
   256  	portalName := bindCmd.PortalName
   257  	// The unnamed portal can be freely overwritten.
   258  	if portalName != "" {
   259  		if _, ok := ex.extraTxnState.prepStmtsNamespace.portals[portalName]; ok {
   260  			return retErr(pgerror.Newf(
   261  				pgcode.DuplicateCursor, "portal %q already exists", portalName))
   262  		}
   263  	} else {
   264  		// Deallocate the unnamed portal, if it exists.
   265  		ex.deletePortal(ctx, "")
   266  	}
   267  
   268  	ps, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[bindCmd.PreparedStatementName]
   269  	if !ok {
   270  		return retErr(pgerror.Newf(
   271  			pgcode.InvalidSQLStatementName,
   272  			"unknown prepared statement %q", bindCmd.PreparedStatementName))
   273  	}
   274  
   275  	numQArgs := uint16(len(ps.InferredTypes))
   276  
   277  	// Decode the arguments, except for internal queries for which we just verify
   278  	// that the arguments match what's expected.
   279  	qargs := make(tree.QueryArguments, numQArgs)
   280  	if bindCmd.internalArgs != nil {
   281  		if len(bindCmd.internalArgs) != int(numQArgs) {
   282  			return retErr(
   283  				pgwirebase.NewProtocolViolationErrorf(
   284  					"expected %d arguments, got %d", numQArgs, len(bindCmd.internalArgs)))
   285  		}
   286  		for i, datum := range bindCmd.internalArgs {
   287  			t := ps.InferredTypes[i]
   288  			if oid := datum.ResolvedType().Oid(); datum != tree.DNull && oid != t {
   289  				return retErr(
   290  					pgwirebase.NewProtocolViolationErrorf(
   291  						"for argument %d expected OID %d, got %d", i, t, oid))
   292  			}
   293  			qargs[i] = datum
   294  		}
   295  	} else {
   296  		qArgFormatCodes := bindCmd.ArgFormatCodes
   297  
   298  		// If there is only one format code, then that format code is used to decode all the
   299  		// arguments. But if the number of format codes provided does not match the number of
   300  		// arguments AND it's not a single format code then we cannot infer what format to use to
   301  		// decode all of the arguments.
   302  		if len(qArgFormatCodes) != 1 && len(qArgFormatCodes) != int(numQArgs) {
   303  			return retErr(pgwirebase.NewProtocolViolationErrorf(
   304  				"wrong number of format codes specified: %d for %d arguments",
   305  				len(qArgFormatCodes), numQArgs))
   306  		}
   307  
   308  		// If a single format code is provided and there is more than one argument to be decoded,
   309  		// then expand qArgFormatCodes to the number of arguments provided.
   310  		// If the number of format codes matches the number of arguments then nothing needs to be
   311  		// done.
   312  		if len(qArgFormatCodes) == 1 && numQArgs > 1 {
   313  			fmtCode := qArgFormatCodes[0]
   314  			qArgFormatCodes = make([]pgwirebase.FormatCode, numQArgs)
   315  			for i := range qArgFormatCodes {
   316  				qArgFormatCodes[i] = fmtCode
   317  			}
   318  		}
   319  
   320  		if len(bindCmd.Args) != int(numQArgs) {
   321  			return retErr(
   322  				pgwirebase.NewProtocolViolationErrorf(
   323  					"expected %d arguments, got %d", numQArgs, len(bindCmd.Args)))
   324  		}
   325  
   326  		ptCtx := tree.NewParseTimeContext(ex.state.sqlTimestamp.In(ex.sessionData.DataConversion.Location))
   327  
   328  		for i, arg := range bindCmd.Args {
   329  			k := tree.PlaceholderIdx(i)
   330  			t := ps.InferredTypes[i]
   331  			if arg == nil {
   332  				// nil indicates a NULL argument value.
   333  				qargs[k] = tree.DNull
   334  			} else {
   335  				d, err := pgwirebase.DecodeOidDatum(ptCtx, t, qArgFormatCodes[i], arg)
   336  				if err != nil {
   337  					return retErr(pgerror.Wrapf(err, pgcode.ProtocolViolation,
   338  						"error in argument for %s", k))
   339  				}
   340  				qargs[k] = d
   341  			}
   342  		}
   343  	}
   344  
   345  	numCols := len(ps.Columns)
   346  	if (len(bindCmd.OutFormats) > 1) && (len(bindCmd.OutFormats) != numCols) {
   347  		return retErr(pgwirebase.NewProtocolViolationErrorf(
   348  			"expected 1 or %d for number of format codes, got %d",
   349  			numCols, len(bindCmd.OutFormats)))
   350  	}
   351  
   352  	columnFormatCodes := bindCmd.OutFormats
   353  	if len(bindCmd.OutFormats) == 1 && numCols > 1 {
   354  		// Apply the format code to every column.
   355  		columnFormatCodes = make([]pgwirebase.FormatCode, numCols)
   356  		for i := 0; i < numCols; i++ {
   357  			columnFormatCodes[i] = bindCmd.OutFormats[0]
   358  		}
   359  	}
   360  
   361  	// Create the new PreparedPortal.
   362  	if err := ex.addPortal(
   363  		ctx, portalName, bindCmd.PreparedStatementName, ps, qargs, columnFormatCodes,
   364  	); err != nil {
   365  		return retErr(err)
   366  	}
   367  
   368  	if log.V(2) {
   369  		log.Infof(ctx, "portal: %q for %q, args %q, formats %q",
   370  			portalName, ps.Statement, qargs, columnFormatCodes)
   371  	}
   372  
   373  	return nil, nil
   374  }
   375  
   376  // addPortal creates a new PreparedPortal on the connExecutor.
   377  //
   378  // It is illegal to call this when a portal with that name already exists (even
   379  // for anonymous portals).
   380  func (ex *connExecutor) addPortal(
   381  	ctx context.Context,
   382  	portalName string,
   383  	psName string,
   384  	stmt *PreparedStatement,
   385  	qargs tree.QueryArguments,
   386  	outFormats []pgwirebase.FormatCode,
   387  ) error {
   388  	if _, ok := ex.extraTxnState.prepStmtsNamespace.portals[portalName]; ok {
   389  		panic(fmt.Sprintf("portal already exists: %q", portalName))
   390  	}
   391  
   392  	portal, err := ex.newPreparedPortal(ctx, portalName, stmt, qargs, outFormats)
   393  	if err != nil {
   394  		return err
   395  	}
   396  
   397  	ex.extraTxnState.prepStmtsNamespace.portals[portalName] = portal
   398  	return nil
   399  }
   400  
   401  func (ex *connExecutor) deletePreparedStmt(ctx context.Context, name string) {
   402  	ps, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[name]
   403  	if !ok {
   404  		return
   405  	}
   406  	ps.decRef(ctx)
   407  	delete(ex.extraTxnState.prepStmtsNamespace.prepStmts, name)
   408  }
   409  
   410  func (ex *connExecutor) deletePortal(ctx context.Context, name string) {
   411  	portal, ok := ex.extraTxnState.prepStmtsNamespace.portals[name]
   412  	if !ok {
   413  		return
   414  	}
   415  	portal.decRef(ctx)
   416  	delete(ex.extraTxnState.prepStmtsNamespace.portals, name)
   417  }
   418  
   419  func (ex *connExecutor) execDelPrepStmt(
   420  	ctx context.Context, delCmd DeletePreparedStmt,
   421  ) (fsm.Event, fsm.EventPayload) {
   422  	switch delCmd.Type {
   423  	case pgwirebase.PrepareStatement:
   424  		_, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[delCmd.Name]
   425  		if !ok {
   426  			// The spec says "It is not an error to issue Close against a nonexistent
   427  			// statement or portal name". See
   428  			// https://www.postgresql.org/docs/current/static/protocol-flow.html.
   429  			break
   430  		}
   431  
   432  		ex.deletePreparedStmt(ctx, delCmd.Name)
   433  	case pgwirebase.PreparePortal:
   434  		_, ok := ex.extraTxnState.prepStmtsNamespace.portals[delCmd.Name]
   435  		if !ok {
   436  			break
   437  		}
   438  		ex.deletePortal(ctx, delCmd.Name)
   439  	default:
   440  		panic(fmt.Sprintf("unknown del type: %s", delCmd.Type))
   441  	}
   442  	return nil, nil
   443  }
   444  
   445  func (ex *connExecutor) execDescribe(
   446  	ctx context.Context, descCmd DescribeStmt, res DescribeResult,
   447  ) (fsm.Event, fsm.EventPayload) {
   448  
   449  	retErr := func(err error) (fsm.Event, fsm.EventPayload) {
   450  		return eventNonRetriableErr{IsCommit: fsm.False}, eventNonRetriableErrPayload{err: err}
   451  	}
   452  
   453  	switch descCmd.Type {
   454  	case pgwirebase.PrepareStatement:
   455  		ps, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[descCmd.Name]
   456  		if !ok {
   457  			return retErr(pgerror.Newf(
   458  				pgcode.InvalidSQLStatementName,
   459  				"unknown prepared statement %q", descCmd.Name))
   460  		}
   461  
   462  		res.SetInferredTypes(ps.InferredTypes)
   463  
   464  		if stmtHasNoData(ps.AST) {
   465  			res.SetNoDataRowDescription()
   466  		} else {
   467  			res.SetPrepStmtOutput(ctx, ps.Columns)
   468  		}
   469  	case pgwirebase.PreparePortal:
   470  		portal, ok := ex.extraTxnState.prepStmtsNamespace.portals[descCmd.Name]
   471  		if !ok {
   472  			return retErr(pgerror.Newf(
   473  				pgcode.InvalidCursorName, "unknown portal %q", descCmd.Name))
   474  		}
   475  
   476  		if stmtHasNoData(portal.Stmt.AST) {
   477  			res.SetNoDataRowDescription()
   478  		} else {
   479  			res.SetPortalOutput(ctx, portal.Stmt.Columns, portal.OutFormats)
   480  		}
   481  	default:
   482  		return retErr(errors.AssertionFailedf(
   483  			"unknown describe type: %s", errors.Safe(descCmd.Type)))
   484  	}
   485  	return nil, nil
   486  }