github.com/MerlinKodo/sing-shadowsocks@v0.2.6/shadowstream/protocol.go (about)

     1  package shadowstream
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"crypto/md5"
     7  	"crypto/rand"
     8  	"crypto/rc4"
     9  	"io"
    10  	"net"
    11  	"os"
    12  
    13  	shadowsocks "github.com/MerlinKodo/sing-shadowsocks"
    14  	"github.com/sagernet/sing/common"
    15  	"github.com/sagernet/sing/common/buf"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  
    19  	"github.com/aead/chacha20/chacha"
    20  	"golang.org/x/crypto/chacha20"
    21  )
    22  
    23  var List = []string{
    24  	"aes-128-ctr",
    25  	"aes-192-ctr",
    26  	"aes-256-ctr",
    27  	"aes-128-cfb",
    28  	"aes-192-cfb",
    29  	"aes-256-cfb",
    30  	"rc4-md5",
    31  	"chacha20-ietf",
    32  	"xchacha20",
    33  	"chacha20",
    34  }
    35  
    36  type Method struct {
    37  	name               string
    38  	keyLength          int
    39  	saltLength         int
    40  	encryptConstructor func(key []byte, salt []byte) (cipher.Stream, error)
    41  	decryptConstructor func(key []byte, salt []byte) (cipher.Stream, error)
    42  	key                []byte
    43  }
    44  
    45  func New(method string, key []byte, password string) (shadowsocks.Method, error) {
    46  	m := &Method{
    47  		name: method,
    48  	}
    49  	switch method {
    50  	case "aes-128-ctr":
    51  		m.keyLength = 16
    52  		m.saltLength = aes.BlockSize
    53  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    54  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    55  	case "aes-192-ctr":
    56  		m.keyLength = 24
    57  		m.saltLength = aes.BlockSize
    58  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    59  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    60  	case "aes-256-ctr":
    61  		m.keyLength = 32
    62  		m.saltLength = aes.BlockSize
    63  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    64  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR)
    65  	case "aes-128-cfb":
    66  		m.keyLength = 16
    67  		m.saltLength = aes.BlockSize
    68  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
    69  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
    70  	case "aes-192-cfb":
    71  		m.keyLength = 24
    72  		m.saltLength = aes.BlockSize
    73  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
    74  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
    75  	case "aes-256-cfb":
    76  		m.keyLength = 32
    77  		m.saltLength = aes.BlockSize
    78  		m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter)
    79  		m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter)
    80  	case "rc4-md5":
    81  		m.keyLength = 16
    82  		m.saltLength = 16
    83  		m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
    84  			h := md5.New()
    85  			h.Write(key)
    86  			h.Write(salt)
    87  			return rc4.NewCipher(h.Sum(nil))
    88  		}
    89  		m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
    90  			h := md5.New()
    91  			h.Write(key)
    92  			h.Write(salt)
    93  			return rc4.NewCipher(h.Sum(nil))
    94  		}
    95  	case "chacha20-ietf":
    96  		m.keyLength = chacha20.KeySize
    97  		m.saltLength = chacha20.NonceSize
    98  		m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
    99  			return chacha20.NewUnauthenticatedCipher(key, salt)
   100  		}
   101  		m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   102  			return chacha20.NewUnauthenticatedCipher(key, salt)
   103  		}
   104  	case "xchacha20":
   105  		m.keyLength = chacha20.KeySize
   106  		m.saltLength = chacha20.NonceSizeX
   107  		m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   108  			return chacha20.NewUnauthenticatedCipher(key, salt)
   109  		}
   110  		m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   111  			return chacha20.NewUnauthenticatedCipher(key, salt)
   112  		}
   113  	case "chacha20":
   114  		m.keyLength = chacha.KeySize
   115  		m.saltLength = chacha.NonceSize
   116  		m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   117  			return chacha.NewCipher(salt, key, 20)
   118  		}
   119  		m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) {
   120  			return chacha.NewCipher(salt, key, 20)
   121  		}
   122  	default:
   123  		return nil, os.ErrInvalid
   124  	}
   125  	if len(key) == m.keyLength {
   126  		m.key = key
   127  	} else if len(key) > 0 {
   128  		return nil, shadowsocks.ErrBadKey
   129  	} else if password != "" {
   130  		m.key = shadowsocks.Key([]byte(password), m.keyLength)
   131  	} else {
   132  		return nil, shadowsocks.ErrMissingPassword
   133  	}
   134  	return m, nil
   135  }
   136  
   137  func blockStream(blockCreator func(key []byte) (cipher.Block, error), streamCreator func(block cipher.Block, iv []byte) cipher.Stream) func([]byte, []byte) (cipher.Stream, error) {
   138  	return func(key []byte, iv []byte) (cipher.Stream, error) {
   139  		block, err := blockCreator(key)
   140  		if err != nil {
   141  			return nil, err
   142  		}
   143  		return streamCreator(block, iv), err
   144  	}
   145  }
   146  
   147  func (m *Method) Name() string {
   148  	return m.name
   149  }
   150  
   151  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
   152  	shadowsocksConn := &clientConn{
   153  		Method:      m,
   154  		Conn:        conn,
   155  		destination: destination,
   156  	}
   157  	return shadowsocksConn, shadowsocksConn.writeRequest()
   158  }
   159  
   160  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
   161  	return &clientConn{
   162  		Method:      m,
   163  		Conn:        conn,
   164  		destination: destination,
   165  	}
   166  }
   167  
   168  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   169  	return &clientPacketConn{m, conn}
   170  }
   171  
   172  type clientConn struct {
   173  	*Method
   174  	net.Conn
   175  	destination M.Socksaddr
   176  	readStream  cipher.Stream
   177  	writeStream cipher.Stream
   178  }
   179  
   180  func (c *clientConn) writeRequest() error {
   181  	buffer := buf.NewSize(c.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination))
   182  	defer buffer.Release()
   183  
   184  	salt := buffer.Extend(c.saltLength)
   185  	common.Must1(io.ReadFull(rand.Reader, salt))
   186  
   187  	stream, err := c.encryptConstructor(c.key, salt)
   188  	if err != nil {
   189  		return err
   190  	}
   191  
   192  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination)
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
   198  
   199  	_, err = c.Conn.Write(buffer.Bytes())
   200  	if err != nil {
   201  		return err
   202  	}
   203  
   204  	c.writeStream = stream
   205  	return nil
   206  }
   207  
   208  func (c *clientConn) readResponse() error {
   209  	if c.readStream != nil {
   210  		return nil
   211  	}
   212  	salt := make([]byte, c.saltLength)
   213  	_, err := io.ReadFull(c.Conn, salt)
   214  	if err != nil {
   215  		return err
   216  	}
   217  	c.readStream, err = c.decryptConstructor(c.key, salt)
   218  	return err
   219  }
   220  
   221  func (c *clientConn) Read(p []byte) (n int, err error) {
   222  	if c.readStream == nil {
   223  		err = c.readResponse()
   224  		if err != nil {
   225  			return
   226  		}
   227  	}
   228  	n, err = c.Conn.Read(p)
   229  	if err != nil {
   230  		return
   231  	}
   232  	c.readStream.XORKeyStream(p[:n], p[:n])
   233  	return
   234  }
   235  
   236  func (c *clientConn) Write(p []byte) (n int, err error) {
   237  	if c.writeStream == nil {
   238  		err = c.writeRequest()
   239  		if err != nil {
   240  			return
   241  		}
   242  	}
   243  
   244  	c.writeStream.XORKeyStream(p, p)
   245  	return c.Conn.Write(p)
   246  }
   247  
   248  func (c *clientConn) NeedAdditionalReadDeadline() bool {
   249  	return true
   250  }
   251  
   252  func (c *clientConn) Upstream() any {
   253  	return c.Conn
   254  }
   255  
   256  type clientPacketConn struct {
   257  	*Method
   258  	net.Conn
   259  }
   260  
   261  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   262  	defer buffer.Release()
   263  	header := buf.With(buffer.ExtendHeader(c.saltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
   264  	common.Must1(header.ReadFullFrom(rand.Reader, c.saltLength))
   265  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   266  	if err != nil {
   267  		return err
   268  	}
   269  	stream, err := c.encryptConstructor(c.key, buffer.To(c.saltLength))
   270  	if err != nil {
   271  		return err
   272  	}
   273  	stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
   274  	return common.Error(c.Write(buffer.Bytes()))
   275  }
   276  
   277  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
   278  	n, err := c.Read(buffer.FreeBytes())
   279  	if err != nil {
   280  		return M.Socksaddr{}, err
   281  	}
   282  	buffer.Truncate(n)
   283  	stream, err := c.decryptConstructor(c.key, buffer.To(c.saltLength))
   284  	if err != nil {
   285  		return M.Socksaddr{}, err
   286  	}
   287  	stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
   288  	buffer.Advance(c.saltLength)
   289  	return M.SocksaddrSerializer.ReadAddrPort(buffer)
   290  }
   291  
   292  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   293  	n, err = c.Read(p)
   294  	if err != nil {
   295  		return
   296  	}
   297  	stream, err := c.decryptConstructor(c.key, p[:c.saltLength])
   298  	if err != nil {
   299  		return
   300  	}
   301  	buffer := buf.As(p[c.saltLength:n])
   302  	stream.XORKeyStream(buffer.Bytes(), buffer.Bytes())
   303  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   304  	if err != nil {
   305  		return
   306  	}
   307  	if destination.IsFqdn() {
   308  		addr = destination
   309  	} else {
   310  		addr = destination.UDPAddr()
   311  	}
   312  	n = copy(p, buffer.Bytes())
   313  	return
   314  }
   315  
   316  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   317  	destination := M.SocksaddrFromNet(addr)
   318  	buffer := buf.NewSize(c.saltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
   319  	defer buffer.Release()
   320  	common.Must1(buffer.ReadFullFrom(rand.Reader, c.saltLength))
   321  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, M.SocksaddrFromNet(addr))
   322  	if err != nil {
   323  		return
   324  	}
   325  	_, err = buffer.Write(p)
   326  	if err != nil {
   327  		return
   328  	}
   329  	stream, err := c.encryptConstructor(c.key, buffer.To(c.saltLength))
   330  	if err != nil {
   331  		return
   332  	}
   333  	stream.XORKeyStream(buffer.From(c.saltLength), buffer.From(c.saltLength))
   334  	_, err = c.Write(buffer.Bytes())
   335  	if err != nil {
   336  		return
   337  	}
   338  	return len(p), nil
   339  }
   340  
   341  func (c *clientPacketConn) FrontHeadroom() int {
   342  	return c.saltLength + M.MaxSocksaddrLength
   343  }
   344  
   345  func (c *clientPacketConn) Upstream() any {
   346  	return c.Conn
   347  }