github.com/polevpn/netstack@v1.10.9/tcpip/stack/transport_demuxer.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack 16 17 import ( 18 "fmt" 19 "math/rand" 20 "sort" 21 "sync" 22 23 "github.com/polevpn/netstack/tcpip" 24 "github.com/polevpn/netstack/tcpip/hash/jenkins" 25 "github.com/polevpn/netstack/tcpip/header" 26 ) 27 28 type protocolIDs struct { 29 network tcpip.NetworkProtocolNumber 30 transport tcpip.TransportProtocolNumber 31 } 32 33 // transportEndpoints manages all endpoints of a given protocol. It has its own 34 // mutex so as to reduce interference between protocols. 35 type transportEndpoints struct { 36 // mu protects all fields of the transportEndpoints. 37 mu sync.RWMutex 38 endpoints map[TransportEndpointID]*endpointsByNic 39 // rawEndpoints contains endpoints for raw sockets, which receive all 40 // traffic of a given protocol regardless of port. 41 rawEndpoints []RawTransportEndpoint 42 } 43 44 // unregisterEndpoint unregisters the endpoint with the given id such that it 45 // won't receive any more packets. 46 func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { 47 eps.mu.Lock() 48 defer eps.mu.Unlock() 49 epsByNic, ok := eps.endpoints[id] 50 if !ok { 51 return 52 } 53 if !epsByNic.unregisterEndpoint(bindToDevice, ep) { 54 return 55 } 56 delete(eps.endpoints, id) 57 } 58 59 func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { 60 eps.mu.RLock() 61 defer eps.mu.RUnlock() 62 es := make([]TransportEndpoint, 0, len(eps.endpoints)) 63 for _, e := range eps.endpoints { 64 es = append(es, e.transportEndpoints()...) 65 } 66 return es 67 } 68 69 type endpointsByNic struct { 70 mu sync.RWMutex 71 endpoints map[tcpip.NICID]*multiPortEndpoint 72 // seed is a random secret for a jenkins hash. 73 seed uint32 74 } 75 76 func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { 77 epsByNic.mu.RLock() 78 defer epsByNic.mu.RUnlock() 79 var eps []TransportEndpoint 80 for _, ep := range epsByNic.endpoints { 81 eps = append(eps, ep.transportEndpoints()...) 82 } 83 return eps 84 } 85 86 // HandlePacket is called by the stack when new packets arrive to this transport 87 // endpoint. 88 func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { 89 epsByNic.mu.RLock() 90 91 mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] 92 if !ok { 93 if mpep, ok = epsByNic.endpoints[0]; !ok { 94 epsByNic.mu.RUnlock() // Don't use defer for performance reasons. 95 return 96 } 97 } 98 99 // If this is a broadcast or multicast datagram, deliver the datagram to all 100 // endpoints bound to the right device. 101 if isMulticastOrBroadcast(id.LocalAddress) { 102 mpep.handlePacketAll(r, id, pkt) 103 epsByNic.mu.RUnlock() // Don't use defer for performance reasons. 104 return 105 } 106 // multiPortEndpoints are guaranteed to have at least one element. 107 selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) 108 epsByNic.mu.RUnlock() // Don't use defer for performance reasons. 109 } 110 111 // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. 112 func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { 113 epsByNic.mu.RLock() 114 defer epsByNic.mu.RUnlock() 115 116 mpep, ok := epsByNic.endpoints[n.ID()] 117 if !ok { 118 mpep, ok = epsByNic.endpoints[0] 119 } 120 if !ok { 121 return 122 } 123 124 // TODO(eyalsoha): Why don't we look at id to see if this packet needs to 125 // broadcast like we are doing with handlePacket above? 126 127 // multiPortEndpoints are guaranteed to have at least one element. 128 selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt) 129 } 130 131 // registerEndpoint returns true if it succeeds. It fails and returns 132 // false if ep already has an element with the same key. 133 func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { 134 epsByNic.mu.Lock() 135 defer epsByNic.mu.Unlock() 136 137 if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { 138 // There was already a bind. 139 return multiPortEp.singleRegisterEndpoint(t, reusePort) 140 } 141 142 // This is a new binding. 143 multiPortEp := &multiPortEndpoint{} 144 multiPortEp.endpointsMap = make(map[TransportEndpoint]int) 145 multiPortEp.reuse = reusePort 146 epsByNic.endpoints[bindToDevice] = multiPortEp 147 return multiPortEp.singleRegisterEndpoint(t, reusePort) 148 } 149 150 // unregisterEndpoint returns true if endpointsByNic has to be unregistered. 151 func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { 152 epsByNic.mu.Lock() 153 defer epsByNic.mu.Unlock() 154 multiPortEp, ok := epsByNic.endpoints[bindToDevice] 155 if !ok { 156 return false 157 } 158 if multiPortEp.unregisterEndpoint(t) { 159 delete(epsByNic.endpoints, bindToDevice) 160 } 161 return len(epsByNic.endpoints) == 0 162 } 163 164 // transportDemuxer demultiplexes packets targeted at a transport endpoint 165 // (i.e., after they've been parsed by the network layer). It does two levels 166 // of demultiplexing: first based on the network and transport protocols, then 167 // based on endpoints IDs. It should only be instantiated via 168 // newTransportDemuxer. 169 type transportDemuxer struct { 170 // protocol is immutable. 171 protocol map[protocolIDs]*transportEndpoints 172 } 173 174 func newTransportDemuxer(stack *Stack) *transportDemuxer { 175 d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} 176 177 // Add each network and transport pair to the demuxer. 178 for netProto := range stack.networkProtocols { 179 for proto := range stack.transportProtocols { 180 d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ 181 endpoints: make(map[TransportEndpointID]*endpointsByNic), 182 } 183 } 184 } 185 186 return d 187 } 188 189 // registerEndpoint registers the given endpoint with the dispatcher such that 190 // packets that match the endpoint ID are delivered to it. 191 func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { 192 for i, n := range netProtos { 193 if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { 194 d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) 195 return err 196 } 197 } 198 199 return nil 200 } 201 202 // multiPortEndpoint is a container for TransportEndpoints which are bound to 203 // the same pair of address and port. endpointsArr always has at least one 204 // element. 205 // 206 // FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save 207 // this to ensure that the underlying endpoints get saved/restored, but not not 208 // use the restored copy. 209 // 210 // +stateify savable 211 type multiPortEndpoint struct { 212 mu sync.RWMutex 213 endpointsArr []TransportEndpoint 214 endpointsMap map[TransportEndpoint]int 215 // reuse indicates if more than one endpoint is allowed. 216 reuse bool 217 } 218 219 func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { 220 ep.mu.RLock() 221 eps := append([]TransportEndpoint(nil), ep.endpointsArr...) 222 ep.mu.RUnlock() 223 return eps 224 } 225 226 // reciprocalScale scales a value into range [0, n). 227 // 228 // This is similar to val % n, but faster. 229 // See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ 230 func reciprocalScale(val, n uint32) uint32 { 231 return uint32((uint64(val) * uint64(n)) >> 32) 232 } 233 234 // selectEndpoint calculates a hash of destination and source addresses and 235 // ports then uses it to select a socket. In this case, all packets from one 236 // address will be sent to same endpoint. 237 func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { 238 if len(mpep.endpointsArr) == 1 { 239 return mpep.endpointsArr[0] 240 } 241 242 payload := []byte{ 243 byte(id.LocalPort), 244 byte(id.LocalPort >> 8), 245 byte(id.RemotePort), 246 byte(id.RemotePort >> 8), 247 } 248 249 h := jenkins.Sum32(seed) 250 h.Write(payload) 251 h.Write([]byte(id.LocalAddress)) 252 h.Write([]byte(id.RemoteAddress)) 253 hash := h.Sum32() 254 255 idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) 256 return mpep.endpointsArr[idx] 257 } 258 259 func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { 260 ep.mu.RLock() 261 for i, endpoint := range ep.endpointsArr { 262 // HandlePacket takes ownership of pkt, so each endpoint needs 263 // its own copy except for the final one. 264 if i == len(ep.endpointsArr)-1 { 265 endpoint.HandlePacket(r, id, pkt) 266 break 267 } 268 endpoint.HandlePacket(r, id, pkt.Clone()) 269 } 270 ep.mu.RUnlock() // Don't use defer for performance reasons. 271 } 272 273 // Close implements stack.TransportEndpoint.Close. 274 func (ep *multiPortEndpoint) Close() { 275 ep.mu.RLock() 276 eps := append([]TransportEndpoint(nil), ep.endpointsArr...) 277 ep.mu.RUnlock() 278 for _, e := range eps { 279 e.Close() 280 } 281 } 282 283 // Wait implements stack.TransportEndpoint.Wait. 284 func (ep *multiPortEndpoint) Wait() { 285 ep.mu.RLock() 286 eps := append([]TransportEndpoint(nil), ep.endpointsArr...) 287 ep.mu.RUnlock() 288 for _, e := range eps { 289 e.Wait() 290 } 291 } 292 293 // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint 294 // list. The list might be empty already. 295 func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { 296 ep.mu.Lock() 297 defer ep.mu.Unlock() 298 299 if len(ep.endpointsArr) > 0 { 300 // If it was previously bound, we need to check if we can bind again. 301 if !ep.reuse || !reusePort { 302 return tcpip.ErrPortInUse 303 } 304 } 305 306 // A new endpoint is added into endpointsArr and its index there is saved in 307 // endpointsMap. This will allow us to remove endpoint from the array fast. 308 ep.endpointsMap[t] = len(ep.endpointsArr) 309 ep.endpointsArr = append(ep.endpointsArr, t) 310 311 // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints 312 // can be restored in the same order. 313 sort.Slice(ep.endpointsArr, func(i, j int) bool { 314 return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID() 315 }) 316 for i, e := range ep.endpointsArr { 317 ep.endpointsMap[e] = i 318 } 319 return nil 320 } 321 322 // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. 323 func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { 324 ep.mu.Lock() 325 defer ep.mu.Unlock() 326 327 idx, ok := ep.endpointsMap[t] 328 if !ok { 329 return false 330 } 331 delete(ep.endpointsMap, t) 332 l := len(ep.endpointsArr) 333 if l > 1 { 334 // The last endpoint in endpointsArr is moved instead of the deleted one. 335 lastEp := ep.endpointsArr[l-1] 336 ep.endpointsArr[idx] = lastEp 337 ep.endpointsMap[lastEp] = idx 338 ep.endpointsArr = ep.endpointsArr[0 : l-1] 339 return false 340 } 341 return true 342 } 343 344 func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { 345 if id.RemotePort != 0 { 346 // TODO(eyalsoha): Why? 347 reusePort = false 348 } 349 350 eps, ok := d.protocol[protocolIDs{netProto, protocol}] 351 if !ok { 352 return tcpip.ErrUnknownProtocol 353 } 354 355 eps.mu.Lock() 356 defer eps.mu.Unlock() 357 358 if epsByNic, ok := eps.endpoints[id]; ok { 359 // There was already a binding. 360 return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) 361 } 362 363 // This is a new binding. 364 epsByNic := &endpointsByNic{ 365 endpoints: make(map[tcpip.NICID]*multiPortEndpoint), 366 seed: rand.Uint32(), 367 } 368 eps.endpoints[id] = epsByNic 369 370 return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) 371 } 372 373 // unregisterEndpoint unregisters the endpoint with the given id such that it 374 // won't receive any more packets. 375 func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { 376 for _, n := range netProtos { 377 if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { 378 eps.unregisterEndpoint(id, ep, bindToDevice) 379 } 380 } 381 } 382 383 var loopbackSubnet = func() tcpip.Subnet { 384 sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00") 385 if err != nil { 386 panic(err) 387 } 388 return sn 389 }() 390 391 // deliverPacket attempts to find one or more matching transport endpoints, and 392 // then, if matches are found, delivers the packet to them. Returns true if 393 // the packet no longer needs to be handled. 394 func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { 395 eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] 396 if !ok { 397 return false 398 } 399 400 eps.mu.RLock() 401 402 // Determine which transport endpoint or endpoints to deliver this packet to. 403 // If the packet is a UDP broadcast or multicast, then find all matching 404 // transport endpoints. If the packet is a TCP packet with a non-unicast 405 // source or destination address, then do nothing further and instruct 406 // the caller to do the same. 407 var destEps []*endpointsByNic 408 switch protocol { 409 case header.UDPProtocolNumber: 410 if isMulticastOrBroadcast(id.LocalAddress) { 411 destEps = d.findAllEndpointsLocked(eps, id) 412 break 413 } 414 415 if ep := d.findEndpointLocked(eps, id); ep != nil { 416 destEps = append(destEps, ep) 417 } 418 419 case header.TCPProtocolNumber: 420 if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) { 421 // TCP can only be used to communicate between a single 422 // source and a single destination; the addresses must 423 // be unicast. 424 eps.mu.RUnlock() 425 r.Stats().TCP.InvalidSegmentsReceived.Increment() 426 return true 427 } 428 429 fallthrough 430 431 default: 432 if ep := d.findEndpointLocked(eps, id); ep != nil { 433 destEps = append(destEps, ep) 434 } 435 } 436 437 eps.mu.RUnlock() 438 439 // Fail if we didn't find at least one matching transport endpoint. 440 if len(destEps) == 0 { 441 // UDP packet could not be delivered to an unknown destination port. 442 if protocol == header.UDPProtocolNumber { 443 r.Stats().UDP.UnknownPortErrors.Increment() 444 } 445 return false 446 } 447 448 // HandlePacket takes ownership of pkt, so each endpoint needs its own 449 // copy except for the final one. 450 for _, ep := range destEps[:len(destEps)-1] { 451 ep.handlePacket(r, id, pkt.Clone()) 452 } 453 destEps[len(destEps)-1].handlePacket(r, id, pkt) 454 455 return true 456 } 457 458 // deliverRawPacket attempts to deliver the given packet and returns whether it 459 // was delivered successfully. 460 func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool { 461 eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] 462 if !ok { 463 return false 464 } 465 466 // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via 467 // raw endpoint first. If there are multiple raw endpoints, they all 468 // receive the packet. 469 foundRaw := false 470 eps.mu.RLock() 471 for _, rawEP := range eps.rawEndpoints { 472 // Each endpoint gets its own copy of the packet for the sake 473 // of save/restore. 474 rawEP.HandlePacket(r, pkt) 475 foundRaw = true 476 } 477 eps.mu.RUnlock() 478 479 return foundRaw 480 } 481 482 // deliverControlPacket attempts to deliver the given control packet. Returns 483 // true if it found an endpoint, false otherwise. 484 func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { 485 eps, ok := d.protocol[protocolIDs{net, trans}] 486 if !ok { 487 return false 488 } 489 490 // Try to find the endpoint. 491 eps.mu.RLock() 492 ep := d.findEndpointLocked(eps, id) 493 eps.mu.RUnlock() 494 495 // Fail if we didn't find one. 496 if ep == nil { 497 return false 498 } 499 500 // Deliver the packet. 501 ep.handleControlPacket(n, id, typ, extra, pkt) 502 503 return true 504 } 505 506 func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { 507 var matchedEPs []*endpointsByNic 508 // Try to find a match with the id as provided. 509 if ep, ok := eps.endpoints[id]; ok { 510 matchedEPs = append(matchedEPs, ep) 511 } 512 513 // Try to find a match with the id minus the local address. 514 nid := id 515 516 nid.LocalAddress = "" 517 if ep, ok := eps.endpoints[nid]; ok { 518 matchedEPs = append(matchedEPs, ep) 519 } 520 521 // Try to find a match with the id minus the remote part. 522 nid.LocalAddress = id.LocalAddress 523 nid.RemoteAddress = "" 524 nid.RemotePort = 0 525 if ep, ok := eps.endpoints[nid]; ok { 526 matchedEPs = append(matchedEPs, ep) 527 } 528 529 // Try to find a match with only the local port. 530 nid.LocalAddress = "" 531 if ep, ok := eps.endpoints[nid]; ok { 532 matchedEPs = append(matchedEPs, ep) 533 } 534 return matchedEPs 535 } 536 537 // findTransportEndpoint find a single endpoint that most closely matches the provided id. 538 func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { 539 eps, ok := d.protocol[protocolIDs{netProto, transProto}] 540 if !ok { 541 return nil 542 } 543 // Try to find the endpoint. 544 eps.mu.RLock() 545 epsByNic := d.findEndpointLocked(eps, id) 546 // Fail if we didn't find one. 547 if epsByNic == nil { 548 eps.mu.RUnlock() 549 return nil 550 } 551 552 epsByNic.mu.RLock() 553 eps.mu.RUnlock() 554 555 mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] 556 if !ok { 557 if mpep, ok = epsByNic.endpoints[0]; !ok { 558 epsByNic.mu.RUnlock() // Don't use defer for performance reasons. 559 return nil 560 } 561 } 562 563 ep := selectEndpoint(id, mpep, epsByNic.seed) 564 epsByNic.mu.RUnlock() 565 return ep 566 } 567 568 // findEndpointLocked returns the endpoint that most closely matches the given 569 // id. 570 func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { 571 if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 { 572 return matchedEPs[0] 573 } 574 return nil 575 } 576 577 // registerRawEndpoint registers the given endpoint with the dispatcher such 578 // that packets of the appropriate protocol are delivered to it. A single 579 // packet can be sent to one or more raw endpoints along with a non-raw 580 // endpoint. 581 func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { 582 eps, ok := d.protocol[protocolIDs{netProto, transProto}] 583 if !ok { 584 return tcpip.ErrNotSupported 585 } 586 587 eps.mu.Lock() 588 defer eps.mu.Unlock() 589 eps.rawEndpoints = append(eps.rawEndpoints, ep) 590 591 return nil 592 } 593 594 // unregisterRawEndpoint unregisters the raw endpoint for the given transport 595 // protocol such that it won't receive any more packets. 596 func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { 597 eps, ok := d.protocol[protocolIDs{netProto, transProto}] 598 if !ok { 599 panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto)) 600 } 601 602 eps.mu.Lock() 603 defer eps.mu.Unlock() 604 for i, rawEP := range eps.rawEndpoints { 605 if rawEP == ep { 606 eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) 607 return 608 } 609 } 610 } 611 612 func isMulticastOrBroadcast(addr tcpip.Address) bool { 613 return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) 614 } 615 616 func isUnicast(addr tcpip.Address) bool { 617 return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) 618 }