github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/pgwire/conn.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package pgwire 12 13 import ( 14 "bufio" 15 "bytes" 16 "context" 17 "fmt" 18 "io" 19 "net" 20 "strconv" 21 "sync" 22 "sync/atomic" 23 "time" 24 25 "github.com/cockroachdb/cockroach/pkg/security" 26 "github.com/cockroachdb/cockroach/pkg/settings" 27 "github.com/cockroachdb/cockroach/pkg/sql" 28 "github.com/cockroachdb/cockroach/pkg/sql/parser" 29 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" 30 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" 31 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" 32 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 33 "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" 34 "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" 35 "github.com/cockroachdb/cockroach/pkg/sql/sqltelemetry" 36 "github.com/cockroachdb/cockroach/pkg/sql/types" 37 "github.com/cockroachdb/cockroach/pkg/util/log" 38 "github.com/cockroachdb/cockroach/pkg/util/mon" 39 "github.com/cockroachdb/cockroach/pkg/util/stop" 40 "github.com/cockroachdb/cockroach/pkg/util/timeutil" 41 "github.com/cockroachdb/cockroach/pkg/util/tracing" 42 "github.com/cockroachdb/errors" 43 "github.com/cockroachdb/logtags" 44 "github.com/lib/pq/oid" 45 ) 46 47 // conn implements a pgwire network connection (version 3 of the protocol, 48 // implemented by Postgres v7.4 and later). conn.serve() reads protocol 49 // messages, transforms them into commands that it pushes onto a StmtBuf (where 50 // they'll be picked up and executed by the connExecutor). 51 // The connExecutor produces results for the commands, which are delivered to 52 // the client through the sql.ClientComm interface, implemented by this conn 53 // (code is in command_result.go). 54 type conn struct { 55 conn net.Conn 56 57 sessionArgs sql.SessionArgs 58 metrics *ServerMetrics 59 60 // rd is a buffered reader consuming conn. All reads from conn go through 61 // this. 62 rd bufio.Reader 63 64 // parser is used to avoid allocating a parser each time. 65 parser parser.Parser 66 67 // stmtBuf is populated with commands queued for execution by this conn. 68 stmtBuf sql.StmtBuf 69 70 // res is used to avoid allocations in the conn's ClientComm implementation. 71 res commandResult 72 73 // err is an error, accessed atomically. It represents any error encountered 74 // while accessing the underlying network connection. This can read via 75 // GetErr() by anybody. If it is found to be != nil, the conn is no longer to 76 // be used. 77 err atomic.Value 78 79 // writerState groups together all aspects of the write-side state of the 80 // connection. 81 writerState struct { 82 fi flushInfo 83 // buf contains command results (rows, etc.) until they're flushed to the 84 // network connection. 85 buf bytes.Buffer 86 tagBuf [64]byte 87 } 88 89 readBuf pgwirebase.ReadBuffer 90 msgBuilder writeBuffer 91 92 sv *settings.Values 93 94 // testingLogEnabled is used in unit tests in this package to 95 // force-enable auth logging without dancing around the 96 // asynchronicity of cluster settings. 97 testingLogEnabled bool 98 } 99 100 // serveConn creates a conn that will serve the netConn. It returns once the 101 // network connection is closed. 102 // 103 // Internally, a connExecutor will be created to execute commands. Commands read 104 // from the network are buffered in a stmtBuf which is consumed by the 105 // connExecutor. The connExecutor produces results which are buffered and 106 // sometimes synchronously flushed to the network. 107 // 108 // The reader goroutine (this one) outlives the connExecutor's goroutine (the 109 // "processor goroutine"). 110 // However, they can both signal each other to stop. Here's how the different 111 // cases work: 112 // 1) The reader receives a ClientMsgTerminate protocol packet: the reader 113 // closes the stmtBuf and also cancels the command processing context. These 114 // actions will prompt the command processor to finish. 115 // 2) The reader gets a read error from the network connection: like above, the 116 // reader closes the command processor. 117 // 3) The reader's context is canceled (happens when the server is draining but 118 // the connection was busy and hasn't quit yet): the reader notices the canceled 119 // context and, like above, closes the processor. 120 // 4) The processor encouters an error. This error can come from various fatal 121 // conditions encoutered internally by the processor, or from a network 122 // communication error encountered while flushing results to the network. 123 // The processor will cancel the reader's context and terminate. 124 // Note that query processing errors are different; they don't cause the 125 // termination of the connection. 126 // 127 // Draining notes: 128 // 129 // The reader notices that the server is draining by polling the draining() 130 // closure passed to serveConn. At that point, the reader delegates the 131 // responsibility of closing the connection to the statement processor: it will 132 // push a DrainRequest to the stmtBuf which signal the processor to quit ASAP. 133 // The processor will quit immediately upon seeing that command if it's not 134 // currently in a transaction. If it is in a transaction, it will wait until the 135 // first time a Sync command is processed outside of a transaction - the logic 136 // being that we want to stop when we're both outside transactions and outside 137 // batches. 138 func (s *Server) serveConn( 139 ctx context.Context, 140 netConn net.Conn, 141 sArgs sql.SessionArgs, 142 reserved mon.BoundAccount, 143 authOpt authOptions, 144 ) { 145 sArgs.RemoteAddr = netConn.RemoteAddr() 146 147 if log.V(2) { 148 log.Infof(ctx, "new connection with options: %+v", sArgs) 149 } 150 151 c := newConn(netConn, sArgs, &s.metrics, &s.execCfg.Settings.SV) 152 c.testingLogEnabled = atomic.LoadInt32(&s.testingLogEnabled) > 0 153 154 // Do the reading of commands from the network. 155 c.serveImpl(ctx, s.IsDraining, s.SQLServer, reserved, authOpt, s.stopper) 156 } 157 158 func newConn( 159 netConn net.Conn, sArgs sql.SessionArgs, metrics *ServerMetrics, sv *settings.Values, 160 ) *conn { 161 c := &conn{ 162 conn: netConn, 163 sessionArgs: sArgs, 164 metrics: metrics, 165 rd: *bufio.NewReader(netConn), 166 sv: sv, 167 } 168 c.stmtBuf.Init() 169 c.res.released = true 170 c.writerState.fi.buf = &c.writerState.buf 171 c.writerState.fi.lastFlushed = -1 172 c.writerState.fi.cmdStarts = make(map[sql.CmdPos]int) 173 c.msgBuilder.init(metrics.BytesOutCount) 174 175 return c 176 } 177 178 func (c *conn) setErr(err error) { 179 c.err.Store(err) 180 } 181 182 func (c *conn) GetErr() error { 183 err := c.err.Load() 184 if err != nil { 185 return err.(error) 186 } 187 return nil 188 } 189 190 func (c *conn) authLogEnabled() bool { 191 return c.testingLogEnabled || logSessionAuth.Get(c.sv) 192 } 193 194 // serveImpl continuously reads from the network connection and pushes execution 195 // instructions into a sql.StmtBuf, from where they'll be processed by a command 196 // "processor" goroutine (a connExecutor). 197 // The method returns when the pgwire termination message is received, when 198 // network communication fails, when the server is draining or when ctx is 199 // canceled (which also happens when draining (but not from the get-go), and 200 // when the processor encounters a fatal error). 201 // 202 // serveImpl always closes the network connection before returning. 203 // 204 // sqlServer is used to create the command processor. As a special facility for 205 // tests, sqlServer can be nil, in which case the command processor and the 206 // write-side of the connection will not be created. 207 func (c *conn) serveImpl( 208 ctx context.Context, 209 draining func() bool, 210 sqlServer *sql.Server, 211 reserved mon.BoundAccount, 212 authOpt authOptions, 213 stopper *stop.Stopper, 214 ) { 215 defer func() { _ = c.conn.Close() }() 216 217 ctx = logtags.AddTag(ctx, "user", c.sessionArgs.User) 218 219 inTestWithoutSQL := sqlServer == nil 220 var authLogger *log.SecondaryLogger 221 if !inTestWithoutSQL { 222 authLogger = sqlServer.GetExecutorConfig().AuthLogger 223 sessionStart := timeutil.Now() 224 defer func() { 225 if c.authLogEnabled() { 226 authLogger.Logf(ctx, "session terminated; duration: %s", timeutil.Now().Sub(sessionStart)) 227 } 228 }() 229 } 230 231 // NOTE: We're going to write a few messages to the connection in this method, 232 // for the handshake. After that, all writes are done async, in the 233 // startWriter() goroutine. 234 235 ctx, cancelConn := context.WithCancel(ctx) 236 defer cancelConn() // This calms the linter that wants these callbacks to always be called. 237 238 var sentDrainSignal bool 239 // The net.Conn is switched to a conn that exits if the ctx is canceled. 240 c.conn = newReadTimeoutConn(c.conn, func() error { 241 // If the context was canceled, it's time to stop reading. Either a 242 // higher-level server or the command processor have canceled us. 243 if ctx.Err() != nil { 244 return ctx.Err() 245 } 246 // If the server is draining, we'll let the processor know by pushing a 247 // DrainRequest. This will make the processor quit whenever it finds a good 248 // time. 249 if !sentDrainSignal && draining() { 250 _ /* err */ = c.stmtBuf.Push(ctx, sql.DrainRequest{}) 251 sentDrainSignal = true 252 } 253 return nil 254 }) 255 c.rd = *bufio.NewReader(c.conn) 256 257 // the authPipe below logs authentication messages iff its auth 258 // logger is non-nil. We define this here. 259 var sessionAuthLogger *log.SecondaryLogger 260 if !inTestWithoutSQL && c.authLogEnabled() { 261 sessionAuthLogger = authLogger 262 } 263 264 // We'll build an authPipe to communicate with the authentication process. 265 authPipe := newAuthPipe(c, sessionAuthLogger) 266 var authenticator authenticatorIO = authPipe 267 268 // procCh is the channel on which we'll receive the termination signal from 269 // the command processor. 270 var procCh <-chan error 271 272 if sqlServer != nil { 273 // Spawn the command processing goroutine, which also handles connection 274 // authentication). It will notify us when it's done through procCh, and 275 // we'll also interact with the authentication process through ac. 276 var ac AuthConn = authPipe 277 procCh = c.processCommandsAsync(ctx, authOpt, ac, sqlServer, reserved, cancelConn) 278 } else { 279 // sqlServer == nil means we are in a local test. In this case 280 // we only need the minimum to make pgx happy. 281 var err error 282 for param, value := range testingStatusReportParams { 283 if err := c.sendParamStatus(param, value); err != nil { 284 break 285 } 286 } 287 if err != nil { 288 reserved.Close(ctx) 289 return 290 } 291 var ac AuthConn = authPipe 292 // Simulate auth succeeding. 293 ac.AuthOK(fixedIntSizer{size: types.Int}) 294 dummyCh := make(chan error) 295 close(dummyCh) 296 procCh = dummyCh 297 // An initial readyForQuery message is part of the handshake. 298 c.msgBuilder.initMsg(pgwirebase.ServerMsgReady) 299 c.msgBuilder.writeByte(byte(sql.IdleTxnBlock)) 300 if err := c.msgBuilder.finishMsg(c.conn); err != nil { 301 reserved.Close(ctx) 302 return 303 } 304 } 305 306 var err error 307 var terminateSeen bool 308 var doingExtendedQueryMessage bool 309 310 // We need an intSizer, which we're ultimately going to get from the 311 // authenticator once authentication succeeds (because it will actually be a 312 // ConnectionHandler). Until then, we unfortunately still need some intSizer 313 // because we technically might enqueue parsed statements in the statement 314 // buffer even before authentication succeeds (because we need this go routine 315 // to keep reading from the network connection while authentication is in 316 // progress in order to react to the connection closing). 317 var intSizer unqualifiedIntSizer = fixedIntSizer{size: types.Int} 318 var authDone bool 319 Loop: 320 for { 321 var typ pgwirebase.ClientMessageType 322 var n int 323 typ, n, err = c.readBuf.ReadTypedMsg(&c.rd) 324 c.metrics.BytesInCount.Inc(int64(n)) 325 if err != nil { 326 break Loop 327 } 328 timeReceived := timeutil.Now() 329 log.VEventf(ctx, 2, "pgwire: processing %s", typ) 330 331 if !authDone { 332 if typ == pgwirebase.ClientMsgPassword { 333 var pwd []byte 334 if pwd, err = c.readBuf.GetBytes(n - 4); err != nil { 335 break Loop 336 } 337 // Pass the data to the authenticator. This hopefully causes it to finish 338 // authentication in the background and give us an intSizer when we loop 339 // around. 340 if err = authenticator.sendPwdData(pwd); err != nil { 341 break Loop 342 } 343 continue 344 } 345 // Wait for the auth result. 346 intSizer, err = authenticator.authResult() 347 if err != nil { 348 // The error has already been sent to the client. 349 break Loop 350 } else { 351 authDone = true 352 } 353 } 354 355 switch typ { 356 case pgwirebase.ClientMsgPassword: 357 // This messages are only acceptable during the auth phase, handled above. 358 err = pgwirebase.NewProtocolViolationErrorf("unexpected authentication data") 359 _ /* err */ = writeErr( 360 ctx, &sqlServer.GetExecutorConfig().Settings.SV, err, 361 &c.msgBuilder, &c.writerState.buf) 362 break Loop 363 case pgwirebase.ClientMsgSimpleQuery: 364 if doingExtendedQueryMessage { 365 if err = c.stmtBuf.Push( 366 ctx, 367 sql.SendError{ 368 Err: pgwirebase.NewProtocolViolationErrorf( 369 "SimpleQuery not allowed while in extended protocol mode"), 370 }, 371 ); err != nil { 372 break 373 } 374 } 375 if err = c.handleSimpleQuery( 376 ctx, &c.readBuf, timeReceived, intSizer.GetUnqualifiedIntSize(), 377 ); err != nil { 378 break 379 } 380 err = c.stmtBuf.Push(ctx, sql.Sync{}) 381 382 case pgwirebase.ClientMsgExecute: 383 doingExtendedQueryMessage = true 384 err = c.handleExecute(ctx, &c.readBuf, timeReceived) 385 386 case pgwirebase.ClientMsgParse: 387 doingExtendedQueryMessage = true 388 err = c.handleParse(ctx, &c.readBuf, intSizer.GetUnqualifiedIntSize()) 389 390 case pgwirebase.ClientMsgDescribe: 391 doingExtendedQueryMessage = true 392 err = c.handleDescribe(ctx, &c.readBuf) 393 394 case pgwirebase.ClientMsgBind: 395 doingExtendedQueryMessage = true 396 err = c.handleBind(ctx, &c.readBuf) 397 398 case pgwirebase.ClientMsgClose: 399 doingExtendedQueryMessage = true 400 err = c.handleClose(ctx, &c.readBuf) 401 402 case pgwirebase.ClientMsgTerminate: 403 terminateSeen = true 404 break Loop 405 406 case pgwirebase.ClientMsgSync: 407 doingExtendedQueryMessage = false 408 // We're starting a batch here. If the client continues using the extended 409 // protocol and encounters an error, everything until the next sync 410 // message has to be skipped. See: 411 // https://www.postgresql.org/docs/current/10/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY 412 413 err = c.stmtBuf.Push(ctx, sql.Sync{}) 414 415 case pgwirebase.ClientMsgFlush: 416 doingExtendedQueryMessage = true 417 err = c.handleFlush(ctx) 418 419 case pgwirebase.ClientMsgCopyData, pgwirebase.ClientMsgCopyDone, pgwirebase.ClientMsgCopyFail: 420 // We're supposed to ignore these messages, per the protocol spec. This 421 // state will happen when an error occurs on the server-side during a copy 422 // operation: the server will send an error and a ready message back to 423 // the client, and must then ignore further copy messages. See: 424 // https://github.com/postgres/postgres/blob/6e1dd2773eb60a6ab87b27b8d9391b756e904ac3/src/backend/tcop/postgres.c#L4295 425 426 default: 427 err = c.stmtBuf.Push( 428 ctx, 429 sql.SendError{Err: pgwirebase.NewUnrecognizedMsgTypeErr(typ)}) 430 } 431 if err != nil { 432 break Loop 433 } 434 } 435 436 // We're done reading data from the client, so make the communication 437 // goroutine stop. Depending on what that goroutine is currently doing (or 438 // blocked on), we cancel and close all the possible channels to make sure we 439 // tickle it in the right way. 440 441 // Signal command processing to stop. It might be the case that the processor 442 // canceled our context and that's how we got here; in that case, this will 443 // be a no-op. 444 c.stmtBuf.Close() 445 // Cancel the processor's context. 446 cancelConn() 447 // In case the authenticator is blocked on waiting for data from the client, 448 // tell it that there's no more data coming. This is a no-op if authentication 449 // was completed already. 450 authenticator.noMorePwdData() 451 452 // Wait for the processor goroutine to finish, if it hasn't already. We're 453 // ignoring the error we get from it, as we have no use for it. It might be a 454 // connection error, or a context cancelation error case this goroutine is the 455 // one that triggered the execution to stop. 456 <-procCh 457 458 if terminateSeen { 459 return 460 } 461 // If we're draining, let the client know by piling on an AdminShutdownError 462 // and flushing the buffer. 463 if draining() { 464 // TODO(andrei): I think sending this extra error to the client if we also 465 // sent another error for the last query (like a context canceled) is a bad 466 // idead; see #22630. I think we should find a way to return the 467 // AdminShutdown error as the only result of the query. 468 _ /* err */ = writeErr(ctx, &sqlServer.GetExecutorConfig().Settings.SV, 469 newAdminShutdownErr(ErrDrainingExistingConn), &c.msgBuilder, &c.writerState.buf) 470 _ /* n */, _ /* err */ = c.writerState.buf.WriteTo(c.conn) 471 } 472 } 473 474 // unqualifiedIntSizer is used by a conn to get the SQL session's current int size 475 // setting. 476 // 477 // It's a restriction on the ConnectionHandler type. 478 type unqualifiedIntSizer interface { 479 // GetUnqualifiedIntSize returns the size that the parser should consider for an 480 // unqualified INT. 481 GetUnqualifiedIntSize() *types.T 482 } 483 484 type fixedIntSizer struct { 485 size *types.T 486 } 487 488 func (f fixedIntSizer) GetUnqualifiedIntSize() *types.T { 489 return f.size 490 } 491 492 // processCommandsAsync spawns a goroutine that authenticates the connection and 493 // then processes commands from c.stmtBuf. 494 // 495 // It returns a channel that will be signaled when this goroutine is done. 496 // Whatever error is returned on that channel has already been written to the 497 // client connection, if applicable. 498 // 499 // If authentication fails, this goroutine finishes and, as always, cancelConn 500 // is called. 501 // 502 // Args: 503 // ac: An interface used by the authentication process to receive password data 504 // and to ultimately declare the authentication successful. 505 // reserved: Reserved memory. This method takes ownership. 506 // cancelConn: A function to be called when this goroutine exits. Its goal is to 507 // cancel the connection's context, thus stopping the connection's goroutine. 508 // The returned channel is also closed before this goroutine dies, but the 509 // connection's goroutine is not expected to be reading from that channel 510 // (instead, it's expected to always be monitoring the network connection). 511 func (c *conn) processCommandsAsync( 512 ctx context.Context, 513 authOpt authOptions, 514 ac AuthConn, 515 sqlServer *sql.Server, 516 reserved mon.BoundAccount, 517 cancelConn func(), 518 ) <-chan error { 519 // reservedOwned is true while we own reserved, false when we pass ownership 520 // away. 521 reservedOwned := true 522 retCh := make(chan error, 1) 523 go func() { 524 var retErr error 525 var connHandler sql.ConnectionHandler 526 var authOK bool 527 var connCloseAuthHandler func() 528 defer func() { 529 // Release resources, if we still own them. 530 if reservedOwned { 531 reserved.Close(ctx) 532 } 533 // Notify the connection's goroutine that we're terminating. The 534 // connection might know already, as it might have triggered this 535 // goroutine's finish, but it also might be us that we're triggering the 536 // connection's death. This context cancelation serves to interrupt a 537 // network read on the connection's goroutine. 538 cancelConn() 539 540 pgwireKnobs := sqlServer.GetExecutorConfig().PGWireTestingKnobs 541 if pgwireKnobs != nil && pgwireKnobs.CatchPanics { 542 if r := recover(); r != nil { 543 // Catch the panic and return it to the client as an error. 544 if err, ok := r.(error); ok { 545 // Mask the cause but keep the details. 546 retErr = errors.Handled(err) 547 } else { 548 retErr = errors.Newf("%+v", r) 549 } 550 retErr = pgerror.WithCandidateCode(retErr, pgcode.CrashShutdown) 551 // Add a prefix. This also adds a stack trace. 552 retErr = errors.Wrap(retErr, "caught fatal error") 553 _ = writeErr( 554 ctx, &sqlServer.GetExecutorConfig().Settings.SV, retErr, 555 &c.msgBuilder, &c.writerState.buf) 556 _ /* n */, _ /* err */ = c.writerState.buf.WriteTo(c.conn) 557 c.stmtBuf.Close() 558 // Send a ready for query to make sure the client can react. 559 // TODO(andrei, jordan): Why are we sending this exactly? 560 c.bufferReadyForQuery('I') 561 } 562 } 563 if !authOK { 564 ac.AuthFail(retErr) 565 } 566 if connCloseAuthHandler != nil { 567 connCloseAuthHandler() 568 } 569 // Inform the connection goroutine of success or failure. 570 retCh <- retErr 571 }() 572 573 // Authenticate the connection. 574 if connCloseAuthHandler, retErr = c.handleAuthentication( 575 ctx, ac, authOpt, sqlServer.GetExecutorConfig(), 576 ); retErr != nil { 577 // Auth failed or some other error. 578 return 579 } 580 581 // Inform the client of the default session settings. 582 connHandler, retErr = c.sendInitialConnData(ctx, sqlServer) 583 if retErr != nil { 584 return 585 } 586 // Signal the connection was established to the authenticator. 587 ac.AuthOK(connHandler) 588 // Mark the authentication as succeeded in case a panic 589 // is thrown below and we need to report to the client 590 // using the defer above. 591 authOK = true 592 593 // Now actually process commands. 594 reservedOwned = false // We're about to pass ownership away. 595 retErr = sqlServer.ServeConn(ctx, connHandler, reserved, cancelConn) 596 }() 597 return retCh 598 } 599 600 func (c *conn) sendParamStatus(param, value string) error { 601 c.msgBuilder.initMsg(pgwirebase.ServerMsgParameterStatus) 602 c.msgBuilder.writeTerminatedString(param) 603 c.msgBuilder.writeTerminatedString(value) 604 return c.msgBuilder.finishMsg(c.conn) 605 } 606 607 func (c *conn) bufferParamStatus(param, value string) error { 608 c.msgBuilder.initMsg(pgwirebase.ServerMsgParameterStatus) 609 c.msgBuilder.writeTerminatedString(param) 610 c.msgBuilder.writeTerminatedString(value) 611 return c.msgBuilder.finishMsg(&c.writerState.buf) 612 } 613 614 func (c *conn) bufferNotice(ctx context.Context, noticeErr error) error { 615 c.msgBuilder.initMsg(pgwirebase.ServerMsgNoticeResponse) 616 return writeErrFields(ctx, c.sv, noticeErr, &c.msgBuilder, &c.writerState.buf) 617 } 618 619 func (c *conn) sendInitialConnData( 620 ctx context.Context, sqlServer *sql.Server, 621 ) (sql.ConnectionHandler, error) { 622 connHandler, err := sqlServer.SetupConn( 623 ctx, c.sessionArgs, &c.stmtBuf, c, c.metrics.SQLMemMetrics) 624 if err != nil { 625 _ /* err */ = writeErr( 626 ctx, &sqlServer.GetExecutorConfig().Settings.SV, err, &c.msgBuilder, c.conn) 627 return sql.ConnectionHandler{}, err 628 } 629 630 // Send the initial "status parameters" to the client. This 631 // overlaps partially with session variables. The client wants to 632 // see the values that result from the combination of server-side 633 // defaults with client-provided values. 634 // For details see: https://www.postgresql.org/docs/10/static/libpq-status.html 635 for _, param := range statusReportParams { 636 param := param 637 value := connHandler.GetParamStatus(ctx, param) 638 if err := c.sendParamStatus(param, value); err != nil { 639 return sql.ConnectionHandler{}, err 640 } 641 } 642 // The two following status parameters have no equivalent session 643 // variable. 644 if err := c.sendParamStatus("session_authorization", c.sessionArgs.User); err != nil { 645 return sql.ConnectionHandler{}, err 646 } 647 648 // TODO(knz): this should retrieve the admin status during 649 // authentication using the roles table, instead of using a 650 // simple/naive username match. 651 isSuperUser := c.sessionArgs.User == security.RootUser 652 superUserVal := "off" 653 if isSuperUser { 654 superUserVal = "on" 655 } 656 if err := c.sendParamStatus("is_superuser", superUserVal); err != nil { 657 return sql.ConnectionHandler{}, err 658 } 659 660 // An initial readyForQuery message is part of the handshake. 661 c.msgBuilder.initMsg(pgwirebase.ServerMsgReady) 662 c.msgBuilder.writeByte(byte(sql.IdleTxnBlock)) 663 if err := c.msgBuilder.finishMsg(c.conn); err != nil { 664 return sql.ConnectionHandler{}, err 665 } 666 return connHandler, nil 667 } 668 669 // An error is returned iff the statement buffer has been closed. In that case, 670 // the connection should be considered toast. 671 func (c *conn) handleSimpleQuery( 672 ctx context.Context, 673 buf *pgwirebase.ReadBuffer, 674 timeReceived time.Time, 675 unqualifiedIntSize *types.T, 676 ) error { 677 query, err := buf.GetString() 678 if err != nil { 679 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 680 } 681 682 tracing.AnnotateTrace() 683 684 startParse := timeutil.Now() 685 stmts, err := c.parser.ParseWithInt(query, unqualifiedIntSize) 686 if err != nil { 687 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 688 } 689 endParse := timeutil.Now() 690 691 if len(stmts) == 0 { 692 return c.stmtBuf.Push( 693 ctx, sql.ExecStmt{ 694 Statement: parser.Statement{}, 695 TimeReceived: timeReceived, 696 ParseStart: startParse, 697 ParseEnd: endParse, 698 }) 699 } 700 701 for i := range stmts { 702 // The CopyFrom statement is special. We need to detect it so we can hand 703 // control of the connection, through the stmtBuf, to a copyMachine, and 704 // block this network routine until control is passed back. 705 if cp, ok := stmts[i].AST.(*tree.CopyFrom); ok { 706 if len(stmts) != 1 { 707 // NOTE(andrei): I don't know if Postgres supports receiving a COPY 708 // together with other statements in the "simple" protocol, but I'd 709 // rather not worry about it since execution of COPY is special - it 710 // takes control over the connection. 711 return c.stmtBuf.Push( 712 ctx, 713 sql.SendError{ 714 Err: pgwirebase.NewProtocolViolationErrorf( 715 "COPY together with other statements in a query string is not supported"), 716 }) 717 } 718 copyDone := sync.WaitGroup{} 719 copyDone.Add(1) 720 if err := c.stmtBuf.Push(ctx, sql.CopyIn{Conn: c, Stmt: cp, CopyDone: ©Done}); err != nil { 721 return err 722 } 723 copyDone.Wait() 724 return nil 725 } 726 727 if err := c.stmtBuf.Push( 728 ctx, 729 sql.ExecStmt{ 730 Statement: stmts[i], 731 TimeReceived: timeReceived, 732 ParseStart: startParse, 733 ParseEnd: endParse, 734 }); err != nil { 735 return err 736 } 737 } 738 return nil 739 } 740 741 // An error is returned iff the statement buffer has been closed. In that case, 742 // the connection should be considered toast. 743 func (c *conn) handleParse( 744 ctx context.Context, buf *pgwirebase.ReadBuffer, nakedIntSize *types.T, 745 ) error { 746 name, err := buf.GetString() 747 if err != nil { 748 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 749 } 750 query, err := buf.GetString() 751 if err != nil { 752 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 753 } 754 // The client may provide type information for (some of) the placeholders. 755 numQArgTypes, err := buf.GetUint16() 756 if err != nil { 757 return err 758 } 759 inTypeHints := make([]oid.Oid, numQArgTypes) 760 for i := range inTypeHints { 761 typ, err := buf.GetUint32() 762 if err != nil { 763 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 764 } 765 inTypeHints[i] = oid.Oid(typ) 766 } 767 768 startParse := timeutil.Now() 769 stmts, err := c.parser.ParseWithInt(query, nakedIntSize) 770 if err != nil { 771 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 772 } 773 if len(stmts) > 1 { 774 err := pgerror.WrongNumberOfPreparedStatements(len(stmts)) 775 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 776 } 777 var stmt parser.Statement 778 if len(stmts) == 1 { 779 stmt = stmts[0] 780 } 781 // len(stmts) == 0 results in a nil (empty) statement. 782 783 if len(inTypeHints) > stmt.NumPlaceholders { 784 err := pgwirebase.NewProtocolViolationErrorf( 785 "received too many type hints: %d vs %d placeholders in query", 786 len(inTypeHints), stmt.NumPlaceholders, 787 ) 788 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 789 } 790 791 var sqlTypeHints tree.PlaceholderTypes 792 if len(inTypeHints) > 0 { 793 // Prepare the mapping of SQL placeholder names to types. Pre-populate it with 794 // the type hints received from the client, if any. 795 sqlTypeHints = make(tree.PlaceholderTypes, stmt.NumPlaceholders) 796 for i, t := range inTypeHints { 797 if t == 0 { 798 continue 799 } 800 v, ok := types.OidToType[t] 801 if !ok { 802 err := pgwirebase.NewProtocolViolationErrorf("unknown oid type: %v", t) 803 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 804 } 805 sqlTypeHints[i] = v 806 } 807 } 808 809 endParse := timeutil.Now() 810 811 if _, ok := stmt.AST.(*tree.CopyFrom); ok { 812 // We don't support COPY in extended protocol because it'd be complicated: 813 // it wouldn't be the preparing, but the execution that would need to 814 // execute the copyMachine. 815 // Also note that COPY FROM in extended mode seems to be quite broken in 816 // Postgres too: 817 // https://www.postgresql.org/message-id/flat/CAMsr%2BYGvp2wRx9pPSxaKFdaObxX8DzWse%2BOkWk2xpXSvT0rq-g%40mail.gmail.com#CAMsr+YGvp2wRx9pPSxaKFdaObxX8DzWse+OkWk2xpXSvT0rq-g@mail.gmail.com 818 return c.stmtBuf.Push(ctx, sql.SendError{Err: fmt.Errorf("CopyFrom not supported in extended protocol mode")}) 819 } 820 821 return c.stmtBuf.Push( 822 ctx, 823 sql.PrepareStmt{ 824 Name: name, 825 Statement: stmt, 826 TypeHints: sqlTypeHints, 827 RawTypeHints: inTypeHints, 828 ParseStart: startParse, 829 ParseEnd: endParse, 830 }) 831 } 832 833 // An error is returned iff the statement buffer has been closed. In that case, 834 // the connection should be considered toast. 835 func (c *conn) handleDescribe(ctx context.Context, buf *pgwirebase.ReadBuffer) error { 836 typ, err := buf.GetPrepareType() 837 if err != nil { 838 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 839 } 840 name, err := buf.GetString() 841 if err != nil { 842 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 843 } 844 return c.stmtBuf.Push( 845 ctx, 846 sql.DescribeStmt{ 847 Name: name, 848 Type: typ, 849 }) 850 } 851 852 // An error is returned iff the statement buffer has been closed. In that case, 853 // the connection should be considered toast. 854 func (c *conn) handleClose(ctx context.Context, buf *pgwirebase.ReadBuffer) error { 855 typ, err := buf.GetPrepareType() 856 if err != nil { 857 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 858 } 859 name, err := buf.GetString() 860 if err != nil { 861 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 862 } 863 return c.stmtBuf.Push( 864 ctx, 865 sql.DeletePreparedStmt{ 866 Name: name, 867 Type: typ, 868 }) 869 } 870 871 // If no format codes are provided then all arguments/result-columns use 872 // the default format, text. 873 var formatCodesAllText = []pgwirebase.FormatCode{pgwirebase.FormatText} 874 875 // handleBind queues instructions for creating a portal from a prepared 876 // statement. 877 // An error is returned iff the statement buffer has been closed. In that case, 878 // the connection should be considered toast. 879 func (c *conn) handleBind(ctx context.Context, buf *pgwirebase.ReadBuffer) error { 880 portalName, err := buf.GetString() 881 if err != nil { 882 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 883 } 884 statementName, err := buf.GetString() 885 if err != nil { 886 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 887 } 888 889 // From the docs on number of argument format codes to bind: 890 // This can be zero to indicate that there are no arguments or that the 891 // arguments all use the default format (text); or one, in which case the 892 // specified format code is applied to all arguments; or it can equal the 893 // actual number of arguments. 894 // http://www.postgresql.org/docs/current/static/protocol-message-formats.html 895 numQArgFormatCodes, err := buf.GetUint16() 896 if err != nil { 897 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 898 } 899 var qArgFormatCodes []pgwirebase.FormatCode 900 switch numQArgFormatCodes { 901 case 0: 902 // No format codes means all arguments are passed as text. 903 qArgFormatCodes = formatCodesAllText 904 case 1: 905 // `1` means read one code and apply it to every argument. 906 ch, err := buf.GetUint16() 907 if err != nil { 908 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 909 } 910 code := pgwirebase.FormatCode(ch) 911 if code == pgwirebase.FormatText { 912 qArgFormatCodes = formatCodesAllText 913 } else { 914 qArgFormatCodes = []pgwirebase.FormatCode{code} 915 } 916 default: 917 qArgFormatCodes = make([]pgwirebase.FormatCode, numQArgFormatCodes) 918 // Read one format code for each argument and apply it to that argument. 919 for i := range qArgFormatCodes { 920 ch, err := buf.GetUint16() 921 if err != nil { 922 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 923 } 924 qArgFormatCodes[i] = pgwirebase.FormatCode(ch) 925 } 926 } 927 928 numValues, err := buf.GetUint16() 929 if err != nil { 930 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 931 } 932 qargs := make([][]byte, numValues) 933 for i := 0; i < int(numValues); i++ { 934 plen, err := buf.GetUint32() 935 if err != nil { 936 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 937 } 938 if int32(plen) == -1 { 939 // The argument is a NULL value. 940 qargs[i] = nil 941 continue 942 } 943 b, err := buf.GetBytes(int(plen)) 944 if err != nil { 945 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 946 } 947 qargs[i] = b 948 } 949 950 // From the docs on number of result-column format codes to bind: 951 // This can be zero to indicate that there are no result columns or that 952 // the result columns should all use the default format (text); or one, in 953 // which case the specified format code is applied to all result columns 954 // (if any); or it can equal the actual number of result columns of the 955 // query. 956 // http://www.postgresql.org/docs/current/static/protocol-message-formats.html 957 numColumnFormatCodes, err := buf.GetUint16() 958 if err != nil { 959 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 960 } 961 var columnFormatCodes []pgwirebase.FormatCode 962 switch numColumnFormatCodes { 963 case 0: 964 // All columns will use the text format. 965 columnFormatCodes = formatCodesAllText 966 case 1: 967 // All columns will use the one specified format. 968 ch, err := buf.GetUint16() 969 if err != nil { 970 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 971 } 972 code := pgwirebase.FormatCode(ch) 973 if code == pgwirebase.FormatText { 974 columnFormatCodes = formatCodesAllText 975 } else { 976 columnFormatCodes = []pgwirebase.FormatCode{code} 977 } 978 default: 979 columnFormatCodes = make([]pgwirebase.FormatCode, numColumnFormatCodes) 980 // Read one format code for each column and apply it to that column. 981 for i := range columnFormatCodes { 982 ch, err := buf.GetUint16() 983 if err != nil { 984 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 985 } 986 columnFormatCodes[i] = pgwirebase.FormatCode(ch) 987 } 988 } 989 return c.stmtBuf.Push( 990 ctx, 991 sql.BindStmt{ 992 PreparedStatementName: statementName, 993 PortalName: portalName, 994 Args: qargs, 995 ArgFormatCodes: qArgFormatCodes, 996 OutFormats: columnFormatCodes, 997 }) 998 } 999 1000 // An error is returned iff the statement buffer has been closed. In that case, 1001 // the connection should be considered toast. 1002 func (c *conn) handleExecute( 1003 ctx context.Context, buf *pgwirebase.ReadBuffer, timeReceived time.Time, 1004 ) error { 1005 portalName, err := buf.GetString() 1006 if err != nil { 1007 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 1008 } 1009 limit, err := buf.GetUint32() 1010 if err != nil { 1011 return c.stmtBuf.Push(ctx, sql.SendError{Err: err}) 1012 } 1013 return c.stmtBuf.Push(ctx, sql.ExecPortal{ 1014 Name: portalName, 1015 TimeReceived: timeReceived, 1016 Limit: int(limit), 1017 }) 1018 } 1019 1020 func (c *conn) handleFlush(ctx context.Context) error { 1021 return c.stmtBuf.Push(ctx, sql.Flush{}) 1022 } 1023 1024 // BeginCopyIn is part of the pgwirebase.Conn interface. 1025 func (c *conn) BeginCopyIn(ctx context.Context, columns []sqlbase.ResultColumn) error { 1026 c.msgBuilder.initMsg(pgwirebase.ServerMsgCopyInResponse) 1027 c.msgBuilder.writeByte(byte(pgwirebase.FormatText)) 1028 c.msgBuilder.putInt16(int16(len(columns))) 1029 for range columns { 1030 c.msgBuilder.putInt16(int16(pgwirebase.FormatText)) 1031 } 1032 return c.msgBuilder.finishMsg(c.conn) 1033 } 1034 1035 // SendCommandComplete is part of the pgwirebase.Conn interface. 1036 func (c *conn) SendCommandComplete(tag []byte) error { 1037 c.bufferCommandComplete(tag) 1038 return nil 1039 } 1040 1041 // Rd is part of the pgwirebase.Conn interface. 1042 func (c *conn) Rd() pgwirebase.BufferedReader { 1043 return &pgwireReader{conn: c} 1044 } 1045 1046 // flushInfo encapsulates information about what results have been flushed to 1047 // the network. 1048 type flushInfo struct { 1049 // buf is a reference to writerState.buf. 1050 buf *bytes.Buffer 1051 // lastFlushed indicates the highest command for which results have been 1052 // flushed. The command may have further results in the buffer that haven't 1053 // been flushed. 1054 lastFlushed sql.CmdPos 1055 // map from CmdPos to the index of the buffer where the results for the 1056 // respective result begins. 1057 cmdStarts map[sql.CmdPos]int 1058 } 1059 1060 // registerCmd updates cmdStarts when the first result for a new command is 1061 // received. 1062 func (fi *flushInfo) registerCmd(pos sql.CmdPos) { 1063 if _, ok := fi.cmdStarts[pos]; ok { 1064 return 1065 } 1066 fi.cmdStarts[pos] = fi.buf.Len() 1067 } 1068 1069 func cookTag(tagStr string, buf []byte, stmtType tree.StatementType, rowsAffected int) []byte { 1070 if tagStr == "INSERT" { 1071 // From the postgres docs (49.5. Message Formats): 1072 // `INSERT oid rows`... oid is the object ID of the inserted row if 1073 // rows is 1 and the target table has OIDs; otherwise oid is 0. 1074 tagStr = "INSERT 0" 1075 } 1076 tag := append(buf, tagStr...) 1077 1078 switch stmtType { 1079 case tree.RowsAffected: 1080 tag = append(tag, ' ') 1081 tag = strconv.AppendInt(tag, int64(rowsAffected), 10) 1082 1083 case tree.Rows: 1084 tag = append(tag, ' ') 1085 tag = strconv.AppendUint(tag, uint64(rowsAffected), 10) 1086 1087 case tree.Ack, tree.DDL: 1088 if tagStr == "SELECT" { 1089 tag = append(tag, ' ') 1090 tag = strconv.AppendInt(tag, int64(rowsAffected), 10) 1091 } 1092 1093 case tree.CopyIn: 1094 // Nothing to do. The CommandComplete message has been sent elsewhere. 1095 panic(fmt.Sprintf("CopyIn statements should have been handled elsewhere " + 1096 "and not produce results")) 1097 default: 1098 panic(fmt.Sprintf("unexpected result type %v", stmtType)) 1099 } 1100 1101 return tag 1102 } 1103 1104 // bufferRow serializes a row and adds it to the buffer. 1105 // 1106 // formatCodes describes the desired encoding for each column. It can be nil, in 1107 // which case all columns are encoded using the text encoding. Otherwise, it 1108 // needs to contain an entry for every column. 1109 func (c *conn) bufferRow( 1110 ctx context.Context, 1111 row tree.Datums, 1112 formatCodes []pgwirebase.FormatCode, 1113 conv sessiondata.DataConversionConfig, 1114 oids []oid.Oid, 1115 ) { 1116 c.msgBuilder.initMsg(pgwirebase.ServerMsgDataRow) 1117 c.msgBuilder.putInt16(int16(len(row))) 1118 for i, col := range row { 1119 fmtCode := pgwirebase.FormatText 1120 if formatCodes != nil { 1121 fmtCode = formatCodes[i] 1122 } 1123 switch fmtCode { 1124 case pgwirebase.FormatText: 1125 c.msgBuilder.writeTextDatum(ctx, col, conv) 1126 case pgwirebase.FormatBinary: 1127 c.msgBuilder.writeBinaryDatum(ctx, col, conv.Location, oids[i]) 1128 default: 1129 c.msgBuilder.setError(errors.Errorf("unsupported format code %s", fmtCode)) 1130 } 1131 } 1132 if err := c.msgBuilder.finishMsg(&c.writerState.buf); err != nil { 1133 panic(fmt.Sprintf("unexpected err from buffer: %s", err)) 1134 } 1135 } 1136 1137 func (c *conn) bufferReadyForQuery(txnStatus byte) { 1138 c.msgBuilder.initMsg(pgwirebase.ServerMsgReady) 1139 c.msgBuilder.writeByte(txnStatus) 1140 if err := c.msgBuilder.finishMsg(&c.writerState.buf); err != nil { 1141 panic(fmt.Sprintf("unexpected err from buffer: %s", err)) 1142 } 1143 } 1144 1145 func (c *conn) bufferParseComplete() { 1146 c.msgBuilder.initMsg(pgwirebase.ServerMsgParseComplete) 1147 if err := c.msgBuilder.finishMsg(&c.writerState.buf); err != nil { 1148 panic(fmt.Sprintf("unexpected err from buffer: %s", err)) 1149 } 1150 } 1151 1152 func (c *conn) bufferBindComplete() { 1153 c.msgBuilder.initMsg(pgwirebase.ServerMsgBindComplete) 1154 if err := c.msgBuilder.finishMsg(&c.writerState.buf); err != nil { 1155 panic(fmt.Sprintf("unexpected err from buffer: %s", err)) 1156 } 1157 } 1158 1159 func (c *conn) bufferCloseComplete() { 1160 c.msgBuilder.initMsg(pgwirebase.ServerMsgCloseComplete) 1161 if err := c.msgBuilder.finishMsg(&c.writerState.buf); err != nil { 1162 panic(fmt.Sprintf("unexpected err from buffer: %s", err)) 1163 } 1164 } 1165 1166 func (c *conn) bufferCommandComplete(tag []byte) { 1167 c.msgBuilder.initMsg(pgwirebase.ServerMsgCommandComplete) 1168 c.msgBuilder.write(tag) 1169 c.msgBuilder.nullTerminate() 1170 if err := c.msgBuilder.finishMsg(&c.writerState.buf); err != nil { 1171 panic(fmt.Sprintf("unexpected err from buffer: %s", err)) 1172 } 1173 } 1174 1175 func (c *conn) bufferPortalSuspended() { 1176 c.msgBuilder.initMsg(pgwirebase.ServerMsgPortalSuspended) 1177 if err := c.msgBuilder.finishMsg(&c.writerState.buf); err != nil { 1178 panic(fmt.Sprintf("unexpected err from buffer: %s", err)) 1179 } 1180 } 1181 1182 func (c *conn) bufferErr(ctx context.Context, err error) { 1183 if err := writeErr(ctx, c.sv, 1184 err, &c.msgBuilder, &c.writerState.buf); err != nil { 1185 panic(fmt.Sprintf("unexpected err from buffer: %s", err)) 1186 } 1187 } 1188 1189 func (c *conn) bufferEmptyQueryResponse() { 1190 c.msgBuilder.initMsg(pgwirebase.ServerMsgEmptyQuery) 1191 if err := c.msgBuilder.finishMsg(&c.writerState.buf); err != nil { 1192 panic(fmt.Sprintf("unexpected err from buffer: %s", err)) 1193 } 1194 } 1195 1196 func writeErr( 1197 ctx context.Context, sv *settings.Values, err error, msgBuilder *writeBuffer, w io.Writer, 1198 ) error { 1199 // Record telemetry for the error. 1200 sqltelemetry.RecordError(ctx, err, sv) 1201 msgBuilder.initMsg(pgwirebase.ServerMsgErrorResponse) 1202 return writeErrFields(ctx, sv, err, msgBuilder, w) 1203 } 1204 1205 func writeErrFields( 1206 ctx context.Context, sv *settings.Values, err error, msgBuilder *writeBuffer, w io.Writer, 1207 ) error { 1208 // Now send the error to the client. 1209 pgErr := pgerror.Flatten(err) 1210 1211 msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFieldSeverity) 1212 msgBuilder.writeTerminatedString(pgErr.Severity) 1213 1214 msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFieldSQLState) 1215 msgBuilder.writeTerminatedString(pgErr.Code) 1216 1217 if pgErr.Detail != "" { 1218 msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFileldDetail) 1219 msgBuilder.writeTerminatedString(pgErr.Detail) 1220 } 1221 1222 if pgErr.Hint != "" { 1223 msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFileldHint) 1224 msgBuilder.writeTerminatedString(pgErr.Hint) 1225 } 1226 1227 if pgErr.Source != nil { 1228 errCtx := pgErr.Source 1229 if errCtx.File != "" { 1230 msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFieldSrcFile) 1231 msgBuilder.writeTerminatedString(errCtx.File) 1232 } 1233 1234 if errCtx.Line > 0 { 1235 msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFieldSrcLine) 1236 msgBuilder.writeTerminatedString(strconv.Itoa(int(errCtx.Line))) 1237 } 1238 1239 if errCtx.Function != "" { 1240 msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFieldSrcFunction) 1241 msgBuilder.writeTerminatedString(errCtx.Function) 1242 } 1243 } 1244 1245 msgBuilder.putErrFieldMsg(pgwirebase.ServerErrFieldMsgPrimary) 1246 msgBuilder.writeTerminatedString(pgErr.Message) 1247 1248 msgBuilder.nullTerminate() 1249 return msgBuilder.finishMsg(w) 1250 } 1251 1252 func (c *conn) bufferParamDesc(types []oid.Oid) { 1253 c.msgBuilder.initMsg(pgwirebase.ServerMsgParameterDescription) 1254 c.msgBuilder.putInt16(int16(len(types))) 1255 for _, t := range types { 1256 c.msgBuilder.putInt32(int32(t)) 1257 } 1258 if err := c.msgBuilder.finishMsg(&c.writerState.buf); err != nil { 1259 panic(fmt.Sprintf("unexpected err from buffer: %s", err)) 1260 } 1261 } 1262 1263 func (c *conn) bufferNoDataMsg() { 1264 c.msgBuilder.initMsg(pgwirebase.ServerMsgNoData) 1265 if err := c.msgBuilder.finishMsg(&c.writerState.buf); err != nil { 1266 panic(fmt.Sprintf("unexpected err from buffer: %s", err)) 1267 } 1268 } 1269 1270 // writeRowDescription writes a row description to the given writer. 1271 // 1272 // formatCodes specifies the format for each column. It can be nil, in which 1273 // case all columns will use FormatText. 1274 // 1275 // If an error is returned, it has also been saved on c.err. 1276 func (c *conn) writeRowDescription( 1277 ctx context.Context, 1278 columns []sqlbase.ResultColumn, 1279 formatCodes []pgwirebase.FormatCode, 1280 w io.Writer, 1281 ) error { 1282 c.msgBuilder.initMsg(pgwirebase.ServerMsgRowDescription) 1283 c.msgBuilder.putInt16(int16(len(columns))) 1284 for i, column := range columns { 1285 if log.V(2) { 1286 log.Infof(ctx, "pgwire: writing column %s of type: %s", column.Name, column.Typ) 1287 } 1288 c.msgBuilder.writeTerminatedString(column.Name) 1289 typ := pgTypeForParserType(column.Typ) 1290 c.msgBuilder.putInt32(int32(column.TableID)) // Table OID (optional). 1291 c.msgBuilder.putInt16(int16(column.PGAttributeNum)) // Column attribute ID (optional). 1292 c.msgBuilder.putInt32(int32(typ.oid)) 1293 c.msgBuilder.putInt16(int16(typ.size)) 1294 c.msgBuilder.putInt32(column.GetTypeModifier()) // Type modifier 1295 if formatCodes == nil { 1296 c.msgBuilder.putInt16(int16(pgwirebase.FormatText)) 1297 } else { 1298 c.msgBuilder.putInt16(int16(formatCodes[i])) 1299 } 1300 } 1301 if err := c.msgBuilder.finishMsg(w); err != nil { 1302 c.setErr(err) 1303 return err 1304 } 1305 return nil 1306 } 1307 1308 // Flush is part of the ClientComm interface. 1309 // 1310 // In case conn.err is set, this is a no-op - the previous err is returned. 1311 func (c *conn) Flush(pos sql.CmdPos) error { 1312 // Check that there were no previous network errors. If there were, we'd 1313 // probably also fail the write below, but this check is here to make 1314 // absolutely sure that we don't send some results after we previously had 1315 // failed to send others. 1316 if err := c.GetErr(); err != nil { 1317 return err 1318 } 1319 1320 c.writerState.fi.lastFlushed = pos 1321 c.writerState.fi.cmdStarts = make(map[sql.CmdPos]int) 1322 1323 _ /* n */, err := c.writerState.buf.WriteTo(c.conn) 1324 if err != nil { 1325 c.setErr(err) 1326 return err 1327 } 1328 return nil 1329 } 1330 1331 // maybeFlush flushes the buffer to the network connection if it exceeded 1332 // sessionArgs.ConnResultsBufferSize. 1333 func (c *conn) maybeFlush(pos sql.CmdPos) (bool, error) { 1334 if int64(c.writerState.buf.Len()) <= c.sessionArgs.ConnResultsBufferSize { 1335 return false, nil 1336 } 1337 return true, c.Flush(pos) 1338 } 1339 1340 // LockCommunication is part of the ClientComm interface. 1341 // 1342 // The current implementation of conn writes results to the network 1343 // synchronously, as they are produced (modulo buffering). Therefore, there's 1344 // nothing to "lock" - communication is naturally blocked as the command 1345 // processor won't write any more results. 1346 func (c *conn) LockCommunication() sql.ClientLock { 1347 return &clientConnLock{flushInfo: &c.writerState.fi} 1348 } 1349 1350 // clientConnLock is the connection's implementation of sql.ClientLock. It lets 1351 // the sql module lock the flushing of results and find out what has already 1352 // been flushed. 1353 type clientConnLock struct { 1354 *flushInfo 1355 } 1356 1357 var _ sql.ClientLock = &clientConnLock{} 1358 1359 // Close is part of the sql.ClientLock interface. 1360 func (cl *clientConnLock) Close() { 1361 // Nothing to do. See LockCommunication note. 1362 } 1363 1364 // ClientPos is part of the sql.ClientLock interface. 1365 func (cl *clientConnLock) ClientPos() sql.CmdPos { 1366 return cl.lastFlushed 1367 } 1368 1369 // RTrim is part of the sql.ClientLock interface. 1370 func (cl *clientConnLock) RTrim(ctx context.Context, pos sql.CmdPos) { 1371 if pos <= cl.lastFlushed { 1372 panic(fmt.Sprintf("asked to trim to pos: %d, below the last flush: %d", pos, cl.lastFlushed)) 1373 } 1374 idx, ok := cl.cmdStarts[pos] 1375 if !ok { 1376 // If we don't have a start index for pos yet, it must be that no results 1377 // for it yet have been produced yet. 1378 idx = cl.buf.Len() 1379 } 1380 // Remove everything from the buffer after idx. 1381 cl.buf.Truncate(idx) 1382 // Update cmdStarts: delete commands that were trimmed. 1383 for p := range cl.cmdStarts { 1384 if p >= pos { 1385 delete(cl.cmdStarts, p) 1386 } 1387 } 1388 } 1389 1390 // CreateStatementResult is part of the sql.ClientComm interface. 1391 func (c *conn) CreateStatementResult( 1392 stmt tree.Statement, 1393 descOpt sql.RowDescOpt, 1394 pos sql.CmdPos, 1395 formatCodes []pgwirebase.FormatCode, 1396 conv sessiondata.DataConversionConfig, 1397 limit int, 1398 portalName string, 1399 implicitTxn bool, 1400 ) sql.CommandResult { 1401 return c.newCommandResult(descOpt, pos, stmt, formatCodes, conv, limit, portalName, implicitTxn) 1402 } 1403 1404 // CreateSyncResult is part of the sql.ClientComm interface. 1405 func (c *conn) CreateSyncResult(pos sql.CmdPos) sql.SyncResult { 1406 return c.newMiscResult(pos, readyForQuery) 1407 } 1408 1409 // CreateFlushResult is part of the sql.ClientComm interface. 1410 func (c *conn) CreateFlushResult(pos sql.CmdPos) sql.FlushResult { 1411 return c.newMiscResult(pos, flush) 1412 } 1413 1414 // CreateDrainResult is part of the sql.ClientComm interface. 1415 func (c *conn) CreateDrainResult(pos sql.CmdPos) sql.DrainResult { 1416 return c.newMiscResult(pos, noCompletionMsg) 1417 } 1418 1419 // CreateBindResult is part of the sql.ClientComm interface. 1420 func (c *conn) CreateBindResult(pos sql.CmdPos) sql.BindResult { 1421 return c.newMiscResult(pos, bindComplete) 1422 } 1423 1424 // CreatePrepareResult is part of the sql.ClientComm interface. 1425 func (c *conn) CreatePrepareResult(pos sql.CmdPos) sql.ParseResult { 1426 return c.newMiscResult(pos, parseComplete) 1427 } 1428 1429 // CreateDescribeResult is part of the sql.ClientComm interface. 1430 func (c *conn) CreateDescribeResult(pos sql.CmdPos) sql.DescribeResult { 1431 return c.newMiscResult(pos, noCompletionMsg) 1432 } 1433 1434 // CreateEmptyQueryResult is part of the sql.ClientComm interface. 1435 func (c *conn) CreateEmptyQueryResult(pos sql.CmdPos) sql.EmptyQueryResult { 1436 return c.newMiscResult(pos, emptyQueryResponse) 1437 } 1438 1439 // CreateDeleteResult is part of the sql.ClientComm interface. 1440 func (c *conn) CreateDeleteResult(pos sql.CmdPos) sql.DeleteResult { 1441 return c.newMiscResult(pos, closeComplete) 1442 } 1443 1444 // CreateErrorResult is part of the sql.ClientComm interface. 1445 func (c *conn) CreateErrorResult(pos sql.CmdPos) sql.ErrorResult { 1446 res := c.newMiscResult(pos, noCompletionMsg) 1447 res.errExpected = true 1448 return res 1449 } 1450 1451 // CreateCopyInResult is part of the sql.ClientComm interface. 1452 func (c *conn) CreateCopyInResult(pos sql.CmdPos) sql.CopyInResult { 1453 return c.newMiscResult(pos, noCompletionMsg) 1454 } 1455 1456 // pgwireReader is an io.Reader that wraps a conn, maintaining its metrics as 1457 // it is consumed. 1458 type pgwireReader struct { 1459 conn *conn 1460 } 1461 1462 // pgwireReader implements the pgwirebase.BufferedReader interface. 1463 var _ pgwirebase.BufferedReader = &pgwireReader{} 1464 1465 // Read is part of the pgwirebase.BufferedReader interface. 1466 func (r *pgwireReader) Read(p []byte) (int, error) { 1467 n, err := r.conn.rd.Read(p) 1468 r.conn.metrics.BytesInCount.Inc(int64(n)) 1469 return n, err 1470 } 1471 1472 // ReadString is part of the pgwirebase.BufferedReader interface. 1473 func (r *pgwireReader) ReadString(delim byte) (string, error) { 1474 s, err := r.conn.rd.ReadString(delim) 1475 r.conn.metrics.BytesInCount.Inc(int64(len(s))) 1476 return s, err 1477 } 1478 1479 // ReadByte is part of the pgwirebase.BufferedReader interface. 1480 func (r *pgwireReader) ReadByte() (byte, error) { 1481 b, err := r.conn.rd.ReadByte() 1482 if err == nil { 1483 r.conn.metrics.BytesInCount.Inc(1) 1484 } 1485 return b, err 1486 } 1487 1488 // statusReportParams is a list of session variables that are also 1489 // reported as server run-time parameters in the pgwire connection 1490 // initialization. 1491 // 1492 // The standard PostgreSQL status vars are listed here: 1493 // https://www.postgresql.org/docs/10/static/libpq-status.html 1494 var statusReportParams = []string{ 1495 "server_version", 1496 "server_encoding", 1497 "client_encoding", 1498 "application_name", 1499 // Note: is_superuser and session_authorization are handled 1500 // specially in serveImpl(). 1501 "DateStyle", 1502 "IntervalStyle", 1503 "TimeZone", 1504 "integer_datetimes", 1505 "standard_conforming_strings", 1506 "crdb_version", // CockroachDB extension. 1507 } 1508 1509 // testingStatusReportParams is the minimum set of status parameters 1510 // needed to make pgx tests in the local package happy. 1511 var testingStatusReportParams = map[string]string{ 1512 "client_encoding": "UTF8", 1513 "standard_conforming_strings": "on", 1514 } 1515 1516 // readTimeoutConn overloads net.Conn.Read by periodically calling 1517 // checkExitConds() and aborting the read if an error is returned. 1518 type readTimeoutConn struct { 1519 net.Conn 1520 // checkExitConds is called periodically by Read(). If it returns an error, 1521 // the Read() returns that error. Future calls to Read() are allowed, in which 1522 // case checkExitConds() will be called again. 1523 checkExitConds func() error 1524 } 1525 1526 func newReadTimeoutConn(c net.Conn, checkExitConds func() error) net.Conn { 1527 // net.Pipe does not support setting deadlines. See 1528 // https://github.com/golang/go/blob/go1.7.4/src/net/pipe.go#L57-L67 1529 // 1530 // TODO(andrei): starting with Go 1.10, pipes are supposed to support 1531 // timeouts, so this should go away when we upgrade the compiler. 1532 if c.LocalAddr().Network() == "pipe" { 1533 return c 1534 } 1535 return &readTimeoutConn{ 1536 Conn: c, 1537 checkExitConds: checkExitConds, 1538 } 1539 } 1540 1541 func (c *readTimeoutConn) Read(b []byte) (int, error) { 1542 // readTimeout is the amount of time ReadTimeoutConn should wait on a 1543 // read before checking for exit conditions. The tradeoff is between the 1544 // time it takes to react to session context cancellation and the overhead 1545 // of waking up and checking for exit conditions. 1546 const readTimeout = 1 * time.Second 1547 1548 // Remove the read deadline when returning from this function to avoid 1549 // unexpected behavior. 1550 defer func() { _ = c.SetReadDeadline(time.Time{}) }() 1551 for { 1552 if err := c.checkExitConds(); err != nil { 1553 return 0, err 1554 } 1555 if err := c.SetReadDeadline(timeutil.Now().Add(readTimeout)); err != nil { 1556 return 0, err 1557 } 1558 n, err := c.Conn.Read(b) 1559 // Continue if the error is due to timing out. 1560 if ne := (net.Error)(nil); errors.As(err, &ne) && ne.Timeout() { 1561 continue 1562 } 1563 return n, err 1564 } 1565 }