github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/vmess/chunk.go (about)

     1  package vmess
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"io"
     7  	"net"
     8  
     9  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    10  	"github.com/Asutorufa/yuhaiin/pkg/utils/relay"
    11  )
    12  
    13  type writer interface {
    14  	io.WriteCloser
    15  	io.ReaderFrom
    16  }
    17  
    18  type connWriter struct {
    19  	net.Conn
    20  }
    21  
    22  func (c *connWriter) ReadFrom(r io.Reader) (int64, error) {
    23  	return relay.Copy(c.Conn, r)
    24  }
    25  
    26  const (
    27  	lenSize          = 2
    28  	maxChunkSize     = 1 << 14 // 16384
    29  	defaultChunkSize = 1 << 13 // 8192
    30  )
    31  
    32  var _ writer = &aeadWriter{}
    33  
    34  type chunkedWriter struct {
    35  	io.Writer
    36  	buf [lenSize + maxChunkSize]byte
    37  }
    38  
    39  // ChunkedWriter returns a chunked writer
    40  func ChunkedWriter(w io.Writer) writer { return &chunkedWriter{Writer: w} }
    41  
    42  func (w *chunkedWriter) Close() error { return nil }
    43  
    44  func (w *chunkedWriter) Write(b []byte) (int, error) {
    45  	n, err := w.ReadFrom(bytes.NewBuffer(b))
    46  	return int(n), err
    47  }
    48  
    49  func (w *chunkedWriter) ReadFrom(r io.Reader) (n int64, err error) {
    50  	for {
    51  		nr, er := r.Read(w.buf[lenSize : lenSize+defaultChunkSize])
    52  		if nr > 0 {
    53  			n += int64(nr)
    54  			binary.BigEndian.PutUint16(w.buf[:lenSize], uint16(nr))
    55  			_, err = w.Writer.Write(w.buf[:lenSize+nr])
    56  			if err != nil {
    57  				// err = ew
    58  				break
    59  			}
    60  		}
    61  
    62  		if er != nil {
    63  			if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
    64  				err = er
    65  			}
    66  			break
    67  		}
    68  	}
    69  
    70  	return n, err
    71  }
    72  
    73  type chunkedReader struct {
    74  	io.Reader
    75  	leftBytes int
    76  }
    77  
    78  // ChunkedReader returns a chunked reader
    79  func ChunkedReader(r io.Reader) io.ReadCloser { return &chunkedReader{Reader: r} }
    80  func (r *chunkedReader) Close() error         { return nil }
    81  func (r *chunkedReader) Read(b []byte) (int, error) {
    82  	if r.leftBytes <= 0 {
    83  		buf := pool.GetBytes(lenSize)
    84  		defer pool.PutBytes(buf)
    85  
    86  		// get length
    87  		_, err := io.ReadFull(r.Reader, buf[:lenSize])
    88  		if err != nil {
    89  			return 0, err
    90  		}
    91  		r.leftBytes = int(binary.BigEndian.Uint16(buf[:lenSize]))
    92  
    93  		// if length == 0, then this is the end
    94  		if r.leftBytes <= 0 {
    95  			return 0, nil
    96  		}
    97  	}
    98  
    99  	m, err := r.Reader.Read(b)
   100  	if err != nil {
   101  		return 0, err
   102  	}
   103  	r.leftBytes -= m
   104  
   105  	return m, nil
   106  }