github.com/MerlinKodo/sing-shadowsocks2@v0.1.6/shadowaead_2022/method.go (about) 1 package shadowaead_2022 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/aes" 7 "crypto/cipher" 8 "crypto/rand" 9 "encoding/base64" 10 "encoding/binary" 11 "math" 12 mRand "math/rand" 13 "net" 14 "os" 15 "strings" 16 "time" 17 18 C "github.com/MerlinKodo/sing-shadowsocks2/cipher" 19 "github.com/MerlinKodo/sing-shadowsocks2/internal/shadowio" 20 "github.com/sagernet/sing/common" 21 "github.com/sagernet/sing/common/buf" 22 "github.com/sagernet/sing/common/bufio" 23 E "github.com/sagernet/sing/common/exceptions" 24 M "github.com/sagernet/sing/common/metadata" 25 N "github.com/sagernet/sing/common/network" 26 "github.com/sagernet/sing/common/ntp" 27 28 "golang.org/x/crypto/chacha20poly1305" 29 "lukechampine.com/blake3" 30 ) 31 32 var MethodList = []string{ 33 "2022-blake3-aes-128-gcm", 34 "2022-blake3-aes-256-gcm", 35 "2022-blake3-chacha20-poly1305", 36 } 37 38 func init() { 39 C.RegisterMethod(MethodList, NewMethod) 40 } 41 42 type Method struct { 43 keySaltLength int 44 timeFunc func() time.Time 45 constructor func(key []byte) (cipher.AEAD, error) 46 blockConstructor func(key []byte) (cipher.Block, error) 47 udpCipher cipher.AEAD 48 udpBlockEncryptCipher cipher.Block 49 udpBlockDecryptCipher cipher.Block 50 pskList [][]byte 51 pskHash []byte 52 } 53 54 func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) { 55 m := &Method{ 56 timeFunc: ntp.TimeFuncFromContext(ctx), 57 pskList: options.KeyList, 58 } 59 if options.Password != "" { 60 var pskList [][]byte 61 keyStrList := strings.Split(options.Password, ":") 62 pskList = make([][]byte, len(keyStrList)) 63 for i, keyStr := range keyStrList { 64 kb, err := base64.StdEncoding.DecodeString(keyStr) 65 if err != nil { 66 return nil, E.Cause(err, "decode key") 67 } 68 pskList[i] = kb 69 } 70 m.pskList = pskList 71 } 72 switch methodName { 73 case "2022-blake3-aes-128-gcm": 74 m.keySaltLength = 16 75 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 76 m.blockConstructor = aes.NewCipher 77 case "2022-blake3-aes-256-gcm": 78 m.keySaltLength = 32 79 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 80 m.blockConstructor = aes.NewCipher 81 case "2022-blake3-chacha20-poly1305": 82 if len(options.KeyList) > 1 { 83 return nil, ErrNoEIH 84 } 85 m.keySaltLength = 32 86 m.constructor = chacha20poly1305.New 87 default: 88 return nil, os.ErrInvalid 89 } 90 if len(m.pskList) == 0 { 91 return nil, C.ErrMissingPassword 92 } 93 for _, key := range m.pskList { 94 if len(key) != m.keySaltLength { 95 return nil, E.New("bad key length, required ", m.keySaltLength, ", got ", len(key)) 96 } 97 } 98 if len(m.pskList) > 1 { 99 pskHash := make([]byte, (len(m.pskList)-1)*aes.BlockSize) 100 for i, key := range m.pskList { 101 if i == 0 { 102 continue 103 } 104 keyHash := blake3.Sum512(key) 105 copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], keyHash[:aes.BlockSize]) 106 } 107 m.pskHash = pskHash 108 } 109 var err error 110 switch methodName { 111 case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm": 112 m.udpBlockEncryptCipher, err = aes.NewCipher(m.pskList[0]) 113 if err != nil { 114 return nil, err 115 } 116 m.udpBlockDecryptCipher, err = aes.NewCipher(m.pskList[len(m.pskList)-1]) 117 if err != nil { 118 return nil, err 119 } 120 case "2022-blake3-chacha20-poly1305": 121 m.udpCipher, err = chacha20poly1305.NewX(m.pskList[0]) 122 if err != nil { 123 return nil, err 124 } 125 } 126 return m, nil 127 } 128 129 func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { 130 shadowsocksConn := &clientConn{ 131 Conn: conn, 132 method: m, 133 destination: destination, 134 } 135 return shadowsocksConn, shadowsocksConn.writeRequest(nil) 136 } 137 138 func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { 139 return &clientConn{ 140 Conn: conn, 141 method: m, 142 destination: destination, 143 } 144 } 145 146 func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { 147 pc := &clientPacketConn{ 148 AbstractConn: conn, 149 reader: bufio.NewExtendedReader(conn), 150 writer: bufio.NewExtendedWriter(conn), 151 method: m, 152 session: m.newUDPSession(), 153 } 154 if waitRead, isWaitRead := N.CastReader[shadowio.WaitReadReader](conn); isWaitRead { 155 return &clientWaitPacketConn{ 156 clientPacketConn: pc, 157 waitRead: waitRead, 158 } 159 } 160 return pc 161 } 162 163 func (m *Method) time() time.Time { 164 if m.timeFunc != nil { 165 return m.timeFunc() 166 } else { 167 return time.Now() 168 } 169 } 170 171 type clientConn struct { 172 net.Conn 173 method *Method 174 destination M.Socksaddr 175 requestSalt []byte 176 reader *shadowio.Reader 177 writer *shadowio.Writer 178 shadowio.WriterInterface 179 } 180 181 func (c *clientConn) writeRequest(payload []byte) error { 182 requestSalt := make([]byte, c.method.keySaltLength) 183 requestBuffer := buf.New() 184 defer requestBuffer.Release() 185 requestBuffer.WriteRandom(c.method.keySaltLength) 186 copy(requestSalt, requestBuffer.Bytes()) 187 key := SessionKey(c.method.pskList[len(c.method.pskList)-1], requestSalt, c.method.keySaltLength) 188 writeCipher, err := c.method.constructor(key) 189 if err != nil { 190 return err 191 } 192 writer := shadowio.NewWriter( 193 c.Conn, 194 writeCipher, 195 nil, 196 buf.BufferSize-shadowio.PacketLengthBufferSize-shadowio.Overhead*2, 197 ) 198 err = c.method.writeExtendedIdentityHeaders(requestBuffer, requestBuffer.To(c.method.keySaltLength)) 199 if err != nil { 200 return err 201 } 202 fixedLengthBuffer := buf.With(requestBuffer.Extend(RequestHeaderFixedChunkLength + shadowio.Overhead)) 203 common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient)) 204 common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(c.method.time().Unix()))) 205 variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 206 var paddingLen int 207 if len(payload) < MaxPaddingLength { 208 paddingLen = mRand.Intn(MaxPaddingLength) + 1 209 } 210 variableLengthHeaderLen += paddingLen 211 maxPayloadLen := requestBuffer.FreeLen() - (variableLengthHeaderLen + shadowio.Overhead) 212 payloadLen := len(payload) 213 if payloadLen > maxPayloadLen { 214 payloadLen = maxPayloadLen 215 } 216 variableLengthHeaderLen += payloadLen 217 common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen))) 218 writer.Encrypt(fixedLengthBuffer.Index(0), fixedLengthBuffer.Bytes()) 219 fixedLengthBuffer.Extend(shadowio.Overhead) 220 221 variableLengthBuffer := buf.With(requestBuffer.Extend(variableLengthHeaderLen + shadowio.Overhead)) 222 err = M.SocksaddrSerializer.WriteAddrPort(variableLengthBuffer, c.destination) 223 if err != nil { 224 return err 225 } 226 common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen))) 227 if paddingLen > 0 { 228 variableLengthBuffer.Extend(paddingLen) 229 } 230 if payloadLen > 0 { 231 common.Must1(variableLengthBuffer.Write(payload[:payloadLen])) 232 } 233 writer.Encrypt(variableLengthBuffer.Index(0), variableLengthBuffer.Bytes()) 234 variableLengthBuffer.Extend(shadowio.Overhead) 235 _, err = c.Conn.Write(requestBuffer.Bytes()) 236 if err != nil { 237 return err 238 } 239 if len(payload) > payloadLen { 240 _, err = writer.Write(payload[payloadLen:]) 241 if err != nil { 242 return err 243 } 244 } 245 c.requestSalt = requestSalt 246 c.writer = writer 247 return nil 248 } 249 250 func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error { 251 pskLen := len(m.pskList) 252 if pskLen < 2 { 253 return nil 254 } 255 for i, psk := range m.pskList { 256 keyMaterial := make([]byte, m.keySaltLength*2) 257 copy(keyMaterial, psk) 258 copy(keyMaterial[m.keySaltLength:], salt) 259 identitySubkey := make([]byte, m.keySaltLength) 260 blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial) 261 pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] 262 header := request.Extend(16) 263 b, err := m.blockConstructor(identitySubkey) 264 if err != nil { 265 return err 266 } 267 b.Encrypt(header, pskHash) 268 if i == pskLen-2 { 269 break 270 } 271 } 272 return nil 273 } 274 275 func (c *clientConn) readResponse() error { 276 salt := buf.NewSize(c.method.keySaltLength) 277 defer salt.Release() 278 _, err := salt.ReadFullFrom(c.Conn, c.method.keySaltLength) 279 if err != nil { 280 return err 281 } 282 key := SessionKey(c.method.pskList[len(c.method.pskList)-1], salt.Bytes(), c.method.keySaltLength) 283 readCipher, err := c.method.constructor(key) 284 if err != nil { 285 return err 286 } 287 reader := shadowio.NewReader(c.Conn, readCipher) 288 fixedResponseBuffer, err := reader.ReadFixedBuffer(1 + 8 + c.method.keySaltLength + 2) 289 if err != nil { 290 return err 291 } 292 headerType := common.Must1(fixedResponseBuffer.ReadByte()) 293 if headerType != HeaderTypeServer { 294 return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) 295 } 296 var epoch uint64 297 common.Must(binary.Read(fixedResponseBuffer, binary.BigEndian, &epoch)) 298 diff := int(math.Abs(float64(c.method.time().Unix() - int64(epoch)))) 299 if diff > 30 { 300 return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 301 } 302 responseSalt := common.Must1(fixedResponseBuffer.ReadBytes(c.method.keySaltLength)) 303 if !bytes.Equal(responseSalt, c.requestSalt) { 304 return ErrBadRequestSalt 305 } 306 var length uint16 307 common.Must(binary.Read(reader, binary.BigEndian, &length)) 308 _, err = reader.ReadFixedBuffer(int(length)) 309 if err != nil { 310 return err 311 } 312 c.reader = reader 313 return nil 314 } 315 316 func (c *clientConn) Read(p []byte) (n int, err error) { 317 if c.reader == nil { 318 if err = c.readResponse(); err != nil { 319 return 320 } 321 } 322 return c.reader.Read(p) 323 } 324 325 func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error { 326 if c.reader == nil { 327 err := c.readResponse() 328 if err != nil { 329 return err 330 } 331 } 332 return c.reader.ReadBuffer(buffer) 333 } 334 335 func (c *clientConn) ReadBufferThreadSafe() (buffer *buf.Buffer, err error) { 336 if c.reader == nil { 337 err = c.readResponse() 338 if err != nil { 339 return 340 } 341 342 } 343 return c.reader.ReadBufferThreadSafe() 344 } 345 346 func (c *clientConn) Write(p []byte) (n int, err error) { 347 if c.writer == nil { 348 err = c.writeRequest(p) 349 if err == nil { 350 n = len(p) 351 } 352 return 353 } 354 return c.writer.Write(p) 355 } 356 357 func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error { 358 if c.writer == nil { 359 defer buffer.Release() 360 return c.writeRequest(buffer.Bytes()) 361 } 362 return c.writer.WriteBuffer(buffer) 363 } 364 365 func (c *clientConn) NeedHandshake() bool { 366 return c.writer == nil 367 } 368 369 func (c *clientConn) Upstream() any { 370 return c.Conn 371 } 372 373 func (c *clientConn) Close() error { 374 return common.Close( 375 c.Conn, 376 common.PtrOrNil(c.reader), 377 common.PtrOrNil(c.writer), 378 ) 379 } 380 381 type clientPacketConn struct { 382 N.AbstractConn 383 reader N.ExtendedReader 384 writer N.ExtendedWriter 385 method *Method 386 session *udpSession 387 } 388 389 func (m *Method) newUDPSession() *udpSession { 390 session := &udpSession{} 391 if m.udpCipher != nil { 392 session.rng = Blake3KeyedHash(rand.Reader) 393 common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) 394 } else { 395 common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) 396 } 397 session.packetId-- 398 if m.udpCipher == nil { 399 sessionId := make([]byte, 8) 400 binary.BigEndian.PutUint64(sessionId, session.sessionId) 401 key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength) 402 var err error 403 session.cipher, err = m.constructor(key) 404 if err != nil { 405 return nil 406 } 407 } 408 return session 409 } 410 411 func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 412 var hdrLen int 413 if c.method.udpCipher != nil { 414 hdrLen = PacketNonceSize 415 } 416 417 var paddingLen int 418 if destination.Port == 53 && buffer.Len() < MaxPaddingLength { 419 paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1 420 } 421 422 hdrLen += 16 // packet header 423 pskLen := len(c.method.pskList) 424 if c.method.udpCipher == nil && pskLen > 1 { 425 hdrLen += (pskLen - 1) * aes.BlockSize 426 } 427 hdrLen += 1 // header type 428 hdrLen += 8 // timestamp 429 hdrLen += 2 // padding length 430 hdrLen += paddingLen 431 hdrLen += M.SocksaddrSerializer.AddrPortLen(destination) 432 header := buf.With(buffer.ExtendHeader(hdrLen)) 433 434 var dataIndex int 435 if c.method.udpCipher != nil { 436 common.Must1(header.ReadFullFrom(c.session.rng, PacketNonceSize)) 437 if pskLen > 1 { 438 panic("unsupported chacha extended header") 439 } 440 dataIndex = PacketNonceSize 441 } else { 442 dataIndex = aes.BlockSize 443 } 444 445 common.Must( 446 binary.Write(header, binary.BigEndian, c.session.sessionId), 447 binary.Write(header, binary.BigEndian, c.session.nextPacketId()), 448 ) 449 450 if c.method.udpCipher == nil && pskLen > 1 { 451 for i, psk := range c.method.pskList { 452 dataIndex += aes.BlockSize 453 pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] 454 455 identityHeader := header.Extend(aes.BlockSize) 456 xorWords(identityHeader, pskHash, header.To(aes.BlockSize)) 457 b, err := c.method.blockConstructor(psk) 458 if err != nil { 459 return err 460 } 461 b.Encrypt(identityHeader, identityHeader) 462 463 if i == pskLen-2 { 464 break 465 } 466 } 467 } 468 common.Must( 469 header.WriteByte(HeaderTypeClient), 470 binary.Write(header, binary.BigEndian, uint64(c.method.time().Unix())), 471 binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length 472 ) 473 474 if paddingLen > 0 { 475 header.Extend(paddingLen) 476 } 477 478 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 479 if err != nil { 480 return err 481 } 482 if c.method.udpCipher != nil { 483 c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) 484 buffer.Extend(shadowio.Overhead) 485 } else { 486 packetHeader := buffer.To(aes.BlockSize) 487 c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) 488 buffer.Extend(shadowio.Overhead) 489 c.method.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader) 490 } 491 return c.writer.WriteBuffer(buffer) 492 } 493 494 func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 495 err = c.reader.ReadBuffer(buffer) 496 if err != nil { 497 return 498 } 499 return c.readPacket(buffer) 500 } 501 502 func (c *clientPacketConn) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 503 var packetHeader []byte 504 if c.method.udpCipher != nil { 505 if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize { 506 return M.Socksaddr{}, C.ErrPacketTooShort 507 } 508 _, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) 509 if err != nil { 510 return M.Socksaddr{}, E.Cause(err, "decrypt packet") 511 } 512 buffer.Advance(PacketNonceSize) 513 buffer.Truncate(buffer.Len() - shadowio.Overhead) 514 } else { 515 if buffer.Len() < PacketMinimalHeaderSize { 516 return M.Socksaddr{}, C.ErrPacketTooShort 517 } 518 packetHeader = buffer.To(aes.BlockSize) 519 c.method.udpBlockDecryptCipher.Decrypt(packetHeader, packetHeader) 520 } 521 522 var sessionId, packetId uint64 523 err = binary.Read(buffer, binary.BigEndian, &sessionId) 524 if err != nil { 525 return M.Socksaddr{}, err 526 } 527 err = binary.Read(buffer, binary.BigEndian, &packetId) 528 if err != nil { 529 return M.Socksaddr{}, err 530 } 531 532 if sessionId == c.session.remoteSessionId { 533 if !c.session.window.Check(packetId) { 534 return M.Socksaddr{}, ErrPacketIdNotUnique 535 } 536 } else if sessionId == c.session.lastRemoteSessionId { 537 if !c.session.lastWindow.Check(packetId) { 538 return M.Socksaddr{}, ErrPacketIdNotUnique 539 } 540 } 541 542 var remoteCipher cipher.AEAD 543 if packetHeader != nil { 544 if sessionId == c.session.remoteSessionId { 545 remoteCipher = c.session.remoteCipher 546 } else if sessionId == c.session.lastRemoteSessionId { 547 remoteCipher = c.session.lastRemoteCipher 548 } else { 549 key := SessionKey(c.method.pskList[len(c.method.pskList)-1], packetHeader[:8], c.method.keySaltLength) 550 remoteCipher, err = c.method.constructor(key) 551 if err != nil { 552 return M.Socksaddr{}, err 553 } 554 } 555 _, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) 556 if err != nil { 557 return M.Socksaddr{}, E.Cause(err, "decrypt packet") 558 } 559 buffer.Truncate(buffer.Len() - shadowio.Overhead) 560 } 561 562 var headerType byte 563 headerType, err = buffer.ReadByte() 564 if err != nil { 565 return M.Socksaddr{}, err 566 } 567 if headerType != HeaderTypeServer { 568 return M.Socksaddr{}, E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType) 569 } 570 571 var epoch uint64 572 err = binary.Read(buffer, binary.BigEndian, &epoch) 573 if err != nil { 574 return M.Socksaddr{}, err 575 } 576 577 diff := int(math.Abs(float64(c.method.time().Unix() - int64(epoch)))) 578 if diff > 30 { 579 return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 580 } 581 582 if sessionId == c.session.remoteSessionId { 583 c.session.window.Add(packetId) 584 } else if sessionId == c.session.lastRemoteSessionId { 585 c.session.lastWindow.Add(packetId) 586 c.session.lastRemoteSeen = c.method.time().Unix() 587 } else { 588 if c.session.remoteSessionId != 0 { 589 if c.method.time().Unix()-c.session.lastRemoteSeen < 60 { 590 return M.Socksaddr{}, ErrTooManyServerSessions 591 } else { 592 c.session.lastRemoteSessionId = c.session.remoteSessionId 593 c.session.lastWindow = c.session.window 594 c.session.lastRemoteSeen = c.method.time().Unix() 595 c.session.lastRemoteCipher = c.session.remoteCipher 596 c.session.window = SlidingWindow{} 597 } 598 } 599 c.session.remoteSessionId = sessionId 600 c.session.remoteCipher = remoteCipher 601 c.session.window.Add(packetId) 602 } 603 604 var clientSessionId uint64 605 err = binary.Read(buffer, binary.BigEndian, &clientSessionId) 606 if err != nil { 607 return M.Socksaddr{}, err 608 } 609 610 if clientSessionId != c.session.sessionId { 611 return M.Socksaddr{}, ErrBadClientSessionId 612 } 613 614 var paddingLen uint16 615 err = binary.Read(buffer, binary.BigEndian, &paddingLen) 616 if err != nil { 617 return M.Socksaddr{}, E.Cause(err, "read padding length") 618 } 619 buffer.Advance(int(paddingLen)) 620 621 destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) 622 if err != nil { 623 return 624 } 625 return destination.Unwrap(), nil 626 } 627 628 func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 629 n, err = c.reader.Read(p) 630 if err != nil { 631 return 632 } 633 buffer := buf.As(p[:n]) 634 destination, err := c.readPacket(buffer) 635 if destination.IsFqdn() { 636 addr = destination 637 } else { 638 addr = destination.UDPAddr() 639 } 640 n = copy(p, buffer.Bytes()) 641 return 642 } 643 644 func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 645 destination := M.SocksaddrFromNet(addr) 646 var overHead int 647 if c.method.udpCipher != nil { 648 overHead = PacketNonceSize + shadowio.Overhead 649 } else { 650 overHead = shadowio.Overhead 651 } 652 overHead += 16 // packet header 653 pskLen := len(c.method.pskList) 654 if c.method.udpCipher == nil && pskLen > 1 { 655 overHead += (pskLen - 1) * aes.BlockSize 656 } 657 var paddingLen int 658 if destination.Port == 53 && len(p) < MaxPaddingLength { 659 paddingLen = mRand.Intn(MaxPaddingLength-len(p)) + 1 660 } 661 overHead += 1 // header type 662 overHead += 8 // timestamp 663 overHead += 2 // padding length 664 overHead += paddingLen 665 overHead += M.SocksaddrSerializer.AddrPortLen(destination) 666 667 buffer := buf.NewSize(overHead + len(p)) 668 defer buffer.Release() 669 670 var dataIndex int 671 if c.method.udpCipher != nil { 672 common.Must1(buffer.ReadFullFrom(c.session.rng, PacketNonceSize)) 673 if pskLen > 1 { 674 panic("unsupported chacha extended header") 675 } 676 dataIndex = PacketNonceSize 677 } else { 678 dataIndex = aes.BlockSize 679 } 680 681 common.Must( 682 binary.Write(buffer, binary.BigEndian, c.session.sessionId), 683 binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()), 684 ) 685 686 if c.method.udpCipher == nil && pskLen > 1 { 687 for i, psk := range c.method.pskList { 688 dataIndex += aes.BlockSize 689 pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)] 690 691 identityHeader := buffer.Extend(aes.BlockSize) 692 xorWords(identityHeader, pskHash, buffer.To(aes.BlockSize)) 693 b, err := c.method.blockConstructor(psk) 694 if err != nil { 695 return 0, err 696 } 697 b.Encrypt(identityHeader, identityHeader) 698 699 if i == pskLen-2 { 700 break 701 } 702 } 703 } 704 common.Must( 705 buffer.WriteByte(HeaderTypeClient), 706 binary.Write(buffer, binary.BigEndian, uint64(c.method.time().Unix())), 707 binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length 708 ) 709 710 if paddingLen > 0 { 711 buffer.Extend(paddingLen) 712 } 713 714 err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 715 if err != nil { 716 return 717 } 718 common.Must1(buffer.Write(p)) 719 if c.method.udpCipher != nil { 720 c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) 721 buffer.Extend(shadowio.Overhead) 722 } else { 723 packetHeader := buffer.To(aes.BlockSize) 724 c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) 725 buffer.Extend(shadowio.Overhead) 726 c.method.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader) 727 } 728 err = common.Error(c.writer.Write(buffer.Bytes())) 729 if err != nil { 730 return 731 } 732 return len(p), nil 733 } 734 735 func (c *clientPacketConn) FrontHeadroom() int { 736 var overHead int 737 if c.method.udpCipher != nil { 738 overHead = PacketNonceSize + shadowio.Overhead 739 } else { 740 overHead = shadowio.Overhead 741 } 742 overHead += 16 // packet header 743 pskLen := len(c.method.pskList) 744 if c.method.udpCipher == nil && pskLen > 1 { 745 overHead += (pskLen - 1) * aes.BlockSize 746 } 747 overHead += 1 // header type 748 overHead += 8 // timestamp 749 overHead += 2 // padding length 750 overHead += MaxPaddingLength 751 overHead += M.MaxSocksaddrLength 752 return overHead 753 } 754 755 func (c *clientPacketConn) RearHeadroom() int { 756 return shadowio.Overhead 757 } 758 759 func (c *clientPacketConn) Upstream() any { 760 return c.AbstractConn 761 } 762 763 func (c *clientPacketConn) Close() error { 764 return c.AbstractConn.Close() 765 } 766 767 var _ shadowio.WaitReadFrom = (*clientWaitPacketConn)(nil) 768 769 type clientWaitPacketConn struct { 770 *clientPacketConn 771 waitRead shadowio.WaitRead 772 } 773 774 func (c *clientWaitPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { 775 data, put, err = c.waitRead.WaitRead() 776 if err != nil { 777 return 778 } 779 if len(data) <= 0 { 780 err = C.ErrPacketTooShort 781 return 782 } 783 buffer := buf.As(data) 784 var destination M.Socksaddr 785 destination, err = c.readPacket(buffer) 786 if err != nil { 787 if put != nil { 788 put() 789 } 790 put = nil 791 data = nil 792 return 793 } 794 if destination.IsFqdn() { 795 addr = destination 796 } else { 797 addr = destination.UDPAddr() 798 } 799 data = buffer.Bytes() 800 return 801 }