github.com/metacubex/mihomo@v1.18.5/transport/shadowsocks/shadowaead/packet.go (about) 1 package shadowaead 2 3 import ( 4 "crypto/rand" 5 "errors" 6 "io" 7 "net" 8 9 N "github.com/metacubex/mihomo/common/net" 10 "github.com/metacubex/mihomo/common/pool" 11 ) 12 13 // ErrShortPacket means that the packet is too short for a valid encrypted packet. 14 var ErrShortPacket = errors.New("short packet") 15 16 var _zerononce [128]byte // read-only. 128 bytes is more than enough. 17 18 // Pack encrypts plaintext using Cipher with a randomly generated salt and 19 // returns a slice of dst containing the encrypted packet and any error occurred. 20 // Ensure len(dst) >= ciph.SaltSize() + len(plaintext) + aead.Overhead(). 21 func Pack(dst, plaintext []byte, ciph Cipher) ([]byte, error) { 22 saltSize := ciph.SaltSize() 23 salt := dst[:saltSize] 24 if _, err := rand.Read(salt); err != nil { 25 return nil, err 26 } 27 aead, err := ciph.Encrypter(salt) 28 if err != nil { 29 return nil, err 30 } 31 if len(dst) < saltSize+len(plaintext)+aead.Overhead() { 32 return nil, io.ErrShortBuffer 33 } 34 b := aead.Seal(dst[saltSize:saltSize], _zerononce[:aead.NonceSize()], plaintext, nil) 35 return dst[:saltSize+len(b)], nil 36 } 37 38 // Unpack decrypts pkt using Cipher and returns a slice of dst containing the decrypted payload and any error occurred. 39 // Ensure len(dst) >= len(pkt) - aead.SaltSize() - aead.Overhead(). 40 func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) { 41 saltSize := ciph.SaltSize() 42 if len(pkt) < saltSize { 43 return nil, ErrShortPacket 44 } 45 salt := pkt[:saltSize] 46 aead, err := ciph.Decrypter(salt) 47 if err != nil { 48 return nil, err 49 } 50 if len(pkt) < saltSize+aead.Overhead() { 51 return nil, ErrShortPacket 52 } 53 if saltSize+len(dst)+aead.Overhead() < len(pkt) { 54 return nil, io.ErrShortBuffer 55 } 56 b, err := aead.Open(dst[:0], _zerononce[:aead.NonceSize()], pkt[saltSize:], nil) 57 return b, err 58 } 59 60 type PacketConn struct { 61 N.EnhancePacketConn 62 Cipher 63 } 64 65 const maxPacketSize = 64 * 1024 66 67 // NewPacketConn wraps an N.EnhancePacketConn with cipher 68 func NewPacketConn(c N.EnhancePacketConn, ciph Cipher) *PacketConn { 69 return &PacketConn{EnhancePacketConn: c, Cipher: ciph} 70 } 71 72 // WriteTo encrypts b and write to addr using the embedded PacketConn. 73 func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { 74 buf := pool.Get(maxPacketSize) 75 defer pool.Put(buf) 76 buf, err := Pack(buf, b, c) 77 if err != nil { 78 return 0, err 79 } 80 _, err = c.EnhancePacketConn.WriteTo(buf, addr) 81 return len(b), err 82 } 83 84 // ReadFrom reads from the embedded PacketConn and decrypts into b. 85 func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { 86 n, addr, err := c.EnhancePacketConn.ReadFrom(b) 87 if err != nil { 88 return n, addr, err 89 } 90 bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c) 91 if err != nil { 92 return n, addr, err 93 } 94 copy(b, bb) 95 return len(bb), addr, err 96 } 97 98 func (c *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { 99 data, put, addr, err = c.EnhancePacketConn.WaitReadFrom() 100 if err != nil { 101 return 102 } 103 data, err = Unpack(data[c.Cipher.SaltSize():], data, c) 104 if err != nil { 105 if put != nil { 106 put() 107 } 108 data = nil 109 put = nil 110 return 111 } 112 return 113 }