github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/socks5/addr_test.go (about)

     1  package socks5
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"io"
     7  	"net/netip"
     8  	"testing"
     9  
    10  	"github.com/database64128/shadowsocks-go/conn"
    11  )
    12  
    13  // Test zero value address.
    14  var (
    15  	addrZero         = IPv4UnspecifiedAddr
    16  	addrZeroConnAddr conn.Addr
    17  )
    18  
    19  // Test IPv4 address.
    20  var (
    21  	addr4                = [IPv4AddrLen]byte{AtypIPv4, 127, 0, 0, 1, 4, 56}
    22  	addr4addr            = netip.AddrFrom4([4]byte{127, 0, 0, 1})
    23  	addr4port     uint16 = 1080
    24  	addr4addrport        = netip.AddrPortFrom(addr4addr, addr4port)
    25  	addr4connaddr        = conn.AddrFromIPPort(addr4addrport)
    26  )
    27  
    28  // Test IPv4-mapped IPv6 address.
    29  var (
    30  	addr4in6                = [IPv4AddrLen]byte{AtypIPv4, 127, 0, 0, 1, 4, 56}
    31  	addr4in6addr            = netip.AddrFrom16([16]byte{10: 0xff, 11: 0xff, 127, 0, 0, 1})
    32  	addr4in6port     uint16 = 1080
    33  	addr4in6addrport        = netip.AddrPortFrom(addr4in6addr, addr4in6port)
    34  	addr4in6connaddr        = conn.AddrFromIPPort(addr4in6addrport)
    35  )
    36  
    37  // Test IPv6 address.
    38  var (
    39  	addr6                = [IPv6AddrLen]byte{AtypIPv6, 0x20, 0x01, 0x0d, 0xb8, 0xfa, 0xd6, 0x05, 0x72, 0xac, 0xbe, 0x71, 0x43, 0x14, 0xe5, 0x7a, 0x6e, 4, 56}
    40  	addr6addr            = netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0xfa, 0xd6, 0x05, 0x72, 0xac, 0xbe, 0x71, 0x43, 0x14, 0xe5, 0x7a, 0x6e})
    41  	addr6port     uint16 = 1080
    42  	addr6addrport        = netip.AddrPortFrom(addr6addr, addr6port)
    43  	addr6connaddr        = conn.AddrFromIPPort(addr6addrport)
    44  )
    45  
    46  // Test domain name.
    47  var (
    48  	addrDomain                = [1 + 1 + 11 + 2]byte{AtypDomainName, 11, 'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'c', 'o', 'm', 1, 187}
    49  	addrDomainHost            = "example.com"
    50  	addrDomainPort     uint16 = 443
    51  	addrDomainConnAddr        = conn.MustAddrFromDomainPort(addrDomainHost, addrDomainPort)
    52  )
    53  
    54  func testAddrFromReader(t *testing.T, addr []byte) {
    55  	b := make([]byte, 512)
    56  	n := copy(b, addr)
    57  	_, err := rand.Read(b[n:])
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  	expectedTail := make([]byte, 512-n)
    62  	copy(expectedTail, b[n:])
    63  
    64  	r := bytes.NewReader(b)
    65  	raddr, err := AddrFromReader(r)
    66  	if err != nil {
    67  		t.Fatal(err)
    68  	}
    69  	if !bytes.Equal(addr, raddr) {
    70  		t.Errorf("Expected: %v\nGot: %v", addr, []byte(raddr))
    71  	}
    72  	tail, err := io.ReadAll(r)
    73  	if err != nil {
    74  		t.Fatal(err)
    75  	}
    76  	if !bytes.Equal(tail, expectedTail) {
    77  		t.Error("AddrFromReader(r) read more bytes than expected.")
    78  	}
    79  }
    80  
    81  func TestAddrFromReader(t *testing.T) {
    82  	testAddrFromReader(t, addr4[:])
    83  	testAddrFromReader(t, addr4in6[:])
    84  	testAddrFromReader(t, addr6[:])
    85  	testAddrFromReader(t, addrDomain[:])
    86  }
    87  
    88  func testAddrPortFromSlice(t *testing.T, sa []byte, expectedAddrPort netip.AddrPort, expectedN int, expectedErr error) {
    89  	b := make([]byte, 512)
    90  	n := copy(b, sa)
    91  	_, err := rand.Read(b[n:])
    92  	if err != nil {
    93  		t.Fatal(err)
    94  	}
    95  	expectedTail := make([]byte, 512-n)
    96  	copy(expectedTail, b[n:])
    97  
    98  	addrPort, n, err := AddrPortFromSlice(b)
    99  	if err != expectedErr {
   100  		t.Errorf("AddrPortFromSlice(b) returned error %s, expected error %s", err, expectedErr)
   101  	}
   102  	if n != expectedN {
   103  		t.Errorf("AddrPortFromSlice(b) returned n=%d, expected n=%d.", n, expectedN)
   104  	}
   105  	if addrPort != expectedAddrPort {
   106  		t.Errorf("AddrPortFromSlice(b) returned %s, expected %s.", addrPort, expectedAddrPort)
   107  	}
   108  	if !bytes.Equal(b[len(sa):], expectedTail) {
   109  		t.Error("AddrPortFromSlice(b) modified non-address bytes.")
   110  	}
   111  }
   112  
   113  func TestAddrPortFromSlice(t *testing.T) {
   114  	testAddrPortFromSlice(t, addr4[:], addr4addrport, len(addr4), nil)
   115  	testAddrPortFromSlice(t, addr4in6[:], addr4addrport, len(addr4in6), nil)
   116  	testAddrPortFromSlice(t, addr6[:], addr6addrport, len(addr6), nil)
   117  	testAddrPortFromSlice(t, addrDomain[:], netip.AddrPort{}, 0, errDomain)
   118  }
   119  
   120  func testConnAddrFromSliceAndReader(t *testing.T, sa []byte, expectedAddr conn.Addr) {
   121  	b := make([]byte, 512)
   122  	n := copy(b, sa)
   123  	_, err := rand.Read(b[n:])
   124  	if err != nil {
   125  		t.Fatal(err)
   126  	}
   127  	expectedTail := make([]byte, 512-n)
   128  	copy(expectedTail, b[n:])
   129  
   130  	addr, n, err := ConnAddrFromSlice(b)
   131  	if err != nil {
   132  		t.Fatal(err)
   133  	}
   134  	if n != len(sa) {
   135  		t.Errorf("ConnAddrFromSlice(b) returned n=%d, expected n=%d.", n, len(sa))
   136  	}
   137  	if !addr.Equals(expectedAddr) {
   138  		t.Errorf("ConnAddrFromSlice(b) returned %s, expected %s.", addr, expectedAddr)
   139  	}
   140  	if !bytes.Equal(b[n:], expectedTail) {
   141  		t.Error("ConnAddrFromSlice(b) modified non-address bytes.")
   142  	}
   143  
   144  	r := bytes.NewReader(b)
   145  	addr, err = ConnAddrFromReader(r)
   146  	if err != nil {
   147  		t.Fatal(err)
   148  	}
   149  	if !addr.Equals(expectedAddr) {
   150  		t.Errorf("ConnAddrFromReader(r) returned %s, expected %s.", addr, expectedAddr)
   151  	}
   152  	tail, err := io.ReadAll(r)
   153  	if err != nil {
   154  		t.Fatal(err)
   155  	}
   156  	if !bytes.Equal(tail, expectedTail) {
   157  		t.Error("ConnAddrFromReader(r) read more bytes than expected.")
   158  	}
   159  }
   160  
   161  func TestConnAddrFromSliceAndReader(t *testing.T) {
   162  	testConnAddrFromSliceAndReader(t, addr4[:], addr4connaddr)
   163  	testConnAddrFromSliceAndReader(t, addr4in6[:], addr4connaddr)
   164  	testConnAddrFromSliceAndReader(t, addr6[:], addr6connaddr)
   165  	testConnAddrFromSliceAndReader(t, addrDomain[:], addrDomainConnAddr)
   166  }
   167  
   168  func testConnAddrFromSliceWithDomainCache(t *testing.T, sa []byte, cachedDomain string, expectedAddr conn.Addr) string {
   169  	b := make([]byte, 512)
   170  	n := copy(b, sa)
   171  	_, err := rand.Read(b[n:])
   172  	if err != nil {
   173  		t.Fatal(err)
   174  	}
   175  	expectedTail := make([]byte, 512-n)
   176  	copy(expectedTail, b[n:])
   177  
   178  	addr, n, cachedDomain, err := ConnAddrFromSliceWithDomainCache(b, cachedDomain)
   179  	if err != nil {
   180  		t.Fatal(err)
   181  	}
   182  	if n != len(sa) {
   183  		t.Errorf("ConnAddrFromSlice(b) returned n=%d, expected n=%d.", n, len(sa))
   184  	}
   185  	if !addr.Equals(expectedAddr) {
   186  		t.Errorf("ConnAddrFromSlice(b) returned %s, expected %s.", addr, expectedAddr)
   187  	}
   188  	if !bytes.Equal(b[n:], expectedTail) {
   189  		t.Error("ConnAddrFromSlice(b) modified non-address bytes.")
   190  	}
   191  	return cachedDomain
   192  }
   193  
   194  func TestConnAddrFromSliceWithDomainCache(t *testing.T) {
   195  	const s = "🌐"
   196  	cachedDomain := s
   197  
   198  	cachedDomain = testConnAddrFromSliceWithDomainCache(t, addr4[:], cachedDomain, addr4connaddr)
   199  	if cachedDomain != s {
   200  		t.Errorf("ConnAddrFromSliceWithDomainCache(addr4) modified cachedDomain to %s.", cachedDomain)
   201  	}
   202  
   203  	cachedDomain = testConnAddrFromSliceWithDomainCache(t, addr4in6[:], cachedDomain, addr4connaddr)
   204  	if cachedDomain != s {
   205  		t.Errorf("ConnAddrFromSliceWithDomainCache(addr4in6) modified cachedDomain to %s.", cachedDomain)
   206  	}
   207  
   208  	cachedDomain = testConnAddrFromSliceWithDomainCache(t, addr6[:], cachedDomain, addr6connaddr)
   209  	if cachedDomain != s {
   210  		t.Errorf("ConnAddrFromSliceWithDomainCache(addr6) modified cachedDomain to %s.", cachedDomain)
   211  	}
   212  
   213  	cachedDomain = testConnAddrFromSliceWithDomainCache(t, addrDomain[:], cachedDomain, addrDomainConnAddr)
   214  	if cachedDomain != addrDomainHost {
   215  		t.Errorf("ConnAddrFromSliceWithDomainCache(addrDomain) modified cachedDomain to %s, expected %s.", cachedDomain, addrDomainHost)
   216  	}
   217  }
   218  
   219  func testAppendAddrFromConnAddr(t *testing.T, addr conn.Addr, expectedSA []byte) {
   220  	head := make([]byte, 64)
   221  	_, err := rand.Read(head)
   222  	if err != nil {
   223  		t.Fatal(err)
   224  	}
   225  
   226  	b := make([]byte, 0, 512)
   227  	b = append(b, head...)
   228  
   229  	b = AppendAddrFromConnAddr(b, addr)
   230  	if !bytes.Equal(b[:64], head) {
   231  		t.Error("Random head mismatch.")
   232  	}
   233  	if !bytes.Equal(b[64:], expectedSA) {
   234  		t.Errorf("Appended SOCKS address is %v, expected %v.", b[64:], expectedSA)
   235  	}
   236  }
   237  
   238  func TestAppendAddrFromConnAddr(t *testing.T) {
   239  	testAppendAddrFromConnAddr(t, addrZeroConnAddr, addrZero[:])
   240  	testAppendAddrFromConnAddr(t, addr4connaddr, addr4[:])
   241  	testAppendAddrFromConnAddr(t, addr4in6connaddr, addr4in6[:])
   242  	testAppendAddrFromConnAddr(t, addr6connaddr, addr6[:])
   243  	testAppendAddrFromConnAddr(t, addrDomainConnAddr, addrDomain[:])
   244  }
   245  
   246  func testLengthOfAndWriteAddrFromConnAddr(t *testing.T, addr conn.Addr, expectedSA []byte) {
   247  	addrLen := LengthOfAddrFromConnAddr(addr)
   248  	if addrLen != len(expectedSA) {
   249  		t.Errorf("LengthOfAddrFromConnAddr(addr) returned %d, expected %d.", addrLen, len(expectedSA))
   250  	}
   251  
   252  	b := make([]byte, 512)
   253  	_, err := rand.Read(b[addrLen:])
   254  	if err != nil {
   255  		t.Fatal(err)
   256  	}
   257  	tail := make([]byte, 512-addrLen)
   258  	copy(tail, b[addrLen:])
   259  
   260  	n := WriteAddrFromConnAddr(b, addr)
   261  	if n != len(expectedSA) {
   262  		t.Errorf("WriteAddrFromConnAddr(b, addr) returned n=%d, expected n=%d.", n, len(expectedSA))
   263  	}
   264  	if !bytes.Equal(b[:len(expectedSA)], expectedSA) {
   265  		t.Errorf("WriteAddrFromConnAddr(b, addr) wrote %v, expected %v.", b[:len(expectedSA)], expectedSA)
   266  	}
   267  	if !bytes.Equal(b[len(expectedSA):], tail) {
   268  		t.Error("WriteAddrFromConnAddr(b, addr) modified non-address bytes.")
   269  	}
   270  }
   271  
   272  func TestLengthOfAndWriteAddrFromConnAddr(t *testing.T) {
   273  	testLengthOfAndWriteAddrFromConnAddr(t, addrZeroConnAddr, addrZero[:])
   274  	testLengthOfAndWriteAddrFromConnAddr(t, addr4connaddr, addr4[:])
   275  	testLengthOfAndWriteAddrFromConnAddr(t, addr4in6connaddr, addr4in6[:])
   276  	testLengthOfAndWriteAddrFromConnAddr(t, addr6connaddr, addr6[:])
   277  	testLengthOfAndWriteAddrFromConnAddr(t, addrDomainConnAddr, addrDomain[:])
   278  }