github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/copy_direct_test.go (about)

     1  package bufio
     2  
     3  import (
     4  	"net"
     5  	"testing"
     6  
     7  	"github.com/sagernet/sing/common/buf"
     8  	N "github.com/sagernet/sing/common/network"
     9  
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  func TestCopyWaitTCP(t *testing.T) {
    14  	t.Parallel()
    15  	inputConn, outputConn := TCPPipe(t)
    16  	readWaiter, created := CreateReadWaiter(outputConn)
    17  	require.True(t, created)
    18  	require.NotNil(t, readWaiter)
    19  	readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
    20  	require.NoError(t, TCPTest(t, inputConn, &readWaitWrapper{
    21  		Conn:       outputConn,
    22  		readWaiter: readWaiter,
    23  	}))
    24  }
    25  
    26  type readWaitWrapper struct {
    27  	net.Conn
    28  	readWaiter N.ReadWaiter
    29  	buffer     *buf.Buffer
    30  }
    31  
    32  func (r *readWaitWrapper) Read(p []byte) (n int, err error) {
    33  	if r.buffer != nil {
    34  		if r.buffer.Len() > 0 {
    35  			return r.buffer.Read(p)
    36  		}
    37  		if r.buffer.IsEmpty() {
    38  			r.buffer.Release()
    39  			r.buffer = nil
    40  		}
    41  	}
    42  	buffer, err := r.readWaiter.WaitReadBuffer()
    43  	if err != nil {
    44  		return
    45  	}
    46  	r.buffer = buffer
    47  	return r.buffer.Read(p)
    48  }
    49  
    50  func TestCopyWaitUDP(t *testing.T) {
    51  	t.Parallel()
    52  	inputConn, outputConn, outputAddr := UDPPipe(t)
    53  	readWaiter, created := CreatePacketReadWaiter(NewPacketConn(outputConn))
    54  	require.True(t, created)
    55  	require.NotNil(t, readWaiter)
    56  	readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
    57  	require.NoError(t, UDPTest(t, inputConn, &packetReadWaitWrapper{
    58  		PacketConn: outputConn,
    59  		readWaiter: readWaiter,
    60  	}, outputAddr))
    61  }
    62  
    63  type packetReadWaitWrapper struct {
    64  	net.PacketConn
    65  	readWaiter N.PacketReadWaiter
    66  }
    67  
    68  func (r *packetReadWaitWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
    69  	buffer, destination, err := r.readWaiter.WaitReadPacket()
    70  	if err != nil {
    71  		return
    72  	}
    73  	n = copy(p, buffer.Bytes())
    74  	buffer.Release()
    75  	addr = destination.UDPAddr()
    76  	return
    77  }