github.com/sagernet/sing@v0.2.6/common/bufio/copy.go (about) 1 package bufio 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "net" 8 "syscall" 9 10 "github.com/sagernet/sing/common" 11 "github.com/sagernet/sing/common/buf" 12 E "github.com/sagernet/sing/common/exceptions" 13 M "github.com/sagernet/sing/common/metadata" 14 N "github.com/sagernet/sing/common/network" 15 "github.com/sagernet/sing/common/rw" 16 "github.com/sagernet/sing/common/task" 17 ) 18 19 func Copy(destination io.Writer, source io.Reader) (n int64, err error) { 20 if source == nil { 21 return 0, E.New("nil reader") 22 } else if destination == nil { 23 return 0, E.New("nil writer") 24 } 25 originDestination := destination 26 var readCounters, writeCounters []N.CountFunc 27 for { 28 source, readCounters = N.UnwrapCountReader(source, readCounters) 29 destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters) 30 if cachedSrc, isCached := source.(N.CachedReader); isCached { 31 cachedBuffer := cachedSrc.ReadCached() 32 if cachedBuffer != nil { 33 if !cachedBuffer.IsEmpty() { 34 _, err = destination.Write(cachedBuffer.Bytes()) 35 if err != nil { 36 cachedBuffer.Release() 37 return 38 } 39 } 40 cachedBuffer.Release() 41 continue 42 } 43 } 44 srcSyscallConn, srcIsSyscall := source.(syscall.Conn) 45 dstSyscallConn, dstIsSyscall := destination.(syscall.Conn) 46 if srcIsSyscall && dstIsSyscall { 47 var handled bool 48 handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters) 49 if handled { 50 return 51 } 52 } 53 break 54 } 55 return CopyExtended(originDestination, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters) 56 } 57 58 func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { 59 safeSrc := N.IsSafeReader(source) 60 headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination) 61 if safeSrc != nil { 62 if headroom == 0 { 63 return CopyExtendedWithSrcBuffer(originDestination, destination, safeSrc, readCounters, writeCounters) 64 } 65 } 66 readWaiter, isReadWaiter := CreateReadWaiter(source) 67 if isReadWaiter { 68 var handled bool 69 handled, n, err = copyWaitWithPool(originDestination, destination, readWaiter, readCounters, writeCounters) 70 if handled { 71 return 72 } 73 } 74 if !common.UnsafeBuffer || N.IsUnsafeWriter(destination) { 75 return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters) 76 } 77 bufferSize := N.CalculateMTU(source, destination) 78 if bufferSize > 0 { 79 bufferSize += headroom 80 } else { 81 bufferSize = buf.BufferSize 82 } 83 _buffer := buf.StackNewSize(bufferSize) 84 defer common.KeepAlive(_buffer) 85 buffer := common.Dup(_buffer) 86 defer buffer.Release() 87 return CopyExtendedBuffer(originDestination, destination, source, buffer, readCounters, writeCounters) 88 } 89 90 func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { 91 buffer.IncRef() 92 defer buffer.DecRef() 93 frontHeadroom := N.CalculateFrontHeadroom(destination) 94 rearHeadroom := N.CalculateRearHeadroom(destination) 95 readBufferRaw := buffer.Slice() 96 readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) 97 var notFirstTime bool 98 for { 99 readBuffer.Resize(frontHeadroom, 0) 100 err = source.ReadBuffer(readBuffer) 101 if err != nil { 102 if errors.Is(err, io.EOF) { 103 err = nil 104 return 105 } 106 if !notFirstTime { 107 err = N.HandshakeFailure(originDestination, err) 108 } 109 return 110 } 111 dataLen := readBuffer.Len() 112 buffer.Resize(readBuffer.Start(), dataLen) 113 err = destination.WriteBuffer(buffer) 114 if err != nil { 115 return 116 } 117 n += int64(dataLen) 118 for _, counter := range readCounters { 119 counter(int64(dataLen)) 120 } 121 for _, counter := range writeCounters { 122 counter(int64(dataLen)) 123 } 124 notFirstTime = true 125 } 126 } 127 128 func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { 129 var notFirstTime bool 130 for { 131 var buffer *buf.Buffer 132 buffer, err = source.ReadBufferThreadSafe() 133 if err != nil { 134 if errors.Is(err, io.EOF) { 135 err = nil 136 return 137 } 138 if !notFirstTime { 139 err = N.HandshakeFailure(originDestination, err) 140 } 141 return 142 } 143 dataLen := buffer.Len() 144 err = destination.WriteBuffer(buffer) 145 if err != nil { 146 buffer.Release() 147 return 148 } 149 n += int64(dataLen) 150 for _, counter := range readCounters { 151 counter(int64(dataLen)) 152 } 153 for _, counter := range writeCounters { 154 counter(int64(dataLen)) 155 } 156 notFirstTime = true 157 } 158 } 159 160 func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { 161 frontHeadroom := N.CalculateFrontHeadroom(destination) 162 rearHeadroom := N.CalculateRearHeadroom(destination) 163 bufferSize := N.CalculateMTU(source, destination) 164 if bufferSize > 0 { 165 bufferSize += frontHeadroom + rearHeadroom 166 } else { 167 bufferSize = buf.BufferSize 168 } 169 var notFirstTime bool 170 for { 171 buffer := buf.NewSize(bufferSize) 172 readBufferRaw := buffer.Slice() 173 readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) 174 readBuffer.Resize(frontHeadroom, 0) 175 err = source.ReadBuffer(readBuffer) 176 if err != nil { 177 buffer.Release() 178 if errors.Is(err, io.EOF) { 179 err = nil 180 return 181 } 182 if !notFirstTime { 183 err = N.HandshakeFailure(originDestination, err) 184 } 185 return 186 } 187 dataLen := readBuffer.Len() 188 buffer.Resize(readBuffer.Start(), dataLen) 189 err = destination.WriteBuffer(buffer) 190 if err != nil { 191 buffer.Release() 192 return 193 } 194 n += int64(dataLen) 195 for _, counter := range readCounters { 196 counter(int64(dataLen)) 197 } 198 for _, counter := range writeCounters { 199 counter(int64(dataLen)) 200 } 201 notFirstTime = true 202 } 203 } 204 205 func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error { 206 return CopyConnContextList([]context.Context{ctx}, source, destination) 207 } 208 209 func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error { 210 var group task.Group 211 if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex { 212 group.Append("upload", func(ctx context.Context) error { 213 err := common.Error(Copy(destination, source)) 214 if err == nil { 215 rw.CloseWrite(destination) 216 } else { 217 common.Close(destination) 218 } 219 return err 220 }) 221 } else { 222 group.Append("upload", func(ctx context.Context) error { 223 defer common.Close(destination) 224 return common.Error(Copy(destination, source)) 225 }) 226 } 227 if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex { 228 group.Append("download", func(ctx context.Context) error { 229 err := common.Error(Copy(source, destination)) 230 if err == nil { 231 rw.CloseWrite(source) 232 } else { 233 common.Close(source) 234 } 235 return err 236 }) 237 } else { 238 group.Append("download", func(ctx context.Context) error { 239 defer common.Close(source) 240 return common.Error(Copy(source, destination)) 241 }) 242 } 243 group.Cleanup(func() { 244 common.Close(source, destination) 245 }) 246 return group.RunContextList(contextList) 247 } 248 249 func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { 250 var readCounters, writeCounters []N.CountFunc 251 var cachedPackets []*N.PacketBuffer 252 for { 253 source, readCounters = N.UnwrapCountPacketReader(source, readCounters) 254 destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters) 255 if cachedReader, isCached := source.(N.CachedPacketReader); isCached { 256 packet := cachedReader.ReadCachedPacket() 257 if packet != nil { 258 cachedPackets = append(cachedPackets, packet) 259 continue 260 } 261 } 262 break 263 } 264 if cachedPackets != nil { 265 n, err = WritePacketWithPool(destinationConn, cachedPackets) 266 if err != nil { 267 return 268 } 269 } 270 safeSrc := N.IsSafePacketReader(source) 271 frontHeadroom := N.CalculateFrontHeadroom(destinationConn) 272 rearHeadroom := N.CalculateRearHeadroom(destinationConn) 273 headroom := frontHeadroom + rearHeadroom 274 if safeSrc != nil { 275 if headroom == 0 { 276 var copyN int64 277 copyN, err = CopyPacketWithSrcBuffer(destinationConn, safeSrc, readCounters, writeCounters) 278 n += copyN 279 return 280 } 281 } 282 readWaiter, isReadWaiter := CreatePacketReadWaiter(source) 283 if isReadWaiter { 284 var ( 285 handled bool 286 copeN int64 287 ) 288 handled, copeN, err = copyPacketWaitWithPool(destinationConn, readWaiter, readCounters, writeCounters) 289 if handled { 290 n += copeN 291 return 292 } 293 } 294 if N.IsUnsafeWriter(destinationConn) { 295 return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters) 296 } 297 bufferSize := N.CalculateMTU(source, destinationConn) 298 if bufferSize > 0 { 299 bufferSize += headroom 300 } else { 301 bufferSize = buf.UDPBufferSize 302 } 303 _buffer := buf.StackNewSize(bufferSize) 304 defer common.KeepAlive(_buffer) 305 buffer := common.Dup(_buffer) 306 defer buffer.Release() 307 buffer.IncRef() 308 defer buffer.DecRef() 309 var destination M.Socksaddr 310 var notFirstTime bool 311 readBufferRaw := buffer.Slice() 312 readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) 313 for { 314 readBuffer.Resize(frontHeadroom, 0) 315 destination, err = source.ReadPacket(readBuffer) 316 if err != nil { 317 if !notFirstTime { 318 err = N.HandshakeFailure(destinationConn, err) 319 } 320 return 321 } 322 dataLen := readBuffer.Len() 323 buffer.Resize(readBuffer.Start(), dataLen) 324 err = destinationConn.WritePacket(buffer, destination) 325 if err != nil { 326 return 327 } 328 n += int64(dataLen) 329 for _, counter := range readCounters { 330 counter(int64(dataLen)) 331 } 332 for _, counter := range writeCounters { 333 counter(int64(dataLen)) 334 } 335 notFirstTime = true 336 } 337 } 338 339 func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { 340 var buffer *buf.Buffer 341 var destination M.Socksaddr 342 var notFirstTime bool 343 for { 344 buffer, destination, err = source.ReadPacketThreadSafe() 345 if err != nil { 346 if !notFirstTime { 347 err = N.HandshakeFailure(destinationConn, err) 348 } 349 return 350 } 351 dataLen := buffer.Len() 352 if dataLen == 0 { 353 continue 354 } 355 err = destinationConn.WritePacket(buffer, destination) 356 if err != nil { 357 buffer.Release() 358 return 359 } 360 n += int64(dataLen) 361 for _, counter := range readCounters { 362 counter(int64(dataLen)) 363 } 364 for _, counter := range writeCounters { 365 counter(int64(dataLen)) 366 } 367 notFirstTime = true 368 } 369 } 370 371 func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { 372 frontHeadroom := N.CalculateFrontHeadroom(destinationConn) 373 rearHeadroom := N.CalculateRearHeadroom(destinationConn) 374 bufferSize := N.CalculateMTU(source, destinationConn) 375 if bufferSize > 0 { 376 bufferSize += frontHeadroom + rearHeadroom 377 } else { 378 bufferSize = buf.UDPBufferSize 379 } 380 var destination M.Socksaddr 381 var notFirstTime bool 382 for { 383 buffer := buf.NewSize(bufferSize) 384 readBufferRaw := buffer.Slice() 385 readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) 386 readBuffer.Resize(frontHeadroom, 0) 387 destination, err = source.ReadPacket(readBuffer) 388 if err != nil { 389 buffer.Release() 390 if !notFirstTime { 391 err = N.HandshakeFailure(destinationConn, err) 392 } 393 return 394 } 395 dataLen := readBuffer.Len() 396 buffer.Resize(readBuffer.Start(), dataLen) 397 err = destinationConn.WritePacket(buffer, destination) 398 if err != nil { 399 buffer.Release() 400 return 401 } 402 n += int64(dataLen) 403 for _, counter := range readCounters { 404 counter(int64(dataLen)) 405 } 406 for _, counter := range writeCounters { 407 counter(int64(dataLen)) 408 } 409 notFirstTime = true 410 } 411 } 412 413 func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) { 414 frontHeadroom := N.CalculateFrontHeadroom(destinationConn) 415 rearHeadroom := N.CalculateRearHeadroom(destinationConn) 416 for _, packetBuffer := range packetBuffers { 417 buffer := buf.NewPacket() 418 readBufferRaw := buffer.Slice() 419 readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom]) 420 readBuffer.Resize(frontHeadroom, 0) 421 _, err = readBuffer.Write(packetBuffer.Buffer.Bytes()) 422 packetBuffer.Buffer.Release() 423 if err != nil { 424 continue 425 } 426 dataLen := readBuffer.Len() 427 buffer.Resize(readBuffer.Start(), dataLen) 428 err = destinationConn.WritePacket(buffer, packetBuffer.Destination) 429 if err != nil { 430 buffer.Release() 431 return 432 } 433 n += int64(dataLen) 434 } 435 return 436 } 437 438 func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error { 439 return CopyPacketConnContextList([]context.Context{ctx}, source, destination) 440 } 441 442 func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error { 443 var group task.Group 444 group.Append("upload", func(ctx context.Context) error { 445 return common.Error(CopyPacket(destination, source)) 446 }) 447 group.Append("download", func(ctx context.Context) error { 448 return common.Error(CopyPacket(source, destination)) 449 }) 450 group.Cleanup(func() { 451 common.Close(source, destination) 452 }) 453 group.FastFail() 454 return group.RunContextList(contextList) 455 }