github.com/sagernet/sing@v0.2.6/common/bufio/once.go (about) 1 package bufio 2 3 import ( 4 "io" 5 6 "github.com/sagernet/sing/common" 7 "github.com/sagernet/sing/common/buf" 8 N "github.com/sagernet/sing/common/network" 9 ) 10 11 func CopyTimes(dst io.Writer, src io.Reader, times int) (n int64, err error) { 12 return CopyExtendedTimes(NewExtendedWriter(N.UnwrapWriter(dst)), NewExtendedReader(N.UnwrapReader(src)), times) 13 } 14 15 func CopyExtendedTimes(dst N.ExtendedWriter, src N.ExtendedReader, times int) (n int64, err error) { 16 frontHeadroom := N.CalculateFrontHeadroom(dst) 17 rearHeadroom := N.CalculateRearHeadroom(dst) 18 bufferSize := N.CalculateMTU(src, dst) 19 if bufferSize > 0 { 20 bufferSize += frontHeadroom + rearHeadroom 21 } else { 22 bufferSize = buf.BufferSize 23 } 24 dstUnsafe := N.IsUnsafeWriter(dst) 25 var buffer *buf.Buffer 26 if !dstUnsafe { 27 _buffer := buf.StackNewSize(bufferSize) 28 defer common.KeepAlive(_buffer) 29 buffer = common.Dup(_buffer) 30 defer buffer.Release() 31 buffer.IncRef() 32 defer buffer.DecRef() 33 } 34 notFirstTime := true 35 for i := 0; i < times; i++ { 36 if dstUnsafe { 37 buffer = buf.NewSize(bufferSize) 38 } 39 readBufferRaw := buffer.Slice() 40 readBuffer := buf.With(readBufferRaw[:cap(readBufferRaw)-rearHeadroom]) 41 readBuffer.Resize(frontHeadroom, 0) 42 err = src.ReadBuffer(readBuffer) 43 if err != nil { 44 buffer.Release() 45 if !notFirstTime { 46 err = N.HandshakeFailure(dst, err) 47 } 48 return 49 } 50 dataLen := readBuffer.Len() 51 buffer.Resize(readBuffer.Start(), dataLen) 52 err = dst.WriteBuffer(buffer) 53 if err != nil { 54 buffer.Release() 55 return 56 } 57 n += int64(dataLen) 58 notFirstTime = true 59 } 60 return 61 } 62 63 type ReadFromWriter interface { 64 io.ReaderFrom 65 io.Writer 66 } 67 68 func ReadFrom0(readerFrom ReadFromWriter, reader io.Reader) (n int64, err error) { 69 n, err = CopyTimes(readerFrom, reader, 1) 70 if err != nil { 71 return 72 } 73 var rn int64 74 rn, err = readerFrom.ReadFrom(reader) 75 if err != nil { 76 return 77 } 78 n += rn 79 return 80 } 81 82 func ReadFromN(readerFrom ReadFromWriter, reader io.Reader, times int) (n int64, err error) { 83 n, err = CopyTimes(readerFrom, reader, times) 84 if err != nil { 85 return 86 } 87 var rn int64 88 rn, err = readerFrom.ReadFrom(reader) 89 if err != nil { 90 return 91 } 92 n += rn 93 return 94 } 95 96 type WriteToReader interface { 97 io.WriterTo 98 io.Reader 99 } 100 101 func WriteTo0(writerTo WriteToReader, writer io.Writer) (n int64, err error) { 102 n, err = CopyTimes(writer, writerTo, 1) 103 if err != nil { 104 return 105 } 106 var wn int64 107 wn, err = writerTo.WriteTo(writer) 108 if err != nil { 109 return 110 } 111 n += wn 112 return 113 } 114 115 func WriteToN(writerTo WriteToReader, writer io.Writer, times int) (n int64, err error) { 116 n, err = CopyTimes(writer, writerTo, times) 117 if err != nil { 118 return 119 } 120 var wn int64 121 wn, err = writerTo.WriteTo(writer) 122 if err != nil { 123 return 124 } 125 n += wn 126 return 127 }