github.com/MicalKarl/sing-shadowsocks@v0.0.5/shadowaead_2022/protocol.go (about) 1 package shadowaead_2022 2 3 import ( 4 "bytes" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/rand" 8 "crypto/sha256" 9 "encoding/base64" 10 "encoding/binary" 11 "io" 12 "math" 13 mRand "math/rand" 14 "net" 15 "os" 16 "strings" 17 "sync/atomic" 18 "time" 19 20 shadowsocks "github.com/MicalKarl/sing-shadowsocks" 21 "github.com/MicalKarl/sing-shadowsocks/shadowaead" 22 "github.com/MicalKarl/sing-shadowsocks/ssv" 23 "github.com/sagernet/sing/common" 24 "github.com/sagernet/sing/common/buf" 25 "github.com/sagernet/sing/common/bufio" 26 E "github.com/sagernet/sing/common/exceptions" 27 M "github.com/sagernet/sing/common/metadata" 28 N "github.com/sagernet/sing/common/network" 29 "github.com/sagernet/sing/common/random" 30 "github.com/sagernet/sing/common/rw" 31 32 "golang.org/x/crypto/chacha20poly1305" 33 "lukechampine.com/blake3" 34 ) 35 36 const ( 37 HeaderTypeClient = 0 38 HeaderTypeServer = 1 39 MaxPaddingLength = 900 40 PacketNonceSize = 24 41 MaxPacketSize = 65535 42 RequestHeaderFixedChunkLength = 1 + 8 + 2 43 PacketMinimalHeaderSize = 30 44 ) 45 46 var ( 47 ErrMissingPSK = E.New("missing psk") 48 ErrBadHeaderType = E.New("bad header type") 49 ErrBadTimestamp = E.New("bad timestamp") 50 ErrBadRequestSalt = E.New("bad request salt") 51 ErrSaltNotUnique = E.New("salt not unique") 52 ErrBadClientSessionId = E.New("bad client session id") 53 ErrPacketIdNotUnique = E.New("packet id not unique") 54 ErrTooManyServerSessions = E.New("server session changed more than once during the last minute") 55 ErrPacketTooShort = E.New("packet too short") 56 ) 57 58 var List = []string{ 59 "2022-blake3-aes-128-gcm", 60 "2022-blake3-aes-256-gcm", 61 "2022-blake3-chacha20-poly1305", 62 } 63 64 func init() { 65 random.InitializeSeed() 66 } 67 68 func NewWithPassword(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) { 69 var pskList [][]byte 70 if password == "" { 71 return nil, ErrMissingPSK 72 } 73 keyStrList := strings.Split(password, ":") 74 pskList = make([][]byte, len(keyStrList)) 75 for i, keyStr := range keyStrList { 76 kb, err := base64.StdEncoding.DecodeString(keyStr) 77 if err != nil { 78 return nil, E.Cause(err, "decode key") 79 } 80 pskList[i] = kb 81 } 82 return New(method, pskList, timeFunc) 83 } 84 85 func New(method string, pskList [][]byte, timeFunc func() time.Time) (shadowsocks.Method, error) { 86 m := &Method{ 87 name: method, 88 timeFunc: timeFunc, 89 } 90 91 switch method { 92 case "2022-blake3-aes-128-gcm": 93 m.keySaltLength = 16 94 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 95 m.blockConstructor = aes.NewCipher 96 case "2022-blake3-aes-256-gcm": 97 m.keySaltLength = 32 98 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 99 m.blockConstructor = aes.NewCipher 100 case "2022-blake3-chacha20-poly1305": 101 if len(pskList) > 1 { 102 return nil, os.ErrInvalid 103 } 104 m.keySaltLength = 32 105 m.constructor = chacha20poly1305.New 106 } 107 108 if len(pskList) == 0 { 109 return nil, ErrMissingPSK 110 } 111 112 for i, psk := range pskList { 113 if len(psk) < m.keySaltLength { 114 return nil, shadowsocks.ErrBadKey 115 } else if len(psk) > m.keySaltLength { 116 pskList[i] = Key(psk, m.keySaltLength) 117 } 118 } 119 120 if len(pskList) > 1 { 121 pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize) 122 for i, psk := range pskList { 123 if i == 0 { 124 continue 125 } 126 hash := blake3.Sum512(psk) 127 copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], hash[:aes.BlockSize]) 128 } 129 m.pskHash = pskHash 130 } 131 132 var err error 133 switch method { 134 case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm": 135 m.udpBlockEncryptCipher, err = aes.NewCipher(pskList[0]) 136 if err != nil { 137 return nil, err 138 } 139 m.udpBlockDecryptCipher, err = aes.NewCipher(pskList[len(pskList)-1]) 140 if err != nil { 141 return nil, err 142 } 143 case "2022-blake3-chacha20-poly1305": 144 m.udpCipher, err = chacha20poly1305.NewX(pskList[0]) 145 if err != nil { 146 return nil, err 147 } 148 } 149 150 m.pskList = pskList 151 return m, nil 152 } 153 154 func Key(key []byte, keyLength int) []byte { 155 psk := sha256.Sum256(key) 156 return psk[:keyLength] 157 } 158 159 func SessionKey(psk []byte, salt []byte, keyLength int) []byte { 160 sessionKey := make([]byte, len(psk)+len(salt)) 161 copy(sessionKey, psk) 162 copy(sessionKey[len(psk):], salt) 163 outKey := make([]byte, keyLength) 164 blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey) 165 return outKey 166 } 167 168 func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) { 169 return func(key []byte) (cipher.AEAD, error) { 170 b, err := block(key) 171 if err != nil { 172 return nil, err 173 } 174 return aead(b) 175 } 176 } 177 178 type Method struct { 179 name string 180 keySaltLength int 181 timeFunc func() time.Time 182 183 constructor func(key []byte) (cipher.AEAD, error) 184 blockConstructor func(key []byte) (cipher.Block, error) 185 udpCipher cipher.AEAD 186 udpBlockEncryptCipher cipher.Block 187 udpBlockDecryptCipher cipher.Block 188 pskList [][]byte 189 pskHash []byte 190 } 191 192 func (m *Method) Name() string { 193 return m.name 194 } 195 196 func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { 197 shadowsocksConn := &clientConn{ 198 Method: m, 199 Conn: conn, 200 destination: destination, 201 } 202 return shadowsocksConn, shadowsocksConn.writeRequest(nil) 203 } 204 205 func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { 206 return &clientConn{ 207 Method: m, 208 Conn: conn, 209 destination: destination, 210 } 211 } 212 213 func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { 214 return &clientPacketConn{m, conn, m.newUDPSession()} 215 } 216 217 type clientConn struct { 218 *Method 219 net.Conn 220 destination M.Socksaddr 221 requestSalt []byte 222 reader *shadowaead.Reader 223 writer *shadowaead.Writer 224 } 225 226 func (m *Method) time() time.Time { 227 if m.timeFunc != nil { 228 return m.timeFunc() 229 } else { 230 return time.Now() 231 } 232 } 233 234 func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error { 235 pskLen := len(m.pskList) 236 if pskLen < 2 { 237 return nil 238 } 239 for i, psk := range m.pskList { 240 keyMaterial := make([]byte, m.keySaltLength*2) 241 copy(keyMaterial, psk) 242 copy(keyMaterial[m.keySaltLength:], salt) 243 identitySubkey := buf.NewSize(m.keySaltLength) 244 identitySubkey.Extend(identitySubkey.FreeLen()) 245 blake3.DeriveKey(identitySubkey.Bytes(), "shadowsocks 2022 identity subkey", keyMaterial) 246 247 pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] 248 249 header := request.Extend(16) 250 b, err := m.blockConstructor(identitySubkey.Bytes()) 251 if err != nil { 252 return err 253 } 254 b.Encrypt(header, pskHash) 255 identitySubkey.Release() 256 if i == pskLen-2 { 257 break 258 } 259 } 260 return nil 261 } 262 263 func (c *clientConn) writeRequest(payload []byte) error { 264 salt := make([]byte, c.keySaltLength) 265 common.Must1(io.ReadFull(rand.Reader, salt)) 266 267 key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength) 268 writeCipher, err := c.constructor(key) 269 if err != nil { 270 return err 271 } 272 writer := shadowaead.NewWriter( 273 c.Conn, 274 writeCipher, 275 MaxPacketSize, 276 ) 277 278 header := writer.Buffer() 279 header.Write(salt) 280 281 err = c.writeExtendedIdentityHeaders(header, salt) 282 if err != nil { 283 return err 284 } 285 286 var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte 287 fixedLengthBuffer := buf.With(_fixedLengthBuffer[:]) 288 common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient)) 289 common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(c.time().Unix()))) 290 var paddingLen int 291 if len(payload) < MaxPaddingLength { 292 paddingLen = mRand.Intn(MaxPaddingLength) + 1 293 } 294 variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen 295 payloadLen := len(payload) 296 variableLengthHeaderLen += payloadLen 297 common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen))) 298 writer.WriteChunk(header, fixedLengthBuffer.Slice()) 299 300 variableLengthBuffer := buf.NewSize(variableLengthHeaderLen) 301 err = ssv.FakeSockSerializer.WriteAddrPort(variableLengthBuffer, c.destination) 302 if err != nil { 303 return err 304 } 305 common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen))) 306 if paddingLen > 0 { 307 variableLengthBuffer.Extend(paddingLen) 308 } 309 if payloadLen > 0 { 310 common.Must1(variableLengthBuffer.Write(payload[:payloadLen])) 311 } 312 writer.WriteChunk(header, variableLengthBuffer.Slice()) 313 variableLengthBuffer.Release() 314 315 err = writer.BufferedWriter(header.Len()).Flush() 316 if err != nil { 317 return E.Cause(err, "client handshake") 318 } 319 320 c.requestSalt = salt 321 c.writer = writer 322 return nil 323 } 324 325 func (c *clientConn) readResponse() error { 326 if c.reader != nil { 327 return nil 328 } 329 330 salt := buf.NewSize(c.keySaltLength) 331 332 _, err := salt.ReadFullFrom(c.Conn, salt.FreeLen()) 333 if err != nil { 334 salt.Release() 335 return err 336 } 337 338 key := SessionKey(c.pskList[len(c.pskList)-1], salt.Bytes(), c.keySaltLength) 339 salt.Release() 340 341 readCipher, err := c.constructor(key) 342 if err != nil { 343 return err 344 } 345 reader := shadowaead.NewReader( 346 c.Conn, 347 readCipher, 348 MaxPacketSize, 349 ) 350 351 err = reader.ReadWithLength(uint16(1 + 8 + c.keySaltLength + 2)) 352 if err != nil { 353 return E.Cause(err, "read response fixed length chunk") 354 } 355 356 headerType, err := rw.ReadByte(reader) 357 if err != nil { 358 return err 359 } 360 if headerType != HeaderTypeServer /* && headerType != HeaderTypeServerEncrypted*/ { 361 return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) 362 } 363 364 var epoch uint64 365 err = binary.Read(reader, binary.BigEndian, &epoch) 366 if err != nil { 367 return err 368 } 369 370 diff := int(math.Abs(float64(c.time().Unix() - int64(epoch)))) 371 if diff > 30 { 372 return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 373 } 374 375 requestSalt := buf.NewSize(c.keySaltLength) 376 _, err = requestSalt.ReadFullFrom(reader, requestSalt.FreeLen()) 377 if err != nil { 378 return err 379 } 380 381 if bytes.Compare(requestSalt.Bytes(), c.requestSalt) > 0 { 382 return ErrBadRequestSalt 383 } 384 requestSalt.Release() 385 c.requestSalt = nil 386 387 var length uint16 388 err = binary.Read(reader, binary.BigEndian, &length) 389 if err != nil { 390 return err 391 } 392 393 err = reader.ReadWithLength(length) 394 if err != nil { 395 return err 396 } 397 if headerType == HeaderTypeServer { 398 c.reader = reader 399 } 400 return nil 401 } 402 403 func (c *clientConn) Read(p []byte) (n int, err error) { 404 if err = c.readResponse(); err != nil { 405 return 406 } 407 return c.reader.Read(p) 408 } 409 410 func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) { 411 if err = c.readResponse(); err != nil { 412 return 413 } 414 return bufio.Copy(w, c.reader) 415 } 416 417 func (c *clientConn) Write(p []byte) (n int, err error) { 418 if c.writer == nil { 419 err = c.writeRequest(p) 420 if err == nil { 421 n = len(p) 422 } 423 return 424 } 425 return c.writer.Write(p) 426 } 427 428 var _ N.VectorisedWriter = (*clientConn)(nil) 429 430 func (c *clientConn) WriteVectorised(buffers []*buf.Buffer) error { 431 if c.writer != nil { 432 return c.writer.WriteVectorised(buffers) 433 } 434 err := c.writeRequest(buffers[0].Bytes()) 435 if err != nil { 436 buf.ReleaseMulti(buffers) 437 return err 438 } 439 buffers[0].Release() 440 return c.writer.WriteVectorised(buffers[1:]) 441 } 442 443 func (c *clientConn) NeedHandshake() bool { 444 return c.writer == nil 445 } 446 447 func (c *clientConn) NeedAdditionalReadDeadline() bool { 448 return true 449 } 450 451 func (c *clientConn) Upstream() any { 452 return c.Conn 453 } 454 455 func (c *clientConn) Close() error { 456 return common.Close( 457 c.Conn, 458 common.PtrOrNil(c.reader), 459 common.PtrOrNil(c.writer), 460 ) 461 } 462 463 type clientPacketConn struct { 464 *Method 465 net.Conn 466 session *udpSession 467 } 468 469 func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 470 defer buffer.Release() 471 var hdrLen int 472 if c.udpCipher != nil { 473 hdrLen = PacketNonceSize 474 } 475 476 var paddingLen int 477 if destination.Port == 53 && buffer.Len() < MaxPaddingLength { 478 paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1 479 } 480 481 hdrLen += 16 // packet header 482 pskLen := len(c.pskList) 483 if c.udpCipher == nil && pskLen > 1 { 484 hdrLen += (pskLen - 1) * aes.BlockSize 485 } 486 hdrLen += 1 // header type 487 hdrLen += 8 // timestamp 488 hdrLen += 2 // padding length 489 hdrLen += paddingLen 490 hdrLen += M.SocksaddrSerializer.AddrPortLen(destination) 491 header := buf.With(buffer.ExtendHeader(hdrLen)) 492 493 var dataIndex int 494 if c.udpCipher != nil { 495 common.Must1(header.ReadFullFrom(c.session.rng, PacketNonceSize)) 496 if pskLen > 1 { 497 panic("unsupported chacha extended header") 498 } 499 dataIndex = PacketNonceSize 500 } else { 501 dataIndex = aes.BlockSize 502 } 503 504 common.Must( 505 binary.Write(header, binary.BigEndian, c.session.sessionId), 506 binary.Write(header, binary.BigEndian, c.session.nextPacketId()), 507 ) 508 509 if c.udpCipher == nil && pskLen > 1 { 510 for i, psk := range c.pskList { 511 dataIndex += aes.BlockSize 512 pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] 513 514 identityHeader := header.Extend(aes.BlockSize) 515 xorWords(identityHeader, pskHash, header.To(aes.BlockSize)) 516 b, err := c.blockConstructor(psk) 517 if err != nil { 518 return err 519 } 520 b.Encrypt(identityHeader, identityHeader) 521 522 if i == pskLen-2 { 523 break 524 } 525 } 526 } 527 common.Must( 528 header.WriteByte(HeaderTypeClient), 529 binary.Write(header, binary.BigEndian, uint64(c.time().Unix())), 530 binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length 531 ) 532 533 if paddingLen > 0 { 534 header.Extend(paddingLen) 535 } 536 537 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 538 if err != nil { 539 return err 540 } 541 if c.udpCipher != nil { 542 c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) 543 buffer.Extend(shadowaead.Overhead) 544 } else { 545 packetHeader := buffer.To(aes.BlockSize) 546 c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) 547 buffer.Extend(shadowaead.Overhead) 548 c.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader) 549 } 550 return common.Error(c.Write(buffer.Bytes())) 551 } 552 553 func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { 554 n, err := c.Read(buffer.FreeBytes()) 555 if err != nil { 556 return M.Socksaddr{}, err 557 } 558 buffer.Truncate(n) 559 560 var packetHeader []byte 561 if c.udpCipher != nil { 562 if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize { 563 return M.Socksaddr{}, ErrPacketTooShort 564 } 565 _, err = c.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) 566 if err != nil { 567 return M.Socksaddr{}, E.Cause(err, "decrypt packet") 568 } 569 buffer.Advance(PacketNonceSize) 570 buffer.Truncate(buffer.Len() - shadowaead.Overhead) 571 } else { 572 if buffer.Len() < PacketMinimalHeaderSize { 573 return M.Socksaddr{}, ErrPacketTooShort 574 } 575 packetHeader = buffer.To(aes.BlockSize) 576 c.udpBlockDecryptCipher.Decrypt(packetHeader, packetHeader) 577 } 578 579 var sessionId, packetId uint64 580 err = binary.Read(buffer, binary.BigEndian, &sessionId) 581 if err != nil { 582 return M.Socksaddr{}, err 583 } 584 err = binary.Read(buffer, binary.BigEndian, &packetId) 585 if err != nil { 586 return M.Socksaddr{}, err 587 } 588 589 if sessionId == c.session.remoteSessionId { 590 if !c.session.window.Check(packetId) { 591 return M.Socksaddr{}, ErrPacketIdNotUnique 592 } 593 } else if sessionId == c.session.lastRemoteSessionId { 594 if !c.session.lastWindow.Check(packetId) { 595 return M.Socksaddr{}, ErrPacketIdNotUnique 596 } 597 } 598 599 var remoteCipher cipher.AEAD 600 if packetHeader != nil { 601 if sessionId == c.session.remoteSessionId { 602 remoteCipher = c.session.remoteCipher 603 } else if sessionId == c.session.lastRemoteSessionId { 604 remoteCipher = c.session.lastRemoteCipher 605 } else { 606 key := SessionKey(c.pskList[len(c.pskList)-1], packetHeader[:8], c.keySaltLength) 607 remoteCipher, err = c.constructor(key) 608 if err != nil { 609 return M.Socksaddr{}, err 610 } 611 } 612 _, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) 613 if err != nil { 614 return M.Socksaddr{}, E.Cause(err, "decrypt packet") 615 } 616 buffer.Truncate(buffer.Len() - shadowaead.Overhead) 617 } 618 619 var headerType byte 620 headerType, err = buffer.ReadByte() 621 if err != nil { 622 return M.Socksaddr{}, err 623 } 624 if headerType != HeaderTypeServer { 625 return M.Socksaddr{}, E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) 626 } 627 628 var epoch uint64 629 err = binary.Read(buffer, binary.BigEndian, &epoch) 630 if err != nil { 631 return M.Socksaddr{}, err 632 } 633 634 diff := int(math.Abs(float64(c.time().Unix() - int64(epoch)))) 635 if diff > 30 { 636 return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 637 } 638 639 if sessionId == c.session.remoteSessionId { 640 c.session.window.Add(packetId) 641 } else if sessionId == c.session.lastRemoteSessionId { 642 c.session.lastWindow.Add(packetId) 643 c.session.lastRemoteSeen = c.time().Unix() 644 } else { 645 if c.session.remoteSessionId != 0 { 646 if c.time().Unix()-c.session.lastRemoteSeen < 60 { 647 return M.Socksaddr{}, ErrTooManyServerSessions 648 } else { 649 c.session.lastRemoteSessionId = c.session.remoteSessionId 650 c.session.lastWindow = c.session.window 651 c.session.lastRemoteSeen = c.time().Unix() 652 c.session.lastRemoteCipher = c.session.remoteCipher 653 c.session.window = SlidingWindow{} 654 } 655 } 656 c.session.remoteSessionId = sessionId 657 c.session.remoteCipher = remoteCipher 658 c.session.window.Add(packetId) 659 } 660 661 var clientSessionId uint64 662 err = binary.Read(buffer, binary.BigEndian, &clientSessionId) 663 if err != nil { 664 return M.Socksaddr{}, err 665 } 666 667 if clientSessionId != c.session.sessionId { 668 return M.Socksaddr{}, ErrBadClientSessionId 669 } 670 671 var paddingLen uint16 672 err = binary.Read(buffer, binary.BigEndian, &paddingLen) 673 if err != nil { 674 return M.Socksaddr{}, E.Cause(err, "read padding length") 675 } 676 buffer.Advance(int(paddingLen)) 677 678 destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) 679 if err != nil { 680 return M.Socksaddr{}, err 681 } 682 return destination, nil 683 } 684 685 func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 686 buffer := buf.With(p) 687 destination, err := c.ReadPacket(buffer) 688 if err != nil { 689 return 690 } 691 if destination.IsFqdn() { 692 addr = destination 693 } else { 694 addr = destination.UDPAddr() 695 } 696 n = copy(p, buffer.Bytes()) 697 return 698 } 699 700 func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 701 destination := M.SocksaddrFromNet(addr) 702 var overHead int 703 if c.udpCipher != nil { 704 overHead = PacketNonceSize + shadowaead.Overhead 705 } else { 706 overHead = shadowaead.Overhead 707 } 708 overHead += 16 // packet header 709 pskLen := len(c.pskList) 710 if c.udpCipher == nil && pskLen > 1 { 711 overHead += (pskLen - 1) * aes.BlockSize 712 } 713 var paddingLen int 714 if destination.Port == 53 && len(p) < MaxPaddingLength { 715 paddingLen = mRand.Intn(MaxPaddingLength-len(p)) + 1 716 } 717 overHead += 1 // header type 718 overHead += 8 // timestamp 719 overHead += 2 // padding length 720 overHead += paddingLen 721 overHead += M.SocksaddrSerializer.AddrPortLen(destination) 722 723 buffer := buf.NewSize(overHead + len(p)) 724 defer buffer.Release() 725 726 var dataIndex int 727 if c.udpCipher != nil { 728 common.Must1(buffer.ReadFullFrom(c.session.rng, PacketNonceSize)) 729 if pskLen > 1 { 730 panic("unsupported chacha extended header") 731 } 732 dataIndex = PacketNonceSize 733 } else { 734 dataIndex = aes.BlockSize 735 } 736 737 common.Must( 738 binary.Write(buffer, binary.BigEndian, c.session.sessionId), 739 binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()), 740 ) 741 742 if c.udpCipher == nil && pskLen > 1 { 743 for i, psk := range c.pskList { 744 dataIndex += aes.BlockSize 745 pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] 746 747 identityHeader := buffer.Extend(aes.BlockSize) 748 xorWords(identityHeader, pskHash, buffer.To(aes.BlockSize)) 749 b, err := c.blockConstructor(psk) 750 if err != nil { 751 return 0, err 752 } 753 b.Encrypt(identityHeader, identityHeader) 754 755 if i == pskLen-2 { 756 break 757 } 758 } 759 } 760 common.Must( 761 buffer.WriteByte(HeaderTypeClient), 762 binary.Write(buffer, binary.BigEndian, uint64(c.time().Unix())), 763 binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length 764 ) 765 766 if paddingLen > 0 { 767 buffer.Extend(paddingLen) 768 } 769 770 err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 771 if err != nil { 772 return 773 } 774 common.Must1(buffer.Write(p)) 775 if c.udpCipher != nil { 776 c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) 777 buffer.Extend(shadowaead.Overhead) 778 } else { 779 packetHeader := buffer.To(aes.BlockSize) 780 c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) 781 buffer.Extend(shadowaead.Overhead) 782 c.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader) 783 } 784 err = common.Error(c.Write(buffer.Bytes())) 785 if err != nil { 786 return 787 } 788 return len(p), nil 789 } 790 791 func (c *clientPacketConn) FrontHeadroom() int { 792 var overHead int 793 if c.udpCipher != nil { 794 overHead = PacketNonceSize + shadowaead.Overhead 795 } else { 796 overHead = shadowaead.Overhead 797 } 798 overHead += 16 // packet header 799 pskLen := len(c.pskList) 800 if c.udpCipher == nil && pskLen > 1 { 801 overHead += (pskLen - 1) * aes.BlockSize 802 } 803 overHead += 1 // header type 804 overHead += 8 // timestamp 805 overHead += 2 // padding length 806 overHead += MaxPaddingLength 807 overHead += M.MaxSocksaddrLength 808 return overHead 809 } 810 811 func (c *clientPacketConn) RearHeadroom() int { 812 return shadowaead.Overhead 813 } 814 815 type udpSession struct { 816 sessionId uint64 817 packetId uint64 818 remoteSessionId uint64 819 lastRemoteSessionId uint64 820 lastRemoteSeen int64 821 cipher cipher.AEAD 822 remoteCipher cipher.AEAD 823 lastRemoteCipher cipher.AEAD 824 window SlidingWindow 825 lastWindow SlidingWindow 826 rng io.Reader 827 } 828 829 func (s *udpSession) nextPacketId() uint64 { 830 return atomic.AddUint64(&s.packetId, 1) 831 } 832 833 func (m *Method) newUDPSession() *udpSession { 834 session := &udpSession{} 835 if m.udpCipher != nil { 836 session.rng = Blake3KeyedHash(rand.Reader) 837 common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) 838 } else { 839 common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) 840 } 841 session.packetId-- 842 if m.udpCipher == nil { 843 sessionId := make([]byte, 8) 844 binary.BigEndian.PutUint64(sessionId, session.sessionId) 845 key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength) 846 var err error 847 session.cipher, err = m.constructor(key) 848 if err != nil { 849 return nil 850 } 851 } 852 return session 853 } 854 855 func (c *clientPacketConn) Upstream() any { 856 return c.Conn 857 } 858 859 func (c *clientPacketConn) Close() error { 860 return common.Close(c.Conn) 861 } 862 863 func Blake3KeyedHash(reader io.Reader) io.Reader { 864 key := make([]byte, 32) 865 common.Must1(io.ReadFull(reader, key)) 866 h := blake3.New(1024, key) 867 return h.XOF() 868 }