github.com/blend/go-sdk@v1.20220411.3/proxyprotocol/proxy_protocol_header.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package proxyprotocol 9 10 import ( 11 "bytes" 12 "encoding/binary" 13 "errors" 14 "fmt" 15 "io" 16 "math" 17 "net" 18 "strconv" 19 ) 20 21 // Protocol Headers 22 var ( 23 SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'} 24 SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'} 25 ) 26 27 // Errors 28 var ( 29 ErrCantReadVersion1Header = errors.New("proxyproto: can't read version 1 header") 30 ErrVersion1HeaderTooLong = errors.New("proxyproto: version 1 header must be 107 bytes or less") 31 ErrLineMustEndWithCrlf = errors.New("proxyproto: version 1 header is invalid, must end with \\r\\n") 32 ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command") 33 ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol") 34 ErrCantReadLength = errors.New("proxyproto: can't read length") 35 ErrCantResolveSourceUnixAddress = errors.New("proxyproto: can't resolve source Unix address") 36 ErrCantResolveDestinationUnixAddress = errors.New("proxyproto: can't resolve destination Unix address") 37 ErrNoProxyProtocol = errors.New("proxyproto: proxy protocol signature not present") 38 ErrUnknownProxyProtocolVersion = errors.New("proxyproto: unknown proxy protocol version") 39 ErrUnsupportedProtocolVersionAndCommand = errors.New("proxyproto: unsupported proxy protocol version and command") 40 ErrUnsupportedAddressFamilyAndProtocol = errors.New("proxyproto: unsupported address family and protocol") 41 ErrInvalidLength = errors.New("proxyproto: invalid length") 42 ErrInvalidAddress = errors.New("proxyproto: invalid address") 43 ErrInvalidPortNumber = errors.New("proxyproto: invalid port number") 44 ErrSuperfluousProxyHeader = errors.New("proxyproto: upstream connection sent PROXY header but isn't allowed to send one") 45 ) 46 47 // Header is the placeholder for proxy protocol header. 48 type Header struct { 49 Version byte 50 Command ProtocolVersionAndCommand 51 TransportProtocol AddressFamilyAndProtocol 52 SourceAddr net.Addr 53 DestinationAddr net.Addr 54 rawTLVs []byte 55 } 56 57 // TCPAddrs returns the tcp addresses for the proxy protocol header. 58 func (header *Header) TCPAddrs() (sourceAddr, destAddr *net.TCPAddr, ok bool) { 59 if !header.TransportProtocol.IsStream() { 60 return nil, nil, false 61 } 62 sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) 63 destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) 64 return sourceAddr, destAddr, sourceOK && destOK 65 } 66 67 // UDPAddrs returns the udp addresses for the proxy protocol header. 68 func (header *Header) UDPAddrs() (sourceAddr, destAddr *net.UDPAddr, ok bool) { 69 if !header.TransportProtocol.IsDatagram() { 70 return nil, nil, false 71 } 72 sourceAddr, sourceOK := header.SourceAddr.(*net.UDPAddr) 73 destAddr, destOK := header.DestinationAddr.(*net.UDPAddr) 74 return sourceAddr, destAddr, sourceOK && destOK 75 } 76 77 // UnixAddrs returns the uds addresses for the proxy protocol header. 78 func (header *Header) UnixAddrs() (sourceAddr, destAddr *net.UnixAddr, ok bool) { 79 if !header.TransportProtocol.IsUnix() { 80 return nil, nil, false 81 } 82 sourceAddr, sourceOK := header.SourceAddr.(*net.UnixAddr) 83 destAddr, destOK := header.DestinationAddr.(*net.UnixAddr) 84 return sourceAddr, destAddr, sourceOK && destOK 85 } 86 87 // IPs returns the ip addresses for the proxy protocol header. 88 func (header *Header) IPs() (sourceIP, destIP net.IP, ok bool) { 89 if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { 90 return sourceAddr.IP, destAddr.IP, true 91 } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { 92 return sourceAddr.IP, destAddr.IP, true 93 } else { 94 return nil, nil, false 95 } 96 } 97 98 // Ports returns the ports for the proxy protocol header. 99 func (header *Header) Ports() (sourcePort, destPort int, ok bool) { 100 if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { 101 return sourceAddr.Port, destAddr.Port, true 102 } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { 103 return sourceAddr.Port, destAddr.Port, true 104 } else { 105 return 0, 0, false 106 } 107 } 108 109 // EqualTo returns true if headers are equivalent, false otherwise. 110 // Deprecated: use EqualsTo instead. This method will eventually be removed. 111 func (header *Header) EqualTo(otherHeader *Header) bool { 112 return header.EqualsTo(otherHeader) 113 } 114 115 // EqualsTo returns true if headers are equivalent, false otherwise. 116 func (header *Header) EqualsTo(otherHeader *Header) bool { 117 if otherHeader == nil { 118 return false 119 } 120 // TLVs only exist for version 2 121 if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) { 122 return false 123 } 124 if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol { 125 return false 126 } 127 // Return early for header with ProtocolVersionAndCommandLocal command, which contains no address information 128 if header.Command == ProtocolVersionAndCommandLocal { 129 return true 130 } 131 return header.SourceAddr.String() == otherHeader.SourceAddr.String() && 132 header.DestinationAddr.String() == otherHeader.DestinationAddr.String() 133 } 134 135 // WriteTo renders a proxy protocol header in a format and writes it to an io.Writer. 136 func (header *Header) WriteTo(w io.Writer) (int64, error) { 137 buf, err := header.Format() 138 if err != nil { 139 return 0, err 140 } 141 return bytes.NewBuffer(buf).WriteTo(w) 142 } 143 144 // Format renders a proxy protocol header in a format to write over the wire. 145 func (header *Header) Format() ([]byte, error) { 146 switch header.Version { 147 case 1: 148 return header.formatVersion1() 149 case 2: 150 return header.formatVersion2() 151 default: 152 return nil, ErrUnknownProxyProtocolVersion 153 } 154 } 155 156 // TLVs returns the TLVs stored into this header, if they exist. TLVs are optional for v2 of the protocol. 157 func (header *Header) TLVs() ([]TLV, error) { 158 return SplitTLVs(header.rawTLVs) 159 } 160 161 // SetTLVs sets the TLVs stored in this header. This method replaces any 162 // previous TLV. 163 func (header *Header) SetTLVs(tlvs []TLV) error { 164 raw, err := JoinTLVs(tlvs) 165 if err != nil { 166 return err 167 } 168 header.rawTLVs = raw 169 return nil 170 } 171 172 const ( 173 crlf = "\r\n" 174 separator = " " 175 ) 176 177 func (header *Header) formatVersion1() ([]byte, error) { 178 // As of version 1, only "TCP4" ( \x54 \x43 \x50 \x34 ) for TCP over IPv4, 179 // and "TCP6" ( \x54 \x43 \x50 \x36 ) for TCP over IPv6 are allowed. 180 var proto string 181 switch header.TransportProtocol { 182 case AddressFamilyAndProtocolTCPv4: 183 proto = "TCP4" 184 case AddressFamilyAndProtocolTCPv6: 185 proto = "TCP6" 186 default: 187 // Unknown connection (short form) 188 return []byte("PROXY UNKNOWN" + crlf), nil 189 } 190 191 sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) 192 destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) 193 if !sourceOK || !destOK { 194 return nil, ErrInvalidAddress 195 } 196 197 sourceIP, destIP := sourceAddr.IP, destAddr.IP 198 switch header.TransportProtocol { 199 case AddressFamilyAndProtocolTCPv4: 200 sourceIP = sourceIP.To4() 201 destIP = destIP.To4() 202 case AddressFamilyAndProtocolTCPv6: 203 sourceIP = sourceIP.To16() 204 destIP = destIP.To16() 205 } 206 if sourceIP == nil || destIP == nil { 207 return nil, ErrInvalidAddress 208 } 209 210 buf := bytes.NewBuffer(make([]byte, 0, 108)) 211 buf.Write(SIGV1) 212 buf.WriteString(separator) 213 buf.WriteString(proto) 214 buf.WriteString(separator) 215 buf.WriteString(sourceIP.String()) 216 buf.WriteString(separator) 217 buf.WriteString(destIP.String()) 218 buf.WriteString(separator) 219 buf.WriteString(strconv.Itoa(sourceAddr.Port)) 220 buf.WriteString(separator) 221 buf.WriteString(strconv.Itoa(destAddr.Port)) 222 buf.WriteString(crlf) 223 224 return buf.Bytes(), nil 225 } 226 227 var ( 228 lengthUnspec = uint16(0) 229 lengthV4 = uint16(12) 230 lengthV6 = uint16(36) 231 lengthUnix = uint16(216) 232 lengthUnspecBytes = func() []byte { 233 a := make([]byte, 2) 234 binary.BigEndian.PutUint16(a, lengthUnspec) 235 return a 236 }() 237 lengthV4Bytes = func() []byte { 238 a := make([]byte, 2) 239 binary.BigEndian.PutUint16(a, lengthV4) 240 return a 241 }() 242 lengthV6Bytes = func() []byte { 243 a := make([]byte, 2) 244 binary.BigEndian.PutUint16(a, lengthV6) 245 return a 246 }() 247 lengthUnixBytes = func() []byte { 248 a := make([]byte, 2) 249 binary.BigEndian.PutUint16(a, lengthUnix) 250 return a 251 }() 252 errUint16Overflow = errors.New("proxyproto: uint16 overflow") 253 ) 254 255 func (header *Header) formatVersion2() ([]byte, error) { 256 var buf bytes.Buffer 257 buf.Write(SIGV2) 258 buf.WriteByte(header.Command.toByte()) 259 buf.WriteByte(header.TransportProtocol.toByte()) 260 if header.TransportProtocol.IsUnspec() { 261 // For UNSPEC, write no addresses and ports but only TLVs if they are present 262 hdrLen, err := addTLVLen(lengthUnspecBytes, len(header.rawTLVs)) 263 if err != nil { 264 return nil, err 265 } 266 buf.Write(hdrLen) 267 } else { 268 var addrSrc, addrDst []byte 269 if header.TransportProtocol.IsIPv4() { 270 hdrLen, err := addTLVLen(lengthV4Bytes, len(header.rawTLVs)) 271 if err != nil { 272 return nil, err 273 } 274 buf.Write(hdrLen) 275 sourceIP, destIP, _ := header.IPs() 276 addrSrc = sourceIP.To4() 277 addrDst = destIP.To4() 278 } else if header.TransportProtocol.IsIPv6() { 279 hdrLen, err := addTLVLen(lengthV6Bytes, len(header.rawTLVs)) 280 if err != nil { 281 return nil, err 282 } 283 buf.Write(hdrLen) 284 sourceIP, destIP, _ := header.IPs() 285 addrSrc = sourceIP.To16() 286 addrDst = destIP.To16() 287 } else if header.TransportProtocol.IsUnix() { 288 buf.Write(lengthUnixBytes) 289 sourceAddr, destAddr, ok := header.UnixAddrs() 290 if !ok { 291 return nil, ErrInvalidAddress 292 } 293 addrSrc = formatUnixName(sourceAddr.Name) 294 addrDst = formatUnixName(destAddr.Name) 295 } 296 297 if addrSrc == nil || addrDst == nil { 298 return nil, ErrInvalidAddress 299 } 300 buf.Write(addrSrc) 301 buf.Write(addrDst) 302 303 if sourcePort, destPort, ok := header.Ports(); ok { 304 portBytes := make([]byte, 2) 305 306 binary.BigEndian.PutUint16(portBytes, uint16(sourcePort)) 307 buf.Write(portBytes) 308 309 binary.BigEndian.PutUint16(portBytes, uint16(destPort)) 310 buf.Write(portBytes) 311 } 312 } 313 314 if len(header.rawTLVs) > 0 { 315 buf.Write(header.rawTLVs) 316 } 317 318 return buf.Bytes(), nil 319 } 320 321 // ProtocolVersionAndCommand represents the command in proxy protocol v2. 322 // Command doesn't exist in v1 but it should be set since other parts of 323 // this library may rely on it for determining connection details. 324 type ProtocolVersionAndCommand byte 325 326 const ( 327 // ProtocolVersionAndCommandLocal represents the ProtocolVersionAndCommandLocal command in v2 or UNKNOWN transport in v1, 328 // in which case no address information is expected. 329 ProtocolVersionAndCommandLocal ProtocolVersionAndCommand = '\x20' 330 // ProtocolVersionAndCommandProxy represents the PROXY command in v2 or transport is not UNKNOWN in v1, 331 // in which case valid local/remote address and port information is expected. 332 ProtocolVersionAndCommandProxy ProtocolVersionAndCommand = '\x21' 333 ) 334 335 var supportedCommand = map[ProtocolVersionAndCommand]bool{ 336 ProtocolVersionAndCommandLocal: true, 337 ProtocolVersionAndCommandProxy: true, 338 } 339 340 // IsLocal returns true if the command in v2 is ProtocolVersionAndCommandLocal or the transport in v1 is UNKNOWN, 341 // i.e. when no address information is expected, false otherwise. 342 func (pvc ProtocolVersionAndCommand) IsLocal() bool { 343 return ProtocolVersionAndCommandLocal == pvc 344 } 345 346 // IsProxy returns true if the command in v2 is PROXY or the transport in v1 is not UNKNOWN, 347 // i.e. when valid local/remote address and port information is expected, false otherwise. 348 func (pvc ProtocolVersionAndCommand) IsProxy() bool { 349 return ProtocolVersionAndCommandProxy == pvc 350 } 351 352 // IsUnspec returns true if the command is unspecified, false otherwise. 353 func (pvc ProtocolVersionAndCommand) IsUnspec() bool { 354 return !(pvc.IsLocal() || pvc.IsProxy()) 355 } 356 357 func (pvc ProtocolVersionAndCommand) toByte() byte { 358 if pvc.IsLocal() { 359 return byte(ProtocolVersionAndCommandLocal) 360 } else if pvc.IsProxy() { 361 return byte(ProtocolVersionAndCommandProxy) 362 } 363 364 return byte(ProtocolVersionAndCommandLocal) 365 } 366 367 // AddressFamilyAndProtocol represents address family and transport protocol. 368 type AddressFamilyAndProtocol byte 369 370 // Address family and protocol constants 371 const ( 372 AddressFamilyAndProtocolUnknown AddressFamilyAndProtocol = '\x00' 373 AddressFamilyAndProtocolTCPv4 AddressFamilyAndProtocol = '\x11' 374 AddressFamilyAndProtocolUDPv4 AddressFamilyAndProtocol = '\x12' 375 AddressFamilyAndProtocolTCPv6 AddressFamilyAndProtocol = '\x21' 376 AddressFamilyAndProtocolUDPv6 AddressFamilyAndProtocol = '\x22' 377 AddressFamilyAndProtocolUnixStream AddressFamilyAndProtocol = '\x31' 378 AddressFamilyAndProtocolUnixDatagram AddressFamilyAndProtocol = '\x32' 379 ) 380 381 // IsIPv4 returns true if the address family is IPv4 (AF_INET4), false otherwise. 382 func (ap AddressFamilyAndProtocol) IsIPv4() bool { 383 return 0x10 == ap&0xF0 384 } 385 386 // IsIPv6 returns true if the address family is IPv6 (AF_INET6), false otherwise. 387 func (ap AddressFamilyAndProtocol) IsIPv6() bool { 388 return 0x20 == ap&0xF0 389 } 390 391 // IsUnix returns true if the address family is UNIX (AF_UNIX), false otherwise. 392 func (ap AddressFamilyAndProtocol) IsUnix() bool { 393 return 0x30 == ap&0xF0 394 } 395 396 // IsStream returns true if the transport protocol is TCP or STREAM (SOCK_STREAM), false otherwise. 397 func (ap AddressFamilyAndProtocol) IsStream() bool { 398 return 0x01 == ap&0x0F 399 } 400 401 // IsDatagram returns true if the transport protocol is UDP or DGRAM (SOCK_DGRAM), false otherwise. 402 func (ap AddressFamilyAndProtocol) IsDatagram() bool { 403 return 0x02 == ap&0x0F 404 } 405 406 // IsUnspec returns true if the transport protocol or address family is unspecified, false otherwise. 407 func (ap AddressFamilyAndProtocol) IsUnspec() bool { 408 return (0x00 == ap&0xF0) || (0x00 == ap&0x0F) 409 } 410 411 func (ap AddressFamilyAndProtocol) toByte() byte { 412 if ap.IsIPv4() && ap.IsStream() { 413 return byte(AddressFamilyAndProtocolTCPv4) 414 } else if ap.IsIPv4() && ap.IsDatagram() { 415 return byte(AddressFamilyAndProtocolUDPv4) 416 } else if ap.IsIPv6() && ap.IsStream() { 417 return byte(AddressFamilyAndProtocolTCPv6) 418 } else if ap.IsIPv6() && ap.IsDatagram() { 419 return byte(AddressFamilyAndProtocolUDPv6) 420 } else if ap.IsUnix() && ap.IsStream() { 421 return byte(AddressFamilyAndProtocolUnixStream) 422 } else if ap.IsUnix() && ap.IsDatagram() { 423 return byte(AddressFamilyAndProtocolUnixDatagram) 424 } 425 426 return byte(AddressFamilyAndProtocolUnknown) 427 } 428 429 // addTLVLen adds the length of the TLV to the header length or errors on uint16 overflow. 430 func addTLVLen(cur []byte, tlvLen int) ([]byte, error) { 431 if tlvLen == 0 { 432 return cur, nil 433 } 434 curLen := binary.BigEndian.Uint16(cur) 435 newLen := int(curLen) + tlvLen 436 if newLen >= 1<<16 { 437 return nil, errUint16Overflow 438 } 439 a := make([]byte, 2) 440 binary.BigEndian.PutUint16(a, uint16(newLen)) 441 return a, nil 442 } 443 444 /* 445 const ( 446 // Section 2.2 447 PP2_TYPE_ALPN PP2Type = 0x01 448 PP2_TYPE_AUTHORITY PP2Type = 0x02 449 PP2_TYPE_CRC32C PP2Type = 0x03 450 PP2_TYPE_NOOP PP2Type = 0x04 451 PP2_TYPE_UNIQUE_ID PP2Type = 0x05 452 PP2_TYPE_SSL PP2Type = 0x20 453 PP2_SUBTYPE_SSL_VERSION PP2Type = 0x21 454 PP2_SUBTYPE_SSL_CN PP2Type = 0x22 455 PP2_SUBTYPE_SSL_CIPHER PP2Type = 0x23 456 PP2_SUBTYPE_SSL_SIG_ALG PP2Type = 0x24 457 PP2_SUBTYPE_SSL_KEY_ALG PP2Type = 0x25 458 PP2_TYPE_NETNS PP2Type = 0x30 459 460 // Section 2.2.7, reserved types 461 PP2_TYPE_MIN_CUSTOM PP2Type = 0xE0 462 PP2_TYPE_MAX_CUSTOM PP2Type = 0xEF 463 PP2_TYPE_MIN_EXPERIMENT PP2Type = 0xF0 464 PP2_TYPE_MAX_EXPERIMENT PP2Type = 0xF7 465 PP2_TYPE_MIN_FUTURE PP2Type = 0xF8 466 PP2_TYPE_MAX_FUTURE PP2Type = 0xFF 467 ) 468 */ 469 470 // Proxy Protocol Type 2 constants 471 const ( 472 PP2TypeNoop PP2Type = 0x04 473 PP2TypeAuthority PP2Type = 0x02 474 ) 475 476 // Error constants 477 var ( 478 ErrTruncatedTLV = errors.New("proxyproto: truncated TLV") 479 ErrMalformedTLV = errors.New("proxyproto: malformed TLV Value") 480 ErrIncompatibleTLV = errors.New("proxyproto: incompatible TLV type") 481 ) 482 483 // PP2Type is the proxy protocol v2 type 484 type PP2Type byte 485 486 // TLV is a uninterpreted Type-Length-Value for V2 protocol, see section 2.2 487 type TLV struct { 488 Type PP2Type 489 Value []byte 490 } 491 492 // SplitTLVs splits the Type-Length-Value vector, returns the vector or an error. 493 func SplitTLVs(raw []byte) ([]TLV, error) { 494 var tlvs []TLV 495 for i := 0; i < len(raw); { 496 tlv := TLV{ 497 Type: PP2Type(raw[i]), 498 } 499 if len(raw)-i <= 2 { 500 return nil, ErrTruncatedTLV 501 } 502 tlvLen := int(binary.BigEndian.Uint16(raw[i+1 : i+3])) // Max length = 65K 503 i += 3 504 if i+tlvLen > len(raw) { 505 return nil, ErrTruncatedTLV 506 } 507 // Ignore no-op padding 508 if tlv.Type != PP2TypeNoop { 509 tlv.Value = make([]byte, tlvLen) 510 copy(tlv.Value, raw[i:i+tlvLen]) 511 } 512 i += tlvLen 513 tlvs = append(tlvs, tlv) 514 } 515 return tlvs, nil 516 } 517 518 // JoinTLVs joins multiple Type-Length-Value records. 519 func JoinTLVs(tlvs []TLV) ([]byte, error) { 520 var raw []byte 521 for _, tlv := range tlvs { 522 if len(tlv.Value) > math.MaxUint16 { 523 return nil, fmt.Errorf("proxyproto: cannot format TLV %v with length %d", tlv.Type, len(tlv.Value)) 524 } 525 var length [2]byte 526 binary.BigEndian.PutUint16(length[:], uint16(len(tlv.Value))) 527 raw = append(raw, byte(tlv.Type)) 528 raw = append(raw, length[:]...) 529 raw = append(raw, tlv.Value...) 530 } 531 return raw, nil 532 } 533 534 func formatUnixName(name string) []byte { 535 n := int(lengthUnix) / 2 536 if len(name) >= n { 537 return []byte(name[:n]) 538 } 539 pad := make([]byte, n-len(name)) 540 return append([]byte(name), pad...) 541 }