github.com/metacubex/mihomo@v1.18.5/transport/shadowsocks/shadowaead/packet.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  
     9  	N "github.com/metacubex/mihomo/common/net"
    10  	"github.com/metacubex/mihomo/common/pool"
    11  )
    12  
    13  // ErrShortPacket means that the packet is too short for a valid encrypted packet.
    14  var ErrShortPacket = errors.New("short packet")
    15  
    16  var _zerononce [128]byte // read-only. 128 bytes is more than enough.
    17  
    18  // Pack encrypts plaintext using Cipher with a randomly generated salt and
    19  // returns a slice of dst containing the encrypted packet and any error occurred.
    20  // Ensure len(dst) >= ciph.SaltSize() + len(plaintext) + aead.Overhead().
    21  func Pack(dst, plaintext []byte, ciph Cipher) ([]byte, error) {
    22  	saltSize := ciph.SaltSize()
    23  	salt := dst[:saltSize]
    24  	if _, err := rand.Read(salt); err != nil {
    25  		return nil, err
    26  	}
    27  	aead, err := ciph.Encrypter(salt)
    28  	if err != nil {
    29  		return nil, err
    30  	}
    31  	if len(dst) < saltSize+len(plaintext)+aead.Overhead() {
    32  		return nil, io.ErrShortBuffer
    33  	}
    34  	b := aead.Seal(dst[saltSize:saltSize], _zerononce[:aead.NonceSize()], plaintext, nil)
    35  	return dst[:saltSize+len(b)], nil
    36  }
    37  
    38  // Unpack decrypts pkt using Cipher and returns a slice of dst containing the decrypted payload and any error occurred.
    39  // Ensure len(dst) >= len(pkt) - aead.SaltSize() - aead.Overhead().
    40  func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) {
    41  	saltSize := ciph.SaltSize()
    42  	if len(pkt) < saltSize {
    43  		return nil, ErrShortPacket
    44  	}
    45  	salt := pkt[:saltSize]
    46  	aead, err := ciph.Decrypter(salt)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  	if len(pkt) < saltSize+aead.Overhead() {
    51  		return nil, ErrShortPacket
    52  	}
    53  	if saltSize+len(dst)+aead.Overhead() < len(pkt) {
    54  		return nil, io.ErrShortBuffer
    55  	}
    56  	b, err := aead.Open(dst[:0], _zerononce[:aead.NonceSize()], pkt[saltSize:], nil)
    57  	return b, err
    58  }
    59  
    60  type PacketConn struct {
    61  	N.EnhancePacketConn
    62  	Cipher
    63  }
    64  
    65  const maxPacketSize = 64 * 1024
    66  
    67  // NewPacketConn wraps an N.EnhancePacketConn with cipher
    68  func NewPacketConn(c N.EnhancePacketConn, ciph Cipher) *PacketConn {
    69  	return &PacketConn{EnhancePacketConn: c, Cipher: ciph}
    70  }
    71  
    72  // WriteTo encrypts b and write to addr using the embedded PacketConn.
    73  func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
    74  	buf := pool.Get(maxPacketSize)
    75  	defer pool.Put(buf)
    76  	buf, err := Pack(buf, b, c)
    77  	if err != nil {
    78  		return 0, err
    79  	}
    80  	_, err = c.EnhancePacketConn.WriteTo(buf, addr)
    81  	return len(b), err
    82  }
    83  
    84  // ReadFrom reads from the embedded PacketConn and decrypts into b.
    85  func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
    86  	n, addr, err := c.EnhancePacketConn.ReadFrom(b)
    87  	if err != nil {
    88  		return n, addr, err
    89  	}
    90  	bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c)
    91  	if err != nil {
    92  		return n, addr, err
    93  	}
    94  	copy(b, bb)
    95  	return len(bb), addr, err
    96  }
    97  
    98  func (c *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
    99  	data, put, addr, err = c.EnhancePacketConn.WaitReadFrom()
   100  	if err != nil {
   101  		return
   102  	}
   103  	data, err = Unpack(data[c.Cipher.SaltSize():], data, c)
   104  	if err != nil {
   105  		if put != nil {
   106  			put()
   107  		}
   108  		data = nil
   109  		put = nil
   110  		return
   111  	}
   112  	return
   113  }