github.com/blend/go-sdk@v1.20220411.3/proxyprotocol/proxy_protocol_test.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package proxyprotocol 9 10 import ( 11 "bytes" 12 "fmt" 13 "net" 14 "testing" 15 "time" 16 ) 17 18 const ( 19 goodAddr = "127.0.0.1" 20 badAddr = "127.0.0.2" 21 errAddr = "9999.0.0.2" 22 ) 23 24 var ( 25 checkAddr string 26 ) 27 28 func TestPassthrough(t *testing.T) { 29 l, err := net.Listen("tcp", "127.0.0.1:0") 30 if err != nil { 31 t.Fatalf("err: %v", err) 32 } 33 34 pl := &Listener{Listener: l} 35 36 errors := make(chan error, 4) 37 go func() { 38 conn, err := net.Dial("tcp", pl.Addr().String()) 39 if err != nil { 40 errors <- err 41 return 42 } 43 defer conn.Close() 44 45 _, err = conn.Write([]byte("ping")) 46 if err != nil { 47 errors <- err 48 return 49 } 50 recv := make([]byte, 4) 51 _, err = conn.Read(recv) 52 if err != nil { 53 errors <- err 54 return 55 } 56 if !bytes.Equal(recv, []byte("pong")) { 57 errors <- fmt.Errorf("bad: %v", recv) 58 return 59 } 60 }() 61 62 conn, err := pl.Accept() 63 if err != nil { 64 t.Fatalf("err: %v", err) 65 } 66 defer func() { _ = conn.Close() }() 67 68 recv := make([]byte, 4) 69 _, err = conn.Read(recv) 70 if err != nil { 71 t.Fatalf("err: %v", err) 72 } 73 if !bytes.Equal(recv, []byte("ping")) { 74 t.Fatalf("bad: %v", recv) 75 } 76 77 if _, err := conn.Write([]byte("pong")); err != nil { 78 t.Fatalf("err: %v", err) 79 } 80 81 if len(errors) > 0 { 82 t.Fatal(<-errors) 83 } 84 } 85 86 func TestTimeout(t *testing.T) { 87 l, err := net.Listen("tcp", "127.0.0.1:0") 88 if err != nil { 89 t.Fatalf("err: %v", err) 90 } 91 92 clientWriteDelay := 200 * time.Millisecond 93 proxyHeaderTimeout := 50 * time.Millisecond 94 pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout} 95 96 errors := make(chan error, 4) 97 go func() { 98 conn, err := net.Dial("tcp", pl.Addr().String()) 99 if err != nil { 100 errors <- err 101 return 102 } 103 defer conn.Close() 104 105 // Do not send data for a while 106 time.Sleep(clientWriteDelay) 107 108 _, err = conn.Write([]byte("ping")) 109 if err != nil { 110 errors <- err 111 return 112 } 113 recv := make([]byte, 4) 114 _, err = conn.Read(recv) 115 if err != nil { 116 errors <- err 117 return 118 } 119 if !bytes.Equal(recv, []byte("pong")) { 120 errors <- fmt.Errorf("bad: %v", recv) 121 return 122 } 123 }() 124 125 conn, err := pl.Accept() 126 if err != nil { 127 t.Fatalf("err: %v", err) 128 } 129 defer conn.Close() 130 131 // Check the remote addr is the original 127.0.0.1 132 remoteAddrStartTime := time.Now() 133 addr := conn.RemoteAddr().(*net.TCPAddr) 134 if addr.IP.String() != "127.0.0.1" { 135 t.Fatalf("bad: %v", addr) 136 } 137 remoteAddrDuration := time.Since(remoteAddrStartTime) 138 139 // Check RemoteAddr() call did timeout 140 if remoteAddrDuration >= clientWriteDelay { 141 t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration) 142 } 143 144 recv := make([]byte, 4) 145 _, err = conn.Read(recv) 146 if err != nil { 147 t.Fatalf("err: %v", err) 148 } 149 if !bytes.Equal(recv, []byte("ping")) { 150 t.Fatalf("bad: %v", recv) 151 } 152 153 if _, err := conn.Write([]byte("pong")); err != nil { 154 t.Fatalf("err: %v", err) 155 } 156 157 if len(errors) > 0 { 158 t.Fatal(<-errors) 159 } 160 } 161 162 func TestParse_ipv4(t *testing.T) { 163 l, err := net.Listen("tcp", "127.0.0.1:0") 164 if err != nil { 165 t.Fatalf("err: %v", err) 166 } 167 168 pl := &Listener{Listener: l} 169 170 errors := make(chan error, 5) 171 172 go func() { 173 conn, err := net.Dial("tcp", pl.Addr().String()) 174 if err != nil { 175 errors <- err 176 return 177 } 178 defer conn.Close() 179 180 // Write out the header! 181 header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n" 182 _, err = conn.Write([]byte(header)) 183 if err != nil { 184 errors <- err 185 return 186 } 187 188 _, err = conn.Write([]byte("ping")) 189 if err != nil { 190 errors <- err 191 return 192 } 193 194 recv := make([]byte, 4) 195 _, err = conn.Read(recv) 196 if err != nil { 197 errors <- err 198 return 199 } 200 if !bytes.Equal(recv, []byte("pong")) { 201 errors <- fmt.Errorf("bad: %v", recv) 202 return 203 } 204 }() 205 206 conn, err := pl.Accept() 207 if err != nil { 208 t.Fatalf("err: %v", err) 209 } 210 defer conn.Close() 211 212 recv := make([]byte, 4) 213 _, err = conn.Read(recv) 214 if err != nil { 215 t.Fatalf("err: %v", err) 216 } 217 if !bytes.Equal(recv, []byte("ping")) { 218 t.Fatalf("bad: %v", recv) 219 } 220 221 if _, err := conn.Write([]byte("pong")); err != nil { 222 t.Fatalf("err: %v", err) 223 } 224 225 // Check the remote addr 226 addr := conn.RemoteAddr().(*net.TCPAddr) 227 if addr.IP.String() != "10.1.1.1" { 228 t.Fatalf("bad: %v", addr) 229 } 230 if addr.Port != 1000 { 231 t.Fatalf("bad: %v", addr) 232 } 233 234 if len(errors) > 0 { 235 t.Fatal(<-errors) 236 } 237 } 238 239 func TestParse_ipv6(t *testing.T) { 240 l, err := net.Listen("tcp", "127.0.0.1:0") 241 if err != nil { 242 t.Fatalf("err: %v", err) 243 } 244 245 pl := &Listener{Listener: l} 246 247 errors := make(chan error, 5) 248 go func() { 249 conn, err := net.Dial("tcp", pl.Addr().String()) 250 if err != nil { 251 errors <- err 252 return 253 } 254 defer conn.Close() 255 256 // Write out the header! 257 header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n" 258 _, err = conn.Write([]byte(header)) 259 if err != nil { 260 errors <- err 261 return 262 } 263 264 _, err = conn.Write([]byte("ping")) 265 if err != nil { 266 errors <- err 267 return 268 } 269 270 recv := make([]byte, 4) 271 _, err = conn.Read(recv) 272 if err != nil { 273 errors <- err 274 return 275 } 276 if !bytes.Equal(recv, []byte("pong")) { 277 errors <- fmt.Errorf("bad: %v", recv) 278 return 279 } 280 }() 281 282 conn, err := pl.Accept() 283 if err != nil { 284 t.Fatalf("err: %v", err) 285 } 286 defer conn.Close() 287 288 recv := make([]byte, 4) 289 _, err = conn.Read(recv) 290 if err != nil { 291 t.Fatalf("err: %v", err) 292 } 293 if !bytes.Equal(recv, []byte("ping")) { 294 t.Fatalf("bad: %v", recv) 295 } 296 297 if _, err := conn.Write([]byte("pong")); err != nil { 298 t.Fatalf("err: %v", err) 299 } 300 301 // Check the remote addr 302 addr := conn.RemoteAddr().(*net.TCPAddr) 303 if addr.IP.String() != "ffff::ffff" { 304 t.Fatalf("bad: %v", addr) 305 } 306 if addr.Port != 1000 { 307 t.Fatalf("bad: %v", addr) 308 } 309 310 if len(errors) > 0 { 311 t.Fatal(<-errors) 312 } 313 } 314 315 func TestParse_BadHeader(t *testing.T) { 316 l, err := net.Listen("tcp", "127.0.0.1:0") 317 if err != nil { 318 t.Fatalf("err: %v", err) 319 } 320 321 pl := &Listener{Listener: l} 322 323 errors := make(chan error, 5) 324 go func() { 325 conn, err := net.Dial("tcp", pl.Addr().String()) 326 if err != nil { 327 errors <- err 328 return 329 } 330 defer conn.Close() 331 332 // Write out the header! 333 header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n" 334 _, err = conn.Write([]byte(header)) 335 if err != nil { 336 errors <- err 337 return 338 } 339 340 _, err = conn.Write([]byte("ping")) 341 if err != nil { 342 errors <- err 343 return 344 } 345 346 recv := make([]byte, 4) 347 _, err = conn.Read(recv) 348 if err == nil { 349 errors <- fmt.Errorf("err: %v", err) 350 return 351 } 352 }() 353 354 conn, err := pl.Accept() 355 if err != nil { 356 t.Fatalf("err: %v", err) 357 } 358 defer conn.Close() 359 360 // Check the remote addr, should be the local addr 361 addr := conn.RemoteAddr().(*net.TCPAddr) 362 if addr.IP.String() != "127.0.0.1" { 363 t.Fatalf("bad: %v", addr) 364 } 365 366 // Read should fail 367 recv := make([]byte, 4) 368 _, err = conn.Read(recv) 369 if err == nil { 370 t.Fatal("err should be set") 371 } 372 } 373 374 func TestParseIPv4CheckFunc(t *testing.T) { 375 checkAddr = goodAddr 376 testParseIpv4CheckFunc(t) 377 checkAddr = badAddr 378 testParseIpv4CheckFunc(t) 379 checkAddr = errAddr 380 testParseIpv4CheckFunc(t) 381 } 382 383 func testParseIpv4CheckFunc(t *testing.T) { 384 l, err := net.Listen("tcp", "127.0.0.1:0") 385 if err != nil { 386 t.Fatalf("err: %v", err) 387 } 388 389 checkFunc := func(addr net.Addr) (bool, error) { 390 tcpAddr := addr.(*net.TCPAddr) 391 if tcpAddr.IP.String() == checkAddr { 392 return true, nil 393 } 394 return false, nil 395 } 396 397 pl := &Listener{Listener: l, SourceCheck: checkFunc} 398 399 errors := make(chan error, 4) 400 go func() { 401 conn, err := net.Dial("tcp", pl.Addr().String()) 402 if err != nil { 403 errors <- err 404 return 405 } 406 defer conn.Close() 407 408 // Write out the header! 409 header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n" 410 _, err = conn.Write([]byte(header)) 411 if err != nil { 412 errors <- err 413 return 414 } 415 416 _, err = conn.Write([]byte("ping")) 417 if err != nil { 418 errors <- err 419 return 420 } 421 recv := make([]byte, 4) 422 _, err = conn.Read(recv) 423 if err != nil { 424 errors <- err 425 return 426 } 427 if !bytes.Equal(recv, []byte("pong")) { 428 errors <- fmt.Errorf("bad: %v", recv) 429 return 430 } 431 }() 432 433 conn, err := pl.Accept() 434 if err != nil { 435 if checkAddr == badAddr { 436 return 437 } 438 t.Fatalf("err: %v", err) 439 } 440 defer conn.Close() 441 442 recv := make([]byte, 4) 443 _, err = conn.Read(recv) 444 if err != nil { 445 t.Fatalf("err: %v", err) 446 } 447 if !bytes.Equal(recv, []byte("ping")) { 448 t.Fatalf("bad: %v", recv) 449 } 450 451 if _, err := conn.Write([]byte("pong")); err != nil { 452 t.Fatalf("err: %v", err) 453 } 454 455 // Check the remote addr 456 addr := conn.RemoteAddr().(*net.TCPAddr) 457 switch checkAddr { 458 case goodAddr: 459 if addr.IP.String() != "10.1.1.1" { 460 t.Fatalf("bad: %v", addr) 461 } 462 if addr.Port != 1000 { 463 t.Fatalf("bad: %v", addr) 464 } 465 case badAddr: 466 if addr.IP.String() != "127.0.0.1" { 467 t.Fatalf("bad: %v", addr) 468 } 469 if addr.Port == 1000 { 470 t.Fatalf("bad: %v", addr) 471 } 472 } 473 if len(errors) > 0 { 474 t.Fatal(<-errors) 475 } 476 }