github.com/sagernet/sing@v0.2.6/common/bufio/copy.go (about)

     1  package bufio
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"syscall"
     9  
    10  	"github.com/sagernet/sing/common"
    11  	"github.com/sagernet/sing/common/buf"
    12  	E "github.com/sagernet/sing/common/exceptions"
    13  	M "github.com/sagernet/sing/common/metadata"
    14  	N "github.com/sagernet/sing/common/network"
    15  	"github.com/sagernet/sing/common/rw"
    16  	"github.com/sagernet/sing/common/task"
    17  )
    18  
    19  func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
    20  	if source == nil {
    21  		return 0, E.New("nil reader")
    22  	} else if destination == nil {
    23  		return 0, E.New("nil writer")
    24  	}
    25  	originDestination := destination
    26  	var readCounters, writeCounters []N.CountFunc
    27  	for {
    28  		source, readCounters = N.UnwrapCountReader(source, readCounters)
    29  		destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
    30  		if cachedSrc, isCached := source.(N.CachedReader); isCached {
    31  			cachedBuffer := cachedSrc.ReadCached()
    32  			if cachedBuffer != nil {
    33  				if !cachedBuffer.IsEmpty() {
    34  					_, err = destination.Write(cachedBuffer.Bytes())
    35  					if err != nil {
    36  						cachedBuffer.Release()
    37  						return
    38  					}
    39  				}
    40  				cachedBuffer.Release()
    41  				continue
    42  			}
    43  		}
    44  		srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
    45  		dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
    46  		if srcIsSyscall && dstIsSyscall {
    47  			var handled bool
    48  			handled, n, err = CopyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
    49  			if handled {
    50  				return
    51  			}
    52  		}
    53  		break
    54  	}
    55  	return CopyExtended(originDestination, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
    56  }
    57  
    58  func CopyExtended(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
    59  	safeSrc := N.IsSafeReader(source)
    60  	headroom := N.CalculateFrontHeadroom(destination) + N.CalculateRearHeadroom(destination)
    61  	if safeSrc != nil {
    62  		if headroom == 0 {
    63  			return CopyExtendedWithSrcBuffer(originDestination, destination, safeSrc, readCounters, writeCounters)
    64  		}
    65  	}
    66  	readWaiter, isReadWaiter := CreateReadWaiter(source)
    67  	if isReadWaiter {
    68  		var handled bool
    69  		handled, n, err = copyWaitWithPool(originDestination, destination, readWaiter, readCounters, writeCounters)
    70  		if handled {
    71  			return
    72  		}
    73  	}
    74  	if !common.UnsafeBuffer || N.IsUnsafeWriter(destination) {
    75  		return CopyExtendedWithPool(originDestination, destination, source, readCounters, writeCounters)
    76  	}
    77  	bufferSize := N.CalculateMTU(source, destination)
    78  	if bufferSize > 0 {
    79  		bufferSize += headroom
    80  	} else {
    81  		bufferSize = buf.BufferSize
    82  	}
    83  	_buffer := buf.StackNewSize(bufferSize)
    84  	defer common.KeepAlive(_buffer)
    85  	buffer := common.Dup(_buffer)
    86  	defer buffer.Release()
    87  	return CopyExtendedBuffer(originDestination, destination, source, buffer, readCounters, writeCounters)
    88  }
    89  
    90  func CopyExtendedBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
    91  	buffer.IncRef()
    92  	defer buffer.DecRef()
    93  	frontHeadroom := N.CalculateFrontHeadroom(destination)
    94  	rearHeadroom := N.CalculateRearHeadroom(destination)
    95  	readBufferRaw := buffer.Slice()
    96  	readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
    97  	var notFirstTime bool
    98  	for {
    99  		readBuffer.Resize(frontHeadroom, 0)
   100  		err = source.ReadBuffer(readBuffer)
   101  		if err != nil {
   102  			if errors.Is(err, io.EOF) {
   103  				err = nil
   104  				return
   105  			}
   106  			if !notFirstTime {
   107  				err = N.HandshakeFailure(originDestination, err)
   108  			}
   109  			return
   110  		}
   111  		dataLen := readBuffer.Len()
   112  		buffer.Resize(readBuffer.Start(), dataLen)
   113  		err = destination.WriteBuffer(buffer)
   114  		if err != nil {
   115  			return
   116  		}
   117  		n += int64(dataLen)
   118  		for _, counter := range readCounters {
   119  			counter(int64(dataLen))
   120  		}
   121  		for _, counter := range writeCounters {
   122  			counter(int64(dataLen))
   123  		}
   124  		notFirstTime = true
   125  	}
   126  }
   127  
   128  func CopyExtendedWithSrcBuffer(originDestination io.Writer, destination N.ExtendedWriter, source N.ThreadSafeReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
   129  	var notFirstTime bool
   130  	for {
   131  		var buffer *buf.Buffer
   132  		buffer, err = source.ReadBufferThreadSafe()
   133  		if err != nil {
   134  			if errors.Is(err, io.EOF) {
   135  				err = nil
   136  				return
   137  			}
   138  			if !notFirstTime {
   139  				err = N.HandshakeFailure(originDestination, err)
   140  			}
   141  			return
   142  		}
   143  		dataLen := buffer.Len()
   144  		err = destination.WriteBuffer(buffer)
   145  		if err != nil {
   146  			buffer.Release()
   147  			return
   148  		}
   149  		n += int64(dataLen)
   150  		for _, counter := range readCounters {
   151  			counter(int64(dataLen))
   152  		}
   153  		for _, counter := range writeCounters {
   154  			counter(int64(dataLen))
   155  		}
   156  		notFirstTime = true
   157  	}
   158  }
   159  
   160  func CopyExtendedWithPool(originDestination io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
   161  	frontHeadroom := N.CalculateFrontHeadroom(destination)
   162  	rearHeadroom := N.CalculateRearHeadroom(destination)
   163  	bufferSize := N.CalculateMTU(source, destination)
   164  	if bufferSize > 0 {
   165  		bufferSize += frontHeadroom + rearHeadroom
   166  	} else {
   167  		bufferSize = buf.BufferSize
   168  	}
   169  	var notFirstTime bool
   170  	for {
   171  		buffer := buf.NewSize(bufferSize)
   172  		readBufferRaw := buffer.Slice()
   173  		readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
   174  		readBuffer.Resize(frontHeadroom, 0)
   175  		err = source.ReadBuffer(readBuffer)
   176  		if err != nil {
   177  			buffer.Release()
   178  			if errors.Is(err, io.EOF) {
   179  				err = nil
   180  				return
   181  			}
   182  			if !notFirstTime {
   183  				err = N.HandshakeFailure(originDestination, err)
   184  			}
   185  			return
   186  		}
   187  		dataLen := readBuffer.Len()
   188  		buffer.Resize(readBuffer.Start(), dataLen)
   189  		err = destination.WriteBuffer(buffer)
   190  		if err != nil {
   191  			buffer.Release()
   192  			return
   193  		}
   194  		n += int64(dataLen)
   195  		for _, counter := range readCounters {
   196  			counter(int64(dataLen))
   197  		}
   198  		for _, counter := range writeCounters {
   199  			counter(int64(dataLen))
   200  		}
   201  		notFirstTime = true
   202  	}
   203  }
   204  
   205  func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
   206  	return CopyConnContextList([]context.Context{ctx}, source, destination)
   207  }
   208  
   209  func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
   210  	var group task.Group
   211  	if _, dstDuplex := common.Cast[rw.WriteCloser](destination); dstDuplex {
   212  		group.Append("upload", func(ctx context.Context) error {
   213  			err := common.Error(Copy(destination, source))
   214  			if err == nil {
   215  				rw.CloseWrite(destination)
   216  			} else {
   217  				common.Close(destination)
   218  			}
   219  			return err
   220  		})
   221  	} else {
   222  		group.Append("upload", func(ctx context.Context) error {
   223  			defer common.Close(destination)
   224  			return common.Error(Copy(destination, source))
   225  		})
   226  	}
   227  	if _, srcDuplex := common.Cast[rw.WriteCloser](source); srcDuplex {
   228  		group.Append("download", func(ctx context.Context) error {
   229  			err := common.Error(Copy(source, destination))
   230  			if err == nil {
   231  				rw.CloseWrite(source)
   232  			} else {
   233  				common.Close(source)
   234  			}
   235  			return err
   236  		})
   237  	} else {
   238  		group.Append("download", func(ctx context.Context) error {
   239  			defer common.Close(source)
   240  			return common.Error(Copy(source, destination))
   241  		})
   242  	}
   243  	group.Cleanup(func() {
   244  		common.Close(source, destination)
   245  	})
   246  	return group.RunContextList(contextList)
   247  }
   248  
   249  func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
   250  	var readCounters, writeCounters []N.CountFunc
   251  	var cachedPackets []*N.PacketBuffer
   252  	for {
   253  		source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
   254  		destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters)
   255  		if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
   256  			packet := cachedReader.ReadCachedPacket()
   257  			if packet != nil {
   258  				cachedPackets = append(cachedPackets, packet)
   259  				continue
   260  			}
   261  		}
   262  		break
   263  	}
   264  	if cachedPackets != nil {
   265  		n, err = WritePacketWithPool(destinationConn, cachedPackets)
   266  		if err != nil {
   267  			return
   268  		}
   269  	}
   270  	safeSrc := N.IsSafePacketReader(source)
   271  	frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
   272  	rearHeadroom := N.CalculateRearHeadroom(destinationConn)
   273  	headroom := frontHeadroom + rearHeadroom
   274  	if safeSrc != nil {
   275  		if headroom == 0 {
   276  			var copyN int64
   277  			copyN, err = CopyPacketWithSrcBuffer(destinationConn, safeSrc, readCounters, writeCounters)
   278  			n += copyN
   279  			return
   280  		}
   281  	}
   282  	readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
   283  	if isReadWaiter {
   284  		var (
   285  			handled bool
   286  			copeN   int64
   287  		)
   288  		handled, copeN, err = copyPacketWaitWithPool(destinationConn, readWaiter, readCounters, writeCounters)
   289  		if handled {
   290  			n += copeN
   291  			return
   292  		}
   293  	}
   294  	if N.IsUnsafeWriter(destinationConn) {
   295  		return CopyPacketWithPool(destinationConn, source, readCounters, writeCounters)
   296  	}
   297  	bufferSize := N.CalculateMTU(source, destinationConn)
   298  	if bufferSize > 0 {
   299  		bufferSize += headroom
   300  	} else {
   301  		bufferSize = buf.UDPBufferSize
   302  	}
   303  	_buffer := buf.StackNewSize(bufferSize)
   304  	defer common.KeepAlive(_buffer)
   305  	buffer := common.Dup(_buffer)
   306  	defer buffer.Release()
   307  	buffer.IncRef()
   308  	defer buffer.DecRef()
   309  	var destination M.Socksaddr
   310  	var notFirstTime bool
   311  	readBufferRaw := buffer.Slice()
   312  	readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
   313  	for {
   314  		readBuffer.Resize(frontHeadroom, 0)
   315  		destination, err = source.ReadPacket(readBuffer)
   316  		if err != nil {
   317  			if !notFirstTime {
   318  				err = N.HandshakeFailure(destinationConn, err)
   319  			}
   320  			return
   321  		}
   322  		dataLen := readBuffer.Len()
   323  		buffer.Resize(readBuffer.Start(), dataLen)
   324  		err = destinationConn.WritePacket(buffer, destination)
   325  		if err != nil {
   326  			return
   327  		}
   328  		n += int64(dataLen)
   329  		for _, counter := range readCounters {
   330  			counter(int64(dataLen))
   331  		}
   332  		for _, counter := range writeCounters {
   333  			counter(int64(dataLen))
   334  		}
   335  		notFirstTime = true
   336  	}
   337  }
   338  
   339  func CopyPacketWithSrcBuffer(destinationConn N.PacketWriter, source N.ThreadSafePacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
   340  	var buffer *buf.Buffer
   341  	var destination M.Socksaddr
   342  	var notFirstTime bool
   343  	for {
   344  		buffer, destination, err = source.ReadPacketThreadSafe()
   345  		if err != nil {
   346  			if !notFirstTime {
   347  				err = N.HandshakeFailure(destinationConn, err)
   348  			}
   349  			return
   350  		}
   351  		dataLen := buffer.Len()
   352  		if dataLen == 0 {
   353  			continue
   354  		}
   355  		err = destinationConn.WritePacket(buffer, destination)
   356  		if err != nil {
   357  			buffer.Release()
   358  			return
   359  		}
   360  		n += int64(dataLen)
   361  		for _, counter := range readCounters {
   362  			counter(int64(dataLen))
   363  		}
   364  		for _, counter := range writeCounters {
   365  			counter(int64(dataLen))
   366  		}
   367  		notFirstTime = true
   368  	}
   369  }
   370  
   371  func CopyPacketWithPool(destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
   372  	frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
   373  	rearHeadroom := N.CalculateRearHeadroom(destinationConn)
   374  	bufferSize := N.CalculateMTU(source, destinationConn)
   375  	if bufferSize > 0 {
   376  		bufferSize += frontHeadroom + rearHeadroom
   377  	} else {
   378  		bufferSize = buf.UDPBufferSize
   379  	}
   380  	var destination M.Socksaddr
   381  	var notFirstTime bool
   382  	for {
   383  		buffer := buf.NewSize(bufferSize)
   384  		readBufferRaw := buffer.Slice()
   385  		readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
   386  		readBuffer.Resize(frontHeadroom, 0)
   387  		destination, err = source.ReadPacket(readBuffer)
   388  		if err != nil {
   389  			buffer.Release()
   390  			if !notFirstTime {
   391  				err = N.HandshakeFailure(destinationConn, err)
   392  			}
   393  			return
   394  		}
   395  		dataLen := readBuffer.Len()
   396  		buffer.Resize(readBuffer.Start(), dataLen)
   397  		err = destinationConn.WritePacket(buffer, destination)
   398  		if err != nil {
   399  			buffer.Release()
   400  			return
   401  		}
   402  		n += int64(dataLen)
   403  		for _, counter := range readCounters {
   404  			counter(int64(dataLen))
   405  		}
   406  		for _, counter := range writeCounters {
   407  			counter(int64(dataLen))
   408  		}
   409  		notFirstTime = true
   410  	}
   411  }
   412  
   413  func WritePacketWithPool(destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
   414  	frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
   415  	rearHeadroom := N.CalculateRearHeadroom(destinationConn)
   416  	for _, packetBuffer := range packetBuffers {
   417  		buffer := buf.NewPacket()
   418  		readBufferRaw := buffer.Slice()
   419  		readBuffer := buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
   420  		readBuffer.Resize(frontHeadroom, 0)
   421  		_, err = readBuffer.Write(packetBuffer.Buffer.Bytes())
   422  		packetBuffer.Buffer.Release()
   423  		if err != nil {
   424  			continue
   425  		}
   426  		dataLen := readBuffer.Len()
   427  		buffer.Resize(readBuffer.Start(), dataLen)
   428  		err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
   429  		if err != nil {
   430  			buffer.Release()
   431  			return
   432  		}
   433  		n += int64(dataLen)
   434  	}
   435  	return
   436  }
   437  
   438  func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.PacketConn) error {
   439  	return CopyPacketConnContextList([]context.Context{ctx}, source, destination)
   440  }
   441  
   442  func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
   443  	var group task.Group
   444  	group.Append("upload", func(ctx context.Context) error {
   445  		return common.Error(CopyPacket(destination, source))
   446  	})
   447  	group.Append("download", func(ctx context.Context) error {
   448  		return common.Error(CopyPacket(source, destination))
   449  	})
   450  	group.Cleanup(func() {
   451  		common.Close(source, destination)
   452  	})
   453  	group.FastFail()
   454  	return group.RunContextList(contextList)
   455  }