github.com/blend/go-sdk@v1.20220411.3/proxyprotocol/proxy_protocol_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package proxyprotocol
     9  
    10  import (
    11  	"bytes"
    12  	"fmt"
    13  	"net"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  const (
    19  	goodAddr = "127.0.0.1"
    20  	badAddr  = "127.0.0.2"
    21  	errAddr  = "9999.0.0.2"
    22  )
    23  
    24  var (
    25  	checkAddr string
    26  )
    27  
    28  func TestPassthrough(t *testing.T) {
    29  	l, err := net.Listen("tcp", "127.0.0.1:0")
    30  	if err != nil {
    31  		t.Fatalf("err: %v", err)
    32  	}
    33  
    34  	pl := &Listener{Listener: l}
    35  
    36  	errors := make(chan error, 4)
    37  	go func() {
    38  		conn, err := net.Dial("tcp", pl.Addr().String())
    39  		if err != nil {
    40  			errors <- err
    41  			return
    42  		}
    43  		defer conn.Close()
    44  
    45  		_, err = conn.Write([]byte("ping"))
    46  		if err != nil {
    47  			errors <- err
    48  			return
    49  		}
    50  		recv := make([]byte, 4)
    51  		_, err = conn.Read(recv)
    52  		if err != nil {
    53  			errors <- err
    54  			return
    55  		}
    56  		if !bytes.Equal(recv, []byte("pong")) {
    57  			errors <- fmt.Errorf("bad: %v", recv)
    58  			return
    59  		}
    60  	}()
    61  
    62  	conn, err := pl.Accept()
    63  	if err != nil {
    64  		t.Fatalf("err: %v", err)
    65  	}
    66  	defer func() { _ = conn.Close() }()
    67  
    68  	recv := make([]byte, 4)
    69  	_, err = conn.Read(recv)
    70  	if err != nil {
    71  		t.Fatalf("err: %v", err)
    72  	}
    73  	if !bytes.Equal(recv, []byte("ping")) {
    74  		t.Fatalf("bad: %v", recv)
    75  	}
    76  
    77  	if _, err := conn.Write([]byte("pong")); err != nil {
    78  		t.Fatalf("err: %v", err)
    79  	}
    80  
    81  	if len(errors) > 0 {
    82  		t.Fatal(<-errors)
    83  	}
    84  }
    85  
    86  func TestTimeout(t *testing.T) {
    87  	l, err := net.Listen("tcp", "127.0.0.1:0")
    88  	if err != nil {
    89  		t.Fatalf("err: %v", err)
    90  	}
    91  
    92  	clientWriteDelay := 200 * time.Millisecond
    93  	proxyHeaderTimeout := 50 * time.Millisecond
    94  	pl := &Listener{Listener: l, ProxyHeaderTimeout: proxyHeaderTimeout}
    95  
    96  	errors := make(chan error, 4)
    97  	go func() {
    98  		conn, err := net.Dial("tcp", pl.Addr().String())
    99  		if err != nil {
   100  			errors <- err
   101  			return
   102  		}
   103  		defer conn.Close()
   104  
   105  		// Do not send data for a while
   106  		time.Sleep(clientWriteDelay)
   107  
   108  		_, err = conn.Write([]byte("ping"))
   109  		if err != nil {
   110  			errors <- err
   111  			return
   112  		}
   113  		recv := make([]byte, 4)
   114  		_, err = conn.Read(recv)
   115  		if err != nil {
   116  			errors <- err
   117  			return
   118  		}
   119  		if !bytes.Equal(recv, []byte("pong")) {
   120  			errors <- fmt.Errorf("bad: %v", recv)
   121  			return
   122  		}
   123  	}()
   124  
   125  	conn, err := pl.Accept()
   126  	if err != nil {
   127  		t.Fatalf("err: %v", err)
   128  	}
   129  	defer conn.Close()
   130  
   131  	// Check the remote addr is the original 127.0.0.1
   132  	remoteAddrStartTime := time.Now()
   133  	addr := conn.RemoteAddr().(*net.TCPAddr)
   134  	if addr.IP.String() != "127.0.0.1" {
   135  		t.Fatalf("bad: %v", addr)
   136  	}
   137  	remoteAddrDuration := time.Since(remoteAddrStartTime)
   138  
   139  	// Check RemoteAddr() call did timeout
   140  	if remoteAddrDuration >= clientWriteDelay {
   141  		t.Fatalf("RemoteAddr() took longer than the specified timeout: %v < %v", proxyHeaderTimeout, remoteAddrDuration)
   142  	}
   143  
   144  	recv := make([]byte, 4)
   145  	_, err = conn.Read(recv)
   146  	if err != nil {
   147  		t.Fatalf("err: %v", err)
   148  	}
   149  	if !bytes.Equal(recv, []byte("ping")) {
   150  		t.Fatalf("bad: %v", recv)
   151  	}
   152  
   153  	if _, err := conn.Write([]byte("pong")); err != nil {
   154  		t.Fatalf("err: %v", err)
   155  	}
   156  
   157  	if len(errors) > 0 {
   158  		t.Fatal(<-errors)
   159  	}
   160  }
   161  
   162  func TestParse_ipv4(t *testing.T) {
   163  	l, err := net.Listen("tcp", "127.0.0.1:0")
   164  	if err != nil {
   165  		t.Fatalf("err: %v", err)
   166  	}
   167  
   168  	pl := &Listener{Listener: l}
   169  
   170  	errors := make(chan error, 5)
   171  
   172  	go func() {
   173  		conn, err := net.Dial("tcp", pl.Addr().String())
   174  		if err != nil {
   175  			errors <- err
   176  			return
   177  		}
   178  		defer conn.Close()
   179  
   180  		// Write out the header!
   181  		header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
   182  		_, err = conn.Write([]byte(header))
   183  		if err != nil {
   184  			errors <- err
   185  			return
   186  		}
   187  
   188  		_, err = conn.Write([]byte("ping"))
   189  		if err != nil {
   190  			errors <- err
   191  			return
   192  		}
   193  
   194  		recv := make([]byte, 4)
   195  		_, err = conn.Read(recv)
   196  		if err != nil {
   197  			errors <- err
   198  			return
   199  		}
   200  		if !bytes.Equal(recv, []byte("pong")) {
   201  			errors <- fmt.Errorf("bad: %v", recv)
   202  			return
   203  		}
   204  	}()
   205  
   206  	conn, err := pl.Accept()
   207  	if err != nil {
   208  		t.Fatalf("err: %v", err)
   209  	}
   210  	defer conn.Close()
   211  
   212  	recv := make([]byte, 4)
   213  	_, err = conn.Read(recv)
   214  	if err != nil {
   215  		t.Fatalf("err: %v", err)
   216  	}
   217  	if !bytes.Equal(recv, []byte("ping")) {
   218  		t.Fatalf("bad: %v", recv)
   219  	}
   220  
   221  	if _, err := conn.Write([]byte("pong")); err != nil {
   222  		t.Fatalf("err: %v", err)
   223  	}
   224  
   225  	// Check the remote addr
   226  	addr := conn.RemoteAddr().(*net.TCPAddr)
   227  	if addr.IP.String() != "10.1.1.1" {
   228  		t.Fatalf("bad: %v", addr)
   229  	}
   230  	if addr.Port != 1000 {
   231  		t.Fatalf("bad: %v", addr)
   232  	}
   233  
   234  	if len(errors) > 0 {
   235  		t.Fatal(<-errors)
   236  	}
   237  }
   238  
   239  func TestParse_ipv6(t *testing.T) {
   240  	l, err := net.Listen("tcp", "127.0.0.1:0")
   241  	if err != nil {
   242  		t.Fatalf("err: %v", err)
   243  	}
   244  
   245  	pl := &Listener{Listener: l}
   246  
   247  	errors := make(chan error, 5)
   248  	go func() {
   249  		conn, err := net.Dial("tcp", pl.Addr().String())
   250  		if err != nil {
   251  			errors <- err
   252  			return
   253  		}
   254  		defer conn.Close()
   255  
   256  		// Write out the header!
   257  		header := "PROXY TCP6 ffff::ffff ffff::ffff 1000 2000\r\n"
   258  		_, err = conn.Write([]byte(header))
   259  		if err != nil {
   260  			errors <- err
   261  			return
   262  		}
   263  
   264  		_, err = conn.Write([]byte("ping"))
   265  		if err != nil {
   266  			errors <- err
   267  			return
   268  		}
   269  
   270  		recv := make([]byte, 4)
   271  		_, err = conn.Read(recv)
   272  		if err != nil {
   273  			errors <- err
   274  			return
   275  		}
   276  		if !bytes.Equal(recv, []byte("pong")) {
   277  			errors <- fmt.Errorf("bad: %v", recv)
   278  			return
   279  		}
   280  	}()
   281  
   282  	conn, err := pl.Accept()
   283  	if err != nil {
   284  		t.Fatalf("err: %v", err)
   285  	}
   286  	defer conn.Close()
   287  
   288  	recv := make([]byte, 4)
   289  	_, err = conn.Read(recv)
   290  	if err != nil {
   291  		t.Fatalf("err: %v", err)
   292  	}
   293  	if !bytes.Equal(recv, []byte("ping")) {
   294  		t.Fatalf("bad: %v", recv)
   295  	}
   296  
   297  	if _, err := conn.Write([]byte("pong")); err != nil {
   298  		t.Fatalf("err: %v", err)
   299  	}
   300  
   301  	// Check the remote addr
   302  	addr := conn.RemoteAddr().(*net.TCPAddr)
   303  	if addr.IP.String() != "ffff::ffff" {
   304  		t.Fatalf("bad: %v", addr)
   305  	}
   306  	if addr.Port != 1000 {
   307  		t.Fatalf("bad: %v", addr)
   308  	}
   309  
   310  	if len(errors) > 0 {
   311  		t.Fatal(<-errors)
   312  	}
   313  }
   314  
   315  func TestParse_BadHeader(t *testing.T) {
   316  	l, err := net.Listen("tcp", "127.0.0.1:0")
   317  	if err != nil {
   318  		t.Fatalf("err: %v", err)
   319  	}
   320  
   321  	pl := &Listener{Listener: l}
   322  
   323  	errors := make(chan error, 5)
   324  	go func() {
   325  		conn, err := net.Dial("tcp", pl.Addr().String())
   326  		if err != nil {
   327  			errors <- err
   328  			return
   329  		}
   330  		defer conn.Close()
   331  
   332  		// Write out the header!
   333  		header := "PROXY TCP4 what 127.0.0.1 1000 2000\r\n"
   334  		_, err = conn.Write([]byte(header))
   335  		if err != nil {
   336  			errors <- err
   337  			return
   338  		}
   339  
   340  		_, err = conn.Write([]byte("ping"))
   341  		if err != nil {
   342  			errors <- err
   343  			return
   344  		}
   345  
   346  		recv := make([]byte, 4)
   347  		_, err = conn.Read(recv)
   348  		if err == nil {
   349  			errors <- fmt.Errorf("err: %v", err)
   350  			return
   351  		}
   352  	}()
   353  
   354  	conn, err := pl.Accept()
   355  	if err != nil {
   356  		t.Fatalf("err: %v", err)
   357  	}
   358  	defer conn.Close()
   359  
   360  	// Check the remote addr, should be the local addr
   361  	addr := conn.RemoteAddr().(*net.TCPAddr)
   362  	if addr.IP.String() != "127.0.0.1" {
   363  		t.Fatalf("bad: %v", addr)
   364  	}
   365  
   366  	// Read should fail
   367  	recv := make([]byte, 4)
   368  	_, err = conn.Read(recv)
   369  	if err == nil {
   370  		t.Fatal("err should be set")
   371  	}
   372  }
   373  
   374  func TestParseIPv4CheckFunc(t *testing.T) {
   375  	checkAddr = goodAddr
   376  	testParseIpv4CheckFunc(t)
   377  	checkAddr = badAddr
   378  	testParseIpv4CheckFunc(t)
   379  	checkAddr = errAddr
   380  	testParseIpv4CheckFunc(t)
   381  }
   382  
   383  func testParseIpv4CheckFunc(t *testing.T) {
   384  	l, err := net.Listen("tcp", "127.0.0.1:0")
   385  	if err != nil {
   386  		t.Fatalf("err: %v", err)
   387  	}
   388  
   389  	checkFunc := func(addr net.Addr) (bool, error) {
   390  		tcpAddr := addr.(*net.TCPAddr)
   391  		if tcpAddr.IP.String() == checkAddr {
   392  			return true, nil
   393  		}
   394  		return false, nil
   395  	}
   396  
   397  	pl := &Listener{Listener: l, SourceCheck: checkFunc}
   398  
   399  	errors := make(chan error, 4)
   400  	go func() {
   401  		conn, err := net.Dial("tcp", pl.Addr().String())
   402  		if err != nil {
   403  			errors <- err
   404  			return
   405  		}
   406  		defer conn.Close()
   407  
   408  		// Write out the header!
   409  		header := "PROXY TCP4 10.1.1.1 20.2.2.2 1000 2000\r\n"
   410  		_, err = conn.Write([]byte(header))
   411  		if err != nil {
   412  			errors <- err
   413  			return
   414  		}
   415  
   416  		_, err = conn.Write([]byte("ping"))
   417  		if err != nil {
   418  			errors <- err
   419  			return
   420  		}
   421  		recv := make([]byte, 4)
   422  		_, err = conn.Read(recv)
   423  		if err != nil {
   424  			errors <- err
   425  			return
   426  		}
   427  		if !bytes.Equal(recv, []byte("pong")) {
   428  			errors <- fmt.Errorf("bad: %v", recv)
   429  			return
   430  		}
   431  	}()
   432  
   433  	conn, err := pl.Accept()
   434  	if err != nil {
   435  		if checkAddr == badAddr {
   436  			return
   437  		}
   438  		t.Fatalf("err: %v", err)
   439  	}
   440  	defer conn.Close()
   441  
   442  	recv := make([]byte, 4)
   443  	_, err = conn.Read(recv)
   444  	if err != nil {
   445  		t.Fatalf("err: %v", err)
   446  	}
   447  	if !bytes.Equal(recv, []byte("ping")) {
   448  		t.Fatalf("bad: %v", recv)
   449  	}
   450  
   451  	if _, err := conn.Write([]byte("pong")); err != nil {
   452  		t.Fatalf("err: %v", err)
   453  	}
   454  
   455  	// Check the remote addr
   456  	addr := conn.RemoteAddr().(*net.TCPAddr)
   457  	switch checkAddr {
   458  	case goodAddr:
   459  		if addr.IP.String() != "10.1.1.1" {
   460  			t.Fatalf("bad: %v", addr)
   461  		}
   462  		if addr.Port != 1000 {
   463  			t.Fatalf("bad: %v", addr)
   464  		}
   465  	case badAddr:
   466  		if addr.IP.String() != "127.0.0.1" {
   467  			t.Fatalf("bad: %v", addr)
   468  		}
   469  		if addr.Port == 1000 {
   470  			t.Fatalf("bad: %v", addr)
   471  		}
   472  	}
   473  	if len(errors) > 0 {
   474  		t.Fatal(<-errors)
   475  	}
   476  }