github.com/blend/go-sdk@v1.20220411.3/proxyprotocol/proxy_protocol_header_test.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package proxyprotocol 9 10 import ( 11 "bytes" 12 "errors" 13 "net" 14 "reflect" 15 "testing" 16 ) 17 18 // Stuff to be used in both versions tests. 19 20 const ( 21 noProtocol = "There is no spoon" 22 ip4Addr = "127.0.0.1" 23 ip6Addr = "::1" 24 ip6LongAddr = "1234:5678:9abc:def0:cafe:babe:dead:2bad" 25 port = 65533 26 invalidPort = 99999 27 ) 28 29 var ( 30 v4ip = net.ParseIP(ip4Addr).To4() 31 v6ip = net.ParseIP(ip6Addr).To16() 32 33 v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: port} 34 v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: port} 35 36 v4UDPAddr net.Addr = &net.UDPAddr{IP: v4ip, Port: port} 37 v6UDPAddr net.Addr = &net.UDPAddr{IP: v6ip, Port: port} 38 39 unixStreamAddr net.Addr = &net.UnixAddr{Net: "unix", Name: "socket"} 40 unixDatagramAddr net.Addr = &net.UnixAddr{Net: "unixgram", Name: "socket"} 41 42 errReadIntentionallyBroken = errors.New("read is intentionally broken") 43 ) 44 45 func TestEqualsTo(t *testing.T) { 46 var headersEqual = []struct { 47 this, that *Header 48 expected bool 49 }{ 50 { 51 &Header{ 52 Version: 1, 53 Command: ProtocolVersionAndCommandProxy, 54 TransportProtocol: AddressFamilyAndProtocolTCPv4, 55 SourceAddr: &net.TCPAddr{ 56 IP: net.ParseIP("10.1.1.1"), 57 Port: 1000, 58 }, 59 DestinationAddr: &net.TCPAddr{ 60 IP: net.ParseIP("20.2.2.2"), 61 Port: 2000, 62 }, 63 }, 64 nil, 65 false, 66 }, 67 { 68 &Header{ 69 Version: 1, 70 Command: ProtocolVersionAndCommandProxy, 71 TransportProtocol: AddressFamilyAndProtocolTCPv4, 72 SourceAddr: &net.TCPAddr{ 73 IP: net.ParseIP("10.1.1.1"), 74 Port: 1000, 75 }, 76 DestinationAddr: &net.TCPAddr{ 77 IP: net.ParseIP("20.2.2.2"), 78 Port: 2000, 79 }, 80 }, 81 &Header{ 82 Version: 2, 83 Command: ProtocolVersionAndCommandProxy, 84 TransportProtocol: AddressFamilyAndProtocolTCPv4, 85 SourceAddr: &net.TCPAddr{ 86 IP: net.ParseIP("10.1.1.1"), 87 Port: 1000, 88 }, 89 DestinationAddr: &net.TCPAddr{ 90 IP: net.ParseIP("20.2.2.2"), 91 Port: 2000, 92 }, 93 }, 94 false, 95 }, 96 { 97 &Header{ 98 Version: 1, 99 Command: ProtocolVersionAndCommandProxy, 100 TransportProtocol: AddressFamilyAndProtocolTCPv4, 101 SourceAddr: &net.TCPAddr{ 102 IP: net.ParseIP("10.1.1.1"), 103 Port: 1000, 104 }, 105 DestinationAddr: &net.TCPAddr{ 106 IP: net.ParseIP("20.2.2.2"), 107 Port: 2000, 108 }, 109 }, 110 &Header{ 111 Version: 1, 112 Command: ProtocolVersionAndCommandProxy, 113 TransportProtocol: AddressFamilyAndProtocolTCPv4, 114 SourceAddr: &net.TCPAddr{ 115 IP: net.ParseIP("10.1.1.1"), 116 Port: 1000, 117 }, 118 DestinationAddr: &net.TCPAddr{ 119 IP: net.ParseIP("20.2.2.2"), 120 Port: 2000, 121 }, 122 }, 123 true, 124 }, 125 } 126 127 for _, tt := range headersEqual { 128 if actual := tt.this.EqualsTo(tt.that); actual != tt.expected { 129 t.Fatalf("expected %t, actual %t", tt.expected, actual) 130 } 131 } 132 } 133 134 // This is here just because of coveralls 135 func TestEqualTo(t *testing.T) { 136 TestEqualsTo(t) 137 } 138 139 func TestGetters(t *testing.T) { 140 var tests = []struct { 141 name string 142 header *Header 143 tcpSourceAddr, tcpDestAddr *net.TCPAddr 144 udpSourceAddr, udpDestAddr *net.UDPAddr 145 unixSourceAddr, unixDestAddr *net.UnixAddr 146 ipSource, ipDest net.IP 147 portSource, portDest int 148 }{ 149 { 150 name: "AddressFamilyAndProtocolTCPv4", 151 header: &Header{ 152 Version: 1, 153 Command: ProtocolVersionAndCommandProxy, 154 TransportProtocol: AddressFamilyAndProtocolTCPv4, 155 SourceAddr: &net.TCPAddr{ 156 IP: net.ParseIP("10.1.1.1"), 157 Port: 1000, 158 }, 159 DestinationAddr: &net.TCPAddr{ 160 IP: net.ParseIP("20.2.2.2"), 161 Port: 2000, 162 }, 163 }, 164 tcpSourceAddr: &net.TCPAddr{ 165 IP: net.ParseIP("10.1.1.1"), 166 Port: 1000, 167 }, 168 tcpDestAddr: &net.TCPAddr{ 169 IP: net.ParseIP("20.2.2.2"), 170 Port: 2000, 171 }, 172 ipSource: net.ParseIP("10.1.1.1"), 173 ipDest: net.ParseIP("20.2.2.2"), 174 portSource: 1000, 175 portDest: 2000, 176 }, 177 { 178 name: "UDPv4", 179 header: &Header{ 180 Version: 2, 181 Command: ProtocolVersionAndCommandProxy, 182 TransportProtocol: AddressFamilyAndProtocolUDPv6, 183 SourceAddr: &net.UDPAddr{ 184 IP: net.ParseIP("10.1.1.1"), 185 Port: 1000, 186 }, 187 DestinationAddr: &net.UDPAddr{ 188 IP: net.ParseIP("20.2.2.2"), 189 Port: 2000, 190 }, 191 }, 192 udpSourceAddr: &net.UDPAddr{ 193 IP: net.ParseIP("10.1.1.1"), 194 Port: 1000, 195 }, 196 udpDestAddr: &net.UDPAddr{ 197 IP: net.ParseIP("20.2.2.2"), 198 Port: 2000, 199 }, 200 ipSource: net.ParseIP("10.1.1.1"), 201 ipDest: net.ParseIP("20.2.2.2"), 202 portSource: 1000, 203 portDest: 2000, 204 }, 205 { 206 name: "UnixStream", 207 header: &Header{ 208 Version: 2, 209 Command: ProtocolVersionAndCommandProxy, 210 TransportProtocol: AddressFamilyAndProtocolUnixStream, 211 SourceAddr: &net.UnixAddr{ 212 Net: "unix", 213 Name: "src", 214 }, 215 DestinationAddr: &net.UnixAddr{ 216 Net: "unix", 217 Name: "dst", 218 }, 219 }, 220 unixSourceAddr: &net.UnixAddr{ 221 Net: "unix", 222 Name: "src", 223 }, 224 unixDestAddr: &net.UnixAddr{ 225 Net: "unix", 226 Name: "dst", 227 }, 228 }, 229 { 230 name: "UnixDatagram", 231 header: &Header{ 232 Version: 2, 233 Command: ProtocolVersionAndCommandProxy, 234 TransportProtocol: AddressFamilyAndProtocolUnixDatagram, 235 SourceAddr: &net.UnixAddr{ 236 Net: "unix", 237 Name: "src", 238 }, 239 DestinationAddr: &net.UnixAddr{ 240 Net: "unix", 241 Name: "dst", 242 }, 243 }, 244 unixSourceAddr: &net.UnixAddr{ 245 Net: "unix", 246 Name: "src", 247 }, 248 unixDestAddr: &net.UnixAddr{ 249 Net: "unix", 250 Name: "dst", 251 }, 252 }, 253 { 254 name: "Unspec", 255 header: &Header{ 256 Version: 1, 257 Command: ProtocolVersionAndCommandProxy, 258 TransportProtocol: AddressFamilyAndProtocolUnknown, 259 }, 260 }, 261 } 262 263 for _, test := range tests { 264 t.Run(test.name, func(t *testing.T) { 265 tcpSourceAddr, tcpDestAddr, _ := test.header.TCPAddrs() 266 if test.tcpSourceAddr != nil && !reflect.DeepEqual(tcpSourceAddr, test.tcpSourceAddr) { 267 t.Errorf("TCPAddrs() source = %v, want %v", tcpSourceAddr, test.tcpSourceAddr) 268 } 269 if test.tcpDestAddr != nil && !reflect.DeepEqual(tcpDestAddr, test.tcpDestAddr) { 270 t.Errorf("TCPAddrs() dest = %v, want %v", tcpDestAddr, test.tcpDestAddr) 271 } 272 273 udpSourceAddr, udpDestAddr, _ := test.header.UDPAddrs() 274 if test.udpSourceAddr != nil && !reflect.DeepEqual(udpSourceAddr, test.udpSourceAddr) { 275 t.Errorf("TCPAddrs() source = %v, want %v", udpSourceAddr, test.udpSourceAddr) 276 } 277 if test.udpDestAddr != nil && !reflect.DeepEqual(udpDestAddr, test.udpDestAddr) { 278 t.Errorf("TCPAddrs() dest = %v, want %v", udpDestAddr, test.udpDestAddr) 279 } 280 281 unixSourceAddr, unixDestAddr, _ := test.header.UnixAddrs() 282 if test.unixSourceAddr != nil && !reflect.DeepEqual(unixSourceAddr, test.unixSourceAddr) { 283 t.Errorf("UnixAddrs() source = %v, want %v", unixSourceAddr, test.unixSourceAddr) 284 } 285 if test.unixDestAddr != nil && !reflect.DeepEqual(unixDestAddr, test.unixDestAddr) { 286 t.Errorf("UnixAddrs() dest = %v, want %v", unixDestAddr, test.unixDestAddr) 287 } 288 289 ipSource, ipDest, _ := test.header.IPs() 290 if test.ipSource != nil && !ipSource.Equal(test.ipSource) { 291 t.Errorf("IPs() source = %v, want %v", ipSource, test.ipSource) 292 } 293 if test.ipDest != nil && !ipDest.Equal(test.ipDest) { 294 t.Errorf("IPs() dest = %v, want %v", ipDest, test.ipDest) 295 } 296 297 portSource, portDest, _ := test.header.Ports() 298 if test.portSource != 0 && portSource != test.portSource { 299 t.Errorf("Ports() source = %v, want %v", portSource, test.portSource) 300 } 301 if test.portDest != 0 && portDest != test.portDest { 302 t.Errorf("Ports() dest = %v, want %v", portDest, test.portDest) 303 } 304 }) 305 } 306 } 307 308 func TestSetTLVs(t *testing.T) { 309 tests := []struct { 310 header *Header 311 name string 312 tlvs []TLV 313 expectErr bool 314 }{ 315 { 316 name: "add authority TLV", 317 header: &Header{ 318 Version: 1, 319 Command: ProtocolVersionAndCommandProxy, 320 TransportProtocol: AddressFamilyAndProtocolTCPv4, 321 SourceAddr: &net.TCPAddr{ 322 IP: net.ParseIP("10.1.1.1"), 323 Port: 1000, 324 }, 325 DestinationAddr: &net.TCPAddr{ 326 IP: net.ParseIP("20.2.2.2"), 327 Port: 2000, 328 }, 329 }, 330 tlvs: []TLV{{ 331 Type: PP2TypeAuthority, 332 Value: []byte("example.org"), 333 }}, 334 }, 335 { 336 name: "add too long TLV", 337 header: &Header{ 338 Version: 1, 339 Command: ProtocolVersionAndCommandProxy, 340 TransportProtocol: AddressFamilyAndProtocolTCPv4, 341 SourceAddr: &net.TCPAddr{ 342 IP: net.ParseIP("10.1.1.1"), 343 Port: 1000, 344 }, 345 DestinationAddr: &net.TCPAddr{ 346 IP: net.ParseIP("20.2.2.2"), 347 Port: 2000, 348 }, 349 }, 350 tlvs: []TLV{{ 351 Type: PP2TypeAuthority, 352 Value: append(bytes.Repeat([]byte("a"), 0xFFFF), []byte(".example.org")...), 353 }}, 354 expectErr: true, 355 }, 356 } 357 for _, tt := range tests { 358 err := tt.header.SetTLVs(tt.tlvs) 359 if err != nil && !tt.expectErr { 360 t.Fatalf("shouldn't have thrown error %q", err.Error()) 361 } 362 } 363 } 364 365 func TestWriteTo(t *testing.T) { 366 var buf bytes.Buffer 367 368 validHeader := &Header{ 369 Version: 1, 370 Command: ProtocolVersionAndCommandProxy, 371 TransportProtocol: AddressFamilyAndProtocolTCPv4, 372 SourceAddr: &net.TCPAddr{ 373 IP: net.ParseIP("10.1.1.1"), 374 Port: 1000, 375 }, 376 DestinationAddr: &net.TCPAddr{ 377 IP: net.ParseIP("20.2.2.2"), 378 Port: 2000, 379 }, 380 } 381 382 if _, err := validHeader.WriteTo(&buf); err != nil { 383 t.Fatalf("shouldn't have thrown error %q", err.Error()) 384 } 385 386 invalidHeader := &Header{ 387 SourceAddr: &net.TCPAddr{ 388 IP: net.ParseIP("10.1.1.1"), 389 Port: 1000, 390 }, 391 DestinationAddr: &net.TCPAddr{ 392 IP: net.ParseIP("20.2.2.2"), 393 Port: 2000, 394 }, 395 } 396 397 if _, err := invalidHeader.WriteTo(&buf); err == nil { 398 t.Fatalf("should have thrown error %q", err.Error()) 399 } 400 } 401 402 func TestFormat(t *testing.T) { 403 validHeader := &Header{ 404 Version: 1, 405 Command: ProtocolVersionAndCommandProxy, 406 TransportProtocol: AddressFamilyAndProtocolTCPv4, 407 SourceAddr: &net.TCPAddr{ 408 IP: net.ParseIP("10.1.1.1"), 409 Port: 1000, 410 }, 411 DestinationAddr: &net.TCPAddr{ 412 IP: net.ParseIP("20.2.2.2"), 413 Port: 2000, 414 }, 415 } 416 417 if _, err := validHeader.Format(); err != nil { 418 t.Fatalf("shouldn't have thrown error %q", err.Error()) 419 } 420 } 421 422 func TestFormatInvalid(t *testing.T) { 423 tests := []struct { 424 name string 425 header *Header 426 err error 427 }{ 428 { 429 name: "invalidVersion", 430 header: &Header{ 431 Version: 3, 432 Command: ProtocolVersionAndCommandProxy, 433 TransportProtocol: AddressFamilyAndProtocolTCPv4, 434 SourceAddr: v4addr, 435 DestinationAddr: v4addr, 436 }, 437 err: ErrUnknownProxyProtocolVersion, 438 }, 439 { 440 name: "v2MismatchTCPv4_UDPv4", 441 header: &Header{ 442 Version: 2, 443 Command: ProtocolVersionAndCommandProxy, 444 TransportProtocol: AddressFamilyAndProtocolTCPv4, 445 SourceAddr: v4UDPAddr, 446 DestinationAddr: v4addr, 447 }, 448 err: ErrInvalidAddress, 449 }, 450 { 451 name: "v2MismatchTCPv4_TCPv6", 452 header: &Header{ 453 Version: 2, 454 Command: ProtocolVersionAndCommandProxy, 455 TransportProtocol: AddressFamilyAndProtocolTCPv4, 456 SourceAddr: v4addr, 457 DestinationAddr: v6addr, 458 }, 459 err: ErrInvalidAddress, 460 }, 461 { 462 name: "v2MismatchUnixStream_TCPv4", 463 header: &Header{ 464 Version: 2, 465 Command: ProtocolVersionAndCommandProxy, 466 TransportProtocol: AddressFamilyAndProtocolUnixStream, 467 SourceAddr: v4addr, 468 DestinationAddr: unixStreamAddr, 469 }, 470 err: ErrInvalidAddress, 471 }, 472 { 473 name: "v1MismatchTCPv4_TCPv6", 474 header: &Header{ 475 Version: 1, 476 Command: ProtocolVersionAndCommandProxy, 477 TransportProtocol: AddressFamilyAndProtocolTCPv4, 478 SourceAddr: v6addr, 479 DestinationAddr: v4addr, 480 }, 481 err: ErrInvalidAddress, 482 }, 483 { 484 name: "v1MismatchTCPv4_UDPv4", 485 header: &Header{ 486 Version: 1, 487 Command: ProtocolVersionAndCommandProxy, 488 TransportProtocol: AddressFamilyAndProtocolTCPv4, 489 SourceAddr: v4UDPAddr, 490 DestinationAddr: v4addr, 491 }, 492 err: ErrInvalidAddress, 493 }, 494 } 495 496 for _, test := range tests { 497 t.Run(test.name, func(t *testing.T) { 498 if _, err := test.header.Format(); err == nil { 499 t.Errorf("Header.Format() succeeded, want an error") 500 } else if err != test.err { 501 t.Errorf("Header.Format() = %q, want %q", err, test.err) 502 } 503 }) 504 } 505 }