github.com/jackc/pgx/v5@v5.5.5/pgproto3/frontend.go (about) 1 package pgproto3 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "fmt" 8 "io" 9 ) 10 11 // Frontend acts as a client for the PostgreSQL wire protocol version 3. 12 type Frontend struct { 13 cr *chunkReader 14 w io.Writer 15 16 // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced 17 // before it is actually transmitted (i.e. before Flush). It is safe to change this variable when the Frontend is 18 // idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq. 19 tracer *tracer 20 21 wbuf []byte 22 encodeError error 23 24 // Backend message flyweights 25 authenticationOk AuthenticationOk 26 authenticationCleartextPassword AuthenticationCleartextPassword 27 authenticationMD5Password AuthenticationMD5Password 28 authenticationGSS AuthenticationGSS 29 authenticationGSSContinue AuthenticationGSSContinue 30 authenticationSASL AuthenticationSASL 31 authenticationSASLContinue AuthenticationSASLContinue 32 authenticationSASLFinal AuthenticationSASLFinal 33 backendKeyData BackendKeyData 34 bindComplete BindComplete 35 closeComplete CloseComplete 36 commandComplete CommandComplete 37 copyBothResponse CopyBothResponse 38 copyData CopyData 39 copyInResponse CopyInResponse 40 copyOutResponse CopyOutResponse 41 copyDone CopyDone 42 dataRow DataRow 43 emptyQueryResponse EmptyQueryResponse 44 errorResponse ErrorResponse 45 functionCallResponse FunctionCallResponse 46 noData NoData 47 noticeResponse NoticeResponse 48 notificationResponse NotificationResponse 49 parameterDescription ParameterDescription 50 parameterStatus ParameterStatus 51 parseComplete ParseComplete 52 readyForQuery ReadyForQuery 53 rowDescription RowDescription 54 portalSuspended PortalSuspended 55 56 bodyLen int 57 msgType byte 58 partialMsg bool 59 authType uint32 60 } 61 62 // NewFrontend creates a new Frontend. 63 func NewFrontend(r io.Reader, w io.Writer) *Frontend { 64 cr := newChunkReader(r, 0) 65 return &Frontend{cr: cr, w: w} 66 } 67 68 // Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error 69 // encountered will be returned from Flush. 70 // 71 // Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods 72 // such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an 73 // extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden 74 // behind an interface. 75 func (f *Frontend) Send(msg FrontendMessage) { 76 if f.encodeError != nil { 77 return 78 } 79 80 prevLen := len(f.wbuf) 81 newBuf, err := msg.Encode(f.wbuf) 82 if err != nil { 83 f.encodeError = err 84 return 85 } 86 f.wbuf = newBuf 87 88 if f.tracer != nil { 89 f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg) 90 } 91 } 92 93 // Flush writes any pending messages to the backend (i.e. the server). 94 func (f *Frontend) Flush() error { 95 if err := f.encodeError; err != nil { 96 f.encodeError = nil 97 f.wbuf = f.wbuf[:0] 98 return &writeError{err: err, safeToRetry: true} 99 } 100 101 if len(f.wbuf) == 0 { 102 return nil 103 } 104 105 n, err := f.w.Write(f.wbuf) 106 107 const maxLen = 1024 108 if len(f.wbuf) > maxLen { 109 f.wbuf = make([]byte, 0, maxLen) 110 } else { 111 f.wbuf = f.wbuf[:0] 112 } 113 114 if err != nil { 115 return &writeError{err: err, safeToRetry: n == 0} 116 } 117 118 return nil 119 } 120 121 // Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function 122 // PQtrace. 123 func (f *Frontend) Trace(w io.Writer, options TracerOptions) { 124 f.tracer = &tracer{ 125 w: w, 126 buf: &bytes.Buffer{}, 127 TracerOptions: options, 128 } 129 } 130 131 // Untrace stops tracing. 132 func (f *Frontend) Untrace() { 133 f.tracer = nil 134 } 135 136 // SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any 137 // error encountered will be returned from Flush. 138 func (f *Frontend) SendBind(msg *Bind) { 139 if f.encodeError != nil { 140 return 141 } 142 143 prevLen := len(f.wbuf) 144 newBuf, err := msg.Encode(f.wbuf) 145 if err != nil { 146 f.encodeError = err 147 return 148 } 149 f.wbuf = newBuf 150 151 if f.tracer != nil { 152 f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg) 153 } 154 } 155 156 // SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any 157 // error encountered will be returned from Flush. 158 func (f *Frontend) SendParse(msg *Parse) { 159 if f.encodeError != nil { 160 return 161 } 162 163 prevLen := len(f.wbuf) 164 newBuf, err := msg.Encode(f.wbuf) 165 if err != nil { 166 f.encodeError = err 167 return 168 } 169 f.wbuf = newBuf 170 171 if f.tracer != nil { 172 f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg) 173 } 174 } 175 176 // SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any 177 // error encountered will be returned from Flush. 178 func (f *Frontend) SendClose(msg *Close) { 179 if f.encodeError != nil { 180 return 181 } 182 183 prevLen := len(f.wbuf) 184 newBuf, err := msg.Encode(f.wbuf) 185 if err != nil { 186 f.encodeError = err 187 return 188 } 189 f.wbuf = newBuf 190 191 if f.tracer != nil { 192 f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg) 193 } 194 } 195 196 // SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is 197 // called. Any error encountered will be returned from Flush. 198 func (f *Frontend) SendDescribe(msg *Describe) { 199 if f.encodeError != nil { 200 return 201 } 202 203 prevLen := len(f.wbuf) 204 newBuf, err := msg.Encode(f.wbuf) 205 if err != nil { 206 f.encodeError = err 207 return 208 } 209 f.wbuf = newBuf 210 211 if f.tracer != nil { 212 f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg) 213 } 214 } 215 216 // SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called. 217 // Any error encountered will be returned from Flush. 218 func (f *Frontend) SendExecute(msg *Execute) { 219 if f.encodeError != nil { 220 return 221 } 222 223 prevLen := len(f.wbuf) 224 newBuf, err := msg.Encode(f.wbuf) 225 if err != nil { 226 f.encodeError = err 227 return 228 } 229 f.wbuf = newBuf 230 231 if f.tracer != nil { 232 f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) 233 } 234 } 235 236 // SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any 237 // error encountered will be returned from Flush. 238 func (f *Frontend) SendSync(msg *Sync) { 239 if f.encodeError != nil { 240 return 241 } 242 243 prevLen := len(f.wbuf) 244 newBuf, err := msg.Encode(f.wbuf) 245 if err != nil { 246 f.encodeError = err 247 return 248 } 249 f.wbuf = newBuf 250 251 if f.tracer != nil { 252 f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg) 253 } 254 } 255 256 // SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any 257 // error encountered will be returned from Flush. 258 func (f *Frontend) SendQuery(msg *Query) { 259 if f.encodeError != nil { 260 return 261 } 262 263 prevLen := len(f.wbuf) 264 newBuf, err := msg.Encode(f.wbuf) 265 if err != nil { 266 f.encodeError = err 267 return 268 } 269 f.wbuf = newBuf 270 271 if f.tracer != nil { 272 f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg) 273 } 274 } 275 276 // SendUnbufferedEncodedCopyData immediately sends an encoded CopyData message to the backend (i.e. the server). This method 277 // is more efficient than sending a CopyData message with Send as the message data is not copied to the internal buffer 278 // before being written out. The internal buffer is flushed before the message is sent. 279 func (f *Frontend) SendUnbufferedEncodedCopyData(msg []byte) error { 280 err := f.Flush() 281 if err != nil { 282 return err 283 } 284 285 n, err := f.w.Write(msg) 286 if err != nil { 287 return &writeError{err: err, safeToRetry: n == 0} 288 } 289 290 if f.tracer != nil { 291 f.tracer.traceCopyData('F', int32(len(msg)-1), &CopyData{}) 292 } 293 294 return nil 295 } 296 297 func translateEOFtoErrUnexpectedEOF(err error) error { 298 if err == io.EOF { 299 return io.ErrUnexpectedEOF 300 } 301 return err 302 } 303 304 // Receive receives a message from the backend. The returned message is only valid until the next call to Receive. 305 func (f *Frontend) Receive() (BackendMessage, error) { 306 if !f.partialMsg { 307 header, err := f.cr.Next(5) 308 if err != nil { 309 return nil, translateEOFtoErrUnexpectedEOF(err) 310 } 311 312 f.msgType = header[0] 313 314 msgLength := int(binary.BigEndian.Uint32(header[1:])) 315 if msgLength < 4 { 316 return nil, fmt.Errorf("invalid message length: %d", msgLength) 317 } 318 319 f.bodyLen = msgLength - 4 320 f.partialMsg = true 321 } 322 323 msgBody, err := f.cr.Next(f.bodyLen) 324 if err != nil { 325 return nil, translateEOFtoErrUnexpectedEOF(err) 326 } 327 328 f.partialMsg = false 329 330 var msg BackendMessage 331 switch f.msgType { 332 case '1': 333 msg = &f.parseComplete 334 case '2': 335 msg = &f.bindComplete 336 case '3': 337 msg = &f.closeComplete 338 case 'A': 339 msg = &f.notificationResponse 340 case 'c': 341 msg = &f.copyDone 342 case 'C': 343 msg = &f.commandComplete 344 case 'd': 345 msg = &f.copyData 346 case 'D': 347 msg = &f.dataRow 348 case 'E': 349 msg = &f.errorResponse 350 case 'G': 351 msg = &f.copyInResponse 352 case 'H': 353 msg = &f.copyOutResponse 354 case 'I': 355 msg = &f.emptyQueryResponse 356 case 'K': 357 msg = &f.backendKeyData 358 case 'n': 359 msg = &f.noData 360 case 'N': 361 msg = &f.noticeResponse 362 case 'R': 363 var err error 364 msg, err = f.findAuthenticationMessageType(msgBody) 365 if err != nil { 366 return nil, err 367 } 368 case 's': 369 msg = &f.portalSuspended 370 case 'S': 371 msg = &f.parameterStatus 372 case 't': 373 msg = &f.parameterDescription 374 case 'T': 375 msg = &f.rowDescription 376 case 'V': 377 msg = &f.functionCallResponse 378 case 'W': 379 msg = &f.copyBothResponse 380 case 'Z': 381 msg = &f.readyForQuery 382 default: 383 return nil, fmt.Errorf("unknown message type: %c", f.msgType) 384 } 385 386 err = msg.Decode(msgBody) 387 if err != nil { 388 return nil, err 389 } 390 391 if f.tracer != nil { 392 f.tracer.traceMessage('B', int32(5+len(msgBody)), msg) 393 } 394 395 return msg, nil 396 } 397 398 // Authentication message type constants. 399 // See src/include/libpq/pqcomm.h for all 400 // constants. 401 const ( 402 AuthTypeOk = 0 403 AuthTypeCleartextPassword = 3 404 AuthTypeMD5Password = 5 405 AuthTypeSCMCreds = 6 406 AuthTypeGSS = 7 407 AuthTypeGSSCont = 8 408 AuthTypeSSPI = 9 409 AuthTypeSASL = 10 410 AuthTypeSASLContinue = 11 411 AuthTypeSASLFinal = 12 412 ) 413 414 func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) { 415 if len(src) < 4 { 416 return nil, errors.New("authentication message too short") 417 } 418 f.authType = binary.BigEndian.Uint32(src[:4]) 419 420 switch f.authType { 421 case AuthTypeOk: 422 return &f.authenticationOk, nil 423 case AuthTypeCleartextPassword: 424 return &f.authenticationCleartextPassword, nil 425 case AuthTypeMD5Password: 426 return &f.authenticationMD5Password, nil 427 case AuthTypeSCMCreds: 428 return nil, errors.New("AuthTypeSCMCreds is unimplemented") 429 case AuthTypeGSS: 430 return &f.authenticationGSS, nil 431 case AuthTypeGSSCont: 432 return &f.authenticationGSSContinue, nil 433 case AuthTypeSSPI: 434 return nil, errors.New("AuthTypeSSPI is unimplemented") 435 case AuthTypeSASL: 436 return &f.authenticationSASL, nil 437 case AuthTypeSASLContinue: 438 return &f.authenticationSASLContinue, nil 439 case AuthTypeSASLFinal: 440 return &f.authenticationSASLFinal, nil 441 default: 442 return nil, fmt.Errorf("unknown authentication type: %d", f.authType) 443 } 444 } 445 446 // GetAuthType returns the authType used in the current state of the frontend. 447 // See SetAuthType for more information. 448 func (f *Frontend) GetAuthType() uint32 { 449 return f.authType 450 } 451 452 func (f *Frontend) ReadBufferLen() int { 453 return f.cr.wp - f.cr.rp 454 }