github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/proxy/vmess/encoding/server.go (about) 1 package encoding 2 3 import ( 4 "bytes" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/sha256" 8 "encoding/binary" 9 "hash/fnv" 10 "io" 11 "sync" 12 "time" 13 14 "github.com/xmplusdev/xmcore/common" 15 "github.com/xmplusdev/xmcore/common/bitmask" 16 "github.com/xmplusdev/xmcore/common/buf" 17 "github.com/xmplusdev/xmcore/common/crypto" 18 "github.com/xmplusdev/xmcore/common/drain" 19 "github.com/xmplusdev/xmcore/common/net" 20 "github.com/xmplusdev/xmcore/common/protocol" 21 "github.com/xmplusdev/xmcore/common/task" 22 "github.com/xmplusdev/xmcore/proxy/vmess" 23 vmessaead "github.com/xmplusdev/xmcore/proxy/vmess/aead" 24 "golang.org/x/crypto/chacha20poly1305" 25 ) 26 27 type sessionID struct { 28 user [16]byte 29 key [16]byte 30 nonce [16]byte 31 } 32 33 // SessionHistory keeps track of historical session ids, to prevent replay attacks. 34 type SessionHistory struct { 35 sync.RWMutex 36 cache map[sessionID]time.Time 37 task *task.Periodic 38 } 39 40 // NewSessionHistory creates a new SessionHistory object. 41 func NewSessionHistory() *SessionHistory { 42 h := &SessionHistory{ 43 cache: make(map[sessionID]time.Time, 128), 44 } 45 h.task = &task.Periodic{ 46 Interval: time.Second * 30, 47 Execute: h.removeExpiredEntries, 48 } 49 return h 50 } 51 52 // Close implements common.Closable. 53 func (h *SessionHistory) Close() error { 54 return h.task.Close() 55 } 56 57 func (h *SessionHistory) addIfNotExits(session sessionID) bool { 58 h.Lock() 59 60 if expire, found := h.cache[session]; found && expire.After(time.Now()) { 61 h.Unlock() 62 return false 63 } 64 65 h.cache[session] = time.Now().Add(time.Minute * 3) 66 h.Unlock() 67 common.Must(h.task.Start()) 68 return true 69 } 70 71 func (h *SessionHistory) removeExpiredEntries() error { 72 now := time.Now() 73 74 h.Lock() 75 defer h.Unlock() 76 77 if len(h.cache) == 0 { 78 return newError("nothing to do") 79 } 80 81 for session, expire := range h.cache { 82 if expire.Before(now) { 83 delete(h.cache, session) 84 } 85 } 86 87 if len(h.cache) == 0 { 88 h.cache = make(map[sessionID]time.Time, 128) 89 } 90 91 return nil 92 } 93 94 // ServerSession keeps information for a session in VMess server. 95 type ServerSession struct { 96 userValidator *vmess.TimedUserValidator 97 sessionHistory *SessionHistory 98 requestBodyKey [16]byte 99 requestBodyIV [16]byte 100 responseBodyKey [16]byte 101 responseBodyIV [16]byte 102 responseWriter io.Writer 103 responseHeader byte 104 } 105 106 // NewServerSession creates a new ServerSession, using the given UserValidator. 107 // The ServerSession instance doesn't take ownership of the validator. 108 func NewServerSession(validator *vmess.TimedUserValidator, sessionHistory *SessionHistory) *ServerSession { 109 return &ServerSession{ 110 userValidator: validator, 111 sessionHistory: sessionHistory, 112 } 113 } 114 115 func parseSecurityType(b byte) protocol.SecurityType { 116 if _, f := protocol.SecurityType_name[int32(b)]; f { 117 st := protocol.SecurityType(b) 118 // For backward compatibility. 119 if st == protocol.SecurityType_UNKNOWN { 120 st = protocol.SecurityType_AUTO 121 } 122 return st 123 } 124 return protocol.SecurityType_UNKNOWN 125 } 126 127 // DecodeRequestHeader decodes and returns (if successful) a RequestHeader from an input stream. 128 func (s *ServerSession) DecodeRequestHeader(reader io.Reader, isDrain bool) (*protocol.RequestHeader, error) { 129 buffer := buf.New() 130 131 drainer, err := drain.NewBehaviorSeedLimitedDrainer(int64(s.userValidator.GetBehaviorSeed()), 16+38, 3266, 64) 132 if err != nil { 133 return nil, newError("failed to initialize drainer").Base(err) 134 } 135 136 drainConnection := func(e error) error { 137 // We read a deterministic generated length of data before closing the connection to offset padding read pattern 138 drainer.AcknowledgeReceive(int(buffer.Len())) 139 if isDrain { 140 return drain.WithError(drainer, reader, e) 141 } 142 return e 143 } 144 145 defer func() { 146 buffer.Release() 147 }() 148 149 if _, err := buffer.ReadFullFrom(reader, protocol.IDBytesLen); err != nil { 150 return nil, newError("failed to read request header").Base(err) 151 } 152 153 var decryptor io.Reader 154 var vmessAccount *vmess.MemoryAccount 155 156 user, foundAEAD, errorAEAD := s.userValidator.GetAEAD(buffer.Bytes()) 157 158 var fixedSizeAuthID [16]byte 159 copy(fixedSizeAuthID[:], buffer.Bytes()) 160 161 switch { 162 case foundAEAD: 163 vmessAccount = user.Account.(*vmess.MemoryAccount) 164 var fixedSizeCmdKey [16]byte 165 copy(fixedSizeCmdKey[:], vmessAccount.ID.CmdKey()) 166 aeadData, shouldDrain, bytesRead, errorReason := vmessaead.OpenVMessAEADHeader(fixedSizeCmdKey, fixedSizeAuthID, reader) 167 if errorReason != nil { 168 if shouldDrain { 169 drainer.AcknowledgeReceive(bytesRead) 170 return nil, drainConnection(newError("AEAD read failed").Base(errorReason)) 171 } else { 172 return nil, drainConnection(newError("AEAD read failed, drain skipped").Base(errorReason)) 173 } 174 } 175 decryptor = bytes.NewReader(aeadData) 176 default: 177 return nil, drainConnection(newError("invalid user").Base(errorAEAD)) 178 } 179 180 drainer.AcknowledgeReceive(int(buffer.Len())) 181 buffer.Clear() 182 if _, err := buffer.ReadFullFrom(decryptor, 38); err != nil { 183 return nil, newError("failed to read request header").Base(err) 184 } 185 186 request := &protocol.RequestHeader{ 187 User: user, 188 Version: buffer.Byte(0), 189 } 190 191 copy(s.requestBodyIV[:], buffer.BytesRange(1, 17)) // 16 bytes 192 copy(s.requestBodyKey[:], buffer.BytesRange(17, 33)) // 16 bytes 193 var sid sessionID 194 copy(sid.user[:], vmessAccount.ID.Bytes()) 195 sid.key = s.requestBodyKey 196 sid.nonce = s.requestBodyIV 197 if !s.sessionHistory.addIfNotExits(sid) { 198 return nil, newError("duplicated session id, possibly under replay attack, but this is a AEAD request") 199 } 200 201 s.responseHeader = buffer.Byte(33) // 1 byte 202 request.Option = bitmask.Byte(buffer.Byte(34)) // 1 byte 203 paddingLen := int(buffer.Byte(35) >> 4) 204 request.Security = parseSecurityType(buffer.Byte(35) & 0x0F) 205 // 1 bytes reserved 206 request.Command = protocol.RequestCommand(buffer.Byte(37)) 207 208 switch request.Command { 209 case protocol.RequestCommandMux: 210 request.Address = net.DomainAddress("v1.mux.cool") 211 request.Port = 0 212 213 case protocol.RequestCommandTCP, protocol.RequestCommandUDP: 214 if addr, port, err := addrParser.ReadAddressPort(buffer, decryptor); err == nil { 215 request.Address = addr 216 request.Port = port 217 } 218 } 219 220 if paddingLen > 0 { 221 if _, err := buffer.ReadFullFrom(decryptor, int32(paddingLen)); err != nil { 222 return nil, newError("failed to read padding").Base(err) 223 } 224 } 225 226 if _, err := buffer.ReadFullFrom(decryptor, 4); err != nil { 227 return nil, newError("failed to read checksum").Base(err) 228 } 229 230 fnv1a := fnv.New32a() 231 common.Must2(fnv1a.Write(buffer.BytesTo(-4))) 232 actualHash := fnv1a.Sum32() 233 expectedHash := binary.BigEndian.Uint32(buffer.BytesFrom(-4)) 234 235 if actualHash != expectedHash { 236 return nil, newError("invalid auth, but this is a AEAD request") 237 } 238 239 if request.Address == nil { 240 return nil, newError("invalid remote address") 241 } 242 243 if request.Security == protocol.SecurityType_UNKNOWN || request.Security == protocol.SecurityType_AUTO { 244 return nil, newError("unknown security type: ", request.Security) 245 } 246 247 return request, nil 248 } 249 250 // DecodeRequestBody returns Reader from which caller can fetch decrypted body. 251 func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reader io.Reader) (buf.Reader, error) { 252 var sizeParser crypto.ChunkSizeDecoder = crypto.PlainChunkSizeParser{} 253 if request.Option.Has(protocol.RequestOptionChunkMasking) { 254 sizeParser = NewShakeSizeParser(s.requestBodyIV[:]) 255 } 256 var padding crypto.PaddingLengthGenerator 257 if request.Option.Has(protocol.RequestOptionGlobalPadding) { 258 var ok bool 259 padding, ok = sizeParser.(crypto.PaddingLengthGenerator) 260 if !ok { 261 return nil, newError("invalid option: RequestOptionGlobalPadding") 262 } 263 } 264 265 switch request.Security { 266 case protocol.SecurityType_NONE: 267 if request.Option.Has(protocol.RequestOptionChunkStream) { 268 if request.Command.TransferType() == protocol.TransferTypeStream { 269 return crypto.NewChunkStreamReader(sizeParser, reader), nil 270 } 271 272 auth := &crypto.AEADAuthenticator{ 273 AEAD: new(NoOpAuthenticator), 274 NonceGenerator: crypto.GenerateEmptyBytes(), 275 AdditionalDataGenerator: crypto.GenerateEmptyBytes(), 276 } 277 return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket, padding), nil 278 } 279 return buf.NewReader(reader), nil 280 281 case protocol.SecurityType_AES128_GCM: 282 aead := crypto.NewAesGcm(s.requestBodyKey[:]) 283 auth := &crypto.AEADAuthenticator{ 284 AEAD: aead, 285 NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), 286 AdditionalDataGenerator: crypto.GenerateEmptyBytes(), 287 } 288 if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { 289 AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len") 290 AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey) 291 292 lengthAuth := &crypto.AEADAuthenticator{ 293 AEAD: AuthenticatedLengthKeyAEAD, 294 NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), 295 AdditionalDataGenerator: crypto.GenerateEmptyBytes(), 296 } 297 sizeParser = NewAEADSizeParser(lengthAuth) 298 } 299 return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding), nil 300 301 case protocol.SecurityType_CHACHA20_POLY1305: 302 aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(s.requestBodyKey[:])) 303 304 auth := &crypto.AEADAuthenticator{ 305 AEAD: aead, 306 NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), 307 AdditionalDataGenerator: crypto.GenerateEmptyBytes(), 308 } 309 if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { 310 AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len") 311 AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey)) 312 common.Must(err) 313 314 lengthAuth := &crypto.AEADAuthenticator{ 315 AEAD: AuthenticatedLengthKeyAEAD, 316 NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), 317 AdditionalDataGenerator: crypto.GenerateEmptyBytes(), 318 } 319 sizeParser = NewAEADSizeParser(lengthAuth) 320 } 321 return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding), nil 322 323 default: 324 return nil, newError("invalid option: Security") 325 } 326 } 327 328 // EncodeResponseHeader writes encoded response header into the given writer. 329 func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, writer io.Writer) { 330 var encryptionWriter io.Writer 331 BodyKey := sha256.Sum256(s.requestBodyKey[:]) 332 copy(s.responseBodyKey[:], BodyKey[:16]) 333 BodyIV := sha256.Sum256(s.requestBodyIV[:]) 334 copy(s.responseBodyIV[:], BodyIV[:16]) 335 336 aesStream := crypto.NewAesEncryptionStream(s.responseBodyKey[:], s.responseBodyIV[:]) 337 encryptionWriter = crypto.NewCryptionWriter(aesStream, writer) 338 s.responseWriter = encryptionWriter 339 340 aeadEncryptedHeaderBuffer := bytes.NewBuffer(nil) 341 encryptionWriter = aeadEncryptedHeaderBuffer 342 343 common.Must2(encryptionWriter.Write([]byte{s.responseHeader, byte(header.Option)})) 344 err := MarshalCommand(header.Command, encryptionWriter) 345 if err != nil { 346 common.Must2(encryptionWriter.Write([]byte{0x00, 0x00})) 347 } 348 349 aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey) 350 aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12] 351 352 aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block) 353 aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD) 354 355 aeadResponseHeaderLengthEncryptionBuffer := bytes.NewBuffer(nil) 356 357 decryptedResponseHeaderLengthBinaryDeserializeBuffer := uint16(aeadEncryptedHeaderBuffer.Len()) 358 359 common.Must(binary.Write(aeadResponseHeaderLengthEncryptionBuffer, binary.BigEndian, decryptedResponseHeaderLengthBinaryDeserializeBuffer)) 360 361 AEADEncryptedLength := aeadResponseHeaderLengthEncryptionAEAD.Seal(nil, aeadResponseHeaderLengthEncryptionIV, aeadResponseHeaderLengthEncryptionBuffer.Bytes(), nil) 362 common.Must2(io.Copy(writer, bytes.NewReader(AEADEncryptedLength))) 363 364 aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey) 365 aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12] 366 367 aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block) 368 aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD) 369 370 aeadEncryptedHeaderPayload := aeadResponseHeaderPayloadEncryptionAEAD.Seal(nil, aeadResponseHeaderPayloadEncryptionIV, aeadEncryptedHeaderBuffer.Bytes(), nil) 371 common.Must2(io.Copy(writer, bytes.NewReader(aeadEncryptedHeaderPayload))) 372 } 373 374 // EncodeResponseBody returns a Writer that auto-encrypt content written by caller. 375 func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) { 376 var sizeParser crypto.ChunkSizeEncoder = crypto.PlainChunkSizeParser{} 377 if request.Option.Has(protocol.RequestOptionChunkMasking) { 378 sizeParser = NewShakeSizeParser(s.responseBodyIV[:]) 379 } 380 var padding crypto.PaddingLengthGenerator 381 if request.Option.Has(protocol.RequestOptionGlobalPadding) { 382 var ok bool 383 padding, ok = sizeParser.(crypto.PaddingLengthGenerator) 384 if !ok { 385 return nil, newError("invalid option: RequestOptionGlobalPadding") 386 } 387 } 388 389 switch request.Security { 390 case protocol.SecurityType_NONE: 391 if request.Option.Has(protocol.RequestOptionChunkStream) { 392 if request.Command.TransferType() == protocol.TransferTypeStream { 393 return crypto.NewChunkStreamWriter(sizeParser, writer), nil 394 } 395 396 auth := &crypto.AEADAuthenticator{ 397 AEAD: new(NoOpAuthenticator), 398 NonceGenerator: crypto.GenerateEmptyBytes(), 399 AdditionalDataGenerator: crypto.GenerateEmptyBytes(), 400 } 401 return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket, padding), nil 402 } 403 return buf.NewWriter(writer), nil 404 405 case protocol.SecurityType_AES128_GCM: 406 aead := crypto.NewAesGcm(s.responseBodyKey[:]) 407 auth := &crypto.AEADAuthenticator{ 408 AEAD: aead, 409 NonceGenerator: GenerateChunkNonce(s.responseBodyIV[:], uint32(aead.NonceSize())), 410 AdditionalDataGenerator: crypto.GenerateEmptyBytes(), 411 } 412 if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { 413 AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len") 414 AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey) 415 416 lengthAuth := &crypto.AEADAuthenticator{ 417 AEAD: AuthenticatedLengthKeyAEAD, 418 NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), 419 AdditionalDataGenerator: crypto.GenerateEmptyBytes(), 420 } 421 sizeParser = NewAEADSizeParser(lengthAuth) 422 } 423 return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding), nil 424 425 case protocol.SecurityType_CHACHA20_POLY1305: 426 aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(s.responseBodyKey[:])) 427 428 auth := &crypto.AEADAuthenticator{ 429 AEAD: aead, 430 NonceGenerator: GenerateChunkNonce(s.responseBodyIV[:], uint32(aead.NonceSize())), 431 AdditionalDataGenerator: crypto.GenerateEmptyBytes(), 432 } 433 if request.Option.Has(protocol.RequestOptionAuthenticatedLength) { 434 AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len") 435 AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey)) 436 common.Must(err) 437 438 lengthAuth := &crypto.AEADAuthenticator{ 439 AEAD: AuthenticatedLengthKeyAEAD, 440 NonceGenerator: GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())), 441 AdditionalDataGenerator: crypto.GenerateEmptyBytes(), 442 } 443 sizeParser = NewAEADSizeParser(lengthAuth) 444 } 445 return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding), nil 446 447 default: 448 return nil, newError("invalid option: Security") 449 } 450 }