github.com/database64128/shadowsocks-go@v1.7.0/socks5/addr_test.go (about)

     1  package socks5
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"io"
     7  	"net/netip"
     8  	"testing"
     9  
    10  	"github.com/database64128/shadowsocks-go/conn"
    11  )
    12  
    13  // Test IPv4 address.
    14  var (
    15  	addr4                = []byte{AtypIPv4, 127, 0, 0, 1, 4, 56}
    16  	addr4addr            = netip.AddrFrom4([4]byte{127, 0, 0, 1})
    17  	addr4port     uint16 = 1080
    18  	addr4addrport        = netip.AddrPortFrom(addr4addr, addr4port)
    19  	addr4connaddr        = conn.AddrFromIPPort(addr4addrport)
    20  )
    21  
    22  // Test IPv4-mapped IPv6 address.
    23  var (
    24  	addr4in6                = []byte{AtypIPv4, 127, 0, 0, 1, 4, 56}
    25  	addr4in6addr            = netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 127, 0, 0, 1})
    26  	addr4in6port     uint16 = 1080
    27  	addr4in6addrport        = netip.AddrPortFrom(addr4in6addr, addr4in6port)
    28  	addr4in6connaddr        = conn.AddrFromIPPort(addr4in6addrport)
    29  )
    30  
    31  // Test IPv6 address.
    32  var (
    33  	addr6                = []byte{AtypIPv6, 0x20, 0x01, 0x0d, 0xb8, 0xfa, 0xd6, 0x05, 0x72, 0xac, 0xbe, 0x71, 0x43, 0x14, 0xe5, 0x7a, 0x6e, 4, 56}
    34  	addr6addr            = netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8, 0xfa, 0xd6, 0x05, 0x72, 0xac, 0xbe, 0x71, 0x43, 0x14, 0xe5, 0x7a, 0x6e})
    35  	addr6port     uint16 = 1080
    36  	addr6addrport        = netip.AddrPortFrom(addr6addr, addr6port)
    37  	addr6connaddr        = conn.AddrFromIPPort(addr6addrport)
    38  )
    39  
    40  // Test domain name.
    41  var (
    42  	addrDomain                = []byte{AtypDomainName, 11, 'e', 'x', 'a', 'm', 'p', 'l', 'e', '.', 'c', 'o', 'm', 1, 187}
    43  	addrDomainHost            = "example.com"
    44  	addrDomainPort     uint16 = 443
    45  	addrDomainConnAddr        = conn.MustAddrFromDomainPort(addrDomainHost, addrDomainPort)
    46  )
    47  
    48  func testAddrFromReader(t *testing.T, addr []byte) {
    49  	b := make([]byte, 512)
    50  	n := copy(b, addr)
    51  	_, err := rand.Read(b[n:])
    52  	if err != nil {
    53  		t.Fatal(err)
    54  	}
    55  	expectedTail := make([]byte, 512-n)
    56  	copy(expectedTail, b[n:])
    57  
    58  	r := bytes.NewReader(b)
    59  	raddr, err := AddrFromReader(r)
    60  	if err != nil {
    61  		t.Fatal(err)
    62  	}
    63  	if !bytes.Equal(addr, raddr) {
    64  		t.Errorf("Expected: %v\nGot: %v", addr, []byte(raddr))
    65  	}
    66  	tail, err := io.ReadAll(r)
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	if !bytes.Equal(tail, expectedTail) {
    71  		t.Error("AddrFromReader(r) read more bytes than expected.")
    72  	}
    73  }
    74  
    75  func TestAddrFromReader(t *testing.T) {
    76  	testAddrFromReader(t, addr4)
    77  	testAddrFromReader(t, addr4in6)
    78  	testAddrFromReader(t, addr6)
    79  	testAddrFromReader(t, addrDomain)
    80  }
    81  
    82  func testAddrPortFromSlice(t *testing.T, sa []byte, expectedAddrPort netip.AddrPort, expectedN int, expectedErr error) {
    83  	b := make([]byte, 512)
    84  	n := copy(b, sa)
    85  	_, err := rand.Read(b[n:])
    86  	if err != nil {
    87  		t.Fatal(err)
    88  	}
    89  	expectedTail := make([]byte, 512-n)
    90  	copy(expectedTail, b[n:])
    91  
    92  	addrPort, n, err := AddrPortFromSlice(b)
    93  	if err != expectedErr {
    94  		t.Errorf("AddrPortFromSlice(b) returned error %s, expected error %s", err, expectedErr)
    95  	}
    96  	if n != expectedN {
    97  		t.Errorf("AddrPortFromSlice(b) returned n=%d, expected n=%d.", n, expectedN)
    98  	}
    99  	if addrPort != expectedAddrPort {
   100  		t.Errorf("AddrPortFromSlice(b) returned %s, expected %s.", addrPort, expectedAddrPort)
   101  	}
   102  	if !bytes.Equal(b[len(sa):], expectedTail) {
   103  		t.Error("AddrPortFromSlice(b) modified non-address bytes.")
   104  	}
   105  }
   106  
   107  func TestAddrPortFromSlice(t *testing.T) {
   108  	testAddrPortFromSlice(t, addr4, addr4addrport, len(addr4), nil)
   109  	testAddrPortFromSlice(t, addr4in6, addr4addrport, len(addr4in6), nil)
   110  	testAddrPortFromSlice(t, addr6, addr6addrport, len(addr6), nil)
   111  	testAddrPortFromSlice(t, addrDomain, netip.AddrPort{}, 0, errDomain)
   112  }
   113  
   114  func testConnAddrFromSliceAndReader(t *testing.T, sa []byte, expectedAddr conn.Addr) {
   115  	b := make([]byte, 512)
   116  	n := copy(b, sa)
   117  	_, err := rand.Read(b[n:])
   118  	if err != nil {
   119  		t.Fatal(err)
   120  	}
   121  	expectedTail := make([]byte, 512-n)
   122  	copy(expectedTail, b[n:])
   123  
   124  	addr, n, err := ConnAddrFromSlice(b)
   125  	if err != nil {
   126  		t.Fatal(err)
   127  	}
   128  	if n != len(sa) {
   129  		t.Errorf("ConnAddrFromSlice(b) returned n=%d, expected n=%d.", n, len(sa))
   130  	}
   131  	if !addr.Equals(expectedAddr) {
   132  		t.Errorf("ConnAddrFromSlice(b) returned %s, expected %s.", addr, expectedAddr)
   133  	}
   134  	if !bytes.Equal(b[n:], expectedTail) {
   135  		t.Error("ConnAddrFromSlice(b) modified non-address bytes.")
   136  	}
   137  
   138  	r := bytes.NewReader(b)
   139  	addr, err = ConnAddrFromReader(r)
   140  	if err != nil {
   141  		t.Fatal(err)
   142  	}
   143  	if !addr.Equals(expectedAddr) {
   144  		t.Errorf("ConnAddrFromReader(r) returned %s, expected %s.", addr, expectedAddr)
   145  	}
   146  	tail, err := io.ReadAll(r)
   147  	if err != nil {
   148  		t.Fatal(err)
   149  	}
   150  	if !bytes.Equal(tail, expectedTail) {
   151  		t.Error("ConnAddrFromReader(r) read more bytes than expected.")
   152  	}
   153  }
   154  
   155  func TestConnAddrFromSliceAndReader(t *testing.T) {
   156  	testConnAddrFromSliceAndReader(t, addr4, addr4connaddr)
   157  	testConnAddrFromSliceAndReader(t, addr4in6, addr4connaddr)
   158  	testConnAddrFromSliceAndReader(t, addr6, addr6connaddr)
   159  	testConnAddrFromSliceAndReader(t, addrDomain, addrDomainConnAddr)
   160  }
   161  
   162  func testConnAddrFromSliceWithDomainCache(t *testing.T, sa []byte, cachedDomain string, expectedAddr conn.Addr) string {
   163  	b := make([]byte, 512)
   164  	n := copy(b, sa)
   165  	_, err := rand.Read(b[n:])
   166  	if err != nil {
   167  		t.Fatal(err)
   168  	}
   169  	expectedTail := make([]byte, 512-n)
   170  	copy(expectedTail, b[n:])
   171  
   172  	addr, n, cachedDomain, err := ConnAddrFromSliceWithDomainCache(b, cachedDomain)
   173  	if err != nil {
   174  		t.Fatal(err)
   175  	}
   176  	if n != len(sa) {
   177  		t.Errorf("ConnAddrFromSlice(b) returned n=%d, expected n=%d.", n, len(sa))
   178  	}
   179  	if !addr.Equals(expectedAddr) {
   180  		t.Errorf("ConnAddrFromSlice(b) returned %s, expected %s.", addr, expectedAddr)
   181  	}
   182  	if !bytes.Equal(b[n:], expectedTail) {
   183  		t.Error("ConnAddrFromSlice(b) modified non-address bytes.")
   184  	}
   185  	return cachedDomain
   186  }
   187  
   188  func TestConnAddrFromSliceWithDomainCache(t *testing.T) {
   189  	const s = "🌐"
   190  	cachedDomain := s
   191  
   192  	cachedDomain = testConnAddrFromSliceWithDomainCache(t, addr4, cachedDomain, addr4connaddr)
   193  	if cachedDomain != s {
   194  		t.Errorf("ConnAddrFromSliceWithDomainCache(addr4) modified cachedDomain to %s.", cachedDomain)
   195  	}
   196  
   197  	cachedDomain = testConnAddrFromSliceWithDomainCache(t, addr4in6, cachedDomain, addr4connaddr)
   198  	if cachedDomain != s {
   199  		t.Errorf("ConnAddrFromSliceWithDomainCache(addr4in6) modified cachedDomain to %s.", cachedDomain)
   200  	}
   201  
   202  	cachedDomain = testConnAddrFromSliceWithDomainCache(t, addr6, cachedDomain, addr6connaddr)
   203  	if cachedDomain != s {
   204  		t.Errorf("ConnAddrFromSliceWithDomainCache(addr6) modified cachedDomain to %s.", cachedDomain)
   205  	}
   206  
   207  	cachedDomain = testConnAddrFromSliceWithDomainCache(t, addrDomain, cachedDomain, addrDomainConnAddr)
   208  	if cachedDomain != addrDomainHost {
   209  		t.Errorf("ConnAddrFromSliceWithDomainCache(addrDomain) modified cachedDomain to %s, expected %s.", cachedDomain, addrDomainHost)
   210  	}
   211  }
   212  
   213  func testAppendAddrFromConnAddr(t *testing.T, addr conn.Addr, expectedSA []byte) {
   214  	head := make([]byte, 64)
   215  	_, err := rand.Read(head)
   216  	if err != nil {
   217  		t.Fatal(err)
   218  	}
   219  
   220  	b := make([]byte, 0, 512)
   221  	b = append(b, head...)
   222  
   223  	b = AppendAddrFromConnAddr(b, addr)
   224  	if !bytes.Equal(b[:64], head) {
   225  		t.Error("Random head mismatch.")
   226  	}
   227  	if !bytes.Equal(b[64:], expectedSA) {
   228  		t.Errorf("Appended SOCKS address is %v, expected %v.", b[64:], expectedSA)
   229  	}
   230  }
   231  
   232  func TestAppendAddrFromConnAddr(t *testing.T) {
   233  	testAppendAddrFromConnAddr(t, addr4connaddr, addr4)
   234  	testAppendAddrFromConnAddr(t, addr4in6connaddr, addr4in6)
   235  	testAppendAddrFromConnAddr(t, addr6connaddr, addr6)
   236  	testAppendAddrFromConnAddr(t, addrDomainConnAddr, addrDomain)
   237  }
   238  
   239  func testLengthOfAndWriteAddrFromConnAddr(t *testing.T, addr conn.Addr, expectedSA []byte) {
   240  	addrLen := LengthOfAddrFromConnAddr(addr)
   241  	if addrLen != len(expectedSA) {
   242  		t.Errorf("LengthOfAddrFromConnAddr(addr) returned %d, expected %d.", addrLen, len(expectedSA))
   243  	}
   244  
   245  	b := make([]byte, 512)
   246  	_, err := rand.Read(b[addrLen:])
   247  	if err != nil {
   248  		t.Fatal(err)
   249  	}
   250  	tail := make([]byte, 512-addrLen)
   251  	copy(tail, b[addrLen:])
   252  
   253  	n := WriteAddrFromConnAddr(b, addr)
   254  	if n != len(expectedSA) {
   255  		t.Errorf("WriteAddrFromConnAddr(b, addr) returned n=%d, expected n=%d.", n, len(expectedSA))
   256  	}
   257  	if !bytes.Equal(b[:len(expectedSA)], expectedSA) {
   258  		t.Errorf("WriteAddrFromConnAddr(b, addr) wrote %v, expected %v.", b[:len(expectedSA)], expectedSA)
   259  	}
   260  	if !bytes.Equal(b[len(expectedSA):], tail) {
   261  		t.Error("WriteAddrFromConnAddr(b, addr) modified non-address bytes.")
   262  	}
   263  }
   264  
   265  func TestLengthOfAndWriteAddrFromConnAddr(t *testing.T) {
   266  	testLengthOfAndWriteAddrFromConnAddr(t, addr4connaddr, addr4)
   267  	testLengthOfAndWriteAddrFromConnAddr(t, addr4in6connaddr, addr4in6)
   268  	testLengthOfAndWriteAddrFromConnAddr(t, addr6connaddr, addr6)
   269  	testLengthOfAndWriteAddrFromConnAddr(t, addrDomainConnAddr, addrDomain)
   270  }