github.com/yaling888/clash@v1.53.0/transport/shadowsocks/shadowstream/packet.go (about)

     1  package shadowstream
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  
     9  	"github.com/yaling888/clash/common/pool"
    10  )
    11  
    12  // ErrShortPacket means the packet is too short to be a valid encrypted packet.
    13  var ErrShortPacket = errors.New("short packet")
    14  
    15  // Pack encrypts plaintext using stream cipher s and a random IV.
    16  // Returns a slice of dst containing random IV and ciphertext.
    17  // Ensure len(dst) >= s.IVSize() + len(plaintext).
    18  func Pack(dst *[]byte, plaintext []byte, s Cipher) error {
    19  	if len(*dst) < s.IVSize()+len(plaintext) {
    20  		return io.ErrShortBuffer
    21  	}
    22  	iv := (*dst)[:s.IVSize()]
    23  	_, err := rand.Read(iv)
    24  	if err != nil {
    25  		return err
    26  	}
    27  	s.Encrypter(iv).XORKeyStream((*dst)[len(iv):], plaintext)
    28  	*dst = (*dst)[:len(iv)+len(plaintext)]
    29  	return nil
    30  }
    31  
    32  // Unpack decrypts pkt using stream cipher s.
    33  // Returns a slice of dst containing decrypted plaintext.
    34  func Unpack(dst, pkt []byte, s Cipher) ([]byte, error) {
    35  	if len(pkt) < s.IVSize() {
    36  		return nil, ErrShortPacket
    37  	}
    38  	if len(dst) < len(pkt)-s.IVSize() {
    39  		return nil, io.ErrShortBuffer
    40  	}
    41  	iv := pkt[:s.IVSize()]
    42  	s.Decrypter(iv).XORKeyStream(dst, pkt[len(iv):])
    43  	return dst[:len(pkt)-len(iv)], nil
    44  }
    45  
    46  type PacketConn struct {
    47  	net.PacketConn
    48  	Cipher
    49  }
    50  
    51  // NewPacketConn wraps a net.PacketConn with stream cipher encryption/decryption.
    52  func NewPacketConn(c net.PacketConn, ciph Cipher) *PacketConn {
    53  	return &PacketConn{PacketConn: c, Cipher: ciph}
    54  }
    55  
    56  // const maxPacketSize = 64 * 1024
    57  
    58  func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
    59  	bufP := pool.GetNetBuf()
    60  	defer pool.PutNetBuf(bufP)
    61  	err := Pack(bufP, b, c)
    62  	if err != nil {
    63  		return 0, err
    64  	}
    65  	_, err = c.PacketConn.WriteTo(*bufP, addr)
    66  	return len(b), err
    67  }
    68  
    69  func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
    70  	n, addr, err := c.PacketConn.ReadFrom(b)
    71  	if err != nil {
    72  		return n, addr, err
    73  	}
    74  	bb, err := Unpack(b[c.IVSize():], b[:n], c.Cipher)
    75  	if err != nil {
    76  		return n, addr, err
    77  	}
    78  	copy(b, bb)
    79  	return len(bb), addr, err
    80  }