github.com/ipfans/trojan-go@v0.11.0/common/io.go (about) 1 package common 2 3 import ( 4 "io" 5 "net" 6 "sync" 7 8 "github.com/ipfans/trojan-go/log" 9 ) 10 11 type RewindReader struct { 12 mu sync.Mutex 13 rawReader io.Reader 14 buf []byte 15 bufReadIdx int 16 rewound bool 17 buffering bool 18 bufferSize int 19 } 20 21 func (r *RewindReader) Read(p []byte) (int, error) { 22 r.mu.Lock() 23 defer r.mu.Unlock() 24 25 if r.rewound { 26 if len(r.buf) > r.bufReadIdx { 27 n := copy(p, r.buf[r.bufReadIdx:]) 28 r.bufReadIdx += n 29 return n, nil 30 } 31 r.rewound = false // all buffering content has been read 32 } 33 n, err := r.rawReader.Read(p) 34 if r.buffering { 35 r.buf = append(r.buf, p[:n]...) 36 if len(r.buf) > r.bufferSize*2 { 37 log.Debug("read too many bytes!") 38 } 39 } 40 return n, err 41 } 42 43 func (r *RewindReader) ReadByte() (byte, error) { 44 buf := [1]byte{} 45 _, err := r.Read(buf[:]) 46 return buf[0], err 47 } 48 49 func (r *RewindReader) Discard(n int) (int, error) { 50 buf := [128]byte{} 51 if n < 128 { 52 return r.Read(buf[:n]) 53 } 54 for discarded := 0; discarded+128 < n; discarded += 128 { 55 _, err := r.Read(buf[:]) 56 if err != nil { 57 return discarded, err 58 } 59 } 60 if rest := n % 128; rest != 0 { 61 return r.Read(buf[:rest]) 62 } 63 return n, nil 64 } 65 66 func (r *RewindReader) Rewind() { 67 r.mu.Lock() 68 if r.bufferSize == 0 { 69 panic("no buffer") 70 } 71 r.rewound = true 72 r.bufReadIdx = 0 73 r.mu.Unlock() 74 } 75 76 func (r *RewindReader) StopBuffering() { 77 r.mu.Lock() 78 r.buffering = false 79 r.mu.Unlock() 80 } 81 82 func (r *RewindReader) SetBufferSize(size int) { 83 r.mu.Lock() 84 if size == 0 { // disable buffering 85 if !r.buffering { 86 panic("reader is disabled") 87 } 88 r.buffering = false 89 r.buf = nil 90 r.bufReadIdx = 0 91 r.bufferSize = 0 92 } else { 93 if r.buffering { 94 panic("reader is buffering") 95 } 96 r.buffering = true 97 r.bufReadIdx = 0 98 r.bufferSize = size 99 r.buf = make([]byte, 0, size) 100 } 101 r.mu.Unlock() 102 } 103 104 type RewindConn struct { 105 net.Conn 106 *RewindReader 107 } 108 109 func (c *RewindConn) Read(p []byte) (int, error) { 110 return c.RewindReader.Read(p) 111 } 112 113 func NewRewindConn(conn net.Conn) *RewindConn { 114 return &RewindConn{ 115 Conn: conn, 116 RewindReader: &RewindReader{ 117 rawReader: conn, 118 }, 119 } 120 } 121 122 type StickyWriter struct { 123 rawWriter io.Writer 124 writeBuffer []byte 125 MaxBuffered int 126 } 127 128 func (w *StickyWriter) Write(p []byte) (int, error) { 129 if w.MaxBuffered > 0 { 130 w.MaxBuffered-- 131 w.writeBuffer = append(w.writeBuffer, p...) 132 if w.MaxBuffered != 0 { 133 return len(p), nil 134 } 135 w.MaxBuffered = 0 136 _, err := w.rawWriter.Write(w.writeBuffer) 137 w.writeBuffer = nil 138 return len(p), err 139 } 140 return w.rawWriter.Write(p) 141 }