github.com/igoogolx/clash@v1.19.8/transport/shadowsocks/shadowaead/stream.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"crypto/rand"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  
    10  	"github.com/igoogolx/clash/common/pool"
    11  )
    12  
    13  const (
    14  	// payloadSizeMask is the maximum size of payload in bytes.
    15  	payloadSizeMask = 0x3FFF    // 16*1024 - 1
    16  	bufSize         = 17 * 1024 // >= 2+aead.Overhead()+payloadSizeMask+aead.Overhead()
    17  )
    18  
    19  var ErrZeroChunk = errors.New("zero chunk")
    20  
    21  type Writer struct {
    22  	io.Writer
    23  	cipher.AEAD
    24  	nonce [32]byte // should be sufficient for most nonce sizes
    25  }
    26  
    27  // NewWriter wraps an io.Writer with authenticated encryption.
    28  func NewWriter(w io.Writer, aead cipher.AEAD) *Writer { return &Writer{Writer: w, AEAD: aead} }
    29  
    30  // Write encrypts p and writes to the embedded io.Writer.
    31  func (w *Writer) Write(p []byte) (n int, err error) {
    32  	buf := pool.Get(bufSize)
    33  	defer pool.Put(buf)
    34  	nonce := w.nonce[:w.NonceSize()]
    35  	tag := w.Overhead()
    36  	off := 2 + tag
    37  
    38  	// compatible with snell
    39  	if len(p) == 0 {
    40  		buf = buf[:off]
    41  		buf[0], buf[1] = byte(0), byte(0)
    42  		w.Seal(buf[:0], nonce, buf[:2], nil)
    43  		increment(nonce)
    44  		_, err = w.Writer.Write(buf)
    45  		return
    46  	}
    47  
    48  	for nr := 0; n < len(p) && err == nil; n += nr {
    49  		nr = payloadSizeMask
    50  		if n+nr > len(p) {
    51  			nr = len(p) - n
    52  		}
    53  		buf = buf[:off+nr+tag]
    54  		buf[0], buf[1] = byte(nr>>8), byte(nr) // big-endian payload size
    55  		w.Seal(buf[:0], nonce, buf[:2], nil)
    56  		increment(nonce)
    57  		w.Seal(buf[:off], nonce, p[n:n+nr], nil)
    58  		increment(nonce)
    59  		_, err = w.Writer.Write(buf)
    60  	}
    61  	return
    62  }
    63  
    64  // ReadFrom reads from the given io.Reader until EOF or error, encrypts and
    65  // writes to the embedded io.Writer. Returns number of bytes read from r and
    66  // any error encountered.
    67  func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
    68  	buf := pool.Get(bufSize)
    69  	defer pool.Put(buf)
    70  	nonce := w.nonce[:w.NonceSize()]
    71  	tag := w.Overhead()
    72  	off := 2 + tag
    73  	for {
    74  		nr, er := r.Read(buf[off : off+payloadSizeMask])
    75  		n += int64(nr)
    76  		buf[0], buf[1] = byte(nr>>8), byte(nr)
    77  		w.Seal(buf[:0], nonce, buf[:2], nil)
    78  		increment(nonce)
    79  		w.Seal(buf[:off], nonce, buf[off:off+nr], nil)
    80  		increment(nonce)
    81  		if _, ew := w.Writer.Write(buf[:off+nr+tag]); ew != nil {
    82  			err = ew
    83  			return
    84  		}
    85  		if er != nil {
    86  			if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
    87  				err = er
    88  			}
    89  			return
    90  		}
    91  	}
    92  }
    93  
    94  type Reader struct {
    95  	io.Reader
    96  	cipher.AEAD
    97  	nonce [32]byte // should be sufficient for most nonce sizes
    98  	buf   []byte   // to be put back into bufPool
    99  	off   int      // offset to unconsumed part of buf
   100  }
   101  
   102  // NewReader wraps an io.Reader with authenticated decryption.
   103  func NewReader(r io.Reader, aead cipher.AEAD) *Reader { return &Reader{Reader: r, AEAD: aead} }
   104  
   105  // Read and decrypt a record into p. len(p) >= max payload size + AEAD overhead.
   106  func (r *Reader) read(p []byte) (int, error) {
   107  	nonce := r.nonce[:r.NonceSize()]
   108  	tag := r.Overhead()
   109  
   110  	// decrypt payload size
   111  	p = p[:2+tag]
   112  	if _, err := io.ReadFull(r.Reader, p); err != nil {
   113  		return 0, err
   114  	}
   115  	_, err := r.Open(p[:0], nonce, p, nil)
   116  	increment(nonce)
   117  	if err != nil {
   118  		return 0, err
   119  	}
   120  
   121  	// decrypt payload
   122  	size := (int(p[0])<<8 + int(p[1])) & payloadSizeMask
   123  	if size == 0 {
   124  		return 0, ErrZeroChunk
   125  	}
   126  
   127  	p = p[:size+tag]
   128  	if _, err := io.ReadFull(r.Reader, p); err != nil {
   129  		return 0, err
   130  	}
   131  	_, err = r.Open(p[:0], nonce, p, nil)
   132  	increment(nonce)
   133  	if err != nil {
   134  		return 0, err
   135  	}
   136  	return size, nil
   137  }
   138  
   139  // Read reads from the embedded io.Reader, decrypts and writes to p.
   140  func (r *Reader) Read(p []byte) (int, error) {
   141  	if r.buf == nil {
   142  		if len(p) >= payloadSizeMask+r.Overhead() {
   143  			return r.read(p)
   144  		}
   145  		b := pool.Get(bufSize)
   146  		n, err := r.read(b)
   147  		if err != nil {
   148  			return 0, err
   149  		}
   150  		r.buf = b[:n]
   151  		r.off = 0
   152  	}
   153  
   154  	n := copy(p, r.buf[r.off:])
   155  	r.off += n
   156  	if r.off == len(r.buf) {
   157  		pool.Put(r.buf[:cap(r.buf)])
   158  		r.buf = nil
   159  	}
   160  	return n, nil
   161  }
   162  
   163  // WriteTo reads from the embedded io.Reader, decrypts and writes to w until
   164  // there's no more data to write or when an error occurs. Return number of
   165  // bytes written to w and any error encountered.
   166  func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
   167  	if r.buf == nil {
   168  		r.buf = pool.Get(bufSize)
   169  		r.off = len(r.buf)
   170  	}
   171  
   172  	for {
   173  		for r.off < len(r.buf) {
   174  			nw, ew := w.Write(r.buf[r.off:])
   175  			r.off += nw
   176  			n += int64(nw)
   177  			if ew != nil {
   178  				if r.off == len(r.buf) {
   179  					pool.Put(r.buf[:cap(r.buf)])
   180  					r.buf = nil
   181  				}
   182  				err = ew
   183  				return
   184  			}
   185  		}
   186  
   187  		nr, er := r.read(r.buf)
   188  		if er != nil {
   189  			if er != io.EOF {
   190  				err = er
   191  			}
   192  			return
   193  		}
   194  		r.buf = r.buf[:nr]
   195  		r.off = 0
   196  	}
   197  }
   198  
   199  // increment little-endian encoded unsigned integer b. Wrap around on overflow.
   200  func increment(b []byte) {
   201  	for i := range b {
   202  		b[i]++
   203  		if b[i] != 0 {
   204  			return
   205  		}
   206  	}
   207  }
   208  
   209  type Conn struct {
   210  	net.Conn
   211  	Cipher
   212  	r *Reader
   213  	w *Writer
   214  }
   215  
   216  // NewConn wraps a stream-oriented net.Conn with cipher.
   217  func NewConn(c net.Conn, ciph Cipher) *Conn { return &Conn{Conn: c, Cipher: ciph} }
   218  
   219  func (c *Conn) initReader() error {
   220  	salt := make([]byte, c.SaltSize())
   221  	if _, err := io.ReadFull(c.Conn, salt); err != nil {
   222  		return err
   223  	}
   224  
   225  	aead, err := c.Decrypter(salt)
   226  	if err != nil {
   227  		return err
   228  	}
   229  
   230  	c.r = NewReader(c.Conn, aead)
   231  	return nil
   232  }
   233  
   234  func (c *Conn) Read(b []byte) (int, error) {
   235  	if c.r == nil {
   236  		if err := c.initReader(); err != nil {
   237  			return 0, err
   238  		}
   239  	}
   240  	return c.r.Read(b)
   241  }
   242  
   243  func (c *Conn) WriteTo(w io.Writer) (int64, error) {
   244  	if c.r == nil {
   245  		if err := c.initReader(); err != nil {
   246  			return 0, err
   247  		}
   248  	}
   249  	return c.r.WriteTo(w)
   250  }
   251  
   252  func (c *Conn) initWriter() error {
   253  	salt := make([]byte, c.SaltSize())
   254  	if _, err := rand.Read(salt); err != nil {
   255  		return err
   256  	}
   257  	aead, err := c.Encrypter(salt)
   258  	if err != nil {
   259  		return err
   260  	}
   261  	_, err = c.Conn.Write(salt)
   262  	if err != nil {
   263  		return err
   264  	}
   265  	c.w = NewWriter(c.Conn, aead)
   266  	return nil
   267  }
   268  
   269  func (c *Conn) Write(b []byte) (int, error) {
   270  	if c.w == nil {
   271  		if err := c.initWriter(); err != nil {
   272  			return 0, err
   273  		}
   274  	}
   275  	return c.w.Write(b)
   276  }
   277  
   278  func (c *Conn) ReadFrom(r io.Reader) (int64, error) {
   279  	if c.w == nil {
   280  		if err := c.initWriter(); err != nil {
   281  			return 0, err
   282  		}
   283  	}
   284  	return c.w.ReadFrom(r)
   285  }