github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/ss2022/header_test.go (about)

     1  package ss2022
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"encoding/binary"
     7  	"errors"
     8  	"math"
     9  	mrand "math/rand/v2"
    10  	"net/netip"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/database64128/shadowsocks-go/conn"
    15  	"github.com/database64128/shadowsocks-go/socks5"
    16  )
    17  
    18  func TestHeaderErrorString(t *testing.T) {
    19  	const errMsg = "time diff is over 30 seconds: expected 1, got 2"
    20  	err := HeaderError[int]{ErrBadTimestamp, 1, 2}
    21  	if err.Error() != errMsg {
    22  		t.FailNow()
    23  	}
    24  }
    25  
    26  func TestWriteAndParseTCPRequestFixedLengthHeader(t *testing.T) {
    27  	b := make([]byte, TCPRequestFixedLengthHeaderLength)
    28  	length := int(mrand.Uint64() & math.MaxUint16)
    29  
    30  	// 1. Good header
    31  	WriteTCPRequestFixedLengthHeader(b, uint16(length))
    32  
    33  	n, err := ParseTCPRequestFixedLengthHeader(b)
    34  	if err != nil {
    35  		t.Fatal(err)
    36  	}
    37  	if n != length {
    38  		t.Fatalf("Expected: %d\nGot: %d", length, n)
    39  	}
    40  
    41  	// 2. Bad timestamp (31s ago)
    42  	ts := time.Now().Add(-31 * time.Second)
    43  	binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix()))
    44  
    45  	_, err = ParseTCPRequestFixedLengthHeader(b)
    46  	if !errors.Is(err, ErrBadTimestamp) {
    47  		t.Fatalf("Expected: %s\nGot: %s", ErrBadTimestamp, err)
    48  	}
    49  
    50  	// 3. Bad timestamp (31s later)
    51  	ts = time.Now().Add(31 * time.Second)
    52  	binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix()))
    53  
    54  	_, err = ParseTCPRequestFixedLengthHeader(b)
    55  	if !errors.Is(err, ErrBadTimestamp) {
    56  		t.Fatalf("Expected: %s\nGot: %s", ErrBadTimestamp, err)
    57  	}
    58  
    59  	// 4. Bad type
    60  	b[0] = HeaderTypeServerStream
    61  
    62  	_, err = ParseTCPRequestFixedLengthHeader(b)
    63  	if !errors.Is(err, ErrTypeMismatch) {
    64  		t.Fatalf("Expected: %s\nGot: %s", ErrTypeMismatch, err)
    65  	}
    66  }
    67  
    68  func TestWriteAndParseTCPRequestVariableLengthHeader(t *testing.T) {
    69  	payloadLen := 1 + int(mrand.Uint64()&1023)
    70  	payload := make([]byte, payloadLen)
    71  	_, err := rand.Read(payload)
    72  	if err != nil {
    73  		t.Fatal(err)
    74  	}
    75  	targetAddr := conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Unspecified(), 443))
    76  	targetAddrLen := socks5.LengthOfAddrFromConnAddr(targetAddr)
    77  	noPayloadLen := targetAddrLen + 2 + 1 + mrand.IntN(MaxPaddingLength)
    78  	noPaddingLen := targetAddrLen + 2 + payloadLen
    79  	bufLen := noPaddingLen + MaxPaddingLength
    80  	b := make([]byte, bufLen)
    81  
    82  	// 1. Good header (padding + initial payload)
    83  	WriteTCPRequestVariableLengthHeader(b, targetAddr, payload)
    84  
    85  	ta, p, err := ParseTCPRequestVariableLengthHeader(b)
    86  	if err != nil {
    87  		t.Fatal(err)
    88  	}
    89  	if !bytes.Equal(p, payload) {
    90  		t.Fatalf("Expected payload %v\nGot: %v", payload, p)
    91  	}
    92  	if !ta.Equals(targetAddr) {
    93  		t.Fatalf("Expected target address %s, got %s", targetAddr, ta)
    94  	}
    95  
    96  	// 2. Good header (initial payload)
    97  	b = b[:noPaddingLen]
    98  	WriteTCPRequestVariableLengthHeader(b, targetAddr, payload)
    99  
   100  	ta, p, err = ParseTCPRequestVariableLengthHeader(b)
   101  	if err != nil {
   102  		t.Fatal(err)
   103  	}
   104  	if !bytes.Equal(p, payload) {
   105  		t.Fatalf("Expected payload %v\nGot: %v", payload, p)
   106  	}
   107  	if !ta.Equals(targetAddr) {
   108  		t.Fatalf("Expected target address %s, got %s", targetAddr, ta)
   109  	}
   110  
   111  	// 3. Good header (padding)
   112  	b = b[:noPayloadLen]
   113  	WriteTCPRequestVariableLengthHeader(b, targetAddr, nil)
   114  
   115  	ta, p, err = ParseTCPRequestVariableLengthHeader(b)
   116  	if err != nil {
   117  		t.Fatal(err)
   118  	}
   119  	if len(p) > 0 {
   120  		t.Fatalf("Expected empty initial payload, got length %d", len(p))
   121  	}
   122  	if !ta.Equals(targetAddr) {
   123  		t.Fatalf("Expected target address %s, got %s", targetAddr, ta)
   124  	}
   125  
   126  	// 4. Bad header (incomplete padding)
   127  	b = b[:noPayloadLen-1]
   128  
   129  	_, _, err = ParseTCPRequestVariableLengthHeader(b)
   130  	if !errors.Is(err, ErrPaddingExceedChunkBorder) {
   131  		t.Fatalf("Expected: %s\nGot: %s", ErrPaddingExceedChunkBorder, err)
   132  	}
   133  
   134  	// 5. Bad header (incomplete padding length)
   135  	b = b[:targetAddrLen+1]
   136  
   137  	_, _, err = ParseTCPRequestVariableLengthHeader(b)
   138  	if !errors.Is(err, ErrIncompleteHeaderInFirstChunk) {
   139  		t.Fatalf("Expected: %s\nGot: %s", ErrIncompleteHeaderInFirstChunk, err)
   140  	}
   141  
   142  	// 6. Bad header (incomplete SOCKS address)
   143  	b = b[:targetAddrLen-1]
   144  
   145  	_, _, err = ParseTCPRequestVariableLengthHeader(b)
   146  	if err == nil {
   147  		t.Fatal("Expected error, got nil")
   148  	}
   149  }
   150  
   151  func TestWriteAndParseTCPResponseHeader(t *testing.T) {
   152  	const (
   153  		saltLen = 32
   154  		bufLen  = 1 + 8 + saltLen + 2
   155  	)
   156  
   157  	b := make([]byte, bufLen)
   158  	length := int(mrand.Uint64() & math.MaxUint16)
   159  	requestSalt := make([]byte, saltLen)
   160  	_, err := rand.Read(requestSalt)
   161  	if err != nil {
   162  		t.Fatal(err)
   163  	}
   164  
   165  	// 1. Good header
   166  	WriteTCPResponseHeader(b, requestSalt, uint16(length))
   167  
   168  	n, err := ParseTCPResponseHeader(b, requestSalt)
   169  	if err != nil {
   170  		t.Fatal(err)
   171  	}
   172  	if n != length {
   173  		t.Fatalf("Expected: %d\nGot: %d", length, n)
   174  	}
   175  
   176  	// 2. Bad request salt
   177  	_, err = rand.Read(b[1+8 : 1+8+saltLen])
   178  	if err != nil {
   179  		t.Fatal(err)
   180  	}
   181  
   182  	_, err = ParseTCPResponseHeader(b, requestSalt)
   183  	if !errors.Is(err, ErrClientSaltMismatch) {
   184  		t.Fatalf("Expected: %s\nGot: %s", ErrClientSaltMismatch, err)
   185  	}
   186  
   187  	// 3. Bad timestamp (31s ago)
   188  	ts := time.Now().Add(-31 * time.Second)
   189  	binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix()))
   190  
   191  	_, err = ParseTCPResponseHeader(b, requestSalt)
   192  	if !errors.Is(err, ErrBadTimestamp) {
   193  		t.Fatalf("Expected: %s\nGot: %s", ErrBadTimestamp, err)
   194  	}
   195  
   196  	// 4. Bad timestamp (31s later)
   197  	ts = time.Now().Add(31 * time.Second)
   198  	binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix()))
   199  
   200  	_, err = ParseTCPResponseHeader(b, requestSalt)
   201  	if !errors.Is(err, ErrBadTimestamp) {
   202  		t.Fatalf("Expected: %s\nGot: %s", ErrBadTimestamp, err)
   203  	}
   204  
   205  	// 5. Bad type
   206  	b[0] = HeaderTypeClientStream
   207  
   208  	_, err = ParseTCPResponseHeader(b, requestSalt)
   209  	if !errors.Is(err, ErrTypeMismatch) {
   210  		t.Fatalf("Expected: %s\nGot: %s", ErrTypeMismatch, err)
   211  	}
   212  }
   213  
   214  func TestWriteAndParseSessionIDAndPacketID(t *testing.T) {
   215  	sid := mrand.Uint64()
   216  	pid := mrand.Uint64()
   217  	b := make([]byte, 16)
   218  
   219  	WriteSessionIDAndPacketID(b, sid, pid)
   220  	psid, ppid := ParseSessionIDAndPacketID(b)
   221  	if psid != sid {
   222  		t.Fatalf("Expected session ID %d, got %d", sid, psid)
   223  	}
   224  	if ppid != pid {
   225  		t.Fatalf("Expected packet ID %d, got %d", pid, ppid)
   226  	}
   227  }
   228  
   229  func TestWriteAndParseUDPClientMessageHeader(t *testing.T) {
   230  	var cachedDomain string
   231  	targetAddr := conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Unspecified(), 53))
   232  	targetAddrLen := socks5.LengthOfAddrFromConnAddr(targetAddr)
   233  	noPaddingLen := UDPClientMessageHeaderFixedLength + targetAddrLen
   234  	paddingLen := 1 + mrand.IntN(MaxPaddingLength)
   235  	headerLen := noPaddingLen + paddingLen
   236  	payloadLen := 1 + int(mrand.Uint64()&math.MaxUint16)
   237  	bufLen := headerLen + payloadLen
   238  	b := make([]byte, bufLen)
   239  	bNoPadding := b[paddingLen:]
   240  	headerBuf := b[:headerLen]
   241  	headerNoPaddingBuf := bNoPadding[:noPaddingLen]
   242  	payload := b[headerLen:]
   243  	_, err := rand.Read(payload)
   244  	if err != nil {
   245  		t.Fatal(err)
   246  	}
   247  
   248  	// 1. Good header (no padding)
   249  	WriteUDPClientMessageHeader(headerNoPaddingBuf, 0, targetAddr)
   250  
   251  	ta, cachedDomain, ps, pl, err := ParseUDPClientMessageHeader(bNoPadding, cachedDomain)
   252  	if err != nil {
   253  		t.Fatal(err)
   254  	}
   255  	ps += headerLen - noPaddingLen
   256  	if ps != headerLen {
   257  		t.Errorf("Expected payload start %d, got %d", headerLen, ps)
   258  	}
   259  	if pl != payloadLen {
   260  		t.Errorf("Expected payload length %d, got %d", payloadLen, pl)
   261  	}
   262  	if !ta.Equals(targetAddr) {
   263  		t.Errorf("Expected target address %s, got %s", targetAddr, ta)
   264  	}
   265  
   266  	// 2. Good header (padding)
   267  	WriteUDPClientMessageHeader(headerBuf, paddingLen, targetAddr)
   268  
   269  	ta, cachedDomain, ps, pl, err = ParseUDPClientMessageHeader(b, cachedDomain)
   270  	if err != nil {
   271  		t.Fatal(err)
   272  	}
   273  	if ps != headerLen {
   274  		t.Errorf("Expected payload start %d, got %d", headerLen, ps)
   275  	}
   276  	if pl != payloadLen {
   277  		t.Errorf("Expected payload length %d, got %d", payloadLen, pl)
   278  	}
   279  	if !ta.Equals(targetAddr) {
   280  		t.Errorf("Expected target address %s, got %s", targetAddr, ta)
   281  	}
   282  
   283  	// 3. Bad header (incomplete SOCKS address)
   284  	b = b[:headerLen-1]
   285  
   286  	_, cachedDomain, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain)
   287  	if err == nil {
   288  		t.Error("Expected error, got nil")
   289  	}
   290  
   291  	// 4. Bad header (incomplete padding)
   292  	b = b[:len(b)-targetAddrLen]
   293  
   294  	_, cachedDomain, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain)
   295  	if !errors.Is(err, ErrPacketIncompleteHeader) {
   296  		t.Errorf("Expected: %s\nGot: %s", ErrPacketIncompleteHeader, err)
   297  	}
   298  
   299  	// 5. Bad header (incomplete padding length)
   300  	b = b[:1+8+1]
   301  
   302  	_, cachedDomain, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain)
   303  	if !errors.Is(err, ErrPacketIncompleteHeader) {
   304  		t.Errorf("Expected: %s\nGot: %s", ErrPacketIncompleteHeader, err)
   305  	}
   306  
   307  	// 6. Bad timestamp (31s ago)
   308  	b = b[:UDPClientMessageHeaderFixedLength]
   309  
   310  	ts := time.Now().Add(-31 * time.Second)
   311  	binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix()))
   312  
   313  	_, cachedDomain, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain)
   314  	if !errors.Is(err, ErrBadTimestamp) {
   315  		t.Errorf("Expected: %s\nGot: %s", ErrBadTimestamp, err)
   316  	}
   317  
   318  	// 7. Bad timestamp (31s later)
   319  	ts = time.Now().Add(31 * time.Second)
   320  	binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix()))
   321  
   322  	_, cachedDomain, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain)
   323  	if !errors.Is(err, ErrBadTimestamp) {
   324  		t.Errorf("Expected: %s\nGot: %s", ErrBadTimestamp, err)
   325  	}
   326  
   327  	// 8. Bad type
   328  	b[0] = HeaderTypeServerPacket
   329  
   330  	_, _, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain)
   331  	if !errors.Is(err, ErrTypeMismatch) {
   332  		t.Errorf("Expected: %s\nGot: %s", ErrTypeMismatch, err)
   333  	}
   334  }
   335  
   336  func TestWriteAndParseUDPServerMessageHeader(t *testing.T) {
   337  	csid := mrand.Uint64()
   338  	sourceAddrPort := netip.AddrPortFrom(netip.IPv6Unspecified(), 53)
   339  	sourceAddrPortLen := socks5.LengthOfAddrFromAddrPort(sourceAddrPort)
   340  	noPaddingLen := UDPServerMessageHeaderFixedLength + sourceAddrPortLen
   341  	paddingLen := 1 + mrand.IntN(MaxPaddingLength)
   342  	headerLen := noPaddingLen + paddingLen
   343  	payloadLen := 1 + int(mrand.Uint64()&math.MaxUint16)
   344  	bufLen := headerLen + payloadLen
   345  	b := make([]byte, bufLen)
   346  	bNoPadding := b[paddingLen:]
   347  	headerBuf := b[:headerLen]
   348  	headerNoPaddingBuf := bNoPadding[:noPaddingLen]
   349  	payload := b[headerLen:]
   350  	_, err := rand.Read(payload)
   351  	if err != nil {
   352  		t.Fatal(err)
   353  	}
   354  
   355  	// 1. Good header (no padding)
   356  	WriteUDPServerMessageHeader(headerNoPaddingBuf, csid, 0, sourceAddrPort)
   357  
   358  	sa, ps, pl, err := ParseUDPServerMessageHeader(bNoPadding, csid)
   359  	if err != nil {
   360  		t.Fatal(err)
   361  	}
   362  	ps += headerLen - noPaddingLen
   363  	if ps != headerLen {
   364  		t.Errorf("Expected payload start %d, got %d", headerLen, ps)
   365  	}
   366  	if pl != payloadLen {
   367  		t.Errorf("Expected payload length %d, got %d", payloadLen, pl)
   368  	}
   369  	if sa != sourceAddrPort {
   370  		t.Errorf("Expected target address %s, got %s", sourceAddrPort, sa)
   371  	}
   372  
   373  	// 2. Good header (pad)
   374  	WriteUDPServerMessageHeader(headerBuf, csid, paddingLen, sourceAddrPort)
   375  
   376  	sa, ps, pl, err = ParseUDPServerMessageHeader(b, csid)
   377  	if err != nil {
   378  		t.Fatal(err)
   379  	}
   380  	if ps != headerLen {
   381  		t.Errorf("Expected payload start %d, got %d", headerLen, ps)
   382  	}
   383  	if pl != payloadLen {
   384  		t.Errorf("Expected payload length %d, got %d", payloadLen, pl)
   385  	}
   386  	if sa != sourceAddrPort {
   387  		t.Errorf("Expected target address %s, got %s", sourceAddrPort, sa)
   388  	}
   389  
   390  	// 3. Bad header (incomplete SOCKS address)
   391  	b = b[:headerLen-1]
   392  
   393  	_, _, _, err = ParseUDPServerMessageHeader(b, csid)
   394  	if err == nil {
   395  		t.Error("Expected error, got nil")
   396  	}
   397  
   398  	// 4. Bad header (incomplete padding)
   399  	b = b[:len(b)-sourceAddrPortLen]
   400  
   401  	_, _, _, err = ParseUDPServerMessageHeader(b, csid)
   402  	if !errors.Is(err, ErrPacketIncompleteHeader) {
   403  		t.Errorf("Expected: %s\nGot: %s", ErrPacketIncompleteHeader, err)
   404  	}
   405  
   406  	// 5. Bad header (incomplete padding length)
   407  	b = b[:1+8+8+1]
   408  
   409  	_, _, _, err = ParseUDPServerMessageHeader(b, csid)
   410  	if !errors.Is(err, ErrPacketIncompleteHeader) {
   411  		t.Errorf("Expected: %s\nGot: %s", ErrPacketIncompleteHeader, err)
   412  	}
   413  
   414  	// 6. Bad client session ID
   415  	b = b[:UDPServerMessageHeaderFixedLength]
   416  	badCsid := csid + 1
   417  	binary.BigEndian.PutUint64(b[1+8:], badCsid)
   418  
   419  	_, _, _, err = ParseUDPServerMessageHeader(b, csid)
   420  	if !errors.Is(err, ErrClientSessionIDMismatch) {
   421  		t.Errorf("Expected: %s\nGot: %s", ErrClientSessionIDMismatch, err)
   422  	}
   423  
   424  	// 7. Bad timestamp (31s ago)
   425  	ts := time.Now().Add(-31 * time.Second)
   426  	binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix()))
   427  
   428  	_, _, _, err = ParseUDPServerMessageHeader(b, csid)
   429  	if !errors.Is(err, ErrBadTimestamp) {
   430  		t.Errorf("Expected: %s\nGot: %s", ErrBadTimestamp, err)
   431  	}
   432  
   433  	// 8. Bad timestamp (31s later)
   434  	ts = time.Now().Add(31 * time.Second)
   435  	binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix()))
   436  
   437  	_, _, _, err = ParseUDPServerMessageHeader(b, csid)
   438  	if !errors.Is(err, ErrBadTimestamp) {
   439  		t.Errorf("Expected: %s\nGot: %s", ErrBadTimestamp, err)
   440  	}
   441  
   442  	// 9. Bad type
   443  	b[0] = HeaderTypeClientPacket
   444  
   445  	_, _, _, err = ParseUDPServerMessageHeader(b, csid)
   446  	if !errors.Is(err, ErrTypeMismatch) {
   447  		t.Errorf("Expected: %s\nGot: %s", ErrTypeMismatch, err)
   448  	}
   449  }