go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/proxyproto/header_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package proxyproto
     9  
    10  import (
    11  	"bytes"
    12  	"errors"
    13  	"net"
    14  	"reflect"
    15  	"testing"
    16  )
    17  
    18  // Stuff to be used in both versions tests.
    19  
    20  const (
    21  	noProtocol  = "There is no spoon"
    22  	ip4Addr     = "127.0.0.1"
    23  	ip6Addr     = "::1"
    24  	ip6LongAddr = "1234:5678:9abc:def0:cafe:babe:dead:2bad"
    25  	port        = 65533
    26  	invalidPort = 99999
    27  )
    28  
    29  var (
    30  	v4ip = net.ParseIP(ip4Addr).To4()
    31  	v6ip = net.ParseIP(ip6Addr).To16()
    32  
    33  	v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: port}
    34  	v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: port}
    35  
    36  	v4UDPAddr net.Addr = &net.UDPAddr{IP: v4ip, Port: port}
    37  	v6UDPAddr net.Addr = &net.UDPAddr{IP: v6ip, Port: port}
    38  
    39  	unixStreamAddr   net.Addr = &net.UnixAddr{Net: "unix", Name: "socket"}
    40  	unixDatagramAddr net.Addr = &net.UnixAddr{Net: "unixgram", Name: "socket"}
    41  
    42  	errReadIntentionallyBroken = errors.New("read is intentionally broken")
    43  )
    44  
    45  func Test_Header_EqualsTo(t *testing.T) {
    46  	var headersEqual = []struct {
    47  		this, that *Header
    48  		expected   bool
    49  	}{
    50  		{
    51  			&Header{
    52  				Version:           1,
    53  				Command:           ProtocolVersionAndCommandProxy,
    54  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
    55  				SourceAddr: &net.TCPAddr{
    56  					IP:   net.ParseIP("10.1.1.1"),
    57  					Port: 1000,
    58  				},
    59  				DestinationAddr: &net.TCPAddr{
    60  					IP:   net.ParseIP("20.2.2.2"),
    61  					Port: 2000,
    62  				},
    63  			},
    64  			nil,
    65  			false,
    66  		},
    67  		{
    68  			&Header{
    69  				Version:           1,
    70  				Command:           ProtocolVersionAndCommandProxy,
    71  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
    72  				SourceAddr: &net.TCPAddr{
    73  					IP:   net.ParseIP("10.1.1.1"),
    74  					Port: 1000,
    75  				},
    76  				DestinationAddr: &net.TCPAddr{
    77  					IP:   net.ParseIP("20.2.2.2"),
    78  					Port: 2000,
    79  				},
    80  			},
    81  			&Header{
    82  				Version:           2,
    83  				Command:           ProtocolVersionAndCommandProxy,
    84  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
    85  				SourceAddr: &net.TCPAddr{
    86  					IP:   net.ParseIP("10.1.1.1"),
    87  					Port: 1000,
    88  				},
    89  				DestinationAddr: &net.TCPAddr{
    90  					IP:   net.ParseIP("20.2.2.2"),
    91  					Port: 2000,
    92  				},
    93  			},
    94  			false,
    95  		},
    96  		{
    97  			&Header{
    98  				Version:           1,
    99  				Command:           ProtocolVersionAndCommandProxy,
   100  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   101  				SourceAddr: &net.TCPAddr{
   102  					IP:   net.ParseIP("10.1.1.1"),
   103  					Port: 1000,
   104  				},
   105  				DestinationAddr: &net.TCPAddr{
   106  					IP:   net.ParseIP("20.2.2.2"),
   107  					Port: 2000,
   108  				},
   109  			},
   110  			&Header{
   111  				Version:           1,
   112  				Command:           ProtocolVersionAndCommandProxy,
   113  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   114  				SourceAddr: &net.TCPAddr{
   115  					IP:   net.ParseIP("10.1.1.1"),
   116  					Port: 1000,
   117  				},
   118  				DestinationAddr: &net.TCPAddr{
   119  					IP:   net.ParseIP("20.2.2.2"),
   120  					Port: 2000,
   121  				},
   122  			},
   123  			true,
   124  		},
   125  	}
   126  
   127  	for _, tt := range headersEqual {
   128  		if actual := tt.this.EqualsTo(tt.that); actual != tt.expected {
   129  			t.Fatalf("expected %t, actual %t", tt.expected, actual)
   130  		}
   131  	}
   132  }
   133  
   134  func Test_Header_getters(t *testing.T) {
   135  	var tests = []struct {
   136  		name                         string
   137  		header                       *Header
   138  		tcpSourceAddr, tcpDestAddr   *net.TCPAddr
   139  		udpSourceAddr, udpDestAddr   *net.UDPAddr
   140  		unixSourceAddr, unixDestAddr *net.UnixAddr
   141  		ipSource, ipDest             net.IP
   142  		portSource, portDest         int
   143  	}{
   144  		{
   145  			name: "AddressFamilyAndProtocolTCPv4",
   146  			header: &Header{
   147  				Version:           1,
   148  				Command:           ProtocolVersionAndCommandProxy,
   149  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   150  				SourceAddr: &net.TCPAddr{
   151  					IP:   net.ParseIP("10.1.1.1"),
   152  					Port: 1000,
   153  				},
   154  				DestinationAddr: &net.TCPAddr{
   155  					IP:   net.ParseIP("20.2.2.2"),
   156  					Port: 2000,
   157  				},
   158  			},
   159  			tcpSourceAddr: &net.TCPAddr{
   160  				IP:   net.ParseIP("10.1.1.1"),
   161  				Port: 1000,
   162  			},
   163  			tcpDestAddr: &net.TCPAddr{
   164  				IP:   net.ParseIP("20.2.2.2"),
   165  				Port: 2000,
   166  			},
   167  			ipSource:   net.ParseIP("10.1.1.1"),
   168  			ipDest:     net.ParseIP("20.2.2.2"),
   169  			portSource: 1000,
   170  			portDest:   2000,
   171  		},
   172  		{
   173  			name: "UDPv4",
   174  			header: &Header{
   175  				Version:           2,
   176  				Command:           ProtocolVersionAndCommandProxy,
   177  				TransportProtocol: AddressFamilyAndProtocolUDPv6,
   178  				SourceAddr: &net.UDPAddr{
   179  					IP:   net.ParseIP("10.1.1.1"),
   180  					Port: 1000,
   181  				},
   182  				DestinationAddr: &net.UDPAddr{
   183  					IP:   net.ParseIP("20.2.2.2"),
   184  					Port: 2000,
   185  				},
   186  			},
   187  			udpSourceAddr: &net.UDPAddr{
   188  				IP:   net.ParseIP("10.1.1.1"),
   189  				Port: 1000,
   190  			},
   191  			udpDestAddr: &net.UDPAddr{
   192  				IP:   net.ParseIP("20.2.2.2"),
   193  				Port: 2000,
   194  			},
   195  			ipSource:   net.ParseIP("10.1.1.1"),
   196  			ipDest:     net.ParseIP("20.2.2.2"),
   197  			portSource: 1000,
   198  			portDest:   2000,
   199  		},
   200  		{
   201  			name: "UnixStream",
   202  			header: &Header{
   203  				Version:           2,
   204  				Command:           ProtocolVersionAndCommandProxy,
   205  				TransportProtocol: AddressFamilyAndProtocolUnixStream,
   206  				SourceAddr: &net.UnixAddr{
   207  					Net:  "unix",
   208  					Name: "src",
   209  				},
   210  				DestinationAddr: &net.UnixAddr{
   211  					Net:  "unix",
   212  					Name: "dst",
   213  				},
   214  			},
   215  			unixSourceAddr: &net.UnixAddr{
   216  				Net:  "unix",
   217  				Name: "src",
   218  			},
   219  			unixDestAddr: &net.UnixAddr{
   220  				Net:  "unix",
   221  				Name: "dst",
   222  			},
   223  		},
   224  		{
   225  			name: "UnixDatagram",
   226  			header: &Header{
   227  				Version:           2,
   228  				Command:           ProtocolVersionAndCommandProxy,
   229  				TransportProtocol: AddressFamilyAndProtocolUnixDatagram,
   230  				SourceAddr: &net.UnixAddr{
   231  					Net:  "unix",
   232  					Name: "src",
   233  				},
   234  				DestinationAddr: &net.UnixAddr{
   235  					Net:  "unix",
   236  					Name: "dst",
   237  				},
   238  			},
   239  			unixSourceAddr: &net.UnixAddr{
   240  				Net:  "unix",
   241  				Name: "src",
   242  			},
   243  			unixDestAddr: &net.UnixAddr{
   244  				Net:  "unix",
   245  				Name: "dst",
   246  			},
   247  		},
   248  		{
   249  			name: "Unspec",
   250  			header: &Header{
   251  				Version:           1,
   252  				Command:           ProtocolVersionAndCommandProxy,
   253  				TransportProtocol: AddressFamilyAndProtocolUnknown,
   254  			},
   255  		},
   256  	}
   257  
   258  	for _, test := range tests {
   259  		t.Run(test.name, func(t *testing.T) {
   260  			tcpSourceAddr, tcpDestAddr, _ := test.header.TCPAddrs()
   261  			if test.tcpSourceAddr != nil && !reflect.DeepEqual(tcpSourceAddr, test.tcpSourceAddr) {
   262  				t.Errorf("TCPAddrs() source = %v, want %v", tcpSourceAddr, test.tcpSourceAddr)
   263  			}
   264  			if test.tcpDestAddr != nil && !reflect.DeepEqual(tcpDestAddr, test.tcpDestAddr) {
   265  				t.Errorf("TCPAddrs() dest = %v, want %v", tcpDestAddr, test.tcpDestAddr)
   266  			}
   267  
   268  			udpSourceAddr, udpDestAddr, _ := test.header.UDPAddrs()
   269  			if test.udpSourceAddr != nil && !reflect.DeepEqual(udpSourceAddr, test.udpSourceAddr) {
   270  				t.Errorf("TCPAddrs() source = %v, want %v", udpSourceAddr, test.udpSourceAddr)
   271  			}
   272  			if test.udpDestAddr != nil && !reflect.DeepEqual(udpDestAddr, test.udpDestAddr) {
   273  				t.Errorf("TCPAddrs() dest = %v, want %v", udpDestAddr, test.udpDestAddr)
   274  			}
   275  
   276  			unixSourceAddr, unixDestAddr, _ := test.header.UnixAddrs()
   277  			if test.unixSourceAddr != nil && !reflect.DeepEqual(unixSourceAddr, test.unixSourceAddr) {
   278  				t.Errorf("UnixAddrs() source = %v, want %v", unixSourceAddr, test.unixSourceAddr)
   279  			}
   280  			if test.unixDestAddr != nil && !reflect.DeepEqual(unixDestAddr, test.unixDestAddr) {
   281  				t.Errorf("UnixAddrs() dest = %v, want %v", unixDestAddr, test.unixDestAddr)
   282  			}
   283  
   284  			ipSource, ipDest, _ := test.header.IPs()
   285  			if test.ipSource != nil && !ipSource.Equal(test.ipSource) {
   286  				t.Errorf("IPs() source = %v, want %v", ipSource, test.ipSource)
   287  			}
   288  			if test.ipDest != nil && !ipDest.Equal(test.ipDest) {
   289  				t.Errorf("IPs() dest = %v, want %v", ipDest, test.ipDest)
   290  			}
   291  
   292  			portSource, portDest, _ := test.header.Ports()
   293  			if test.portSource != 0 && portSource != test.portSource {
   294  				t.Errorf("Ports() source = %v, want %v", portSource, test.portSource)
   295  			}
   296  			if test.portDest != 0 && portDest != test.portDest {
   297  				t.Errorf("Ports() dest = %v, want %v", portDest, test.portDest)
   298  			}
   299  		})
   300  	}
   301  }
   302  
   303  func Test_Header_SetTLVs(t *testing.T) {
   304  	tests := []struct {
   305  		header    *Header
   306  		name      string
   307  		tlvs      []TLV
   308  		expectErr bool
   309  	}{
   310  		{
   311  			name: "add authority TLV",
   312  			header: &Header{
   313  				Version:           1,
   314  				Command:           ProtocolVersionAndCommandProxy,
   315  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   316  				SourceAddr: &net.TCPAddr{
   317  					IP:   net.ParseIP("10.1.1.1"),
   318  					Port: 1000,
   319  				},
   320  				DestinationAddr: &net.TCPAddr{
   321  					IP:   net.ParseIP("20.2.2.2"),
   322  					Port: 2000,
   323  				},
   324  			},
   325  			tlvs: []TLV{{
   326  				Type:  PP2TypeAuthority,
   327  				Value: []byte("example.org"),
   328  			}},
   329  		},
   330  		{
   331  			name: "add too long TLV",
   332  			header: &Header{
   333  				Version:           1,
   334  				Command:           ProtocolVersionAndCommandProxy,
   335  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   336  				SourceAddr: &net.TCPAddr{
   337  					IP:   net.ParseIP("10.1.1.1"),
   338  					Port: 1000,
   339  				},
   340  				DestinationAddr: &net.TCPAddr{
   341  					IP:   net.ParseIP("20.2.2.2"),
   342  					Port: 2000,
   343  				},
   344  			},
   345  			tlvs: []TLV{{
   346  				Type:  PP2TypeAuthority,
   347  				Value: append(bytes.Repeat([]byte("a"), 0xFFFF), []byte(".example.org")...),
   348  			}},
   349  			expectErr: true,
   350  		},
   351  	}
   352  	for _, tt := range tests {
   353  		err := tt.header.SetTLVs(tt.tlvs)
   354  		if err != nil && !tt.expectErr {
   355  			t.Fatalf("shouldn't have thrown error %q", err.Error())
   356  		}
   357  	}
   358  }
   359  
   360  func Test_Header_WriteTo(t *testing.T) {
   361  	var buf bytes.Buffer
   362  
   363  	validHeader := &Header{
   364  		Version:           1,
   365  		Command:           ProtocolVersionAndCommandProxy,
   366  		TransportProtocol: AddressFamilyAndProtocolTCPv4,
   367  		SourceAddr: &net.TCPAddr{
   368  			IP:   net.ParseIP("10.1.1.1"),
   369  			Port: 1000,
   370  		},
   371  		DestinationAddr: &net.TCPAddr{
   372  			IP:   net.ParseIP("20.2.2.2"),
   373  			Port: 2000,
   374  		},
   375  	}
   376  
   377  	if _, err := validHeader.WriteTo(&buf); err != nil {
   378  		t.Fatalf("shouldn't have thrown error %q", err.Error())
   379  	}
   380  
   381  	invalidHeader := &Header{
   382  		SourceAddr: &net.TCPAddr{
   383  			IP:   net.ParseIP("10.1.1.1"),
   384  			Port: 1000,
   385  		},
   386  		DestinationAddr: &net.TCPAddr{
   387  			IP:   net.ParseIP("20.2.2.2"),
   388  			Port: 2000,
   389  		},
   390  	}
   391  
   392  	if _, err := invalidHeader.WriteTo(&buf); err == nil {
   393  		t.Fatalf("should have thrown error %q", err.Error())
   394  	}
   395  }
   396  
   397  func Test_Header_Format(t *testing.T) {
   398  	validHeader := &Header{
   399  		Version:           1,
   400  		Command:           ProtocolVersionAndCommandProxy,
   401  		TransportProtocol: AddressFamilyAndProtocolTCPv4,
   402  		SourceAddr: &net.TCPAddr{
   403  			IP:   net.ParseIP("10.1.1.1"),
   404  			Port: 1000,
   405  		},
   406  		DestinationAddr: &net.TCPAddr{
   407  			IP:   net.ParseIP("20.2.2.2"),
   408  			Port: 2000,
   409  		},
   410  	}
   411  
   412  	if _, err := validHeader.Format(); err != nil {
   413  		t.Fatalf("shouldn't have thrown error %q", err.Error())
   414  	}
   415  }
   416  
   417  func Test_Header_Format_invalid(t *testing.T) {
   418  	tests := []struct {
   419  		name   string
   420  		header *Header
   421  		err    error
   422  	}{
   423  		{
   424  			name: "invalidVersion",
   425  			header: &Header{
   426  				Version:           3,
   427  				Command:           ProtocolVersionAndCommandProxy,
   428  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   429  				SourceAddr:        v4addr,
   430  				DestinationAddr:   v4addr,
   431  			},
   432  			err: ErrUnknownProxyProtocolVersion,
   433  		},
   434  		{
   435  			name: "v2MismatchTCPv4_UDPv4",
   436  			header: &Header{
   437  				Version:           2,
   438  				Command:           ProtocolVersionAndCommandProxy,
   439  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   440  				SourceAddr:        v4UDPAddr,
   441  				DestinationAddr:   v4addr,
   442  			},
   443  			err: ErrInvalidAddress,
   444  		},
   445  		{
   446  			name: "v2MismatchTCPv4_TCPv6",
   447  			header: &Header{
   448  				Version:           2,
   449  				Command:           ProtocolVersionAndCommandProxy,
   450  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   451  				SourceAddr:        v4addr,
   452  				DestinationAddr:   v6addr,
   453  			},
   454  			err: ErrInvalidAddress,
   455  		},
   456  		{
   457  			name: "v2MismatchUnixStream_TCPv4",
   458  			header: &Header{
   459  				Version:           2,
   460  				Command:           ProtocolVersionAndCommandProxy,
   461  				TransportProtocol: AddressFamilyAndProtocolUnixStream,
   462  				SourceAddr:        v4addr,
   463  				DestinationAddr:   unixStreamAddr,
   464  			},
   465  			err: ErrInvalidAddress,
   466  		},
   467  		{
   468  			name: "v1MismatchTCPv4_TCPv6",
   469  			header: &Header{
   470  				Version:           1,
   471  				Command:           ProtocolVersionAndCommandProxy,
   472  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   473  				SourceAddr:        v6addr,
   474  				DestinationAddr:   v4addr,
   475  			},
   476  			err: ErrInvalidAddress,
   477  		},
   478  		{
   479  			name: "v1MismatchTCPv4_UDPv4",
   480  			header: &Header{
   481  				Version:           1,
   482  				Command:           ProtocolVersionAndCommandProxy,
   483  				TransportProtocol: AddressFamilyAndProtocolTCPv4,
   484  				SourceAddr:        v4UDPAddr,
   485  				DestinationAddr:   v4addr,
   486  			},
   487  			err: ErrInvalidAddress,
   488  		},
   489  	}
   490  
   491  	for _, test := range tests {
   492  		t.Run(test.name, func(t *testing.T) {
   493  			if _, err := test.header.Format(); err == nil {
   494  				t.Errorf("Header.Format() succeeded, want an error")
   495  			} else if err != test.err {
   496  				t.Errorf("Header.Format() = %q, want %q", err, test.err)
   497  			}
   498  		})
   499  	}
   500  }