github.com/MerlinKodo/sing-shadowsocks@v0.2.6/shadowaead/protocol.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"crypto/sha1"
     7  	"io"
     8  	"net"
     9  
    10  	shadowsocks "github.com/MerlinKodo/sing-shadowsocks"
    11  	"github.com/sagernet/sing/common"
    12  	"github.com/sagernet/sing/common/buf"
    13  	M "github.com/sagernet/sing/common/metadata"
    14  	N "github.com/sagernet/sing/common/network"
    15  	"github.com/sagernet/sing/common/rw"
    16  
    17  	"github.com/RyuaNerin/go-krypto/lea"
    18  	"github.com/Yawning/aez"
    19  	"github.com/ericlagergren/aegis"
    20  	"github.com/ericlagergren/siv"
    21  	"github.com/oasisprotocol/deoxysii"
    22  	"github.com/sina-ghaderi/rabaead"
    23  	"golang.org/x/crypto/chacha20poly1305"
    24  	"golang.org/x/crypto/hkdf"
    25  )
    26  
    27  var List = []string{
    28  	"aes-128-gcm",
    29  	"aes-192-gcm",
    30  	"aes-256-gcm",
    31  	"chacha20-ietf-poly1305",
    32  	"xchacha20-ietf-poly1305",
    33  	// began not standard methods
    34  	"rabbit128-poly1305",
    35  	"aes-128-gcm-siv",
    36  	"aes-256-gcm-siv",
    37  	"aegis-128l",
    38  	"aegis-256",
    39  	"aez-384",
    40  	"deoxys-ii-256-128",
    41  	"lea-128-gcm",
    42  	"lea-192-gcm",
    43  	"lea-256-gcm",
    44  }
    45  
    46  var _ shadowsocks.Method = (*Method)(nil)
    47  
    48  func New(method string, key []byte, password string) (*Method, error) {
    49  	m := &Method{
    50  		name: method,
    51  	}
    52  	switch method {
    53  	case "aes-128-gcm":
    54  		m.keySaltLength = 16
    55  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    56  	case "aes-192-gcm":
    57  		m.keySaltLength = 24
    58  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    59  	case "aes-256-gcm":
    60  		m.keySaltLength = 32
    61  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    62  	case "chacha20-ietf-poly1305":
    63  		m.keySaltLength = 32
    64  		m.constructor = chacha20poly1305.New
    65  	case "xchacha20-ietf-poly1305":
    66  		m.keySaltLength = 32
    67  		m.constructor = chacha20poly1305.NewX
    68  	case "rabbit128-poly1305":
    69  		m.keySaltLength = 16
    70  		m.constructor = rabaead.NewAEAD
    71  	case "aes-128-gcm-siv":
    72  		m.keySaltLength = 16
    73  		m.constructor = siv.NewGCM
    74  	case "aes-256-gcm-siv":
    75  		m.keySaltLength = 32
    76  		m.constructor = siv.NewGCM
    77  	case "aegis-128l":
    78  		m.keySaltLength = 16
    79  		m.constructor = aegis.New
    80  	case "aegis-256":
    81  		m.keySaltLength = 32
    82  		m.constructor = aegis.New
    83  	case "aez-384":
    84  		m.keySaltLength = 3 * 16
    85  		m.constructor = aez.New
    86  	case "deoxys-ii-256-128":
    87  		m.keySaltLength = 32
    88  		m.constructor = deoxysii.New
    89  	case "lea-128-gcm":
    90  		m.keySaltLength = 16
    91  		m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM)
    92  	case "lea-192-gcm":
    93  		m.keySaltLength = 24
    94  		m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM)
    95  	case "lea-256-gcm":
    96  		m.keySaltLength = 32
    97  		m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM)
    98  	}
    99  	if len(key) == m.keySaltLength {
   100  		m.key = key
   101  	} else if len(key) > 0 {
   102  		return nil, shadowsocks.ErrBadKey
   103  	} else if password == "" {
   104  		return nil, shadowsocks.ErrMissingPassword
   105  	} else {
   106  		m.key = shadowsocks.Key([]byte(password), m.keySaltLength)
   107  	}
   108  	return m, nil
   109  }
   110  
   111  func Kdf(key, iv []byte, buffer *buf.Buffer) {
   112  	kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
   113  	common.Must1(buffer.ReadFullFrom(kdf, buffer.FreeLen()))
   114  }
   115  
   116  func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) {
   117  	return func(key []byte) (cipher.AEAD, error) {
   118  		b, err := block(key)
   119  		if err != nil {
   120  			return nil, err
   121  		}
   122  		return aead(b)
   123  	}
   124  }
   125  
   126  type Method struct {
   127  	name          string
   128  	keySaltLength int
   129  	constructor   func(key []byte) (cipher.AEAD, error)
   130  	key           []byte
   131  }
   132  
   133  func (m *Method) Name() string {
   134  	return m.name
   135  }
   136  
   137  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
   138  	shadowsocksConn := &clientConn{
   139  		Conn:        conn,
   140  		Method:      m,
   141  		destination: destination,
   142  	}
   143  	return shadowsocksConn, shadowsocksConn.writeRequest(nil)
   144  }
   145  
   146  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
   147  	return &clientConn{
   148  		Conn:        conn,
   149  		Method:      m,
   150  		destination: destination,
   151  	}
   152  }
   153  
   154  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   155  	return &clientPacketConn{m, conn}
   156  }
   157  
   158  type clientConn struct {
   159  	net.Conn
   160  	*Method
   161  	destination M.Socksaddr
   162  	reader      *Reader
   163  	writer      *Writer
   164  }
   165  
   166  func (c *clientConn) writeRequest(payload []byte) error {
   167  	salt := buf.NewSize(c.keySaltLength)
   168  	defer salt.Release()
   169  	salt.WriteRandom(c.keySaltLength)
   170  
   171  	key := buf.NewSize(c.keySaltLength)
   172  
   173  	Kdf(c.key, salt.Bytes(), key)
   174  	writeCipher, err := c.constructor(key.Bytes())
   175  	key.Release()
   176  	if err != nil {
   177  		return err
   178  	}
   179  	writer := NewWriter(c.Conn, writeCipher, MaxPacketSize)
   180  	header := writer.Buffer()
   181  	common.Must1(header.Write(salt.Bytes()))
   182  	bufferedWriter := writer.BufferedWriter(header.Len())
   183  
   184  	if len(payload) > 0 {
   185  		err = M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
   186  		if err != nil {
   187  			return err
   188  		}
   189  
   190  		_, err = bufferedWriter.Write(payload)
   191  		if err != nil {
   192  			return err
   193  		}
   194  	} else {
   195  		err = M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
   196  		if err != nil {
   197  			return err
   198  		}
   199  	}
   200  
   201  	err = bufferedWriter.Flush()
   202  	if err != nil {
   203  		return err
   204  	}
   205  
   206  	c.writer = writer
   207  	return nil
   208  }
   209  
   210  func (c *clientConn) readResponse() error {
   211  	salt := buf.NewSize(c.keySaltLength)
   212  	defer salt.Release()
   213  	_, err := salt.ReadFullFrom(c.Conn, c.keySaltLength)
   214  	if err != nil {
   215  		return err
   216  	}
   217  	key := buf.NewSize(c.keySaltLength)
   218  	defer key.Release()
   219  	Kdf(c.key, salt.Bytes(), key)
   220  	readCipher, err := c.constructor(key.Bytes())
   221  	if err != nil {
   222  		return err
   223  	}
   224  	c.reader = NewReader(
   225  		c.Conn,
   226  		readCipher,
   227  		MaxPacketSize,
   228  	)
   229  	return nil
   230  }
   231  
   232  func (c *clientConn) Read(p []byte) (n int, err error) {
   233  	if c.reader == nil {
   234  		if err = c.readResponse(); err != nil {
   235  			return
   236  		}
   237  	}
   238  	return c.reader.Read(p)
   239  }
   240  
   241  func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
   242  	if c.reader == nil {
   243  		if err = c.readResponse(); err != nil {
   244  			return
   245  		}
   246  	}
   247  	return c.reader.WriteTo(w)
   248  }
   249  
   250  func (c *clientConn) Write(p []byte) (n int, err error) {
   251  	if c.writer == nil {
   252  		err = c.writeRequest(p)
   253  		if err != nil {
   254  			return
   255  		}
   256  		return len(p), nil
   257  	}
   258  	return c.writer.Write(p)
   259  }
   260  
   261  func (c *clientConn) NeedHandshake() bool {
   262  	return c.writer == nil
   263  }
   264  
   265  func (c *clientConn) NeedAdditionalReadDeadline() bool {
   266  	return true
   267  }
   268  
   269  func (c *clientConn) Upstream() any {
   270  	return c.Conn
   271  }
   272  
   273  type clientPacketConn struct {
   274  	*Method
   275  	net.Conn
   276  }
   277  
   278  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   279  	defer buffer.Release()
   280  	header := buf.With(buffer.ExtendHeader(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
   281  	header.WriteRandom(c.keySaltLength)
   282  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   283  	if err != nil {
   284  		return err
   285  	}
   286  	key := buf.NewSize(c.keySaltLength)
   287  	Kdf(c.key, buffer.To(c.keySaltLength), key)
   288  	writeCipher, err := c.constructor(key.Bytes())
   289  	key.Release()
   290  	if err != nil {
   291  		return err
   292  	}
   293  	writeCipher.Seal(buffer.Index(c.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.keySaltLength), nil)
   294  	buffer.Extend(Overhead)
   295  	return common.Error(c.Write(buffer.Bytes()))
   296  }
   297  
   298  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
   299  	n, err := c.Read(buffer.FreeBytes())
   300  	if err != nil {
   301  		return M.Socksaddr{}, err
   302  	}
   303  	buffer.Truncate(n)
   304  	if buffer.Len() < c.keySaltLength {
   305  		return M.Socksaddr{}, io.ErrShortBuffer
   306  	}
   307  	key := buf.NewSize(c.keySaltLength)
   308  	Kdf(c.key, buffer.To(c.keySaltLength), key)
   309  	readCipher, err := c.constructor(key.Bytes())
   310  	key.Release()
   311  	if err != nil {
   312  		return M.Socksaddr{}, err
   313  	}
   314  	packet, err := readCipher.Open(buffer.Index(c.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(c.keySaltLength), nil)
   315  	if err != nil {
   316  		return M.Socksaddr{}, err
   317  	}
   318  	buffer.Advance(c.keySaltLength)
   319  	buffer.Truncate(len(packet))
   320  	if err != nil {
   321  		return M.Socksaddr{}, err
   322  	}
   323  	return M.SocksaddrSerializer.ReadAddrPort(buffer)
   324  }
   325  
   326  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   327  	buffer := buf.With(p)
   328  	destination, err := c.ReadPacket(buffer)
   329  	if err != nil {
   330  		return
   331  	}
   332  	if destination.IsFqdn() {
   333  		addr = destination
   334  	} else {
   335  		addr = destination.UDPAddr()
   336  	}
   337  	n = copy(p, buffer.Bytes())
   338  	return
   339  }
   340  
   341  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   342  	destination := M.SocksaddrFromNet(addr)
   343  	buffer := buf.NewSize(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p) + Overhead)
   344  	defer buffer.Release()
   345  	buffer.Resize(c.keySaltLength+M.SocksaddrSerializer.AddrPortLen(destination), 0)
   346  	common.Must1(buffer.Write(p))
   347  	err = c.WritePacket(buffer, destination)
   348  	if err != nil {
   349  		return
   350  	}
   351  	return len(p), nil
   352  }
   353  
   354  func (c *clientPacketConn) FrontHeadroom() int {
   355  	return c.keySaltLength + M.MaxSocksaddrLength
   356  }
   357  
   358  func (c *clientPacketConn) RearHeadroom() int {
   359  	return Overhead
   360  }
   361  
   362  func (c *clientPacketConn) ReaderMTU() int {
   363  	return MaxPacketSize
   364  }
   365  
   366  func (c *clientPacketConn) WriterMTU() int {
   367  	return MaxPacketSize
   368  }
   369  
   370  func (c *clientPacketConn) Upstream() any {
   371  	return c.Conn
   372  }