github.com/yaling888/clash@v1.53.0/transport/shadowsocks/shadowaead/packet.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  
     9  	"github.com/yaling888/clash/common/pool"
    10  )
    11  
    12  // ErrShortPacket means that the packet is too short for a valid encrypted packet.
    13  var ErrShortPacket = errors.New("short packet")
    14  
    15  var _zerononce [128]byte // read-only. 128 bytes is more than enough.
    16  
    17  // Pack encrypts plaintext using Cipher with a randomly generated salt and
    18  // returns a slice of dst containing the encrypted packet and any error occurred.
    19  // Ensure len(dst) >= ciph.SaltSize() + len(plaintext) + aead.Overhead().
    20  func Pack(dst *[]byte, plaintext []byte, ciph Cipher) error {
    21  	saltSize := ciph.SaltSize()
    22  	salt := (*dst)[:saltSize]
    23  	if _, err := rand.Read(salt); err != nil {
    24  		return err
    25  	}
    26  	aead, err := ciph.Encrypter(salt)
    27  	if err != nil {
    28  		return err
    29  	}
    30  	if len(*dst) < saltSize+len(plaintext)+aead.Overhead() {
    31  		return io.ErrShortBuffer
    32  	}
    33  	b := aead.Seal((*dst)[saltSize:saltSize], _zerononce[:aead.NonceSize()], plaintext, nil)
    34  	*dst = (*dst)[:saltSize+len(b)]
    35  	return nil
    36  }
    37  
    38  // Unpack decrypts pkt using Cipher and returns a slice of dst containing the decrypted payload and any error occurred.
    39  // Ensure len(dst) >= len(pkt) - aead.SaltSize() - aead.Overhead().
    40  func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) {
    41  	saltSize := ciph.SaltSize()
    42  	if len(pkt) < saltSize {
    43  		return nil, ErrShortPacket
    44  	}
    45  	salt := pkt[:saltSize]
    46  	aead, err := ciph.Decrypter(salt)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  	if len(pkt) < saltSize+aead.Overhead() {
    51  		return nil, ErrShortPacket
    52  	}
    53  	if saltSize+len(dst)+aead.Overhead() < len(pkt) {
    54  		return nil, io.ErrShortBuffer
    55  	}
    56  	b, err := aead.Open(dst[:0], _zerononce[:aead.NonceSize()], pkt[saltSize:], nil)
    57  	return b, err
    58  }
    59  
    60  type PacketConn struct {
    61  	net.PacketConn
    62  	Cipher
    63  }
    64  
    65  // const maxPacketSize = 64 * 1024
    66  
    67  // NewPacketConn wraps a net.PacketConn with cipher
    68  func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn {
    69  	return &PacketConn{PacketConn: c, Cipher: ciph}
    70  }
    71  
    72  // WriteTo encrypts b and write to addr using the embedded PacketConn.
    73  func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
    74  	bufP := pool.GetNetBuf()
    75  	defer pool.PutNetBuf(bufP)
    76  	err := Pack(bufP, b, c)
    77  	if err != nil {
    78  		return 0, err
    79  	}
    80  	_, err = c.PacketConn.WriteTo(*bufP, addr)
    81  	return len(b), err
    82  }
    83  
    84  // ReadFrom reads from the embedded PacketConn and decrypts into b.
    85  func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
    86  	n, addr, err := c.PacketConn.ReadFrom(b)
    87  	if err != nil {
    88  		return n, addr, err
    89  	}
    90  	bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c)
    91  	if err != nil {
    92  		return n, addr, err
    93  	}
    94  	copy(b, bb)
    95  	return len(bb), addr, err
    96  }