github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/copy.go (about)

     1  package bufio
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"syscall"
     9  
    10  	"github.com/sagernet/sing/common"
    11  	"github.com/sagernet/sing/common/buf"
    12  	E "github.com/sagernet/sing/common/exceptions"
    13  	M "github.com/sagernet/sing/common/metadata"
    14  	N "github.com/sagernet/sing/common/network"
    15  	"github.com/sagernet/sing/common/rw"
    16  	"github.com/sagernet/sing/common/task"
    17  )
    18  
    19  func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
    20  	if source == nil {
    21  		return 0, E.New("nil reader")
    22  	} else if destination == nil {
    23  		return 0, E.New("nil writer")
    24  	}
    25  	originSource := source
    26  	var readCounters, writeCounters []N.CountFunc
    27  	for {
    28  		source, readCounters = N.UnwrapCountReader(source, readCounters)
    29  		destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
    30  		if cachedSrc, isCached := source.(N.CachedReader); isCached {
    31  			cachedBuffer := cachedSrc.ReadCached()
    32  			if cachedBuffer != nil {
    33  				if !cachedBuffer.IsEmpty() {
    34  					_, err = destination.Write(cachedBuffer.Bytes())
    35  					if err != nil {
    36  						cachedBuffer.Release()
    37  						return
    38  					}
    39  				}
    40  				cachedBuffer.Release()
    41  				continue
    42  			}
    43  		}
    44  		srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
    45  		dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
    46  		if srcIsSyscall && dstIsSyscall {
    47  			var handled bool
    48  			handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
    49  			if handled {
    50  				return
    51  			}
    52  		}
    53  		break
    54  	}
    55  	return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
    56  }
    57  
    58  func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
    59  	frontHeadroom := N.CalculateFrontHeadroom(destination)
    60  	rearHeadroom := N.CalculateRearHeadroom(destination)
    61  	readWaiter, isReadWaiter := CreateReadWaiter(source)
    62  	if isReadWaiter {
    63  		needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
    64  			FrontHeadroom: frontHeadroom,
    65  			RearHeadroom:  rearHeadroom,
    66  			MTU:           N.CalculateMTU(source, destination),
    67  		})
    68  		if !needCopy || common.LowMemory {
    69  			var handled bool
    70  			handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
    71  			if handled {
    72  				return
    73  			}
    74  		}
    75  	}
    76  	return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
    77  }
    78  
    79  func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
    80  	buffer.IncRef()
    81  	defer buffer.DecRef()
    82  	frontHeadroom := N.CalculateFrontHeadroom(destination)
    83  	rearHeadroom := N.CalculateRearHeadroom(destination)
    84  	buffer.Resize(frontHeadroom, 0)
    85  	buffer.Reserve(rearHeadroom)
    86  	var notFirstTime bool
    87  	for {
    88  		err = source.ReadBuffer(buffer)
    89  		if err != nil {
    90  			if errors.Is(err, io.EOF) {
    91  				err = nil
    92  				return
    93  			}
    94  			return
    95  		}
    96  		dataLen := buffer.Len()
    97  		buffer.OverCap(rearHeadroom)
    98  		err = destination.WriteBuffer(buffer)
    99  		if err != nil {
   100  			if !notFirstTime {
   101  				err = N.ReportHandshakeFailure(originSource, err)
   102  			}
   103  			return
   104  		}
   105  		n += int64(dataLen)
   106  		for _, counter := range readCounters {
   107  			counter(int64(dataLen))
   108  		}
   109  		for _, counter := range writeCounters {
   110  			counter(int64(dataLen))
   111  		}
   112  		notFirstTime = true
   113  	}
   114  }
   115  
   116  func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
   117  	frontHeadroom := N.CalculateFrontHeadroom(destination)
   118  	rearHeadroom := N.CalculateRearHeadroom(destination)
   119  	bufferSize := N.CalculateMTU(source, destination)
   120  	if bufferSize > 0 {
   121  		bufferSize += frontHeadroom + rearHeadroom
   122  	} else {
   123  		bufferSize = buf.BufferSize
   124  	}
   125  	var notFirstTime bool
   126  	for {
   127  		buffer := buf.NewSize(bufferSize)
   128  		buffer.Resize(frontHeadroom, 0)
   129  		buffer.Reserve(rearHeadroom)
   130  		err = source.ReadBuffer(buffer)
   131  		if err != nil {
   132  			buffer.Release()
   133  			if errors.Is(err, io.EOF) {
   134  				err = nil
   135  				return
   136  			}
   137  			return
   138  		}
   139  		dataLen := buffer.Len()
   140  		buffer.OverCap(rearHeadroom)
   141  		err = destination.WriteBuffer(buffer)
   142  		if err != nil {
   143  			buffer.Leak()
   144  			if !notFirstTime {
   145  				err = N.ReportHandshakeFailure(originSource, err)
   146  			}
   147  			return
   148  		}
   149  		n += int64(dataLen)
   150  		for _, counter := range readCounters {
   151  			counter(int64(dataLen))
   152  		}
   153  		for _, counter := range writeCounters {
   154  			counter(int64(dataLen))
   155  		}
   156  		notFirstTime = true
   157  	}
   158  }
   159  
   160  func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
   161  	return CopyConnContextList([]context.Context{ctx}, source, destination)
   162  }
   163  
   164  func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
   165  	var group task.Group
   166  	if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex {
   167  		group.Append("upload", func(ctx context.Context) error {
   168  			err := common.Error(Copy(destination, source))
   169  			if err == nil {
   170  				rw.CloseWrite(destination)
   171  			} else {
   172  				common.Close(destination)
   173  			}
   174  			return err
   175  		})
   176  	} else {
   177  		group.Append("upload", func(ctx context.Context) error {
   178  			defer common.Close(destination)
   179  			return common.Error(Copy(destination, source))
   180  		})
   181  	}
   182  	if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex {
   183  		group.Append("download", func(ctx context.Context) error {
   184  			err := common.Error(Copy(source, destination))
   185  			if err == nil {
   186  				rw.CloseWrite(source)
   187  			} else {
   188  				common.Close(source)
   189  			}
   190  			return err
   191  		})
   192  	} else {
   193  		group.Append("download", func(ctx context.Context) error {
   194  			defer common.Close(source)
   195  			return common.Error(Copy(source, destination))
   196  		})
   197  	}
   198  	group.Cleanup(func() {
   199  		common.Close(source, destination)
   200  	})
   201  	return group.RunContextList(contextList)
   202  }
   203  
   204  func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
   205  	var readCounters, writeCounters []N.CountFunc
   206  	var cachedPackets []*N.PacketBuffer
   207  	originSource := source
   208  	for {
   209  		source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
   210  		destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters)
   211  		if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
   212  			packet := cachedReader.ReadCachedPacket()
   213  			if packet != nil {
   214  				cachedPackets = append(cachedPackets, packet)
   215  				continue
   216  			}
   217  		}
   218  		break
   219  	}
   220  	if cachedPackets != nil {
   221  		n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
   222  		if err != nil {
   223  			return
   224  		}
   225  	}
   226  	frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
   227  	rearHeadroom := N.CalculateRearHeadroom(destinationConn)
   228  	var (
   229  		handled bool
   230  		copeN   int64
   231  	)
   232  	readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
   233  	if isReadWaiter {
   234  		needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
   235  			FrontHeadroom: frontHeadroom,
   236  			RearHeadroom:  rearHeadroom,
   237  			MTU:           N.CalculateMTU(source, destinationConn),
   238  		})
   239  		if !needCopy || common.LowMemory {
   240  			handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
   241  			if handled {
   242  				n += copeN
   243  				return
   244  			}
   245  		}
   246  	}
   247  	copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
   248  	n += copeN
   249  	return
   250  }
   251  
   252  func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
   253  	frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
   254  	rearHeadroom := N.CalculateRearHeadroom(destinationConn)
   255  	bufferSize := N.CalculateMTU(source, destinationConn)
   256  	if bufferSize > 0 {
   257  		bufferSize += frontHeadroom + rearHeadroom
   258  	} else {
   259  		bufferSize = buf.UDPBufferSize
   260  	}
   261  	var destination M.Socksaddr
   262  	for {
   263  		buffer := buf.NewSize(bufferSize)
   264  		buffer.Resize(frontHeadroom, 0)
   265  		buffer.Reserve(rearHeadroom)
   266  		destination, err = source.ReadPacket(buffer)
   267  		if err != nil {
   268  			buffer.Release()
   269  			return
   270  		}
   271  		dataLen := buffer.Len()
   272  		buffer.OverCap(rearHeadroom)
   273  		err = destinationConn.WritePacket(buffer, destination)
   274  		if err != nil {
   275  			buffer.Leak()
   276  			if !notFirstTime {
   277  				err = N.ReportHandshakeFailure(originSource, err)
   278  			}
   279  			return
   280  		}
   281  		n += int64(dataLen)
   282  		for _, counter := range readCounters {
   283  			counter(int64(dataLen))
   284  		}
   285  		for _, counter := range writeCounters {
   286  			counter(int64(dataLen))
   287  		}
   288  		notFirstTime = true
   289  	}
   290  }
   291  
   292  func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
   293  	frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
   294  	rearHeadroom := N.CalculateRearHeadroom(destinationConn)
   295  	var notFirstTime bool
   296  	for _, packetBuffer := range packetBuffers {
   297  		buffer := buf.NewPacket()
   298  		buffer.Resize(frontHeadroom, 0)
   299  		buffer.Reserve(rearHeadroom)
   300  		_, err = buffer.Write(packetBuffer.Buffer.Bytes())
   301  		packetBuffer.Buffer.Release()
   302  		if err != nil {
   303  			buffer.Release()
   304  			continue
   305  		}
   306  		dataLen := buffer.Len()
   307  		buffer.OverCap(rearHeadroom)
   308  		err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
   309  		if err != nil {
   310  			buffer.Leak()
   311  			if !notFirstTime {
   312  				err = N.ReportHandshakeFailure(originSource, err)
   313  			}
   314  			return
   315  		}
   316  		n += int64(dataLen)
   317  	}
   318  	return
   319  }
   320  
   321  func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
   322  	return CopyPacketConnContextList([]context.Context{ctx}, source, destination)
   323  }
   324  
   325  func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
   326  	var group task.Group
   327  	group.Append("upload", func(ctx context.Context) error {
   328  		return common.Error(CopyPacket(destination, source))
   329  	})
   330  	group.Append("download", func(ctx context.Context) error {
   331  		return common.Error(CopyPacket(source, destination))
   332  	})
   333  	group.Cleanup(func() {
   334  		common.Close(source, destination)
   335  	})
   336  	group.FastFail()
   337  	return group.RunContextList(contextList)
   338  }