github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/yuubinsya/crypto/aead.go (about)

     1  package crypto
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"encoding/binary"
     6  	"io"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
    11  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    12  	"golang.org/x/crypto/chacha20poly1305"
    13  )
    14  
    15  var Chacha20poly1305 = chacha20poly1305Aead{}
    16  
    17  type chacha20poly1305Aead struct{}
    18  
    19  func (chacha20poly1305Aead) New(key []byte) (cipher.AEAD, error) { return chacha20poly1305.New(key) }
    20  func (chacha20poly1305Aead) KeySize() int                        { return chacha20poly1305.KeySize }
    21  func (chacha20poly1305Aead) NonceSize() int                      { return chacha20poly1305.NonceSize }
    22  func (chacha20poly1305Aead) Name() []byte                        { return []byte("chacha20poly1305-key") }
    23  
    24  type streamConn struct {
    25  	net.Conn
    26  	r io.Reader
    27  	w io.Writer
    28  }
    29  
    30  func (c *streamConn) Read(b []byte) (int, error)  { return c.r.Read(b) }
    31  func (c *streamConn) Write(b []byte) (int, error) { return c.w.Write(b) }
    32  
    33  // NewConn wraps a stream-oriented net.Conn with cipher.
    34  func NewConn(c net.Conn, rnonce, wnonce []byte, rciph, wciph cipher.AEAD) net.Conn {
    35  	return &streamConn{
    36  		Conn: c,
    37  		r:    NewReader(c, rnonce, rciph, nat.MaxSegmentSize),
    38  		w:    NewWriter(c, wnonce, wciph, nat.MaxSegmentSize),
    39  	}
    40  }
    41  
    42  type writer struct {
    43  	io.Writer
    44  	cipher.AEAD
    45  	nonce          []byte
    46  	maxPayloadSize int
    47  
    48  	mu sync.Mutex
    49  }
    50  
    51  // NewWriter wraps an io.Writer with AEAD encryption.
    52  
    53  func NewWriter(w io.Writer, nonce []byte, aead cipher.AEAD, maxPayloadSize int) *writer {
    54  	return &writer{
    55  		Writer:         w,
    56  		AEAD:           aead,
    57  		nonce:          nonce,
    58  		maxPayloadSize: maxPayloadSize,
    59  	}
    60  }
    61  
    62  func (w *writer) Write(p []byte) (n int, err error) {
    63  	if len(p) == 0 {
    64  		return
    65  	}
    66  
    67  	buf := pool.GetBytes(2 + w.AEAD.Overhead() + w.maxPayloadSize + w.AEAD.Overhead())
    68  	defer pool.PutBytes(buf)
    69  
    70  	for pLen := len(p); pLen > 0; {
    71  		var data []byte
    72  		if pLen > w.maxPayloadSize {
    73  			data = p[:w.maxPayloadSize]
    74  			p = p[w.maxPayloadSize:]
    75  			pLen -= w.maxPayloadSize
    76  		} else {
    77  			data = p
    78  			pLen = 0
    79  		}
    80  		binary.BigEndian.PutUint16(buf[:2], uint16(len(data)))
    81  		w.mu.Lock()
    82  		w.Seal(buf[:0], w.nonce, buf[:2], nil)
    83  		increment(w.nonce)
    84  		offset := w.Overhead() + 2
    85  		packet := w.Seal(buf[offset:offset], w.nonce, data, nil)
    86  		increment(w.nonce)
    87  		_, err = w.Writer.Write(buf[:offset+len(packet)])
    88  		w.mu.Unlock()
    89  		if err != nil {
    90  			return
    91  		}
    92  		n += len(data)
    93  	}
    94  
    95  	return
    96  }
    97  
    98  type reader struct {
    99  	io.Reader
   100  	cipher.AEAD
   101  	nonce    []byte
   102  	buf      []byte
   103  	leftover []byte
   104  
   105  	mu sync.Mutex
   106  }
   107  
   108  func NewReader(r io.Reader, nonce []byte, aead cipher.AEAD, maxPayloadSize int) *reader {
   109  	return &reader{
   110  		Reader: r,
   111  		AEAD:   aead,
   112  		buf:    make([]byte, maxPayloadSize+aead.Overhead()),
   113  		nonce:  nonce,
   114  	}
   115  }
   116  
   117  // read and decrypt a record into the internal buffer. Return decrypted payload length and any error encountered.
   118  func (r *reader) read() (int, error) {
   119  	// decrypt payload size
   120  	buf := r.buf[:2+r.Overhead()]
   121  	_, err := io.ReadFull(r.Reader, buf)
   122  	if err != nil {
   123  		return 0, err
   124  	}
   125  
   126  	_, err = r.Open(buf[:0], r.nonce, buf, nil)
   127  	increment(r.nonce)
   128  	if err != nil {
   129  		return 0, err
   130  	}
   131  
   132  	size := int(binary.BigEndian.Uint16(buf[:2]))
   133  
   134  	// decrypt payload
   135  	buf = r.buf[:size+r.Overhead()]
   136  	_, err = io.ReadFull(r.Reader, buf)
   137  	if err != nil {
   138  		return 0, err
   139  	}
   140  
   141  	_, err = r.Open(buf[:0], r.nonce, buf, nil)
   142  	increment(r.nonce)
   143  	if err != nil {
   144  		return 0, err
   145  	}
   146  
   147  	return size, nil
   148  }
   149  
   150  // Read reads from the embedded io.Reader, decrypts and writes to b.
   151  func (r *reader) Read(b []byte) (int, error) {
   152  	r.mu.Lock()
   153  	defer r.mu.Unlock()
   154  
   155  	// copy decrypted bytes (if any) from previous record first
   156  	if len(r.leftover) > 0 {
   157  		n := copy(b, r.leftover)
   158  		r.leftover = r.leftover[n:]
   159  		return n, nil
   160  	}
   161  
   162  	n, err := r.read()
   163  
   164  	m := copy(b, r.buf[:n])
   165  	if m < n { // insufficient len(b), keep leftover for next read
   166  		r.leftover = r.buf[m:n]
   167  	}
   168  	return m, err
   169  }
   170  
   171  // increment little-endian encoded unsigned integer b. Wrap around on overflow.
   172  func increment(b []byte) {
   173  	for i := range b {
   174  		b[i]++
   175  		if b[i] != 0 {
   176  			return
   177  		}
   178  	}
   179  }