github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/transport/udp/endpoint.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package udp 16 17 import ( 18 "bytes" 19 "fmt" 20 "io" 21 "math" 22 "time" 23 24 "github.com/nicocha30/gvisor-ligolo/pkg/buffer" 25 "github.com/nicocha30/gvisor-ligolo/pkg/sync" 26 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip" 27 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/checksum" 28 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header" 29 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/ports" 30 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/stack" 31 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/transport" 32 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/transport/internal/network" 33 "github.com/nicocha30/gvisor-ligolo/pkg/waiter" 34 ) 35 36 // +stateify savable 37 type udpPacket struct { 38 udpPacketEntry 39 netProto tcpip.NetworkProtocolNumber 40 senderAddress tcpip.FullAddress 41 destinationAddress tcpip.FullAddress 42 packetInfo tcpip.IPPacketInfo 43 pkt stack.PacketBufferPtr 44 receivedAt time.Time `state:".(int64)"` 45 // tosOrTClass stores either the Type of Service for IPv4 or the Traffic Class 46 // for IPv6. 47 tosOrTClass uint8 48 // ttlOrHopLimit stores either the TTL for IPv4 or the HopLimit for IPv6 49 ttlOrHopLimit uint8 50 } 51 52 // endpoint represents a UDP endpoint. This struct serves as the interface 53 // between users of the endpoint and the protocol implementation; it is legal to 54 // have concurrent goroutines make calls into the endpoint, they are properly 55 // synchronized. 56 // 57 // It implements tcpip.Endpoint. 58 // 59 // +stateify savable 60 type endpoint struct { 61 tcpip.DefaultSocketOptionsHandler 62 63 // The following fields are initialized at creation time and do not 64 // change throughout the lifetime of the endpoint. 65 stack *stack.Stack `state:"manual"` 66 waiterQueue *waiter.Queue 67 uniqueID uint64 68 net network.Endpoint 69 stats tcpip.TransportEndpointStats 70 ops tcpip.SocketOptions 71 72 // The following fields are used to manage the receive queue, and are 73 // protected by rcvMu. 74 rcvMu sync.Mutex `state:"nosave"` 75 rcvReady bool 76 rcvList udpPacketList 77 rcvBufSize int 78 rcvClosed bool 79 80 lastErrorMu sync.Mutex `state:"nosave"` 81 lastError tcpip.Error 82 83 // The following fields are protected by the mu mutex. 84 mu sync.RWMutex `state:"nosave"` 85 portFlags ports.Flags 86 87 // Values used to reserve a port or register a transport endpoint. 88 // (which ever happens first). 89 boundBindToDevice tcpip.NICID 90 boundPortFlags ports.Flags 91 92 readShutdown bool 93 94 // effectiveNetProtos contains the network protocols actually in use. In 95 // most cases it will only contain "netProto", but in cases like IPv6 96 // endpoints with v6only set to false, this could include multiple 97 // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., 98 // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped 99 // address). 100 effectiveNetProtos []tcpip.NetworkProtocolNumber 101 102 // frozen indicates if the packets should be delivered to the endpoint 103 // during restore. 104 frozen bool 105 106 localPort uint16 107 remotePort uint16 108 } 109 110 func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { 111 e := &endpoint{ 112 stack: s, 113 waiterQueue: waiterQueue, 114 uniqueID: s.UniqueID(), 115 } 116 e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) 117 e.ops.SetMulticastLoop(true) 118 e.ops.SetSendBufferSize(32*1024, false /* notify */) 119 e.ops.SetReceiveBufferSize(32*1024, false /* notify */) 120 e.net.Init(s, netProto, header.UDPProtocolNumber, &e.ops, waiterQueue) 121 122 // Override with stack defaults. 123 var ss tcpip.SendBufferSizeOption 124 if err := s.Option(&ss); err == nil { 125 e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */) 126 } 127 128 var rs tcpip.ReceiveBufferSizeOption 129 if err := s.Option(&rs); err == nil { 130 e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */) 131 } 132 133 return e 134 } 135 136 // WakeupWriters implements tcpip.SocketOptionsHandler. 137 func (e *endpoint) WakeupWriters() { 138 e.net.MaybeSignalWritable() 139 } 140 141 // UniqueID implements stack.TransportEndpoint. 142 func (e *endpoint) UniqueID() uint64 { 143 return e.uniqueID 144 } 145 146 func (e *endpoint) LastError() tcpip.Error { 147 e.lastErrorMu.Lock() 148 defer e.lastErrorMu.Unlock() 149 150 err := e.lastError 151 e.lastError = nil 152 return err 153 } 154 155 // UpdateLastError implements tcpip.SocketOptionsHandler. 156 func (e *endpoint) UpdateLastError(err tcpip.Error) { 157 e.lastErrorMu.Lock() 158 e.lastError = err 159 e.lastErrorMu.Unlock() 160 } 161 162 // Abort implements stack.TransportEndpoint. 163 func (e *endpoint) Abort() { 164 e.Close() 165 } 166 167 // Close puts the endpoint in a closed state and frees all resources 168 // associated with it. 169 func (e *endpoint) Close() { 170 e.mu.Lock() 171 172 switch state := e.net.State(); state { 173 case transport.DatagramEndpointStateInitial: 174 case transport.DatagramEndpointStateClosed: 175 e.mu.Unlock() 176 return 177 case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: 178 id := e.net.Info().ID 179 id.LocalPort = e.localPort 180 id.RemotePort = e.remotePort 181 e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, id, e, e.boundPortFlags, e.boundBindToDevice) 182 portRes := ports.Reservation{ 183 Networks: e.effectiveNetProtos, 184 Transport: ProtocolNumber, 185 Addr: id.LocalAddress, 186 Port: id.LocalPort, 187 Flags: e.boundPortFlags, 188 BindToDevice: e.boundBindToDevice, 189 Dest: tcpip.FullAddress{}, 190 } 191 e.stack.ReleasePort(portRes) 192 e.boundBindToDevice = 0 193 e.boundPortFlags = ports.Flags{} 194 default: 195 panic(fmt.Sprintf("unhandled state = %s", state)) 196 } 197 198 // Close the receive list and drain it. 199 e.rcvMu.Lock() 200 e.rcvClosed = true 201 e.rcvBufSize = 0 202 for !e.rcvList.Empty() { 203 p := e.rcvList.Front() 204 e.rcvList.Remove(p) 205 p.pkt.DecRef() 206 } 207 e.rcvMu.Unlock() 208 209 e.net.Shutdown() 210 e.net.Close() 211 e.readShutdown = true 212 e.mu.Unlock() 213 214 e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) 215 } 216 217 // ModerateRecvBuf implements tcpip.Endpoint. 218 func (*endpoint) ModerateRecvBuf(int) {} 219 220 // Read implements tcpip.Endpoint. 221 func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { 222 if err := e.LastError(); err != nil { 223 return tcpip.ReadResult{}, err 224 } 225 226 e.rcvMu.Lock() 227 228 if e.rcvList.Empty() { 229 var err tcpip.Error = &tcpip.ErrWouldBlock{} 230 if e.rcvClosed { 231 e.stats.ReadErrors.ReadClosed.Increment() 232 err = &tcpip.ErrClosedForReceive{} 233 } 234 e.rcvMu.Unlock() 235 return tcpip.ReadResult{}, err 236 } 237 238 p := e.rcvList.Front() 239 if !opts.Peek { 240 e.rcvList.Remove(p) 241 defer p.pkt.DecRef() 242 e.rcvBufSize -= p.pkt.Data().Size() 243 } 244 e.rcvMu.Unlock() 245 246 // Control Messages 247 // TODO(https://gvisor.dev/issue/7012): Share control message code with other 248 // network endpoints. 249 cm := tcpip.ReceivableControlMessages{ 250 HasTimestamp: true, 251 Timestamp: p.receivedAt, 252 } 253 switch p.netProto { 254 case header.IPv4ProtocolNumber: 255 if e.ops.GetReceiveTOS() { 256 cm.HasTOS = true 257 cm.TOS = p.tosOrTClass 258 } 259 if e.ops.GetReceiveTTL() { 260 cm.HasTTL = true 261 cm.TTL = p.ttlOrHopLimit 262 } 263 if e.ops.GetReceivePacketInfo() { 264 cm.HasIPPacketInfo = true 265 cm.PacketInfo = p.packetInfo 266 } 267 case header.IPv6ProtocolNumber: 268 if e.ops.GetReceiveTClass() { 269 cm.HasTClass = true 270 // Although TClass is an 8-bit value it's read in the CMsg as a uint32. 271 cm.TClass = uint32(p.tosOrTClass) 272 } 273 if e.ops.GetReceiveHopLimit() { 274 cm.HasHopLimit = true 275 cm.HopLimit = p.ttlOrHopLimit 276 } 277 if e.ops.GetIPv6ReceivePacketInfo() { 278 cm.HasIPv6PacketInfo = true 279 cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{ 280 NIC: p.packetInfo.NIC, 281 Addr: p.packetInfo.DestinationAddr, 282 } 283 } 284 default: 285 panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto)) 286 } 287 288 if e.ops.GetReceiveOriginalDstAddress() { 289 cm.HasOriginalDstAddress = true 290 cm.OriginalDstAddress = p.destinationAddress 291 } 292 293 // Read Result 294 res := tcpip.ReadResult{ 295 Total: p.pkt.Data().Size(), 296 ControlMessages: cm, 297 } 298 if opts.NeedRemoteAddr { 299 res.RemoteAddr = p.senderAddress 300 } 301 302 n, err := p.pkt.Data().ReadTo(dst, opts.Peek) 303 if n == 0 && err != nil { 304 return res, &tcpip.ErrBadBuffer{} 305 } 306 res.Count = n 307 return res, nil 308 } 309 310 // prepareForWriteInner prepares the endpoint for sending data. In particular, 311 // it binds it if it's still in the initial state. To do so, it must first 312 // reacquire the mutex in exclusive mode. 313 // 314 // Returns true for retry if preparation should be retried. 315 // +checklocksread:e.mu 316 func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { 317 switch e.net.State() { 318 case transport.DatagramEndpointStateInitial: 319 case transport.DatagramEndpointStateConnected: 320 return false, nil 321 322 case transport.DatagramEndpointStateBound: 323 if to == nil { 324 return false, &tcpip.ErrDestinationRequired{} 325 } 326 return false, nil 327 default: 328 return false, &tcpip.ErrInvalidEndpointState{} 329 } 330 331 e.mu.RUnlock() 332 e.mu.Lock() 333 defer e.mu.DowngradeLock() 334 335 // The state changed when we released the shared locked and re-acquired 336 // it in exclusive mode. Try again. 337 if e.net.State() != transport.DatagramEndpointStateInitial { 338 return true, nil 339 } 340 341 // The state is still 'initial', so try to bind the endpoint. 342 if err := e.bindLocked(tcpip.FullAddress{}); err != nil { 343 return false, err 344 } 345 346 return true, nil 347 } 348 349 var _ tcpip.EndpointWithPreflight = (*endpoint)(nil) 350 351 // Validates the passed WriteOptions and prepares the endpoint for writes 352 // using those options. If the endpoint is unbound and the `To` address 353 // is specified, binds the endpoint to that address. 354 func (e *endpoint) Preflight(opts tcpip.WriteOptions) tcpip.Error { 355 var r bytes.Reader 356 udpInfo, err := e.prepareForWrite(&r, opts) 357 if err == nil { 358 udpInfo.ctx.Release() 359 } 360 return err 361 } 362 363 // Write writes data to the endpoint's peer. This method does not block 364 // if the data cannot be written. 365 func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { 366 n, err := e.write(p, opts) 367 switch err.(type) { 368 case nil: 369 e.stats.PacketsSent.Increment() 370 case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: 371 e.stats.WriteErrors.InvalidArgs.Increment() 372 case *tcpip.ErrClosedForSend: 373 e.stats.WriteErrors.WriteClosed.Increment() 374 case *tcpip.ErrInvalidEndpointState: 375 e.stats.WriteErrors.InvalidEndpointState.Increment() 376 case *tcpip.ErrHostUnreachable, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: 377 // Errors indicating any problem with IP routing of the packet. 378 e.stats.SendErrors.NoRoute.Increment() 379 default: 380 // For all other errors when writing to the network layer. 381 e.stats.SendErrors.SendToNetworkFailed.Increment() 382 } 383 return n, err 384 } 385 386 func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (udpPacketInfo, tcpip.Error) { 387 e.mu.RLock() 388 defer e.mu.RUnlock() 389 390 // Prepare for write. 391 for { 392 retry, err := e.prepareForWriteInner(opts.To) 393 if err != nil { 394 return udpPacketInfo{}, err 395 } 396 397 if !retry { 398 break 399 } 400 } 401 402 dst, connected := e.net.GetRemoteAddress() 403 dst.Port = e.remotePort 404 if opts.To != nil { 405 if opts.To.Port == 0 { 406 // Port 0 is an invalid port to send to. 407 return udpPacketInfo{}, &tcpip.ErrInvalidEndpointState{} 408 } 409 410 dst = *opts.To 411 } else if !connected { 412 return udpPacketInfo{}, &tcpip.ErrDestinationRequired{} 413 } 414 415 ctx, err := e.net.AcquireContextForWrite(opts) 416 if err != nil { 417 return udpPacketInfo{}, err 418 } 419 420 if p.Len() > header.UDPMaximumPacketSize { 421 // Native linux behaviour differs for IPv4 and IPv6 packets; IPv4 packet 422 // errors aren't report to the error queue at all. 423 if ctx.PacketInfo().NetProto == header.IPv6ProtocolNumber { 424 so := e.SocketOptions() 425 if so.GetIPv6RecvError() { 426 so.QueueLocalErr( 427 &tcpip.ErrMessageTooLong{}, 428 e.net.NetProto(), 429 uint32(p.Len()), 430 dst, 431 nil, 432 ) 433 } 434 } 435 ctx.Release() 436 return udpPacketInfo{}, &tcpip.ErrMessageTooLong{} 437 } 438 439 var buf buffer.Buffer 440 if _, err := buf.WriteFromReader(p, int64(p.Len())); err != nil { 441 buf.Release() 442 ctx.Release() 443 return udpPacketInfo{}, &tcpip.ErrBadBuffer{} 444 } 445 446 return udpPacketInfo{ 447 ctx: ctx, 448 data: buf, 449 localPort: e.localPort, 450 remotePort: dst.Port, 451 }, nil 452 } 453 454 func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { 455 // Do not hold lock when sending as loopback is synchronous and if the UDP 456 // datagram ends up generating an ICMP response then it can result in a 457 // deadlock where the ICMP response handling ends up acquiring this endpoint's 458 // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a 459 // deadlock if another caller is trying to acquire e.mu in exclusive mode w/ 460 // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the 461 // lock can be eventually acquired. 462 // 463 // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read 464 // locking is prohibited. 465 466 if err := e.LastError(); err != nil { 467 return 0, err 468 } 469 470 udpInfo, err := e.prepareForWrite(p, opts) 471 if err != nil { 472 return 0, err 473 } 474 defer udpInfo.ctx.Release() 475 476 dataSz := udpInfo.data.Size() 477 pktInfo := udpInfo.ctx.PacketInfo() 478 pkt := udpInfo.ctx.TryNewPacketBuffer(header.UDPMinimumSize+int(pktInfo.MaxHeaderLength), udpInfo.data) 479 if pkt.IsNil() { 480 return 0, &tcpip.ErrWouldBlock{} 481 } 482 defer pkt.DecRef() 483 484 // Initialize the UDP header. 485 udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) 486 pkt.TransportProtocolNumber = ProtocolNumber 487 488 length := uint16(pkt.Size()) 489 udp.Encode(&header.UDPFields{ 490 SrcPort: udpInfo.localPort, 491 DstPort: udpInfo.remotePort, 492 Length: length, 493 }) 494 495 // Set the checksum field unless TX checksum offload is enabled. 496 // On IPv4, UDP checksum is optional, and a zero value indicates the 497 // transmitter skipped the checksum generation (RFC768). 498 // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). 499 if pktInfo.RequiresTXTransportChecksum && 500 (!e.ops.GetNoChecksum() || pktInfo.NetProto == header.IPv6ProtocolNumber) { 501 xsum := udp.CalculateChecksum(checksum.Combine( 502 header.PseudoHeaderChecksum(ProtocolNumber, pktInfo.LocalAddress, pktInfo.RemoteAddress, length), 503 pkt.Data().Checksum(), 504 )) 505 // As per RFC 768 page 2, 506 // 507 // Checksum is the 16-bit one's complement of the one's complement sum of 508 // a pseudo header of information from the IP header, the UDP header, and 509 // the data, padded with zero octets at the end (if necessary) to make a 510 // multiple of two octets. 511 // 512 // The pseudo header conceptually prefixed to the UDP header contains the 513 // source address, the destination address, the protocol, and the UDP 514 // length. This information gives protection against misrouted datagrams. 515 // This checksum procedure is the same as is used in TCP. 516 // 517 // If the computed checksum is zero, it is transmitted as all ones (the 518 // equivalent in one's complement arithmetic). An all zero transmitted 519 // checksum value means that the transmitter generated no checksum (for 520 // debugging or for higher level protocols that don't care). 521 // 522 // To avoid the zero value, we only calculate the one's complement of the 523 // one's complement sum if the sum is not all ones. 524 if xsum != math.MaxUint16 { 525 xsum = ^xsum 526 } 527 udp.SetChecksum(xsum) 528 } 529 if err := udpInfo.ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { 530 e.stack.Stats().UDP.PacketSendErrors.Increment() 531 return 0, err 532 } 533 534 // Track count of packets sent. 535 e.stack.Stats().UDP.PacketsSent.Increment() 536 return int64(dataSz), nil 537 } 538 539 // OnReuseAddressSet implements tcpip.SocketOptionsHandler. 540 func (e *endpoint) OnReuseAddressSet(v bool) { 541 e.mu.Lock() 542 e.portFlags.MostRecent = v 543 e.mu.Unlock() 544 } 545 546 // OnReusePortSet implements tcpip.SocketOptionsHandler. 547 func (e *endpoint) OnReusePortSet(v bool) { 548 e.mu.Lock() 549 e.portFlags.LoadBalanced = v 550 e.mu.Unlock() 551 } 552 553 // SetSockOptInt implements tcpip.Endpoint. 554 func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { 555 return e.net.SetSockOptInt(opt, v) 556 } 557 558 var _ tcpip.SocketOptionsHandler = (*endpoint)(nil) 559 560 // HasNIC implements tcpip.SocketOptionsHandler. 561 func (e *endpoint) HasNIC(id int32) bool { 562 return e.stack.HasNIC(tcpip.NICID(id)) 563 } 564 565 // SetSockOpt implements tcpip.Endpoint. 566 func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { 567 return e.net.SetSockOpt(opt) 568 } 569 570 // GetSockOptInt implements tcpip.Endpoint. 571 func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { 572 switch opt { 573 case tcpip.ReceiveQueueSizeOption: 574 v := 0 575 e.rcvMu.Lock() 576 if !e.rcvList.Empty() { 577 p := e.rcvList.Front() 578 v = p.pkt.Data().Size() 579 } 580 e.rcvMu.Unlock() 581 return v, nil 582 583 default: 584 return e.net.GetSockOptInt(opt) 585 } 586 } 587 588 // GetSockOpt implements tcpip.Endpoint. 589 func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { 590 return e.net.GetSockOpt(opt) 591 } 592 593 // udpPacketInfo holds information needed to send a UDP packet. 594 type udpPacketInfo struct { 595 ctx network.WriteContext 596 data buffer.Buffer 597 localPort uint16 598 remotePort uint16 599 } 600 601 // Disconnect implements tcpip.Endpoint. 602 func (e *endpoint) Disconnect() tcpip.Error { 603 e.mu.Lock() 604 defer e.mu.Unlock() 605 606 if e.net.State() != transport.DatagramEndpointStateConnected { 607 return nil 608 } 609 var ( 610 id stack.TransportEndpointID 611 btd tcpip.NICID 612 ) 613 614 // We change this value below and we need the old value to unregister 615 // the endpoint. 616 boundPortFlags := e.boundPortFlags 617 618 // Exclude ephemerally bound endpoints. 619 info := e.net.Info() 620 info.ID.LocalPort = e.localPort 621 info.ID.RemotePort = e.remotePort 622 if e.net.WasBound() { 623 var err tcpip.Error 624 id = stack.TransportEndpointID{ 625 LocalPort: info.ID.LocalPort, 626 LocalAddress: info.ID.LocalAddress, 627 } 628 id, btd, err = e.registerWithStack(e.effectiveNetProtos, id) 629 if err != nil { 630 return err 631 } 632 boundPortFlags = e.boundPortFlags 633 } else { 634 if info.ID.LocalPort != 0 { 635 // Release the ephemeral port. 636 portRes := ports.Reservation{ 637 Networks: e.effectiveNetProtos, 638 Transport: ProtocolNumber, 639 Addr: info.ID.LocalAddress, 640 Port: info.ID.LocalPort, 641 Flags: boundPortFlags, 642 BindToDevice: e.boundBindToDevice, 643 Dest: tcpip.FullAddress{}, 644 } 645 e.stack.ReleasePort(portRes) 646 e.boundPortFlags = ports.Flags{} 647 } 648 } 649 650 e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, info.ID, e, boundPortFlags, e.boundBindToDevice) 651 e.boundBindToDevice = btd 652 e.localPort = id.LocalPort 653 e.remotePort = id.RemotePort 654 655 e.net.Disconnect() 656 657 return nil 658 } 659 660 // Connect connects the endpoint to its peer. Specifying a NIC is optional. 661 func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { 662 e.mu.Lock() 663 defer e.mu.Unlock() 664 665 err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error { 666 nextID.LocalPort = e.localPort 667 nextID.RemotePort = addr.Port 668 669 // Even if we're connected, this endpoint can still be used to send 670 // packets on a different network protocol, so we register both even if 671 // v6only is set to false and this is an ipv6 endpoint. 672 netProtos := []tcpip.NetworkProtocolNumber{netProto} 673 if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) { 674 netProtos = []tcpip.NetworkProtocolNumber{ 675 header.IPv4ProtocolNumber, 676 header.IPv6ProtocolNumber, 677 } 678 } 679 680 oldPortFlags := e.boundPortFlags 681 682 nextID, btd, err := e.registerWithStack(netProtos, nextID) 683 if err != nil { 684 return err 685 } 686 687 // Remove the old registration. 688 if e.localPort != 0 { 689 previousID.LocalPort = e.localPort 690 previousID.RemotePort = e.remotePort 691 e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, previousID, e, oldPortFlags, e.boundBindToDevice) 692 } 693 694 e.localPort = nextID.LocalPort 695 e.remotePort = nextID.RemotePort 696 e.boundBindToDevice = btd 697 e.effectiveNetProtos = netProtos 698 return nil 699 }) 700 if err != nil { 701 return err 702 } 703 704 e.rcvMu.Lock() 705 e.rcvReady = true 706 e.rcvMu.Unlock() 707 return nil 708 } 709 710 // ConnectEndpoint is not supported. 711 func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { 712 return &tcpip.ErrInvalidEndpointState{} 713 } 714 715 // Shutdown closes the read and/or write end of the endpoint connection 716 // to its peer. 717 func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { 718 e.mu.Lock() 719 defer e.mu.Unlock() 720 721 switch state := e.net.State(); state { 722 case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: 723 return &tcpip.ErrNotConnected{} 724 case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: 725 default: 726 panic(fmt.Sprintf("unhandled state = %s", state)) 727 } 728 729 if flags&tcpip.ShutdownWrite != 0 { 730 if err := e.net.Shutdown(); err != nil { 731 return err 732 } 733 } 734 735 if flags&tcpip.ShutdownRead != 0 { 736 e.readShutdown = true 737 738 e.rcvMu.Lock() 739 wasClosed := e.rcvClosed 740 e.rcvClosed = true 741 e.rcvMu.Unlock() 742 743 if !wasClosed { 744 e.waiterQueue.Notify(waiter.ReadableEvents) 745 } 746 } 747 748 if e.net.State() == transport.DatagramEndpointStateBound { 749 return &tcpip.ErrNotConnected{} 750 } 751 return nil 752 } 753 754 // Listen is not supported by UDP, it just fails. 755 func (*endpoint) Listen(int) tcpip.Error { 756 return &tcpip.ErrNotSupported{} 757 } 758 759 // Accept is not supported by UDP, it just fails. 760 func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) { 761 return nil, nil, &tcpip.ErrNotSupported{} 762 } 763 764 func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) { 765 bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) 766 if e.localPort == 0 { 767 portRes := ports.Reservation{ 768 Networks: netProtos, 769 Transport: ProtocolNumber, 770 Addr: id.LocalAddress, 771 Port: id.LocalPort, 772 Flags: e.portFlags, 773 BindToDevice: bindToDevice, 774 Dest: tcpip.FullAddress{}, 775 } 776 port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */) 777 if err != nil { 778 return id, bindToDevice, err 779 } 780 id.LocalPort = port 781 } 782 e.boundPortFlags = e.portFlags 783 784 err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) 785 if err != nil { 786 portRes := ports.Reservation{ 787 Networks: netProtos, 788 Transport: ProtocolNumber, 789 Addr: id.LocalAddress, 790 Port: id.LocalPort, 791 Flags: e.boundPortFlags, 792 BindToDevice: bindToDevice, 793 Dest: tcpip.FullAddress{}, 794 } 795 e.stack.ReleasePort(portRes) 796 e.boundPortFlags = ports.Flags{} 797 } 798 return id, bindToDevice, err 799 } 800 801 func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { 802 // Don't allow binding once endpoint is not in the initial state 803 // anymore. 804 if e.net.State() != transport.DatagramEndpointStateInitial { 805 return &tcpip.ErrInvalidEndpointState{} 806 } 807 808 err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error { 809 // Expand netProtos to include v4 and v6 if the caller is binding to a 810 // wildcard (empty) address, and this is an IPv6 endpoint with v6only 811 // set to false. 812 netProtos := []tcpip.NetworkProtocolNumber{boundNetProto} 813 if boundNetProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && boundAddr == (tcpip.Address{}) && e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) { 814 netProtos = []tcpip.NetworkProtocolNumber{ 815 header.IPv6ProtocolNumber, 816 header.IPv4ProtocolNumber, 817 } 818 } 819 820 id := stack.TransportEndpointID{ 821 LocalPort: addr.Port, 822 LocalAddress: boundAddr, 823 } 824 id, btd, err := e.registerWithStack(netProtos, id) 825 if err != nil { 826 return err 827 } 828 829 e.localPort = id.LocalPort 830 e.boundBindToDevice = btd 831 e.effectiveNetProtos = netProtos 832 return nil 833 }) 834 if err != nil { 835 return err 836 } 837 838 e.rcvMu.Lock() 839 e.rcvReady = true 840 e.rcvMu.Unlock() 841 return nil 842 } 843 844 // Bind binds the endpoint to a specific local address and port. 845 // Specifying a NIC is optional. 846 func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { 847 e.mu.Lock() 848 defer e.mu.Unlock() 849 850 err := e.bindLocked(addr) 851 if err != nil { 852 return err 853 } 854 855 return nil 856 } 857 858 // GetLocalAddress returns the address to which the endpoint is bound. 859 func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { 860 e.mu.RLock() 861 defer e.mu.RUnlock() 862 863 addr := e.net.GetLocalAddress() 864 addr.Port = e.localPort 865 return addr, nil 866 } 867 868 // GetRemoteAddress returns the address to which the endpoint is connected. 869 func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { 870 e.mu.RLock() 871 defer e.mu.RUnlock() 872 873 addr, connected := e.net.GetRemoteAddress() 874 if !connected || e.remotePort == 0 { 875 return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} 876 } 877 878 addr.Port = e.remotePort 879 return addr, nil 880 } 881 882 // Readiness returns the current readiness of the endpoint. For example, if 883 // waiter.EventIn is set, the endpoint is immediately readable. 884 func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { 885 var result waiter.EventMask 886 887 if e.net.HasSendSpace() { 888 result |= waiter.WritableEvents & mask 889 } 890 891 // Determine if the endpoint is readable if requested. 892 if mask&waiter.ReadableEvents != 0 { 893 e.rcvMu.Lock() 894 if !e.rcvList.Empty() || e.rcvClosed { 895 result |= waiter.ReadableEvents 896 } 897 e.rcvMu.Unlock() 898 } 899 900 e.lastErrorMu.Lock() 901 hasError := e.lastError != nil 902 e.lastErrorMu.Unlock() 903 if hasError { 904 result |= waiter.EventErr 905 } 906 return result 907 } 908 909 // HandlePacket is called by the stack when new packets arrive to this transport 910 // endpoint. 911 func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) { 912 // Get the header then trim it from the view. 913 hdr := header.UDP(pkt.TransportHeader().Slice()) 914 netHdr := pkt.Network() 915 lengthValid, csumValid := header.UDPValid( 916 hdr, 917 func() uint16 { return pkt.Data().Checksum() }, 918 uint16(pkt.Data().Size()), 919 pkt.NetworkProtocolNumber, 920 netHdr.SourceAddress(), 921 netHdr.DestinationAddress(), 922 pkt.RXChecksumValidated) 923 if !lengthValid { 924 // Malformed packet. 925 e.stack.Stats().UDP.MalformedPacketsReceived.Increment() 926 e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() 927 return 928 } 929 930 if !csumValid { 931 e.stack.Stats().UDP.ChecksumErrors.Increment() 932 e.stats.ReceiveErrors.ChecksumErrors.Increment() 933 return 934 } 935 936 e.stack.Stats().UDP.PacketsReceived.Increment() 937 e.stats.PacketsReceived.Increment() 938 939 e.rcvMu.Lock() 940 // Drop the packet if our buffer is not ready to receive packets. 941 if !e.rcvReady || e.rcvClosed { 942 e.rcvMu.Unlock() 943 e.stack.Stats().UDP.ReceiveBufferErrors.Increment() 944 e.stats.ReceiveErrors.ClosedReceiver.Increment() 945 return 946 } 947 948 rcvBufSize := e.ops.GetReceiveBufferSize() 949 // Drop the packet if our buffer is currently full. 950 if e.frozen || e.rcvBufSize >= int(rcvBufSize) { 951 e.rcvMu.Unlock() 952 e.stack.Stats().UDP.ReceiveBufferErrors.Increment() 953 e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() 954 return 955 } 956 957 wasEmpty := e.rcvBufSize == 0 958 959 // Push new packet into receive list and increment the buffer size. 960 packet := &udpPacket{ 961 netProto: pkt.NetworkProtocolNumber, 962 senderAddress: tcpip.FullAddress{ 963 NIC: pkt.NICID, 964 Addr: id.RemoteAddress, 965 Port: hdr.SourcePort(), 966 }, 967 destinationAddress: tcpip.FullAddress{ 968 NIC: pkt.NICID, 969 Addr: id.LocalAddress, 970 Port: hdr.DestinationPort(), 971 }, 972 pkt: pkt.IncRef(), 973 } 974 e.rcvList.PushBack(packet) 975 e.rcvBufSize += pkt.Data().Size() 976 977 // Save any useful information from the network header to the packet. 978 packet.tosOrTClass, _ = pkt.Network().TOS() 979 switch pkt.NetworkProtocolNumber { 980 case header.IPv4ProtocolNumber: 981 packet.ttlOrHopLimit = header.IPv4(pkt.NetworkHeader().Slice()).TTL() 982 case header.IPv6ProtocolNumber: 983 packet.ttlOrHopLimit = header.IPv6(pkt.NetworkHeader().Slice()).HopLimit() 984 } 985 986 // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast 987 // address. packetInfo.LocalAddr should hold a unicast address that can be 988 // used to respond to the incoming packet. 989 localAddr := pkt.Network().DestinationAddress() 990 packet.packetInfo.LocalAddr = localAddr 991 packet.packetInfo.DestinationAddr = localAddr 992 packet.packetInfo.NIC = pkt.NICID 993 packet.receivedAt = e.stack.Clock().Now() 994 995 e.rcvMu.Unlock() 996 997 // Notify any waiters that there's data to be read now. 998 if wasEmpty { 999 e.waiterQueue.Notify(waiter.ReadableEvents) 1000 } 1001 } 1002 1003 func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, pkt stack.PacketBufferPtr) { 1004 // Update last error first. 1005 e.lastErrorMu.Lock() 1006 e.lastError = err 1007 e.lastErrorMu.Unlock() 1008 1009 var recvErr bool 1010 switch pkt.NetworkProtocolNumber { 1011 case header.IPv4ProtocolNumber: 1012 recvErr = e.SocketOptions().GetIPv4RecvError() 1013 case header.IPv6ProtocolNumber: 1014 recvErr = e.SocketOptions().GetIPv6RecvError() 1015 default: 1016 panic(fmt.Sprintf("unhandled network protocol number = %d", pkt.NetworkProtocolNumber)) 1017 } 1018 1019 if recvErr { 1020 // Linux passes the payload without the UDP header. 1021 payload := pkt.Data().AsRange().ToView() 1022 udp := header.UDP(payload.AsSlice()) 1023 if len(udp) >= header.UDPMinimumSize { 1024 payload.TrimFront(header.UDPMinimumSize) 1025 } 1026 1027 id := e.net.Info().ID 1028 e.SocketOptions().QueueErr(&tcpip.SockError{ 1029 Err: err, 1030 Cause: transErr, 1031 Payload: payload, 1032 Dst: tcpip.FullAddress{ 1033 NIC: pkt.NICID, 1034 Addr: id.RemoteAddress, 1035 Port: e.remotePort, 1036 }, 1037 Offender: tcpip.FullAddress{ 1038 NIC: pkt.NICID, 1039 Addr: id.LocalAddress, 1040 Port: e.localPort, 1041 }, 1042 NetProto: pkt.NetworkProtocolNumber, 1043 }) 1044 } 1045 1046 // Notify of the error. 1047 e.waiterQueue.Notify(waiter.EventErr) 1048 } 1049 1050 // HandleError implements stack.TransportEndpoint. 1051 func (e *endpoint) HandleError(transErr stack.TransportError, pkt stack.PacketBufferPtr) { 1052 // TODO(gvisor.dev/issues/5270): Handle all transport errors. 1053 switch transErr.Kind() { 1054 case stack.DestinationPortUnreachableTransportError: 1055 if e.net.State() == transport.DatagramEndpointStateConnected { 1056 e.onICMPError(&tcpip.ErrConnectionRefused{}, transErr, pkt) 1057 } 1058 } 1059 } 1060 1061 // State implements tcpip.Endpoint. 1062 func (e *endpoint) State() uint32 { 1063 return uint32(e.net.State()) 1064 } 1065 1066 // Info returns a copy of the endpoint info. 1067 func (e *endpoint) Info() tcpip.EndpointInfo { 1068 e.mu.RLock() 1069 defer e.mu.RUnlock() 1070 info := e.net.Info() 1071 info.ID.LocalPort = e.localPort 1072 info.ID.RemotePort = e.remotePort 1073 return &info 1074 } 1075 1076 // Stats returns a pointer to the endpoint stats. 1077 func (e *endpoint) Stats() tcpip.EndpointStats { 1078 return &e.stats 1079 } 1080 1081 // Wait implements tcpip.Endpoint. 1082 func (*endpoint) Wait() {} 1083 1084 // SetOwner implements tcpip.Endpoint. 1085 func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { 1086 e.net.SetOwner(owner) 1087 } 1088 1089 // SocketOptions implements tcpip.Endpoint. 1090 func (e *endpoint) SocketOptions() *tcpip.SocketOptions { 1091 return &e.ops 1092 } 1093 1094 // freeze prevents any more packets from being delivered to the endpoint. 1095 func (e *endpoint) freeze() { 1096 e.mu.Lock() 1097 e.frozen = true 1098 e.mu.Unlock() 1099 } 1100 1101 // thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows 1102 // new packets to be delivered again. 1103 func (e *endpoint) thaw() { 1104 e.mu.Lock() 1105 e.frozen = false 1106 e.mu.Unlock() 1107 }