github.com/FlowerWrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/stack/transport_demuxer.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stack 16 17 import ( 18 "fmt" 19 "math/rand" 20 "sync" 21 22 "github.com/FlowerWrong/netstack/tcpip" 23 "github.com/FlowerWrong/netstack/tcpip/buffer" 24 "github.com/FlowerWrong/netstack/tcpip/hash/jenkins" 25 "github.com/FlowerWrong/netstack/tcpip/header" 26 ) 27 28 type protocolIDs struct { 29 network tcpip.NetworkProtocolNumber 30 transport tcpip.TransportProtocolNumber 31 } 32 33 // transportEndpoints manages all endpoints of a given protocol. It has its own 34 // mutex so as to reduce interference between protocols. 35 type transportEndpoints struct { 36 // mu protects all fields of the transportEndpoints. 37 mu sync.RWMutex 38 endpoints map[TransportEndpointID]*endpointsByNic 39 // rawEndpoints contains endpoints for raw sockets, which receive all 40 // traffic of a given protocol regardless of port. 41 rawEndpoints []RawTransportEndpoint 42 } 43 44 type endpointsByNic struct { 45 mu sync.RWMutex 46 endpoints map[tcpip.NICID]*multiPortEndpoint 47 // seed is a random secret for a jenkins hash. 48 seed uint32 49 } 50 51 // HandlePacket is called by the stack when new packets arrive to this transport 52 // endpoint. 53 func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { 54 epsByNic.mu.RLock() 55 56 mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] 57 if !ok { 58 if mpep, ok = epsByNic.endpoints[0]; !ok { 59 epsByNic.mu.RUnlock() // Don't use defer for performance reasons. 60 return 61 } 62 } 63 64 // If this is a broadcast or multicast datagram, deliver the datagram to all 65 // endpoints bound to the right device. 66 if isMulticastOrBroadcast(id.LocalAddress) { 67 mpep.handlePacketAll(r, id, vv) 68 epsByNic.mu.RUnlock() // Don't use defer for performance reasons. 69 return 70 } 71 72 // multiPortEndpoints are guaranteed to have at least one element. 73 selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv) 74 epsByNic.mu.RUnlock() // Don't use defer for performance reasons. 75 } 76 77 // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. 78 func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { 79 epsByNic.mu.RLock() 80 defer epsByNic.mu.RUnlock() 81 82 mpep, ok := epsByNic.endpoints[n.ID()] 83 if !ok { 84 mpep, ok = epsByNic.endpoints[0] 85 } 86 if !ok { 87 return 88 } 89 90 // TODO(eyalsoha): Why don't we look at id to see if this packet needs to 91 // broadcast like we are doing with handlePacket above? 92 93 // multiPortEndpoints are guaranteed to have at least one element. 94 selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv) 95 } 96 97 // registerEndpoint returns true if it succeeds. It fails and returns 98 // false if ep already has an element with the same key. 99 func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { 100 epsByNic.mu.Lock() 101 defer epsByNic.mu.Unlock() 102 103 if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { 104 // There was already a bind. 105 return multiPortEp.singleRegisterEndpoint(t, reusePort) 106 } 107 108 // This is a new binding. 109 multiPortEp := &multiPortEndpoint{} 110 multiPortEp.endpointsMap = make(map[TransportEndpoint]int) 111 multiPortEp.reuse = reusePort 112 epsByNic.endpoints[bindToDevice] = multiPortEp 113 return multiPortEp.singleRegisterEndpoint(t, reusePort) 114 } 115 116 // unregisterEndpoint returns true if endpointsByNic has to be unregistered. 117 func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { 118 epsByNic.mu.Lock() 119 defer epsByNic.mu.Unlock() 120 multiPortEp, ok := epsByNic.endpoints[bindToDevice] 121 if !ok { 122 return false 123 } 124 if multiPortEp.unregisterEndpoint(t) { 125 delete(epsByNic.endpoints, bindToDevice) 126 } 127 return len(epsByNic.endpoints) == 0 128 } 129 130 // unregisterEndpoint unregisters the endpoint with the given id such that it 131 // won't receive any more packets. 132 func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { 133 eps.mu.Lock() 134 defer eps.mu.Unlock() 135 epsByNic, ok := eps.endpoints[id] 136 if !ok { 137 return 138 } 139 if !epsByNic.unregisterEndpoint(bindToDevice, ep) { 140 return 141 } 142 delete(eps.endpoints, id) 143 } 144 145 // transportDemuxer demultiplexes packets targeted at a transport endpoint 146 // (i.e., after they've been parsed by the network layer). It does two levels 147 // of demultiplexing: first based on the network and transport protocols, then 148 // based on endpoints IDs. It should only be instantiated via 149 // newTransportDemuxer. 150 type transportDemuxer struct { 151 // protocol is immutable. 152 protocol map[protocolIDs]*transportEndpoints 153 } 154 155 func newTransportDemuxer(stack *Stack) *transportDemuxer { 156 d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} 157 158 // Add each network and transport pair to the demuxer. 159 for netProto := range stack.networkProtocols { 160 for proto := range stack.transportProtocols { 161 d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ 162 endpoints: make(map[TransportEndpointID]*endpointsByNic), 163 } 164 } 165 } 166 167 return d 168 } 169 170 // registerEndpoint registers the given endpoint with the dispatcher such that 171 // packets that match the endpoint ID are delivered to it. 172 func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { 173 for i, n := range netProtos { 174 if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { 175 d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) 176 return err 177 } 178 } 179 180 return nil 181 } 182 183 // multiPortEndpoint is a container for TransportEndpoints which are bound to 184 // the same pair of address and port. endpointsArr always has at least one 185 // element. 186 type multiPortEndpoint struct { 187 mu sync.RWMutex 188 endpointsArr []TransportEndpoint 189 endpointsMap map[TransportEndpoint]int 190 // reuse indicates if more than one endpoint is allowed. 191 reuse bool 192 } 193 194 // reciprocalScale scales a value into range [0, n). 195 // 196 // This is similar to val % n, but faster. 197 // See http://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ 198 func reciprocalScale(val, n uint32) uint32 { 199 return uint32((uint64(val) * uint64(n)) >> 32) 200 } 201 202 // selectEndpoint calculates a hash of destination and source addresses and 203 // ports then uses it to select a socket. In this case, all packets from one 204 // address will be sent to same endpoint. 205 func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { 206 if len(mpep.endpointsArr) == 1 { 207 return mpep.endpointsArr[0] 208 } 209 210 payload := []byte{ 211 byte(id.LocalPort), 212 byte(id.LocalPort >> 8), 213 byte(id.RemotePort), 214 byte(id.RemotePort >> 8), 215 } 216 217 h := jenkins.Sum32(seed) 218 h.Write(payload) 219 h.Write([]byte(id.LocalAddress)) 220 h.Write([]byte(id.RemoteAddress)) 221 hash := h.Sum32() 222 223 idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) 224 return mpep.endpointsArr[idx] 225 } 226 227 func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { 228 ep.mu.RLock() 229 for i, endpoint := range ep.endpointsArr { 230 // HandlePacket modifies vv, so each endpoint needs its own copy except for 231 // the final one. 232 if i == len(ep.endpointsArr)-1 { 233 endpoint.HandlePacket(r, id, vv) 234 break 235 } 236 vvCopy := buffer.NewView(vv.Size()) 237 copy(vvCopy, vv.ToView()) 238 endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) 239 } 240 ep.mu.RUnlock() // Don't use defer for performance reasons. 241 } 242 243 // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint 244 // list. The list might be empty already. 245 func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { 246 ep.mu.Lock() 247 defer ep.mu.Unlock() 248 249 if len(ep.endpointsArr) > 0 { 250 // If it was previously bound, we need to check if we can bind again. 251 if !ep.reuse || !reusePort { 252 return tcpip.ErrPortInUse 253 } 254 } 255 256 // A new endpoint is added into endpointsArr and its index there is saved in 257 // endpointsMap. This will allow us to remove endpoint from the array fast. 258 ep.endpointsMap[t] = len(ep.endpointsArr) 259 ep.endpointsArr = append(ep.endpointsArr, t) 260 return nil 261 } 262 263 // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. 264 func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { 265 ep.mu.Lock() 266 defer ep.mu.Unlock() 267 268 idx, ok := ep.endpointsMap[t] 269 if !ok { 270 return false 271 } 272 delete(ep.endpointsMap, t) 273 l := len(ep.endpointsArr) 274 if l > 1 { 275 // The last endpoint in endpointsArr is moved instead of the deleted one. 276 lastEp := ep.endpointsArr[l-1] 277 ep.endpointsArr[idx] = lastEp 278 ep.endpointsMap[lastEp] = idx 279 ep.endpointsArr = ep.endpointsArr[0 : l-1] 280 return false 281 } 282 return true 283 } 284 285 func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { 286 if id.RemotePort != 0 { 287 // TODO(eyalsoha): Why? 288 reusePort = false 289 } 290 291 eps, ok := d.protocol[protocolIDs{netProto, protocol}] 292 if !ok { 293 return tcpip.ErrUnknownProtocol 294 } 295 296 eps.mu.Lock() 297 defer eps.mu.Unlock() 298 299 if epsByNic, ok := eps.endpoints[id]; ok { 300 // There was already a binding. 301 return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) 302 } 303 304 // This is a new binding. 305 epsByNic := &endpointsByNic{ 306 endpoints: make(map[tcpip.NICID]*multiPortEndpoint), 307 seed: rand.Uint32(), 308 } 309 eps.endpoints[id] = epsByNic 310 311 return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) 312 } 313 314 // unregisterEndpoint unregisters the endpoint with the given id such that it 315 // won't receive any more packets. 316 func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { 317 for _, n := range netProtos { 318 if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { 319 eps.unregisterEndpoint(id, ep, bindToDevice) 320 } 321 } 322 } 323 324 var loopbackSubnet = func() tcpip.Subnet { 325 sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00") 326 if err != nil { 327 panic(err) 328 } 329 return sn 330 }() 331 332 // deliverPacket attempts to find one or more matching transport endpoints, and 333 // then, if matches are found, delivers the packet to them. Returns true if it 334 // found one or more endpoints, false otherwise. 335 func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool { 336 eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] 337 if !ok { 338 return false 339 } 340 341 eps.mu.RLock() 342 343 // Determine which transport endpoint or endpoints to deliver this packet to. 344 // If the packet is a broadcast or multicast, then find all matching 345 // transport endpoints. 346 var destEps []*endpointsByNic 347 if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { 348 destEps = d.findAllEndpointsLocked(eps, vv, id) 349 } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { 350 destEps = append(destEps, ep) 351 } 352 353 eps.mu.RUnlock() 354 355 // Fail if we didn't find at least one matching transport endpoint. 356 if len(destEps) == 0 { 357 // UDP packet could not be delivered to an unknown destination port. 358 if protocol == header.UDPProtocolNumber { 359 r.Stats().UDP.UnknownPortErrors.Increment() 360 } 361 return false 362 } 363 364 // Deliver the packet. 365 for _, ep := range destEps { 366 ep.handlePacket(r, id, vv) 367 } 368 369 return true 370 } 371 372 // deliverRawPacket attempts to deliver the given packet and returns whether it 373 // was delivered successfully. 374 func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool { 375 eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] 376 if !ok { 377 return false 378 } 379 380 // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via 381 // raw endpoint first. If there are multiple raw endpoints, they all 382 // receive the packet. 383 foundRaw := false 384 eps.mu.RLock() 385 for _, rawEP := range eps.rawEndpoints { 386 // Each endpoint gets its own copy of the packet for the sake 387 // of save/restore. 388 rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) 389 foundRaw = true 390 } 391 eps.mu.RUnlock() 392 393 return foundRaw 394 } 395 396 // deliverControlPacket attempts to deliver the given control packet. Returns 397 // true if it found an endpoint, false otherwise. 398 func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { 399 eps, ok := d.protocol[protocolIDs{net, trans}] 400 if !ok { 401 return false 402 } 403 404 // Try to find the endpoint. 405 eps.mu.RLock() 406 ep := d.findEndpointLocked(eps, vv, id) 407 eps.mu.RUnlock() 408 409 // Fail if we didn't find one. 410 if ep == nil { 411 return false 412 } 413 414 // Deliver the packet. 415 ep.handleControlPacket(n, id, typ, extra, vv) 416 417 return true 418 } 419 420 func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic { 421 var matchedEPs []*endpointsByNic 422 // Try to find a match with the id as provided. 423 if ep, ok := eps.endpoints[id]; ok { 424 matchedEPs = append(matchedEPs, ep) 425 } 426 427 // Try to find a match with the id minus the local address. 428 nid := id 429 430 nid.LocalAddress = "" 431 if ep, ok := eps.endpoints[nid]; ok { 432 matchedEPs = append(matchedEPs, ep) 433 } 434 435 // Try to find a match with the id minus the remote part. 436 nid.LocalAddress = id.LocalAddress 437 nid.RemoteAddress = "" 438 nid.RemotePort = 0 439 if ep, ok := eps.endpoints[nid]; ok { 440 matchedEPs = append(matchedEPs, ep) 441 } 442 443 // Try to find a match with only the local port. 444 nid.LocalAddress = "" 445 if ep, ok := eps.endpoints[nid]; ok { 446 matchedEPs = append(matchedEPs, ep) 447 } 448 449 return matchedEPs 450 } 451 452 // findEndpointLocked returns the endpoint that most closely matches the given 453 // id. 454 func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { 455 if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 { 456 return matchedEPs[0] 457 } 458 return nil 459 } 460 461 // registerRawEndpoint registers the given endpoint with the dispatcher such 462 // that packets of the appropriate protocol are delivered to it. A single 463 // packet can be sent to one or more raw endpoints along with a non-raw 464 // endpoint. 465 func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { 466 eps, ok := d.protocol[protocolIDs{netProto, transProto}] 467 if !ok { 468 return nil 469 } 470 471 eps.mu.Lock() 472 defer eps.mu.Unlock() 473 eps.rawEndpoints = append(eps.rawEndpoints, ep) 474 475 return nil 476 } 477 478 // unregisterRawEndpoint unregisters the raw endpoint for the given transport 479 // protocol such that it won't receive any more packets. 480 func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { 481 eps, ok := d.protocol[protocolIDs{netProto, transProto}] 482 if !ok { 483 panic(fmt.Errorf("tried to unregister endpoint with unsupported network and transport protocol pair: %d, %d", netProto, transProto)) 484 } 485 486 eps.mu.Lock() 487 defer eps.mu.Unlock() 488 for i, rawEP := range eps.rawEndpoints { 489 if rawEP == ep { 490 eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) 491 return 492 } 493 } 494 } 495 496 func isMulticastOrBroadcast(addr tcpip.Address) bool { 497 return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) 498 }