github.com/Heebron/moby@v0.0.0-20221111184709-6eab4f55faf7/libnetwork/drivers/overlay/encryption.go (about) 1 //go:build linux 2 // +build linux 3 4 package overlay 5 6 import ( 7 "bytes" 8 "encoding/binary" 9 "encoding/hex" 10 "fmt" 11 "hash/fnv" 12 "net" 13 "strconv" 14 "sync" 15 "syscall" 16 17 "github.com/docker/docker/libnetwork/drivers/overlay/overlayutils" 18 "github.com/docker/docker/libnetwork/iptables" 19 "github.com/docker/docker/libnetwork/ns" 20 "github.com/docker/docker/libnetwork/types" 21 "github.com/sirupsen/logrus" 22 "github.com/vishvananda/netlink" 23 ) 24 25 const ( 26 r = 0xD0C4E3 27 pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8) 28 ) 29 30 const ( 31 forward = iota + 1 32 reverse 33 bidir 34 ) 35 36 var spMark = netlink.XfrmMark{Value: uint32(r), Mask: 0xffffffff} 37 38 type key struct { 39 value []byte 40 tag uint32 41 } 42 43 func (k *key) String() string { 44 if k != nil { 45 return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag) 46 } 47 return "" 48 } 49 50 type spi struct { 51 forward int 52 reverse int 53 } 54 55 func (s *spi) String() string { 56 return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse)) 57 } 58 59 type encrMap struct { 60 nodes map[string][]*spi 61 sync.Mutex 62 } 63 64 func (e *encrMap) String() string { 65 e.Lock() 66 defer e.Unlock() 67 b := new(bytes.Buffer) 68 for k, v := range e.nodes { 69 b.WriteString("\n") 70 b.WriteString(k) 71 b.WriteString(":") 72 b.WriteString("[") 73 for _, s := range v { 74 b.WriteString(s.String()) 75 b.WriteString(",") 76 } 77 b.WriteString("]") 78 } 79 return b.String() 80 } 81 82 func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error { 83 logrus.Debugf("checkEncryption(%.7s, %v, %d, %t)", nid, rIP, vxlanID, isLocal) 84 85 n := d.network(nid) 86 if n == nil || !n.secure { 87 return nil 88 } 89 90 if len(d.keys) == 0 { 91 return types.ForbiddenErrorf("encryption key is not present") 92 } 93 94 lIP := net.ParseIP(d.bindAddress) 95 aIP := net.ParseIP(d.advertiseAddress) 96 nodes := map[string]net.IP{} 97 98 switch { 99 case isLocal: 100 if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool { 101 if !aIP.Equal(pEntry.vtep) { 102 nodes[pEntry.vtep.String()] = pEntry.vtep 103 } 104 return false 105 }); err != nil { 106 logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %.5s: %v", nid, err) 107 } 108 default: 109 if len(d.network(nid).endpoints) > 0 { 110 nodes[rIP.String()] = rIP 111 } 112 } 113 114 logrus.Debugf("List of nodes: %s", nodes) 115 116 if add { 117 for _, rIP := range nodes { 118 if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil { 119 logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err) 120 } 121 } 122 } else { 123 if len(nodes) == 0 { 124 if err := removeEncryption(lIP, rIP, d.secMap); err != nil { 125 logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err) 126 } 127 } 128 } 129 130 return nil 131 } 132 133 func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error { 134 logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP) 135 rIPs := remoteIP.String() 136 137 indices := make([]*spi, 0, len(keys)) 138 139 err := programMangle(vni, true) 140 if err != nil { 141 logrus.Warn(err) 142 } 143 144 err = programInput(vni, true) 145 if err != nil { 146 logrus.Warn(err) 147 } 148 149 for i, k := range keys { 150 spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)} 151 dir := reverse 152 if i == 0 { 153 dir = bidir 154 } 155 fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true) 156 if err != nil { 157 logrus.Warn(err) 158 } 159 indices = append(indices, spis) 160 if i != 0 { 161 continue 162 } 163 err = programSP(fSA, rSA, true) 164 if err != nil { 165 logrus.Warn(err) 166 } 167 } 168 169 em.Lock() 170 em.nodes[rIPs] = indices 171 em.Unlock() 172 173 return nil 174 } 175 176 func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error { 177 em.Lock() 178 indices, ok := em.nodes[remoteIP.String()] 179 em.Unlock() 180 if !ok { 181 return nil 182 } 183 for i, idxs := range indices { 184 dir := reverse 185 if i == 0 { 186 dir = bidir 187 } 188 fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false) 189 if err != nil { 190 logrus.Warn(err) 191 } 192 if i != 0 { 193 continue 194 } 195 err = programSP(fSA, rSA, false) 196 if err != nil { 197 logrus.Warn(err) 198 } 199 } 200 return nil 201 } 202 203 func programMangle(vni uint32, add bool) (err error) { 204 var ( 205 p = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10) 206 c = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) 207 m = strconv.FormatUint(uint64(r), 10) 208 chain = "OUTPUT" 209 rule = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m} 210 a = "-A" 211 action = "install" 212 ) 213 214 // TODO IPv6 support 215 iptable := iptables.GetIptable(iptables.IPv4) 216 217 if add == iptable.Exists(iptables.Mangle, chain, rule...) { 218 return 219 } 220 221 if !add { 222 a = "-D" 223 action = "remove" 224 } 225 226 if err = iptable.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil { 227 logrus.Warnf("could not %s mangle rule: %v", action, err) 228 } 229 230 return 231 } 232 233 func programInput(vni uint32, add bool) (err error) { 234 var ( 235 port = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10) 236 vniMatch = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) 237 plainVxlan = []string{"-p", "udp", "--dport", port, "-m", "u32", "--u32", vniMatch, "-j"} 238 ipsecVxlan = append([]string{"-m", "policy", "--dir", "in", "--pol", "ipsec"}, plainVxlan...) 239 block = append(plainVxlan, "DROP") 240 accept = append(ipsecVxlan, "ACCEPT") 241 chain = "INPUT" 242 action = iptables.Append 243 msg = "add" 244 ) 245 246 // TODO IPv6 support 247 iptable := iptables.GetIptable(iptables.IPv4) 248 249 if !add { 250 action = iptables.Delete 251 msg = "remove" 252 } 253 254 if err := iptable.ProgramRule(iptables.Filter, chain, action, accept); err != nil { 255 logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err) 256 } 257 258 if err := iptable.ProgramRule(iptables.Filter, chain, action, block); err != nil { 259 logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err) 260 } 261 262 return 263 } 264 265 func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) { 266 var ( 267 action = "Removing" 268 xfrmProgram = ns.NlHandle().XfrmStateDel 269 ) 270 271 if add { 272 action = "Adding" 273 xfrmProgram = ns.NlHandle().XfrmStateAdd 274 } 275 276 if dir&reverse > 0 { 277 rSA = &netlink.XfrmState{ 278 Src: remoteIP, 279 Dst: localIP, 280 Proto: netlink.XFRM_PROTO_ESP, 281 Spi: spi.reverse, 282 Mode: netlink.XFRM_MODE_TRANSPORT, 283 Reqid: r, 284 } 285 if add { 286 rSA.Aead = buildAeadAlgo(k, spi.reverse) 287 } 288 289 exists, err := saExists(rSA) 290 if err != nil { 291 exists = !add 292 } 293 294 if add != exists { 295 logrus.Debugf("%s: rSA{%s}", action, rSA) 296 if err := xfrmProgram(rSA); err != nil { 297 logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err) 298 } 299 } 300 } 301 302 if dir&forward > 0 { 303 fSA = &netlink.XfrmState{ 304 Src: localIP, 305 Dst: remoteIP, 306 Proto: netlink.XFRM_PROTO_ESP, 307 Spi: spi.forward, 308 Mode: netlink.XFRM_MODE_TRANSPORT, 309 Reqid: r, 310 } 311 if add { 312 fSA.Aead = buildAeadAlgo(k, spi.forward) 313 } 314 315 exists, err := saExists(fSA) 316 if err != nil { 317 exists = !add 318 } 319 320 if add != exists { 321 logrus.Debugf("%s fSA{%s}", action, fSA) 322 if err := xfrmProgram(fSA); err != nil { 323 logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err) 324 } 325 } 326 } 327 328 return 329 } 330 331 func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error { 332 action := "Removing" 333 xfrmProgram := ns.NlHandle().XfrmPolicyDel 334 if add { 335 action = "Adding" 336 xfrmProgram = ns.NlHandle().XfrmPolicyAdd 337 } 338 339 // Create a congruent cidr 340 s := types.GetMinimalIP(fSA.Src) 341 d := types.GetMinimalIP(fSA.Dst) 342 fullMask := net.CIDRMask(8*len(s), 8*len(s)) 343 344 fPol := &netlink.XfrmPolicy{ 345 Src: &net.IPNet{IP: s, Mask: fullMask}, 346 Dst: &net.IPNet{IP: d, Mask: fullMask}, 347 Dir: netlink.XFRM_DIR_OUT, 348 Proto: 17, 349 DstPort: 4789, 350 Mark: &spMark, 351 Tmpls: []netlink.XfrmPolicyTmpl{ 352 { 353 Src: fSA.Src, 354 Dst: fSA.Dst, 355 Proto: netlink.XFRM_PROTO_ESP, 356 Mode: netlink.XFRM_MODE_TRANSPORT, 357 Spi: fSA.Spi, 358 Reqid: r, 359 }, 360 }, 361 } 362 363 exists, err := spExists(fPol) 364 if err != nil { 365 exists = !add 366 } 367 368 if add != exists { 369 logrus.Debugf("%s fSP{%s}", action, fPol) 370 if err := xfrmProgram(fPol); err != nil { 371 logrus.Warnf("%s fSP{%s}: %v", action, fPol, err) 372 } 373 } 374 375 return nil 376 } 377 378 func saExists(sa *netlink.XfrmState) (bool, error) { 379 _, err := ns.NlHandle().XfrmStateGet(sa) 380 switch err { 381 case nil: 382 return true, nil 383 case syscall.ESRCH: 384 return false, nil 385 default: 386 err = fmt.Errorf("Error while checking for SA existence: %v", err) 387 logrus.Warn(err) 388 return false, err 389 } 390 } 391 392 func spExists(sp *netlink.XfrmPolicy) (bool, error) { 393 _, err := ns.NlHandle().XfrmPolicyGet(sp) 394 switch err { 395 case nil: 396 return true, nil 397 case syscall.ENOENT: 398 return false, nil 399 default: 400 err = fmt.Errorf("Error while checking for SP existence: %v", err) 401 logrus.Warn(err) 402 return false, err 403 } 404 } 405 406 func buildSPI(src, dst net.IP, st uint32) int { 407 b := make([]byte, 4) 408 binary.BigEndian.PutUint32(b, st) 409 h := fnv.New32a() 410 h.Write(src) 411 h.Write(b) 412 h.Write(dst) 413 return int(binary.BigEndian.Uint32(h.Sum(nil))) 414 } 415 416 func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo { 417 salt := make([]byte, 4) 418 binary.BigEndian.PutUint32(salt, uint32(s)) 419 return &netlink.XfrmStateAlgo{ 420 Name: "rfc4106(gcm(aes))", 421 Key: append(k.value, salt...), 422 ICVLen: 64, 423 } 424 } 425 426 func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error { 427 d.secMap.Lock() 428 for node, indices := range d.secMap.nodes { 429 idxs, stop := f(node, indices) 430 if idxs != nil { 431 d.secMap.nodes[node] = idxs 432 } 433 if stop { 434 break 435 } 436 } 437 d.secMap.Unlock() 438 return nil 439 } 440 441 func (d *driver) setKeys(keys []*key) error { 442 // Remove any stale policy, state 443 clearEncryptionStates() 444 // Accept the encryption keys and clear any stale encryption map 445 d.Lock() 446 d.keys = keys 447 d.secMap = &encrMap{nodes: map[string][]*spi{}} 448 d.Unlock() 449 logrus.Debugf("Initial encryption keys: %v", keys) 450 return nil 451 } 452 453 // updateKeys allows to add a new key and/or change the primary key and/or prune an existing key 454 // The primary key is the key used in transmission and will go in first position in the list. 455 func (d *driver) updateKeys(newKey, primary, pruneKey *key) error { 456 logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey) 457 458 logrus.Debugf("Current: %v", d.keys) 459 460 var ( 461 newIdx = -1 462 priIdx = -1 463 delIdx = -1 464 lIP = net.ParseIP(d.bindAddress) 465 aIP = net.ParseIP(d.advertiseAddress) 466 ) 467 468 d.Lock() 469 defer d.Unlock() 470 471 // add new 472 if newKey != nil { 473 d.keys = append(d.keys, newKey) 474 newIdx += len(d.keys) 475 } 476 for i, k := range d.keys { 477 if primary != nil && k.tag == primary.tag { 478 priIdx = i 479 } 480 if pruneKey != nil && k.tag == pruneKey.tag { 481 delIdx = i 482 } 483 } 484 485 if (newKey != nil && newIdx == -1) || 486 (primary != nil && priIdx == -1) || 487 (pruneKey != nil && delIdx == -1) { 488 return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+ 489 "(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx) 490 } 491 492 if priIdx != -1 && priIdx == delIdx { 493 return types.BadRequestErrorf("attempting to both make a key (index %d) primary and delete it", priIdx) 494 } 495 496 d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) { 497 rIP := net.ParseIP(rIPs) 498 return updateNodeKey(lIP, aIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false 499 }) 500 501 // swap primary 502 if priIdx != -1 { 503 d.keys[0], d.keys[priIdx] = d.keys[priIdx], d.keys[0] 504 } 505 // prune 506 if delIdx != -1 { 507 if delIdx == 0 { 508 delIdx = priIdx 509 } 510 d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...) 511 } 512 513 logrus.Debugf("Updated: %v", d.keys) 514 515 return nil 516 } 517 518 /******************************************************** 519 * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1 520 * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1 521 * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2 522 *********************************************************/ 523 524 // Spis and keys are sorted in such away the one in position 0 is the primary 525 func updateNodeKey(lIP, aIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi { 526 logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx) 527 528 spis := idxs 529 logrus.Debugf("Current: %v", spis) 530 531 // add new 532 if newIdx != -1 { 533 spis = append(spis, &spi{ 534 forward: buildSPI(aIP, rIP, curKeys[newIdx].tag), 535 reverse: buildSPI(rIP, aIP, curKeys[newIdx].tag), 536 }) 537 } 538 539 if delIdx != -1 { 540 // -rSA0 541 programSA(lIP, rIP, spis[delIdx], nil, reverse, false) 542 } 543 544 if newIdx > -1 { 545 // +rSA2 546 programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true) 547 } 548 549 if priIdx > 0 { 550 // +fSA2 551 fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true) 552 553 // +fSP2, -fSP1 554 s := types.GetMinimalIP(fSA2.Src) 555 d := types.GetMinimalIP(fSA2.Dst) 556 fullMask := net.CIDRMask(8*len(s), 8*len(s)) 557 558 fSP1 := &netlink.XfrmPolicy{ 559 Src: &net.IPNet{IP: s, Mask: fullMask}, 560 Dst: &net.IPNet{IP: d, Mask: fullMask}, 561 Dir: netlink.XFRM_DIR_OUT, 562 Proto: 17, 563 DstPort: 4789, 564 Mark: &spMark, 565 Tmpls: []netlink.XfrmPolicyTmpl{ 566 { 567 Src: fSA2.Src, 568 Dst: fSA2.Dst, 569 Proto: netlink.XFRM_PROTO_ESP, 570 Mode: netlink.XFRM_MODE_TRANSPORT, 571 Spi: fSA2.Spi, 572 Reqid: r, 573 }, 574 }, 575 } 576 logrus.Debugf("Updating fSP{%s}", fSP1) 577 if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil { 578 logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err) 579 } 580 581 // -fSA1 582 programSA(lIP, rIP, spis[0], nil, forward, false) 583 } 584 585 // swap 586 if priIdx > 0 { 587 swp := spis[0] 588 spis[0] = spis[priIdx] 589 spis[priIdx] = swp 590 } 591 // prune 592 if delIdx != -1 { 593 if delIdx == 0 { 594 delIdx = priIdx 595 } 596 spis = append(spis[:delIdx], spis[delIdx+1:]...) 597 } 598 599 logrus.Debugf("Updated: %v", spis) 600 601 return spis 602 } 603 604 func (n *network) maxMTU() int { 605 mtu := 1500 606 if n.mtu != 0 { 607 mtu = n.mtu 608 } 609 mtu -= vxlanEncap 610 if n.secure { 611 // In case of encryption account for the 612 // esp packet expansion and padding 613 mtu -= pktExpansion 614 mtu -= (mtu % 4) 615 } 616 return mtu 617 } 618 619 func clearEncryptionStates() { 620 nlh := ns.NlHandle() 621 spList, err := nlh.XfrmPolicyList(netlink.FAMILY_ALL) 622 if err != nil { 623 logrus.Warnf("Failed to retrieve SP list for cleanup: %v", err) 624 } 625 saList, err := nlh.XfrmStateList(netlink.FAMILY_ALL) 626 if err != nil { 627 logrus.Warnf("Failed to retrieve SA list for cleanup: %v", err) 628 } 629 for _, sp := range spList { 630 sp := sp 631 if sp.Mark != nil && sp.Mark.Value == spMark.Value { 632 if err := nlh.XfrmPolicyDel(&sp); err != nil { 633 logrus.Warnf("Failed to delete stale SP %s: %v", sp, err) 634 continue 635 } 636 logrus.Debugf("Removed stale SP: %s", sp) 637 } 638 } 639 for _, sa := range saList { 640 sa := sa 641 if sa.Reqid == r { 642 if err := nlh.XfrmStateDel(&sa); err != nil { 643 logrus.Warnf("Failed to delete stale SA %s: %v", sa, err) 644 continue 645 } 646 logrus.Debugf("Removed stale SA: %s", sa) 647 } 648 } 649 }