github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/stack/transport_demuxer.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack 16 17 import ( 18 "fmt" 19 20 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip" 21 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/hash/jenkins" 22 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header" 23 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/ports" 24 ) 25 26 type protocolIDs struct { 27 network tcpip.NetworkProtocolNumber 28 transport tcpip.TransportProtocolNumber 29 } 30 31 // transportEndpoints manages all endpoints of a given protocol. It has its own 32 // mutex so as to reduce interference between protocols. 33 type transportEndpoints struct { 34 mu transportEndpointsRWMutex 35 // +checklocks:mu 36 endpoints map[TransportEndpointID]*endpointsByNIC 37 // rawEndpoints contains endpoints for raw sockets, which receive all 38 // traffic of a given protocol regardless of port. 39 // 40 // +checklocks:mu 41 rawEndpoints []RawTransportEndpoint 42 } 43 44 // unregisterEndpoint unregisters the endpoint with the given id such that it 45 // won't receive any more packets. 46 func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { 47 eps.mu.Lock() 48 defer eps.mu.Unlock() 49 epsByNIC, ok := eps.endpoints[id] 50 if !ok { 51 return 52 } 53 if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) { 54 return 55 } 56 delete(eps.endpoints, id) 57 } 58 59 func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { 60 eps.mu.RLock() 61 defer eps.mu.RUnlock() 62 es := make([]TransportEndpoint, 0, len(eps.endpoints)) 63 for _, e := range eps.endpoints { 64 es = append(es, e.transportEndpoints()...) 65 } 66 return es 67 } 68 69 // iterEndpointsLocked yields all endpointsByNIC in eps that match id, in 70 // descending order of match quality. If a call to yield returns false, 71 // iterEndpointsLocked stops iteration and returns immediately. 72 // 73 // +checklocksread:eps.mu 74 func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { 75 // Try to find a match with the id as provided. 76 if ep, ok := eps.endpoints[id]; ok { 77 if !yield(ep) { 78 return 79 } 80 } 81 82 // Try to find a match with the id minus the local address. 83 nid := id 84 85 nid.LocalAddress = tcpip.Address{} 86 if ep, ok := eps.endpoints[nid]; ok { 87 if !yield(ep) { 88 return 89 } 90 } 91 92 // Try to find a match with the id minus the remote part. 93 nid.LocalAddress = id.LocalAddress 94 nid.RemoteAddress = tcpip.Address{} 95 nid.RemotePort = 0 96 if ep, ok := eps.endpoints[nid]; ok { 97 if !yield(ep) { 98 return 99 } 100 } 101 102 // Try to find a match with only the local port. 103 nid.LocalAddress = tcpip.Address{} 104 if ep, ok := eps.endpoints[nid]; ok { 105 if !yield(ep) { 106 return 107 } 108 } 109 } 110 111 // findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in 112 // descending order of match quality. 113 // 114 // +checklocksread:eps.mu 115 func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { 116 var matchedEPs []*endpointsByNIC 117 eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { 118 matchedEPs = append(matchedEPs, ep) 119 return true 120 }) 121 return matchedEPs 122 } 123 124 // findEndpointLocked returns the endpoint that most closely matches the given id. 125 // 126 // +checklocksread:eps.mu 127 func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { 128 var matchedEP *endpointsByNIC 129 eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { 130 matchedEP = ep 131 return false 132 }) 133 return matchedEP 134 } 135 136 type endpointsByNIC struct { 137 // seed is a random secret for a jenkins hash. 138 seed uint32 139 140 mu endpointsByNICRWMutex 141 // +checklocks:mu 142 endpoints map[tcpip.NICID]*multiPortEndpoint 143 } 144 145 func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { 146 epsByNIC.mu.RLock() 147 defer epsByNIC.mu.RUnlock() 148 var eps []TransportEndpoint 149 for _, ep := range epsByNIC.endpoints { 150 eps = append(eps, ep.transportEndpoints()...) 151 } 152 return eps 153 } 154 155 // handlePacket is called by the stack when new packets arrive to this transport 156 // endpoint. It returns false if the packet could not be matched to any 157 // transport endpoint, true otherwise. 158 func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt PacketBufferPtr) bool { 159 epsByNIC.mu.RLock() 160 161 mpep, ok := epsByNIC.endpoints[pkt.NICID] 162 if !ok { 163 if mpep, ok = epsByNIC.endpoints[0]; !ok { 164 epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. 165 return false 166 } 167 } 168 169 // If this is a broadcast or multicast datagram, deliver the datagram to all 170 // endpoints bound to the right device. 171 if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { 172 mpep.handlePacketAll(id, pkt) 173 epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. 174 return true 175 } 176 // multiPortEndpoints are guaranteed to have at least one element. 177 transEP := mpep.selectEndpoint(id, epsByNIC.seed) 178 if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { 179 queuedProtocol.QueuePacket(transEP, id, pkt) 180 epsByNIC.mu.RUnlock() 181 return true 182 } 183 epsByNIC.mu.RUnlock() 184 185 transEP.HandlePacket(id, pkt) 186 return true 187 } 188 189 // handleError delivers an error to the transport endpoint identified by id. 190 func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, transErr TransportError, pkt PacketBufferPtr) { 191 epsByNIC.mu.RLock() 192 193 mpep, ok := epsByNIC.endpoints[n.ID()] 194 if !ok { 195 mpep, ok = epsByNIC.endpoints[0] 196 } 197 if !ok { 198 epsByNIC.mu.RUnlock() 199 return 200 } 201 202 // TODO(eyalsoha): Why don't we look at id to see if this packet needs to 203 // broadcast like we are doing with handlePacket above? 204 205 // multiPortEndpoints are guaranteed to have at least one element. 206 transEP := mpep.selectEndpoint(id, epsByNIC.seed) 207 epsByNIC.mu.RUnlock() 208 209 transEP.HandleError(transErr, pkt) 210 } 211 212 // registerEndpoint returns true if it succeeds. It fails and returns 213 // false if ep already has an element with the same key. 214 func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 215 epsByNIC.mu.Lock() 216 defer epsByNIC.mu.Unlock() 217 218 multiPortEp, ok := epsByNIC.endpoints[bindToDevice] 219 if !ok { 220 multiPortEp = &multiPortEndpoint{ 221 demux: d, 222 netProto: netProto, 223 transProto: transProto, 224 } 225 } 226 227 if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil { 228 return err 229 } 230 // Only add this newly created multiportEndpoint if the singleRegisterEndpoint 231 // succeeded. 232 if !ok { 233 epsByNIC.endpoints[bindToDevice] = multiPortEp 234 } 235 return nil 236 } 237 238 func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 239 epsByNIC.mu.RLock() 240 defer epsByNIC.mu.RUnlock() 241 242 multiPortEp, ok := epsByNIC.endpoints[bindToDevice] 243 if !ok { 244 return nil 245 } 246 247 return multiPortEp.singleCheckEndpoint(flags) 248 } 249 250 // unregisterEndpoint returns true if endpointsByNIC has to be unregistered. 251 func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool { 252 epsByNIC.mu.Lock() 253 defer epsByNIC.mu.Unlock() 254 multiPortEp, ok := epsByNIC.endpoints[bindToDevice] 255 if !ok { 256 return false 257 } 258 if multiPortEp.unregisterEndpoint(t, flags) { 259 delete(epsByNIC.endpoints, bindToDevice) 260 } 261 return len(epsByNIC.endpoints) == 0 262 } 263 264 // transportDemuxer demultiplexes packets targeted at a transport endpoint 265 // (i.e., after they've been parsed by the network layer). It does two levels 266 // of demultiplexing: first based on the network and transport protocols, then 267 // based on endpoints IDs. It should only be instantiated via 268 // newTransportDemuxer. 269 type transportDemuxer struct { 270 stack *Stack 271 272 // protocol is immutable. 273 protocol map[protocolIDs]*transportEndpoints 274 queuedProtocols map[protocolIDs]queuedTransportProtocol 275 } 276 277 // queuedTransportProtocol if supported by a protocol implementation will cause 278 // the dispatcher to delivery packets to the QueuePacket method instead of 279 // calling HandlePacket directly on the endpoint. 280 type queuedTransportProtocol interface { 281 QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt PacketBufferPtr) 282 } 283 284 func newTransportDemuxer(stack *Stack) *transportDemuxer { 285 d := &transportDemuxer{ 286 stack: stack, 287 protocol: make(map[protocolIDs]*transportEndpoints), 288 queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), 289 } 290 291 // Add each network and transport pair to the demuxer. 292 for netProto := range stack.networkProtocols { 293 for proto := range stack.transportProtocols { 294 protoIDs := protocolIDs{netProto, proto} 295 d.protocol[protoIDs] = &transportEndpoints{ 296 endpoints: make(map[TransportEndpointID]*endpointsByNIC), 297 } 298 qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) 299 if isQueued { 300 d.queuedProtocols[protoIDs] = qTransProto 301 } 302 } 303 } 304 305 return d 306 } 307 308 // registerEndpoint registers the given endpoint with the dispatcher such that 309 // packets that match the endpoint ID are delivered to it. 310 func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 311 for i, n := range netProtos { 312 if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil { 313 d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice) 314 return err 315 } 316 } 317 318 return nil 319 } 320 321 // checkEndpoint checks if an endpoint can be registered with the dispatcher. 322 func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 323 for _, n := range netProtos { 324 if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { 325 return err 326 } 327 } 328 329 return nil 330 } 331 332 // multiPortEndpoint is a container for TransportEndpoints which are bound to 333 // the same pair of address and port. endpointsArr always has at least one 334 // element. 335 // 336 // FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save 337 // this to ensure that the underlying endpoints get saved/restored, but not not 338 // use the restored copy. 339 // 340 // +stateify savable 341 type multiPortEndpoint struct { 342 demux *transportDemuxer 343 netProto tcpip.NetworkProtocolNumber 344 transProto tcpip.TransportProtocolNumber 345 346 flags ports.FlagCounter 347 348 mu multiPortEndpointRWMutex `state:"nosave"` 349 // endpoints stores the transport endpoints in the order in which they 350 // were bound. This is required for UDP SO_REUSEADDR. 351 // 352 // +checklocks:mu 353 endpoints []TransportEndpoint 354 } 355 356 func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { 357 ep.mu.RLock() 358 eps := append([]TransportEndpoint(nil), ep.endpoints...) 359 ep.mu.RUnlock() 360 return eps 361 } 362 363 // reciprocalScale scales a value into range [0, n). 364 // 365 // This is similar to val % n, but faster. 366 // See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ 367 func reciprocalScale(val, n uint32) uint32 { 368 return uint32((uint64(val) * uint64(n)) >> 32) 369 } 370 371 // selectEndpoint calculates a hash of destination and source addresses and 372 // ports then uses it to select a socket. In this case, all packets from one 373 // address will be sent to same endpoint. 374 func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID, seed uint32) TransportEndpoint { 375 ep.mu.RLock() 376 defer ep.mu.RUnlock() 377 378 if len(ep.endpoints) == 1 { 379 return ep.endpoints[0] 380 } 381 382 if ep.flags.SharedFlags().ToFlags().Effective().MostRecent { 383 return ep.endpoints[len(ep.endpoints)-1] 384 } 385 386 payload := []byte{ 387 byte(id.LocalPort), 388 byte(id.LocalPort >> 8), 389 byte(id.RemotePort), 390 byte(id.RemotePort >> 8), 391 } 392 393 h := jenkins.Sum32(seed) 394 h.Write(payload) 395 h.Write(id.LocalAddress.AsSlice()) 396 h.Write(id.RemoteAddress.AsSlice()) 397 hash := h.Sum32() 398 399 idx := reciprocalScale(hash, uint32(len(ep.endpoints))) 400 return ep.endpoints[idx] 401 } 402 403 func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt PacketBufferPtr) { 404 ep.mu.RLock() 405 queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] 406 // HandlePacket may modify pkt, so each endpoint needs 407 // its own copy except for the final one. 408 for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { 409 clone := pkt.Clone() 410 if mustQueue { 411 queuedProtocol.QueuePacket(endpoint, id, clone) 412 } else { 413 endpoint.HandlePacket(id, clone) 414 } 415 clone.DecRef() 416 } 417 if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { 418 queuedProtocol.QueuePacket(endpoint, id, pkt) 419 } else { 420 endpoint.HandlePacket(id, pkt) 421 } 422 ep.mu.RUnlock() // Don't use defer for performance reasons. 423 } 424 425 // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint 426 // list. The list might be empty already. 427 func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error { 428 ep.mu.Lock() 429 defer ep.mu.Unlock() 430 bits := flags.Bits() & ports.MultiBindFlagMask 431 432 if len(ep.endpoints) != 0 { 433 // If it was previously bound, we need to check if we can bind again. 434 if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 { 435 return &tcpip.ErrPortInUse{} 436 } 437 } 438 439 ep.endpoints = append(ep.endpoints, t) 440 ep.flags.AddRef(bits) 441 442 return nil 443 } 444 445 func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error { 446 ep.mu.RLock() 447 defer ep.mu.RUnlock() 448 449 bits := flags.Bits() & ports.MultiBindFlagMask 450 451 if len(ep.endpoints) != 0 { 452 // If it was previously bound, we need to check if we can bind again. 453 if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 { 454 return &tcpip.ErrPortInUse{} 455 } 456 } 457 458 return nil 459 } 460 461 // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. 462 func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool { 463 ep.mu.Lock() 464 defer ep.mu.Unlock() 465 466 for i, endpoint := range ep.endpoints { 467 if endpoint == t { 468 copy(ep.endpoints[i:], ep.endpoints[i+1:]) 469 ep.endpoints[len(ep.endpoints)-1] = nil 470 ep.endpoints = ep.endpoints[:len(ep.endpoints)-1] 471 472 ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask) 473 break 474 } 475 } 476 return len(ep.endpoints) == 0 477 } 478 479 func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 480 if id.RemotePort != 0 { 481 // SO_REUSEPORT only applies to bound/listening endpoints. 482 flags.LoadBalanced = false 483 } 484 485 eps, ok := d.protocol[protocolIDs{netProto, protocol}] 486 if !ok { 487 return &tcpip.ErrUnknownProtocol{} 488 } 489 490 eps.mu.Lock() 491 defer eps.mu.Unlock() 492 epsByNIC, ok := eps.endpoints[id] 493 if !ok { 494 epsByNIC = &endpointsByNIC{ 495 endpoints: make(map[tcpip.NICID]*multiPortEndpoint), 496 seed: d.stack.seed, 497 } 498 } 499 if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil { 500 return err 501 } 502 // Only add this newly created epsByNIC if registerEndpoint succeeded. 503 if !ok { 504 eps.endpoints[id] = epsByNIC 505 } 506 return nil 507 } 508 509 func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 510 if id.RemotePort != 0 { 511 // SO_REUSEPORT only applies to bound/listening endpoints. 512 flags.LoadBalanced = false 513 } 514 515 eps, ok := d.protocol[protocolIDs{netProto, protocol}] 516 if !ok { 517 return &tcpip.ErrUnknownProtocol{} 518 } 519 520 eps.mu.RLock() 521 defer eps.mu.RUnlock() 522 523 epsByNIC, ok := eps.endpoints[id] 524 if !ok { 525 return nil 526 } 527 528 return epsByNIC.checkEndpoint(flags, bindToDevice) 529 } 530 531 // unregisterEndpoint unregisters the endpoint with the given id such that it 532 // won't receive any more packets. 533 func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { 534 if id.RemotePort != 0 { 535 // SO_REUSEPORT only applies to bound/listening endpoints. 536 flags.LoadBalanced = false 537 } 538 539 for _, n := range netProtos { 540 if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { 541 eps.unregisterEndpoint(id, ep, flags, bindToDevice) 542 } 543 } 544 } 545 546 // deliverPacket attempts to find one or more matching transport endpoints, and 547 // then, if matches are found, delivers the packet to them. Returns true if 548 // the packet no longer needs to be handled. 549 func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt PacketBufferPtr, id TransportEndpointID) bool { 550 eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] 551 if !ok { 552 return false 553 } 554 555 // If the packet is a UDP broadcast or multicast, then find all matching 556 // transport endpoints. 557 if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { 558 eps.mu.RLock() 559 destEPs := eps.findAllEndpointsLocked(id) 560 eps.mu.RUnlock() 561 // Fail if we didn't find at least one matching transport endpoint. 562 if len(destEPs) == 0 { 563 d.stack.stats.UDP.UnknownPortErrors.Increment() 564 return false 565 } 566 // handlePacket takes may modify pkt, so each endpoint needs its own 567 // copy except for the final one. 568 for _, ep := range destEPs[:len(destEPs)-1] { 569 clone := pkt.Clone() 570 ep.handlePacket(id, clone) 571 clone.DecRef() 572 } 573 destEPs[len(destEPs)-1].handlePacket(id, pkt) 574 return true 575 } 576 577 // If the packet is a TCP packet with a unspecified source or non-unicast 578 // destination address, then do nothing further and instruct the caller to do 579 // the same. The network layer handles address validation for specified source 580 // addresses. 581 if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) { 582 // TCP can only be used to communicate between a single source and a 583 // single destination; the addresses must be unicast.e 584 d.stack.stats.TCP.InvalidSegmentsReceived.Increment() 585 return true 586 } 587 588 eps.mu.RLock() 589 ep := eps.findEndpointLocked(id) 590 eps.mu.RUnlock() 591 if ep == nil { 592 if protocol == header.UDPProtocolNumber { 593 d.stack.stats.UDP.UnknownPortErrors.Increment() 594 } 595 return false 596 } 597 return ep.handlePacket(id, pkt) 598 } 599 600 // deliverRawPacket attempts to deliver the given packet and returns whether it 601 // was delivered successfully. 602 func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt PacketBufferPtr) bool { 603 eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] 604 if !ok { 605 return false 606 } 607 608 // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via 609 // raw endpoint first. If there are multiple raw endpoints, they all 610 // receive the packet. 611 eps.mu.RLock() 612 // Copy the list of raw endpoints to avoid packet handling under lock. 613 var rawEPs []RawTransportEndpoint 614 if n := len(eps.rawEndpoints); n != 0 { 615 rawEPs = make([]RawTransportEndpoint, n) 616 if m := copy(rawEPs, eps.rawEndpoints); m != n { 617 panic(fmt.Sprintf("unexpected copy = %d, want %d", m, n)) 618 } 619 } 620 eps.mu.RUnlock() 621 for _, rawEP := range rawEPs { 622 // Each endpoint gets its own copy of the packet for the sake 623 // of save/restore. 624 clone := pkt.Clone() 625 rawEP.HandlePacket(clone) 626 clone.DecRef() 627 } 628 629 return len(rawEPs) != 0 630 } 631 632 // deliverError attempts to deliver the given error to the appropriate transport 633 // endpoint. 634 // 635 // Returns true if the error was delivered. 636 func (d *transportDemuxer) deliverError(n *nic, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt PacketBufferPtr, id TransportEndpointID) bool { 637 eps, ok := d.protocol[protocolIDs{net, trans}] 638 if !ok { 639 return false 640 } 641 642 eps.mu.RLock() 643 ep := eps.findEndpointLocked(id) 644 eps.mu.RUnlock() 645 if ep == nil { 646 return false 647 } 648 649 ep.handleError(n, id, transErr, pkt) 650 return true 651 } 652 653 // findTransportEndpoint find a single endpoint that most closely matches the provided id. 654 func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { 655 eps, ok := d.protocol[protocolIDs{netProto, transProto}] 656 if !ok { 657 return nil 658 } 659 660 eps.mu.RLock() 661 epsByNIC := eps.findEndpointLocked(id) 662 if epsByNIC == nil { 663 eps.mu.RUnlock() 664 return nil 665 } 666 667 epsByNIC.mu.RLock() 668 eps.mu.RUnlock() 669 670 mpep, ok := epsByNIC.endpoints[nicID] 671 if !ok { 672 if mpep, ok = epsByNIC.endpoints[0]; !ok { 673 epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. 674 return nil 675 } 676 } 677 678 ep := mpep.selectEndpoint(id, epsByNIC.seed) 679 epsByNIC.mu.RUnlock() 680 return ep 681 } 682 683 // registerRawEndpoint registers the given endpoint with the dispatcher such 684 // that packets of the appropriate protocol are delivered to it. A single 685 // packet can be sent to one or more raw endpoints along with a non-raw 686 // endpoint. 687 func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error { 688 eps, ok := d.protocol[protocolIDs{netProto, transProto}] 689 if !ok { 690 return &tcpip.ErrNotSupported{} 691 } 692 693 eps.mu.Lock() 694 eps.rawEndpoints = append(eps.rawEndpoints, ep) 695 eps.mu.Unlock() 696 697 return nil 698 } 699 700 // unregisterRawEndpoint unregisters the raw endpoint for the given transport 701 // protocol such that it won't receive any more packets. 702 func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { 703 eps, ok := d.protocol[protocolIDs{netProto, transProto}] 704 if !ok { 705 panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto)) 706 } 707 708 eps.mu.Lock() 709 for i, rawEP := range eps.rawEndpoints { 710 if rawEP == ep { 711 lastIdx := len(eps.rawEndpoints) - 1 712 eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx] 713 eps.rawEndpoints[lastIdx] = nil 714 eps.rawEndpoints = eps.rawEndpoints[:lastIdx] 715 break 716 } 717 } 718 eps.mu.Unlock() 719 } 720 721 func isInboundMulticastOrBroadcast(pkt PacketBufferPtr, localAddr tcpip.Address) bool { 722 return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr) 723 } 724 725 func isSpecified(addr tcpip.Address) bool { 726 return addr != header.IPv4Any && addr != header.IPv6Any 727 }