github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/vmess/aead.go (about)

     1  package vmess
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/cipher"
     6  	"encoding/binary"
     7  	"io"
     8  
     9  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    10  )
    11  
    12  var _ io.WriteCloser = &aeadWriter{}
    13  
    14  type aeadWriter struct {
    15  	io.Writer
    16  	cipher.AEAD
    17  	nonce []byte
    18  	buf   [lenSize + maxChunkSize]byte
    19  	count uint16
    20  	iv    []byte
    21  }
    22  
    23  // AEADWriter returns a aead writer
    24  func AEADWriter(w io.Writer, aead cipher.AEAD, iv []byte) writer {
    25  	return &aeadWriter{
    26  		Writer: w,
    27  		AEAD:   aead,
    28  		nonce:  make([]byte, aead.NonceSize()),
    29  		count:  0,
    30  		iv:     iv,
    31  	}
    32  }
    33  
    34  func (w *aeadWriter) Close() error { return nil }
    35  
    36  func (w *aeadWriter) Write(b []byte) (int, error) {
    37  	n, err := w.ReadFrom(bytes.NewBuffer(b))
    38  	return int(n), err
    39  }
    40  
    41  func (w *aeadWriter) ReadFrom(r io.Reader) (n int64, err error) {
    42  	buf := w.buf[:]
    43  	for {
    44  		payloadBuf := w.buf[lenSize : lenSize+defaultChunkSize-w.Overhead()]
    45  
    46  		nr, er := r.Read(payloadBuf)
    47  		if nr > 0 {
    48  			n += int64(nr)
    49  			buf = buf[:lenSize+nr+w.Overhead()]
    50  			payloadBuf = payloadBuf[:nr]
    51  			binary.BigEndian.PutUint16(w.buf[:lenSize], uint16(nr+w.Overhead()))
    52  
    53  			binary.BigEndian.PutUint16(w.nonce[:2], w.count)
    54  			copy(w.nonce[2:], w.iv[2:12])
    55  
    56  			w.Seal(payloadBuf[:0], w.nonce[:w.NonceSize()], payloadBuf, nil)
    57  			w.count++
    58  
    59  			_, ew := w.Writer.Write(buf)
    60  			if ew != nil {
    61  				err = ew
    62  				break
    63  			}
    64  		}
    65  
    66  		if er != nil {
    67  			if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
    68  				err = er
    69  			}
    70  			break
    71  		}
    72  	}
    73  
    74  	return n, err
    75  }
    76  
    77  var _ io.ReadCloser = &aeadReader{}
    78  
    79  type aeadReader struct {
    80  	io.Reader
    81  	cipher.AEAD
    82  	count uint16
    83  	iv    []byte
    84  
    85  	decrypted bytes.Buffer
    86  }
    87  
    88  // AEADReader returns a aead reader
    89  func AEADReader(r io.Reader, aead cipher.AEAD, iv []byte) io.ReadCloser {
    90  	return &aeadReader{
    91  		Reader: r,
    92  		AEAD:   aead,
    93  		count:  0,
    94  		iv:     iv,
    95  	}
    96  }
    97  
    98  func (r *aeadReader) Close() error { return nil }
    99  
   100  func (r *aeadReader) Read(b []byte) (int, error) {
   101  	if r.decrypted.Len() > 0 {
   102  		return r.decrypted.Read(b)
   103  	}
   104  
   105  	lb := pool.GetBytes(r.NonceSize())
   106  	defer pool.PutBytes(lb)
   107  
   108  	// get length
   109  	_, err := io.ReadFull(r.Reader, lb[:lenSize])
   110  	if err != nil {
   111  		return 0, err
   112  	}
   113  
   114  	// if length == 0, then this is the end
   115  	l := binary.BigEndian.Uint16(lb[:lenSize])
   116  	if l == 0 {
   117  		return 0, nil
   118  	}
   119  
   120  	buf := pool.GetBytes(int(l))
   121  	defer pool.PutBytes(buf)
   122  	// get payload
   123  	_, err = io.ReadFull(r.Reader, buf[:l])
   124  	if err != nil {
   125  		return 0, err
   126  	}
   127  
   128  	binary.BigEndian.PutUint16(lb[:2], r.count)
   129  	copy(lb[2:], r.iv[2:12])
   130  
   131  	_, err = r.Open(buf[:0], lb[:r.NonceSize()], buf[:l], nil)
   132  	r.count++
   133  	if err != nil {
   134  		return 0, err
   135  	}
   136  
   137  	r.decrypted.Write(buf[:int(l)-r.Overhead()])
   138  	return r.decrypted.Read(b)
   139  }