github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/service/udp_nat.go (about) 1 package service 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "net" 8 "net/netip" 9 "os" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/database64128/shadowsocks-go/conn" 15 "github.com/database64128/shadowsocks-go/router" 16 "github.com/database64128/shadowsocks-go/stats" 17 "github.com/database64128/shadowsocks-go/zerocopy" 18 "go.uber.org/zap" 19 ) 20 21 // natQueuedPacket is the structure used by send channels to queue packets for sending. 22 type natQueuedPacket struct { 23 buf []byte 24 start int 25 length int 26 targetAddr conn.Addr 27 } 28 29 // natEntry is an entry in the NAT table. 30 type natEntry struct { 31 // state synchronizes session initialization and shutdown. 32 // 33 // - Swap the natConn in to signal initialization completion. 34 // - Swap the serverConn in to signal shutdown. 35 // 36 // Callers must check the swapped-out value to determine the next action. 37 // 38 // - During initialization, if the swapped-out value is non-nil, 39 // initialization must not proceed. 40 // - During shutdown, if the swapped-out value is nil, preceed to the next entry. 41 state atomic.Pointer[net.UDPConn] 42 clientPktinfo atomic.Pointer[[]byte] 43 clientPktinfoCache []byte 44 natConnSendCh chan<- *natQueuedPacket 45 serverConn *net.UDPConn 46 serverConnUnpacker zerocopy.ServerUnpacker 47 logger *zap.Logger 48 } 49 50 // natUplinkGeneric is used for passing information about relay uplink to the relay goroutine. 51 type natUplinkGeneric struct { 52 clientName string 53 clientAddrPort netip.AddrPort 54 natConn *net.UDPConn 55 natConnSendCh <-chan *natQueuedPacket 56 natConnPacker zerocopy.ClientPacker 57 natTimeout time.Duration 58 logger *zap.Logger 59 } 60 61 // natDownlinkGeneric is used for passing information about relay downlink to the relay goroutine. 62 type natDownlinkGeneric struct { 63 clientName string 64 clientAddrPort netip.AddrPort 65 clientPktinfo *atomic.Pointer[[]byte] 66 natConn *net.UDPConn 67 natConnRecvBufSize int 68 natConnUnpacker zerocopy.ClientUnpacker 69 serverConn *net.UDPConn 70 serverConnPacker zerocopy.ServerPacker 71 logger *zap.Logger 72 } 73 74 // UDPNATRelay is an address-based UDP relay service. 75 // 76 // Incoming UDP packets are dispatched to NAT sessions based on the source address and port. 77 type UDPNATRelay struct { 78 serverName string 79 serverIndex int 80 mtu int 81 packetBufFrontHeadroom int 82 packetBufRecvSize int 83 listeners []udpRelayServerConn 84 server zerocopy.UDPNATServer 85 collector stats.Collector 86 router *router.Router 87 logger *zap.Logger 88 queuedPacketPool sync.Pool 89 mu sync.Mutex 90 wg sync.WaitGroup 91 mwg sync.WaitGroup 92 table map[netip.AddrPort]*natEntry 93 } 94 95 func NewUDPNATRelay( 96 serverName string, 97 serverIndex, mtu, packetBufFrontHeadroom, packetBufRecvSize, packetBufSize int, 98 listeners []udpRelayServerConn, 99 server zerocopy.UDPNATServer, 100 collector stats.Collector, 101 router *router.Router, 102 logger *zap.Logger, 103 ) *UDPNATRelay { 104 return &UDPNATRelay{ 105 serverName: serverName, 106 serverIndex: serverIndex, 107 mtu: mtu, 108 packetBufFrontHeadroom: packetBufFrontHeadroom, 109 packetBufRecvSize: packetBufRecvSize, 110 listeners: listeners, 111 server: server, 112 collector: collector, 113 router: router, 114 logger: logger, 115 queuedPacketPool: sync.Pool{ 116 New: func() any { 117 return &natQueuedPacket{ 118 buf: make([]byte, packetBufSize), 119 } 120 }, 121 }, 122 table: make(map[netip.AddrPort]*natEntry), 123 } 124 } 125 126 // String implements the Service String method. 127 func (s *UDPNATRelay) String() string { 128 return "UDP NAT relay service for " + s.serverName 129 } 130 131 // Start implements the Service Start method. 132 func (s *UDPNATRelay) Start(ctx context.Context) error { 133 for i := range s.listeners { 134 if err := s.start(ctx, i, &s.listeners[i]); err != nil { 135 return err 136 } 137 } 138 return nil 139 } 140 141 func (s *UDPNATRelay) startGeneric(ctx context.Context, index int, lnc *udpRelayServerConn) (err error) { 142 lnc.serverConn, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address) 143 if err != nil { 144 return 145 } 146 lnc.address = lnc.serverConn.LocalAddr().String() 147 lnc.logger = s.logger.With( 148 zap.String("server", s.serverName), 149 zap.Int("listener", index), 150 zap.String("listenAddress", lnc.address), 151 ) 152 153 s.mwg.Add(1) 154 155 go func() { 156 s.recvFromServerConnGeneric(ctx, lnc) 157 s.mwg.Done() 158 }() 159 160 lnc.logger.Info("Started UDP NAT relay service listener") 161 return 162 } 163 164 func (s *UDPNATRelay) recvFromServerConnGeneric(ctx context.Context, lnc *udpRelayServerConn) { 165 cmsgBuf := make([]byte, conn.SocketControlMessageBufferSize) 166 167 var ( 168 packetsReceived uint64 169 payloadBytesReceived uint64 170 ) 171 172 for { 173 queuedPacket := s.getQueuedPacket() 174 packetBuf := queuedPacket.buf 175 recvBuf := packetBuf[s.packetBufFrontHeadroom : s.packetBufFrontHeadroom+s.packetBufRecvSize] 176 177 n, cmsgn, flags, clientAddrPort, err := lnc.serverConn.ReadMsgUDPAddrPort(recvBuf, cmsgBuf) 178 if err != nil { 179 if errors.Is(err, os.ErrDeadlineExceeded) { 180 s.putQueuedPacket(queuedPacket) 181 break 182 } 183 184 lnc.logger.Warn("Failed to read packet from serverConn", 185 zap.Stringer("clientAddress", clientAddrPort), 186 zap.Int("packetLength", n), 187 zap.Error(err), 188 ) 189 190 s.putQueuedPacket(queuedPacket) 191 continue 192 } 193 err = conn.ParseFlagsForError(flags) 194 if err != nil { 195 lnc.logger.Warn("Failed to read packet from serverConn", 196 zap.Stringer("clientAddress", clientAddrPort), 197 zap.Int("packetLength", n), 198 zap.Error(err), 199 ) 200 201 s.putQueuedPacket(queuedPacket) 202 continue 203 } 204 205 s.mu.Lock() 206 207 entry, ok := s.table[clientAddrPort] 208 if !ok { 209 entry = &natEntry{ 210 serverConn: lnc.serverConn, 211 logger: lnc.logger, 212 } 213 214 entry.serverConnUnpacker, err = s.server.NewUnpacker() 215 if err != nil { 216 lnc.logger.Warn("Failed to create unpacker for serverConn", 217 zap.Stringer("clientAddress", clientAddrPort), 218 zap.Error(err), 219 ) 220 221 s.putQueuedPacket(queuedPacket) 222 s.mu.Unlock() 223 continue 224 } 225 } 226 227 queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length, err = entry.serverConnUnpacker.UnpackInPlace(packetBuf, clientAddrPort, s.packetBufFrontHeadroom, n) 228 if err != nil { 229 lnc.logger.Warn("Failed to unpack packet from serverConn", 230 zap.Stringer("clientAddress", clientAddrPort), 231 zap.Int("packetLength", n), 232 zap.Error(err), 233 ) 234 235 s.putQueuedPacket(queuedPacket) 236 s.mu.Unlock() 237 continue 238 } 239 240 packetsReceived++ 241 payloadBytesReceived += uint64(queuedPacket.length) 242 243 cmsg := cmsgBuf[:cmsgn] 244 245 if !bytes.Equal(entry.clientPktinfoCache, cmsg) { 246 clientPktinfoAddr, clientPktinfoIfindex, err := conn.ParsePktinfoCmsg(cmsg) 247 if err != nil { 248 lnc.logger.Warn("Failed to parse pktinfo control message from serverConn", 249 zap.Stringer("clientAddress", clientAddrPort), 250 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 251 zap.Error(err), 252 ) 253 254 s.putQueuedPacket(queuedPacket) 255 s.mu.Unlock() 256 continue 257 } 258 259 clientPktinfoCache := make([]byte, len(cmsg)) 260 copy(clientPktinfoCache, cmsg) 261 entry.clientPktinfo.Store(&clientPktinfoCache) 262 entry.clientPktinfoCache = clientPktinfoCache 263 264 if ce := lnc.logger.Check(zap.DebugLevel, "Updated client pktinfo"); ce != nil { 265 ce.Write( 266 zap.String("server", s.serverName), 267 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 268 zap.Stringer("clientPktinfoAddr", clientPktinfoAddr), 269 zap.Uint32("clientPktinfoIfindex", clientPktinfoIfindex), 270 ) 271 } 272 } 273 274 if !ok { 275 natConnSendCh := make(chan *natQueuedPacket, lnc.sendChannelCapacity) 276 entry.natConnSendCh = natConnSendCh 277 s.table[clientAddrPort] = entry 278 s.wg.Add(1) 279 280 go func() { 281 var sendChClean bool 282 283 defer func() { 284 s.mu.Lock() 285 close(natConnSendCh) 286 delete(s.table, clientAddrPort) 287 s.mu.Unlock() 288 289 if !sendChClean { 290 for queuedPacket := range natConnSendCh { 291 s.putQueuedPacket(queuedPacket) 292 } 293 } 294 295 s.wg.Done() 296 }() 297 298 c, err := s.router.GetUDPClient(ctx, router.RequestInfo{ 299 ServerIndex: s.serverIndex, 300 SourceAddrPort: clientAddrPort, 301 TargetAddr: queuedPacket.targetAddr, 302 }) 303 if err != nil { 304 lnc.logger.Warn("Failed to get UDP client for new NAT session", 305 zap.Stringer("clientAddress", clientAddrPort), 306 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 307 zap.Error(err), 308 ) 309 return 310 } 311 312 clientInfo, clientSession, err := c.NewSession(ctx) 313 if err != nil { 314 lnc.logger.Warn("Failed to create new UDP client session", 315 zap.Stringer("clientAddress", clientAddrPort), 316 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 317 zap.String("client", clientInfo.Name), 318 zap.Error(err), 319 ) 320 return 321 } 322 323 natConn, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "") 324 if err != nil { 325 lnc.logger.Warn("Failed to create UDP socket for new NAT session", 326 zap.Stringer("clientAddress", clientAddrPort), 327 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 328 zap.String("client", clientInfo.Name), 329 zap.Error(err), 330 ) 331 clientSession.Close() 332 return 333 } 334 335 err = natConn.SetReadDeadline(time.Now().Add(lnc.natTimeout)) 336 if err != nil { 337 lnc.logger.Warn("Failed to set read deadline on natConn", 338 zap.Stringer("clientAddress", clientAddrPort), 339 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 340 zap.String("client", clientInfo.Name), 341 zap.Duration("natTimeout", lnc.natTimeout), 342 zap.Error(err), 343 ) 344 natConn.Close() 345 clientSession.Close() 346 return 347 } 348 349 serverConnPacker, err := entry.serverConnUnpacker.NewPacker() 350 if err != nil { 351 lnc.logger.Warn("Failed to create packer for serverConn", 352 zap.Stringer("clientAddress", clientAddrPort), 353 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 354 zap.Error(err), 355 ) 356 natConn.Close() 357 clientSession.Close() 358 return 359 } 360 361 oldState := entry.state.Swap(natConn) 362 if oldState != nil { 363 natConn.Close() 364 clientSession.Close() 365 return 366 } 367 368 // No more early returns! 369 sendChClean = true 370 371 lnc.logger.Info("UDP NAT relay started", 372 zap.Stringer("clientAddress", clientAddrPort), 373 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 374 zap.String("client", clientInfo.Name), 375 ) 376 377 s.wg.Add(1) 378 379 go func() { 380 s.relayServerConnToNatConnGeneric(ctx, natUplinkGeneric{ 381 clientName: clientInfo.Name, 382 clientAddrPort: clientAddrPort, 383 natConn: natConn, 384 natConnSendCh: natConnSendCh, 385 natConnPacker: clientSession.Packer, 386 natTimeout: lnc.natTimeout, 387 logger: lnc.logger, 388 }) 389 natConn.Close() 390 clientSession.Close() 391 s.wg.Done() 392 }() 393 394 s.relayNatConnToServerConnGeneric(natDownlinkGeneric{ 395 clientName: clientInfo.Name, 396 clientAddrPort: clientAddrPort, 397 clientPktinfo: &entry.clientPktinfo, 398 natConn: natConn, 399 natConnRecvBufSize: clientSession.MaxPacketSize, 400 natConnUnpacker: clientSession.Unpacker, 401 serverConn: lnc.serverConn, 402 serverConnPacker: serverConnPacker, 403 logger: lnc.logger, 404 }) 405 }() 406 407 if ce := lnc.logger.Check(zap.DebugLevel, "New UDP NAT session"); ce != nil { 408 ce.Write( 409 zap.Stringer("clientAddress", clientAddrPort), 410 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 411 ) 412 } 413 } 414 415 select { 416 case entry.natConnSendCh <- queuedPacket: 417 default: 418 if ce := lnc.logger.Check(zap.DebugLevel, "Dropping packet due to full send channel"); ce != nil { 419 ce.Write( 420 zap.Stringer("clientAddress", clientAddrPort), 421 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 422 ) 423 } 424 425 s.putQueuedPacket(queuedPacket) 426 } 427 428 s.mu.Unlock() 429 } 430 431 lnc.logger.Info("Finished receiving from serverConn", 432 zap.Uint64("packetsReceived", packetsReceived), 433 zap.Uint64("payloadBytesReceived", payloadBytesReceived), 434 ) 435 } 436 437 func (s *UDPNATRelay) relayServerConnToNatConnGeneric(ctx context.Context, uplink natUplinkGeneric) { 438 var ( 439 destAddrPort netip.AddrPort 440 packetStart int 441 packetLength int 442 err error 443 packetsSent uint64 444 payloadBytesSent uint64 445 ) 446 447 for queuedPacket := range uplink.natConnSendCh { 448 destAddrPort, packetStart, packetLength, err = uplink.natConnPacker.PackInPlace(ctx, queuedPacket.buf, queuedPacket.targetAddr, queuedPacket.start, queuedPacket.length) 449 if err != nil { 450 uplink.logger.Warn("Failed to pack packet for natConn", 451 zap.Stringer("clientAddress", uplink.clientAddrPort), 452 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 453 zap.String("client", uplink.clientName), 454 zap.Int("payloadLength", queuedPacket.length), 455 zap.Error(err), 456 ) 457 458 s.putQueuedPacket(queuedPacket) 459 continue 460 } 461 462 _, err = uplink.natConn.WriteToUDPAddrPort(queuedPacket.buf[packetStart:packetStart+packetLength], destAddrPort) 463 if err != nil { 464 uplink.logger.Warn("Failed to write packet to natConn", 465 zap.Stringer("clientAddress", uplink.clientAddrPort), 466 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 467 zap.String("client", uplink.clientName), 468 zap.Stringer("writeDestAddress", destAddrPort), 469 zap.Int("packetLength", packetLength), 470 zap.Error(err), 471 ) 472 } 473 474 err = uplink.natConn.SetReadDeadline(time.Now().Add(uplink.natTimeout)) 475 if err != nil { 476 uplink.logger.Warn("Failed to set read deadline on natConn", 477 zap.Stringer("clientAddress", uplink.clientAddrPort), 478 zap.Stringer("targetAddress", &queuedPacket.targetAddr), 479 zap.String("client", uplink.clientName), 480 zap.Stringer("writeDestAddress", destAddrPort), 481 zap.Duration("natTimeout", uplink.natTimeout), 482 zap.Error(err), 483 ) 484 } 485 486 s.putQueuedPacket(queuedPacket) 487 packetsSent++ 488 payloadBytesSent += uint64(queuedPacket.length) 489 } 490 491 uplink.logger.Info("Finished relay serverConn -> natConn", 492 zap.Stringer("clientAddress", uplink.clientAddrPort), 493 zap.String("client", uplink.clientName), 494 zap.Stringer("lastWriteDestAddress", destAddrPort), 495 zap.Uint64("packetsSent", packetsSent), 496 zap.Uint64("payloadBytesSent", payloadBytesSent), 497 ) 498 499 s.collector.CollectUDPSessionUplink("", packetsSent, payloadBytesSent) 500 } 501 502 func (s *UDPNATRelay) relayNatConnToServerConnGeneric(downlink natDownlinkGeneric) { 503 maxClientPacketSize := zerocopy.MaxPacketSizeForAddr(s.mtu, downlink.clientAddrPort.Addr()) 504 505 serverConnPackerInfo := downlink.serverConnPacker.ServerPackerInfo() 506 natConnUnpackerInfo := downlink.natConnUnpacker.ClientUnpackerInfo() 507 headroom := zerocopy.UDPRelayHeadroom(serverConnPackerInfo.Headroom, natConnUnpackerInfo.Headroom) 508 509 var ( 510 clientPktinfo []byte 511 clientPktinfop *[]byte 512 packetsSent uint64 513 payloadBytesSent uint64 514 ) 515 516 packetBuf := make([]byte, headroom.Front+downlink.natConnRecvBufSize+headroom.Rear) 517 recvBuf := packetBuf[headroom.Front : headroom.Front+downlink.natConnRecvBufSize] 518 519 for { 520 n, _, flags, packetSourceAddrPort, err := downlink.natConn.ReadMsgUDPAddrPort(recvBuf, nil) 521 if err != nil { 522 if errors.Is(err, os.ErrDeadlineExceeded) { 523 break 524 } 525 526 downlink.logger.Warn("Failed to read packet from natConn", 527 zap.Stringer("clientAddress", downlink.clientAddrPort), 528 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 529 zap.String("client", downlink.clientName), 530 zap.Int("packetLength", n), 531 zap.Error(err), 532 ) 533 continue 534 } 535 err = conn.ParseFlagsForError(flags) 536 if err != nil { 537 downlink.logger.Warn("Failed to read packet from natConn", 538 zap.Stringer("clientAddress", downlink.clientAddrPort), 539 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 540 zap.String("client", downlink.clientName), 541 zap.Int("packetLength", n), 542 zap.Error(err), 543 ) 544 continue 545 } 546 547 payloadSourceAddrPort, payloadStart, payloadLength, err := downlink.natConnUnpacker.UnpackInPlace(packetBuf, packetSourceAddrPort, headroom.Front, n) 548 if err != nil { 549 downlink.logger.Warn("Failed to unpack packet from natConn", 550 zap.Stringer("clientAddress", downlink.clientAddrPort), 551 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 552 zap.String("client", downlink.clientName), 553 zap.Int("packetLength", n), 554 zap.Error(err), 555 ) 556 continue 557 } 558 559 packetStart, packetLength, err := downlink.serverConnPacker.PackInPlace(packetBuf, payloadSourceAddrPort, payloadStart, payloadLength, maxClientPacketSize) 560 if err != nil { 561 downlink.logger.Warn("Failed to pack packet for serverConn", 562 zap.Stringer("clientAddress", downlink.clientAddrPort), 563 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 564 zap.String("client", downlink.clientName), 565 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 566 zap.Int("payloadLength", payloadLength), 567 zap.Int("maxClientPacketSize", maxClientPacketSize), 568 zap.Error(err), 569 ) 570 continue 571 } 572 573 if cpp := downlink.clientPktinfo.Load(); cpp != clientPktinfop { 574 clientPktinfo = *cpp 575 clientPktinfop = cpp 576 } 577 578 _, _, err = downlink.serverConn.WriteMsgUDPAddrPort(packetBuf[packetStart:packetStart+packetLength], clientPktinfo, downlink.clientAddrPort) 579 if err != nil { 580 downlink.logger.Warn("Failed to write packet to serverConn", 581 zap.Stringer("clientAddress", downlink.clientAddrPort), 582 zap.Stringer("packetSourceAddress", packetSourceAddrPort), 583 zap.String("client", downlink.clientName), 584 zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), 585 zap.Int("packetLength", packetLength), 586 zap.Error(err), 587 ) 588 } 589 590 packetsSent++ 591 payloadBytesSent += uint64(payloadLength) 592 } 593 594 downlink.logger.Info("Finished relay serverConn <- natConn", 595 zap.Stringer("clientAddress", downlink.clientAddrPort), 596 zap.String("client", downlink.clientName), 597 zap.Uint64("packetsSent", packetsSent), 598 zap.Uint64("payloadBytesSent", payloadBytesSent), 599 ) 600 601 s.collector.CollectUDPSessionDownlink("", packetsSent, payloadBytesSent) 602 } 603 604 // getQueuedPacket retrieves a queued packet from the pool. 605 func (s *UDPNATRelay) getQueuedPacket() *natQueuedPacket { 606 return s.queuedPacketPool.Get().(*natQueuedPacket) 607 } 608 609 // putQueuedPacket puts the queued packet back into the pool. 610 func (s *UDPNATRelay) putQueuedPacket(queuedPacket *natQueuedPacket) { 611 s.queuedPacketPool.Put(queuedPacket) 612 } 613 614 // Stop implements the Service Stop method. 615 func (s *UDPNATRelay) Stop() error { 616 for i := range s.listeners { 617 lnc := &s.listeners[i] 618 if err := lnc.serverConn.SetReadDeadline(conn.ALongTimeAgo); err != nil { 619 lnc.logger.Warn("Failed to set read deadline on serverConn", zap.Error(err)) 620 } 621 } 622 623 // Wait for serverConn receive goroutines to exit, 624 // so there won't be any new sessions added to the table. 625 s.mwg.Wait() 626 627 s.mu.Lock() 628 for clientAddrPort, entry := range s.table { 629 natConn := entry.state.Swap(entry.serverConn) 630 if natConn == nil { 631 continue 632 } 633 634 if err := natConn.SetReadDeadline(conn.ALongTimeAgo); err != nil { 635 entry.logger.Warn("Failed to set read deadline on natConn", 636 zap.Stringer("clientAddress", clientAddrPort), 637 zap.Error(err), 638 ) 639 } 640 } 641 s.mu.Unlock() 642 643 // Wait for all relay goroutines to exit before closing serverConn, 644 // so in-flight packets can be written out. 645 s.wg.Wait() 646 647 for i := range s.listeners { 648 lnc := &s.listeners[i] 649 if err := lnc.serverConn.Close(); err != nil { 650 lnc.logger.Warn("Failed to close serverConn", zap.Error(err)) 651 } 652 } 653 654 return nil 655 }