github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/copy_direct_posix.go (about)

     1  //go:build !windows
     2  
     3  package bufio
     4  
     5  import (
     6  	"io"
     7  	"net/netip"
     8  	"os"
     9  	"syscall"
    10  
    11  	"github.com/sagernet/sing/common/buf"
    12  	E "github.com/sagernet/sing/common/exceptions"
    13  	M "github.com/sagernet/sing/common/metadata"
    14  	N "github.com/sagernet/sing/common/network"
    15  )
    16  
    17  var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
    18  
    19  type syscallReadWaiter struct {
    20  	rawConn  syscall.RawConn
    21  	readErr  error
    22  	readFunc func(fd uintptr) (done bool)
    23  	buffer   *buf.Buffer
    24  	options  N.ReadWaitOptions
    25  }
    26  
    27  func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
    28  	if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
    29  		rawConn, err := syscallConn.SyscallConn()
    30  		if err == nil {
    31  			return &syscallReadWaiter{rawConn: rawConn}, true
    32  		}
    33  	}
    34  	return nil, false
    35  }
    36  
    37  func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
    38  	w.options = options
    39  	w.readFunc = func(fd uintptr) (done bool) {
    40  		buffer := w.options.NewBuffer()
    41  		var readN int
    42  		readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes())
    43  		if readN > 0 {
    44  			buffer.Truncate(readN)
    45  			w.options.PostReturn(buffer)
    46  			w.buffer = buffer
    47  		} else {
    48  			buffer.Release()
    49  		}
    50  		//goland:noinspection GoDirectComparisonOfErrors
    51  		if w.readErr == syscall.EAGAIN {
    52  			return false
    53  		}
    54  		if readN == 0 && w.readErr == nil {
    55  			w.readErr = io.EOF
    56  		}
    57  		return true
    58  	}
    59  	return false
    60  }
    61  
    62  func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
    63  	if w.readFunc == nil {
    64  		return nil, os.ErrInvalid
    65  	}
    66  	err = w.rawConn.Read(w.readFunc)
    67  	if err != nil {
    68  		return
    69  	}
    70  	if w.readErr != nil {
    71  		if w.readErr == io.EOF {
    72  			return nil, io.EOF
    73  		}
    74  		return nil, E.Cause(w.readErr, "raw read")
    75  	}
    76  	buffer = w.buffer
    77  	w.buffer = nil
    78  	return
    79  }
    80  
    81  var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
    82  
    83  type syscallPacketReadWaiter struct {
    84  	rawConn  syscall.RawConn
    85  	readErr  error
    86  	readFrom M.Socksaddr
    87  	readFunc func(fd uintptr) (done bool)
    88  	buffer   *buf.Buffer
    89  	options  N.ReadWaitOptions
    90  }
    91  
    92  func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
    93  	if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
    94  		rawConn, err := syscallConn.SyscallConn()
    95  		if err == nil {
    96  			return &syscallPacketReadWaiter{rawConn: rawConn}, true
    97  		}
    98  	}
    99  	return nil, false
   100  }
   101  
   102  func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
   103  	w.options = options
   104  	w.readFunc = func(fd uintptr) (done bool) {
   105  		buffer := w.options.NewPacketBuffer()
   106  		var readN int
   107  		var from syscall.Sockaddr
   108  		readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0)
   109  		//goland:noinspection GoDirectComparisonOfErrors
   110  		if w.readErr != nil {
   111  			buffer.Release()
   112  			return w.readErr != syscall.EAGAIN
   113  		}
   114  		if readN > 0 {
   115  			buffer.Truncate(readN)
   116  		}
   117  		w.options.PostReturn(buffer)
   118  		w.buffer = buffer
   119  		switch fromAddr := from.(type) {
   120  		case *syscall.SockaddrInet4:
   121  			w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
   122  		case *syscall.SockaddrInet6:
   123  			w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
   124  		}
   125  		return true
   126  	}
   127  	return false
   128  }
   129  
   130  func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
   131  	if w.readFunc == nil {
   132  		return nil, M.Socksaddr{}, os.ErrInvalid
   133  	}
   134  	err = w.rawConn.Read(w.readFunc)
   135  	if err != nil {
   136  		return
   137  	}
   138  	if w.readErr != nil {
   139  		err = E.Cause(w.readErr, "raw read")
   140  		return
   141  	}
   142  	buffer = w.buffer
   143  	w.buffer = nil
   144  	destination = w.readFrom
   145  	return
   146  }