github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/proxy/vmess/encoding/client.go (about)

     1  package encoding
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/aes"
     7  	"crypto/cipher"
     8  	"crypto/rand"
     9  	"crypto/sha256"
    10  	"encoding/binary"
    11  	"hash/fnv"
    12  	"io"
    13  
    14  	"github.com/xmplusdev/xmcore/common"
    15  	"github.com/xmplusdev/xmcore/common/bitmask"
    16  	"github.com/xmplusdev/xmcore/common/buf"
    17  	"github.com/xmplusdev/xmcore/common/crypto"
    18  	"github.com/xmplusdev/xmcore/common/dice"
    19  	"github.com/xmplusdev/xmcore/common/drain"
    20  	"github.com/xmplusdev/xmcore/common/protocol"
    21  	"github.com/xmplusdev/xmcore/proxy/vmess"
    22  	vmessaead "github.com/xmplusdev/xmcore/proxy/vmess/aead"
    23  	"golang.org/x/crypto/chacha20poly1305"
    24  )
    25  
    26  // ClientSession stores connection session info for VMess client.
    27  type ClientSession struct {
    28  	requestBodyKey  [16]byte
    29  	requestBodyIV   [16]byte
    30  	responseBodyKey [16]byte
    31  	responseBodyIV  [16]byte
    32  	responseReader  io.Reader
    33  	responseHeader  byte
    34  
    35  	readDrainer drain.Drainer
    36  }
    37  
    38  // NewClientSession creates a new ClientSession.
    39  func NewClientSession(ctx context.Context, behaviorSeed int64) *ClientSession {
    40  	session := &ClientSession{}
    41  
    42  	randomBytes := make([]byte, 33) // 16 + 16 + 1
    43  	common.Must2(rand.Read(randomBytes))
    44  	copy(session.requestBodyKey[:], randomBytes[:16])
    45  	copy(session.requestBodyIV[:], randomBytes[16:32])
    46  	session.responseHeader = randomBytes[32]
    47  
    48  	BodyKey := sha256.Sum256(session.requestBodyKey[:])
    49  	copy(session.responseBodyKey[:], BodyKey[:16])
    50  	BodyIV := sha256.Sum256(session.requestBodyIV[:])
    51  	copy(session.responseBodyIV[:], BodyIV[:16])
    52  	{
    53  		var err error
    54  		session.readDrainer, err = drain.NewBehaviorSeedLimitedDrainer(behaviorSeed, 18, 3266, 64)
    55  		if err != nil {
    56  			newError("unable to initialize drainer").Base(err).WriteToLog()
    57  			session.readDrainer = drain.NewNopDrainer()
    58  		}
    59  	}
    60  
    61  	return session
    62  }
    63  
    64  func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
    65  	account := header.User.Account.(*vmess.MemoryAccount)
    66  
    67  	buffer := buf.New()
    68  	defer buffer.Release()
    69  
    70  	common.Must(buffer.WriteByte(Version))
    71  	common.Must2(buffer.Write(c.requestBodyIV[:]))
    72  	common.Must2(buffer.Write(c.requestBodyKey[:]))
    73  	common.Must(buffer.WriteByte(c.responseHeader))
    74  	common.Must(buffer.WriteByte(byte(header.Option)))
    75  
    76  	paddingLen := dice.Roll(16)
    77  	security := byte(paddingLen<<4) | byte(header.Security)
    78  	common.Must2(buffer.Write([]byte{security, byte(0), byte(header.Command)}))
    79  
    80  	if header.Command != protocol.RequestCommandMux {
    81  		if err := addrParser.WriteAddressPort(buffer, header.Address, header.Port); err != nil {
    82  			return newError("failed to writer address and port").Base(err)
    83  		}
    84  	}
    85  
    86  	if paddingLen > 0 {
    87  		common.Must2(buffer.ReadFullFrom(rand.Reader, int32(paddingLen)))
    88  	}
    89  
    90  	{
    91  		fnv1a := fnv.New32a()
    92  		common.Must2(fnv1a.Write(buffer.Bytes()))
    93  		hashBytes := buffer.Extend(int32(fnv1a.Size()))
    94  		fnv1a.Sum(hashBytes[:0])
    95  	}
    96  
    97  	var fixedLengthCmdKey [16]byte
    98  	copy(fixedLengthCmdKey[:], account.ID.CmdKey())
    99  	vmessout := vmessaead.SealVMessAEADHeader(fixedLengthCmdKey, buffer.Bytes())
   100  	common.Must2(io.Copy(writer, bytes.NewReader(vmessout)))
   101  
   102  	return nil
   103  }
   104  
   105  func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
   106  	var sizeParser crypto.ChunkSizeEncoder = crypto.PlainChunkSizeParser{}
   107  	if request.Option.Has(protocol.RequestOptionChunkMasking) {
   108  		sizeParser = NewShakeSizeParser(c.requestBodyIV[:])
   109  	}
   110  	var padding crypto.PaddingLengthGenerator
   111  	if request.Option.Has(protocol.RequestOptionGlobalPadding) {
   112  		var ok bool
   113  		padding, ok = sizeParser.(crypto.PaddingLengthGenerator)
   114  		if !ok {
   115  			return nil, newError("invalid option: RequestOptionGlobalPadding")
   116  		}
   117  	}
   118  
   119  	switch request.Security {
   120  	case protocol.SecurityType_NONE:
   121  		if request.Option.Has(protocol.RequestOptionChunkStream) {
   122  			if request.Command.TransferType() == protocol.TransferTypeStream {
   123  				return crypto.NewChunkStreamWriter(sizeParser, writer), nil
   124  			}
   125  			auth := &crypto.AEADAuthenticator{
   126  				AEAD:                    new(NoOpAuthenticator),
   127  				NonceGenerator:          crypto.GenerateEmptyBytes(),
   128  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   129  			}
   130  			return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket, padding), nil
   131  		}
   132  
   133  		return buf.NewWriter(writer), nil
   134  	case protocol.SecurityType_AES128_GCM:
   135  		aead := crypto.NewAesGcm(c.requestBodyKey[:])
   136  		auth := &crypto.AEADAuthenticator{
   137  			AEAD:                    aead,
   138  			NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
   139  			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   140  		}
   141  		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
   142  			AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len")
   143  			AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey)
   144  
   145  			lengthAuth := &crypto.AEADAuthenticator{
   146  				AEAD:                    AuthenticatedLengthKeyAEAD,
   147  				NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
   148  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   149  			}
   150  			sizeParser = NewAEADSizeParser(lengthAuth)
   151  		}
   152  		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding), nil
   153  	case protocol.SecurityType_CHACHA20_POLY1305:
   154  		aead, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(c.requestBodyKey[:]))
   155  		common.Must(err)
   156  
   157  		auth := &crypto.AEADAuthenticator{
   158  			AEAD:                    aead,
   159  			NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
   160  			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   161  		}
   162  		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
   163  			AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len")
   164  			AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey))
   165  			common.Must(err)
   166  
   167  			lengthAuth := &crypto.AEADAuthenticator{
   168  				AEAD:                    AuthenticatedLengthKeyAEAD,
   169  				NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
   170  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   171  			}
   172  			sizeParser = NewAEADSizeParser(lengthAuth)
   173  		}
   174  		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding), nil
   175  	default:
   176  		return nil, newError("invalid option: Security")
   177  	}
   178  }
   179  
   180  func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.ResponseHeader, error) {
   181  	aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey)
   182  	aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12]
   183  
   184  	aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block)
   185  	aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD)
   186  
   187  	var aeadEncryptedResponseHeaderLength [18]byte
   188  	var decryptedResponseHeaderLength int
   189  	var decryptedResponseHeaderLengthBinaryDeserializeBuffer uint16
   190  
   191  	if n, err := io.ReadFull(reader, aeadEncryptedResponseHeaderLength[:]); err != nil {
   192  		c.readDrainer.AcknowledgeReceive(n)
   193  		return nil, drain.WithError(c.readDrainer, reader, newError("Unable to Read Header Len").Base(err))
   194  	} else { // nolint: golint
   195  		c.readDrainer.AcknowledgeReceive(n)
   196  	}
   197  	if decryptedResponseHeaderLengthBinaryBuffer, err := aeadResponseHeaderLengthEncryptionAEAD.Open(nil, aeadResponseHeaderLengthEncryptionIV, aeadEncryptedResponseHeaderLength[:], nil); err != nil {
   198  		return nil, drain.WithError(c.readDrainer, reader, newError("Failed To Decrypt Length").Base(err))
   199  	} else { // nolint: golint
   200  		common.Must(binary.Read(bytes.NewReader(decryptedResponseHeaderLengthBinaryBuffer), binary.BigEndian, &decryptedResponseHeaderLengthBinaryDeserializeBuffer))
   201  		decryptedResponseHeaderLength = int(decryptedResponseHeaderLengthBinaryDeserializeBuffer)
   202  	}
   203  
   204  	aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey)
   205  	aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12]
   206  
   207  	aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block)
   208  	aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD)
   209  
   210  	encryptedResponseHeaderBuffer := make([]byte, decryptedResponseHeaderLength+16)
   211  
   212  	if n, err := io.ReadFull(reader, encryptedResponseHeaderBuffer); err != nil {
   213  		c.readDrainer.AcknowledgeReceive(n)
   214  		return nil, drain.WithError(c.readDrainer, reader, newError("Unable to Read Header Data").Base(err))
   215  	} else { // nolint: golint
   216  		c.readDrainer.AcknowledgeReceive(n)
   217  	}
   218  
   219  	if decryptedResponseHeaderBuffer, err := aeadResponseHeaderPayloadEncryptionAEAD.Open(nil, aeadResponseHeaderPayloadEncryptionIV, encryptedResponseHeaderBuffer, nil); err != nil {
   220  		return nil, drain.WithError(c.readDrainer, reader, newError("Failed To Decrypt Payload").Base(err))
   221  	} else { // nolint: golint
   222  		c.responseReader = bytes.NewReader(decryptedResponseHeaderBuffer)
   223  	}
   224  
   225  	buffer := buf.StackNew()
   226  	defer buffer.Release()
   227  
   228  	if _, err := buffer.ReadFullFrom(c.responseReader, 4); err != nil {
   229  		return nil, newError("failed to read response header").Base(err).AtWarning()
   230  	}
   231  
   232  	if buffer.Byte(0) != c.responseHeader {
   233  		return nil, newError("unexpected response header. Expecting ", int(c.responseHeader), " but actually ", int(buffer.Byte(0)))
   234  	}
   235  
   236  	header := &protocol.ResponseHeader{
   237  		Option: bitmask.Byte(buffer.Byte(1)),
   238  	}
   239  
   240  	if buffer.Byte(2) != 0 {
   241  		cmdID := buffer.Byte(2)
   242  		dataLen := int32(buffer.Byte(3))
   243  
   244  		buffer.Clear()
   245  		if _, err := buffer.ReadFullFrom(c.responseReader, dataLen); err != nil {
   246  			return nil, newError("failed to read response command").Base(err)
   247  		}
   248  		command, err := UnmarshalCommand(cmdID, buffer.Bytes())
   249  		if err == nil {
   250  			header.Command = command
   251  		}
   252  	}
   253  	aesStream := crypto.NewAesDecryptionStream(c.responseBodyKey[:], c.responseBodyIV[:])
   254  	c.responseReader = crypto.NewCryptionReader(aesStream, reader)
   255  	return header, nil
   256  }
   257  
   258  func (c *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, reader io.Reader) (buf.Reader, error) {
   259  	var sizeParser crypto.ChunkSizeDecoder = crypto.PlainChunkSizeParser{}
   260  	if request.Option.Has(protocol.RequestOptionChunkMasking) {
   261  		sizeParser = NewShakeSizeParser(c.responseBodyIV[:])
   262  	}
   263  	var padding crypto.PaddingLengthGenerator
   264  	if request.Option.Has(protocol.RequestOptionGlobalPadding) {
   265  		var ok bool
   266  		padding, ok = sizeParser.(crypto.PaddingLengthGenerator)
   267  		if !ok {
   268  			return nil, newError("invalid option: RequestOptionGlobalPadding")
   269  		}
   270  	}
   271  
   272  	switch request.Security {
   273  	case protocol.SecurityType_NONE:
   274  		if request.Option.Has(protocol.RequestOptionChunkStream) {
   275  			if request.Command.TransferType() == protocol.TransferTypeStream {
   276  				return crypto.NewChunkStreamReader(sizeParser, reader), nil
   277  			}
   278  
   279  			auth := &crypto.AEADAuthenticator{
   280  				AEAD:                    new(NoOpAuthenticator),
   281  				NonceGenerator:          crypto.GenerateEmptyBytes(),
   282  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   283  			}
   284  
   285  			return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket, padding), nil
   286  		}
   287  
   288  		return buf.NewReader(reader), nil
   289  	case protocol.SecurityType_AES128_GCM:
   290  		aead := crypto.NewAesGcm(c.responseBodyKey[:])
   291  
   292  		auth := &crypto.AEADAuthenticator{
   293  			AEAD:                    aead,
   294  			NonceGenerator:          GenerateChunkNonce(c.responseBodyIV[:], uint32(aead.NonceSize())),
   295  			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   296  		}
   297  		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
   298  			AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len")
   299  			AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey)
   300  
   301  			lengthAuth := &crypto.AEADAuthenticator{
   302  				AEAD:                    AuthenticatedLengthKeyAEAD,
   303  				NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
   304  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   305  			}
   306  			sizeParser = NewAEADSizeParser(lengthAuth)
   307  		}
   308  		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding), nil
   309  	case protocol.SecurityType_CHACHA20_POLY1305:
   310  		aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(c.responseBodyKey[:]))
   311  
   312  		auth := &crypto.AEADAuthenticator{
   313  			AEAD:                    aead,
   314  			NonceGenerator:          GenerateChunkNonce(c.responseBodyIV[:], uint32(aead.NonceSize())),
   315  			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   316  		}
   317  		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
   318  			AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len")
   319  			AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey))
   320  			common.Must(err)
   321  
   322  			lengthAuth := &crypto.AEADAuthenticator{
   323  				AEAD:                    AuthenticatedLengthKeyAEAD,
   324  				NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
   325  				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
   326  			}
   327  			sizeParser = NewAEADSizeParser(lengthAuth)
   328  		}
   329  		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding), nil
   330  	default:
   331  		return nil, newError("invalid option: Security")
   332  	}
   333  }
   334  
   335  func GenerateChunkNonce(nonce []byte, size uint32) crypto.BytesGenerator {
   336  	c := append([]byte(nil), nonce...)
   337  	count := uint16(0)
   338  	return func() []byte {
   339  		binary.BigEndian.PutUint16(c, count)
   340  		count++
   341  		return c[:size]
   342  	}
   343  }