github.com/sagernet/sing@v0.2.6/common/bufio/once.go (about)

     1  package bufio
     2  
     3  import (
     4  	"io"
     5  
     6  	"github.com/sagernet/sing/common"
     7  	"github.com/sagernet/sing/common/buf"
     8  	N "github.com/sagernet/sing/common/network"
     9  )
    10  
    11  func CopyTimes(dst io.Writer, src io.Reader, times int) (n int64, err error) {
    12  	return CopyExtendedTimes(NewExtendedWriter(N.UnwrapWriter(dst)), NewExtendedReader(N.UnwrapReader(src)), times)
    13  }
    14  
    15  func CopyExtendedTimes(dst N.ExtendedWriter, src N.ExtendedReader, times int) (n int64, err error) {
    16  	frontHeadroom := N.CalculateFrontHeadroom(dst)
    17  	rearHeadroom := N.CalculateRearHeadroom(dst)
    18  	bufferSize := N.CalculateMTU(src, dst)
    19  	if bufferSize > 0 {
    20  		bufferSize += frontHeadroom + rearHeadroom
    21  	} else {
    22  		bufferSize = buf.BufferSize
    23  	}
    24  	dstUnsafe := N.IsUnsafeWriter(dst)
    25  	var buffer *buf.Buffer
    26  	if !dstUnsafe {
    27  		_buffer := buf.StackNewSize(bufferSize)
    28  		defer common.KeepAlive(_buffer)
    29  		buffer = common.Dup(_buffer)
    30  		defer buffer.Release()
    31  		buffer.IncRef()
    32  		defer buffer.DecRef()
    33  	}
    34  	notFirstTime := true
    35  	for i := 0; i < times; i++ {
    36  		if dstUnsafe {
    37  			buffer = buf.NewSize(bufferSize)
    38  		}
    39  		readBufferRaw := buffer.Slice()
    40  		readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom])
    41  		readBuffer.Resize(frontHeadroom, 0)
    42  		err = src.ReadBuffer(readBuffer)
    43  		if err != nil {
    44  			buffer.Release()
    45  			if !notFirstTime {
    46  				err = N.HandshakeFailure(dst, err)
    47  			}
    48  			return
    49  		}
    50  		dataLen := readBuffer.Len()
    51  		buffer.Resize(readBuffer.Start(), dataLen)
    52  		err = dst.WriteBuffer(buffer)
    53  		if err != nil {
    54  			buffer.Release()
    55  			return
    56  		}
    57  		n += int64(dataLen)
    58  		notFirstTime = true
    59  	}
    60  	return
    61  }
    62  
    63  type ReadFromWriter interface {
    64  	io.ReaderFrom
    65  	io.Writer
    66  }
    67  
    68  func ReadFrom0(readerFrom ReadFromWriter, reader io.Reader) (n int64, err error) {
    69  	n, err = CopyTimes(readerFrom, reader, 1)
    70  	if err != nil {
    71  		return
    72  	}
    73  	var rn int64
    74  	rn, err = readerFrom.ReadFrom(reader)
    75  	if err != nil {
    76  		return
    77  	}
    78  	n += rn
    79  	return
    80  }
    81  
    82  func ReadFromN(readerFrom ReadFromWriter, reader io.Reader, times int) (n int64, err error) {
    83  	n, err = CopyTimes(readerFrom, reader, times)
    84  	if err != nil {
    85  		return
    86  	}
    87  	var rn int64
    88  	rn, err = readerFrom.ReadFrom(reader)
    89  	if err != nil {
    90  		return
    91  	}
    92  	n += rn
    93  	return
    94  }
    95  
    96  type WriteToReader interface {
    97  	io.WriterTo
    98  	io.Reader
    99  }
   100  
   101  func WriteTo0(writerTo WriteToReader, writer io.Writer) (n int64, err error) {
   102  	n, err = CopyTimes(writer, writerTo, 1)
   103  	if err != nil {
   104  		return
   105  	}
   106  	var wn int64
   107  	wn, err = writerTo.WriteTo(writer)
   108  	if err != nil {
   109  		return
   110  	}
   111  	n += wn
   112  	return
   113  }
   114  
   115  func WriteToN(writerTo WriteToReader, writer io.Writer, times int) (n int64, err error) {
   116  	n, err = CopyTimes(writer, writerTo, times)
   117  	if err != nil {
   118  		return
   119  	}
   120  	var wn int64
   121  	wn, err = writerTo.WriteTo(writer)
   122  	if err != nil {
   123  		return
   124  	}
   125  	n += wn
   126  	return
   127  }