github.com/sagernet/sing@v0.2.6/common/bufio/copy_direct_posix.go (about) 1 //go:build !windows 2 3 package bufio 4 5 import ( 6 "errors" 7 "io" 8 "net/netip" 9 "os" 10 "syscall" 11 12 "github.com/sagernet/sing/common/buf" 13 E "github.com/sagernet/sing/common/exceptions" 14 M "github.com/sagernet/sing/common/metadata" 15 N "github.com/sagernet/sing/common/network" 16 ) 17 18 func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { 19 handled = true 20 frontHeadroom := N.CalculateFrontHeadroom(destination) 21 rearHeadroom := N.CalculateRearHeadroom(destination) 22 bufferSize := N.CalculateMTU(source, destination) 23 if bufferSize > 0 { 24 bufferSize += frontHeadroom + rearHeadroom 25 } else { 26 bufferSize = buf.BufferSize 27 } 28 var ( 29 buffer *buf.Buffer 30 readBuffer *buf.Buffer 31 notFirstTime bool 32 ) 33 source.InitializeReadWaiter(func() *buf.Buffer { 34 buffer = buf.NewSize(bufferSize) 35 readBufferRaw := buffer.Slice() 36 readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) 37 readBuffer.Resize(frontHeadroom, 0) 38 return readBuffer 39 }) 40 defer source.InitializeReadWaiter(nil) 41 for { 42 err = source.WaitReadBuffer() 43 if err != nil { 44 if errors.Is(err, io.EOF) { 45 err = nil 46 return 47 } 48 if !notFirstTime { 49 err = N.HandshakeFailure(originDestination, err) 50 } 51 return 52 } 53 dataLen := readBuffer.Len() 54 buffer.Resize(readBuffer.Start(), dataLen) 55 err = destination.WriteBuffer(buffer) 56 if err != nil { 57 buffer.Release() 58 return 59 } 60 n += int64(dataLen) 61 for _, counter := range readCounters { 62 counter(int64(dataLen)) 63 } 64 for _, counter := range writeCounters { 65 counter(int64(dataLen)) 66 } 67 notFirstTime = true 68 } 69 } 70 71 func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) { 72 handled = true 73 frontHeadroom := N.CalculateFrontHeadroom(destinationConn) 74 rearHeadroom := N.CalculateRearHeadroom(destinationConn) 75 bufferSize := N.CalculateMTU(source, destinationConn) 76 if bufferSize > 0 { 77 bufferSize += frontHeadroom + rearHeadroom 78 } else { 79 bufferSize = buf.UDPBufferSize 80 } 81 var ( 82 buffer *buf.Buffer 83 readBuffer *buf.Buffer 84 destination M.Socksaddr 85 notFirstTime bool 86 ) 87 source.InitializeReadWaiter(func() *buf.Buffer { 88 buffer = buf.NewSize(bufferSize) 89 readBufferRaw := buffer.Slice() 90 readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) 91 readBuffer.Resize(frontHeadroom, 0) 92 return readBuffer 93 }) 94 defer source.InitializeReadWaiter(nil) 95 for { 96 destination, err = source.WaitReadPacket() 97 if err != nil { 98 if !notFirstTime { 99 err = N.HandshakeFailure(destinationConn, err) 100 } 101 return 102 } 103 dataLen := readBuffer.Len() 104 buffer.Resize(readBuffer.Start(), dataLen) 105 err = destinationConn.WritePacket(buffer, destination) 106 if err != nil { 107 buffer.Release() 108 return 109 } 110 n += int64(dataLen) 111 for _, counter := range readCounters { 112 counter(int64(dataLen)) 113 } 114 for _, counter := range writeCounters { 115 counter(int64(dataLen)) 116 } 117 notFirstTime = true 118 } 119 } 120 121 var _ N.ReadWaiter = (*syscallReadWaiter)(nil) 122 123 type syscallReadWaiter struct { 124 rawConn syscall.RawConn 125 readErr error 126 readFunc func(fd uintptr) (done bool) 127 } 128 129 func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) { 130 if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn { 131 rawConn, err := syscallConn.SyscallConn() 132 if err == nil { 133 return &syscallReadWaiter{rawConn: rawConn}, true 134 } 135 } 136 return nil, false 137 } 138 139 func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { 140 w.readErr = nil 141 if newBuffer == nil { 142 w.readFunc = nil 143 } else { 144 w.readFunc = func(fd uintptr) (done bool) { 145 buffer := newBuffer() 146 var readN int 147 readN, w.readErr = syscall.Read(int(fd), buffer.FreeBytes()) 148 if readN > 0 { 149 buffer.Truncate(readN) 150 } else { 151 buffer.Release() 152 buffer = nil 153 } 154 if w.readErr == syscall.EAGAIN { 155 return false 156 } 157 if readN == 0 { 158 w.readErr = io.EOF 159 } 160 return true 161 } 162 } 163 } 164 165 func (w *syscallReadWaiter) WaitReadBuffer() error { 166 if w.readFunc == nil { 167 return os.ErrInvalid 168 } 169 err := w.rawConn.Read(w.readFunc) 170 if err != nil { 171 return err 172 } 173 if w.readErr != nil { 174 if w.readErr == io.EOF { 175 return io.EOF 176 } 177 return E.Cause(w.readErr, "raw read") 178 } 179 return nil 180 } 181 182 var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil) 183 184 type syscallPacketReadWaiter struct { 185 rawConn syscall.RawConn 186 readErr error 187 readFrom M.Socksaddr 188 readFunc func(fd uintptr) (done bool) 189 } 190 191 func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) { 192 if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn { 193 rawConn, err := syscallConn.SyscallConn() 194 if err == nil { 195 return &syscallPacketReadWaiter{rawConn: rawConn}, true 196 } 197 } 198 return nil, false 199 } 200 201 func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { 202 w.readErr = nil 203 w.readFrom = M.Socksaddr{} 204 if newBuffer == nil { 205 w.readFunc = nil 206 } else { 207 w.readFunc = func(fd uintptr) (done bool) { 208 buffer := newBuffer() 209 var readN int 210 var from syscall.Sockaddr 211 readN, _, _, from, w.readErr = syscall.Recvmsg(int(fd), buffer.FreeBytes(), nil, 0) 212 if readN > 0 { 213 buffer.Truncate(readN) 214 } else { 215 buffer.Release() 216 buffer = nil 217 } 218 if w.readErr == syscall.EAGAIN { 219 return false 220 } 221 if from != nil { 222 switch fromAddr := from.(type) { 223 case *syscall.SockaddrInet4: 224 w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port)) 225 case *syscall.SockaddrInet6: 226 w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)) 227 } 228 } 229 return true 230 } 231 } 232 } 233 234 func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) { 235 if w.readFunc == nil { 236 return M.Socksaddr{}, os.ErrInvalid 237 } 238 err = w.rawConn.Read(w.readFunc) 239 if err != nil { 240 return 241 } 242 if w.readErr != nil { 243 err = E.Cause(w.readErr, "raw read") 244 return 245 } 246 destination = w.readFrom 247 return 248 }