github.com/database64128/shadowsocks-go@v1.7.0/ss2022/udp_test.go (about)

     1  package ss2022
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"errors"
     7  	"net/netip"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/database64128/shadowsocks-go/conn"
    12  )
    13  
    14  const (
    15  	name       = "test"
    16  	mtu        = 1500
    17  	packetSize = 1452
    18  	payloadLen = 1280
    19  	fwmark     = 10240
    20  )
    21  
    22  // UDP jumbograms.
    23  const (
    24  	jumboMTU        = 128 * 1024
    25  	jumboPacketSize = 128*1024 - 40 - 8 - 8
    26  	jumboPayloadLen = 127 * 1024
    27  )
    28  
    29  var (
    30  	targetAddr           = conn.AddrFromIPPort(targetAddrPort)
    31  	targetAddrPort       = netip.AddrPortFrom(netip.IPv6Unspecified(), 53)
    32  	serverAddrPort       = netip.AddrPortFrom(netip.IPv6Unspecified(), 1080)
    33  	clientAddrPort       = netip.AddrPortFrom(netip.IPv6Unspecified(), 10800)
    34  	replayClientAddrPort = netip.AddrPortFrom(netip.IPv6Unspecified(), 10801)
    35  	replayServerAddrPort = netip.AddrPortFrom(netip.IPv6Unspecified(), 10802)
    36  )
    37  
    38  func testUDPClientServer(t *testing.T, clientCipherConfig *ClientCipherConfig, userCipherConfig UserCipherConfig, identityCipherConfig ServerIdentityCipherConfig, userLookupMap UserLookupMap, clientShouldPad, serverShouldPad PaddingPolicy, mtu, packetSize, payloadLen int) {
    39  	c := NewUDPClient(serverAddrPort, name, mtu, conn.DefaultUDPClientListenConfig, clientCipherConfig, clientShouldPad)
    40  	s := NewUDPServer(userCipherConfig, identityCipherConfig, serverShouldPad)
    41  	s.ReplaceUserLookupMap(userLookupMap)
    42  
    43  	clientInfo := c.Info()
    44  	if clientInfo.Name != name {
    45  		t.Errorf("Fixed name mismatch: in: %s, out: %s", name, clientInfo.Name)
    46  	}
    47  	if clientInfo.MaxPacketSize != packetSize {
    48  		t.Errorf("Fixed MTU mismatch: in: %d, out: %d", mtu, clientInfo.MaxPacketSize)
    49  	}
    50  
    51  	_, clientPacker, clientUnpacker, err := c.NewSession()
    52  	clientPackerInfo := clientPacker.ClientPackerInfo()
    53  	frontHeadroom := clientPackerInfo.Headroom.Front + 8 // Compensate for server message overhead.
    54  	rearHeadroom := clientPackerInfo.Headroom.Rear
    55  	b := make([]byte, frontHeadroom+payloadLen+rearHeadroom)
    56  	payload := b[frontHeadroom : frontHeadroom+payloadLen]
    57  
    58  	// Fill random payload.
    59  	_, err = rand.Read(payload)
    60  	if err != nil {
    61  		t.Fatal(err)
    62  	}
    63  
    64  	// Backup payload.
    65  	payloadBackup := make([]byte, len(payload))
    66  	copy(payloadBackup, payload)
    67  
    68  	// Client packs.
    69  	dap, pkts, pktl, err := clientPacker.PackInPlace(b, targetAddr, frontHeadroom, payloadLen)
    70  	if err != nil {
    71  		t.Fatal(err)
    72  	}
    73  	if dap != serverAddrPort {
    74  		t.Errorf("Expected packed client packet destAddrPort %s, got %s", serverAddrPort, dap)
    75  	}
    76  	p := b[pkts : pkts+pktl]
    77  
    78  	// Server unpacks.
    79  	csid, err := s.SessionInfo(p)
    80  	if err != nil {
    81  		t.Fatal(err)
    82  	}
    83  	serverUnpacker, _, err := s.NewUnpacker(p, csid)
    84  	if err != nil {
    85  		t.Fatal(err)
    86  	}
    87  	ta, ps, pl, err := serverUnpacker.UnpackInPlace(b, clientAddrPort, pkts, pktl)
    88  	if err != nil {
    89  		t.Error(err)
    90  	}
    91  
    92  	// Check target address.
    93  	if !ta.Equals(targetAddr) {
    94  		t.Errorf("Target address mismatch: c: %s, s: %s", targetAddr, ta)
    95  	}
    96  
    97  	// Check payload.
    98  	p = b[ps : ps+pl]
    99  	if !bytes.Equal(payloadBackup, p) {
   100  		t.Errorf("Payload mismatch: c: %v, s: %v", payloadBackup, p)
   101  	}
   102  
   103  	// Fill random again.
   104  	_, err = rand.Read(payload)
   105  	if err != nil {
   106  		t.Fatal(err)
   107  	}
   108  	copy(payloadBackup, payload)
   109  
   110  	// Server packs.
   111  	serverPacker, err := serverUnpacker.NewPacker()
   112  	if err != nil {
   113  		t.Fatal(err)
   114  	}
   115  	pkts, pktl, err = serverPacker.PackInPlace(b, targetAddrPort, frontHeadroom, payloadLen, packetSize)
   116  	if err != nil {
   117  		t.Fatal(err)
   118  	}
   119  
   120  	// Client unpacks.
   121  	tap, ps, pl, err := clientUnpacker.UnpackInPlace(b, serverAddrPort, pkts, pktl)
   122  	if err != nil {
   123  		t.Error(err)
   124  	}
   125  
   126  	// Check target address.
   127  	if tap != targetAddrPort {
   128  		t.Errorf("Target address mismatch: s: %s, c: %s", targetAddrPort, tap)
   129  	}
   130  
   131  	// Check payload.
   132  	p = b[ps : ps+pl]
   133  	if !bytes.Equal(payloadBackup, p) {
   134  		t.Errorf("Payload mismatch: s: %v, c: %v", payloadBackup, p)
   135  	}
   136  }
   137  
   138  func testUDPClientServerSessionChangeAndReplay(t *testing.T, clientCipherConfig *ClientCipherConfig, userCipherConfig UserCipherConfig, identityCipherConfig ServerIdentityCipherConfig, userLookupMap UserLookupMap) {
   139  	shouldPad, err := ParsePaddingPolicy("")
   140  	if err != nil {
   141  		t.Fatal(err)
   142  	}
   143  
   144  	c := NewUDPClient(serverAddrPort, name, mtu, conn.DefaultUDPClientListenConfig, clientCipherConfig, shouldPad)
   145  	s := NewUDPServer(userCipherConfig, identityCipherConfig, shouldPad)
   146  	s.ReplaceUserLookupMap(userLookupMap)
   147  
   148  	_, clientPacker, clientUnpacker, err := c.NewSession()
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  
   153  	clientPackerInfo := clientPacker.ClientPackerInfo()
   154  	frontHeadroom := clientPackerInfo.Headroom.Front + 8 // Compensate for server message overhead.
   155  	rearHeadroom := clientPackerInfo.Headroom.Rear
   156  	b := make([]byte, frontHeadroom+payloadLen+rearHeadroom)
   157  
   158  	// Client packs.
   159  	dap, pkts, pktl, err := clientPacker.PackInPlace(b, targetAddr, frontHeadroom, payloadLen)
   160  	if err != nil {
   161  		t.Fatal(err)
   162  	}
   163  	if dap != serverAddrPort {
   164  		t.Errorf("Expected packed client packet destAddrPort %s, got %s", serverAddrPort, dap)
   165  	}
   166  	p := b[pkts : pkts+pktl]
   167  
   168  	// Server processes client packet.
   169  	csid, err := s.SessionInfo(p)
   170  	if err != nil {
   171  		t.Fatal(err)
   172  	}
   173  	serverUnpacker, _, err := s.NewUnpacker(p, csid)
   174  	if err != nil {
   175  		t.Fatal(err)
   176  	}
   177  
   178  	// Backup processed client packet.
   179  	pb := make([]byte, pktl)
   180  	copy(pb, p)
   181  
   182  	// Server unpacks.
   183  	_, _, _, err = serverUnpacker.UnpackInPlace(b, clientAddrPort, pkts, pktl)
   184  	if err != nil {
   185  		t.Error(err)
   186  	}
   187  
   188  	// Server unpacks the same packet again.
   189  	_, _, _, err = serverUnpacker.UnpackInPlace(pb, replayClientAddrPort, 0, pktl)
   190  	var sprErr *ShadowPacketReplayError
   191  	if !errors.As(err, &sprErr) {
   192  		t.Errorf("Expected ShadowPacketReplayError, got %T", err)
   193  	}
   194  	if sprErr.srcAddr != replayClientAddrPort {
   195  		t.Errorf("Expected ShadowPacketReplayError srcAddr %s, got %s", replayClientAddrPort, sprErr.srcAddr)
   196  	}
   197  	if sprErr.sid != csid {
   198  		t.Errorf("Expected ShadowPacketReplayError sid %d, got %d", csid, sprErr.sid)
   199  	}
   200  	if sprErr.pid != 0 {
   201  		t.Errorf("Expected ShadowPacketReplayError pid 0, got %d", sprErr.pid)
   202  	}
   203  
   204  	// Server packs.
   205  	serverPacker, err := serverUnpacker.NewPacker()
   206  	if err != nil {
   207  		t.Fatal(err)
   208  	}
   209  	pkts, pktl, err = serverPacker.PackInPlace(b, targetAddrPort, frontHeadroom, payloadLen, packetSize)
   210  	if err != nil {
   211  		t.Fatal(err)
   212  	}
   213  	ssid0 := serverPacker.(*ShadowPacketServerPacker).ssid
   214  
   215  	// Backup packed server packet.
   216  	pb0 := make([]byte, pktl)
   217  	copy(pb0, b[pkts:pkts+pktl])
   218  
   219  	// Client unpacks.
   220  	_, _, _, err = clientUnpacker.UnpackInPlace(b, serverAddrPort, pkts, pktl)
   221  	if err != nil {
   222  		t.Error(err)
   223  	}
   224  
   225  	// Refresh server session.
   226  	serverPacker, err = serverUnpacker.NewPacker()
   227  	if err != nil {
   228  		t.Fatal(err)
   229  	}
   230  	pkts, pktl, err = serverPacker.PackInPlace(b, targetAddrPort, frontHeadroom, payloadLen, packetSize)
   231  	if err != nil {
   232  		t.Fatal(err)
   233  	}
   234  	ssid1 := serverPacker.(*ShadowPacketServerPacker).ssid
   235  
   236  	// Backup packed server packet.
   237  	pb1 := make([]byte, pktl)
   238  	copy(pb1, b[pkts:pkts+pktl])
   239  
   240  	// Trick client into accepting refreshed server session.
   241  	spcu := clientUnpacker.(*ShadowPacketClientUnpacker)
   242  	spcu.oldServerSessionLastSeenTime = spcu.oldServerSessionLastSeenTime.Add(-time.Minute - time.Nanosecond)
   243  
   244  	// Client unpacks.
   245  	_, _, _, err = clientUnpacker.UnpackInPlace(b, serverAddrPort, pkts, pktl)
   246  	if err != nil {
   247  		t.Error(err)
   248  	}
   249  
   250  	// Refresh server session again. No tricks this time!
   251  	serverPacker, err = serverUnpacker.NewPacker()
   252  	if err != nil {
   253  		t.Fatal(err)
   254  	}
   255  	pkts, pktl, err = serverPacker.PackInPlace(b, targetAddrPort, frontHeadroom, payloadLen, packetSize)
   256  	if err != nil {
   257  		t.Fatal(err)
   258  	}
   259  
   260  	// Client unpacks.
   261  	_, _, _, err = clientUnpacker.UnpackInPlace(b, serverAddrPort, pkts, pktl)
   262  	if err != ErrTooManyServerSessions {
   263  		t.Errorf("Expected ErrTooManyServerSessions, got %v", err)
   264  	}
   265  
   266  	// Client unpacks pb0.
   267  	_, _, _, err = clientUnpacker.UnpackInPlace(pb0, replayServerAddrPort, 0, len(pb0))
   268  	if !errors.As(err, &sprErr) {
   269  		t.Errorf("Expected ShadowPacketReplayError, got %T", err)
   270  	}
   271  	if sprErr.srcAddr != replayServerAddrPort {
   272  		t.Errorf("Expected ShadowPacketReplayError srcAddr %s, got %s", replayServerAddrPort, sprErr.srcAddr)
   273  	}
   274  	if sprErr.sid != ssid0 {
   275  		t.Errorf("Expected ShadowPacketReplayError sid %d, got %d", ssid0, sprErr.sid)
   276  	}
   277  	if sprErr.pid != 0 {
   278  		t.Errorf("Expected ShadowPacketReplayError pid 0, got %d", sprErr.pid)
   279  	}
   280  
   281  	// Client unpacks pb1.
   282  	_, _, _, err = clientUnpacker.UnpackInPlace(pb1, replayServerAddrPort, 0, len(pb1))
   283  	if !errors.As(err, &sprErr) {
   284  		t.Errorf("Expected ShadowPacketReplayError, got %T", err)
   285  	}
   286  	if sprErr.srcAddr != replayServerAddrPort {
   287  		t.Errorf("Expected ShadowPacketReplayError srcAddr %s, got %s", replayServerAddrPort, sprErr.srcAddr)
   288  	}
   289  	if sprErr.sid != ssid1 {
   290  		t.Errorf("Expected ShadowPacketReplayError sid %d, got %d", ssid1, sprErr.sid)
   291  	}
   292  	if sprErr.pid != 0 {
   293  		t.Errorf("Expected ShadowPacketReplayError pid 0, got %d", sprErr.pid)
   294  	}
   295  }
   296  
   297  func testUDPClientServerPaddingPolicy(t *testing.T, clientCipherConfig *ClientCipherConfig, userCipherConfig UserCipherConfig, identityCipherConfig ServerIdentityCipherConfig, userLookupMap UserLookupMap, mtu, packetSize, payloadLen int) {
   298  	t.Run("NoPadding", func(t *testing.T) {
   299  		testUDPClientServer(t, clientCipherConfig, userCipherConfig, identityCipherConfig, userLookupMap, NoPadding, NoPadding, mtu, packetSize, payloadLen)
   300  	})
   301  	t.Run("PadPlainDNS", func(t *testing.T) {
   302  		testUDPClientServer(t, clientCipherConfig, userCipherConfig, identityCipherConfig, userLookupMap, PadPlainDNS, PadPlainDNS, mtu, packetSize, payloadLen)
   303  	})
   304  	t.Run("PadAll", func(t *testing.T) {
   305  		testUDPClientServer(t, clientCipherConfig, userCipherConfig, identityCipherConfig, userLookupMap, PadAll, PadAll, mtu, packetSize, payloadLen)
   306  	})
   307  }
   308  
   309  func testUDPClientServerWithCipher(t *testing.T, clientCipherConfig *ClientCipherConfig, userCipherConfig UserCipherConfig, identityCipherConfig ServerIdentityCipherConfig, userLookupMap UserLookupMap) {
   310  	t.Run("Typical", func(t *testing.T) {
   311  		testUDPClientServerPaddingPolicy(t, clientCipherConfig, userCipherConfig, identityCipherConfig, userLookupMap, mtu, packetSize, payloadLen)
   312  	})
   313  	t.Run("EmptyPayload", func(t *testing.T) {
   314  		testUDPClientServerPaddingPolicy(t, clientCipherConfig, userCipherConfig, identityCipherConfig, userLookupMap, mtu, packetSize, 0)
   315  	})
   316  	t.Run("Jumbogram", func(t *testing.T) {
   317  		testUDPClientServerPaddingPolicy(t, clientCipherConfig, userCipherConfig, identityCipherConfig, userLookupMap, jumboMTU, jumboPacketSize, jumboPayloadLen)
   318  	})
   319  	t.Run("SessionChangeAndReplay", func(t *testing.T) {
   320  		testUDPClientServerSessionChangeAndReplay(t, clientCipherConfig, userCipherConfig, identityCipherConfig, userLookupMap)
   321  	})
   322  }
   323  
   324  func TestUDPClientServerNoEIH(t *testing.T) {
   325  	clientCipherConfig128, userCipherConfig128, err := newRandomCipherConfigTupleNoEIH("2022-blake3-aes-128-gcm", true)
   326  	if err != nil {
   327  		t.Fatal(err)
   328  	}
   329  	clientCipherConfig256, userCipherConfig256, err := newRandomCipherConfigTupleNoEIH("2022-blake3-aes-256-gcm", true)
   330  	if err != nil {
   331  		t.Fatal(err)
   332  	}
   333  
   334  	t.Run("128", func(t *testing.T) {
   335  		testUDPClientServerWithCipher(t, clientCipherConfig128, userCipherConfig128, ServerIdentityCipherConfig{}, nil)
   336  	})
   337  	t.Run("256", func(t *testing.T) {
   338  		testUDPClientServerWithCipher(t, clientCipherConfig256, userCipherConfig256, ServerIdentityCipherConfig{}, nil)
   339  	})
   340  }
   341  
   342  func TestUDPClientServerWithEIH(t *testing.T) {
   343  	clientCipherConfig128, identityCipherConfig128, userLookupMap128, err := newRandomCipherConfigTupleWithEIH("2022-blake3-aes-128-gcm", true)
   344  	if err != nil {
   345  		t.Fatal(err)
   346  	}
   347  	clientCipherConfig256, identityCipherConfig256, userLookupMap256, err := newRandomCipherConfigTupleWithEIH("2022-blake3-aes-256-gcm", true)
   348  	if err != nil {
   349  		t.Fatal(err)
   350  	}
   351  
   352  	t.Run("128", func(t *testing.T) {
   353  		testUDPClientServerWithCipher(t, clientCipherConfig128, UserCipherConfig{}, identityCipherConfig128, userLookupMap128)
   354  	})
   355  	t.Run("256", func(t *testing.T) {
   356  		testUDPClientServerWithCipher(t, clientCipherConfig256, UserCipherConfig{}, identityCipherConfig256, userLookupMap256)
   357  	})
   358  }