github.com/kelleygo/clashcore@v1.0.2/transport/shadowsocks/shadowstream/packet.go (about)

     1  package shadowstream
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  
     9  	N "github.com/kelleygo/clashcore/common/net"
    10  	"github.com/kelleygo/clashcore/common/pool"
    11  )
    12  
    13  // ErrShortPacket means the packet is too short to be a valid encrypted packet.
    14  var ErrShortPacket = errors.New("short packet")
    15  
    16  // Pack encrypts plaintext using stream cipher s and a random IV.
    17  // Returns a slice of dst containing random IV and ciphertext.
    18  // Ensure len(dst) >= s.IVSize() + len(plaintext).
    19  func Pack(dst, plaintext []byte, s Cipher) ([]byte, error) {
    20  	if len(dst) < s.IVSize()+len(plaintext) {
    21  		return nil, io.ErrShortBuffer
    22  	}
    23  	iv := dst[:s.IVSize()]
    24  	_, err := rand.Read(iv)
    25  	if err != nil {
    26  		return nil, err
    27  	}
    28  	s.Encrypter(iv).XORKeyStream(dst[len(iv):], plaintext)
    29  	return dst[:len(iv)+len(plaintext)], nil
    30  }
    31  
    32  // Unpack decrypts pkt using stream cipher s.
    33  // Returns a slice of dst containing decrypted plaintext.
    34  func Unpack(dst, pkt []byte, s Cipher) ([]byte, error) {
    35  	if len(pkt) < s.IVSize() {
    36  		return nil, ErrShortPacket
    37  	}
    38  	if len(dst) < len(pkt)-s.IVSize() {
    39  		return nil, io.ErrShortBuffer
    40  	}
    41  	iv := pkt[:s.IVSize()]
    42  	s.Decrypter(iv).XORKeyStream(dst, pkt[len(iv):])
    43  	return dst[:len(pkt)-len(iv)], nil
    44  }
    45  
    46  type PacketConn struct {
    47  	N.EnhancePacketConn
    48  	Cipher
    49  }
    50  
    51  // NewPacketConn wraps an N.EnhancePacketConn with stream cipher encryption/decryption.
    52  func NewPacketConn(c N.EnhancePacketConn, ciph Cipher) *PacketConn {
    53  	return &PacketConn{EnhancePacketConn: c, Cipher: ciph}
    54  }
    55  
    56  const maxPacketSize = 64 * 1024
    57  
    58  func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
    59  	buf := pool.Get(maxPacketSize)
    60  	defer pool.Put(buf)
    61  	buf, err := Pack(buf, b, c.Cipher)
    62  	if err != nil {
    63  		return 0, err
    64  	}
    65  	_, err = c.EnhancePacketConn.WriteTo(buf, addr)
    66  	return len(b), err
    67  }
    68  
    69  func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
    70  	n, addr, err := c.EnhancePacketConn.ReadFrom(b)
    71  	if err != nil {
    72  		return n, addr, err
    73  	}
    74  	bb, err := Unpack(b[c.IVSize():], b[:n], c.Cipher)
    75  	if err != nil {
    76  		return n, addr, err
    77  	}
    78  	copy(b, bb)
    79  	return len(bb), addr, err
    80  }
    81  
    82  func (c *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
    83  	data, put, addr, err = c.EnhancePacketConn.WaitReadFrom()
    84  	if err != nil {
    85  		return
    86  	}
    87  	data, err = Unpack(data[c.IVSize():], data, c)
    88  	if err != nil {
    89  		if put != nil {
    90  			put()
    91  		}
    92  		data = nil
    93  		put = nil
    94  		return
    95  	}
    96  	return
    97  }