github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/copy.go (about) 1 package bufio 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "net" 8 "syscall" 9 10 "github.com/sagernet/sing/common" 11 "github.com/sagernet/sing/common/buf" 12 E "github.com/sagernet/sing/common/exceptions" 13 M "github.com/sagernet/sing/common/metadata" 14 N "github.com/sagernet/sing/common/network" 15 "github.com/sagernet/sing/common/rw" 16 "github.com/sagernet/sing/common/task" 17 ) 18 19 func Copy(destination io.Writer, source io.Reader) (n int64, err error) { 20 if source == nil { 21 return 0, E.New("nil reader") 22 } else if destination == nil { 23 return 0, E.New("nil writer") 24 } 25 originSource := source 26 var readCounters, writeCounters []N.CountFunc 27 for { 28 source, readCounters = N.UnwrapCountReader(source, readCounters) 29 destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) 30 if cachedSrc, isCached := source.(N.CachedReader); isCached { 31 cachedBuffer := cachedSrc.ReadCached() 32 if cachedBuffer != nil { 33 if !cachedBuffer.IsEmpty() { 34 _, err = destination.Write(cachedBuffer.Bytes()) 35 if err != nil { 36 cachedBuffer.Release() 37 return 38 } 39 } 40 cachedBuffer.Release() 41 continue 42 } 43 } 44 srcSyscallConn, srcIsSyscall := source.(syscall.Conn) 45 dstSyscallConn, dstIsSyscall := destination.(syscall.Conn) 46 if srcIsSyscall && dstIsSyscall { 47 var handled bool 48 handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) 49 if handled { 50 return 51 } 52 } 53 break 54 } 55 return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) 56 } 57 58 func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { 59 frontHeadroom := N.CalculateFrontHeadroom(destination) 60 rearHeadroom := N.CalculateRearHeadroom(destination) 61 readWaiter, isReadWaiter := CreateReadWaiter(source) 62 if isReadWaiter { 63 needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ 64 FrontHeadroom: frontHeadroom, 65 RearHeadroom: rearHeadroom, 66 MTU: N.CalculateMTU(source, destination), 67 }) 68 if !needCopy || common.LowMemory { 69 var handled bool 70 handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters) 71 if handled { 72 return 73 } 74 } 75 } 76 return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters) 77 } 78 79 func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { 80 buffer.IncRef() 81 defer buffer.DecRef() 82 frontHeadroom := N.CalculateFrontHeadroom(destination) 83 rearHeadroom := N.CalculateRearHeadroom(destination) 84 buffer.Resize(frontHeadroom, 0) 85 buffer.Reserve(rearHeadroom) 86 var notFirstTime bool 87 for { 88 err = source.ReadBuffer(buffer) 89 if err != nil { 90 if errors.Is(err, io.EOF) { 91 err = nil 92 return 93 } 94 return 95 } 96 dataLen := buffer.Len() 97 buffer.OverCap(rearHeadroom) 98 err = destination.WriteBuffer(buffer) 99 if err != nil { 100 if !notFirstTime { 101 err = N.ReportHandshakeFailure(originSource, err) 102 } 103 return 104 } 105 n += int64(dataLen) 106 for _, counter := range readCounters { 107 counter(int64(dataLen)) 108 } 109 for _, counter := range writeCounters { 110 counter(int64(dataLen)) 111 } 112 notFirstTime = true 113 } 114 } 115 116 func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { 117 frontHeadroom := N.CalculateFrontHeadroom(destination) 118 rearHeadroom := N.CalculateRearHeadroom(destination) 119 bufferSize := N.CalculateMTU(source, destination) 120 if bufferSize > 0 { 121 bufferSize += frontHeadroom + rearHeadroom 122 } else { 123 bufferSize = buf.BufferSize 124 } 125 var notFirstTime bool 126 for { 127 buffer := buf.NewSize(bufferSize) 128 buffer.Resize(frontHeadroom, 0) 129 buffer.Reserve(rearHeadroom) 130 err = source.ReadBuffer(buffer) 131 if err != nil { 132 buffer.Release() 133 if errors.Is(err, io.EOF) { 134 err = nil 135 return 136 } 137 return 138 } 139 dataLen := buffer.Len() 140 buffer.OverCap(rearHeadroom) 141 err = destination.WriteBuffer(buffer) 142 if err != nil { 143 buffer.Leak() 144 if !notFirstTime { 145 err = N.ReportHandshakeFailure(originSource, err) 146 } 147 return 148 } 149 n += int64(dataLen) 150 for _, counter := range readCounters { 151 counter(int64(dataLen)) 152 } 153 for _, counter := range writeCounters { 154 counter(int64(dataLen)) 155 } 156 notFirstTime = true 157 } 158 } 159 160 func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error { 161 return CopyConnContextList([]context.Context{ctx}, source, destination) 162 } 163 164 func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error { 165 var group task.Group 166 if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex { 167 group.Append("upload", func(ctx context.Context) error { 168 err := common.Error(Copy(destination, source)) 169 if err == nil { 170 rw.CloseWrite(destination) 171 } else { 172 common.Close(destination) 173 } 174 return err 175 }) 176 } else { 177 group.Append("upload", func(ctx context.Context) error { 178 defer common.Close(destination) 179 return common.Error(Copy(destination, source)) 180 }) 181 } 182 if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex { 183 group.Append("download", func(ctx context.Context) error { 184 err := common.Error(Copy(source, destination)) 185 if err == nil { 186 rw.CloseWrite(source) 187 } else { 188 common.Close(source) 189 } 190 return err 191 }) 192 } else { 193 group.Append("download", func(ctx context.Context) error { 194 defer common.Close(source) 195 return common.Error(Copy(source, destination)) 196 }) 197 } 198 group.Cleanup(func() { 199 common.Close(source, destination) 200 }) 201 return group.RunContextList(contextList) 202 } 203 204 func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { 205 var readCounters, writeCounters []N.CountFunc 206 var cachedPackets []*N.PacketBuffer 207 originSource := source 208 for { 209 source, readCounters = N.UnwrapCountPacketReader(source, readCounters) 210 destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters) 211 if cachedReader, isCached := source.(N.CachedPacketReader); isCached { 212 packet := cachedReader.ReadCachedPacket() 213 if packet != nil { 214 cachedPackets = append(cachedPackets, packet) 215 continue 216 } 217 } 218 break 219 } 220 if cachedPackets != nil { 221 n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets) 222 if err != nil { 223 return 224 } 225 } 226 frontHeadroom := N.CalculateFrontHeadroom(destinationConn) 227 rearHeadroom := N.CalculateRearHeadroom(destinationConn) 228 var ( 229 handled bool 230 copeN int64 231 ) 232 readWaiter, isReadWaiter := CreatePacketReadWaiter(source) 233 if isReadWaiter { 234 needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ 235 FrontHeadroom: frontHeadroom, 236 RearHeadroom: rearHeadroom, 237 MTU: N.CalculateMTU(source, destinationConn), 238 }) 239 if !needCopy || common.LowMemory { 240 handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) 241 if handled { 242 n += copeN 243 return 244 } 245 } 246 } 247 copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0) 248 n += copeN 249 return 250 } 251 252 func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { 253 frontHeadroom := N.CalculateFrontHeadroom(destinationConn) 254 rearHeadroom := N.CalculateRearHeadroom(destinationConn) 255 bufferSize := N.CalculateMTU(source, destinationConn) 256 if bufferSize > 0 { 257 bufferSize += frontHeadroom + rearHeadroom 258 } else { 259 bufferSize = buf.UDPBufferSize 260 } 261 var destination M.Socksaddr 262 for { 263 buffer := buf.NewSize(bufferSize) 264 buffer.Resize(frontHeadroom, 0) 265 buffer.Reserve(rearHeadroom) 266 destination, err = source.ReadPacket(buffer) 267 if err != nil { 268 buffer.Release() 269 return 270 } 271 dataLen := buffer.Len() 272 buffer.OverCap(rearHeadroom) 273 err = destinationConn.WritePacket(buffer, destination) 274 if err != nil { 275 buffer.Leak() 276 if !notFirstTime { 277 err = N.ReportHandshakeFailure(originSource, err) 278 } 279 return 280 } 281 n += int64(dataLen) 282 for _, counter := range readCounters { 283 counter(int64(dataLen)) 284 } 285 for _, counter := range writeCounters { 286 counter(int64(dataLen)) 287 } 288 notFirstTime = true 289 } 290 } 291 292 func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { 293 frontHeadroom := N.CalculateFrontHeadroom(destinationConn) 294 rearHeadroom := N.CalculateRearHeadroom(destinationConn) 295 var notFirstTime bool 296 for _, packetBuffer := range packetBuffers { 297 buffer := buf.NewPacket() 298 buffer.Resize(frontHeadroom, 0) 299 buffer.Reserve(rearHeadroom) 300 _, err = buffer.Write(packetBuffer.Buffer.Bytes()) 301 packetBuffer.Buffer.Release() 302 if err != nil { 303 buffer.Release() 304 continue 305 } 306 dataLen := buffer.Len() 307 buffer.OverCap(rearHeadroom) 308 err = destinationConn.WritePacket(buffer, packetBuffer.Destination) 309 if err != nil { 310 buffer.Leak() 311 if !notFirstTime { 312 err = N.ReportHandshakeFailure(originSource, err) 313 } 314 return 315 } 316 n += int64(dataLen) 317 } 318 return 319 } 320 321 func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error { 322 return CopyPacketConnContextList([]context.Context{ctx}, source, destination) 323 } 324 325 func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error { 326 var group task.Group 327 group.Append("upload", func(ctx context.Context) error { 328 return common.Error(CopyPacket(destination, source)) 329 }) 330 group.Append("download", func(ctx context.Context) error { 331 return common.Error(CopyPacket(source, destination)) 332 }) 333 group.Cleanup(func() { 334 common.Close(source, destination) 335 }) 336 group.FastFail() 337 return group.RunContextList(contextList) 338 }