github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/net_test.go (about) 1 package bufio 2 3 import ( 4 "context" 5 "crypto/md5" 6 "crypto/rand" 7 "errors" 8 "io" 9 "net" 10 "sync" 11 "testing" 12 "time" 13 14 M "github.com/sagernet/sing/common/metadata" 15 "github.com/sagernet/sing/common/task" 16 17 "github.com/stretchr/testify/assert" 18 "github.com/stretchr/testify/require" 19 ) 20 21 func TCPPipe(t *testing.T) (net.Conn, net.Conn) { 22 listener, err := net.Listen("tcp", "127.0.0.1:0") 23 require.NoError(t, err) 24 var ( 25 group task.Group 26 serverConn net.Conn 27 clientConn net.Conn 28 ) 29 group.Append0(func(ctx context.Context) error { 30 var serverErr error 31 serverConn, serverErr = listener.Accept() 32 return serverErr 33 }) 34 group.Append0(func(ctx context.Context) error { 35 var clientErr error 36 clientConn, clientErr = net.Dial("tcp", listener.Addr().String()) 37 return clientErr 38 }) 39 err = group.Run() 40 require.NoError(t, err) 41 listener.Close() 42 t.Cleanup(func() { 43 serverConn.Close() 44 clientConn.Close() 45 }) 46 return serverConn, clientConn 47 } 48 49 func UDPPipe(t *testing.T) (net.PacketConn, net.PacketConn, M.Socksaddr) { 50 serverConn, err := net.ListenPacket("udp", "127.0.0.1:0") 51 require.NoError(t, err) 52 clientConn, err := net.ListenPacket("udp", "127.0.0.1:0") 53 require.NoError(t, err) 54 return serverConn, clientConn, M.SocksaddrFromNet(clientConn.LocalAddr()) 55 } 56 57 func Timeout(t *testing.T) context.CancelFunc { 58 ctx, cancel := context.WithCancel(context.Background()) 59 go func() { 60 select { 61 case <-ctx.Done(): 62 return 63 case <-time.After(5 * time.Second): 64 t.Error("timeout") 65 } 66 }() 67 return cancel 68 } 69 70 type hashPair struct { 71 sendHash map[int][]byte 72 recvHash map[int][]byte 73 } 74 75 func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) { 76 pingCh := make(chan hashPair) 77 pongCh := make(chan hashPair) 78 test := func(t *testing.T) error { 79 defer close(pingCh) 80 defer close(pongCh) 81 pingOpen := false 82 pongOpen := false 83 var serverPair hashPair 84 var clientPair hashPair 85 86 for { 87 if pingOpen && pongOpen { 88 break 89 } 90 91 select { 92 case serverPair, pingOpen = <-pingCh: 93 assert.True(t, pingOpen) 94 case clientPair, pongOpen = <-pongCh: 95 assert.True(t, pongOpen) 96 case <-time.After(10 * time.Second): 97 return errors.New("timeout") 98 } 99 } 100 101 assert.Equal(t, serverPair.recvHash, clientPair.sendHash) 102 assert.Equal(t, serverPair.sendHash, clientPair.recvHash) 103 104 return nil 105 } 106 107 return pingCh, pongCh, test 108 } 109 110 func TCPTest(t *testing.T, inputConn net.Conn, outputConn net.Conn) error { 111 times := 100 112 chunkSize := int64(64 * 1024) 113 114 pingCh, pongCh, test := newLargeDataPair() 115 writeRandData := func(conn net.Conn) (map[int][]byte, error) { 116 buf := make([]byte, chunkSize) 117 hashMap := map[int][]byte{} 118 for i := 0; i < times; i++ { 119 if _, err := rand.Read(buf[1:]); err != nil { 120 return nil, err 121 } 122 buf[0] = byte(i) 123 124 hash := md5.Sum(buf) 125 hashMap[i] = hash[:] 126 127 if _, err := conn.Write(buf); err != nil { 128 return nil, err 129 } 130 } 131 132 return hashMap, nil 133 } 134 go func() { 135 hashMap := map[int][]byte{} 136 buf := make([]byte, chunkSize) 137 138 for i := 0; i < times; i++ { 139 _, err := io.ReadFull(outputConn, buf) 140 if err != nil { 141 t.Log(err.Error()) 142 return 143 } 144 145 hash := md5.Sum(buf) 146 hashMap[int(buf[0])] = hash[:] 147 } 148 149 sendHash, err := writeRandData(outputConn) 150 if err != nil { 151 t.Log(err.Error()) 152 return 153 } 154 155 pingCh <- hashPair{ 156 sendHash: sendHash, 157 recvHash: hashMap, 158 } 159 }() 160 161 go func() { 162 sendHash, err := writeRandData(inputConn) 163 if err != nil { 164 t.Log(err.Error()) 165 return 166 } 167 168 hashMap := map[int][]byte{} 169 buf := make([]byte, chunkSize) 170 171 for i := 0; i < times; i++ { 172 _, err = io.ReadFull(inputConn, buf) 173 if err != nil { 174 t.Log(err.Error()) 175 return 176 } 177 178 hash := md5.Sum(buf) 179 hashMap[int(buf[0])] = hash[:] 180 } 181 182 pongCh <- hashPair{ 183 sendHash: sendHash, 184 recvHash: hashMap, 185 } 186 }() 187 return test(t) 188 } 189 190 func UDPTest(t *testing.T, inputConn net.PacketConn, outputConn net.PacketConn, outputAddr M.Socksaddr) error { 191 rAddr := outputAddr.UDPAddr() 192 times := 50 193 chunkSize := 9000 194 pingCh, pongCh, test := newLargeDataPair() 195 writeRandData := func(pc net.PacketConn, addr net.Addr) (map[int][]byte, error) { 196 hashMap := map[int][]byte{} 197 mux := sync.Mutex{} 198 for i := 0; i < times; i++ { 199 buf := make([]byte, chunkSize) 200 if _, err := rand.Read(buf[1:]); err != nil { 201 t.Log(err.Error()) 202 continue 203 } 204 buf[0] = byte(i) 205 206 hash := md5.Sum(buf) 207 mux.Lock() 208 hashMap[i] = hash[:] 209 mux.Unlock() 210 211 if _, err := pc.WriteTo(buf, addr); err != nil { 212 t.Log(err.Error()) 213 } 214 215 time.Sleep(10 * time.Millisecond) 216 } 217 218 return hashMap, nil 219 } 220 go func() { 221 var ( 222 lAddr net.Addr 223 err error 224 ) 225 hashMap := map[int][]byte{} 226 buf := make([]byte, 64*1024) 227 228 for i := 0; i < times; i++ { 229 _, lAddr, err = outputConn.ReadFrom(buf) 230 if err != nil { 231 t.Log(err.Error()) 232 return 233 } 234 hash := md5.Sum(buf[:chunkSize]) 235 hashMap[int(buf[0])] = hash[:] 236 } 237 sendHash, err := writeRandData(outputConn, lAddr) 238 if err != nil { 239 t.Log(err.Error()) 240 return 241 } 242 243 pingCh <- hashPair{ 244 sendHash: sendHash, 245 recvHash: hashMap, 246 } 247 }() 248 249 go func() { 250 sendHash, err := writeRandData(inputConn, rAddr) 251 if err != nil { 252 t.Log(err.Error()) 253 return 254 } 255 256 hashMap := map[int][]byte{} 257 buf := make([]byte, 64*1024) 258 259 for i := 0; i < times; i++ { 260 _, _, err := inputConn.ReadFrom(buf) 261 if err != nil { 262 t.Log(err.Error()) 263 return 264 } 265 266 hash := md5.Sum(buf[:chunkSize]) 267 hashMap[int(buf[0])] = hash[:] 268 } 269 270 pongCh <- hashPair{ 271 sendHash: sendHash, 272 recvHash: hashMap, 273 } 274 }() 275 276 return test(t) 277 }