go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/proxyproto/header_test.go (about) 1 /* 2 3 Copyright (c) 2023 - Present. Will Charczuk. All rights reserved. 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository. 5 6 */ 7 8 package proxyproto 9 10 import ( 11 "bytes" 12 "errors" 13 "net" 14 "reflect" 15 "testing" 16 ) 17 18 // Stuff to be used in both versions tests. 19 20 const ( 21 noProtocol = "There is no spoon" 22 ip4Addr = "127.0.0.1" 23 ip6Addr = "::1" 24 ip6LongAddr = "1234:5678:9abc:def0:cafe:babe:dead:2bad" 25 port = 65533 26 invalidPort = 99999 27 ) 28 29 var ( 30 v4ip = net.ParseIP(ip4Addr).To4() 31 v6ip = net.ParseIP(ip6Addr).To16() 32 33 v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: port} 34 v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: port} 35 36 v4UDPAddr net.Addr = &net.UDPAddr{IP: v4ip, Port: port} 37 v6UDPAddr net.Addr = &net.UDPAddr{IP: v6ip, Port: port} 38 39 unixStreamAddr net.Addr = &net.UnixAddr{Net: "unix", Name: "socket"} 40 unixDatagramAddr net.Addr = &net.UnixAddr{Net: "unixgram", Name: "socket"} 41 42 errReadIntentionallyBroken = errors.New("read is intentionally broken") 43 ) 44 45 func Test_Header_EqualsTo(t *testing.T) { 46 var headersEqual = []struct { 47 this, that *Header 48 expected bool 49 }{ 50 { 51 &Header{ 52 Version: 1, 53 Command: ProtocolVersionAndCommandProxy, 54 TransportProtocol: AddressFamilyAndProtocolTCPv4, 55 SourceAddr: &net.TCPAddr{ 56 IP: net.ParseIP("10.1.1.1"), 57 Port: 1000, 58 }, 59 DestinationAddr: &net.TCPAddr{ 60 IP: net.ParseIP("20.2.2.2"), 61 Port: 2000, 62 }, 63 }, 64 nil, 65 false, 66 }, 67 { 68 &Header{ 69 Version: 1, 70 Command: ProtocolVersionAndCommandProxy, 71 TransportProtocol: AddressFamilyAndProtocolTCPv4, 72 SourceAddr: &net.TCPAddr{ 73 IP: net.ParseIP("10.1.1.1"), 74 Port: 1000, 75 }, 76 DestinationAddr: &net.TCPAddr{ 77 IP: net.ParseIP("20.2.2.2"), 78 Port: 2000, 79 }, 80 }, 81 &Header{ 82 Version: 2, 83 Command: ProtocolVersionAndCommandProxy, 84 TransportProtocol: AddressFamilyAndProtocolTCPv4, 85 SourceAddr: &net.TCPAddr{ 86 IP: net.ParseIP("10.1.1.1"), 87 Port: 1000, 88 }, 89 DestinationAddr: &net.TCPAddr{ 90 IP: net.ParseIP("20.2.2.2"), 91 Port: 2000, 92 }, 93 }, 94 false, 95 }, 96 { 97 &Header{ 98 Version: 1, 99 Command: ProtocolVersionAndCommandProxy, 100 TransportProtocol: AddressFamilyAndProtocolTCPv4, 101 SourceAddr: &net.TCPAddr{ 102 IP: net.ParseIP("10.1.1.1"), 103 Port: 1000, 104 }, 105 DestinationAddr: &net.TCPAddr{ 106 IP: net.ParseIP("20.2.2.2"), 107 Port: 2000, 108 }, 109 }, 110 &Header{ 111 Version: 1, 112 Command: ProtocolVersionAndCommandProxy, 113 TransportProtocol: AddressFamilyAndProtocolTCPv4, 114 SourceAddr: &net.TCPAddr{ 115 IP: net.ParseIP("10.1.1.1"), 116 Port: 1000, 117 }, 118 DestinationAddr: &net.TCPAddr{ 119 IP: net.ParseIP("20.2.2.2"), 120 Port: 2000, 121 }, 122 }, 123 true, 124 }, 125 } 126 127 for _, tt := range headersEqual { 128 if actual := tt.this.EqualsTo(tt.that); actual != tt.expected { 129 t.Fatalf("expected %t, actual %t", tt.expected, actual) 130 } 131 } 132 } 133 134 func Test_Header_getters(t *testing.T) { 135 var tests = []struct { 136 name string 137 header *Header 138 tcpSourceAddr, tcpDestAddr *net.TCPAddr 139 udpSourceAddr, udpDestAddr *net.UDPAddr 140 unixSourceAddr, unixDestAddr *net.UnixAddr 141 ipSource, ipDest net.IP 142 portSource, portDest int 143 }{ 144 { 145 name: "AddressFamilyAndProtocolTCPv4", 146 header: &Header{ 147 Version: 1, 148 Command: ProtocolVersionAndCommandProxy, 149 TransportProtocol: AddressFamilyAndProtocolTCPv4, 150 SourceAddr: &net.TCPAddr{ 151 IP: net.ParseIP("10.1.1.1"), 152 Port: 1000, 153 }, 154 DestinationAddr: &net.TCPAddr{ 155 IP: net.ParseIP("20.2.2.2"), 156 Port: 2000, 157 }, 158 }, 159 tcpSourceAddr: &net.TCPAddr{ 160 IP: net.ParseIP("10.1.1.1"), 161 Port: 1000, 162 }, 163 tcpDestAddr: &net.TCPAddr{ 164 IP: net.ParseIP("20.2.2.2"), 165 Port: 2000, 166 }, 167 ipSource: net.ParseIP("10.1.1.1"), 168 ipDest: net.ParseIP("20.2.2.2"), 169 portSource: 1000, 170 portDest: 2000, 171 }, 172 { 173 name: "UDPv4", 174 header: &Header{ 175 Version: 2, 176 Command: ProtocolVersionAndCommandProxy, 177 TransportProtocol: AddressFamilyAndProtocolUDPv6, 178 SourceAddr: &net.UDPAddr{ 179 IP: net.ParseIP("10.1.1.1"), 180 Port: 1000, 181 }, 182 DestinationAddr: &net.UDPAddr{ 183 IP: net.ParseIP("20.2.2.2"), 184 Port: 2000, 185 }, 186 }, 187 udpSourceAddr: &net.UDPAddr{ 188 IP: net.ParseIP("10.1.1.1"), 189 Port: 1000, 190 }, 191 udpDestAddr: &net.UDPAddr{ 192 IP: net.ParseIP("20.2.2.2"), 193 Port: 2000, 194 }, 195 ipSource: net.ParseIP("10.1.1.1"), 196 ipDest: net.ParseIP("20.2.2.2"), 197 portSource: 1000, 198 portDest: 2000, 199 }, 200 { 201 name: "UnixStream", 202 header: &Header{ 203 Version: 2, 204 Command: ProtocolVersionAndCommandProxy, 205 TransportProtocol: AddressFamilyAndProtocolUnixStream, 206 SourceAddr: &net.UnixAddr{ 207 Net: "unix", 208 Name: "src", 209 }, 210 DestinationAddr: &net.UnixAddr{ 211 Net: "unix", 212 Name: "dst", 213 }, 214 }, 215 unixSourceAddr: &net.UnixAddr{ 216 Net: "unix", 217 Name: "src", 218 }, 219 unixDestAddr: &net.UnixAddr{ 220 Net: "unix", 221 Name: "dst", 222 }, 223 }, 224 { 225 name: "UnixDatagram", 226 header: &Header{ 227 Version: 2, 228 Command: ProtocolVersionAndCommandProxy, 229 TransportProtocol: AddressFamilyAndProtocolUnixDatagram, 230 SourceAddr: &net.UnixAddr{ 231 Net: "unix", 232 Name: "src", 233 }, 234 DestinationAddr: &net.UnixAddr{ 235 Net: "unix", 236 Name: "dst", 237 }, 238 }, 239 unixSourceAddr: &net.UnixAddr{ 240 Net: "unix", 241 Name: "src", 242 }, 243 unixDestAddr: &net.UnixAddr{ 244 Net: "unix", 245 Name: "dst", 246 }, 247 }, 248 { 249 name: "Unspec", 250 header: &Header{ 251 Version: 1, 252 Command: ProtocolVersionAndCommandProxy, 253 TransportProtocol: AddressFamilyAndProtocolUnknown, 254 }, 255 }, 256 } 257 258 for _, test := range tests { 259 t.Run(test.name, func(t *testing.T) { 260 tcpSourceAddr, tcpDestAddr, _ := test.header.TCPAddrs() 261 if test.tcpSourceAddr != nil && !reflect.DeepEqual(tcpSourceAddr, test.tcpSourceAddr) { 262 t.Errorf("TCPAddrs() source = %v, want %v", tcpSourceAddr, test.tcpSourceAddr) 263 } 264 if test.tcpDestAddr != nil && !reflect.DeepEqual(tcpDestAddr, test.tcpDestAddr) { 265 t.Errorf("TCPAddrs() dest = %v, want %v", tcpDestAddr, test.tcpDestAddr) 266 } 267 268 udpSourceAddr, udpDestAddr, _ := test.header.UDPAddrs() 269 if test.udpSourceAddr != nil && !reflect.DeepEqual(udpSourceAddr, test.udpSourceAddr) { 270 t.Errorf("TCPAddrs() source = %v, want %v", udpSourceAddr, test.udpSourceAddr) 271 } 272 if test.udpDestAddr != nil && !reflect.DeepEqual(udpDestAddr, test.udpDestAddr) { 273 t.Errorf("TCPAddrs() dest = %v, want %v", udpDestAddr, test.udpDestAddr) 274 } 275 276 unixSourceAddr, unixDestAddr, _ := test.header.UnixAddrs() 277 if test.unixSourceAddr != nil && !reflect.DeepEqual(unixSourceAddr, test.unixSourceAddr) { 278 t.Errorf("UnixAddrs() source = %v, want %v", unixSourceAddr, test.unixSourceAddr) 279 } 280 if test.unixDestAddr != nil && !reflect.DeepEqual(unixDestAddr, test.unixDestAddr) { 281 t.Errorf("UnixAddrs() dest = %v, want %v", unixDestAddr, test.unixDestAddr) 282 } 283 284 ipSource, ipDest, _ := test.header.IPs() 285 if test.ipSource != nil && !ipSource.Equal(test.ipSource) { 286 t.Errorf("IPs() source = %v, want %v", ipSource, test.ipSource) 287 } 288 if test.ipDest != nil && !ipDest.Equal(test.ipDest) { 289 t.Errorf("IPs() dest = %v, want %v", ipDest, test.ipDest) 290 } 291 292 portSource, portDest, _ := test.header.Ports() 293 if test.portSource != 0 && portSource != test.portSource { 294 t.Errorf("Ports() source = %v, want %v", portSource, test.portSource) 295 } 296 if test.portDest != 0 && portDest != test.portDest { 297 t.Errorf("Ports() dest = %v, want %v", portDest, test.portDest) 298 } 299 }) 300 } 301 } 302 303 func Test_Header_SetTLVs(t *testing.T) { 304 tests := []struct { 305 header *Header 306 name string 307 tlvs []TLV 308 expectErr bool 309 }{ 310 { 311 name: "add authority TLV", 312 header: &Header{ 313 Version: 1, 314 Command: ProtocolVersionAndCommandProxy, 315 TransportProtocol: AddressFamilyAndProtocolTCPv4, 316 SourceAddr: &net.TCPAddr{ 317 IP: net.ParseIP("10.1.1.1"), 318 Port: 1000, 319 }, 320 DestinationAddr: &net.TCPAddr{ 321 IP: net.ParseIP("20.2.2.2"), 322 Port: 2000, 323 }, 324 }, 325 tlvs: []TLV{{ 326 Type: PP2TypeAuthority, 327 Value: []byte("example.org"), 328 }}, 329 }, 330 { 331 name: "add too long TLV", 332 header: &Header{ 333 Version: 1, 334 Command: ProtocolVersionAndCommandProxy, 335 TransportProtocol: AddressFamilyAndProtocolTCPv4, 336 SourceAddr: &net.TCPAddr{ 337 IP: net.ParseIP("10.1.1.1"), 338 Port: 1000, 339 }, 340 DestinationAddr: &net.TCPAddr{ 341 IP: net.ParseIP("20.2.2.2"), 342 Port: 2000, 343 }, 344 }, 345 tlvs: []TLV{{ 346 Type: PP2TypeAuthority, 347 Value: append(bytes.Repeat([]byte("a"), 0xFFFF), []byte(".example.org")...), 348 }}, 349 expectErr: true, 350 }, 351 } 352 for _, tt := range tests { 353 err := tt.header.SetTLVs(tt.tlvs) 354 if err != nil && !tt.expectErr { 355 t.Fatalf("shouldn't have thrown error %q", err.Error()) 356 } 357 } 358 } 359 360 func Test_Header_WriteTo(t *testing.T) { 361 var buf bytes.Buffer 362 363 validHeader := &Header{ 364 Version: 1, 365 Command: ProtocolVersionAndCommandProxy, 366 TransportProtocol: AddressFamilyAndProtocolTCPv4, 367 SourceAddr: &net.TCPAddr{ 368 IP: net.ParseIP("10.1.1.1"), 369 Port: 1000, 370 }, 371 DestinationAddr: &net.TCPAddr{ 372 IP: net.ParseIP("20.2.2.2"), 373 Port: 2000, 374 }, 375 } 376 377 if _, err := validHeader.WriteTo(&buf); err != nil { 378 t.Fatalf("shouldn't have thrown error %q", err.Error()) 379 } 380 381 invalidHeader := &Header{ 382 SourceAddr: &net.TCPAddr{ 383 IP: net.ParseIP("10.1.1.1"), 384 Port: 1000, 385 }, 386 DestinationAddr: &net.TCPAddr{ 387 IP: net.ParseIP("20.2.2.2"), 388 Port: 2000, 389 }, 390 } 391 392 if _, err := invalidHeader.WriteTo(&buf); err == nil { 393 t.Fatalf("should have thrown error %q", err.Error()) 394 } 395 } 396 397 func Test_Header_Format(t *testing.T) { 398 validHeader := &Header{ 399 Version: 1, 400 Command: ProtocolVersionAndCommandProxy, 401 TransportProtocol: AddressFamilyAndProtocolTCPv4, 402 SourceAddr: &net.TCPAddr{ 403 IP: net.ParseIP("10.1.1.1"), 404 Port: 1000, 405 }, 406 DestinationAddr: &net.TCPAddr{ 407 IP: net.ParseIP("20.2.2.2"), 408 Port: 2000, 409 }, 410 } 411 412 if _, err := validHeader.Format(); err != nil { 413 t.Fatalf("shouldn't have thrown error %q", err.Error()) 414 } 415 } 416 417 func Test_Header_Format_invalid(t *testing.T) { 418 tests := []struct { 419 name string 420 header *Header 421 err error 422 }{ 423 { 424 name: "invalidVersion", 425 header: &Header{ 426 Version: 3, 427 Command: ProtocolVersionAndCommandProxy, 428 TransportProtocol: AddressFamilyAndProtocolTCPv4, 429 SourceAddr: v4addr, 430 DestinationAddr: v4addr, 431 }, 432 err: ErrUnknownProxyProtocolVersion, 433 }, 434 { 435 name: "v2MismatchTCPv4_UDPv4", 436 header: &Header{ 437 Version: 2, 438 Command: ProtocolVersionAndCommandProxy, 439 TransportProtocol: AddressFamilyAndProtocolTCPv4, 440 SourceAddr: v4UDPAddr, 441 DestinationAddr: v4addr, 442 }, 443 err: ErrInvalidAddress, 444 }, 445 { 446 name: "v2MismatchTCPv4_TCPv6", 447 header: &Header{ 448 Version: 2, 449 Command: ProtocolVersionAndCommandProxy, 450 TransportProtocol: AddressFamilyAndProtocolTCPv4, 451 SourceAddr: v4addr, 452 DestinationAddr: v6addr, 453 }, 454 err: ErrInvalidAddress, 455 }, 456 { 457 name: "v2MismatchUnixStream_TCPv4", 458 header: &Header{ 459 Version: 2, 460 Command: ProtocolVersionAndCommandProxy, 461 TransportProtocol: AddressFamilyAndProtocolUnixStream, 462 SourceAddr: v4addr, 463 DestinationAddr: unixStreamAddr, 464 }, 465 err: ErrInvalidAddress, 466 }, 467 { 468 name: "v1MismatchTCPv4_TCPv6", 469 header: &Header{ 470 Version: 1, 471 Command: ProtocolVersionAndCommandProxy, 472 TransportProtocol: AddressFamilyAndProtocolTCPv4, 473 SourceAddr: v6addr, 474 DestinationAddr: v4addr, 475 }, 476 err: ErrInvalidAddress, 477 }, 478 { 479 name: "v1MismatchTCPv4_UDPv4", 480 header: &Header{ 481 Version: 1, 482 Command: ProtocolVersionAndCommandProxy, 483 TransportProtocol: AddressFamilyAndProtocolTCPv4, 484 SourceAddr: v4UDPAddr, 485 DestinationAddr: v4addr, 486 }, 487 err: ErrInvalidAddress, 488 }, 489 } 490 491 for _, test := range tests { 492 t.Run(test.name, func(t *testing.T) { 493 if _, err := test.header.Format(); err == nil { 494 t.Errorf("Header.Format() succeeded, want an error") 495 } else if err != test.err { 496 t.Errorf("Header.Format() = %q, want %q", err, test.err) 497 } 498 }) 499 } 500 }