github.com/igoogolx/clash@v1.19.8/transport/shadowsocks/shadowaead/packet.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  
     9  	"github.com/igoogolx/clash/common/pool"
    10  )
    11  
    12  // ErrShortPacket means that the packet is too short for a valid encrypted packet.
    13  var ErrShortPacket = errors.New("short packet")
    14  
    15  var _zerononce [128]byte // read-only. 128 bytes is more than enough.
    16  
    17  // Pack encrypts plaintext using Cipher with a randomly generated salt and
    18  // returns a slice of dst containing the encrypted packet and any error occurred.
    19  // Ensure len(dst) >= ciph.SaltSize() + len(plaintext) + aead.Overhead().
    20  func Pack(dst, plaintext []byte, ciph Cipher) ([]byte, error) {
    21  	saltSize := ciph.SaltSize()
    22  	salt := dst[:saltSize]
    23  	if _, err := rand.Read(salt); err != nil {
    24  		return nil, err
    25  	}
    26  	aead, err := ciph.Encrypter(salt)
    27  	if err != nil {
    28  		return nil, err
    29  	}
    30  	if len(dst) < saltSize+len(plaintext)+aead.Overhead() {
    31  		return nil, io.ErrShortBuffer
    32  	}
    33  	b := aead.Seal(dst[saltSize:saltSize], _zerononce[:aead.NonceSize()], plaintext, nil)
    34  	return dst[:saltSize+len(b)], nil
    35  }
    36  
    37  // Unpack decrypts pkt using Cipher and returns a slice of dst containing the decrypted payload and any error occurred.
    38  // Ensure len(dst) >= len(pkt) - aead.SaltSize() - aead.Overhead().
    39  func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) {
    40  	saltSize := ciph.SaltSize()
    41  	if len(pkt) < saltSize {
    42  		return nil, ErrShortPacket
    43  	}
    44  	salt := pkt[:saltSize]
    45  	aead, err := ciph.Decrypter(salt)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	if len(pkt) < saltSize+aead.Overhead() {
    50  		return nil, ErrShortPacket
    51  	}
    52  	if saltSize+len(dst)+aead.Overhead() < len(pkt) {
    53  		return nil, io.ErrShortBuffer
    54  	}
    55  	b, err := aead.Open(dst[:0], _zerononce[:aead.NonceSize()], pkt[saltSize:], nil)
    56  	return b, err
    57  }
    58  
    59  type PacketConn struct {
    60  	net.PacketConn
    61  	Cipher
    62  }
    63  
    64  const maxPacketSize = 64 * 1024
    65  
    66  // NewPacketConn wraps a net.PacketConn with cipher
    67  func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn {
    68  	return &PacketConn{PacketConn: c, Cipher: ciph}
    69  }
    70  
    71  // WriteTo encrypts b and write to addr using the embedded PacketConn.
    72  func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
    73  	buf := pool.Get(maxPacketSize)
    74  	defer pool.Put(buf)
    75  	buf, err := Pack(buf, b, c)
    76  	if err != nil {
    77  		return 0, err
    78  	}
    79  	_, err = c.PacketConn.WriteTo(buf, addr)
    80  	return len(b), err
    81  }
    82  
    83  // ReadFrom reads from the embedded PacketConn and decrypts into b.
    84  func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
    85  	n, addr, err := c.PacketConn.ReadFrom(b)
    86  	if err != nil {
    87  		return n, addr, err
    88  	}
    89  	bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c)
    90  	if err != nil {
    91  		return n, addr, err
    92  	}
    93  	copy(b, bb)
    94  	return len(bb), addr, err
    95  }