github.com/jackc/pgx/v5@v5.5.5/pgproto3/backend.go (about) 1 package pgproto3 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "fmt" 7 "io" 8 ) 9 10 // Backend acts as a server for the PostgreSQL wire protocol version 3. 11 type Backend struct { 12 cr *chunkReader 13 w io.Writer 14 15 // tracer is used to trace messages when Send or Receive is called. This means an outbound message is traced 16 // before it is actually transmitted (i.e. before Flush). 17 tracer *tracer 18 19 wbuf []byte 20 encodeError error 21 22 // Frontend message flyweights 23 bind Bind 24 cancelRequest CancelRequest 25 _close Close 26 copyFail CopyFail 27 copyData CopyData 28 copyDone CopyDone 29 describe Describe 30 execute Execute 31 flush Flush 32 functionCall FunctionCall 33 gssEncRequest GSSEncRequest 34 parse Parse 35 query Query 36 sslRequest SSLRequest 37 startupMessage StartupMessage 38 sync Sync 39 terminate Terminate 40 41 bodyLen int 42 maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error. 43 msgType byte 44 partialMsg bool 45 authType uint32 46 } 47 48 const ( 49 minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. 50 maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. 51 ) 52 53 // NewBackend creates a new Backend. 54 func NewBackend(r io.Reader, w io.Writer) *Backend { 55 cr := newChunkReader(r, 0) 56 return &Backend{cr: cr, w: w} 57 } 58 59 // Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error 60 // encountered will be returned from Flush. 61 func (b *Backend) Send(msg BackendMessage) { 62 if b.encodeError != nil { 63 return 64 } 65 66 prevLen := len(b.wbuf) 67 newBuf, err := msg.Encode(b.wbuf) 68 if err != nil { 69 b.encodeError = err 70 return 71 } 72 b.wbuf = newBuf 73 74 if b.tracer != nil { 75 b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) 76 } 77 } 78 79 // Flush writes any pending messages to the frontend (i.e. the client). 80 func (b *Backend) Flush() error { 81 if err := b.encodeError; err != nil { 82 b.encodeError = nil 83 b.wbuf = b.wbuf[:0] 84 return &writeError{err: err, safeToRetry: true} 85 } 86 87 n, err := b.w.Write(b.wbuf) 88 89 const maxLen = 1024 90 if len(b.wbuf) > maxLen { 91 b.wbuf = make([]byte, 0, maxLen) 92 } else { 93 b.wbuf = b.wbuf[:0] 94 } 95 96 if err != nil { 97 return &writeError{err: err, safeToRetry: n == 0} 98 } 99 100 return nil 101 } 102 103 // Trace starts tracing the message traffic to w. It writes in a similar format to that produced by the libpq function 104 // PQtrace. 105 func (b *Backend) Trace(w io.Writer, options TracerOptions) { 106 b.tracer = &tracer{ 107 w: w, 108 buf: &bytes.Buffer{}, 109 TracerOptions: options, 110 } 111 } 112 113 // Untrace stops tracing. 114 func (b *Backend) Untrace() { 115 b.tracer = nil 116 } 117 118 // ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method 119 // because the initial connection message is "special" and does not include the message type as the first byte. This 120 // will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest. 121 func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) { 122 buf, err := b.cr.Next(4) 123 if err != nil { 124 return nil, err 125 } 126 msgSize := int(binary.BigEndian.Uint32(buf) - 4) 127 128 if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen { 129 return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize) 130 } 131 132 buf, err = b.cr.Next(msgSize) 133 if err != nil { 134 return nil, translateEOFtoErrUnexpectedEOF(err) 135 } 136 137 code := binary.BigEndian.Uint32(buf) 138 139 switch code { 140 case ProtocolVersionNumber: 141 err = b.startupMessage.Decode(buf) 142 if err != nil { 143 return nil, err 144 } 145 return &b.startupMessage, nil 146 case sslRequestNumber: 147 err = b.sslRequest.Decode(buf) 148 if err != nil { 149 return nil, err 150 } 151 return &b.sslRequest, nil 152 case cancelRequestCode: 153 err = b.cancelRequest.Decode(buf) 154 if err != nil { 155 return nil, err 156 } 157 return &b.cancelRequest, nil 158 case gssEncReqNumber: 159 err = b.gssEncRequest.Decode(buf) 160 if err != nil { 161 return nil, err 162 } 163 return &b.gssEncRequest, nil 164 default: 165 return nil, fmt.Errorf("unknown startup message code: %d", code) 166 } 167 } 168 169 // Receive receives a message from the frontend. The returned message is only valid until the next call to Receive. 170 func (b *Backend) Receive() (FrontendMessage, error) { 171 if !b.partialMsg { 172 header, err := b.cr.Next(5) 173 if err != nil { 174 return nil, translateEOFtoErrUnexpectedEOF(err) 175 } 176 177 b.msgType = header[0] 178 b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 179 if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen { 180 return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen} 181 } 182 b.partialMsg = true 183 } 184 185 var msg FrontendMessage 186 switch b.msgType { 187 case 'B': 188 msg = &b.bind 189 case 'C': 190 msg = &b._close 191 case 'D': 192 msg = &b.describe 193 case 'E': 194 msg = &b.execute 195 case 'F': 196 msg = &b.functionCall 197 case 'f': 198 msg = &b.copyFail 199 case 'd': 200 msg = &b.copyData 201 case 'c': 202 msg = &b.copyDone 203 case 'H': 204 msg = &b.flush 205 case 'P': 206 msg = &b.parse 207 case 'p': 208 switch b.authType { 209 case AuthTypeSASL: 210 msg = &SASLInitialResponse{} 211 case AuthTypeSASLContinue: 212 msg = &SASLResponse{} 213 case AuthTypeSASLFinal: 214 msg = &SASLResponse{} 215 case AuthTypeGSS, AuthTypeGSSCont: 216 msg = &GSSResponse{} 217 case AuthTypeCleartextPassword, AuthTypeMD5Password: 218 fallthrough 219 default: 220 // to maintain backwards compatibility 221 msg = &PasswordMessage{} 222 } 223 case 'Q': 224 msg = &b.query 225 case 'S': 226 msg = &b.sync 227 case 'X': 228 msg = &b.terminate 229 default: 230 return nil, fmt.Errorf("unknown message type: %c", b.msgType) 231 } 232 233 msgBody, err := b.cr.Next(b.bodyLen) 234 if err != nil { 235 return nil, translateEOFtoErrUnexpectedEOF(err) 236 } 237 238 b.partialMsg = false 239 240 err = msg.Decode(msgBody) 241 if err != nil { 242 return nil, err 243 } 244 245 if b.tracer != nil { 246 b.tracer.traceMessage('F', int32(5+len(msgBody)), msg) 247 } 248 249 return msg, nil 250 } 251 252 // SetAuthType sets the authentication type in the backend. 253 // Since multiple message types can start with 'p', SetAuthType allows 254 // contextual identification of FrontendMessages. For example, in the 255 // PG message flow documentation for PasswordMessage: 256 // 257 // Byte1('p') 258 // 259 // Identifies the message as a password response. Note that this is also used for 260 // GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from 261 // the context. 262 // 263 // Since the Frontend does not know about the state of a backend, it is important 264 // to call SetAuthType() after an authentication request is received by the Frontend. 265 func (b *Backend) SetAuthType(authType uint32) error { 266 switch authType { 267 case AuthTypeOk, 268 AuthTypeCleartextPassword, 269 AuthTypeMD5Password, 270 AuthTypeSCMCreds, 271 AuthTypeGSS, 272 AuthTypeGSSCont, 273 AuthTypeSSPI, 274 AuthTypeSASL, 275 AuthTypeSASLContinue, 276 AuthTypeSASLFinal: 277 b.authType = authType 278 default: 279 return fmt.Errorf("authType not recognized: %d", authType) 280 } 281 282 return nil 283 } 284 285 // SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return 286 // an error. This is useful for protecting against malicious clients that send large messages with the intent of 287 // causing memory exhaustion. 288 // The default value is 0. 289 // If maxBodyLen is 0, then no maximum is enforced. 290 func (b *Backend) SetMaxBodyLen(maxBodyLen int) { 291 b.maxBodyLen = maxBodyLen 292 }