github.com/blend/go-sdk@v1.20220411.3/proxyprotocol/proxy_protocol_header_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package proxyprotocol
     9  
    10  import (
    11  	"bytes"
    12  	"errors"
    13  	"net"
    14  	"reflect"
    15  	"testing"
    16  )
    17  
    18  // Stuff to be used in both versions tests.
    19  
    20  const (
    21  	noProtocol  = "There is no spoon"
    22  	ip4Addr     = "127.0.0.1"
    23  	ip6Addr     = "::1"
    24  	ip6LongAddr = "1234:5678:9abc:def0:cafe:babe:dead:2bad"
    25  	port        = 65533
    26  	invalidPort = 99999
    27  )
    28  
    29  var (
    30  	v4ip = net.ParseIP(ip4Addr).To4()
    31  	v6ip = net.ParseIP(ip6Addr).To16()
    32  
    33  	v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: port}
    34  	v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: port}
    35  
    36  	v4UDPAddr net.Addr = &net.UDPAddr{IP: v4ip, Port: port}
    37  	v6UDPAddr net.Addr = &net.UDPAddr{IP: v6ip, Port: port}
    38  
    39  	unixStreamAddr   net.Addr = &net.UnixAddr{Net: "unix", Name: "socket"}
    40  	unixDatagramAddr net.Addr = &net.UnixAddr{Net: "unixgram", Name: "socket"}
    41  
    42  	errReadIntentionallyBroken = errors.New("read is intentionally broken")
    43  )
    44  
    45  func TestEqualsTo(t *testing.T) {
    46  	var headersEqual = []struct {
    47  		this, that *Header
    48  		expected   bool
    49  	}{
    50  		{
    51  			&Header{
    52  				Version:           1,
    53  				Command:           ProtocolVersionAndCommandProxy,
    54  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
    55  				SourceAddr: &net.TCPAddr{
    56  					IP:   net.ParseIP("10.1.1.1"),
    57  					Port: 1000,
    58  				},
    59  				DestinationAddr: &net.TCPAddr{
    60  					IP:   net.ParseIP("20.2.2.2"),
    61  					Port: 2000,
    62  				},
    63  			},
    64  			nil,
    65  			false,
    66  		},
    67  		{
    68  			&Header{
    69  				Version:           1,
    70  				Command:           ProtocolVersionAndCommandProxy,
    71  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
    72  				SourceAddr: &net.TCPAddr{
    73  					IP:   net.ParseIP("10.1.1.1"),
    74  					Port: 1000,
    75  				},
    76  				DestinationAddr: &net.TCPAddr{
    77  					IP:   net.ParseIP("20.2.2.2"),
    78  					Port: 2000,
    79  				},
    80  			},
    81  			&Header{
    82  				Version:           2,
    83  				Command:           ProtocolVersionAndCommandProxy,
    84  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
    85  				SourceAddr: &net.TCPAddr{
    86  					IP:   net.ParseIP("10.1.1.1"),
    87  					Port: 1000,
    88  				},
    89  				DestinationAddr: &net.TCPAddr{
    90  					IP:   net.ParseIP("20.2.2.2"),
    91  					Port: 2000,
    92  				},
    93  			},
    94  			false,
    95  		},
    96  		{
    97  			&Header{
    98  				Version:           1,
    99  				Command:           ProtocolVersionAndCommandProxy,
   100  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   101  				SourceAddr: &net.TCPAddr{
   102  					IP:   net.ParseIP("10.1.1.1"),
   103  					Port: 1000,
   104  				},
   105  				DestinationAddr: &net.TCPAddr{
   106  					IP:   net.ParseIP("20.2.2.2"),
   107  					Port: 2000,
   108  				},
   109  			},
   110  			&Header{
   111  				Version:           1,
   112  				Command:           ProtocolVersionAndCommandProxy,
   113  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   114  				SourceAddr: &net.TCPAddr{
   115  					IP:   net.ParseIP("10.1.1.1"),
   116  					Port: 1000,
   117  				},
   118  				DestinationAddr: &net.TCPAddr{
   119  					IP:   net.ParseIP("20.2.2.2"),
   120  					Port: 2000,
   121  				},
   122  			},
   123  			true,
   124  		},
   125  	}
   126  
   127  	for _, tt := range headersEqual {
   128  		if actual := tt.this.EqualsTo(tt.that); actual != tt.expected {
   129  			t.Fatalf("expected %t, actual %t", tt.expected, actual)
   130  		}
   131  	}
   132  }
   133  
   134  // This is here just because of coveralls
   135  func TestEqualTo(t *testing.T) {
   136  	TestEqualsTo(t)
   137  }
   138  
   139  func TestGetters(t *testing.T) {
   140  	var tests = []struct {
   141  		name                         string
   142  		header                       *Header
   143  		tcpSourceAddr, tcpDestAddr   *net.TCPAddr
   144  		udpSourceAddr, udpDestAddr   *net.UDPAddr
   145  		unixSourceAddr, unixDestAddr *net.UnixAddr
   146  		ipSource, ipDest             net.IP
   147  		portSource, portDest         int
   148  	}{
   149  		{
   150  			name: "AddressFamilyAndProtocolTCPv4",
   151  			header: &Header{
   152  				Version:           1,
   153  				Command:           ProtocolVersionAndCommandProxy,
   154  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   155  				SourceAddr: &net.TCPAddr{
   156  					IP:   net.ParseIP("10.1.1.1"),
   157  					Port: 1000,
   158  				},
   159  				DestinationAddr: &net.TCPAddr{
   160  					IP:   net.ParseIP("20.2.2.2"),
   161  					Port: 2000,
   162  				},
   163  			},
   164  			tcpSourceAddr: &net.TCPAddr{
   165  				IP:   net.ParseIP("10.1.1.1"),
   166  				Port: 1000,
   167  			},
   168  			tcpDestAddr: &net.TCPAddr{
   169  				IP:   net.ParseIP("20.2.2.2"),
   170  				Port: 2000,
   171  			},
   172  			ipSource:   net.ParseIP("10.1.1.1"),
   173  			ipDest:     net.ParseIP("20.2.2.2"),
   174  			portSource: 1000,
   175  			portDest:   2000,
   176  		},
   177  		{
   178  			name: "UDPv4",
   179  			header: &Header{
   180  				Version:           2,
   181  				Command:           ProtocolVersionAndCommandProxy,
   182  				TransportProtocol: AddressFamilyAndProtocolUDPv6,
   183  				SourceAddr: &net.UDPAddr{
   184  					IP:   net.ParseIP("10.1.1.1"),
   185  					Port: 1000,
   186  				},
   187  				DestinationAddr: &net.UDPAddr{
   188  					IP:   net.ParseIP("20.2.2.2"),
   189  					Port: 2000,
   190  				},
   191  			},
   192  			udpSourceAddr: &net.UDPAddr{
   193  				IP:   net.ParseIP("10.1.1.1"),
   194  				Port: 1000,
   195  			},
   196  			udpDestAddr: &net.UDPAddr{
   197  				IP:   net.ParseIP("20.2.2.2"),
   198  				Port: 2000,
   199  			},
   200  			ipSource:   net.ParseIP("10.1.1.1"),
   201  			ipDest:     net.ParseIP("20.2.2.2"),
   202  			portSource: 1000,
   203  			portDest:   2000,
   204  		},
   205  		{
   206  			name: "UnixStream",
   207  			header: &Header{
   208  				Version:           2,
   209  				Command:           ProtocolVersionAndCommandProxy,
   210  				TransportProtocol: AddressFamilyAndProtocolUnixStream,
   211  				SourceAddr: &net.UnixAddr{
   212  					Net:  "unix",
   213  					Name: "src",
   214  				},
   215  				DestinationAddr: &net.UnixAddr{
   216  					Net:  "unix",
   217  					Name: "dst",
   218  				},
   219  			},
   220  			unixSourceAddr: &net.UnixAddr{
   221  				Net:  "unix",
   222  				Name: "src",
   223  			},
   224  			unixDestAddr: &net.UnixAddr{
   225  				Net:  "unix",
   226  				Name: "dst",
   227  			},
   228  		},
   229  		{
   230  			name: "UnixDatagram",
   231  			header: &Header{
   232  				Version:           2,
   233  				Command:           ProtocolVersionAndCommandProxy,
   234  				TransportProtocol: AddressFamilyAndProtocolUnixDatagram,
   235  				SourceAddr: &net.UnixAddr{
   236  					Net:  "unix",
   237  					Name: "src",
   238  				},
   239  				DestinationAddr: &net.UnixAddr{
   240  					Net:  "unix",
   241  					Name: "dst",
   242  				},
   243  			},
   244  			unixSourceAddr: &net.UnixAddr{
   245  				Net:  "unix",
   246  				Name: "src",
   247  			},
   248  			unixDestAddr: &net.UnixAddr{
   249  				Net:  "unix",
   250  				Name: "dst",
   251  			},
   252  		},
   253  		{
   254  			name: "Unspec",
   255  			header: &Header{
   256  				Version:           1,
   257  				Command:           ProtocolVersionAndCommandProxy,
   258  				TransportProtocol: AddressFamilyAndProtocolUnknown,
   259  			},
   260  		},
   261  	}
   262  
   263  	for _, test := range tests {
   264  		t.Run(test.name, func(t *testing.T) {
   265  			tcpSourceAddr, tcpDestAddr, _ := test.header.TCPAddrs()
   266  			if test.tcpSourceAddr != nil && !reflect.DeepEqual(tcpSourceAddr, test.tcpSourceAddr) {
   267  				t.Errorf("TCPAddrs() source = %v, want %v", tcpSourceAddr, test.tcpSourceAddr)
   268  			}
   269  			if test.tcpDestAddr != nil && !reflect.DeepEqual(tcpDestAddr, test.tcpDestAddr) {
   270  				t.Errorf("TCPAddrs() dest = %v, want %v", tcpDestAddr, test.tcpDestAddr)
   271  			}
   272  
   273  			udpSourceAddr, udpDestAddr, _ := test.header.UDPAddrs()
   274  			if test.udpSourceAddr != nil && !reflect.DeepEqual(udpSourceAddr, test.udpSourceAddr) {
   275  				t.Errorf("TCPAddrs() source = %v, want %v", udpSourceAddr, test.udpSourceAddr)
   276  			}
   277  			if test.udpDestAddr != nil && !reflect.DeepEqual(udpDestAddr, test.udpDestAddr) {
   278  				t.Errorf("TCPAddrs() dest = %v, want %v", udpDestAddr, test.udpDestAddr)
   279  			}
   280  
   281  			unixSourceAddr, unixDestAddr, _ := test.header.UnixAddrs()
   282  			if test.unixSourceAddr != nil && !reflect.DeepEqual(unixSourceAddr, test.unixSourceAddr) {
   283  				t.Errorf("UnixAddrs() source = %v, want %v", unixSourceAddr, test.unixSourceAddr)
   284  			}
   285  			if test.unixDestAddr != nil && !reflect.DeepEqual(unixDestAddr, test.unixDestAddr) {
   286  				t.Errorf("UnixAddrs() dest = %v, want %v", unixDestAddr, test.unixDestAddr)
   287  			}
   288  
   289  			ipSource, ipDest, _ := test.header.IPs()
   290  			if test.ipSource != nil && !ipSource.Equal(test.ipSource) {
   291  				t.Errorf("IPs() source = %v, want %v", ipSource, test.ipSource)
   292  			}
   293  			if test.ipDest != nil && !ipDest.Equal(test.ipDest) {
   294  				t.Errorf("IPs() dest = %v, want %v", ipDest, test.ipDest)
   295  			}
   296  
   297  			portSource, portDest, _ := test.header.Ports()
   298  			if test.portSource != 0 && portSource != test.portSource {
   299  				t.Errorf("Ports() source = %v, want %v", portSource, test.portSource)
   300  			}
   301  			if test.portDest != 0 && portDest != test.portDest {
   302  				t.Errorf("Ports() dest = %v, want %v", portDest, test.portDest)
   303  			}
   304  		})
   305  	}
   306  }
   307  
   308  func TestSetTLVs(t *testing.T) {
   309  	tests := []struct {
   310  		header    *Header
   311  		name      string
   312  		tlvs      []TLV
   313  		expectErr bool
   314  	}{
   315  		{
   316  			name: "add authority TLV",
   317  			header: &Header{
   318  				Version:           1,
   319  				Command:           ProtocolVersionAndCommandProxy,
   320  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   321  				SourceAddr: &net.TCPAddr{
   322  					IP:   net.ParseIP("10.1.1.1"),
   323  					Port: 1000,
   324  				},
   325  				DestinationAddr: &net.TCPAddr{
   326  					IP:   net.ParseIP("20.2.2.2"),
   327  					Port: 2000,
   328  				},
   329  			},
   330  			tlvs: []TLV{{
   331  				Type:  PP2TypeAuthority,
   332  				Value: []byte("example.org"),
   333  			}},
   334  		},
   335  		{
   336  			name: "add too long TLV",
   337  			header: &Header{
   338  				Version:           1,
   339  				Command:           ProtocolVersionAndCommandProxy,
   340  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   341  				SourceAddr: &net.TCPAddr{
   342  					IP:   net.ParseIP("10.1.1.1"),
   343  					Port: 1000,
   344  				},
   345  				DestinationAddr: &net.TCPAddr{
   346  					IP:   net.ParseIP("20.2.2.2"),
   347  					Port: 2000,
   348  				},
   349  			},
   350  			tlvs: []TLV{{
   351  				Type:  PP2TypeAuthority,
   352  				Value: append(bytes.Repeat([]byte("a"), 0xFFFF), []byte(".example.org")...),
   353  			}},
   354  			expectErr: true,
   355  		},
   356  	}
   357  	for _, tt := range tests {
   358  		err := tt.header.SetTLVs(tt.tlvs)
   359  		if err != nil && !tt.expectErr {
   360  			t.Fatalf("shouldn't have thrown error %q", err.Error())
   361  		}
   362  	}
   363  }
   364  
   365  func TestWriteTo(t *testing.T) {
   366  	var buf bytes.Buffer
   367  
   368  	validHeader := &Header{
   369  		Version:           1,
   370  		Command:           ProtocolVersionAndCommandProxy,
   371  		TransportProtocol: AddressFamilyAndProtocolTCPv4,
   372  		SourceAddr: &net.TCPAddr{
   373  			IP:   net.ParseIP("10.1.1.1"),
   374  			Port: 1000,
   375  		},
   376  		DestinationAddr: &net.TCPAddr{
   377  			IP:   net.ParseIP("20.2.2.2"),
   378  			Port: 2000,
   379  		},
   380  	}
   381  
   382  	if _, err := validHeader.WriteTo(&buf); err != nil {
   383  		t.Fatalf("shouldn't have thrown error %q", err.Error())
   384  	}
   385  
   386  	invalidHeader := &Header{
   387  		SourceAddr: &net.TCPAddr{
   388  			IP:   net.ParseIP("10.1.1.1"),
   389  			Port: 1000,
   390  		},
   391  		DestinationAddr: &net.TCPAddr{
   392  			IP:   net.ParseIP("20.2.2.2"),
   393  			Port: 2000,
   394  		},
   395  	}
   396  
   397  	if _, err := invalidHeader.WriteTo(&buf); err == nil {
   398  		t.Fatalf("should have thrown error %q", err.Error())
   399  	}
   400  }
   401  
   402  func TestFormat(t *testing.T) {
   403  	validHeader := &Header{
   404  		Version:           1,
   405  		Command:           ProtocolVersionAndCommandProxy,
   406  		TransportProtocol: AddressFamilyAndProtocolTCPv4,
   407  		SourceAddr: &net.TCPAddr{
   408  			IP:   net.ParseIP("10.1.1.1"),
   409  			Port: 1000,
   410  		},
   411  		DestinationAddr: &net.TCPAddr{
   412  			IP:   net.ParseIP("20.2.2.2"),
   413  			Port: 2000,
   414  		},
   415  	}
   416  
   417  	if _, err := validHeader.Format(); err != nil {
   418  		t.Fatalf("shouldn't have thrown error %q", err.Error())
   419  	}
   420  }
   421  
   422  func TestFormatInvalid(t *testing.T) {
   423  	tests := []struct {
   424  		name   string
   425  		header *Header
   426  		err    error
   427  	}{
   428  		{
   429  			name: "invalidVersion",
   430  			header: &Header{
   431  				Version:           3,
   432  				Command:           ProtocolVersionAndCommandProxy,
   433  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   434  				SourceAddr:        v4addr,
   435  				DestinationAddr:   v4addr,
   436  			},
   437  			err: ErrUnknownProxyProtocolVersion,
   438  		},
   439  		{
   440  			name: "v2MismatchTCPv4_UDPv4",
   441  			header: &Header{
   442  				Version:           2,
   443  				Command:           ProtocolVersionAndCommandProxy,
   444  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   445  				SourceAddr:        v4UDPAddr,
   446  				DestinationAddr:   v4addr,
   447  			},
   448  			err: ErrInvalidAddress,
   449  		},
   450  		{
   451  			name: "v2MismatchTCPv4_TCPv6",
   452  			header: &Header{
   453  				Version:           2,
   454  				Command:           ProtocolVersionAndCommandProxy,
   455  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   456  				SourceAddr:        v4addr,
   457  				DestinationAddr:   v6addr,
   458  			},
   459  			err: ErrInvalidAddress,
   460  		},
   461  		{
   462  			name: "v2MismatchUnixStream_TCPv4",
   463  			header: &Header{
   464  				Version:           2,
   465  				Command:           ProtocolVersionAndCommandProxy,
   466  				TransportProtocol: AddressFamilyAndProtocolUnixStream,
   467  				SourceAddr:        v4addr,
   468  				DestinationAddr:   unixStreamAddr,
   469  			},
   470  			err: ErrInvalidAddress,
   471  		},
   472  		{
   473  			name: "v1MismatchTCPv4_TCPv6",
   474  			header: &Header{
   475  				Version:           1,
   476  				Command:           ProtocolVersionAndCommandProxy,
   477  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   478  				SourceAddr:        v6addr,
   479  				DestinationAddr:   v4addr,
   480  			},
   481  			err: ErrInvalidAddress,
   482  		},
   483  		{
   484  			name: "v1MismatchTCPv4_UDPv4",
   485  			header: &Header{
   486  				Version:           1,
   487  				Command:           ProtocolVersionAndCommandProxy,
   488  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   489  				SourceAddr:        v4UDPAddr,
   490  				DestinationAddr:   v4addr,
   491  			},
   492  			err: ErrInvalidAddress,
   493  		},
   494  	}
   495  
   496  	for _, test := range tests {
   497  		t.Run(test.name, func(t *testing.T) {
   498  			if _, err := test.header.Format(); err == nil {
   499  				t.Errorf("Header.Format() succeeded, want an error")
   500  			} else if err != test.err {
   501  				t.Errorf("Header.Format() = %q, want %q", err, test.err)
   502  			}
   503  		})
   504  	}
   505  }