github.com/MerlinKodo/sing-shadowsocks@v0.2.6/shadowaead_2022/protocol.go (about) 1 package shadowaead_2022 2 3 import ( 4 "bytes" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/rand" 8 "crypto/sha256" 9 "encoding/base64" 10 "encoding/binary" 11 "io" 12 "math" 13 mRand "math/rand" 14 "net" 15 "os" 16 "strings" 17 "sync/atomic" 18 "time" 19 20 shadowsocks "github.com/MerlinKodo/sing-shadowsocks" 21 "github.com/MerlinKodo/sing-shadowsocks/shadowaead" 22 "github.com/sagernet/sing/common" 23 "github.com/sagernet/sing/common/buf" 24 "github.com/sagernet/sing/common/bufio" 25 E "github.com/sagernet/sing/common/exceptions" 26 M "github.com/sagernet/sing/common/metadata" 27 N "github.com/sagernet/sing/common/network" 28 "github.com/sagernet/sing/common/random" 29 "github.com/sagernet/sing/common/rw" 30 31 "golang.org/x/crypto/chacha20poly1305" 32 "lukechampine.com/blake3" 33 ) 34 35 const ( 36 HeaderTypeClient = 0 37 HeaderTypeServer = 1 38 MaxPaddingLength = 900 39 PacketNonceSize = 24 40 MaxPacketSize = 65535 41 RequestHeaderFixedChunkLength = 1 + 8 + 2 42 PacketMinimalHeaderSize = 30 43 ) 44 45 var ( 46 ErrMissingPSK = E.New("missing psk") 47 ErrBadHeaderType = E.New("bad header type") 48 ErrBadTimestamp = E.New("bad timestamp") 49 ErrBadRequestSalt = E.New("bad request salt") 50 ErrSaltNotUnique = E.New("salt not unique") 51 ErrBadClientSessionId = E.New("bad client session id") 52 ErrPacketIdNotUnique = E.New("packet id not unique") 53 ErrTooManyServerSessions = E.New("server session changed more than once during the last minute") 54 ErrPacketTooShort = E.New("packet too short") 55 ) 56 57 var List = []string{ 58 "2022-blake3-aes-128-gcm", 59 "2022-blake3-aes-256-gcm", 60 "2022-blake3-chacha20-poly1305", 61 } 62 63 func init() { 64 random.InitializeSeed() 65 } 66 67 func NewWithPassword(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) { 68 var pskList [][]byte 69 if password == "" { 70 return nil, ErrMissingPSK 71 } 72 keyStrList := strings.Split(password, ":") 73 pskList = make([][]byte, len(keyStrList)) 74 for i, keyStr := range keyStrList { 75 kb, err := base64.StdEncoding.DecodeString(keyStr) 76 if err != nil { 77 return nil, E.Cause(err, "decode key") 78 } 79 pskList[i] = kb 80 } 81 return New(method, pskList, timeFunc) 82 } 83 84 func New(method string, pskList [][]byte, timeFunc func() time.Time) (shadowsocks.Method, error) { 85 m := &Method{ 86 name: method, 87 timeFunc: timeFunc, 88 } 89 90 switch method { 91 case "2022-blake3-aes-128-gcm": 92 m.keySaltLength = 16 93 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 94 m.blockConstructor = aes.NewCipher 95 case "2022-blake3-aes-256-gcm": 96 m.keySaltLength = 32 97 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 98 m.blockConstructor = aes.NewCipher 99 case "2022-blake3-chacha20-poly1305": 100 if len(pskList) > 1 { 101 return nil, os.ErrInvalid 102 } 103 m.keySaltLength = 32 104 m.constructor = chacha20poly1305.New 105 } 106 107 if len(pskList) == 0 { 108 return nil, ErrMissingPSK 109 } 110 111 for i, psk := range pskList { 112 if len(psk) < m.keySaltLength { 113 return nil, shadowsocks.ErrBadKey 114 } else if len(psk) > m.keySaltLength { 115 pskList[i] = Key(psk, m.keySaltLength) 116 } 117 } 118 119 if len(pskList) > 1 { 120 pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize) 121 for i, psk := range pskList { 122 if i == 0 { 123 continue 124 } 125 hash := blake3.Sum512(psk) 126 copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], hash[:aes.BlockSize]) 127 } 128 m.pskHash = pskHash 129 } 130 131 var err error 132 switch method { 133 case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm": 134 m.udpBlockEncryptCipher, err = aes.NewCipher(pskList[0]) 135 if err != nil { 136 return nil, err 137 } 138 m.udpBlockDecryptCipher, err = aes.NewCipher(pskList[len(pskList)-1]) 139 if err != nil { 140 return nil, err 141 } 142 case "2022-blake3-chacha20-poly1305": 143 m.udpCipher, err = chacha20poly1305.NewX(pskList[0]) 144 if err != nil { 145 return nil, err 146 } 147 } 148 149 m.pskList = pskList 150 return m, nil 151 } 152 153 func Key(key []byte, keyLength int) []byte { 154 psk := sha256.Sum256(key) 155 return psk[:keyLength] 156 } 157 158 func SessionKey(psk []byte, salt []byte, keyLength int) []byte { 159 sessionKey := make([]byte, len(psk)+len(salt)) 160 copy(sessionKey, psk) 161 copy(sessionKey[len(psk):], salt) 162 outKey := make([]byte, keyLength) 163 blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey) 164 return outKey 165 } 166 167 func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) { 168 return func(key []byte) (cipher.AEAD, error) { 169 b, err := block(key) 170 if err != nil { 171 return nil, err 172 } 173 return aead(b) 174 } 175 } 176 177 type Method struct { 178 name string 179 keySaltLength int 180 timeFunc func() time.Time 181 182 constructor func(key []byte) (cipher.AEAD, error) 183 blockConstructor func(key []byte) (cipher.Block, error) 184 udpCipher cipher.AEAD 185 udpBlockEncryptCipher cipher.Block 186 udpBlockDecryptCipher cipher.Block 187 pskList [][]byte 188 pskHash []byte 189 } 190 191 func (m *Method) Name() string { 192 return m.name 193 } 194 195 func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { 196 shadowsocksConn := &clientConn{ 197 Method: m, 198 Conn: conn, 199 destination: destination, 200 } 201 return shadowsocksConn, shadowsocksConn.writeRequest(nil) 202 } 203 204 func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { 205 return &clientConn{ 206 Method: m, 207 Conn: conn, 208 destination: destination, 209 } 210 } 211 212 func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { 213 return &clientPacketConn{m, conn, m.newUDPSession()} 214 } 215 216 type clientConn struct { 217 *Method 218 net.Conn 219 destination M.Socksaddr 220 requestSalt []byte 221 reader *shadowaead.Reader 222 writer *shadowaead.Writer 223 } 224 225 func (m *Method) time() time.Time { 226 if m.timeFunc != nil { 227 return m.timeFunc() 228 } else { 229 return time.Now() 230 } 231 } 232 233 func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error { 234 pskLen := len(m.pskList) 235 if pskLen < 2 { 236 return nil 237 } 238 for i, psk := range m.pskList { 239 keyMaterial := make([]byte, m.keySaltLength*2) 240 copy(keyMaterial, psk) 241 copy(keyMaterial[m.keySaltLength:], salt) 242 identitySubkey := buf.NewSize(m.keySaltLength) 243 identitySubkey.Extend(identitySubkey.FreeLen()) 244 blake3.DeriveKey(identitySubkey.Bytes(), "shadowsocks 2022 identity subkey", keyMaterial) 245 246 pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] 247 248 header := request.Extend(16) 249 b, err := m.blockConstructor(identitySubkey.Bytes()) 250 if err != nil { 251 return err 252 } 253 b.Encrypt(header, pskHash) 254 identitySubkey.Release() 255 if i == pskLen-2 { 256 break 257 } 258 } 259 return nil 260 } 261 262 func (c *clientConn) writeRequest(payload []byte) error { 263 salt := make([]byte, c.keySaltLength) 264 common.Must1(io.ReadFull(rand.Reader, salt)) 265 266 key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength) 267 writeCipher, err := c.constructor(key) 268 if err != nil { 269 return err 270 } 271 writer := shadowaead.NewWriter( 272 c.Conn, 273 writeCipher, 274 MaxPacketSize, 275 ) 276 277 header := writer.Buffer() 278 header.Write(salt) 279 280 err = c.writeExtendedIdentityHeaders(header, salt) 281 if err != nil { 282 return err 283 } 284 285 var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte 286 fixedLengthBuffer := buf.With(_fixedLengthBuffer[:]) 287 common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient)) 288 common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(c.time().Unix()))) 289 var paddingLen int 290 if len(payload) < MaxPaddingLength { 291 paddingLen = mRand.Intn(MaxPaddingLength) + 1 292 } 293 variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen 294 payloadLen := len(payload) 295 variableLengthHeaderLen += payloadLen 296 common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen))) 297 writer.WriteChunk(header, fixedLengthBuffer.Slice()) 298 299 variableLengthBuffer := buf.NewSize(variableLengthHeaderLen) 300 err = M.SocksaddrSerializer.WriteAddrPort(variableLengthBuffer, c.destination) 301 if err != nil { 302 return err 303 } 304 common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen))) 305 if paddingLen > 0 { 306 variableLengthBuffer.Extend(paddingLen) 307 } 308 if payloadLen > 0 { 309 common.Must1(variableLengthBuffer.Write(payload[:payloadLen])) 310 } 311 writer.WriteChunk(header, variableLengthBuffer.Slice()) 312 variableLengthBuffer.Release() 313 314 err = writer.BufferedWriter(header.Len()).Flush() 315 if err != nil { 316 return E.Cause(err, "client handshake") 317 } 318 319 c.requestSalt = salt 320 c.writer = writer 321 return nil 322 } 323 324 func (c *clientConn) readResponse() error { 325 if c.reader != nil { 326 return nil 327 } 328 329 salt := buf.NewSize(c.keySaltLength) 330 331 _, err := salt.ReadFullFrom(c.Conn, salt.FreeLen()) 332 if err != nil { 333 salt.Release() 334 return err 335 } 336 337 key := SessionKey(c.pskList[len(c.pskList)-1], salt.Bytes(), c.keySaltLength) 338 salt.Release() 339 340 readCipher, err := c.constructor(key) 341 if err != nil { 342 return err 343 } 344 reader := shadowaead.NewReader( 345 c.Conn, 346 readCipher, 347 MaxPacketSize, 348 ) 349 350 err = reader.ReadWithLength(uint16(1 + 8 + c.keySaltLength + 2)) 351 if err != nil { 352 return E.Cause(err, "read response fixed length chunk") 353 } 354 355 headerType, err := rw.ReadByte(reader) 356 if err != nil { 357 return err 358 } 359 if headerType != HeaderTypeServer /* && headerType != HeaderTypeServerEncrypted*/ { 360 return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) 361 } 362 363 var epoch uint64 364 err = binary.Read(reader, binary.BigEndian, &epoch) 365 if err != nil { 366 return err 367 } 368 369 diff := int(math.Abs(float64(c.time().Unix() - int64(epoch)))) 370 if diff > 30 { 371 return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 372 } 373 374 requestSalt := buf.NewSize(c.keySaltLength) 375 _, err = requestSalt.ReadFullFrom(reader, requestSalt.FreeLen()) 376 if err != nil { 377 return err 378 } 379 380 if bytes.Compare(requestSalt.Bytes(), c.requestSalt) > 0 { 381 return ErrBadRequestSalt 382 } 383 requestSalt.Release() 384 c.requestSalt = nil 385 386 var length uint16 387 err = binary.Read(reader, binary.BigEndian, &length) 388 if err != nil { 389 return err 390 } 391 392 err = reader.ReadWithLength(length) 393 if err != nil { 394 return err 395 } 396 if headerType == HeaderTypeServer { 397 c.reader = reader 398 } 399 return nil 400 } 401 402 func (c *clientConn) Read(p []byte) (n int, err error) { 403 if err = c.readResponse(); err != nil { 404 return 405 } 406 return c.reader.Read(p) 407 } 408 409 func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) { 410 if err = c.readResponse(); err != nil { 411 return 412 } 413 return bufio.Copy(w, c.reader) 414 } 415 416 func (c *clientConn) Write(p []byte) (n int, err error) { 417 if c.writer == nil { 418 err = c.writeRequest(p) 419 if err == nil { 420 n = len(p) 421 } 422 return 423 } 424 return c.writer.Write(p) 425 } 426 427 var _ N.VectorisedWriter = (*clientConn)(nil) 428 429 func (c *clientConn) WriteVectorised(buffers []*buf.Buffer) error { 430 if c.writer != nil { 431 return c.writer.WriteVectorised(buffers) 432 } 433 err := c.writeRequest(buffers[0].Bytes()) 434 if err != nil { 435 buf.ReleaseMulti(buffers) 436 return err 437 } 438 buffers[0].Release() 439 return c.writer.WriteVectorised(buffers[1:]) 440 } 441 442 func (c *clientConn) NeedHandshake() bool { 443 return c.writer == nil 444 } 445 446 func (c *clientConn) NeedAdditionalReadDeadline() bool { 447 return true 448 } 449 450 func (c *clientConn) Upstream() any { 451 return c.Conn 452 } 453 454 func (c *clientConn) Close() error { 455 return common.Close( 456 c.Conn, 457 common.PtrOrNil(c.reader), 458 common.PtrOrNil(c.writer), 459 ) 460 } 461 462 type clientPacketConn struct { 463 *Method 464 net.Conn 465 session *udpSession 466 } 467 468 func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 469 defer buffer.Release() 470 var hdrLen int 471 if c.udpCipher != nil { 472 hdrLen = PacketNonceSize 473 } 474 475 var paddingLen int 476 if destination.Port == 53 && buffer.Len() < MaxPaddingLength { 477 paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1 478 } 479 480 hdrLen += 16 // packet header 481 pskLen := len(c.pskList) 482 if c.udpCipher == nil && pskLen > 1 { 483 hdrLen += (pskLen - 1) * aes.BlockSize 484 } 485 hdrLen += 1 // header type 486 hdrLen += 8 // timestamp 487 hdrLen += 2 // padding length 488 hdrLen += paddingLen 489 hdrLen += M.SocksaddrSerializer.AddrPortLen(destination) 490 header := buf.With(buffer.ExtendHeader(hdrLen)) 491 492 var dataIndex int 493 if c.udpCipher != nil { 494 common.Must1(header.ReadFullFrom(c.session.rng, PacketNonceSize)) 495 if pskLen > 1 { 496 panic("unsupported chacha extended header") 497 } 498 dataIndex = PacketNonceSize 499 } else { 500 dataIndex = aes.BlockSize 501 } 502 503 common.Must( 504 binary.Write(header, binary.BigEndian, c.session.sessionId), 505 binary.Write(header, binary.BigEndian, c.session.nextPacketId()), 506 ) 507 508 if c.udpCipher == nil && pskLen > 1 { 509 for i, psk := range c.pskList { 510 dataIndex += aes.BlockSize 511 pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] 512 513 identityHeader := header.Extend(aes.BlockSize) 514 xorWords(identityHeader, pskHash, header.To(aes.BlockSize)) 515 b, err := c.blockConstructor(psk) 516 if err != nil { 517 return err 518 } 519 b.Encrypt(identityHeader, identityHeader) 520 521 if i == pskLen-2 { 522 break 523 } 524 } 525 } 526 common.Must( 527 header.WriteByte(HeaderTypeClient), 528 binary.Write(header, binary.BigEndian, uint64(c.time().Unix())), 529 binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length 530 ) 531 532 if paddingLen > 0 { 533 header.Extend(paddingLen) 534 } 535 536 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 537 if err != nil { 538 return err 539 } 540 if c.udpCipher != nil { 541 c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) 542 buffer.Extend(shadowaead.Overhead) 543 } else { 544 packetHeader := buffer.To(aes.BlockSize) 545 c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) 546 buffer.Extend(shadowaead.Overhead) 547 c.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader) 548 } 549 return common.Error(c.Write(buffer.Bytes())) 550 } 551 552 func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { 553 n, err := c.Read(buffer.FreeBytes()) 554 if err != nil { 555 return M.Socksaddr{}, err 556 } 557 buffer.Truncate(n) 558 559 var packetHeader []byte 560 if c.udpCipher != nil { 561 if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize { 562 return M.Socksaddr{}, ErrPacketTooShort 563 } 564 _, err = c.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) 565 if err != nil { 566 return M.Socksaddr{}, E.Cause(err, "decrypt packet") 567 } 568 buffer.Advance(PacketNonceSize) 569 buffer.Truncate(buffer.Len() - shadowaead.Overhead) 570 } else { 571 if buffer.Len() < PacketMinimalHeaderSize { 572 return M.Socksaddr{}, ErrPacketTooShort 573 } 574 packetHeader = buffer.To(aes.BlockSize) 575 c.udpBlockDecryptCipher.Decrypt(packetHeader, packetHeader) 576 } 577 578 var sessionId, packetId uint64 579 err = binary.Read(buffer, binary.BigEndian, &sessionId) 580 if err != nil { 581 return M.Socksaddr{}, err 582 } 583 err = binary.Read(buffer, binary.BigEndian, &packetId) 584 if err != nil { 585 return M.Socksaddr{}, err 586 } 587 588 if sessionId == c.session.remoteSessionId { 589 if !c.session.window.Check(packetId) { 590 return M.Socksaddr{}, ErrPacketIdNotUnique 591 } 592 } else if sessionId == c.session.lastRemoteSessionId { 593 if !c.session.lastWindow.Check(packetId) { 594 return M.Socksaddr{}, ErrPacketIdNotUnique 595 } 596 } 597 598 var remoteCipher cipher.AEAD 599 if packetHeader != nil { 600 if sessionId == c.session.remoteSessionId { 601 remoteCipher = c.session.remoteCipher 602 } else if sessionId == c.session.lastRemoteSessionId { 603 remoteCipher = c.session.lastRemoteCipher 604 } else { 605 key := SessionKey(c.pskList[len(c.pskList)-1], packetHeader[:8], c.keySaltLength) 606 remoteCipher, err = c.constructor(key) 607 if err != nil { 608 return M.Socksaddr{}, err 609 } 610 } 611 _, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) 612 if err != nil { 613 return M.Socksaddr{}, E.Cause(err, "decrypt packet") 614 } 615 buffer.Truncate(buffer.Len() - shadowaead.Overhead) 616 } 617 618 var headerType byte 619 headerType, err = buffer.ReadByte() 620 if err != nil { 621 return M.Socksaddr{}, err 622 } 623 if headerType != HeaderTypeServer { 624 return M.Socksaddr{}, E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) 625 } 626 627 var epoch uint64 628 err = binary.Read(buffer, binary.BigEndian, &epoch) 629 if err != nil { 630 return M.Socksaddr{}, err 631 } 632 633 diff := int(math.Abs(float64(c.time().Unix() - int64(epoch)))) 634 if diff > 30 { 635 return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 636 } 637 638 if sessionId == c.session.remoteSessionId { 639 c.session.window.Add(packetId) 640 } else if sessionId == c.session.lastRemoteSessionId { 641 c.session.lastWindow.Add(packetId) 642 c.session.lastRemoteSeen = c.time().Unix() 643 } else { 644 if c.session.remoteSessionId != 0 { 645 if c.time().Unix()-c.session.lastRemoteSeen < 60 { 646 return M.Socksaddr{}, ErrTooManyServerSessions 647 } else { 648 c.session.lastRemoteSessionId = c.session.remoteSessionId 649 c.session.lastWindow = c.session.window 650 c.session.lastRemoteSeen = c.time().Unix() 651 c.session.lastRemoteCipher = c.session.remoteCipher 652 c.session.window = SlidingWindow{} 653 } 654 } 655 c.session.remoteSessionId = sessionId 656 c.session.remoteCipher = remoteCipher 657 c.session.window.Add(packetId) 658 } 659 660 var clientSessionId uint64 661 err = binary.Read(buffer, binary.BigEndian, &clientSessionId) 662 if err != nil { 663 return M.Socksaddr{}, err 664 } 665 666 if clientSessionId != c.session.sessionId { 667 return M.Socksaddr{}, ErrBadClientSessionId 668 } 669 670 var paddingLen uint16 671 err = binary.Read(buffer, binary.BigEndian, &paddingLen) 672 if err != nil { 673 return M.Socksaddr{}, E.Cause(err, "read padding length") 674 } 675 buffer.Advance(int(paddingLen)) 676 677 destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) 678 if err != nil { 679 return M.Socksaddr{}, err 680 } 681 return destination, nil 682 } 683 684 func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 685 buffer := buf.With(p) 686 destination, err := c.ReadPacket(buffer) 687 if err != nil { 688 return 689 } 690 if destination.IsFqdn() { 691 addr = destination 692 } else { 693 addr = destination.UDPAddr() 694 } 695 n = copy(p, buffer.Bytes()) 696 return 697 } 698 699 func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 700 destination := M.SocksaddrFromNet(addr) 701 var overHead int 702 if c.udpCipher != nil { 703 overHead = PacketNonceSize + shadowaead.Overhead 704 } else { 705 overHead = shadowaead.Overhead 706 } 707 overHead += 16 // packet header 708 pskLen := len(c.pskList) 709 if c.udpCipher == nil && pskLen > 1 { 710 overHead += (pskLen - 1) * aes.BlockSize 711 } 712 var paddingLen int 713 if destination.Port == 53 && len(p) < MaxPaddingLength { 714 paddingLen = mRand.Intn(MaxPaddingLength-len(p)) + 1 715 } 716 overHead += 1 // header type 717 overHead += 8 // timestamp 718 overHead += 2 // padding length 719 overHead += paddingLen 720 overHead += M.SocksaddrSerializer.AddrPortLen(destination) 721 722 buffer := buf.NewSize(overHead + len(p)) 723 defer buffer.Release() 724 725 var dataIndex int 726 if c.udpCipher != nil { 727 common.Must1(buffer.ReadFullFrom(c.session.rng, PacketNonceSize)) 728 if pskLen > 1 { 729 panic("unsupported chacha extended header") 730 } 731 dataIndex = PacketNonceSize 732 } else { 733 dataIndex = aes.BlockSize 734 } 735 736 common.Must( 737 binary.Write(buffer, binary.BigEndian, c.session.sessionId), 738 binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()), 739 ) 740 741 if c.udpCipher == nil && pskLen > 1 { 742 for i, psk := range c.pskList { 743 dataIndex += aes.BlockSize 744 pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] 745 746 identityHeader := buffer.Extend(aes.BlockSize) 747 xorWords(identityHeader, pskHash, buffer.To(aes.BlockSize)) 748 b, err := c.blockConstructor(psk) 749 if err != nil { 750 return 0, err 751 } 752 b.Encrypt(identityHeader, identityHeader) 753 754 if i == pskLen-2 { 755 break 756 } 757 } 758 } 759 common.Must( 760 buffer.WriteByte(HeaderTypeClient), 761 binary.Write(buffer, binary.BigEndian, uint64(c.time().Unix())), 762 binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length 763 ) 764 765 if paddingLen > 0 { 766 buffer.Extend(paddingLen) 767 } 768 769 err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 770 if err != nil { 771 return 772 } 773 common.Must1(buffer.Write(p)) 774 if c.udpCipher != nil { 775 c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) 776 buffer.Extend(shadowaead.Overhead) 777 } else { 778 packetHeader := buffer.To(aes.BlockSize) 779 c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) 780 buffer.Extend(shadowaead.Overhead) 781 c.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader) 782 } 783 err = common.Error(c.Write(buffer.Bytes())) 784 if err != nil { 785 return 786 } 787 return len(p), nil 788 } 789 790 func (c *clientPacketConn) FrontHeadroom() int { 791 var overHead int 792 if c.udpCipher != nil { 793 overHead = PacketNonceSize + shadowaead.Overhead 794 } else { 795 overHead = shadowaead.Overhead 796 } 797 overHead += 16 // packet header 798 pskLen := len(c.pskList) 799 if c.udpCipher == nil && pskLen > 1 { 800 overHead += (pskLen - 1) * aes.BlockSize 801 } 802 overHead += 1 // header type 803 overHead += 8 // timestamp 804 overHead += 2 // padding length 805 overHead += MaxPaddingLength 806 overHead += M.MaxSocksaddrLength 807 return overHead 808 } 809 810 func (c *clientPacketConn) RearHeadroom() int { 811 return shadowaead.Overhead 812 } 813 814 type udpSession struct { 815 sessionId uint64 816 packetId uint64 817 remoteSessionId uint64 818 lastRemoteSessionId uint64 819 lastRemoteSeen int64 820 cipher cipher.AEAD 821 remoteCipher cipher.AEAD 822 lastRemoteCipher cipher.AEAD 823 window SlidingWindow 824 lastWindow SlidingWindow 825 rng io.Reader 826 } 827 828 func (s *udpSession) nextPacketId() uint64 { 829 return atomic.AddUint64(&s.packetId, 1) 830 } 831 832 func (m *Method) newUDPSession() *udpSession { 833 session := &udpSession{} 834 if m.udpCipher != nil { 835 session.rng = Blake3KeyedHash(rand.Reader) 836 common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) 837 } else { 838 common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) 839 } 840 session.packetId-- 841 if m.udpCipher == nil { 842 sessionId := make([]byte, 8) 843 binary.BigEndian.PutUint64(sessionId, session.sessionId) 844 key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength) 845 var err error 846 session.cipher, err = m.constructor(key) 847 if err != nil { 848 return nil 849 } 850 } 851 return session 852 } 853 854 func (c *clientPacketConn) Upstream() any { 855 return c.Conn 856 } 857 858 func (c *clientPacketConn) Close() error { 859 return common.Close(c.Conn) 860 } 861 862 func Blake3KeyedHash(reader io.Reader) io.Reader { 863 key := make([]byte, 32) 864 common.Must1(io.ReadFull(reader, key)) 865 h := blake3.New(1024, key) 866 return h.XOF() 867 }