github.com/MerlinKodo/sing-shadowsocks2@v0.1.6/shadowaead/method.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"context"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"net"
     8  
     9  	C "github.com/MerlinKodo/sing-shadowsocks2/cipher"
    10  	"github.com/MerlinKodo/sing-shadowsocks2/internal/legacykey"
    11  	"github.com/MerlinKodo/sing-shadowsocks2/internal/shadowio"
    12  	"github.com/sagernet/sing/common"
    13  	"github.com/sagernet/sing/common/buf"
    14  	"github.com/sagernet/sing/common/bufio"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  	"github.com/sagernet/sing/common/rw"
    19  
    20  	"github.com/RyuaNerin/go-krypto/lea"
    21  	"github.com/Yawning/aez"
    22  	"github.com/ericlagergren/aegis"
    23  	"github.com/ericlagergren/siv"
    24  	"github.com/oasisprotocol/deoxysii"
    25  	"github.com/sina-ghaderi/rabaead"
    26  	"golang.org/x/crypto/chacha20poly1305"
    27  )
    28  
    29  var MethodList = []string{
    30  	"aes-128-gcm",
    31  	"aes-192-gcm",
    32  	"aes-256-gcm",
    33  	"chacha20-ietf-poly1305",
    34  	"xchacha20-ietf-poly1305",
    35  	// began not standard methods
    36  	"rabbit128-poly1305",
    37  	"aes-128-gcm-siv",
    38  	"aes-256-gcm-siv",
    39  	"aegis-128l",
    40  	"aegis-256",
    41  	"aez-384",
    42  	"deoxys-ii-256-128",
    43  	"lea-128-gcm",
    44  	"lea-192-gcm",
    45  	"lea-256-gcm",
    46  }
    47  
    48  func init() {
    49  	C.RegisterMethod(MethodList, func(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) {
    50  		return NewMethod(ctx, methodName, options)
    51  	})
    52  }
    53  
    54  type Method struct {
    55  	keySaltLength int
    56  	constructor   func(key []byte) (cipher.AEAD, error)
    57  	key           []byte
    58  }
    59  
    60  func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (*Method, error) {
    61  	m := &Method{}
    62  	switch methodName {
    63  	case "aes-128-gcm":
    64  		m.keySaltLength = 16
    65  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    66  	case "aes-192-gcm":
    67  		m.keySaltLength = 24
    68  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    69  	case "aes-256-gcm":
    70  		m.keySaltLength = 32
    71  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    72  	case "chacha20-ietf-poly1305":
    73  		m.keySaltLength = 32
    74  		m.constructor = chacha20poly1305.New
    75  	case "xchacha20-ietf-poly1305":
    76  		m.keySaltLength = 32
    77  		m.constructor = chacha20poly1305.NewX
    78  	case "rabbit128-poly1305":
    79  		m.keySaltLength = 16
    80  		m.constructor = rabaead.NewAEAD
    81  	case "aes-128-gcm-siv":
    82  		m.keySaltLength = 16
    83  		m.constructor = siv.NewGCM
    84  	case "aes-256-gcm-siv":
    85  		m.keySaltLength = 32
    86  		m.constructor = siv.NewGCM
    87  	case "aegis-128l":
    88  		m.keySaltLength = 16
    89  		m.constructor = aegis.New
    90  	case "aegis-256":
    91  		m.keySaltLength = 32
    92  		m.constructor = aegis.New
    93  	case "aez-384":
    94  		m.keySaltLength = 3 * 16
    95  		m.constructor = aez.New
    96  	case "deoxys-ii-256-128":
    97  		m.keySaltLength = 32
    98  		m.constructor = deoxysii.New
    99  	case "lea-128-gcm":
   100  		m.keySaltLength = 16
   101  		m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM)
   102  	case "lea-192-gcm":
   103  		m.keySaltLength = 24
   104  		m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM)
   105  	case "lea-256-gcm":
   106  		m.keySaltLength = 32
   107  		m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM)
   108  	}
   109  	if len(options.Key) == m.keySaltLength {
   110  		m.key = options.Key
   111  	} else if len(options.Key) > 0 {
   112  		return nil, E.New("bad key length, required ", m.keySaltLength, ", got ", len(options.Key))
   113  	} else if options.Password == "" {
   114  		return nil, C.ErrMissingPassword
   115  	} else {
   116  		m.key = legacykey.Key([]byte(options.Password), m.keySaltLength)
   117  	}
   118  	return m, nil
   119  }
   120  
   121  func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) {
   122  	return func(key []byte) (cipher.AEAD, error) {
   123  		b, err := block(key)
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  		return aead(b)
   128  	}
   129  }
   130  
   131  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
   132  	ssConn := &clientConn{
   133  		Conn:        conn,
   134  		method:      m,
   135  		destination: destination,
   136  	}
   137  	return ssConn, ssConn.writeRequest(nil)
   138  }
   139  
   140  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
   141  	return &clientConn{
   142  		Conn:        conn,
   143  		method:      m,
   144  		destination: destination,
   145  	}
   146  }
   147  
   148  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   149  	pc := &clientPacketConn{
   150  		AbstractConn: conn,
   151  		reader:       bufio.NewExtendedReader(conn),
   152  		writer:       bufio.NewExtendedWriter(conn),
   153  		method:       m,
   154  	}
   155  	if waitRead, isWaitRead := N.CastReader[shadowio.WaitReadReader](conn); isWaitRead {
   156  		return &clientWaitPacketConn{
   157  			clientPacketConn: pc,
   158  			waitRead:         waitRead,
   159  		}
   160  	}
   161  	return pc
   162  }
   163  
   164  type clientConn struct {
   165  	net.Conn
   166  	method      *Method
   167  	destination M.Socksaddr
   168  	reader      *shadowio.Reader
   169  	writer      *shadowio.Writer
   170  	shadowio.WriterInterface
   171  }
   172  
   173  func (c *clientConn) writeRequest(payload []byte) error {
   174  	requestBuffer := buf.New()
   175  	requestBuffer.WriteRandom(c.method.keySaltLength)
   176  	key := make([]byte, c.method.keySaltLength)
   177  	legacykey.Kdf(c.method.key, requestBuffer.Bytes(), key)
   178  	writeCipher, err := c.method.constructor(key)
   179  	if err != nil {
   180  		return err
   181  	}
   182  	bufferedRequestWriter := bufio.NewBufferedWriter(c.Conn, requestBuffer)
   183  	requestContentWriter := shadowio.NewWriter(bufferedRequestWriter, writeCipher, nil, MaxPacketSize)
   184  	bufferedRequestContentWriter := bufio.NewBufferedWriter(requestContentWriter, buf.New())
   185  	err = M.SocksaddrSerializer.WriteAddrPort(bufferedRequestContentWriter, c.destination)
   186  	if err != nil {
   187  		return err
   188  	}
   189  	_, err = bufferedRequestContentWriter.Write(payload)
   190  	if err != nil {
   191  		return err
   192  	}
   193  	err = bufferedRequestContentWriter.Fallthrough()
   194  	if err != nil {
   195  		return err
   196  	}
   197  	err = bufferedRequestWriter.Fallthrough()
   198  	if err != nil {
   199  		return err
   200  	}
   201  	c.writer = shadowio.NewWriter(c.Conn, writeCipher, requestContentWriter.TakeNonce(), MaxPacketSize)
   202  	return nil
   203  }
   204  
   205  func (c *clientConn) readResponse() error {
   206  	buffer := buf.NewSize(c.method.keySaltLength)
   207  	defer buffer.Release()
   208  	_, err := buffer.ReadFullFrom(c.Conn, c.method.keySaltLength)
   209  	if err != nil {
   210  		return err
   211  	}
   212  	legacykey.Kdf(c.method.key, buffer.Bytes(), buffer.Bytes())
   213  	readCipher, err := c.method.constructor(buffer.Bytes())
   214  	if err != nil {
   215  		return err
   216  	}
   217  	c.reader = shadowio.NewReader(c.Conn, readCipher)
   218  	return nil
   219  }
   220  
   221  func (c *clientConn) Read(p []byte) (n int, err error) {
   222  	if c.reader == nil {
   223  		err = c.readResponse()
   224  		if err != nil {
   225  			return
   226  		}
   227  	}
   228  	return c.reader.Read(p)
   229  }
   230  
   231  func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error {
   232  	if c.reader == nil {
   233  		err := c.readResponse()
   234  		if err != nil {
   235  			return err
   236  		}
   237  	}
   238  	return c.reader.ReadBuffer(buffer)
   239  }
   240  
   241  func (c *clientConn) ReadBufferThreadSafe() (buffer *buf.Buffer, err error) {
   242  	if c.reader == nil {
   243  		err = c.readResponse()
   244  		if err != nil {
   245  			return
   246  		}
   247  	}
   248  	return c.reader.ReadBufferThreadSafe()
   249  }
   250  
   251  func (c *clientConn) Write(p []byte) (n int, err error) {
   252  	if c.writer == nil {
   253  		err = c.writeRequest(p)
   254  		if err == nil {
   255  			n = len(p)
   256  		}
   257  		return
   258  	}
   259  	return c.writer.Write(p)
   260  }
   261  
   262  func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error {
   263  	if c.writer == nil {
   264  		defer buffer.Release()
   265  		return c.writeRequest(buffer.Bytes())
   266  	}
   267  	return c.writer.WriteBuffer(buffer)
   268  }
   269  
   270  func (c *clientConn) NeedHandshake() bool {
   271  	return c.writer == nil
   272  }
   273  
   274  func (c *clientConn) Upstream() any {
   275  	return c.Conn
   276  }
   277  
   278  func (c *clientConn) WriterMTU() int {
   279  	return MaxPacketSize
   280  }
   281  
   282  type clientPacketConn struct {
   283  	N.AbstractConn
   284  	reader N.ExtendedReader
   285  	writer N.ExtendedWriter
   286  	method *Method
   287  }
   288  
   289  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   290  	err = c.reader.ReadBuffer(buffer)
   291  	if err != nil {
   292  		return
   293  	}
   294  	if buffer.Len() < c.method.keySaltLength {
   295  		return M.Socksaddr{}, C.ErrPacketTooShort
   296  	}
   297  	key := buf.NewSize(c.method.keySaltLength)
   298  	legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength))
   299  	readCipher, err := c.method.constructor(key.Bytes())
   300  	key.Release()
   301  	if err != nil {
   302  		return
   303  	}
   304  	packet, err := readCipher.Open(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil)
   305  	if err != nil {
   306  		return
   307  	}
   308  	buffer.Advance(c.method.keySaltLength)
   309  	buffer.Truncate(len(packet))
   310  	if err != nil {
   311  		return
   312  	}
   313  	destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
   314  	if err != nil {
   315  		return
   316  	}
   317  	return destination.Unwrap(), nil
   318  }
   319  
   320  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   321  	header := buf.With(buffer.ExtendHeader(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
   322  	header.WriteRandom(c.method.keySaltLength)
   323  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   324  	if err != nil {
   325  		return err
   326  	}
   327  	key := buf.NewSize(c.method.keySaltLength)
   328  	legacykey.Kdf(c.method.key, header.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength))
   329  	writeCipher, err := c.method.constructor(key.Bytes())
   330  	key.Release()
   331  	if err != nil {
   332  		return err
   333  	}
   334  	writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil)
   335  	buffer.Extend(shadowio.Overhead)
   336  	return c.writer.WriteBuffer(buffer)
   337  }
   338  
   339  func (c *clientPacketConn) readFrom(p []byte) (data []byte, addr net.Addr, err error) {
   340  	if len(p) < c.method.keySaltLength {
   341  		err = C.ErrPacketTooShort
   342  		return
   343  	}
   344  	key := buf.NewSize(c.method.keySaltLength)
   345  	legacykey.Kdf(c.method.key, p[:c.method.keySaltLength], key.Extend(c.method.keySaltLength))
   346  	readCipher, err := c.method.constructor(key.Bytes())
   347  	key.Release()
   348  	if err != nil {
   349  		return
   350  	}
   351  	packet, err := readCipher.Open(p[c.method.keySaltLength:c.method.keySaltLength], rw.ZeroBytes[:readCipher.NonceSize()], p[c.method.keySaltLength:], nil)
   352  	if err != nil {
   353  		return
   354  	}
   355  	packetContent := buf.As(packet)
   356  	destination, err := M.SocksaddrSerializer.ReadAddrPort(packetContent)
   357  	if err != nil {
   358  		return
   359  	}
   360  	if !destination.IsFqdn() {
   361  		addr = destination.UDPAddr()
   362  	} else {
   363  		addr = destination
   364  	}
   365  	data = packetContent.Bytes()
   366  	return
   367  }
   368  
   369  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   370  	n, err = c.reader.Read(p)
   371  	if err != nil {
   372  		return
   373  	}
   374  	var data []byte
   375  	data, addr, err = c.readFrom(p[:n])
   376  	n = copy(p, data)
   377  	return
   378  }
   379  
   380  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   381  	destination := M.SocksaddrFromNet(addr)
   382  	buffer := buf.NewSize(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p) + shadowio.Overhead)
   383  	defer buffer.Release()
   384  	buffer.WriteRandom(c.method.keySaltLength)
   385  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
   386  	if err != nil {
   387  		return
   388  	}
   389  	common.Must1(buffer.Write(p))
   390  	key := buf.NewSize(c.method.keySaltLength)
   391  	legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength))
   392  	writeCipher, err := c.method.constructor(key.Bytes())
   393  	key.Release()
   394  	if err != nil {
   395  		return
   396  	}
   397  	writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil)
   398  	buffer.Extend(shadowio.Overhead)
   399  	_, err = c.writer.Write(buffer.Bytes())
   400  	if err != nil {
   401  		return
   402  	}
   403  	return len(p), nil
   404  }
   405  
   406  func (c *clientPacketConn) FrontHeadroom() int {
   407  	return c.method.keySaltLength + M.MaxSocksaddrLength
   408  }
   409  
   410  func (c *clientPacketConn) RearHeadroom() int {
   411  	return shadowio.Overhead
   412  }
   413  
   414  func (c *clientPacketConn) Upstream() any {
   415  	return c.AbstractConn
   416  }
   417  
   418  var _ shadowio.WaitReadFrom = (*clientWaitPacketConn)(nil)
   419  
   420  type clientWaitPacketConn struct {
   421  	*clientPacketConn
   422  	waitRead shadowio.WaitRead
   423  }
   424  
   425  func (c *clientWaitPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
   426  	data, put, err = c.waitRead.WaitRead()
   427  	if err != nil {
   428  		return
   429  	}
   430  	if len(data) <= 0 {
   431  		err = C.ErrPacketTooShort
   432  		return
   433  	}
   434  	data, addr, err = c.readFrom(data)
   435  	if err != nil {
   436  		if put != nil {
   437  			put()
   438  		}
   439  		put = nil
   440  		data = nil
   441  		return
   442  	}
   443  	return
   444  }