github.com/MerlinKodo/sing-shadowsocks2@v0.1.6/shadowstream/method.go (about)

     1  package shadowstream
     2  
     3  import (
     4  	"context"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/md5"
     8  	"crypto/rc4"
     9  	"net"
    10  	"os"
    11  
    12  	C "github.com/MerlinKodo/sing-shadowsocks2/cipher"
    13  	"github.com/MerlinKodo/sing-shadowsocks2/internal/legacykey"
    14  	"github.com/MerlinKodo/sing-shadowsocks2/internal/shadowio"
    15  	"github.com/sagernet/sing/common"
    16  	"github.com/sagernet/sing/common/buf"
    17  	"github.com/sagernet/sing/common/bufio"
    18  	E "github.com/sagernet/sing/common/exceptions"
    19  	M "github.com/sagernet/sing/common/metadata"
    20  	N "github.com/sagernet/sing/common/network"
    21  
    22  	"github.com/aead/chacha20/chacha"
    23  	"golang.org/x/crypto/chacha20"
    24  )
    25  
    26  var MethodList = []string{
    27  	"aes-128-ctr",
    28  	"aes-192-ctr",
    29  	"aes-256-ctr",
    30  	"aes-128-cfb",
    31  	"aes-192-cfb",
    32  	"aes-256-cfb",
    33  	"rc4-md5",
    34  	"chacha20-ietf",
    35  	"xchacha20",
    36  	"chacha20",
    37  }
    38  
    39  func init() {
    40  	C.RegisterMethod(MethodList, NewMethod)
    41  }
    42  
    43  type Method struct {
    44  	keyLength          int
    45  	saltLength         int
    46  	encryptConstructor func(key []byte, salt []byte) (cipher.Stream, error)
    47  	decryptConstructor func(key []byte, salt []byte) (cipher.Stream, error)
    48  	key                []byte
    49  }
    50  
    51  func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) {
    52  	m := &Method{}
    53  	switch methodName {
    54  	case "aes-128-ctr":
    55  		m.keyLength = 16
    56  		m.saltLength = aes.BlockSize
    57  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    58  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    59  	case "aes-192-ctr":
    60  		m.keyLength = 24
    61  		m.saltLength = aes.BlockSize
    62  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    63  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    64  	case "aes-256-ctr":
    65  		m.keyLength = 32
    66  		m.saltLength = aes.BlockSize
    67  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    68  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    69  	case "aes-128-cfb":
    70  		m.keyLength = 16
    71  		m.saltLength = aes.BlockSize
    72  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
    73  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
    74  	case "aes-192-cfb":
    75  		m.keyLength = 24
    76  		m.saltLength = aes.BlockSize
    77  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
    78  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
    79  	case "aes-256-cfb":
    80  		m.keyLength = 32
    81  		m.saltLength = aes.BlockSize
    82  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
    83  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
    84  	case "rc4-md5":
    85  		m.keyLength = 16
    86  		m.saltLength = 16
    87  		m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
    88  			h := md5.New()
    89  			h.Write(key)
    90  			h.Write(salt)
    91  			return rc4.NewCipher(h.Sum(nil))
    92  		}
    93  		m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
    94  			h := md5.New()
    95  			h.Write(key)
    96  			h.Write(salt)
    97  			return rc4.NewCipher(h.Sum(nil))
    98  		}
    99  	case "chacha20-ietf":
   100  		m.keyLength = chacha20.KeySize
   101  		m.saltLength = chacha20.NonceSize
   102  		m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   103  			return chacha20.NewUnauthenticatedCipher(key, salt)
   104  		}
   105  		m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   106  			return chacha20.NewUnauthenticatedCipher(key, salt)
   107  		}
   108  	case "xchacha20":
   109  		m.keyLength = chacha20.KeySize
   110  		m.saltLength = chacha20.NonceSizeX
   111  		m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   112  			return chacha20.NewUnauthenticatedCipher(key, salt)
   113  		}
   114  		m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   115  			return chacha20.NewUnauthenticatedCipher(key, salt)
   116  		}
   117  	case "chacha20":
   118  		m.keyLength = chacha.KeySize
   119  		m.saltLength = chacha.NonceSize
   120  		m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   121  			return chacha.NewCipher(salt, key, 20)
   122  		}
   123  		m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   124  			return chacha.NewCipher(salt, key, 20)
   125  		}
   126  	default:
   127  		return nil, os.ErrInvalid
   128  	}
   129  	if len(options.Key) == m.keyLength {
   130  		m.key = options.Key
   131  	} else if len(options.Key) > 0 {
   132  		return nil, E.New("bad key length, required ", m.keyLength, ", got ", len(options.Key))
   133  	} else if options.Password != "" {
   134  		m.key = legacykey.Key([]byte(options.Password), m.keyLength)
   135  	} else {
   136  		return nil, C.ErrMissingPassword
   137  	}
   138  	return m, nil
   139  }
   140  
   141  func blockStream(blockCreator func(key []byte) (cipher.Block, error), streamCreator func(block cipher.Block, iv []byte) cipher.Stream) func([]byte, []byte) (cipher.Stream, error) {
   142  	return func(key []byte, iv []byte) (cipher.Stream, error) {
   143  		block, err := blockCreator(key)
   144  		if err != nil {
   145  			return nil, err
   146  		}
   147  		return streamCreator(block, iv), err
   148  	}
   149  }
   150  
   151  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
   152  	ssConn := &clientConn{
   153  		ExtendedConn: bufio.NewExtendedConn(conn),
   154  		method:       m,
   155  		destination:  destination,
   156  	}
   157  	return ssConn, common.Error(ssConn.Write(nil))
   158  }
   159  
   160  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
   161  	return &clientConn{
   162  		ExtendedConn: bufio.NewExtendedConn(conn),
   163  		method:       m,
   164  		destination:  destination,
   165  	}
   166  }
   167  
   168  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   169  	pc := &clientPacketConn{
   170  		ExtendedConn: bufio.NewExtendedConn(conn),
   171  		method:       m,
   172  	}
   173  	if waitRead, isWaitRead := N.CastReader[shadowio.WaitReadReader](conn); isWaitRead {
   174  		return &clientWaitPacketConn{
   175  			clientPacketConn: pc,
   176  			waitRead:         waitRead,
   177  		}
   178  	}
   179  	return pc
   180  }
   181  
   182  type clientConn struct {
   183  	N.ExtendedConn
   184  	method      *Method
   185  	destination M.Socksaddr
   186  	readStream  cipher.Stream
   187  	writeStream cipher.Stream
   188  }
   189  
   190  func (c *clientConn) readResponse() error {
   191  	saltBuffer := buf.NewSize(c.method.saltLength)
   192  	defer saltBuffer.Release()
   193  	_, err := saltBuffer.ReadFullFrom(c.ExtendedConn, c.method.saltLength)
   194  	if err != nil {
   195  		return err
   196  	}
   197  	c.readStream, err = c.method.decryptConstructor(c.method.key, saltBuffer.Bytes())
   198  	return err
   199  }
   200  
   201  func (c *clientConn) Read(p []byte) (n int, err error) {
   202  	if c.readStream == nil {
   203  		err = c.readResponse()
   204  		if err != nil {
   205  			return
   206  		}
   207  	}
   208  	n, err = c.ExtendedConn.Read(p)
   209  	if err != nil {
   210  		return
   211  	}
   212  	c.readStream.XORKeyStream(p[:n], p[:n])
   213  	return
   214  }
   215  
   216  func (c *clientConn) Write(p []byte) (n int, err error) {
   217  	if c.writeStream == nil {
   218  		buffer := buf.NewSize(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(p))
   219  		defer buffer.Release()
   220  		buffer.WriteRandom(c.method.saltLength)
   221  		err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
   222  		if err != nil {
   223  			return
   224  		}
   225  		common.Must1(buffer.Write(p))
   226  		c.writeStream, err = c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength))
   227  		if err != nil {
   228  			return
   229  		}
   230  		c.writeStream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength))
   231  		_, err = c.ExtendedConn.Write(buffer.Bytes())
   232  		if err == nil {
   233  			n = len(p)
   234  		}
   235  		return
   236  	}
   237  	c.writeStream.XORKeyStream(p, p)
   238  	return c.ExtendedConn.Write(p)
   239  }
   240  
   241  func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error {
   242  	if c.readStream == nil {
   243  		err := c.readResponse()
   244  		if err != nil {
   245  			return err
   246  		}
   247  	}
   248  	err := c.ExtendedConn.ReadBuffer(buffer)
   249  	if err != nil {
   250  		return err
   251  	}
   252  	c.readStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
   253  	return nil
   254  }
   255  
   256  func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error {
   257  	if c.writeStream == nil {
   258  		header := buf.With(buffer.ExtendHeader(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination)))
   259  		header.WriteRandom(c.method.saltLength)
   260  		err := M.SocksaddrSerializer.WriteAddrPort(header, c.destination)
   261  		if err != nil {
   262  			return err
   263  		}
   264  		c.writeStream, err = c.method.encryptConstructor(c.method.key, header.To(c.method.saltLength))
   265  		if err != nil {
   266  			return err
   267  		}
   268  		c.writeStream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength))
   269  	} else {
   270  		c.writeStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
   271  	}
   272  	return c.ExtendedConn.WriteBuffer(buffer)
   273  }
   274  
   275  func (c *clientConn) FrontHeadroom() int {
   276  	if c.writeStream == nil {
   277  		return c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination)
   278  	}
   279  	return 0
   280  }
   281  
   282  func (c *clientConn) NeedHandshake() bool {
   283  	return c.writeStream == nil
   284  }
   285  
   286  func (c *clientConn) Upstream() any {
   287  	return c.ExtendedConn
   288  }
   289  
   290  type clientPacketConn struct {
   291  	N.ExtendedConn
   292  	method *Method
   293  }
   294  
   295  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   296  	err = c.ReadBuffer(buffer)
   297  	if err != nil {
   298  		return
   299  	}
   300  	stream, err := c.method.decryptConstructor(c.method.key, buffer.To(c.method.saltLength))
   301  	if err != nil {
   302  		return
   303  	}
   304  	stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength))
   305  	buffer.Advance(c.method.saltLength)
   306  	destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
   307  	if err != nil {
   308  		return
   309  	}
   310  	return destination.Unwrap(), nil
   311  }
   312  
   313  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   314  	header := buf.With(buffer.ExtendHeader(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
   315  	header.WriteRandom(c.method.saltLength)
   316  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   317  	if err != nil {
   318  		return err
   319  	}
   320  	stream, err := c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength))
   321  	if err != nil {
   322  		return err
   323  	}
   324  	stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength))
   325  	return c.ExtendedConn.WriteBuffer(buffer)
   326  }
   327  
   328  func (c *clientPacketConn) readFrom(p []byte) (data []byte, addr net.Addr, err error) {
   329  	if len(p) < c.method.saltLength {
   330  		err = C.ErrPacketTooShort
   331  		return
   332  	}
   333  	stream, err := c.method.decryptConstructor(c.method.key, p[:c.method.saltLength])
   334  	if err != nil {
   335  		return
   336  	}
   337  	buffer := buf.As(p[c.method.saltLength:])
   338  	stream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
   339  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   340  	if err != nil {
   341  		return
   342  	}
   343  	if destination.IsFqdn() {
   344  		addr = destination
   345  	} else {
   346  		addr = destination.UDPAddr()
   347  	}
   348  	data = buffer.Bytes()
   349  	return
   350  }
   351  
   352  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   353  	n, err = c.ExtendedConn.Read(p)
   354  	if err != nil {
   355  		return
   356  	}
   357  	var data []byte
   358  	data, addr, err = c.readFrom(p[:n])
   359  	n = copy(p, data)
   360  	return
   361  }
   362  
   363  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   364  	destination := M.SocksaddrFromNet(addr)
   365  	buffer := buf.NewSize(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
   366  	defer buffer.Release()
   367  	buffer.WriteRandom(c.method.saltLength)
   368  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
   369  	if err != nil {
   370  		return
   371  	}
   372  	stream, err := c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength))
   373  	if err != nil {
   374  		return
   375  	}
   376  	stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength))
   377  	stream.XORKeyStream(buffer.Extend(len(p)), p)
   378  	_, err = c.ExtendedConn.Write(buffer.Bytes())
   379  	if err == nil {
   380  		n = len(p)
   381  	}
   382  	return
   383  }
   384  
   385  func (c *clientPacketConn) FrontHeadroom() int {
   386  	return c.method.saltLength + M.MaxSocksaddrLength
   387  }
   388  
   389  func (c *clientPacketConn) Upstream() any {
   390  	return c.ExtendedConn
   391  }
   392  
   393  var _ shadowio.WaitReadFrom = (*clientWaitPacketConn)(nil)
   394  
   395  type clientWaitPacketConn struct {
   396  	*clientPacketConn
   397  	waitRead shadowio.WaitRead
   398  }
   399  
   400  func (c *clientWaitPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
   401  	data, put, err = c.waitRead.WaitRead()
   402  	if err != nil {
   403  		return
   404  	}
   405  	if len(data) <= 0 {
   406  		err = C.ErrPacketTooShort
   407  		return
   408  	}
   409  	data, addr, err = c.readFrom(data)
   410  	if err != nil {
   411  		if put != nil {
   412  			put()
   413  		}
   414  		put = nil
   415  		data = nil
   416  		return
   417  	}
   418  	return
   419  }