github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/copy_direct.go (about)

     1  package bufio
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"syscall"
     7  
     8  	"github.com/sagernet/sing/common/buf"
     9  	M "github.com/sagernet/sing/common/metadata"
    10  	N "github.com/sagernet/sing/common/network"
    11  )
    12  
    13  func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
    14  	rawSource, err := source.SyscallConn()
    15  	if err != nil {
    16  		return
    17  	}
    18  	rawDestination, err := destination.SyscallConn()
    19  	if err != nil {
    20  		return
    21  	}
    22  	handed, n, err = splice(rawSource, rawDestination, readCounters, writeCounters)
    23  	return
    24  }
    25  
    26  func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
    27  	handled = true
    28  	var (
    29  		buffer       *buf.Buffer
    30  		notFirstTime bool
    31  	)
    32  	for {
    33  		buffer, err = source.WaitReadBuffer()
    34  		if err != nil {
    35  			if errors.Is(err, io.EOF) {
    36  				err = nil
    37  				return
    38  			}
    39  			return
    40  		}
    41  		dataLen := buffer.Len()
    42  		err = destination.WriteBuffer(buffer)
    43  		if err != nil {
    44  			buffer.Leak()
    45  			if !notFirstTime {
    46  				err = N.ReportHandshakeFailure(originSource, err)
    47  			}
    48  			return
    49  		}
    50  		n += int64(dataLen)
    51  		for _, counter := range readCounters {
    52  			counter(int64(dataLen))
    53  		}
    54  		for _, counter := range writeCounters {
    55  			counter(int64(dataLen))
    56  		}
    57  		notFirstTime = true
    58  	}
    59  }
    60  
    61  func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) {
    62  	handled = true
    63  	var (
    64  		buffer      *buf.Buffer
    65  		destination M.Socksaddr
    66  	)
    67  	for {
    68  		buffer, destination, err = source.WaitReadPacket()
    69  		if err != nil {
    70  			return
    71  		}
    72  		dataLen := buffer.Len()
    73  		err = destinationConn.WritePacket(buffer, destination)
    74  		if err != nil {
    75  			buffer.Leak()
    76  			if !notFirstTime {
    77  				err = N.ReportHandshakeFailure(originSource, err)
    78  			}
    79  			return
    80  		}
    81  		n += int64(dataLen)
    82  		for _, counter := range readCounters {
    83  			counter(int64(dataLen))
    84  		}
    85  		for _, counter := range writeCounters {
    86  			counter(int64(dataLen))
    87  		}
    88  		notFirstTime = true
    89  	}
    90  }