github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/conn_executor_prepare.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package sql 12 13 import ( 14 "context" 15 "fmt" 16 17 "github.com/cockroachdb/cockroach/pkg/kv" 18 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" 19 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" 20 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" 21 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 22 "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" 23 "github.com/cockroachdb/cockroach/pkg/util/fsm" 24 "github.com/cockroachdb/cockroach/pkg/util/log" 25 "github.com/cockroachdb/cockroach/pkg/util/timeutil" 26 "github.com/cockroachdb/errors" 27 "github.com/lib/pq/oid" 28 ) 29 30 func (ex *connExecutor) execPrepare( 31 ctx context.Context, parseCmd PrepareStmt, 32 ) (fsm.Event, fsm.EventPayload) { 33 34 retErr := func(err error) (fsm.Event, fsm.EventPayload) { 35 return ex.makeErrEvent(err, parseCmd.AST) 36 } 37 38 // The anonymous statement can be overwritten. 39 if parseCmd.Name != "" { 40 if _, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[parseCmd.Name]; ok { 41 err := pgerror.Newf( 42 pgcode.DuplicatePreparedStatement, 43 "prepared statement %q already exists", parseCmd.Name, 44 ) 45 return retErr(err) 46 } 47 } else { 48 // Deallocate the unnamed statement, if it exists. 49 ex.deletePreparedStmt(ctx, "") 50 } 51 52 ps, err := ex.addPreparedStmt( 53 ctx, 54 parseCmd.Name, 55 Statement{Statement: parseCmd.Statement}, 56 parseCmd.TypeHints, 57 PreparedStatementOriginWire, 58 ) 59 if err != nil { 60 return retErr(err) 61 } 62 63 // Convert the inferred SQL types back to an array of pgwire Oids. 64 if len(ps.TypeHints) > pgwirebase.MaxPreparedStatementArgs { 65 return retErr( 66 pgwirebase.NewProtocolViolationErrorf( 67 "more than %d arguments to prepared statement: %d", 68 pgwirebase.MaxPreparedStatementArgs, len(ps.TypeHints))) 69 } 70 inferredTypes := make([]oid.Oid, len(ps.Types)) 71 copy(inferredTypes, parseCmd.RawTypeHints) 72 73 for i := range ps.Types { 74 // OID to Datum is not a 1-1 mapping (for example, int4 and int8 75 // both map to TypeInt), so we need to maintain the types sent by 76 // the client. 77 if inferredTypes[i] == 0 { 78 t, _ := ps.ValueType(tree.PlaceholderIdx(i)) 79 inferredTypes[i] = t.Oid() 80 } 81 } 82 // Remember the inferred placeholder types so they can be reported on 83 // Describe. 84 ps.InferredTypes = inferredTypes 85 return nil, nil 86 } 87 88 // addPreparedStmt creates a new PreparedStatement with the provided name using 89 // the given query. The new prepared statement is added to the connExecutor and 90 // also returned. It is illegal to call this when a statement with that name 91 // already exists (even for anonymous prepared statements). 92 // 93 // placeholderHints are used to assist in inferring placeholder types. 94 func (ex *connExecutor) addPreparedStmt( 95 ctx context.Context, 96 name string, 97 stmt Statement, 98 placeholderHints tree.PlaceholderTypes, 99 origin PreparedStatementOrigin, 100 ) (*PreparedStatement, error) { 101 if _, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[name]; ok { 102 panic(fmt.Sprintf("prepared statement already exists: %q", name)) 103 } 104 105 // Prepare the query. This completes the typing of placeholders. 106 prepared, err := ex.prepare(ctx, stmt, placeholderHints, origin) 107 if err != nil { 108 return nil, err 109 } 110 111 if err := prepared.memAcc.Grow(ctx, int64(len(name))); err != nil { 112 return nil, err 113 } 114 ex.extraTxnState.prepStmtsNamespace.prepStmts[name] = prepared 115 return prepared, nil 116 } 117 118 // prepare prepares the given statement. 119 // 120 // placeholderHints may contain partial type information for placeholders. 121 // prepare will populate the missing types. It can be nil. 122 // 123 // The PreparedStatement is returned (or nil if there are no results). The 124 // returned PreparedStatement needs to be close()d once its no longer in use. 125 func (ex *connExecutor) prepare( 126 ctx context.Context, 127 stmt Statement, 128 placeholderHints tree.PlaceholderTypes, 129 origin PreparedStatementOrigin, 130 ) (*PreparedStatement, error) { 131 if placeholderHints == nil { 132 placeholderHints = make(tree.PlaceholderTypes, stmt.NumPlaceholders) 133 } 134 135 prepared := &PreparedStatement{ 136 PrepareMetadata: sqlbase.PrepareMetadata{ 137 PlaceholderTypesInfo: tree.PlaceholderTypesInfo{ 138 TypeHints: placeholderHints, 139 }, 140 }, 141 memAcc: ex.sessionMon.MakeBoundAccount(), 142 refCount: 1, 143 144 createdAt: timeutil.Now(), 145 origin: origin, 146 } 147 // NB: if we start caching the plan, we'll want to keep around the memory 148 // account used for the plan, rather than clearing it. 149 defer prepared.memAcc.Clear(ctx) 150 151 if stmt.AST == nil { 152 return prepared, nil 153 } 154 prepared.Statement = stmt.Statement 155 156 // Point to the prepared state, which can be further populated during query 157 // preparation. 158 stmt.Prepared = prepared 159 160 if err := tree.ProcessPlaceholderAnnotations(&ex.planner.semaCtx, stmt.AST, placeholderHints); err != nil { 161 return nil, err 162 } 163 164 // Preparing needs a transaction because it needs to retrieve db/table 165 // descriptors for type checking. If we already have an open transaction for 166 // this planner, use it. Using the user's transaction here is critical for 167 // proper deadlock detection. At the time of writing, it is the case that any 168 // data read on behalf of this transaction is not cached for use in other 169 // transactions. It's critical that this fact remain true but nothing really 170 // enforces it. If we create a new transaction (newTxn is true), we'll need to 171 // finish it before we return. 172 173 var flags planFlags 174 prepare := func(ctx context.Context, txn *kv.Txn) (err error) { 175 ex.statsCollector.reset(&ex.server.sqlStats, ex.appStats, &ex.phaseTimes) 176 p := &ex.planner 177 ex.resetPlanner(ctx, p, txn, ex.server.cfg.Clock.PhysicalTime() /* stmtTS */) 178 p.stmt = &stmt 179 p.semaCtx.Annotations = tree.MakeAnnotations(stmt.NumAnnotations) 180 flags, err = ex.populatePrepared(ctx, txn, placeholderHints, p) 181 return err 182 } 183 184 if txn := ex.state.mu.txn; txn != nil && txn.IsOpen() { 185 // Use the existing transaction. 186 if err := prepare(ctx, txn); err != nil { 187 return nil, err 188 } 189 } else { 190 // Use a new transaction. This will handle retriable errors here rather 191 // than bubbling them up to the connExecutor state machine. 192 if err := ex.server.cfg.DB.Txn(ctx, prepare); err != nil { 193 return nil, err 194 } 195 } 196 197 // Account for the memory used by this prepared statement. 198 if err := prepared.memAcc.Grow(ctx, prepared.MemoryEstimate()); err != nil { 199 return nil, err 200 } 201 ex.updateOptCounters(flags) 202 return prepared, nil 203 } 204 205 // populatePrepared analyzes and type-checks the query and populates 206 // stmt.Prepared. 207 func (ex *connExecutor) populatePrepared( 208 ctx context.Context, txn *kv.Txn, placeholderHints tree.PlaceholderTypes, p *planner, 209 ) (planFlags, error) { 210 if before := ex.server.cfg.TestingKnobs.BeforePrepare; before != nil { 211 if err := before(ctx, ex.planner.stmt.String(), txn); err != nil { 212 return 0, err 213 } 214 } 215 stmt := p.stmt 216 if err := p.semaCtx.Placeholders.Init(stmt.NumPlaceholders, placeholderHints); err != nil { 217 return 0, err 218 } 219 p.extendedEvalCtx.PrepareOnly = true 220 221 protoTS, err := p.isAsOf(ctx, stmt.AST) 222 if err != nil { 223 return 0, err 224 } 225 if protoTS != nil { 226 p.semaCtx.AsOfTimestamp = protoTS 227 txn.SetFixedTimestamp(ctx, *protoTS) 228 } 229 230 // PREPARE has a limited subset of statements it can be run with. Postgres 231 // only allows SELECT, INSERT, UPDATE, DELETE and VALUES statements to be 232 // prepared. 233 // See: https://www.postgresql.org/docs/current/static/sql-prepare.html 234 // However, we allow a large number of additional statements. 235 // As of right now, the optimizer only works on SELECT statements and will 236 // fallback for all others, so this should be safe for the foreseeable 237 // future. 238 flags, err := p.prepareUsingOptimizer(ctx) 239 if err != nil { 240 log.VEventf(ctx, 1, "optimizer prepare failed: %v", err) 241 return 0, err 242 } 243 log.VEvent(ctx, 2, "optimizer prepare succeeded") 244 // stmt.Prepared fields have been populated. 245 return flags, nil 246 } 247 248 func (ex *connExecutor) execBind( 249 ctx context.Context, bindCmd BindStmt, 250 ) (fsm.Event, fsm.EventPayload) { 251 252 retErr := func(err error) (fsm.Event, fsm.EventPayload) { 253 return eventNonRetriableErr{IsCommit: fsm.False}, eventNonRetriableErrPayload{err: err} 254 } 255 256 portalName := bindCmd.PortalName 257 // The unnamed portal can be freely overwritten. 258 if portalName != "" { 259 if _, ok := ex.extraTxnState.prepStmtsNamespace.portals[portalName]; ok { 260 return retErr(pgerror.Newf( 261 pgcode.DuplicateCursor, "portal %q already exists", portalName)) 262 } 263 } else { 264 // Deallocate the unnamed portal, if it exists. 265 ex.deletePortal(ctx, "") 266 } 267 268 ps, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[bindCmd.PreparedStatementName] 269 if !ok { 270 return retErr(pgerror.Newf( 271 pgcode.InvalidSQLStatementName, 272 "unknown prepared statement %q", bindCmd.PreparedStatementName)) 273 } 274 275 numQArgs := uint16(len(ps.InferredTypes)) 276 277 // Decode the arguments, except for internal queries for which we just verify 278 // that the arguments match what's expected. 279 qargs := make(tree.QueryArguments, numQArgs) 280 if bindCmd.internalArgs != nil { 281 if len(bindCmd.internalArgs) != int(numQArgs) { 282 return retErr( 283 pgwirebase.NewProtocolViolationErrorf( 284 "expected %d arguments, got %d", numQArgs, len(bindCmd.internalArgs))) 285 } 286 for i, datum := range bindCmd.internalArgs { 287 t := ps.InferredTypes[i] 288 if oid := datum.ResolvedType().Oid(); datum != tree.DNull && oid != t { 289 return retErr( 290 pgwirebase.NewProtocolViolationErrorf( 291 "for argument %d expected OID %d, got %d", i, t, oid)) 292 } 293 qargs[i] = datum 294 } 295 } else { 296 qArgFormatCodes := bindCmd.ArgFormatCodes 297 298 // If there is only one format code, then that format code is used to decode all the 299 // arguments. But if the number of format codes provided does not match the number of 300 // arguments AND it's not a single format code then we cannot infer what format to use to 301 // decode all of the arguments. 302 if len(qArgFormatCodes) != 1 && len(qArgFormatCodes) != int(numQArgs) { 303 return retErr(pgwirebase.NewProtocolViolationErrorf( 304 "wrong number of format codes specified: %d for %d arguments", 305 len(qArgFormatCodes), numQArgs)) 306 } 307 308 // If a single format code is provided and there is more than one argument to be decoded, 309 // then expand qArgFormatCodes to the number of arguments provided. 310 // If the number of format codes matches the number of arguments then nothing needs to be 311 // done. 312 if len(qArgFormatCodes) == 1 && numQArgs > 1 { 313 fmtCode := qArgFormatCodes[0] 314 qArgFormatCodes = make([]pgwirebase.FormatCode, numQArgs) 315 for i := range qArgFormatCodes { 316 qArgFormatCodes[i] = fmtCode 317 } 318 } 319 320 if len(bindCmd.Args) != int(numQArgs) { 321 return retErr( 322 pgwirebase.NewProtocolViolationErrorf( 323 "expected %d arguments, got %d", numQArgs, len(bindCmd.Args))) 324 } 325 326 ptCtx := tree.NewParseTimeContext(ex.state.sqlTimestamp.In(ex.sessionData.DataConversion.Location)) 327 328 for i, arg := range bindCmd.Args { 329 k := tree.PlaceholderIdx(i) 330 t := ps.InferredTypes[i] 331 if arg == nil { 332 // nil indicates a NULL argument value. 333 qargs[k] = tree.DNull 334 } else { 335 d, err := pgwirebase.DecodeOidDatum(ptCtx, t, qArgFormatCodes[i], arg) 336 if err != nil { 337 return retErr(pgerror.Wrapf(err, pgcode.ProtocolViolation, 338 "error in argument for %s", k)) 339 } 340 qargs[k] = d 341 } 342 } 343 } 344 345 numCols := len(ps.Columns) 346 if (len(bindCmd.OutFormats) > 1) && (len(bindCmd.OutFormats) != numCols) { 347 return retErr(pgwirebase.NewProtocolViolationErrorf( 348 "expected 1 or %d for number of format codes, got %d", 349 numCols, len(bindCmd.OutFormats))) 350 } 351 352 columnFormatCodes := bindCmd.OutFormats 353 if len(bindCmd.OutFormats) == 1 && numCols > 1 { 354 // Apply the format code to every column. 355 columnFormatCodes = make([]pgwirebase.FormatCode, numCols) 356 for i := 0; i < numCols; i++ { 357 columnFormatCodes[i] = bindCmd.OutFormats[0] 358 } 359 } 360 361 // Create the new PreparedPortal. 362 if err := ex.addPortal( 363 ctx, portalName, bindCmd.PreparedStatementName, ps, qargs, columnFormatCodes, 364 ); err != nil { 365 return retErr(err) 366 } 367 368 if log.V(2) { 369 log.Infof(ctx, "portal: %q for %q, args %q, formats %q", 370 portalName, ps.Statement, qargs, columnFormatCodes) 371 } 372 373 return nil, nil 374 } 375 376 // addPortal creates a new PreparedPortal on the connExecutor. 377 // 378 // It is illegal to call this when a portal with that name already exists (even 379 // for anonymous portals). 380 func (ex *connExecutor) addPortal( 381 ctx context.Context, 382 portalName string, 383 psName string, 384 stmt *PreparedStatement, 385 qargs tree.QueryArguments, 386 outFormats []pgwirebase.FormatCode, 387 ) error { 388 if _, ok := ex.extraTxnState.prepStmtsNamespace.portals[portalName]; ok { 389 panic(fmt.Sprintf("portal already exists: %q", portalName)) 390 } 391 392 portal, err := ex.newPreparedPortal(ctx, portalName, stmt, qargs, outFormats) 393 if err != nil { 394 return err 395 } 396 397 ex.extraTxnState.prepStmtsNamespace.portals[portalName] = portal 398 return nil 399 } 400 401 func (ex *connExecutor) deletePreparedStmt(ctx context.Context, name string) { 402 ps, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[name] 403 if !ok { 404 return 405 } 406 ps.decRef(ctx) 407 delete(ex.extraTxnState.prepStmtsNamespace.prepStmts, name) 408 } 409 410 func (ex *connExecutor) deletePortal(ctx context.Context, name string) { 411 portal, ok := ex.extraTxnState.prepStmtsNamespace.portals[name] 412 if !ok { 413 return 414 } 415 portal.decRef(ctx) 416 delete(ex.extraTxnState.prepStmtsNamespace.portals, name) 417 } 418 419 func (ex *connExecutor) execDelPrepStmt( 420 ctx context.Context, delCmd DeletePreparedStmt, 421 ) (fsm.Event, fsm.EventPayload) { 422 switch delCmd.Type { 423 case pgwirebase.PrepareStatement: 424 _, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[delCmd.Name] 425 if !ok { 426 // The spec says "It is not an error to issue Close against a nonexistent 427 // statement or portal name". See 428 // https://www.postgresql.org/docs/current/static/protocol-flow.html. 429 break 430 } 431 432 ex.deletePreparedStmt(ctx, delCmd.Name) 433 case pgwirebase.PreparePortal: 434 _, ok := ex.extraTxnState.prepStmtsNamespace.portals[delCmd.Name] 435 if !ok { 436 break 437 } 438 ex.deletePortal(ctx, delCmd.Name) 439 default: 440 panic(fmt.Sprintf("unknown del type: %s", delCmd.Type)) 441 } 442 return nil, nil 443 } 444 445 func (ex *connExecutor) execDescribe( 446 ctx context.Context, descCmd DescribeStmt, res DescribeResult, 447 ) (fsm.Event, fsm.EventPayload) { 448 449 retErr := func(err error) (fsm.Event, fsm.EventPayload) { 450 return eventNonRetriableErr{IsCommit: fsm.False}, eventNonRetriableErrPayload{err: err} 451 } 452 453 switch descCmd.Type { 454 case pgwirebase.PrepareStatement: 455 ps, ok := ex.extraTxnState.prepStmtsNamespace.prepStmts[descCmd.Name] 456 if !ok { 457 return retErr(pgerror.Newf( 458 pgcode.InvalidSQLStatementName, 459 "unknown prepared statement %q", descCmd.Name)) 460 } 461 462 res.SetInferredTypes(ps.InferredTypes) 463 464 if stmtHasNoData(ps.AST) { 465 res.SetNoDataRowDescription() 466 } else { 467 res.SetPrepStmtOutput(ctx, ps.Columns) 468 } 469 case pgwirebase.PreparePortal: 470 portal, ok := ex.extraTxnState.prepStmtsNamespace.portals[descCmd.Name] 471 if !ok { 472 return retErr(pgerror.Newf( 473 pgcode.InvalidCursorName, "unknown portal %q", descCmd.Name)) 474 } 475 476 if stmtHasNoData(portal.Stmt.AST) { 477 res.SetNoDataRowDescription() 478 } else { 479 res.SetPortalOutput(ctx, portal.Stmt.Columns, portal.OutFormats) 480 } 481 default: 482 return retErr(errors.AssertionFailedf( 483 "unknown describe type: %s", errors.Safe(descCmd.Type))) 484 } 485 return nil, nil 486 }