github.com/yaling888/clash@v1.53.0/transport/crypto/conn.go (about)

     1  package crypto
     2  
     3  import (
     4  	"crypto/rand"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/yaling888/clash/common/pool"
    13  )
    14  
    15  type AEADOption struct {
    16  	Cipher string `proxy:"cipher,omitempty"`
    17  	Key    string `proxy:"key,omitempty"`
    18  	Salt   string `proxy:"salt,omitempty"`
    19  }
    20  
    21  var _ net.Conn = (*aeadConn)(nil)
    22  
    23  type aeadConn struct {
    24  	net.Conn
    25  	cipher *AEAD
    26  
    27  	rMux sync.Mutex
    28  	buf  []byte
    29  	lasR int
    30  }
    31  
    32  func (c *aeadConn) Read(p []byte) (n int, err error) {
    33  	c.rMux.Lock()
    34  	defer c.rMux.Unlock()
    35  
    36  	if c.lasR > 0 && c.buf != nil {
    37  		n = copy(p, c.buf[len(c.buf)-c.lasR:])
    38  		c.lasR -= n
    39  		return
    40  	}
    41  
    42  	if c.buf == nil {
    43  		c.buf = make([]byte, 64<<10)
    44  	} else {
    45  		c.buf = c.buf[:64<<10]
    46  	}
    47  
    48  	defer func() {
    49  		if err != nil {
    50  			c.lasR = 0
    51  			c.buf = nil
    52  		}
    53  	}()
    54  
    55  	hdSize := c.cipher.NonceSize() + 2
    56  	_, err = io.ReadFull(c.Conn, c.buf[:hdSize])
    57  	if err != nil {
    58  		return
    59  	}
    60  
    61  	length := binary.BigEndian.Uint16(c.buf[c.cipher.NonceSize():])
    62  	if length == 0 {
    63  		err = io.EOF
    64  		return
    65  	}
    66  
    67  	nonce := make([]byte, c.cipher.NonceSize())
    68  	copy(nonce, c.buf[:c.cipher.NonceSize()])
    69  
    70  	_, err = io.ReadAtLeast(c.Conn, c.buf[:length], int(length))
    71  	if err != nil {
    72  		return
    73  	}
    74  
    75  	b, err := c.cipher.Open(c.buf[:0], nonce, c.buf[:length], nil)
    76  	if err != nil {
    77  		return
    78  	}
    79  
    80  	c.lasR = len(b)
    81  	c.buf = c.buf[:c.lasR]
    82  
    83  	n = copy(p, c.buf[len(c.buf)-c.lasR:])
    84  	c.lasR -= n
    85  	return
    86  }
    87  
    88  func (c *aeadConn) Write(p []byte) (n int, err error) {
    89  	bufP := pool.GetBufferWriter()
    90  	defer pool.PutBufferWriter(bufP)
    91  
    92  	bufP.Grow(c.cipher.NonceSize() + 2 + c.cipher.Overhead() + len(p))
    93  
    94  	nonce := (*bufP)[:c.cipher.NonceSize()]
    95  	if _, err = rand.Read(nonce); err != nil {
    96  		return
    97  	}
    98  
    99  	b := c.cipher.Seal((*bufP)[:c.cipher.NonceSize()+2], nonce, p, nil)
   100  	lenB := len(b)
   101  
   102  	binary.BigEndian.PutUint16(b[c.cipher.NonceSize():], uint16(lenB-c.cipher.NonceSize()-2))
   103  
   104  	lenP := len(p)
   105  	delta := lenB - lenP
   106  	nw, err := c.Conn.Write(b)
   107  	n = max(nw-delta, 0)
   108  	if n < lenP && err == nil {
   109  		err = io.ErrShortWrite
   110  	}
   111  	return
   112  }
   113  
   114  func (c *aeadConn) Close() (err error) {
   115  	err = c.Conn.Close()
   116  
   117  	c.rMux.Lock()
   118  	defer c.rMux.Unlock()
   119  
   120  	c.lasR = 0
   121  	c.buf = nil
   122  	return
   123  }
   124  
   125  func StreamAEADConn(conn net.Conn, opt AEADOption) (net.Conn, error) {
   126  	aead, err := NewAEAD(opt.Cipher, opt.Key, opt.Salt)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	if aead == nil {
   132  		return nil, fmt.Errorf("unsupported cipher: %s", opt.Cipher)
   133  	}
   134  
   135  	return &aeadConn{
   136  		Conn:   conn,
   137  		cipher: aead,
   138  	}, nil
   139  }
   140  
   141  func StreamAEADConnOrNot(conn net.Conn, opt AEADOption) (net.Conn, error) {
   142  	if opt.Cipher == "" || strings.ToLower(opt.Cipher) == "none" {
   143  		return conn, nil
   144  	}
   145  
   146  	return StreamAEADConn(conn, opt)
   147  }
   148  
   149  func VerifyAEADOption(opt AEADOption, allowNone bool) (bool, error) {
   150  	if !allowNone && (opt.Cipher == "" || strings.ToLower(opt.Cipher) == "none" || opt.Key == "") {
   151  		return false, nil
   152  	}
   153  	if _, err := NewAEAD(opt.Cipher, opt.Key, opt.Salt); err != nil {
   154  		return false, err
   155  	}
   156  	return true, nil
   157  }