github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/net_test.go (about)

     1  package bufio
     2  
     3  import (
     4  	"context"
     5  	"crypto/md5"
     6  	"crypto/rand"
     7  	"errors"
     8  	"io"
     9  	"net"
    10  	"sync"
    11  	"testing"
    12  	"time"
    13  
    14  	M "github.com/sagernet/sing/common/metadata"
    15  	"github.com/sagernet/sing/common/task"
    16  
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  )
    20  
    21  func TCPPipe(t *testing.T) (net.Conn, net.Conn) {
    22  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    23  	require.NoError(t, err)
    24  	var (
    25  		group      task.Group
    26  		serverConn net.Conn
    27  		clientConn net.Conn
    28  	)
    29  	group.Append0(func(ctx context.Context) error {
    30  		var serverErr error
    31  		serverConn, serverErr = listener.Accept()
    32  		return serverErr
    33  	})
    34  	group.Append0(func(ctx context.Context) error {
    35  		var clientErr error
    36  		clientConn, clientErr = net.Dial("tcp", listener.Addr().String())
    37  		return clientErr
    38  	})
    39  	err = group.Run()
    40  	require.NoError(t, err)
    41  	listener.Close()
    42  	t.Cleanup(func() {
    43  		serverConn.Close()
    44  		clientConn.Close()
    45  	})
    46  	return serverConn, clientConn
    47  }
    48  
    49  func UDPPipe(t *testing.T) (net.PacketConn, net.PacketConn, M.Socksaddr) {
    50  	serverConn, err := net.ListenPacket("udp", "127.0.0.1:0")
    51  	require.NoError(t, err)
    52  	clientConn, err := net.ListenPacket("udp", "127.0.0.1:0")
    53  	require.NoError(t, err)
    54  	return serverConn, clientConn, M.SocksaddrFromNet(clientConn.LocalAddr())
    55  }
    56  
    57  func Timeout(t *testing.T) context.CancelFunc {
    58  	ctx, cancel := context.WithCancel(context.Background())
    59  	go func() {
    60  		select {
    61  		case <-ctx.Done():
    62  			return
    63  		case <-time.After(5 * time.Second):
    64  			t.Error("timeout")
    65  		}
    66  	}()
    67  	return cancel
    68  }
    69  
    70  type hashPair struct {
    71  	sendHash map[int][]byte
    72  	recvHash map[int][]byte
    73  }
    74  
    75  func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) {
    76  	pingCh := make(chan hashPair)
    77  	pongCh := make(chan hashPair)
    78  	test := func(t *testing.T) error {
    79  		defer close(pingCh)
    80  		defer close(pongCh)
    81  		pingOpen := false
    82  		pongOpen := false
    83  		var serverPair hashPair
    84  		var clientPair hashPair
    85  
    86  		for {
    87  			if pingOpen && pongOpen {
    88  				break
    89  			}
    90  
    91  			select {
    92  			case serverPair, pingOpen = <-pingCh:
    93  				assert.True(t, pingOpen)
    94  			case clientPair, pongOpen = <-pongCh:
    95  				assert.True(t, pongOpen)
    96  			case <-time.After(10 * time.Second):
    97  				return errors.New("timeout")
    98  			}
    99  		}
   100  
   101  		assert.Equal(t, serverPair.recvHash, clientPair.sendHash)
   102  		assert.Equal(t, serverPair.sendHash, clientPair.recvHash)
   103  
   104  		return nil
   105  	}
   106  
   107  	return pingCh, pongCh, test
   108  }
   109  
   110  func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error {
   111  	times := 100
   112  	chunkSize := int64(64 * 1024)
   113  
   114  	pingCh, pongCh, test := newLargeDataPair()
   115  	writeRandData := func(conn net.Conn) (map[int][]byte, error) {
   116  		buf := make([]byte, chunkSize)
   117  		hashMap := map[int][]byte{}
   118  		for i := 0; i < times; i++ {
   119  			if _, err := rand.Read(buf[1:]); err != nil {
   120  				return nil, err
   121  			}
   122  			buf[0] = byte(i)
   123  
   124  			hash := md5.Sum(buf)
   125  			hashMap[i] = hash[:]
   126  
   127  			if _, err := conn.Write(buf); err != nil {
   128  				return nil, err
   129  			}
   130  		}
   131  
   132  		return hashMap, nil
   133  	}
   134  	go func() {
   135  		hashMap := map[int][]byte{}
   136  		buf := make([]byte, chunkSize)
   137  
   138  		for i := 0; i < times; i++ {
   139  			_, err := io.ReadFull(outputConn, buf)
   140  			if err != nil {
   141  				t.Log(err.Error())
   142  				return
   143  			}
   144  
   145  			hash := md5.Sum(buf)
   146  			hashMap[int(buf[0])] = hash[:]
   147  		}
   148  
   149  		sendHash, err := writeRandData(outputConn)
   150  		if err != nil {
   151  			t.Log(err.Error())
   152  			return
   153  		}
   154  
   155  		pingCh <- hashPair{
   156  			sendHash: sendHash,
   157  			recvHash: hashMap,
   158  		}
   159  	}()
   160  
   161  	go func() {
   162  		sendHash, err := writeRandData(inputConn)
   163  		if err != nil {
   164  			t.Log(err.Error())
   165  			return
   166  		}
   167  
   168  		hashMap := map[int][]byte{}
   169  		buf := make([]byte, chunkSize)
   170  
   171  		for i := 0; i < times; i++ {
   172  			_, err = io.ReadFull(inputConn, buf)
   173  			if err != nil {
   174  				t.Log(err.Error())
   175  				return
   176  			}
   177  
   178  			hash := md5.Sum(buf)
   179  			hashMap[int(buf[0])] = hash[:]
   180  		}
   181  
   182  		pongCh <- hashPair{
   183  			sendHash: sendHash,
   184  			recvHash: hashMap,
   185  		}
   186  	}()
   187  	return test(t)
   188  }
   189  
   190  func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error {
   191  	rAddr := outputAddr.UDPAddr()
   192  	times := 50
   193  	chunkSize := 9000
   194  	pingCh, pongCh, test := newLargeDataPair()
   195  	writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) {
   196  		hashMap := map[int][]byte{}
   197  		mux := sync.Mutex{}
   198  		for i := 0; i < times; i++ {
   199  			buf := make([]byte, chunkSize)
   200  			if _, err := rand.Read(buf[1:]); err != nil {
   201  				t.Log(err.Error())
   202  				continue
   203  			}
   204  			buf[0] = byte(i)
   205  
   206  			hash := md5.Sum(buf)
   207  			mux.Lock()
   208  			hashMap[i] = hash[:]
   209  			mux.Unlock()
   210  
   211  			if _, err := pc.WriteTo(buf, addr); err != nil {
   212  				t.Log(err.Error())
   213  			}
   214  
   215  			time.Sleep(10 * time.Millisecond)
   216  		}
   217  
   218  		return hashMap, nil
   219  	}
   220  	go func() {
   221  		var (
   222  			lAddr net.Addr
   223  			err   error
   224  		)
   225  		hashMap := map[int][]byte{}
   226  		buf := make([]byte, 64*1024)
   227  
   228  		for i := 0; i < times; i++ {
   229  			_, lAddr, err = outputConn.ReadFrom(buf)
   230  			if err != nil {
   231  				t.Log(err.Error())
   232  				return
   233  			}
   234  			hash := md5.Sum(buf[:chunkSize])
   235  			hashMap[int(buf[0])] = hash[:]
   236  		}
   237  		sendHash, err := writeRandData(outputConn, lAddr)
   238  		if err != nil {
   239  			t.Log(err.Error())
   240  			return
   241  		}
   242  
   243  		pingCh <- hashPair{
   244  			sendHash: sendHash,
   245  			recvHash: hashMap,
   246  		}
   247  	}()
   248  
   249  	go func() {
   250  		sendHash, err := writeRandData(inputConn, rAddr)
   251  		if err != nil {
   252  			t.Log(err.Error())
   253  			return
   254  		}
   255  
   256  		hashMap := map[int][]byte{}
   257  		buf := make([]byte, 64*1024)
   258  
   259  		for i := 0; i < times; i++ {
   260  			_, _, err := inputConn.ReadFrom(buf)
   261  			if err != nil {
   262  				t.Log(err.Error())
   263  				return
   264  			}
   265  
   266  			hash := md5.Sum(buf[:chunkSize])
   267  			hashMap[int(buf[0])] = hash[:]
   268  		}
   269  
   270  		pongCh <- hashPair{
   271  			sendHash: sendHash,
   272  			recvHash: hashMap,
   273  		}
   274  	}()
   275  
   276  	return test(t)
   277  }