github.com/metacubex/sing-shadowsocks2@v0.2.0/shadowaead/method.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"context"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"net"
     8  
     9  	C "github.com/metacubex/sing-shadowsocks2/cipher"
    10  	"github.com/metacubex/sing-shadowsocks2/internal/legacykey"
    11  	"github.com/metacubex/sing-shadowsocks2/internal/shadowio"
    12  	"github.com/sagernet/sing/common"
    13  	"github.com/sagernet/sing/common/buf"
    14  	"github.com/sagernet/sing/common/bufio"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  	"github.com/sagernet/sing/common/rw"
    19  
    20  	"github.com/RyuaNerin/go-krypto/lea"
    21  	"github.com/Yawning/aez"
    22  	"github.com/ericlagergren/aegis"
    23  	"github.com/ericlagergren/siv"
    24  	"github.com/oasisprotocol/deoxysii"
    25  	"github.com/sina-ghaderi/rabaead"
    26  	"golang.org/x/crypto/chacha20poly1305"
    27  )
    28  
    29  var MethodList = []string{
    30  	"aes-128-gcm",
    31  	"aes-192-gcm",
    32  	"aes-256-gcm",
    33  	"chacha20-ietf-poly1305",
    34  	"xchacha20-ietf-poly1305",
    35  	// began not standard methods
    36  	"rabbit128-poly1305",
    37  	"aes-128-gcm-siv",
    38  	"aes-256-gcm-siv",
    39  	"aegis-128l",
    40  	"aegis-256",
    41  	"aez-384",
    42  	"deoxys-ii-256-128",
    43  	"lea-128-gcm",
    44  	"lea-192-gcm",
    45  	"lea-256-gcm",
    46  }
    47  
    48  func init() {
    49  	C.RegisterMethod(MethodList, func(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) {
    50  		return NewMethod(ctx, methodName, options)
    51  	})
    52  }
    53  
    54  type Method struct {
    55  	keySaltLength int
    56  	constructor   func(key []byte) (cipher.AEAD, error)
    57  	key           []byte
    58  }
    59  
    60  func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (*Method, error) {
    61  	m := &Method{}
    62  	switch methodName {
    63  	case "aes-128-gcm":
    64  		m.keySaltLength = 16
    65  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    66  	case "aes-192-gcm":
    67  		m.keySaltLength = 24
    68  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    69  	case "aes-256-gcm":
    70  		m.keySaltLength = 32
    71  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    72  	case "chacha20-ietf-poly1305":
    73  		m.keySaltLength = 32
    74  		m.constructor = chacha20poly1305.New
    75  	case "xchacha20-ietf-poly1305":
    76  		m.keySaltLength = 32
    77  		m.constructor = chacha20poly1305.NewX
    78  	case "rabbit128-poly1305":
    79  		m.keySaltLength = 16
    80  		m.constructor = rabaead.NewAEAD
    81  	case "aes-128-gcm-siv":
    82  		m.keySaltLength = 16
    83  		m.constructor = siv.NewGCM
    84  	case "aes-256-gcm-siv":
    85  		m.keySaltLength = 32
    86  		m.constructor = siv.NewGCM
    87  	case "aegis-128l":
    88  		m.keySaltLength = 16
    89  		m.constructor = aegis.New
    90  	case "aegis-256":
    91  		m.keySaltLength = 32
    92  		m.constructor = aegis.New
    93  	case "aez-384":
    94  		m.keySaltLength = 3 * 16
    95  		m.constructor = aez.New
    96  	case "deoxys-ii-256-128":
    97  		m.keySaltLength = 32
    98  		m.constructor = deoxysii.New
    99  	case "lea-128-gcm":
   100  		m.keySaltLength = 16
   101  		m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM)
   102  	case "lea-192-gcm":
   103  		m.keySaltLength = 24
   104  		m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM)
   105  	case "lea-256-gcm":
   106  		m.keySaltLength = 32
   107  		m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM)
   108  	}
   109  	if len(options.Key) == m.keySaltLength {
   110  		m.key = options.Key
   111  	} else if len(options.Key) > 0 {
   112  		return nil, E.New("bad key length, required ", m.keySaltLength, ", got ", len(options.Key))
   113  	} else if options.Password == "" {
   114  		return nil, C.ErrMissingPassword
   115  	} else {
   116  		m.key = legacykey.Key([]byte(options.Password), m.keySaltLength)
   117  	}
   118  	return m, nil
   119  }
   120  
   121  func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) {
   122  	return func(key []byte) (cipher.AEAD, error) {
   123  		b, err := block(key)
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  		return aead(b)
   128  	}
   129  }
   130  
   131  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
   132  	ssConn := &clientConn{
   133  		Conn:        conn,
   134  		method:      m,
   135  		destination: destination,
   136  	}
   137  	return ssConn, ssConn.writeRequest(nil)
   138  }
   139  
   140  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
   141  	return &clientConn{
   142  		Conn:        conn,
   143  		method:      m,
   144  		destination: destination,
   145  	}
   146  }
   147  
   148  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   149  	return &clientPacketConn{
   150  		AbstractConn: conn,
   151  		reader:       bufio.NewExtendedReader(conn),
   152  		writer:       bufio.NewExtendedWriter(conn),
   153  		method:       m,
   154  	}
   155  }
   156  
   157  var _ N.ExtendedConn = (*clientConn)(nil)
   158  
   159  type clientConn struct {
   160  	net.Conn
   161  	method          *Method
   162  	destination     M.Socksaddr
   163  	reader          *shadowio.Reader
   164  	readWaitOptions N.ReadWaitOptions
   165  	writer          *shadowio.Writer
   166  	shadowio.WriterInterface
   167  }
   168  
   169  func (c *clientConn) writeRequest(payload []byte) error {
   170  	requestBuffer := buf.New()
   171  	requestBuffer.WriteRandom(c.method.keySaltLength)
   172  	key := make([]byte, c.method.keySaltLength)
   173  	legacykey.Kdf(c.method.key, requestBuffer.Bytes(), key)
   174  	writeCipher, err := c.method.constructor(key)
   175  	if err != nil {
   176  		return err
   177  	}
   178  	bufferedRequestWriter := bufio.NewBufferedWriter(c.Conn, requestBuffer)
   179  	requestContentWriter := shadowio.NewWriter(bufferedRequestWriter, writeCipher, nil, MaxPacketSize)
   180  	bufferedRequestContentWriter := bufio.NewBufferedWriter(requestContentWriter, buf.New())
   181  	err = M.SocksaddrSerializer.WriteAddrPort(bufferedRequestContentWriter, c.destination)
   182  	if err != nil {
   183  		return err
   184  	}
   185  	_, err = bufferedRequestContentWriter.Write(payload)
   186  	if err != nil {
   187  		return err
   188  	}
   189  	err = bufferedRequestContentWriter.Fallthrough()
   190  	if err != nil {
   191  		return err
   192  	}
   193  	err = bufferedRequestWriter.Fallthrough()
   194  	if err != nil {
   195  		return err
   196  	}
   197  	c.writer = shadowio.NewWriter(c.Conn, writeCipher, requestContentWriter.TakeNonce(), MaxPacketSize)
   198  	return nil
   199  }
   200  
   201  func (c *clientConn) readResponse() error {
   202  	buffer := buf.NewSize(c.method.keySaltLength)
   203  	defer buffer.Release()
   204  	_, err := buffer.ReadFullFrom(c.Conn, c.method.keySaltLength)
   205  	if err != nil {
   206  		return err
   207  	}
   208  	legacykey.Kdf(c.method.key, buffer.Bytes(), buffer.Bytes())
   209  	readCipher, err := c.method.constructor(buffer.Bytes())
   210  	if err != nil {
   211  		return err
   212  	}
   213  	reader := shadowio.NewReader(c.Conn, readCipher)
   214  	reader.InitializeReadWaiter(c.readWaitOptions)
   215  	c.reader = reader
   216  	return nil
   217  }
   218  
   219  func (c *clientConn) Read(p []byte) (n int, err error) {
   220  	if c.reader == nil {
   221  		err = c.readResponse()
   222  		if err != nil {
   223  			return
   224  		}
   225  	}
   226  	return c.reader.Read(p)
   227  }
   228  
   229  func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error {
   230  	if c.reader == nil {
   231  		err := c.readResponse()
   232  		if err != nil {
   233  			return err
   234  		}
   235  	}
   236  	return c.reader.ReadBuffer(buffer)
   237  }
   238  
   239  func (c *clientConn) Write(p []byte) (n int, err error) {
   240  	if c.writer == nil {
   241  		err = c.writeRequest(p)
   242  		if err == nil {
   243  			n = len(p)
   244  		}
   245  		return
   246  	}
   247  	return c.writer.Write(p)
   248  }
   249  
   250  func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error {
   251  	if c.writer == nil {
   252  		defer buffer.Release()
   253  		return c.writeRequest(buffer.Bytes())
   254  	}
   255  	return c.writer.WriteBuffer(buffer)
   256  }
   257  
   258  func (c *clientConn) NeedHandshake() bool {
   259  	return c.writer == nil
   260  }
   261  
   262  func (c *clientConn) Upstream() any {
   263  	return c.Conn
   264  }
   265  
   266  func (c *clientConn) WriterMTU() int {
   267  	return MaxPacketSize
   268  }
   269  
   270  type clientPacketConn struct {
   271  	N.AbstractConn
   272  	reader N.ExtendedReader
   273  	writer N.ExtendedWriter
   274  	method *Method
   275  }
   276  
   277  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   278  	err = c.reader.ReadBuffer(buffer)
   279  	if err != nil {
   280  		return
   281  	}
   282  	return c.readPacket(buffer)
   283  }
   284  
   285  func (c *clientPacketConn) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   286  	if buffer.Len() < c.method.keySaltLength {
   287  		return M.Socksaddr{}, C.ErrPacketTooShort
   288  	}
   289  	key := buf.NewSize(c.method.keySaltLength)
   290  	legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength))
   291  	readCipher, err := c.method.constructor(key.Bytes())
   292  	key.Release()
   293  	if err != nil {
   294  		return
   295  	}
   296  	packet, err := readCipher.Open(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil)
   297  	if err != nil {
   298  		return
   299  	}
   300  	buffer.Advance(c.method.keySaltLength)
   301  	buffer.Truncate(len(packet))
   302  	if err != nil {
   303  		return
   304  	}
   305  	destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
   306  	if err != nil {
   307  		return
   308  	}
   309  	return destination.Unwrap(), nil
   310  }
   311  
   312  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   313  	header := buf.With(buffer.ExtendHeader(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
   314  	header.WriteRandom(c.method.keySaltLength)
   315  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   316  	if err != nil {
   317  		return err
   318  	}
   319  	key := buf.NewSize(c.method.keySaltLength)
   320  	legacykey.Kdf(c.method.key, header.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength))
   321  	writeCipher, err := c.method.constructor(key.Bytes())
   322  	key.Release()
   323  	if err != nil {
   324  		return err
   325  	}
   326  	writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil)
   327  	buffer.Extend(shadowio.Overhead)
   328  	return c.writer.WriteBuffer(buffer)
   329  }
   330  
   331  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   332  	n, err = c.reader.Read(p)
   333  	if err != nil {
   334  		return
   335  	}
   336  	if n < c.method.keySaltLength {
   337  		err = C.ErrPacketTooShort
   338  		return
   339  	}
   340  	key := buf.NewSize(c.method.keySaltLength)
   341  	legacykey.Kdf(c.method.key, p[:c.method.keySaltLength], key.Extend(c.method.keySaltLength))
   342  	readCipher, err := c.method.constructor(key.Bytes())
   343  	key.Release()
   344  	if err != nil {
   345  		return
   346  	}
   347  	packet, err := readCipher.Open(p[c.method.keySaltLength:c.method.keySaltLength], rw.ZeroBytes[:readCipher.NonceSize()], p[c.method.keySaltLength:n], nil)
   348  	if err != nil {
   349  		return
   350  	}
   351  	packetContent := buf.As(packet)
   352  	destination, err := M.SocksaddrSerializer.ReadAddrPort(packetContent)
   353  	if err != nil {
   354  		return
   355  	}
   356  	if !destination.IsFqdn() {
   357  		addr = destination.UDPAddr()
   358  	} else {
   359  		addr = destination
   360  	}
   361  	n = copy(p, packetContent.Bytes())
   362  	return
   363  }
   364  
   365  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   366  	destination := M.SocksaddrFromNet(addr)
   367  	buffer := buf.NewSize(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p) + shadowio.Overhead)
   368  	defer buffer.Release()
   369  	buffer.WriteRandom(c.method.keySaltLength)
   370  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
   371  	if err != nil {
   372  		return
   373  	}
   374  	common.Must1(buffer.Write(p))
   375  	key := buf.NewSize(c.method.keySaltLength)
   376  	legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength))
   377  	writeCipher, err := c.method.constructor(key.Bytes())
   378  	key.Release()
   379  	if err != nil {
   380  		return
   381  	}
   382  	writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil)
   383  	buffer.Extend(shadowio.Overhead)
   384  	_, err = c.writer.Write(buffer.Bytes())
   385  	if err != nil {
   386  		return
   387  	}
   388  	return len(p), nil
   389  }
   390  
   391  func (c *clientPacketConn) FrontHeadroom() int {
   392  	return c.method.keySaltLength + M.MaxSocksaddrLength
   393  }
   394  
   395  func (c *clientPacketConn) RearHeadroom() int {
   396  	return shadowio.Overhead
   397  }
   398  
   399  func (c *clientPacketConn) ReaderMTU() int {
   400  	return MaxPacketSize
   401  }
   402  
   403  func (c *clientPacketConn) WriterMTU() int {
   404  	return MaxPacketSize
   405  }
   406  
   407  func (c *clientPacketConn) Upstream() any {
   408  	return c.AbstractConn
   409  }