github.com/yaling888/clash@v1.53.0/transport/vmess/aead.go (about)

     1  package vmess
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"encoding/binary"
     6  	"errors"
     7  	"io"
     8  	"sync"
     9  
    10  	"github.com/yaling888/clash/common/pool"
    11  )
    12  
    13  type aeadWriter struct {
    14  	io.Writer
    15  	cipher.AEAD
    16  	nonce [32]byte
    17  	count uint16
    18  	iv    []byte
    19  
    20  	writeLock sync.Mutex
    21  }
    22  
    23  func newAEADWriter(w io.Writer, aead cipher.AEAD, iv []byte) *aeadWriter {
    24  	return &aeadWriter{Writer: w, AEAD: aead, iv: iv}
    25  }
    26  
    27  func (w *aeadWriter) Write(b []byte) (n int, err error) {
    28  	w.writeLock.Lock()
    29  	bufP := pool.GetNetBuf()
    30  	defer func() {
    31  		w.writeLock.Unlock()
    32  		pool.PutNetBuf(bufP)
    33  	}()
    34  	length := len(b)
    35  	for {
    36  		if length == 0 {
    37  			break
    38  		}
    39  		readLen := chunkSize - w.Overhead()
    40  		if length < readLen {
    41  			readLen = length
    42  		}
    43  		payloadBuf := (*bufP)[lenSize : lenSize+chunkSize-w.Overhead()]
    44  		copy(payloadBuf, b[n:n+readLen])
    45  
    46  		binary.BigEndian.PutUint16((*bufP)[:lenSize], uint16(readLen+w.Overhead()))
    47  		binary.BigEndian.PutUint16(w.nonce[:2], w.count)
    48  		copy(w.nonce[2:], w.iv[2:12])
    49  
    50  		w.Seal(payloadBuf[:0], w.nonce[:w.NonceSize()], payloadBuf[:readLen], nil)
    51  		w.count++
    52  
    53  		_, err = w.Writer.Write((*bufP)[:lenSize+readLen+w.Overhead()])
    54  		if err != nil {
    55  			break
    56  		}
    57  		n += readLen
    58  		length -= readLen
    59  	}
    60  	return
    61  }
    62  
    63  type aeadReader struct {
    64  	io.Reader
    65  	cipher.AEAD
    66  	nonce   [32]byte
    67  	bufP    *[]byte
    68  	offset  int
    69  	iv      []byte
    70  	sizeBuf []byte
    71  	count   uint16
    72  }
    73  
    74  func newAEADReader(r io.Reader, aead cipher.AEAD, iv []byte) *aeadReader {
    75  	return &aeadReader{Reader: r, AEAD: aead, iv: iv, sizeBuf: make([]byte, lenSize)}
    76  }
    77  
    78  func (r *aeadReader) Read(b []byte) (int, error) {
    79  	if r.bufP != nil {
    80  		n := copy(b, (*r.bufP)[r.offset:])
    81  		r.offset += n
    82  		if r.offset == len(*r.bufP) {
    83  			pool.PutNetBuf(r.bufP)
    84  			r.bufP = nil
    85  		}
    86  		return n, nil
    87  	}
    88  
    89  	_, err := io.ReadFull(r.Reader, r.sizeBuf)
    90  	if err != nil {
    91  		return 0, err
    92  	}
    93  
    94  	size := int(binary.BigEndian.Uint16(r.sizeBuf))
    95  	if size > maxSize {
    96  		return 0, errors.New("buffer is larger than standard")
    97  	}
    98  
    99  	bufP := pool.GetNetBuf()
   100  	_, err = io.ReadFull(r.Reader, (*bufP)[:size])
   101  	if err != nil {
   102  		pool.PutNetBuf(bufP)
   103  		return 0, err
   104  	}
   105  
   106  	binary.BigEndian.PutUint16(r.nonce[:2], r.count)
   107  	copy(r.nonce[2:], r.iv[2:12])
   108  
   109  	_, err = r.Open((*bufP)[:0], r.nonce[:r.NonceSize()], (*bufP)[:size], nil)
   110  	r.count++
   111  	if err != nil {
   112  		pool.PutNetBuf(bufP)
   113  		return 0, err
   114  	}
   115  	realLen := size - r.Overhead()
   116  	n := copy(b, (*bufP)[:realLen])
   117  	if len(b) >= realLen {
   118  		pool.PutNetBuf(bufP)
   119  		return n, nil
   120  	}
   121  
   122  	*bufP = (*bufP)[:realLen]
   123  	r.offset = n
   124  	r.bufP = bufP
   125  	return n, nil
   126  }