github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/shadowsocks/shadowaead/packet.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/shadowsocks/internal"
    11  )
    12  
    13  // ErrShortPacket means that the packet is too short for a valid encrypted packet.
    14  var ErrShortPacket = errors.New("short packet")
    15  
    16  var _zerononce [128]byte // read-only. 128 bytes is more than enough.
    17  
    18  // Pack encrypts plaintext using Cipher with a randomly generated salt and
    19  // returns a slice of dst containing the encrypted packet and any error occurred.
    20  // Ensure len(dst) >= ciph.SaltSize() + len(plaintext) + aead.Overhead().
    21  func Pack(dst, plaintext []byte, ciph Cipher) ([]byte, error) {
    22  	saltSize := ciph.SaltSize()
    23  	salt := dst[:saltSize]
    24  	if _, err := io.ReadFull(rand.Reader, salt); err != nil {
    25  		return nil, err
    26  	}
    27  
    28  	aead, err := ciph.Encrypter(salt)
    29  	if err != nil {
    30  		return nil, err
    31  	}
    32  	internal.AddSalt(salt)
    33  
    34  	if len(dst) < saltSize+len(plaintext)+aead.Overhead() {
    35  		return nil, io.ErrShortBuffer
    36  	}
    37  	b := aead.Seal(dst[saltSize:saltSize], _zerononce[:aead.NonceSize()], plaintext, nil)
    38  	return dst[:saltSize+len(b)], nil
    39  }
    40  
    41  // Unpack decrypts pkt using Cipher and returns a slice of dst containing the decrypted payload and any error occurred.
    42  // Ensure len(dst) >= len(pkt) - aead.SaltSize() - aead.Overhead().
    43  func Unpack(dst, pkt []byte, ciph Cipher) ([]byte, error) {
    44  	saltSize := ciph.SaltSize()
    45  	if len(pkt) < saltSize {
    46  		return nil, ErrShortPacket
    47  	}
    48  	salt := pkt[:saltSize]
    49  	aead, err := ciph.Decrypter(salt)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	if internal.CheckSalt(salt) {
    54  		return nil, ErrRepeatedSalt
    55  	}
    56  	if len(pkt) < saltSize+aead.Overhead() {
    57  		return nil, ErrShortPacket
    58  	}
    59  	if saltSize+len(dst)+aead.Overhead() < len(pkt) {
    60  		return nil, io.ErrShortBuffer
    61  	}
    62  	b, err := aead.Open(dst[:0], _zerononce[:aead.NonceSize()], pkt[saltSize:], nil)
    63  	return b, err
    64  }
    65  
    66  type packetConn struct {
    67  	net.PacketConn
    68  	Cipher
    69  	sync.Mutex
    70  	buf []byte // write lock
    71  }
    72  
    73  // NewPacketConn wraps a net.PacketConn with cipher
    74  func NewPacketConn(c net.PacketConn, ciph Cipher) net.PacketConn {
    75  	const maxPacketSize = 64 * 1024
    76  	return &packetConn{PacketConn: c, Cipher: ciph, buf: make([]byte, maxPacketSize)}
    77  }
    78  
    79  // WriteTo encrypts b and write to addr using the embedded PacketConn.
    80  func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
    81  	c.Lock()
    82  	defer c.Unlock()
    83  	buf, err := Pack(c.buf, b, c)
    84  	if err != nil {
    85  		return 0, err
    86  	}
    87  	_, err = c.PacketConn.WriteTo(buf, addr)
    88  	return len(b), err
    89  }
    90  
    91  // ReadFrom reads from the embedded PacketConn and decrypts into b.
    92  func (c *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
    93  	n, addr, err := c.PacketConn.ReadFrom(b)
    94  	if err != nil {
    95  		return n, addr, err
    96  	}
    97  	bb, err := Unpack(b[c.Cipher.SaltSize():], b[:n], c)
    98  	if err != nil {
    99  		return n, addr, err
   100  	}
   101  	copy(b, bb)
   102  	return len(bb), addr, err
   103  }