gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/udp/endpoint.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package udp 16 17 import ( 18 "bytes" 19 "fmt" 20 "io" 21 "math" 22 "time" 23 24 "gvisor.dev/gvisor/pkg/sync" 25 "gvisor.dev/gvisor/pkg/tcpip" 26 "gvisor.dev/gvisor/pkg/tcpip/checksum" 27 "gvisor.dev/gvisor/pkg/tcpip/header" 28 "gvisor.dev/gvisor/pkg/tcpip/ports" 29 "gvisor.dev/gvisor/pkg/tcpip/stack" 30 "gvisor.dev/gvisor/pkg/tcpip/transport" 31 "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" 32 "gvisor.dev/gvisor/pkg/waiter" 33 ) 34 35 // +stateify savable 36 type udpPacket struct { 37 udpPacketEntry 38 netProto tcpip.NetworkProtocolNumber 39 senderAddress tcpip.FullAddress 40 destinationAddress tcpip.FullAddress 41 packetInfo tcpip.IPPacketInfo 42 pkt *stack.PacketBuffer 43 receivedAt time.Time `state:".(int64)"` 44 // tosOrTClass stores either the Type of Service for IPv4 or the Traffic Class 45 // for IPv6. 46 tosOrTClass uint8 47 // ttlOrHopLimit stores either the TTL for IPv4 or the HopLimit for IPv6 48 ttlOrHopLimit uint8 49 } 50 51 // endpoint represents a UDP endpoint. This struct serves as the interface 52 // between users of the endpoint and the protocol implementation; it is legal to 53 // have concurrent goroutines make calls into the endpoint, they are properly 54 // synchronized. 55 // 56 // It implements tcpip.Endpoint. 57 // 58 // +stateify savable 59 type endpoint struct { 60 tcpip.DefaultSocketOptionsHandler 61 62 // The following fields are initialized at creation time and do not 63 // change throughout the lifetime of the endpoint. 64 stack *stack.Stack `state:"manual"` 65 waiterQueue *waiter.Queue 66 uniqueID uint64 67 net network.Endpoint 68 stats tcpip.TransportEndpointStats 69 ops tcpip.SocketOptions 70 71 // The following fields are used to manage the receive queue, and are 72 // protected by rcvMu. 73 rcvMu sync.Mutex `state:"nosave"` 74 rcvReady bool 75 rcvList udpPacketList 76 rcvBufSize int 77 rcvClosed bool 78 79 lastErrorMu sync.Mutex `state:"nosave"` 80 lastError tcpip.Error 81 82 // The following fields are protected by the mu mutex. 83 mu sync.RWMutex `state:"nosave"` 84 portFlags ports.Flags 85 86 // Values used to reserve a port or register a transport endpoint. 87 // (which ever happens first). 88 boundBindToDevice tcpip.NICID 89 boundPortFlags ports.Flags 90 91 readShutdown bool 92 93 // effectiveNetProtos contains the network protocols actually in use. In 94 // most cases it will only contain "netProto", but in cases like IPv6 95 // endpoints with v6only set to false, this could include multiple 96 // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., 97 // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped 98 // address). 99 effectiveNetProtos []tcpip.NetworkProtocolNumber 100 101 // frozen indicates if the packets should be delivered to the endpoint 102 // during restore. 103 frozen bool 104 105 localPort uint16 106 remotePort uint16 107 } 108 109 func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { 110 e := &endpoint{ 111 stack: s, 112 waiterQueue: waiterQueue, 113 uniqueID: s.UniqueID(), 114 } 115 e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) 116 e.ops.SetMulticastLoop(true) 117 e.ops.SetSendBufferSize(32*1024, false /* notify */) 118 e.ops.SetReceiveBufferSize(32*1024, false /* notify */) 119 e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops, waiterQueue) 120 121 // Override with stack defaults. 122 var ss tcpip.SendBufferSizeOption 123 if err := s.Option(&ss); err == nil { 124 e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) 125 } 126 127 var rs tcpip.ReceiveBufferSizeOption 128 if err := s.Option(&rs); err == nil { 129 e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) 130 } 131 132 return e 133 } 134 135 // WakeupWriters implements tcpip.SocketOptionsHandler. 136 func (e *endpoint) WakeupWriters() { 137 e.net.MaybeSignalWritable() 138 } 139 140 // UniqueID implements stack.TransportEndpoint. 141 func (e *endpoint) UniqueID() uint64 { 142 return e.uniqueID 143 } 144 145 func (e *endpoint) LastError() tcpip.Error { 146 e.lastErrorMu.Lock() 147 defer e.lastErrorMu.Unlock() 148 149 err := e.lastError 150 e.lastError = nil 151 return err 152 } 153 154 // UpdateLastError implements tcpip.SocketOptionsHandler. 155 func (e *endpoint) UpdateLastError(err tcpip.Error) { 156 e.lastErrorMu.Lock() 157 e.lastError = err 158 e.lastErrorMu.Unlock() 159 } 160 161 // Abort implements stack.TransportEndpoint. 162 func (e *endpoint) Abort() { 163 e.Close() 164 } 165 166 // Close puts the endpoint in a closed state and frees all resources 167 // associated with it. 168 func (e *endpoint) Close() { 169 e.mu.Lock() 170 171 switch state := e.net.State(); state { 172 case transport.DatagramEndpointStateInitial: 173 case transport.DatagramEndpointStateClosed: 174 e.mu.Unlock() 175 return 176 case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: 177 id := e.net.Info().ID 178 id.LocalPort = e.localPort 179 id.RemotePort = e.remotePort 180 e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice) 181 portRes := ports.Reservation{ 182 Networks: e.effectiveNetProtos, 183 Transport: ProtocolNumber, 184 Addr: id.LocalAddress, 185 Port: id.LocalPort, 186 Flags: e.boundPortFlags, 187 BindToDevice: e.boundBindToDevice, 188 Dest: tcpip.FullAddress{}, 189 } 190 e.stack.ReleasePort(portRes) 191 e.boundBindToDevice = 0 192 e.boundPortFlags = ports.Flags{} 193 default: 194 panic(fmt.Sprintf("unhandled state = %s", state)) 195 } 196 197 // Close the receive list and drain it. 198 e.rcvMu.Lock() 199 e.rcvClosed = true 200 e.rcvBufSize = 0 201 for !e.rcvList.Empty() { 202 p := e.rcvList.Front() 203 e.rcvList.Remove(p) 204 p.pkt.DecRef() 205 } 206 e.rcvMu.Unlock() 207 208 e.net.Shutdown() 209 e.net.Close() 210 e.readShutdown = true 211 e.mu.Unlock() 212 213 e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) 214 } 215 216 // ModerateRecvBuf implements tcpip.Endpoint. 217 func (*endpoint) ModerateRecvBuf(int) {} 218 219 // Read implements tcpip.Endpoint. 220 func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { 221 if err := e.LastError(); err != nil { 222 return tcpip.ReadResult{}, err 223 } 224 225 e.rcvMu.Lock() 226 227 if e.rcvList.Empty() { 228 var err tcpip.Error = &tcpip.ErrWouldBlock{} 229 if e.rcvClosed { 230 e.stats.ReadErrors.ReadClosed.Increment() 231 err = &tcpip.ErrClosedForReceive{} 232 } 233 e.rcvMu.Unlock() 234 return tcpip.ReadResult{}, err 235 } 236 237 p := e.rcvList.Front() 238 if !opts.Peek { 239 e.rcvList.Remove(p) 240 defer p.pkt.DecRef() 241 e.rcvBufSize -= p.pkt.Data().Size() 242 } 243 e.rcvMu.Unlock() 244 245 // Control Messages 246 // TODO(https://gvisor.dev/issue/7012): Share control message code with other 247 // network endpoints. 248 cm := tcpip.ReceivableControlMessages{ 249 HasTimestamp: true, 250 Timestamp: p.receivedAt, 251 } 252 switch p.netProto { 253 case header.IPv4ProtocolNumber: 254 if e.ops.GetReceiveTOS() { 255 cm.HasTOS = true 256 cm.TOS = p.tosOrTClass 257 } 258 if e.ops.GetReceiveTTL() { 259 cm.HasTTL = true 260 cm.TTL = p.ttlOrHopLimit 261 } 262 if e.ops.GetReceivePacketInfo() { 263 cm.HasIPPacketInfo = true 264 cm.PacketInfo = p.packetInfo 265 } 266 case header.IPv6ProtocolNumber: 267 if e.ops.GetReceiveTClass() { 268 cm.HasTClass = true 269 // Although TClass is an 8-bit value it's read in the CMsg as a uint32. 270 cm.TClass = uint32(p.tosOrTClass) 271 } 272 if e.ops.GetReceiveHopLimit() { 273 cm.HasHopLimit = true 274 cm.HopLimit = p.ttlOrHopLimit 275 } 276 if e.ops.GetIPv6ReceivePacketInfo() { 277 cm.HasIPv6PacketInfo = true 278 cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{ 279 NIC: p.packetInfo.NIC, 280 Addr: p.packetInfo.DestinationAddr, 281 } 282 } 283 default: 284 panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto)) 285 } 286 287 if e.ops.GetReceiveOriginalDstAddress() { 288 cm.HasOriginalDstAddress = true 289 cm.OriginalDstAddress = p.destinationAddress 290 } 291 292 // Read Result 293 res := tcpip.ReadResult{ 294 Total: p.pkt.Data().Size(), 295 ControlMessages: cm, 296 } 297 if opts.NeedRemoteAddr { 298 res.RemoteAddr = p.senderAddress 299 } 300 301 n, err := p.pkt.Data().ReadTo(dst, opts.Peek) 302 if n == 0 && err != nil { 303 return res, &tcpip.ErrBadBuffer{} 304 } 305 res.Count = n 306 return res, nil 307 } 308 309 // prepareForWriteInner prepares the endpoint for sending data. In particular, 310 // it binds it if it's still in the initial state. To do so, it must first 311 // reacquire the mutex in exclusive mode. 312 // 313 // Returns true for retry if preparation should be retried. 314 // +checklocksread:e.mu 315 func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { 316 switch e.net.State() { 317 case transport.DatagramEndpointStateInitial: 318 case transport.DatagramEndpointStateConnected: 319 return false, nil 320 321 case transport.DatagramEndpointStateBound: 322 if to == nil { 323 return false, &tcpip.ErrDestinationRequired{} 324 } 325 return false, nil 326 default: 327 return false, &tcpip.ErrInvalidEndpointState{} 328 } 329 330 e.mu.RUnlock() 331 e.mu.Lock() 332 defer e.mu.DowngradeLock() 333 334 // The state changed when we released the shared locked and re-acquired 335 // it in exclusive mode. Try again. 336 if e.net.State() != transport.DatagramEndpointStateInitial { 337 return true, nil 338 } 339 340 // The state is still 'initial', so try to bind the endpoint. 341 if err := e.bindLocked(tcpip.FullAddress{}); err != nil { 342 return false, err 343 } 344 345 return true, nil 346 } 347 348 var _ tcpip.EndpointWithPreflight = (*endpoint)(nil) 349 350 // Validates the passed WriteOptions and prepares the endpoint for writes 351 // using those options. If the endpoint is unbound and the `To` address 352 // is specified, binds the endpoint to that address. 353 func (e *endpoint) Preflight(opts tcpip.WriteOptions) tcpip.Error { 354 var r bytes.Reader 355 udpInfo, err := e.prepareForWrite(&r, opts) 356 if err == nil { 357 udpInfo.ctx.Release() 358 } 359 return err 360 } 361 362 // Write writes data to the endpoint's peer. This method does not block 363 // if the data cannot be written. 364 func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { 365 n, err := e.write(p, opts) 366 switch err.(type) { 367 case nil: 368 e.stats.PacketsSent.Increment() 369 case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: 370 e.stats.WriteErrors.InvalidArgs.Increment() 371 case *tcpip.ErrClosedForSend: 372 e.stats.WriteErrors.WriteClosed.Increment() 373 case *tcpip.ErrInvalidEndpointState: 374 e.stats.WriteErrors.InvalidEndpointState.Increment() 375 case *tcpip.ErrHostUnreachable, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: 376 // Errors indicating any problem with IP routing of the packet. 377 e.stats.SendErrors.NoRoute.Increment() 378 default: 379 // For all other errors when writing to the network layer. 380 e.stats.SendErrors.SendToNetworkFailed.Increment() 381 } 382 return n, err 383 } 384 385 func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) { 386 e.mu.RLock() 387 defer e.mu.RUnlock() 388 389 // Prepare for write. 390 for { 391 retry, err := e.prepareForWriteInner(opts.To) 392 if err != nil { 393 return udpPacketInfo{}, err 394 } 395 396 if !retry { 397 break 398 } 399 } 400 401 dst, connected := e.net.GetRemoteAddress() 402 dst.Port = e.remotePort 403 if opts.To != nil { 404 if opts.To.Port == 0 { 405 // Port 0 is an invalid port to send to. 406 return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{} 407 } 408 409 dst = *opts.To 410 } else if !connected { 411 return udpPacketInfo{}, &tcpip.ErrDestinationRequired{} 412 } 413 414 ctx, err := e.net.AcquireContextForWrite(opts) 415 if err != nil { 416 return udpPacketInfo{}, err 417 } 418 419 if p.Len() > header.UDPMaximumPacketSize { 420 // Native linux behaviour differs for IPv4 and IPv6 packets; IPv4 packet 421 // errors aren't report to the error queue at all. 422 if ctx.PacketInfo().NetProto == header.IPv6ProtocolNumber { 423 so := e.SocketOptions() 424 if so.GetIPv6RecvError() { 425 so.QueueLocalErr( 426 &tcpip.ErrMessageTooLong{}, 427 e.net.NetProto(), 428 uint32(p.Len()), 429 dst, 430 nil, 431 ) 432 } 433 } 434 ctx.Release() 435 return udpPacketInfo{}, &tcpip.ErrMessageTooLong{} 436 } 437 438 return udpPacketInfo{ 439 ctx: ctx, 440 localPort: e.localPort, 441 remotePort: dst.Port, 442 }, nil 443 } 444 445 func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { 446 // Do not hold lock when sending as loopback is synchronous and if the UDP 447 // datagram ends up generating an ICMP response then it can result in a 448 // deadlock where the ICMP response handling ends up acquiring this endpoint's 449 // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a 450 // deadlock if another caller is trying to acquire e.mu in exclusive mode w/ 451 // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the 452 // lock can be eventually acquired. 453 // 454 // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read 455 // locking is prohibited. 456 457 if err := e.LastError(); err != nil { 458 return 0, err 459 } 460 461 udpInfo, err := e.prepareForWrite(p, opts) 462 if err != nil { 463 return 0, err 464 } 465 defer udpInfo.ctx.Release() 466 467 dataSz := p.Len() 468 pktInfo := udpInfo.ctx.PacketInfo() 469 pkt := udpInfo.ctx.TryNewPacketBufferFromPayloader(header.UDPMinimumSize+int(pktInfo.MaxHeaderLength), p) 470 if pkt == nil { 471 return 0, &tcpip.ErrWouldBlock{} 472 } 473 defer pkt.DecRef() 474 475 // Initialize the UDP header. 476 udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) 477 pkt.TransportProtocolNumber = ProtocolNumber 478 479 length := uint16(pkt.Size()) 480 udp.Encode(&header.UDPFields{ 481 SrcPort: udpInfo.localPort, 482 DstPort: udpInfo.remotePort, 483 Length: length, 484 }) 485 486 // Set the checksum field unless TX checksum offload is enabled. 487 // On IPv4, UDP checksum is optional, and a zero value indicates the 488 // transmitter skipped the checksum generation (RFC768). 489 // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). 490 if pktInfo.RequiresTXTransportChecksum && 491 (!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) { 492 xsum := udp.CalculateChecksum(checksum.Combine( 493 header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length), 494 pkt.Data().Checksum(), 495 )) 496 // As per RFC 768 page 2, 497 // 498 // Checksum is the 16-bit one's complement of the one's complement sum of 499 // a pseudo header of information from the IP header, the UDP header, and 500 // the data, padded with zero octets at the end (if necessary) to make a 501 // multiple of two octets. 502 // 503 // The pseudo header conceptually prefixed to the UDP header contains the 504 // source address, the destination address, the protocol, and the UDP 505 // length. This information gives protection against misrouted datagrams. 506 // This checksum procedure is the same as is used in TCP. 507 // 508 // If the computed checksum is zero, it is transmitted as all ones (the 509 // equivalent in one's complement arithmetic). An all zero transmitted 510 // checksum value means that the transmitter generated no checksum (for 511 // debugging or for higher level protocols that don't care). 512 // 513 // To avoid the zero value, we only calculate the one's complement of the 514 // one's complement sum if the sum is not all ones. 515 if xsum != math.MaxUint16 { 516 xsum = ^xsum 517 } 518 udp.SetChecksum(xsum) 519 } 520 if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { 521 e.stack.Stats().UDP.PacketSendErrors.Increment() 522 return 0, err 523 } 524 525 // Track count of packets sent. 526 e.stack.Stats().UDP.PacketsSent.Increment() 527 return int64(dataSz), nil 528 } 529 530 // OnReuseAddressSet implements tcpip.SocketOptionsHandler. 531 func (e *endpoint) OnReuseAddressSet(v bool) { 532 e.mu.Lock() 533 e.portFlags.MostRecent = v 534 e.mu.Unlock() 535 } 536 537 // OnReusePortSet implements tcpip.SocketOptionsHandler. 538 func (e *endpoint) OnReusePortSet(v bool) { 539 e.mu.Lock() 540 e.portFlags.LoadBalanced = v 541 e.mu.Unlock() 542 } 543 544 // SetSockOptInt implements tcpip.Endpoint. 545 func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { 546 return e.net.SetSockOptInt(opt, v) 547 } 548 549 var _ tcpip.SocketOptionsHandler = (*endpoint)(nil) 550 551 // HasNIC implements tcpip.SocketOptionsHandler. 552 func (e *endpoint) HasNIC(id int32) bool { 553 return e.stack.HasNIC(tcpip.NICID(id)) 554 } 555 556 // SetSockOpt implements tcpip.Endpoint. 557 func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { 558 return e.net.SetSockOpt(opt) 559 } 560 561 // GetSockOptInt implements tcpip.Endpoint. 562 func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { 563 switch opt { 564 case tcpip.ReceiveQueueSizeOption: 565 v := 0 566 e.rcvMu.Lock() 567 if !e.rcvList.Empty() { 568 p := e.rcvList.Front() 569 v = p.pkt.Data().Size() 570 } 571 e.rcvMu.Unlock() 572 return v, nil 573 574 default: 575 return e.net.GetSockOptInt(opt) 576 } 577 } 578 579 // GetSockOpt implements tcpip.Endpoint. 580 func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { 581 return e.net.GetSockOpt(opt) 582 } 583 584 // udpPacketInfo holds information needed to send a UDP packet. 585 type udpPacketInfo struct { 586 ctx network.WriteContext 587 localPort uint16 588 remotePort uint16 589 } 590 591 // Disconnect implements tcpip.Endpoint. 592 func (e *endpoint) Disconnect() tcpip.Error { 593 e.mu.Lock() 594 defer e.mu.Unlock() 595 596 if e.net.State() != transport.DatagramEndpointStateConnected { 597 return nil 598 } 599 var ( 600 id stack.TransportEndpointID 601 btd tcpip.NICID 602 ) 603 604 // We change this value below and we need the old value to unregister 605 // the endpoint. 606 boundPortFlags := e.boundPortFlags 607 608 // Exclude ephemerally bound endpoints. 609 info := e.net.Info() 610 info.ID.LocalPort = e.localPort 611 info.ID.RemotePort = e.remotePort 612 if e.net.WasBound() { 613 var err tcpip.Error 614 id = stack.TransportEndpointID{ 615 LocalPort: info.ID.LocalPort, 616 LocalAddress: info.ID.LocalAddress, 617 } 618 id, btd, err = e.registerWithStack(e.effectiveNetProtos, id) 619 if err != nil { 620 return err 621 } 622 boundPortFlags = e.boundPortFlags 623 } else { 624 if info.ID.LocalPort != 0 { 625 // Release the ephemeral port. 626 portRes := ports.Reservation{ 627 Networks: e.effectiveNetProtos, 628 Transport: ProtocolNumber, 629 Addr: info.ID.LocalAddress, 630 Port: info.ID.LocalPort, 631 Flags: boundPortFlags, 632 BindToDevice: e.boundBindToDevice, 633 Dest: tcpip.FullAddress{}, 634 } 635 e.stack.ReleasePort(portRes) 636 e.boundPortFlags = ports.Flags{} 637 } 638 } 639 640 e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice) 641 e.boundBindToDevice = btd 642 e.localPort = id.LocalPort 643 e.remotePort = id.RemotePort 644 645 e.net.Disconnect() 646 647 return nil 648 } 649 650 // Connect connects the endpoint to its peer. Specifying a NIC is optional. 651 func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { 652 e.mu.Lock() 653 defer e.mu.Unlock() 654 655 err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error { 656 nextID.LocalPort = e.localPort 657 nextID.RemotePort = addr.Port 658 659 // Even if we're connected, this endpoint can still be used to send 660 // packets on a different network protocol, so we register both even if 661 // v6only is set to false and this is an ipv6 endpoint. 662 netProtos := []tcpip.NetworkProtocolNumber{netProto} 663 if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) { 664 netProtos = []tcpip.NetworkProtocolNumber{ 665 header.IPv4ProtocolNumber, 666 header.IPv6ProtocolNumber, 667 } 668 } 669 670 oldPortFlags := e.boundPortFlags 671 672 // Remove the old registration. 673 if e.localPort != 0 { 674 previousID.LocalPort = e.localPort 675 previousID.RemotePort = e.remotePort 676 e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice) 677 } 678 679 nextID, btd, err := e.registerWithStack(netProtos, nextID) 680 if err != nil { 681 return err 682 } 683 684 e.localPort = nextID.LocalPort 685 e.remotePort = nextID.RemotePort 686 e.boundBindToDevice = btd 687 e.effectiveNetProtos = netProtos 688 return nil 689 }) 690 if err != nil { 691 return err 692 } 693 694 e.rcvMu.Lock() 695 e.rcvReady = true 696 e.rcvMu.Unlock() 697 return nil 698 } 699 700 // ConnectEndpoint is not supported. 701 func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { 702 return &tcpip.ErrInvalidEndpointState{} 703 } 704 705 // Shutdown closes the read and/or write end of the endpoint connection 706 // to its peer. 707 func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { 708 e.mu.Lock() 709 defer e.mu.Unlock() 710 711 switch state := e.net.State(); state { 712 case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: 713 return &tcpip.ErrNotConnected{} 714 case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: 715 default: 716 panic(fmt.Sprintf("unhandled state = %s", state)) 717 } 718 719 if flags&tcpip.ShutdownWrite != 0 { 720 if err := e.net.Shutdown(); err != nil { 721 return err 722 } 723 } 724 725 if flags&tcpip.ShutdownRead != 0 { 726 e.readShutdown = true 727 728 e.rcvMu.Lock() 729 wasClosed := e.rcvClosed 730 e.rcvClosed = true 731 e.rcvMu.Unlock() 732 733 if !wasClosed { 734 e.waiterQueue.Notify(waiter.ReadableEvents) 735 } 736 } 737 738 if e.net.State() == transport.DatagramEndpointStateBound { 739 return &tcpip.ErrNotConnected{} 740 } 741 return nil 742 } 743 744 // Listen is not supported by UDP, it just fails. 745 func (*endpoint) Listen(int) tcpip.Error { 746 return &tcpip.ErrNotSupported{} 747 } 748 749 // Accept is not supported by UDP, it just fails. 750 func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { 751 return nil, nil, &tcpip.ErrNotSupported{} 752 } 753 754 func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) { 755 bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) 756 if e.localPort == 0 { 757 portRes := ports.Reservation{ 758 Networks: netProtos, 759 Transport: ProtocolNumber, 760 Addr: id.LocalAddress, 761 Port: id.LocalPort, 762 Flags: e.portFlags, 763 BindToDevice: bindToDevice, 764 Dest: tcpip.FullAddress{}, 765 } 766 port, err := e.stack.ReservePort(e.stack.SecureRNG(), portRes, nil /* testPort */) 767 if err != nil { 768 return id, bindToDevice, err 769 } 770 id.LocalPort = port 771 } 772 e.boundPortFlags = e.portFlags 773 774 err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) 775 if err != nil { 776 portRes := ports.Reservation{ 777 Networks: netProtos, 778 Transport: ProtocolNumber, 779 Addr: id.LocalAddress, 780 Port: id.LocalPort, 781 Flags: e.boundPortFlags, 782 BindToDevice: bindToDevice, 783 Dest: tcpip.FullAddress{}, 784 } 785 e.stack.ReleasePort(portRes) 786 e.boundPortFlags = ports.Flags{} 787 } 788 return id, bindToDevice, err 789 } 790 791 func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { 792 // Don't allow binding once endpoint is not in the initial state 793 // anymore. 794 if e.net.State() != transport.DatagramEndpointStateInitial { 795 return &tcpip.ErrInvalidEndpointState{} 796 } 797 798 err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error { 799 // Expand netProtos to include v4 and v6 if the caller is binding to a 800 // wildcard (empty) address, and this is an IPv6 endpoint with v6only 801 // set to false. 802 netProtos := []tcpip.NetworkProtocolNumber{boundNetProto} 803 if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == (tcpip.Address{}) && e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) { 804 netProtos = []tcpip.NetworkProtocolNumber{ 805 header.IPv6ProtocolNumber, 806 header.IPv4ProtocolNumber, 807 } 808 } 809 810 id := stack.TransportEndpointID{ 811 LocalPort: addr.Port, 812 LocalAddress: boundAddr, 813 } 814 id, btd, err := e.registerWithStack(netProtos, id) 815 if err != nil { 816 return err 817 } 818 819 e.localPort = id.LocalPort 820 e.boundBindToDevice = btd 821 e.effectiveNetProtos = netProtos 822 return nil 823 }) 824 if err != nil { 825 return err 826 } 827 828 e.rcvMu.Lock() 829 e.rcvReady = true 830 e.rcvMu.Unlock() 831 return nil 832 } 833 834 // Bind binds the endpoint to a specific local address and port. 835 // Specifying a NIC is optional. 836 func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { 837 e.mu.Lock() 838 defer e.mu.Unlock() 839 840 err := e.bindLocked(addr) 841 if err != nil { 842 return err 843 } 844 845 return nil 846 } 847 848 // GetLocalAddress returns the address to which the endpoint is bound. 849 func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { 850 e.mu.RLock() 851 defer e.mu.RUnlock() 852 853 addr := e.net.GetLocalAddress() 854 addr.Port = e.localPort 855 return addr, nil 856 } 857 858 // GetRemoteAddress returns the address to which the endpoint is connected. 859 func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { 860 e.mu.RLock() 861 defer e.mu.RUnlock() 862 863 addr, connected := e.net.GetRemoteAddress() 864 if !connected || e.remotePort == 0 { 865 return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} 866 } 867 868 addr.Port = e.remotePort 869 return addr, nil 870 } 871 872 // Readiness returns the current readiness of the endpoint. For example, if 873 // waiter.EventIn is set, the endpoint is immediately readable. 874 func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { 875 var result waiter.EventMask 876 877 if e.net.HasSendSpace() { 878 result |= waiter.WritableEvents & mask 879 } 880 881 // Determine if the endpoint is readable if requested. 882 if mask&waiter.ReadableEvents != 0 { 883 e.rcvMu.Lock() 884 if !e.rcvList.Empty() || e.rcvClosed { 885 result |= waiter.ReadableEvents 886 } 887 e.rcvMu.Unlock() 888 } 889 890 e.lastErrorMu.Lock() 891 hasError := e.lastError != nil 892 e.lastErrorMu.Unlock() 893 if hasError { 894 result |= waiter.EventErr 895 } 896 return result 897 } 898 899 // HandlePacket is called by the stack when new packets arrive to this transport 900 // endpoint. 901 func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { 902 // Get the header then trim it from the view. 903 hdr := header.UDP(pkt.TransportHeader().Slice()) 904 netHdr := pkt.Network() 905 lengthValid, csumValid := header.UDPValid( 906 hdr, 907 func() uint16 { return pkt.Data().Checksum() }, 908 uint16(pkt.Data().Size()), 909 pkt.NetworkProtocolNumber, 910 netHdr.SourceAddress(), 911 netHdr.DestinationAddress(), 912 pkt.RXChecksumValidated) 913 if !lengthValid { 914 // Malformed packet. 915 e.stack.Stats().UDP.MalformedPacketsReceived.Increment() 916 e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() 917 return 918 } 919 920 if !csumValid { 921 e.stack.Stats().UDP.ChecksumErrors.Increment() 922 e.stats.ReceiveErrors.ChecksumErrors.Increment() 923 return 924 } 925 926 e.stack.Stats().UDP.PacketsReceived.Increment() 927 e.stats.PacketsReceived.Increment() 928 929 e.rcvMu.Lock() 930 // Drop the packet if our buffer is not ready to receive packets. 931 if !e.rcvReady || e.rcvClosed { 932 e.rcvMu.Unlock() 933 e.stack.Stats().UDP.ReceiveBufferErrors.Increment() 934 e.stats.ReceiveErrors.ClosedReceiver.Increment() 935 return 936 } 937 938 rcvBufSize := e.ops.GetReceiveBufferSize() 939 // Drop the packet if our buffer is currently full. 940 if e.frozen || e.rcvBufSize >= int(rcvBufSize) { 941 e.rcvMu.Unlock() 942 e.stack.Stats().UDP.ReceiveBufferErrors.Increment() 943 e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() 944 return 945 } 946 947 wasEmpty := e.rcvBufSize == 0 948 949 // Push new packet into receive list and increment the buffer size. 950 packet := &udpPacket{ 951 netProto: pkt.NetworkProtocolNumber, 952 senderAddress: tcpip.FullAddress{ 953 NIC: pkt.NICID, 954 Addr: id.RemoteAddress, 955 Port: hdr.SourcePort(), 956 }, 957 destinationAddress: tcpip.FullAddress{ 958 NIC: pkt.NICID, 959 Addr: id.LocalAddress, 960 Port: hdr.DestinationPort(), 961 }, 962 pkt: pkt.IncRef(), 963 } 964 e.rcvList.PushBack(packet) 965 e.rcvBufSize += pkt.Data().Size() 966 967 // Save any useful information from the network header to the packet. 968 packet.tosOrTClass, _ = pkt.Network().TOS() 969 switch pkt.NetworkProtocolNumber { 970 case header.IPv4ProtocolNumber: 971 packet.ttlOrHopLimit = header.IPv4(pkt.NetworkHeader().Slice()).TTL() 972 case header.IPv6ProtocolNumber: 973 packet.ttlOrHopLimit = header.IPv6(pkt.NetworkHeader().Slice()).HopLimit() 974 } 975 976 // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast 977 // address. packetInfo.LocalAddr should hold a unicast address that can be 978 // used to respond to the incoming packet. 979 localAddr := pkt.Network().DestinationAddress() 980 packet.packetInfo.LocalAddr = localAddr 981 packet.packetInfo.DestinationAddr = localAddr 982 packet.packetInfo.NIC = pkt.NICID 983 packet.receivedAt = e.stack.Clock().Now() 984 985 e.rcvMu.Unlock() 986 987 // Notify any waiters that there's data to be read now. 988 if wasEmpty { 989 e.waiterQueue.Notify(waiter.ReadableEvents) 990 } 991 } 992 993 func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, pkt *stack.PacketBuffer) { 994 // Update last error first. 995 e.lastErrorMu.Lock() 996 e.lastError = err 997 e.lastErrorMu.Unlock() 998 999 var recvErr bool 1000 switch pkt.NetworkProtocolNumber { 1001 case header.IPv4ProtocolNumber: 1002 recvErr = e.SocketOptions().GetIPv4RecvError() 1003 case header.IPv6ProtocolNumber: 1004 recvErr = e.SocketOptions().GetIPv6RecvError() 1005 default: 1006 panic(fmt.Sprintf("unhandled network protocol number = %d", pkt.NetworkProtocolNumber)) 1007 } 1008 1009 if recvErr { 1010 // Linux passes the payload without the UDP header. 1011 payload := pkt.Data().AsRange().ToView() 1012 udp := header.UDP(payload.AsSlice()) 1013 if len(udp) >= header.UDPMinimumSize { 1014 payload.TrimFront(header.UDPMinimumSize) 1015 } 1016 1017 id := e.net.Info().ID 1018 e.mu.RLock() 1019 e.SocketOptions().QueueErr(&tcpip.SockError{ 1020 Err: err, 1021 Cause: transErr, 1022 Payload: payload, 1023 Dst: tcpip.FullAddress{ 1024 NIC: pkt.NICID, 1025 Addr: id.RemoteAddress, 1026 Port: e.remotePort, 1027 }, 1028 Offender: tcpip.FullAddress{ 1029 NIC: pkt.NICID, 1030 Addr: id.LocalAddress, 1031 Port: e.localPort, 1032 }, 1033 NetProto: pkt.NetworkProtocolNumber, 1034 }) 1035 e.mu.RUnlock() 1036 } 1037 1038 // Notify of the error. 1039 e.waiterQueue.Notify(waiter.EventErr) 1040 } 1041 1042 // HandleError implements stack.TransportEndpoint. 1043 func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketBuffer) { 1044 // TODO(gvisor.dev/issues/5270): Handle all transport errors. 1045 switch transErr.Kind() { 1046 case stack.DestinationPortUnreachableTransportError: 1047 if e.net.State() == transport.DatagramEndpointStateConnected { 1048 e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt) 1049 } 1050 } 1051 } 1052 1053 // State implements tcpip.Endpoint. 1054 func (e *endpoint) State() uint32 { 1055 return uint32(e.net.State()) 1056 } 1057 1058 // Info returns a copy of the endpoint info. 1059 func (e *endpoint) Info() tcpip.EndpointInfo { 1060 e.mu.RLock() 1061 defer e.mu.RUnlock() 1062 info := e.net.Info() 1063 info.ID.LocalPort = e.localPort 1064 info.ID.RemotePort = e.remotePort 1065 return &info 1066 } 1067 1068 // Stats returns a pointer to the endpoint stats. 1069 func (e *endpoint) Stats() tcpip.EndpointStats { 1070 return &e.stats 1071 } 1072 1073 // Wait implements tcpip.Endpoint. 1074 func (*endpoint) Wait() {} 1075 1076 // SetOwner implements tcpip.Endpoint. 1077 func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { 1078 e.net.SetOwner(owner) 1079 } 1080 1081 // SocketOptions implements tcpip.Endpoint. 1082 func (e *endpoint) SocketOptions() *tcpip.SocketOptions { 1083 return &e.ops 1084 } 1085 1086 // freeze prevents any more packets from being delivered to the endpoint. 1087 func (e *endpoint) freeze() { 1088 e.mu.Lock() 1089 e.frozen = true 1090 e.mu.Unlock() 1091 } 1092 1093 // thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows 1094 // new packets to be delivered again. 1095 func (e *endpoint) thaw() { 1096 e.mu.Lock() 1097 e.frozen = false 1098 e.mu.Unlock() 1099 }