github.com/sagernet/sing-shadowsocks2@v0.2.0/shadowstream/method.go (about)

     1  package shadowstream
     2  
     3  import (
     4  	"context"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/md5"
     8  	"crypto/rc4"
     9  	"net"
    10  	"os"
    11  
    12  	C "github.com/sagernet/sing-shadowsocks2/cipher"
    13  	"github.com/sagernet/sing-shadowsocks2/internal/legacykey"
    14  	"github.com/sagernet/sing/common"
    15  	"github.com/sagernet/sing/common/buf"
    16  	"github.com/sagernet/sing/common/bufio"
    17  	E "github.com/sagernet/sing/common/exceptions"
    18  	M "github.com/sagernet/sing/common/metadata"
    19  	N "github.com/sagernet/sing/common/network"
    20  
    21  	"golang.org/x/crypto/chacha20"
    22  )
    23  
    24  var MethodList = []string{
    25  	"aes-128-ctr",
    26  	"aes-192-ctr",
    27  	"aes-256-ctr",
    28  	"aes-128-cfb",
    29  	"aes-192-cfb",
    30  	"aes-256-cfb",
    31  	"rc4-md5",
    32  	"chacha20-ietf",
    33  	"xchacha20",
    34  }
    35  
    36  func init() {
    37  	C.RegisterMethod(MethodList, NewMethod)
    38  }
    39  
    40  type Method struct {
    41  	keyLength          int
    42  	saltLength         int
    43  	encryptConstructor func(key []byte, salt []byte) (cipher.Stream, error)
    44  	decryptConstructor func(key []byte, salt []byte) (cipher.Stream, error)
    45  	key                []byte
    46  }
    47  
    48  func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) {
    49  	m := &Method{}
    50  	switch methodName {
    51  	case "aes-128-ctr":
    52  		m.keyLength = 16
    53  		m.saltLength = aes.BlockSize
    54  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    55  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    56  	case "aes-192-ctr":
    57  		m.keyLength = 24
    58  		m.saltLength = aes.BlockSize
    59  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    60  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    61  	case "aes-256-ctr":
    62  		m.keyLength = 32
    63  		m.saltLength = aes.BlockSize
    64  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    65  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    66  	case "aes-128-cfb":
    67  		m.keyLength = 16
    68  		m.saltLength = aes.BlockSize
    69  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
    70  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
    71  	case "aes-192-cfb":
    72  		m.keyLength = 24
    73  		m.saltLength = aes.BlockSize
    74  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
    75  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
    76  	case "aes-256-cfb":
    77  		m.keyLength = 32
    78  		m.saltLength = aes.BlockSize
    79  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
    80  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
    81  	case "rc4-md5":
    82  		m.keyLength = 16
    83  		m.saltLength = 16
    84  		m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
    85  			h := md5.New()
    86  			h.Write(key)
    87  			h.Write(salt)
    88  			return rc4.NewCipher(h.Sum(nil))
    89  		}
    90  		m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
    91  			h := md5.New()
    92  			h.Write(key)
    93  			h.Write(salt)
    94  			return rc4.NewCipher(h.Sum(nil))
    95  		}
    96  	case "chacha20-ietf":
    97  		m.keyLength = chacha20.KeySize
    98  		m.saltLength = chacha20.NonceSize
    99  		m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   100  			return chacha20.NewUnauthenticatedCipher(key, salt)
   101  		}
   102  		m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   103  			return chacha20.NewUnauthenticatedCipher(key, salt)
   104  		}
   105  	case "xchacha20":
   106  		m.keyLength = chacha20.KeySize
   107  		m.saltLength = chacha20.NonceSizeX
   108  		m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   109  			return chacha20.NewUnauthenticatedCipher(key, salt)
   110  		}
   111  		m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   112  			return chacha20.NewUnauthenticatedCipher(key, salt)
   113  		}
   114  	default:
   115  		return nil, os.ErrInvalid
   116  	}
   117  	if len(options.Key) == m.keyLength {
   118  		m.key = options.Key
   119  	} else if len(options.Key) > 0 {
   120  		return nil, E.New("bad key length, required ", m.keyLength, ", got ", len(options.Key))
   121  	} else if options.Password != "" {
   122  		m.key = legacykey.Key([]byte(options.Password), m.keyLength)
   123  	} else {
   124  		return nil, C.ErrMissingPassword
   125  	}
   126  	return m, nil
   127  }
   128  
   129  func blockStream(blockCreator func(key []byte) (cipher.Block, error), streamCreator func(block cipher.Block, iv []byte) cipher.Stream) func([]byte, []byte) (cipher.Stream, error) {
   130  	return func(key []byte, iv []byte) (cipher.Stream, error) {
   131  		block, err := blockCreator(key)
   132  		if err != nil {
   133  			return nil, err
   134  		}
   135  		return streamCreator(block, iv), err
   136  	}
   137  }
   138  
   139  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
   140  	ssConn := &clientConn{
   141  		ExtendedConn: bufio.NewExtendedConn(conn),
   142  		method:       m,
   143  		destination:  destination,
   144  	}
   145  	return ssConn, common.Error(ssConn.Write(nil))
   146  }
   147  
   148  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
   149  	return &clientConn{
   150  		ExtendedConn: bufio.NewExtendedConn(conn),
   151  		method:       m,
   152  		destination:  destination,
   153  	}
   154  }
   155  
   156  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   157  	return &clientPacketConn{
   158  		ExtendedConn: bufio.NewExtendedConn(conn),
   159  		method:       m,
   160  	}
   161  }
   162  
   163  type clientConn struct {
   164  	N.ExtendedConn
   165  	method      *Method
   166  	destination M.Socksaddr
   167  	readStream  cipher.Stream
   168  	writeStream cipher.Stream
   169  }
   170  
   171  func (c *clientConn) readResponse() error {
   172  	saltBuffer := buf.NewSize(c.method.saltLength)
   173  	defer saltBuffer.Release()
   174  	_, err := saltBuffer.ReadFullFrom(c.ExtendedConn, c.method.saltLength)
   175  	if err != nil {
   176  		return err
   177  	}
   178  	c.readStream, err = c.method.decryptConstructor(c.method.key, saltBuffer.Bytes())
   179  	return err
   180  }
   181  
   182  func (c *clientConn) Read(p []byte) (n int, err error) {
   183  	if c.readStream == nil {
   184  		err = c.readResponse()
   185  		if err != nil {
   186  			return
   187  		}
   188  	}
   189  	n, err = c.ExtendedConn.Read(p)
   190  	if err != nil {
   191  		return
   192  	}
   193  	c.readStream.XORKeyStream(p[:n], p[:n])
   194  	return
   195  }
   196  
   197  func (c *clientConn) Write(p []byte) (n int, err error) {
   198  	if c.writeStream == nil {
   199  		buffer := buf.NewSize(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(p))
   200  		defer buffer.Release()
   201  		buffer.WriteRandom(c.method.saltLength)
   202  		err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
   203  		if err != nil {
   204  			return
   205  		}
   206  		common.Must1(buffer.Write(p))
   207  		c.writeStream, err = c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength))
   208  		if err != nil {
   209  			return
   210  		}
   211  		c.writeStream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength))
   212  		_, err = c.ExtendedConn.Write(buffer.Bytes())
   213  		if err == nil {
   214  			n = len(p)
   215  		}
   216  		return
   217  	}
   218  	c.writeStream.XORKeyStream(p, p)
   219  	return c.ExtendedConn.Write(p)
   220  }
   221  
   222  func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error {
   223  	if c.readStream == nil {
   224  		err := c.readResponse()
   225  		if err != nil {
   226  			return err
   227  		}
   228  	}
   229  	err := c.ExtendedConn.ReadBuffer(buffer)
   230  	if err != nil {
   231  		return err
   232  	}
   233  	c.readStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
   234  	return nil
   235  }
   236  
   237  func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error {
   238  	if c.writeStream == nil {
   239  		header := buf.With(buffer.ExtendHeader(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination)))
   240  		header.WriteRandom(c.method.saltLength)
   241  		err := M.SocksaddrSerializer.WriteAddrPort(header, c.destination)
   242  		if err != nil {
   243  			return err
   244  		}
   245  		c.writeStream, err = c.method.encryptConstructor(c.method.key, header.To(c.method.saltLength))
   246  		if err != nil {
   247  			return err
   248  		}
   249  		c.writeStream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength))
   250  	} else {
   251  		c.writeStream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
   252  	}
   253  	return c.ExtendedConn.WriteBuffer(buffer)
   254  }
   255  
   256  func (c *clientConn) FrontHeadroom() int {
   257  	if c.writeStream == nil {
   258  		return c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination)
   259  	}
   260  	return 0
   261  }
   262  
   263  func (c *clientConn) NeedHandshake() bool {
   264  	return c.writeStream == nil
   265  }
   266  
   267  func (c *clientConn) Upstream() any {
   268  	return c.ExtendedConn
   269  }
   270  
   271  type clientPacketConn struct {
   272  	N.ExtendedConn
   273  	method *Method
   274  }
   275  
   276  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   277  	err = c.ReadBuffer(buffer)
   278  	if err != nil {
   279  		return
   280  	}
   281  	stream, err := c.method.decryptConstructor(c.method.key, buffer.To(c.method.saltLength))
   282  	if err != nil {
   283  		return
   284  	}
   285  	stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength))
   286  	buffer.Advance(c.method.saltLength)
   287  	destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
   288  	if err != nil {
   289  		return
   290  	}
   291  	return destination.Unwrap(), nil
   292  }
   293  
   294  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   295  	header := buf.With(buffer.ExtendHeader(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
   296  	header.WriteRandom(c.method.saltLength)
   297  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   298  	if err != nil {
   299  		return err
   300  	}
   301  	stream, err := c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength))
   302  	if err != nil {
   303  		return err
   304  	}
   305  	stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength))
   306  	return c.ExtendedConn.WriteBuffer(buffer)
   307  }
   308  
   309  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   310  	n, err = c.ExtendedConn.Read(p)
   311  	if err != nil {
   312  		return
   313  	}
   314  	stream, err := c.method.decryptConstructor(c.method.key, p[:c.method.saltLength])
   315  	if err != nil {
   316  		return
   317  	}
   318  	buffer := buf.As(p[c.method.saltLength:n])
   319  	stream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
   320  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   321  	if err != nil {
   322  		return
   323  	}
   324  	if destination.IsFqdn() {
   325  		addr = destination
   326  	} else {
   327  		addr = destination.UDPAddr()
   328  	}
   329  	n = copy(p, buffer.Bytes())
   330  	return
   331  }
   332  
   333  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   334  	destination := M.SocksaddrFromNet(addr)
   335  	buffer := buf.NewSize(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
   336  	defer buffer.Release()
   337  	buffer.WriteRandom(c.method.saltLength)
   338  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
   339  	if err != nil {
   340  		return
   341  	}
   342  	stream, err := c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength))
   343  	if err != nil {
   344  		return
   345  	}
   346  	stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength))
   347  	stream.XORKeyStream(buffer.Extend(len(p)), p)
   348  	_, err = c.ExtendedConn.Write(buffer.Bytes())
   349  	if err == nil {
   350  		n = len(p)
   351  	}
   352  	return
   353  }
   354  
   355  func (c *clientPacketConn) FrontHeadroom() int {
   356  	return c.method.saltLength + M.MaxSocksaddrLength
   357  }
   358  
   359  func (c *clientPacketConn) Upstream() any {
   360  	return c.ExtendedConn
   361  }