github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/proxy/vmess/encoding/server.go (about)

     1  package encoding
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/sha256"
     8  	"encoding/binary"
     9  	"hash/fnv"
    10  	"io"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/xmplusdev/xmcore/common"
    15  	"github.com/xmplusdev/xmcore/common/bitmask"
    16  	"github.com/xmplusdev/xmcore/common/buf"
    17  	"github.com/xmplusdev/xmcore/common/crypto"
    18  	"github.com/xmplusdev/xmcore/common/drain"
    19  	"github.com/xmplusdev/xmcore/common/net"
    20  	"github.com/xmplusdev/xmcore/common/protocol"
    21  	"github.com/xmplusdev/xmcore/common/task"
    22  	"github.com/xmplusdev/xmcore/proxy/vmess"
    23  	vmessaead "github.com/xmplusdev/xmcore/proxy/vmess/aead"
    24  	"golang.org/x/crypto/chacha20poly1305"
    25  )
    26  
    27  type sessionID struct {
    28  	user  [16]byte
    29  	key   [16]byte
    30  	nonce [16]byte
    31  }
    32  
    33  // SessionHistory keeps track of historical session ids, to prevent replay attacks.
    34  type SessionHistory struct {
    35  	sync.RWMutex
    36  	cache map[sessionID]time.Time
    37  	task  *task.Periodic
    38  }
    39  
    40  // NewSessionHistory creates a new SessionHistory object.
    41  func NewSessionHistory() *SessionHistory {
    42  	h := &SessionHistory{
    43  		cache: make(map[sessionID]time.Time, 128),
    44  	}
    45  	h.task = &task.Periodic{
    46  		Interval: time.Second * 30,
    47  		Execute:  h.removeExpiredEntries,
    48  	}
    49  	return h
    50  }
    51  
    52  // Close implements common.Closable.
    53  func (h *SessionHistory) Close() error {
    54  	return h.task.Close()
    55  }
    56  
    57  func (h *SessionHistory) addIfNotExits(session sessionID) bool {
    58  	h.Lock()
    59  
    60  	if expire, found := h.cache[session]; found && expire.After(time.Now()) {
    61  		h.Unlock()
    62  		return false
    63  	}
    64  
    65  	h.cache[session] = time.Now().Add(time.Minute * 3)
    66  	h.Unlock()
    67  	common.Must(h.task.Start())
    68  	return true
    69  }
    70  
    71  func (h *SessionHistory) removeExpiredEntries() error {
    72  	now := time.Now()
    73  
    74  	h.Lock()
    75  	defer h.Unlock()
    76  
    77  	if len(h.cache) == 0 {
    78  		return newError("nothing to do")
    79  	}
    80  
    81  	for session, expire := range h.cache {
    82  		if expire.Before(now) {
    83  			delete(h.cache, session)
    84  		}
    85  	}
    86  
    87  	if len(h.cache) == 0 {
    88  		h.cache = make(map[sessionID]time.Time, 128)
    89  	}
    90  
    91  	return nil
    92  }
    93  
    94  // ServerSession keeps information for a session in VMess server.
    95  type ServerSession struct {
    96  	userValidator   *vmess.TimedUserValidator
    97  	sessionHistory  *SessionHistory
    98  	requestBodyKey  [16]byte
    99  	requestBodyIV   [16]byte
   100  	responseBodyKey [16]byte
   101  	responseBodyIV  [16]byte
   102  	responseWriter  io.Writer
   103  	responseHeader  byte
   104  }
   105  
   106  // NewServerSession creates a new ServerSession, using the given UserValidator.
   107  // The ServerSession instance doesn't take ownership of the validator.
   108  func NewServerSession(validator *vmess.TimedUserValidator, sessionHistory *SessionHistory) *ServerSession {
   109  	return &ServerSession{
   110  		userValidator:  validator,
   111  		sessionHistory: sessionHistory,
   112  	}
   113  }
   114  
   115  func parseSecurityType(b byte) protocol.SecurityType {
   116  	if _, f := protocol.SecurityType_name[int32(b)]; f {
   117  		st := protocol.SecurityType(b)
   118  		// For backward compatibility.
   119  		if st == protocol.SecurityType_UNKNOWN {
   120  			st = protocol.SecurityType_AUTO
   121  		}
   122  		return st
   123  	}
   124  	return protocol.SecurityType_UNKNOWN
   125  }
   126  
   127  // DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream.
   128  func (s *ServerSession) DecodeRequestHeader(reader io.Reader, isDrain bool) (*protocol.RequestHeader, error) {
   129  	buffer := buf.New()
   130  
   131  	drainer, err := drain.NewBehaviorSeedLimitedDrainer(int64(s.userValidator.GetBehaviorSeed()), 16+38, 3266, 64)
   132  	if err != nil {
   133  		return nil, newError("failed to initialize drainer").Base(err)
   134  	}
   135  
   136  	drainConnection := func(e error) error {
   137  		// We read a deterministic generated length of data before closing the connection to offset padding read pattern
   138  		drainer.AcknowledgeReceive(int(buffer.Len()))
   139  		if isDrain {
   140  			return drain.WithError(drainer, reader, e)
   141  		}
   142  		return e
   143  	}
   144  
   145  	defer func() {
   146  		buffer.Release()
   147  	}()
   148  
   149  	if _, err := buffer.ReadFullFrom(reader, protocol.IDBytesLen); err != nil {
   150  		return nil, newError("failed to read request header").Base(err)
   151  	}
   152  
   153  	var decryptor io.Reader
   154  	var vmessAccount *vmess.MemoryAccount
   155  
   156  	user, foundAEAD, errorAEAD := s.userValidator.GetAEAD(buffer.Bytes())
   157  
   158  	var fixedSizeAuthID [16]byte
   159  	copy(fixedSizeAuthID[:], buffer.Bytes())
   160  
   161  	switch {
   162  	case foundAEAD:
   163  		vmessAccount = user.Account.(*vmess.MemoryAccount)
   164  		var fixedSizeCmdKey [16]byte
   165  		copy(fixedSizeCmdKey[:], vmessAccount.ID.CmdKey())
   166  		aeadData, shouldDrain, bytesRead, errorReason := vmessaead.OpenVMessAEADHeader(fixedSizeCmdKey, fixedSizeAuthID, reader)
   167  		if errorReason != nil {
   168  			if shouldDrain {
   169  				drainer.AcknowledgeReceive(bytesRead)
   170  				return nil, drainConnection(newError("AEAD read failed").Base(errorReason))
   171  			} else {
   172  				return nil, drainConnection(newError("AEAD read failed, drain skipped").Base(errorReason))
   173  			}
   174  		}
   175  		decryptor = bytes.NewReader(aeadData)
   176  	default:
   177  		return nil, drainConnection(newError("invalid user").Base(errorAEAD))
   178  	}
   179  
   180  	drainer.AcknowledgeReceive(int(buffer.Len()))
   181  	buffer.Clear()
   182  	if _, err := buffer.ReadFullFrom(decryptor, 38); err != nil {
   183  		return nil, newError("failed to read request header").Base(err)
   184  	}
   185  
   186  	request := &protocol.RequestHeader{
   187  		User:    user,
   188  		Version: buffer.Byte(0),
   189  	}
   190  
   191  	copy(s.requestBodyIV[:], buffer.BytesRange(1, 17))   // 16 bytes
   192  	copy(s.requestBodyKey[:], buffer.BytesRange(17, 33)) // 16 bytes
   193  	var sid sessionID
   194  	copy(sid.user[:], vmessAccount.ID.Bytes())
   195  	sid.key = s.requestBodyKey
   196  	sid.nonce = s.requestBodyIV
   197  	if !s.sessionHistory.addIfNotExits(sid) {
   198  		return nil, newError("duplicated session id, possibly under replay attack, but this is a AEAD request")
   199  	}
   200  
   201  	s.responseHeader = buffer.Byte(33)             // 1 byte
   202  	request.Option = bitmask.Byte(buffer.Byte(34)) // 1 byte
   203  	paddingLen := int(buffer.Byte(35) >> 4)
   204  	request.Security = parseSecurityType(buffer.Byte(35) & 0x0F)
   205  	// 1 bytes reserved
   206  	request.Command = protocol.RequestCommand(buffer.Byte(37))
   207  
   208  	switch request.Command {
   209  	case protocol.RequestCommandMux:
   210  		request.Address = net.DomainAddress("v1.mux.cool")
   211  		request.Port = 0
   212  
   213  	case protocol.RequestCommandTCP, protocol.RequestCommandUDP:
   214  		if addr, port, err := addrParser.ReadAddressPort(buffer, decryptor); err == nil {
   215  			request.Address = addr
   216  			request.Port = port
   217  		}
   218  	}
   219  
   220  	if paddingLen > 0 {
   221  		if _, err := buffer.ReadFullFrom(decryptor, int32(paddingLen)); err != nil {
   222  			return nil, newError("failed to read padding").Base(err)
   223  		}
   224  	}
   225  
   226  	if _, err := buffer.ReadFullFrom(decryptor, 4); err != nil {
   227  		return nil, newError("failed to read checksum").Base(err)
   228  	}
   229  
   230  	fnv1a := fnv.New32a()
   231  	common.Must2(fnv1a.Write(buffer.BytesTo(-4)))
   232  	actualHash := fnv1a.Sum32()
   233  	expectedHash := binary.BigEndian.Uint32(buffer.BytesFrom(-4))
   234  
   235  	if actualHash != expectedHash {
   236  		return nil, newError("invalid auth, but this is a AEAD request")
   237  	}
   238  
   239  	if request.Address == nil {
   240  		return nil, newError("invalid remote address")
   241  	}
   242  
   243  	if request.Security == protocol.SecurityType_UNKNOWN || request.Security == protocol.SecurityType_AUTO {
   244  		return nil, newError("unknown security type: ", request.Security)
   245  	}
   246  
   247  	return request, nil
   248  }
   249  
   250  // DecodeRequestBody returns Reader from which caller can fetch decrypted body.
   251  func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reader io.Reader) (buf.Reader, error) {
   252  	var sizeParser crypto.ChunkSizeDecoder = crypto.PlainChunkSizeParser{}
   253  	if request.Option.Has(protocol.RequestOptionChunkMasking) {
   254  		sizeParser = NewShakeSizeParser(s.requestBodyIV[:])
   255  	}
   256  	var padding crypto.PaddingLengthGenerator
   257  	if request.Option.Has(protocol.RequestOptionGlobalPadding) {
   258  		var ok bool
   259  		padding, ok = sizeParser.(crypto.PaddingLengthGenerator)
   260  		if !ok {
   261  			return nil, newError("invalid option: RequestOptionGlobalPadding")
   262  		}
   263  	}
   264  
   265  	switch request.Security {
   266  	case protocol.SecurityType_NONE:
   267  		if request.Option.Has(protocol.RequestOptionChunkStream) {
   268  			if request.Command.TransferType() == protocol.TransferTypeStream {
   269  				return crypto.NewChunkStreamReader(sizeParser, reader), nil
   270  			}
   271  
   272  			auth := &crypto.AEADAuthenticator{
   273  				AEAD:                    new(NoOpAuthenticator),
   274  				NonceGenerator:          crypto.GenerateEmptyBytes(),
   275  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   276  			}
   277  			return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket, padding), nil
   278  		}
   279  		return buf.NewReader(reader), nil
   280  
   281  	case protocol.SecurityType_AES128_GCM:
   282  		aead := crypto.NewAesGcm(s.requestBodyKey[:])
   283  		auth := &crypto.AEADAuthenticator{
   284  			AEAD:                    aead,
   285  			NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
   286  			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   287  		}
   288  		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
   289  			AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len")
   290  			AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey)
   291  
   292  			lengthAuth := &crypto.AEADAuthenticator{
   293  				AEAD:                    AuthenticatedLengthKeyAEAD,
   294  				NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
   295  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   296  			}
   297  			sizeParser = NewAEADSizeParser(lengthAuth)
   298  		}
   299  		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding), nil
   300  
   301  	case protocol.SecurityType_CHACHA20_POLY1305:
   302  		aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(s.requestBodyKey[:]))
   303  
   304  		auth := &crypto.AEADAuthenticator{
   305  			AEAD:                    aead,
   306  			NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
   307  			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   308  		}
   309  		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
   310  			AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len")
   311  			AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey))
   312  			common.Must(err)
   313  
   314  			lengthAuth := &crypto.AEADAuthenticator{
   315  				AEAD:                    AuthenticatedLengthKeyAEAD,
   316  				NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
   317  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   318  			}
   319  			sizeParser = NewAEADSizeParser(lengthAuth)
   320  		}
   321  		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding), nil
   322  
   323  	default:
   324  		return nil, newError("invalid option: Security")
   325  	}
   326  }
   327  
   328  // EncodeResponseHeader writes encoded response header into the given writer.
   329  func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, writer io.Writer) {
   330  	var encryptionWriter io.Writer
   331  	BodyKey := sha256.Sum256(s.requestBodyKey[:])
   332  	copy(s.responseBodyKey[:], BodyKey[:16])
   333  	BodyIV := sha256.Sum256(s.requestBodyIV[:])
   334  	copy(s.responseBodyIV[:], BodyIV[:16])
   335  
   336  	aesStream := crypto.NewAesEncryptionStream(s.responseBodyKey[:], s.responseBodyIV[:])
   337  	encryptionWriter = crypto.NewCryptionWriter(aesStream, writer)
   338  	s.responseWriter = encryptionWriter
   339  
   340  	aeadEncryptedHeaderBuffer := bytes.NewBuffer(nil)
   341  	encryptionWriter = aeadEncryptedHeaderBuffer
   342  
   343  	common.Must2(encryptionWriter.Write([]byte{s.responseHeader, byte(header.Option)}))
   344  	err := MarshalCommand(header.Command, encryptionWriter)
   345  	if err != nil {
   346  		common.Must2(encryptionWriter.Write([]byte{0x00, 0x00}))
   347  	}
   348  
   349  	aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey)
   350  	aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12]
   351  
   352  	aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block)
   353  	aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD)
   354  
   355  	aeadResponseHeaderLengthEncryptionBuffer := bytes.NewBuffer(nil)
   356  
   357  	decryptedResponseHeaderLengthBinaryDeserializeBuffer := uint16(aeadEncryptedHeaderBuffer.Len())
   358  
   359  	common.Must(binary.Write(aeadResponseHeaderLengthEncryptionBuffer, binary.BigEndian, decryptedResponseHeaderLengthBinaryDeserializeBuffer))
   360  
   361  	AEADEncryptedLength := aeadResponseHeaderLengthEncryptionAEAD.Seal(nil, aeadResponseHeaderLengthEncryptionIV, aeadResponseHeaderLengthEncryptionBuffer.Bytes(), nil)
   362  	common.Must2(io.Copy(writer, bytes.NewReader(AEADEncryptedLength)))
   363  
   364  	aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey)
   365  	aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12]
   366  
   367  	aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block)
   368  	aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD)
   369  
   370  	aeadEncryptedHeaderPayload := aeadResponseHeaderPayloadEncryptionAEAD.Seal(nil, aeadResponseHeaderPayloadEncryptionIV, aeadEncryptedHeaderBuffer.Bytes(), nil)
   371  	common.Must2(io.Copy(writer, bytes.NewReader(aeadEncryptedHeaderPayload)))
   372  }
   373  
   374  // EncodeResponseBody returns a Writer that auto-encrypt content written by caller.
   375  func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
   376  	var sizeParser crypto.ChunkSizeEncoder = crypto.PlainChunkSizeParser{}
   377  	if request.Option.Has(protocol.RequestOptionChunkMasking) {
   378  		sizeParser = NewShakeSizeParser(s.responseBodyIV[:])
   379  	}
   380  	var padding crypto.PaddingLengthGenerator
   381  	if request.Option.Has(protocol.RequestOptionGlobalPadding) {
   382  		var ok bool
   383  		padding, ok = sizeParser.(crypto.PaddingLengthGenerator)
   384  		if !ok {
   385  			return nil, newError("invalid option: RequestOptionGlobalPadding")
   386  		}
   387  	}
   388  
   389  	switch request.Security {
   390  	case protocol.SecurityType_NONE:
   391  		if request.Option.Has(protocol.RequestOptionChunkStream) {
   392  			if request.Command.TransferType() == protocol.TransferTypeStream {
   393  				return crypto.NewChunkStreamWriter(sizeParser, writer), nil
   394  			}
   395  
   396  			auth := &crypto.AEADAuthenticator{
   397  				AEAD:                    new(NoOpAuthenticator),
   398  				NonceGenerator:          crypto.GenerateEmptyBytes(),
   399  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   400  			}
   401  			return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket, padding), nil
   402  		}
   403  		return buf.NewWriter(writer), nil
   404  
   405  	case protocol.SecurityType_AES128_GCM:
   406  		aead := crypto.NewAesGcm(s.responseBodyKey[:])
   407  		auth := &crypto.AEADAuthenticator{
   408  			AEAD:                    aead,
   409  			NonceGenerator:          GenerateChunkNonce(s.responseBodyIV[:], uint32(aead.NonceSize())),
   410  			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   411  		}
   412  		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
   413  			AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len")
   414  			AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey)
   415  
   416  			lengthAuth := &crypto.AEADAuthenticator{
   417  				AEAD:                    AuthenticatedLengthKeyAEAD,
   418  				NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
   419  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   420  			}
   421  			sizeParser = NewAEADSizeParser(lengthAuth)
   422  		}
   423  		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding), nil
   424  
   425  	case protocol.SecurityType_CHACHA20_POLY1305:
   426  		aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(s.responseBodyKey[:]))
   427  
   428  		auth := &crypto.AEADAuthenticator{
   429  			AEAD:                    aead,
   430  			NonceGenerator:          GenerateChunkNonce(s.responseBodyIV[:], uint32(aead.NonceSize())),
   431  			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   432  		}
   433  		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
   434  			AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len")
   435  			AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey))
   436  			common.Must(err)
   437  
   438  			lengthAuth := &crypto.AEADAuthenticator{
   439  				AEAD:                    AuthenticatedLengthKeyAEAD,
   440  				NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
   441  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   442  			}
   443  			sizeParser = NewAEADSizeParser(lengthAuth)
   444  		}
   445  		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding), nil
   446  
   447  	default:
   448  		return nil, newError("invalid option: Security")
   449  	}
   450  }