github.com/sagernet/sing@v0.2.6/common/bufio/copy_direct_posix.go (about)

     1  //go:build !windows
     2  
     3  package bufio
     4  
     5  import (
     6  	"errors"
     7  	"io"
     8  	"net/netip"
     9  	"os"
    10  	"syscall"
    11  
    12  	"github.com/sagernet/sing/common/buf"
    13  	E "github.com/sagernet/sing/common/exceptions"
    14  	M "github.com/sagernet/sing/common/metadata"
    15  	N "github.com/sagernet/sing/common/network"
    16  )
    17  
    18  func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
    19  	handled = true
    20  	frontHeadroom := N.CalculateFrontHeadroom(destination)
    21  	rearHeadroom := N.CalculateRearHeadroom(destination)
    22  	bufferSize := N.CalculateMTU(source, destination)
    23  	if bufferSize > 0 {
    24  		bufferSize += frontHeadroom + rearHeadroom
    25  	} else {
    26  		bufferSize = buf.BufferSize
    27  	}
    28  	var (
    29  		buffer       *buf.Buffer
    30  		readBuffer   *buf.Buffer
    31  		notFirstTime bool
    32  	)
    33  	source.InitializeReadWaiter(func() *buf.Buffer {
    34  		buffer = buf.NewSize(bufferSize)
    35  		readBufferRaw := buffer.Slice()
    36  		readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
    37  		readBuffer.Resize(frontHeadroom, 0)
    38  		return readBuffer
    39  	})
    40  	defer source.InitializeReadWaiter(nil)
    41  	for {
    42  		err = source.WaitReadBuffer()
    43  		if err != nil {
    44  			if errors.Is(err, io.EOF) {
    45  				err = nil
    46  				return
    47  			}
    48  			if !notFirstTime {
    49  				err = N.HandshakeFailure(originDestination, err)
    50  			}
    51  			return
    52  		}
    53  		dataLen := readBuffer.Len()
    54  		buffer.Resize(readBuffer.Start(), dataLen)
    55  		err = destination.WriteBuffer(buffer)
    56  		if err != nil {
    57  			buffer.Release()
    58  			return
    59  		}
    60  		n += int64(dataLen)
    61  		for _, counter := range readCounters {
    62  			counter(int64(dataLen))
    63  		}
    64  		for _, counter := range writeCounters {
    65  			counter(int64(dataLen))
    66  		}
    67  		notFirstTime = true
    68  	}
    69  }
    70  
    71  func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
    72  	handled = true
    73  	frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
    74  	rearHeadroom := N.CalculateRearHeadroom(destinationConn)
    75  	bufferSize := N.CalculateMTU(source, destinationConn)
    76  	if bufferSize > 0 {
    77  		bufferSize += frontHeadroom + rearHeadroom
    78  	} else {
    79  		bufferSize = buf.UDPBufferSize
    80  	}
    81  	var (
    82  		buffer       *buf.Buffer
    83  		readBuffer   *buf.Buffer
    84  		destination  M.Socksaddr
    85  		notFirstTime bool
    86  	)
    87  	source.InitializeReadWaiter(func() *buf.Buffer {
    88  		buffer = buf.NewSize(bufferSize)
    89  		readBufferRaw := buffer.Slice()
    90  		readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
    91  		readBuffer.Resize(frontHeadroom, 0)
    92  		return readBuffer
    93  	})
    94  	defer source.InitializeReadWaiter(nil)
    95  	for {
    96  		destination, err = source.WaitReadPacket()
    97  		if err != nil {
    98  			if !notFirstTime {
    99  				err = N.HandshakeFailure(destinationConn, err)
   100  			}
   101  			return
   102  		}
   103  		dataLen := readBuffer.Len()
   104  		buffer.Resize(readBuffer.Start(), dataLen)
   105  		err = destinationConn.WritePacket(buffer, destination)
   106  		if err != nil {
   107  			buffer.Release()
   108  			return
   109  		}
   110  		n += int64(dataLen)
   111  		for _, counter := range readCounters {
   112  			counter(int64(dataLen))
   113  		}
   114  		for _, counter := range writeCounters {
   115  			counter(int64(dataLen))
   116  		}
   117  		notFirstTime = true
   118  	}
   119  }
   120  
   121  var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
   122  
   123  type syscallReadWaiter struct {
   124  	rawConn  syscall.RawConn
   125  	readErr  error
   126  	readFunc func(fd uintptr) (done bool)
   127  }
   128  
   129  func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
   130  	if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
   131  		rawConn, err := syscallConn.SyscallConn()
   132  		if err == nil {
   133  			return &syscallReadWaiter{rawConn: rawConn}, true
   134  		}
   135  	}
   136  	return nil, false
   137  }
   138  
   139  func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
   140  	w.readErr = nil
   141  	if newBuffer == nil {
   142  		w.readFunc = nil
   143  	} else {
   144  		w.readFunc = func(fd uintptr) (done bool) {
   145  			buffer := newBuffer()
   146  			var readN int
   147  			readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
   148  			if readN > 0 {
   149  				buffer.Truncate(readN)
   150  			} else {
   151  				buffer.Release()
   152  				buffer = nil
   153  			}
   154  			if w.readErr == syscall.EAGAIN {
   155  				return false
   156  			}
   157  			if readN == 0 {
   158  				w.readErr = io.EOF
   159  			}
   160  			return true
   161  		}
   162  	}
   163  }
   164  
   165  func (w *syscallReadWaiter) WaitReadBuffer() error {
   166  	if w.readFunc == nil {
   167  		return os.ErrInvalid
   168  	}
   169  	err := w.rawConn.Read(w.readFunc)
   170  	if err != nil {
   171  		return err
   172  	}
   173  	if w.readErr != nil {
   174  		if w.readErr == io.EOF {
   175  			return io.EOF
   176  		}
   177  		return E.Cause(w.readErr, "raw read")
   178  	}
   179  	return nil
   180  }
   181  
   182  var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
   183  
   184  type syscallPacketReadWaiter struct {
   185  	rawConn  syscall.RawConn
   186  	readErr  error
   187  	readFrom M.Socksaddr
   188  	readFunc func(fd uintptr) (done bool)
   189  }
   190  
   191  func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
   192  	if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
   193  		rawConn, err := syscallConn.SyscallConn()
   194  		if err == nil {
   195  			return &syscallPacketReadWaiter{rawConn: rawConn}, true
   196  		}
   197  	}
   198  	return nil, false
   199  }
   200  
   201  func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
   202  	w.readErr = nil
   203  	w.readFrom = M.Socksaddr{}
   204  	if newBuffer == nil {
   205  		w.readFunc = nil
   206  	} else {
   207  		w.readFunc = func(fd uintptr) (done bool) {
   208  			buffer := newBuffer()
   209  			var readN int
   210  			var from syscall.Sockaddr
   211  			readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
   212  			if readN > 0 {
   213  				buffer.Truncate(readN)
   214  			} else {
   215  				buffer.Release()
   216  				buffer = nil
   217  			}
   218  			if w.readErr == syscall.EAGAIN {
   219  				return false
   220  			}
   221  			if from != nil {
   222  				switch fromAddr := from.(type) {
   223  				case *syscall.SockaddrInet4:
   224  					w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
   225  				case *syscall.SockaddrInet6:
   226  					w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port))
   227  				}
   228  			}
   229  			return true
   230  		}
   231  	}
   232  }
   233  
   234  func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
   235  	if w.readFunc == nil {
   236  		return M.Socksaddr{}, os.ErrInvalid
   237  	}
   238  	err = w.rawConn.Read(w.readFunc)
   239  	if err != nil {
   240  		return
   241  	}
   242  	if w.readErr != nil {
   243  		err = E.Cause(w.readErr, "raw read")
   244  		return
   245  	}
   246  	destination = w.readFrom
   247  	return
   248  }