github.com/ipfans/trojan-go@v0.11.0/common/io.go (about)

     1  package common
     2  
     3  import (
     4  	"io"
     5  	"net"
     6  	"sync"
     7  
     8  	"github.com/ipfans/trojan-go/log"
     9  )
    10  
    11  type RewindReader struct {
    12  	mu         sync.Mutex
    13  	rawReader  io.Reader
    14  	buf        []byte
    15  	bufReadIdx int
    16  	rewound    bool
    17  	buffering  bool
    18  	bufferSize int
    19  }
    20  
    21  func (r *RewindReader) Read(p []byte) (int, error) {
    22  	r.mu.Lock()
    23  	defer r.mu.Unlock()
    24  
    25  	if r.rewound {
    26  		if len(r.buf) > r.bufReadIdx {
    27  			n := copy(p, r.buf[r.bufReadIdx:])
    28  			r.bufReadIdx += n
    29  			return n, nil
    30  		}
    31  		r.rewound = false // all buffering content has been read
    32  	}
    33  	n, err := r.rawReader.Read(p)
    34  	if r.buffering {
    35  		r.buf = append(r.buf, p[:n]...)
    36  		if len(r.buf) > r.bufferSize*2 {
    37  			log.Debug("read too many bytes!")
    38  		}
    39  	}
    40  	return n, err
    41  }
    42  
    43  func (r *RewindReader) ReadByte() (byte, error) {
    44  	buf := [1]byte{}
    45  	_, err := r.Read(buf[:])
    46  	return buf[0], err
    47  }
    48  
    49  func (r *RewindReader) Discard(n int) (int, error) {
    50  	buf := [128]byte{}
    51  	if n < 128 {
    52  		return r.Read(buf[:n])
    53  	}
    54  	for discarded := 0; discarded+128 < n; discarded += 128 {
    55  		_, err := r.Read(buf[:])
    56  		if err != nil {
    57  			return discarded, err
    58  		}
    59  	}
    60  	if rest := n % 128; rest != 0 {
    61  		return r.Read(buf[:rest])
    62  	}
    63  	return n, nil
    64  }
    65  
    66  func (r *RewindReader) Rewind() {
    67  	r.mu.Lock()
    68  	if r.bufferSize == 0 {
    69  		panic("no buffer")
    70  	}
    71  	r.rewound = true
    72  	r.bufReadIdx = 0
    73  	r.mu.Unlock()
    74  }
    75  
    76  func (r *RewindReader) StopBuffering() {
    77  	r.mu.Lock()
    78  	r.buffering = false
    79  	r.mu.Unlock()
    80  }
    81  
    82  func (r *RewindReader) SetBufferSize(size int) {
    83  	r.mu.Lock()
    84  	if size == 0 { // disable buffering
    85  		if !r.buffering {
    86  			panic("reader is disabled")
    87  		}
    88  		r.buffering = false
    89  		r.buf = nil
    90  		r.bufReadIdx = 0
    91  		r.bufferSize = 0
    92  	} else {
    93  		if r.buffering {
    94  			panic("reader is buffering")
    95  		}
    96  		r.buffering = true
    97  		r.bufReadIdx = 0
    98  		r.bufferSize = size
    99  		r.buf = make([]byte, 0, size)
   100  	}
   101  	r.mu.Unlock()
   102  }
   103  
   104  type RewindConn struct {
   105  	net.Conn
   106  	*RewindReader
   107  }
   108  
   109  func (c *RewindConn) Read(p []byte) (int, error) {
   110  	return c.RewindReader.Read(p)
   111  }
   112  
   113  func NewRewindConn(conn net.Conn) *RewindConn {
   114  	return &RewindConn{
   115  		Conn: conn,
   116  		RewindReader: &RewindReader{
   117  			rawReader: conn,
   118  		},
   119  	}
   120  }
   121  
   122  type StickyWriter struct {
   123  	rawWriter   io.Writer
   124  	writeBuffer []byte
   125  	MaxBuffered int
   126  }
   127  
   128  func (w *StickyWriter) Write(p []byte) (int, error) {
   129  	if w.MaxBuffered > 0 {
   130  		w.MaxBuffered--
   131  		w.writeBuffer = append(w.writeBuffer, p...)
   132  		if w.MaxBuffered != 0 {
   133  			return len(p), nil
   134  		}
   135  		w.MaxBuffered = 0
   136  		_, err := w.rawWriter.Write(w.writeBuffer)
   137  		w.writeBuffer = nil
   138  		return len(p), err
   139  	}
   140  	return w.rawWriter.Write(p)
   141  }