github.com/slackhq/nebula@v1.9.0/connection_manager.go (about) 1 package nebula 2 3 import ( 4 "bytes" 5 "context" 6 "sync" 7 "time" 8 9 "github.com/rcrowley/go-metrics" 10 "github.com/sirupsen/logrus" 11 "github.com/slackhq/nebula/cert" 12 "github.com/slackhq/nebula/header" 13 "github.com/slackhq/nebula/iputil" 14 "github.com/slackhq/nebula/udp" 15 ) 16 17 type trafficDecision int 18 19 const ( 20 doNothing trafficDecision = 0 21 deleteTunnel trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote 22 closeTunnel trafficDecision = 2 // delete the hostinfo and notify the remote 23 swapPrimary trafficDecision = 3 24 migrateRelays trafficDecision = 4 25 tryRehandshake trafficDecision = 5 26 sendTestPacket trafficDecision = 6 27 ) 28 29 type connectionManager struct { 30 in map[uint32]struct{} 31 inLock *sync.RWMutex 32 33 out map[uint32]struct{} 34 outLock *sync.RWMutex 35 36 // relayUsed holds which relay localIndexs are in use 37 relayUsed map[uint32]struct{} 38 relayUsedLock *sync.RWMutex 39 40 hostMap *HostMap 41 trafficTimer *LockingTimerWheel[uint32] 42 intf *Interface 43 pendingDeletion map[uint32]struct{} 44 punchy *Punchy 45 checkInterval time.Duration 46 pendingDeletionInterval time.Duration 47 metricsTxPunchy metrics.Counter 48 49 l *logrus.Logger 50 } 51 52 func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager { 53 var max time.Duration 54 if checkInterval < pendingDeletionInterval { 55 max = pendingDeletionInterval 56 } else { 57 max = checkInterval 58 } 59 60 nc := &connectionManager{ 61 hostMap: intf.hostMap, 62 in: make(map[uint32]struct{}), 63 inLock: &sync.RWMutex{}, 64 out: make(map[uint32]struct{}), 65 outLock: &sync.RWMutex{}, 66 relayUsed: make(map[uint32]struct{}), 67 relayUsedLock: &sync.RWMutex{}, 68 trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max), 69 intf: intf, 70 pendingDeletion: make(map[uint32]struct{}), 71 checkInterval: checkInterval, 72 pendingDeletionInterval: pendingDeletionInterval, 73 punchy: punchy, 74 metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), 75 l: l, 76 } 77 78 nc.Start(ctx) 79 return nc 80 } 81 82 func (n *connectionManager) In(localIndex uint32) { 83 n.inLock.RLock() 84 // If this already exists, return 85 if _, ok := n.in[localIndex]; ok { 86 n.inLock.RUnlock() 87 return 88 } 89 n.inLock.RUnlock() 90 n.inLock.Lock() 91 n.in[localIndex] = struct{}{} 92 n.inLock.Unlock() 93 } 94 95 func (n *connectionManager) Out(localIndex uint32) { 96 n.outLock.RLock() 97 // If this already exists, return 98 if _, ok := n.out[localIndex]; ok { 99 n.outLock.RUnlock() 100 return 101 } 102 n.outLock.RUnlock() 103 n.outLock.Lock() 104 n.out[localIndex] = struct{}{} 105 n.outLock.Unlock() 106 } 107 108 func (n *connectionManager) RelayUsed(localIndex uint32) { 109 n.relayUsedLock.RLock() 110 // If this already exists, return 111 if _, ok := n.relayUsed[localIndex]; ok { 112 n.relayUsedLock.RUnlock() 113 return 114 } 115 n.relayUsedLock.RUnlock() 116 n.relayUsedLock.Lock() 117 n.relayUsed[localIndex] = struct{}{} 118 n.relayUsedLock.Unlock() 119 } 120 121 // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and 122 // resets the state for this local index 123 func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) { 124 n.inLock.Lock() 125 n.outLock.Lock() 126 _, in := n.in[localIndex] 127 _, out := n.out[localIndex] 128 delete(n.in, localIndex) 129 delete(n.out, localIndex) 130 n.inLock.Unlock() 131 n.outLock.Unlock() 132 return in, out 133 } 134 135 func (n *connectionManager) AddTrafficWatch(localIndex uint32) { 136 // Use a write lock directly because it should be incredibly rare that we are ever already tracking this index 137 n.outLock.Lock() 138 if _, ok := n.out[localIndex]; ok { 139 n.outLock.Unlock() 140 return 141 } 142 n.out[localIndex] = struct{}{} 143 n.trafficTimer.Add(localIndex, n.checkInterval) 144 n.outLock.Unlock() 145 } 146 147 func (n *connectionManager) Start(ctx context.Context) { 148 go n.Run(ctx) 149 } 150 151 func (n *connectionManager) Run(ctx context.Context) { 152 //TODO: this tick should be based on the min wheel tick? Check firewall 153 clockSource := time.NewTicker(500 * time.Millisecond) 154 defer clockSource.Stop() 155 156 p := []byte("") 157 nb := make([]byte, 12, 12) 158 out := make([]byte, mtu) 159 160 for { 161 select { 162 case <-ctx.Done(): 163 return 164 165 case now := <-clockSource.C: 166 n.trafficTimer.Advance(now) 167 for { 168 localIndex, has := n.trafficTimer.Purge() 169 if !has { 170 break 171 } 172 173 n.doTrafficCheck(localIndex, p, nb, out, now) 174 } 175 } 176 } 177 } 178 179 func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { 180 decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now) 181 182 switch decision { 183 case deleteTunnel: 184 if n.hostMap.DeleteHostInfo(hostinfo) { 185 // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap 186 n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp) 187 } 188 189 case closeTunnel: 190 n.intf.sendCloseTunnel(hostinfo) 191 n.intf.closeTunnel(hostinfo) 192 193 case swapPrimary: 194 n.swapPrimary(hostinfo, primary) 195 196 case migrateRelays: 197 n.migrateRelayUsed(hostinfo, primary) 198 199 case tryRehandshake: 200 n.tryRehandshake(hostinfo) 201 202 case sendTestPacket: 203 n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) 204 } 205 206 n.resetRelayTrafficCheck(hostinfo) 207 } 208 209 func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) { 210 if hostinfo != nil { 211 n.relayUsedLock.Lock() 212 defer n.relayUsedLock.Unlock() 213 // No need to migrate any relays, delete usage info now. 214 for _, idx := range hostinfo.relayState.CopyRelayForIdxs() { 215 delete(n.relayUsed, idx) 216 } 217 } 218 } 219 220 func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) { 221 relayFor := oldhostinfo.relayState.CopyAllRelayFor() 222 223 for _, r := range relayFor { 224 existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) 225 226 var index uint32 227 var relayFrom iputil.VpnIp 228 var relayTo iputil.VpnIp 229 switch { 230 case ok && existing.State == Established: 231 // This relay already exists in newhostinfo, then do nothing. 232 continue 233 case ok && existing.State == Requested: 234 // The relay exists in a Requested state; re-send the request 235 index = existing.LocalIndex 236 switch r.Type { 237 case TerminalType: 238 relayFrom = n.intf.myVpnIp 239 relayTo = existing.PeerIp 240 case ForwardingType: 241 relayFrom = existing.PeerIp 242 relayTo = newhostinfo.vpnIp 243 default: 244 // should never happen 245 } 246 case !ok: 247 n.relayUsedLock.RLock() 248 if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed { 249 // The relay hasn't been used; don't migrate it. 250 n.relayUsedLock.RUnlock() 251 continue 252 } 253 n.relayUsedLock.RUnlock() 254 // The relay doesn't exist at all; create some relay state and send the request. 255 var err error 256 index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested) 257 if err != nil { 258 n.l.WithError(err).Error("failed to migrate relay to new hostinfo") 259 continue 260 } 261 switch r.Type { 262 case TerminalType: 263 relayFrom = n.intf.myVpnIp 264 relayTo = r.PeerIp 265 case ForwardingType: 266 relayFrom = r.PeerIp 267 relayTo = newhostinfo.vpnIp 268 default: 269 // should never happen 270 } 271 } 272 273 // Send a CreateRelayRequest to the peer. 274 req := NebulaControl{ 275 Type: NebulaControl_CreateRelayRequest, 276 InitiatorRelayIndex: index, 277 RelayFromIp: uint32(relayFrom), 278 RelayToIp: uint32(relayTo), 279 } 280 msg, err := req.Marshal() 281 if err != nil { 282 n.l.WithError(err).Error("failed to marshal Control message to migrate relay") 283 } else { 284 n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) 285 n.l.WithFields(logrus.Fields{ 286 "relayFrom": iputil.VpnIp(req.RelayFromIp), 287 "relayTo": iputil.VpnIp(req.RelayToIp), 288 "initiatorRelayIndex": req.InitiatorRelayIndex, 289 "responderRelayIndex": req.ResponderRelayIndex, 290 "vpnIp": newhostinfo.vpnIp}). 291 Info("send CreateRelayRequest") 292 } 293 } 294 } 295 296 func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { 297 n.hostMap.RLock() 298 defer n.hostMap.RUnlock() 299 300 hostinfo := n.hostMap.Indexes[localIndex] 301 if hostinfo == nil { 302 n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") 303 delete(n.pendingDeletion, localIndex) 304 return doNothing, nil, nil 305 } 306 307 if n.isInvalidCertificate(now, hostinfo) { 308 delete(n.pendingDeletion, hostinfo.localIndexId) 309 return closeTunnel, hostinfo, nil 310 } 311 312 primary := n.hostMap.Hosts[hostinfo.vpnIp] 313 mainHostInfo := true 314 if primary != nil && primary != hostinfo { 315 mainHostInfo = false 316 } 317 318 // Check for traffic on this hostinfo 319 inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex) 320 321 // A hostinfo is determined alive if there is incoming traffic 322 if inTraffic { 323 decision := doNothing 324 if n.l.Level >= logrus.DebugLevel { 325 hostinfo.logger(n.l). 326 WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). 327 Debug("Tunnel status") 328 } 329 delete(n.pendingDeletion, hostinfo.localIndexId) 330 331 if mainHostInfo { 332 decision = tryRehandshake 333 334 } else { 335 if n.shouldSwapPrimary(hostinfo, primary) { 336 decision = swapPrimary 337 } else { 338 // migrate the relays to the primary, if in use. 339 decision = migrateRelays 340 } 341 } 342 343 n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) 344 345 if !outTraffic { 346 // Send a punch packet to keep the NAT state alive 347 n.sendPunch(hostinfo) 348 } 349 350 return decision, hostinfo, primary 351 } 352 353 if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok { 354 // We have already sent a test packet and nothing was returned, this hostinfo is dead 355 hostinfo.logger(n.l). 356 WithField("tunnelCheck", m{"state": "dead", "method": "active"}). 357 Info("Tunnel status") 358 359 delete(n.pendingDeletion, hostinfo.localIndexId) 360 return deleteTunnel, hostinfo, nil 361 } 362 363 decision := doNothing 364 if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { 365 if !outTraffic { 366 // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. 367 // Just maintain NAT state if configured to do so. 368 n.sendPunch(hostinfo) 369 n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) 370 return doNothing, nil, nil 371 372 } 373 374 if n.punchy.GetTargetEverything() { 375 // This is similar to the old punchy behavior with a slight optimization. 376 // We aren't receiving traffic but we are sending it, punch on all known 377 // ips in case we need to re-prime NAT state 378 n.sendPunch(hostinfo) 379 } 380 381 if n.l.Level >= logrus.DebugLevel { 382 hostinfo.logger(n.l). 383 WithField("tunnelCheck", m{"state": "testing", "method": "active"}). 384 Debug("Tunnel status") 385 } 386 387 // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues 388 decision = sendTestPacket 389 390 } else { 391 if n.l.Level >= logrus.DebugLevel { 392 hostinfo.logger(n.l).Debugf("Hostinfo sadness") 393 } 394 } 395 396 n.pendingDeletion[hostinfo.localIndexId] = struct{}{} 397 n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval) 398 return decision, hostinfo, nil 399 } 400 401 func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { 402 // The primary tunnel is the most recent handshake to complete locally and should work entirely fine. 403 // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. 404 // Let's sort this out. 405 406 if current.vpnIp < n.intf.myVpnIp { 407 // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. 408 // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. 409 // The remotes vpn ip is lower than mine. I will not flip. 410 return false 411 } 412 413 certState := n.intf.pki.GetCertState() 414 return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature) 415 } 416 417 func (n *connectionManager) swapPrimary(current, primary *HostInfo) { 418 n.hostMap.Lock() 419 // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. 420 if n.hostMap.Hosts[current.vpnIp] == primary { 421 n.hostMap.unlockedMakePrimary(current) 422 } 423 n.hostMap.Unlock() 424 } 425 426 // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and 427 // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid 428 // check and return true. 429 func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { 430 remoteCert := hostinfo.GetCert() 431 if remoteCert == nil { 432 return false 433 } 434 435 valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool()) 436 if valid { 437 return false 438 } 439 440 if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed { 441 // Block listed certificates should always be disconnected 442 return false 443 } 444 445 fingerprint, _ := remoteCert.Sha256Sum() 446 hostinfo.logger(n.l).WithError(err). 447 WithField("fingerprint", fingerprint). 448 Info("Remote certificate is no longer valid, tearing down the tunnel") 449 450 return true 451 } 452 453 func (n *connectionManager) sendPunch(hostinfo *HostInfo) { 454 if !n.punchy.GetPunch() { 455 // Punching is disabled 456 return 457 } 458 459 if n.punchy.GetTargetEverything() { 460 hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) { 461 n.metricsTxPunchy.Inc(1) 462 n.intf.outside.WriteTo([]byte{1}, addr) 463 }) 464 465 } else if hostinfo.remote != nil { 466 n.metricsTxPunchy.Inc(1) 467 n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) 468 } 469 } 470 471 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { 472 certState := n.intf.pki.GetCertState() 473 if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) { 474 return 475 } 476 477 n.l.WithField("vpnIp", hostinfo.vpnIp). 478 WithField("reason", "local certificate is not current"). 479 Info("Re-handshaking with remote") 480 481 n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) 482 }