github.com/yaling888/clash@v1.53.0/transport/shadowsocks/shadowaead/stream.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"crypto/rand"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"sync"
    10  )
    11  
    12  const (
    13  	// payloadSizeMask is the maximum size of payload in bytes.
    14  	payloadSizeMask = 0x3FFF    // 16*1024 - 1
    15  	bufSize         = 17 * 1024 // >= 2+aead.Overhead()+payloadSizeMask+aead.Overhead()
    16  )
    17  
    18  var ErrZeroChunk = errors.New("zero chunk")
    19  
    20  type Writer struct {
    21  	io.Writer
    22  	cipher.AEAD
    23  	nonce [32]byte // should be sufficient for most nonce sizes
    24  }
    25  
    26  // NewWriter wraps an io.Writer with authenticated encryption.
    27  func NewWriter(w io.Writer, aead cipher.AEAD) *Writer { return &Writer{Writer: w, AEAD: aead} }
    28  
    29  var bufPool = sync.Pool{
    30  	New: func() any {
    31  		b := make([]byte, bufSize)
    32  		return &b
    33  	},
    34  }
    35  
    36  // Write encrypts p and writes to the embedded io.Writer.
    37  func (w *Writer) Write(p []byte) (n int, err error) {
    38  	bufP := bufPool.Get().(*[]byte)
    39  	defer bufPool.Put(bufP)
    40  	nonce := w.nonce[:w.NonceSize()]
    41  	tag := w.Overhead()
    42  	off := 2 + tag
    43  
    44  	// compatible with snell
    45  	if len(p) == 0 {
    46  		(*bufP)[0], (*bufP)[1] = byte(0), byte(0)
    47  		w.Seal((*bufP)[:0], nonce, (*bufP)[:2], nil)
    48  		increment(nonce)
    49  		_, err = w.Writer.Write((*bufP)[:off])
    50  		return
    51  	}
    52  
    53  	for nr := 0; n < len(p) && err == nil; n += nr {
    54  		nr = payloadSizeMask
    55  		if n+nr > len(p) {
    56  			nr = len(p) - n
    57  		}
    58  		(*bufP)[0], (*bufP)[1] = byte(nr>>8), byte(nr) // big-endian payload size
    59  		w.Seal((*bufP)[:0], nonce, (*bufP)[:2], nil)
    60  		increment(nonce)
    61  		w.Seal((*bufP)[:off], nonce, p[n:n+nr], nil)
    62  		increment(nonce)
    63  		_, err = w.Writer.Write((*bufP)[:off+nr+tag])
    64  	}
    65  	return
    66  }
    67  
    68  // ReadFrom reads from the given io.Reader until EOF or error, encrypts and
    69  // writes to the embedded io.Writer. Returns number of bytes read from r and
    70  // any error encountered.
    71  func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
    72  	bufP := bufPool.Get().(*[]byte)
    73  	defer bufPool.Put(bufP)
    74  	nonce := w.nonce[:w.NonceSize()]
    75  	tag := w.Overhead()
    76  	off := 2 + tag
    77  	for {
    78  		nr, er := r.Read((*bufP)[off : off+payloadSizeMask])
    79  		n += int64(nr)
    80  		(*bufP)[0], (*bufP)[1] = byte(nr>>8), byte(nr)
    81  		w.Seal((*bufP)[:0], nonce, (*bufP)[:2], nil)
    82  		increment(nonce)
    83  		w.Seal((*bufP)[:off], nonce, (*bufP)[off:off+nr], nil)
    84  		increment(nonce)
    85  		if _, ew := w.Writer.Write((*bufP)[:off+nr+tag]); ew != nil {
    86  			err = ew
    87  			return
    88  		}
    89  		if er != nil {
    90  			if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
    91  				err = er
    92  			}
    93  			return
    94  		}
    95  	}
    96  }
    97  
    98  type Reader struct {
    99  	io.Reader
   100  	cipher.AEAD
   101  	nonce [32]byte // should be sufficient for most nonce sizes
   102  	bufP  *[]byte  // to be put back into bufPool
   103  	off   int      // offset to unconsumed part of buf
   104  }
   105  
   106  // NewReader wraps an io.Reader with authenticated decryption.
   107  func NewReader(r io.Reader, aead cipher.AEAD) *Reader { return &Reader{Reader: r, AEAD: aead} }
   108  
   109  // Read and decrypt a record into p. len(p) >= max payload size + AEAD overhead.
   110  func (r *Reader) read(p []byte) (int, error) {
   111  	nonce := r.nonce[:r.NonceSize()]
   112  	tag := r.Overhead()
   113  
   114  	// decrypt payload size
   115  	p = p[:2+tag]
   116  	if _, err := io.ReadFull(r.Reader, p); err != nil {
   117  		return 0, err
   118  	}
   119  	_, err := r.Open(p[:0], nonce, p, nil)
   120  	increment(nonce)
   121  	if err != nil {
   122  		return 0, err
   123  	}
   124  
   125  	// decrypt payload
   126  	size := (int(p[0])<<8 + int(p[1])) & payloadSizeMask
   127  	if size == 0 {
   128  		return 0, ErrZeroChunk
   129  	}
   130  
   131  	p = p[:size+tag]
   132  	if _, err := io.ReadFull(r.Reader, p); err != nil {
   133  		return 0, err
   134  	}
   135  	_, err = r.Open(p[:0], nonce, p, nil)
   136  	increment(nonce)
   137  	if err != nil {
   138  		return 0, err
   139  	}
   140  	return size, nil
   141  }
   142  
   143  // Read reads from the embedded io.Reader, decrypts and writes to p.
   144  func (r *Reader) Read(p []byte) (int, error) {
   145  	if r.bufP == nil {
   146  		if len(p) >= payloadSizeMask+r.Overhead() {
   147  			return r.read(p)
   148  		}
   149  		bp := bufPool.Get().(*[]byte)
   150  		n, err := r.read(*bp)
   151  		if err != nil {
   152  			return 0, err
   153  		}
   154  		*bp = (*bp)[:n]
   155  		r.bufP = bp
   156  		r.off = 0
   157  	}
   158  
   159  	n := copy(p, (*r.bufP)[r.off:])
   160  	r.off += n
   161  	if r.off == len(*r.bufP) {
   162  		*r.bufP = (*r.bufP)[:bufSize]
   163  		bufPool.Put(r.bufP)
   164  		r.bufP = nil
   165  	}
   166  	return n, nil
   167  }
   168  
   169  // WriteTo reads from the embedded io.Reader, decrypts and writes to w until
   170  // there's no more data to write or when an error occurs. Return number of
   171  // bytes written to w and any error encountered.
   172  func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
   173  	if r.bufP == nil {
   174  		r.bufP = bufPool.Get().(*[]byte)
   175  		r.off = len(*r.bufP)
   176  	}
   177  
   178  	for {
   179  		for r.off < len(*r.bufP) {
   180  			nw, ew := w.Write((*r.bufP)[r.off:])
   181  			r.off += nw
   182  			n += int64(nw)
   183  			if ew != nil {
   184  				if r.off == len(*r.bufP) {
   185  					*r.bufP = (*r.bufP)[:bufSize]
   186  					bufPool.Put(r.bufP)
   187  					r.bufP = nil
   188  				}
   189  				err = ew
   190  				return
   191  			}
   192  		}
   193  
   194  		nr, er := r.read(*r.bufP)
   195  		if er != nil {
   196  			if er != io.EOF {
   197  				err = er
   198  			}
   199  			return
   200  		}
   201  		*r.bufP = (*r.bufP)[:nr]
   202  		r.off = 0
   203  	}
   204  }
   205  
   206  // increment little-endian encoded unsigned integer b. Wrap around on overflow.
   207  func increment(b []byte) {
   208  	for i := range b {
   209  		b[i]++
   210  		if b[i] != 0 {
   211  			return
   212  		}
   213  	}
   214  }
   215  
   216  type Conn struct {
   217  	net.Conn
   218  	Cipher
   219  	r *Reader
   220  	w *Writer
   221  }
   222  
   223  // NewConn wraps a stream-oriented net.Conn with cipher.
   224  func NewConn(c net.Conn, ciph Cipher) *Conn { return &Conn{Conn: c, Cipher: ciph} }
   225  
   226  func (c *Conn) initReader() error {
   227  	salt := make([]byte, c.SaltSize())
   228  	if _, err := io.ReadFull(c.Conn, salt); err != nil {
   229  		return err
   230  	}
   231  
   232  	aead, err := c.Decrypter(salt)
   233  	if err != nil {
   234  		return err
   235  	}
   236  
   237  	c.r = NewReader(c.Conn, aead)
   238  	return nil
   239  }
   240  
   241  func (c *Conn) Read(b []byte) (int, error) {
   242  	if c.r == nil {
   243  		if err := c.initReader(); err != nil {
   244  			return 0, err
   245  		}
   246  	}
   247  	return c.r.Read(b)
   248  }
   249  
   250  func (c *Conn) WriteTo(w io.Writer) (int64, error) {
   251  	if c.r == nil {
   252  		if err := c.initReader(); err != nil {
   253  			return 0, err
   254  		}
   255  	}
   256  	return c.r.WriteTo(w)
   257  }
   258  
   259  func (c *Conn) initWriter() error {
   260  	salt := make([]byte, c.SaltSize())
   261  	if _, err := rand.Read(salt); err != nil {
   262  		return err
   263  	}
   264  	aead, err := c.Encrypter(salt)
   265  	if err != nil {
   266  		return err
   267  	}
   268  	_, err = c.Conn.Write(salt)
   269  	if err != nil {
   270  		return err
   271  	}
   272  	c.w = NewWriter(c.Conn, aead)
   273  	return nil
   274  }
   275  
   276  func (c *Conn) Write(b []byte) (int, error) {
   277  	if c.w == nil {
   278  		if err := c.initWriter(); err != nil {
   279  			return 0, err
   280  		}
   281  	}
   282  	return c.w.Write(b)
   283  }
   284  
   285  func (c *Conn) ReadFrom(r io.Reader) (int64, error) {
   286  	if c.w == nil {
   287  		if err := c.initWriter(); err != nil {
   288  			return 0, err
   289  		}
   290  	}
   291  	return c.w.ReadFrom(r)
   292  }