github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/copy_direct_test.go (about) 1 package bufio 2 3 import ( 4 "net" 5 "testing" 6 7 "github.com/sagernet/sing/common/buf" 8 N "github.com/sagernet/sing/common/network" 9 10 "github.com/stretchr/testify/require" 11 ) 12 13 func TestCopyWaitTCP(t *testing.T) { 14 t.Parallel() 15 inputConn, outputConn := TCPPipe(t) 16 readWaiter, created := CreateReadWaiter(outputConn) 17 require.True(t, created) 18 require.NotNil(t, readWaiter) 19 readWaiter.InitializeReadWaiter(N.ReadWaitOptions{}) 20 require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{ 21 Conn: outputConn, 22 readWaiter: readWaiter, 23 })) 24 } 25 26 type readWaitWrapper struct { 27 net.Conn 28 readWaiter N.ReadWaiter 29 buffer *buf.Buffer 30 } 31 32 func (r *readWaitWrapper) Read(p []byte) (n int, err error) { 33 if r.buffer != nil { 34 if r.buffer.Len() > 0 { 35 return r.buffer.Read(p) 36 } 37 if r.buffer.IsEmpty() { 38 r.buffer.Release() 39 r.buffer = nil 40 } 41 } 42 buffer, err := r.readWaiter.WaitReadBuffer() 43 if err != nil { 44 return 45 } 46 r.buffer = buffer 47 return r.buffer.Read(p) 48 } 49 50 func TestCopyWaitUDP(t *testing.T) { 51 t.Parallel() 52 inputConn, outputConn, outputAddr := UDPPipe(t) 53 readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn)) 54 require.True(t, created) 55 require.NotNil(t, readWaiter) 56 readWaiter.InitializeReadWaiter(N.ReadWaitOptions{}) 57 require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{ 58 PacketConn: outputConn, 59 readWaiter: readWaiter, 60 }, outputAddr)) 61 } 62 63 type packetReadWaitWrapper struct { 64 net.PacketConn 65 readWaiter N.PacketReadWaiter 66 } 67 68 func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 69 buffer, destination, err := r.readWaiter.WaitReadPacket() 70 if err != nil { 71 return 72 } 73 n = copy(p, buffer.Bytes()) 74 buffer.Release() 75 addr = destination.UDPAddr() 76 return 77 }