github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/stack/transport_demuxer.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack 16 17 import ( 18 "fmt" 19 20 "github.com/SagerNet/gvisor/pkg/sync" 21 "github.com/SagerNet/gvisor/pkg/tcpip" 22 "github.com/SagerNet/gvisor/pkg/tcpip/hash/jenkins" 23 "github.com/SagerNet/gvisor/pkg/tcpip/header" 24 "github.com/SagerNet/gvisor/pkg/tcpip/ports" 25 ) 26 27 type protocolIDs struct { 28 network tcpip.NetworkProtocolNumber 29 transport tcpip.TransportProtocolNumber 30 } 31 32 // transportEndpoints manages all endpoints of a given protocol. It has its own 33 // mutex so as to reduce interference between protocols. 34 type transportEndpoints struct { 35 // mu protects all fields of the transportEndpoints. 36 mu sync.RWMutex 37 endpoints map[TransportEndpointID]*endpointsByNIC 38 // rawEndpoints contains endpoints for raw sockets, which receive all 39 // traffic of a given protocol regardless of port. 40 rawEndpoints []RawTransportEndpoint 41 } 42 43 // unregisterEndpoint unregisters the endpoint with the given id such that it 44 // won't receive any more packets. 45 func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { 46 eps.mu.Lock() 47 defer eps.mu.Unlock() 48 epsByNIC, ok := eps.endpoints[id] 49 if !ok { 50 return 51 } 52 if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) { 53 return 54 } 55 delete(eps.endpoints, id) 56 } 57 58 func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { 59 eps.mu.RLock() 60 defer eps.mu.RUnlock() 61 es := make([]TransportEndpoint, 0, len(eps.endpoints)) 62 for _, e := range eps.endpoints { 63 es = append(es, e.transportEndpoints()...) 64 } 65 return es 66 } 67 68 // iterEndpointsLocked yields all endpointsByNIC in eps that match id, in 69 // descending order of match quality. If a call to yield returns false, 70 // iterEndpointsLocked stops iteration and returns immediately. 71 // 72 // Preconditions: eps.mu must be locked. 73 func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { 74 // Try to find a match with the id as provided. 75 if ep, ok := eps.endpoints[id]; ok { 76 if !yield(ep) { 77 return 78 } 79 } 80 81 // Try to find a match with the id minus the local address. 82 nid := id 83 84 nid.LocalAddress = "" 85 if ep, ok := eps.endpoints[nid]; ok { 86 if !yield(ep) { 87 return 88 } 89 } 90 91 // Try to find a match with the id minus the remote part. 92 nid.LocalAddress = id.LocalAddress 93 nid.RemoteAddress = "" 94 nid.RemotePort = 0 95 if ep, ok := eps.endpoints[nid]; ok { 96 if !yield(ep) { 97 return 98 } 99 } 100 101 // Try to find a match with only the local port. 102 nid.LocalAddress = "" 103 if ep, ok := eps.endpoints[nid]; ok { 104 if !yield(ep) { 105 return 106 } 107 } 108 } 109 110 // findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in 111 // descending order of match quality. 112 // 113 // Preconditions: eps.mu must be locked. 114 func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { 115 var matchedEPs []*endpointsByNIC 116 eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { 117 matchedEPs = append(matchedEPs, ep) 118 return true 119 }) 120 return matchedEPs 121 } 122 123 // findEndpointLocked returns the endpoint that most closely matches the given id. 124 // 125 // Preconditions: eps.mu must be locked. 126 func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { 127 var matchedEP *endpointsByNIC 128 eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { 129 matchedEP = ep 130 return false 131 }) 132 return matchedEP 133 } 134 135 type endpointsByNIC struct { 136 mu sync.RWMutex 137 endpoints map[tcpip.NICID]*multiPortEndpoint 138 // seed is a random secret for a jenkins hash. 139 seed uint32 140 } 141 142 func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { 143 epsByNIC.mu.RLock() 144 defer epsByNIC.mu.RUnlock() 145 var eps []TransportEndpoint 146 for _, ep := range epsByNIC.endpoints { 147 eps = append(eps, ep.transportEndpoints()...) 148 } 149 return eps 150 } 151 152 // handlePacket is called by the stack when new packets arrive to this transport 153 // endpoint. It returns false if the packet could not be matched to any 154 // transport endpoint, true otherwise. 155 func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) bool { 156 epsByNIC.mu.RLock() 157 158 mpep, ok := epsByNIC.endpoints[pkt.NICID] 159 if !ok { 160 if mpep, ok = epsByNIC.endpoints[0]; !ok { 161 epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. 162 return false 163 } 164 } 165 166 // If this is a broadcast or multicast datagram, deliver the datagram to all 167 // endpoints bound to the right device. 168 if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { 169 mpep.handlePacketAll(id, pkt) 170 epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. 171 return true 172 } 173 // multiPortEndpoints are guaranteed to have at least one element. 174 transEP := selectEndpoint(id, mpep, epsByNIC.seed) 175 if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { 176 queuedProtocol.QueuePacket(transEP, id, pkt) 177 epsByNIC.mu.RUnlock() 178 return true 179 } 180 181 transEP.HandlePacket(id, pkt) 182 epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. 183 return true 184 } 185 186 // handleError delivers an error to the transport endpoint identified by id. 187 func (epsByNIC *endpointsByNIC) handleError(n *nic, id TransportEndpointID, transErr TransportError, pkt *PacketBuffer) { 188 epsByNIC.mu.RLock() 189 defer epsByNIC.mu.RUnlock() 190 191 mpep, ok := epsByNIC.endpoints[n.ID()] 192 if !ok { 193 mpep, ok = epsByNIC.endpoints[0] 194 } 195 if !ok { 196 return 197 } 198 199 // TODO(eyalsoha): Why don't we look at id to see if this packet needs to 200 // broadcast like we are doing with handlePacket above? 201 202 // multiPortEndpoints are guaranteed to have at least one element. 203 selectEndpoint(id, mpep, epsByNIC.seed).HandleError(transErr, pkt) 204 } 205 206 // registerEndpoint returns true if it succeeds. It fails and returns 207 // false if ep already has an element with the same key. 208 func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 209 epsByNIC.mu.Lock() 210 defer epsByNIC.mu.Unlock() 211 212 multiPortEp, ok := epsByNIC.endpoints[bindToDevice] 213 if !ok { 214 multiPortEp = &multiPortEndpoint{ 215 demux: d, 216 netProto: netProto, 217 transProto: transProto, 218 } 219 } 220 221 if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil { 222 return err 223 } 224 // Only add this newly created multiportEndpoint if the singleRegisterEndpoint 225 // succeeded. 226 if !ok { 227 epsByNIC.endpoints[bindToDevice] = multiPortEp 228 } 229 return nil 230 } 231 232 func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 233 epsByNIC.mu.RLock() 234 defer epsByNIC.mu.RUnlock() 235 236 multiPortEp, ok := epsByNIC.endpoints[bindToDevice] 237 if !ok { 238 return nil 239 } 240 241 return multiPortEp.singleCheckEndpoint(flags) 242 } 243 244 // unregisterEndpoint returns true if endpointsByNIC has to be unregistered. 245 func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool { 246 epsByNIC.mu.Lock() 247 defer epsByNIC.mu.Unlock() 248 multiPortEp, ok := epsByNIC.endpoints[bindToDevice] 249 if !ok { 250 return false 251 } 252 if multiPortEp.unregisterEndpoint(t, flags) { 253 delete(epsByNIC.endpoints, bindToDevice) 254 } 255 return len(epsByNIC.endpoints) == 0 256 } 257 258 // transportDemuxer demultiplexes packets targeted at a transport endpoint 259 // (i.e., after they've been parsed by the network layer). It does two levels 260 // of demultiplexing: first based on the network and transport protocols, then 261 // based on endpoints IDs. It should only be instantiated via 262 // newTransportDemuxer. 263 type transportDemuxer struct { 264 stack *Stack 265 266 // protocol is immutable. 267 protocol map[protocolIDs]*transportEndpoints 268 queuedProtocols map[protocolIDs]queuedTransportProtocol 269 } 270 271 // queuedTransportProtocol if supported by a protocol implementation will cause 272 // the dispatcher to delivery packets to the QueuePacket method instead of 273 // calling HandlePacket directly on the endpoint. 274 type queuedTransportProtocol interface { 275 QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) 276 } 277 278 func newTransportDemuxer(stack *Stack) *transportDemuxer { 279 d := &transportDemuxer{ 280 stack: stack, 281 protocol: make(map[protocolIDs]*transportEndpoints), 282 queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), 283 } 284 285 // Add each network and transport pair to the demuxer. 286 for netProto := range stack.networkProtocols { 287 for proto := range stack.transportProtocols { 288 protoIDs := protocolIDs{netProto, proto} 289 d.protocol[protoIDs] = &transportEndpoints{ 290 endpoints: make(map[TransportEndpointID]*endpointsByNIC), 291 } 292 qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) 293 if isQueued { 294 d.queuedProtocols[protoIDs] = qTransProto 295 } 296 } 297 } 298 299 return d 300 } 301 302 // registerEndpoint registers the given endpoint with the dispatcher such that 303 // packets that match the endpoint ID are delivered to it. 304 func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 305 for i, n := range netProtos { 306 if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil { 307 d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice) 308 return err 309 } 310 } 311 312 return nil 313 } 314 315 // checkEndpoint checks if an endpoint can be registered with the dispatcher. 316 func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 317 for _, n := range netProtos { 318 if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { 319 return err 320 } 321 } 322 323 return nil 324 } 325 326 // multiPortEndpoint is a container for TransportEndpoints which are bound to 327 // the same pair of address and port. endpointsArr always has at least one 328 // element. 329 // 330 // FIXME(github.com/SagerNet/issue/873): Restore this properly. Currently, we just save 331 // this to ensure that the underlying endpoints get saved/restored, but not not 332 // use the restored copy. 333 // 334 // +stateify savable 335 type multiPortEndpoint struct { 336 mu sync.RWMutex `state:"nosave"` 337 demux *transportDemuxer 338 netProto tcpip.NetworkProtocolNumber 339 transProto tcpip.TransportProtocolNumber 340 341 // endpoints stores the transport endpoints in the order in which they 342 // were bound. This is required for UDP SO_REUSEADDR. 343 endpoints []TransportEndpoint 344 flags ports.FlagCounter 345 } 346 347 func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { 348 ep.mu.RLock() 349 eps := append([]TransportEndpoint(nil), ep.endpoints...) 350 ep.mu.RUnlock() 351 return eps 352 } 353 354 // reciprocalScale scales a value into range [0, n). 355 // 356 // This is similar to val % n, but faster. 357 // See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ 358 func reciprocalScale(val, n uint32) uint32 { 359 return uint32((uint64(val) * uint64(n)) >> 32) 360 } 361 362 // selectEndpoint calculates a hash of destination and source addresses and 363 // ports then uses it to select a socket. In this case, all packets from one 364 // address will be sent to same endpoint. 365 func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { 366 if len(mpep.endpoints) == 1 { 367 return mpep.endpoints[0] 368 } 369 370 if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent { 371 return mpep.endpoints[len(mpep.endpoints)-1] 372 } 373 374 payload := []byte{ 375 byte(id.LocalPort), 376 byte(id.LocalPort >> 8), 377 byte(id.RemotePort), 378 byte(id.RemotePort >> 8), 379 } 380 381 h := jenkins.Sum32(seed) 382 h.Write(payload) 383 h.Write([]byte(id.LocalAddress)) 384 h.Write([]byte(id.RemoteAddress)) 385 hash := h.Sum32() 386 387 idx := reciprocalScale(hash, uint32(len(mpep.endpoints))) 388 return mpep.endpoints[idx] 389 } 390 391 func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { 392 ep.mu.RLock() 393 queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] 394 // HandlePacket takes ownership of pkt, so each endpoint needs 395 // its own copy except for the final one. 396 for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { 397 if mustQueue { 398 queuedProtocol.QueuePacket(endpoint, id, pkt.Clone()) 399 } else { 400 endpoint.HandlePacket(id, pkt.Clone()) 401 } 402 } 403 if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { 404 queuedProtocol.QueuePacket(endpoint, id, pkt) 405 } else { 406 endpoint.HandlePacket(id, pkt) 407 } 408 ep.mu.RUnlock() // Don't use defer for performance reasons. 409 } 410 411 // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint 412 // list. The list might be empty already. 413 func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error { 414 ep.mu.Lock() 415 defer ep.mu.Unlock() 416 bits := flags.Bits() & ports.MultiBindFlagMask 417 418 if len(ep.endpoints) != 0 { 419 // If it was previously bound, we need to check if we can bind again. 420 if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 { 421 return &tcpip.ErrPortInUse{} 422 } 423 } 424 425 ep.endpoints = append(ep.endpoints, t) 426 ep.flags.AddRef(bits) 427 428 return nil 429 } 430 431 func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error { 432 ep.mu.RLock() 433 defer ep.mu.RUnlock() 434 435 bits := flags.Bits() & ports.MultiBindFlagMask 436 437 if len(ep.endpoints) != 0 { 438 // If it was previously bound, we need to check if we can bind again. 439 if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 { 440 return &tcpip.ErrPortInUse{} 441 } 442 } 443 444 return nil 445 } 446 447 // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. 448 func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool { 449 ep.mu.Lock() 450 defer ep.mu.Unlock() 451 452 for i, endpoint := range ep.endpoints { 453 if endpoint == t { 454 copy(ep.endpoints[i:], ep.endpoints[i+1:]) 455 ep.endpoints[len(ep.endpoints)-1] = nil 456 ep.endpoints = ep.endpoints[:len(ep.endpoints)-1] 457 458 ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask) 459 break 460 } 461 } 462 return len(ep.endpoints) == 0 463 } 464 465 func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 466 if id.RemotePort != 0 { 467 // SO_REUSEPORT only applies to bound/listening endpoints. 468 flags.LoadBalanced = false 469 } 470 471 eps, ok := d.protocol[protocolIDs{netProto, protocol}] 472 if !ok { 473 return &tcpip.ErrUnknownProtocol{} 474 } 475 476 eps.mu.Lock() 477 defer eps.mu.Unlock() 478 epsByNIC, ok := eps.endpoints[id] 479 if !ok { 480 epsByNIC = &endpointsByNIC{ 481 endpoints: make(map[tcpip.NICID]*multiPortEndpoint), 482 seed: d.stack.Seed(), 483 } 484 } 485 if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil { 486 return err 487 } 488 // Only add this newly created epsByNIC if registerEndpoint succeeded. 489 if !ok { 490 eps.endpoints[id] = epsByNIC 491 } 492 return nil 493 } 494 495 func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { 496 if id.RemotePort != 0 { 497 // SO_REUSEPORT only applies to bound/listening endpoints. 498 flags.LoadBalanced = false 499 } 500 501 eps, ok := d.protocol[protocolIDs{netProto, protocol}] 502 if !ok { 503 return &tcpip.ErrUnknownProtocol{} 504 } 505 506 eps.mu.RLock() 507 defer eps.mu.RUnlock() 508 509 epsByNIC, ok := eps.endpoints[id] 510 if !ok { 511 return nil 512 } 513 514 return epsByNIC.checkEndpoint(flags, bindToDevice) 515 } 516 517 // unregisterEndpoint unregisters the endpoint with the given id such that it 518 // won't receive any more packets. 519 func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { 520 if id.RemotePort != 0 { 521 // SO_REUSEPORT only applies to bound/listening endpoints. 522 flags.LoadBalanced = false 523 } 524 525 for _, n := range netProtos { 526 if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { 527 eps.unregisterEndpoint(id, ep, flags, bindToDevice) 528 } 529 } 530 } 531 532 // deliverPacket attempts to find one or more matching transport endpoints, and 533 // then, if matches are found, delivers the packet to them. Returns true if 534 // the packet no longer needs to be handled. 535 func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { 536 eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] 537 if !ok { 538 return false 539 } 540 541 // If the packet is a UDP broadcast or multicast, then find all matching 542 // transport endpoints. 543 if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { 544 eps.mu.RLock() 545 destEPs := eps.findAllEndpointsLocked(id) 546 eps.mu.RUnlock() 547 // Fail if we didn't find at least one matching transport endpoint. 548 if len(destEPs) == 0 { 549 d.stack.stats.UDP.UnknownPortErrors.Increment() 550 return false 551 } 552 // handlePacket takes ownership of pkt, so each endpoint needs its own 553 // copy except for the final one. 554 for _, ep := range destEPs[:len(destEPs)-1] { 555 ep.handlePacket(id, pkt.Clone()) 556 } 557 destEPs[len(destEPs)-1].handlePacket(id, pkt) 558 return true 559 } 560 561 // If the packet is a TCP packet with a unspecified source or non-unicast 562 // destination address, then do nothing further and instruct the caller to do 563 // the same. The network layer handles address validation for specified source 564 // addresses. 565 if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) { 566 // TCP can only be used to communicate between a single source and a 567 // single destination; the addresses must be unicast.e 568 d.stack.stats.TCP.InvalidSegmentsReceived.Increment() 569 return true 570 } 571 572 eps.mu.RLock() 573 ep := eps.findEndpointLocked(id) 574 eps.mu.RUnlock() 575 if ep == nil { 576 if protocol == header.UDPProtocolNumber { 577 d.stack.stats.UDP.UnknownPortErrors.Increment() 578 } 579 return false 580 } 581 return ep.handlePacket(id, pkt) 582 } 583 584 // deliverRawPacket attempts to deliver the given packet and returns whether it 585 // was delivered successfully. 586 func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { 587 eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] 588 if !ok { 589 return false 590 } 591 592 // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via 593 // raw endpoint first. If there are multiple raw endpoints, they all 594 // receive the packet. 595 eps.mu.RLock() 596 // Copy the list of raw endpoints to avoid packet handling under lock. 597 var rawEPs []RawTransportEndpoint 598 if n := len(eps.rawEndpoints); n != 0 { 599 rawEPs = make([]RawTransportEndpoint, n) 600 if m := copy(rawEPs, eps.rawEndpoints); m != n { 601 panic(fmt.Sprintf("unexpected copy = %d, want %d", m, n)) 602 } 603 } 604 eps.mu.RUnlock() 605 for _, rawEP := range rawEPs { 606 // Each endpoint gets its own copy of the packet for the sake 607 // of save/restore. 608 rawEP.HandlePacket(pkt.Clone()) 609 } 610 611 return len(rawEPs) != 0 612 } 613 614 // deliverError attempts to deliver the given error to the appropriate transport 615 // endpoint. 616 // 617 // Returns true if the error was delivered. 618 func (d *transportDemuxer) deliverError(n *nic, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, transErr TransportError, pkt *PacketBuffer, id TransportEndpointID) bool { 619 eps, ok := d.protocol[protocolIDs{net, trans}] 620 if !ok { 621 return false 622 } 623 624 eps.mu.RLock() 625 ep := eps.findEndpointLocked(id) 626 eps.mu.RUnlock() 627 if ep == nil { 628 return false 629 } 630 631 ep.handleError(n, id, transErr, pkt) 632 return true 633 } 634 635 // findTransportEndpoint find a single endpoint that most closely matches the provided id. 636 func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { 637 eps, ok := d.protocol[protocolIDs{netProto, transProto}] 638 if !ok { 639 return nil 640 } 641 642 eps.mu.RLock() 643 epsByNIC := eps.findEndpointLocked(id) 644 if epsByNIC == nil { 645 eps.mu.RUnlock() 646 return nil 647 } 648 649 epsByNIC.mu.RLock() 650 eps.mu.RUnlock() 651 652 mpep, ok := epsByNIC.endpoints[nicID] 653 if !ok { 654 if mpep, ok = epsByNIC.endpoints[0]; !ok { 655 epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. 656 return nil 657 } 658 } 659 660 ep := selectEndpoint(id, mpep, epsByNIC.seed) 661 epsByNIC.mu.RUnlock() 662 return ep 663 } 664 665 // registerRawEndpoint registers the given endpoint with the dispatcher such 666 // that packets of the appropriate protocol are delivered to it. A single 667 // packet can be sent to one or more raw endpoints along with a non-raw 668 // endpoint. 669 func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) tcpip.Error { 670 eps, ok := d.protocol[protocolIDs{netProto, transProto}] 671 if !ok { 672 return &tcpip.ErrNotSupported{} 673 } 674 675 eps.mu.Lock() 676 eps.rawEndpoints = append(eps.rawEndpoints, ep) 677 eps.mu.Unlock() 678 679 return nil 680 } 681 682 // unregisterRawEndpoint unregisters the raw endpoint for the given transport 683 // protocol such that it won't receive any more packets. 684 func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { 685 eps, ok := d.protocol[protocolIDs{netProto, transProto}] 686 if !ok { 687 panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto)) 688 } 689 690 eps.mu.Lock() 691 for i, rawEP := range eps.rawEndpoints { 692 if rawEP == ep { 693 lastIdx := len(eps.rawEndpoints) - 1 694 eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx] 695 eps.rawEndpoints[lastIdx] = nil 696 eps.rawEndpoints = eps.rawEndpoints[:lastIdx] 697 break 698 } 699 } 700 eps.mu.Unlock() 701 } 702 703 func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool { 704 return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr) 705 } 706 707 func isSpecified(addr tcpip.Address) bool { 708 return addr != header.IPv4Any && addr != header.IPv6Any 709 }