github.com/docker/engine@v22.0.0-20211208180946-d456264580cf+incompatible/libnetwork/drivers/overlay/encryption.go (about) 1 //go:build linux 2 // +build linux 3 4 package overlay 5 6 import ( 7 "bytes" 8 "encoding/binary" 9 "encoding/hex" 10 "fmt" 11 "hash/fnv" 12 "net" 13 "sync" 14 "syscall" 15 16 "strconv" 17 18 "github.com/docker/docker/libnetwork/drivers/overlay/overlayutils" 19 "github.com/docker/docker/libnetwork/iptables" 20 "github.com/docker/docker/libnetwork/ns" 21 "github.com/docker/docker/libnetwork/types" 22 "github.com/sirupsen/logrus" 23 "github.com/vishvananda/netlink" 24 ) 25 26 const ( 27 r = 0xD0C4E3 28 pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8) 29 ) 30 31 const ( 32 forward = iota + 1 33 reverse 34 bidir 35 ) 36 37 var spMark = netlink.XfrmMark{Value: uint32(r), Mask: 0xffffffff} 38 39 type key struct { 40 value []byte 41 tag uint32 42 } 43 44 func (k *key) String() string { 45 if k != nil { 46 return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag) 47 } 48 return "" 49 } 50 51 type spi struct { 52 forward int 53 reverse int 54 } 55 56 func (s *spi) String() string { 57 return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse)) 58 } 59 60 type encrMap struct { 61 nodes map[string][]*spi 62 sync.Mutex 63 } 64 65 func (e *encrMap) String() string { 66 e.Lock() 67 defer e.Unlock() 68 b := new(bytes.Buffer) 69 for k, v := range e.nodes { 70 b.WriteString("\n") 71 b.WriteString(k) 72 b.WriteString(":") 73 b.WriteString("[") 74 for _, s := range v { 75 b.WriteString(s.String()) 76 b.WriteString(",") 77 } 78 b.WriteString("]") 79 80 } 81 return b.String() 82 } 83 84 func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error { 85 logrus.Debugf("checkEncryption(%.7s, %v, %d, %t)", nid, rIP, vxlanID, isLocal) 86 87 n := d.network(nid) 88 if n == nil || !n.secure { 89 return nil 90 } 91 92 if len(d.keys) == 0 { 93 return types.ForbiddenErrorf("encryption key is not present") 94 } 95 96 lIP := net.ParseIP(d.bindAddress) 97 aIP := net.ParseIP(d.advertiseAddress) 98 nodes := map[string]net.IP{} 99 100 switch { 101 case isLocal: 102 if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool { 103 if !aIP.Equal(pEntry.vtep) { 104 nodes[pEntry.vtep.String()] = pEntry.vtep 105 } 106 return false 107 }); err != nil { 108 logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %.5s: %v", nid, err) 109 } 110 default: 111 if len(d.network(nid).endpoints) > 0 { 112 nodes[rIP.String()] = rIP 113 } 114 } 115 116 logrus.Debugf("List of nodes: %s", nodes) 117 118 if add { 119 for _, rIP := range nodes { 120 if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil { 121 logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err) 122 } 123 } 124 } else { 125 if len(nodes) == 0 { 126 if err := removeEncryption(lIP, rIP, d.secMap); err != nil { 127 logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err) 128 } 129 } 130 } 131 132 return nil 133 } 134 135 func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error { 136 logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP) 137 rIPs := remoteIP.String() 138 139 indices := make([]*spi, 0, len(keys)) 140 141 err := programMangle(vni, true) 142 if err != nil { 143 logrus.Warn(err) 144 } 145 146 err = programInput(vni, true) 147 if err != nil { 148 logrus.Warn(err) 149 } 150 151 for i, k := range keys { 152 spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)} 153 dir := reverse 154 if i == 0 { 155 dir = bidir 156 } 157 fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true) 158 if err != nil { 159 logrus.Warn(err) 160 } 161 indices = append(indices, spis) 162 if i != 0 { 163 continue 164 } 165 err = programSP(fSA, rSA, true) 166 if err != nil { 167 logrus.Warn(err) 168 } 169 } 170 171 em.Lock() 172 em.nodes[rIPs] = indices 173 em.Unlock() 174 175 return nil 176 } 177 178 func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error { 179 em.Lock() 180 indices, ok := em.nodes[remoteIP.String()] 181 em.Unlock() 182 if !ok { 183 return nil 184 } 185 for i, idxs := range indices { 186 dir := reverse 187 if i == 0 { 188 dir = bidir 189 } 190 fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false) 191 if err != nil { 192 logrus.Warn(err) 193 } 194 if i != 0 { 195 continue 196 } 197 err = programSP(fSA, rSA, false) 198 if err != nil { 199 logrus.Warn(err) 200 } 201 } 202 return nil 203 } 204 205 func programMangle(vni uint32, add bool) (err error) { 206 var ( 207 p = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10) 208 c = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) 209 m = strconv.FormatUint(uint64(r), 10) 210 chain = "OUTPUT" 211 rule = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m} 212 a = "-A" 213 action = "install" 214 ) 215 216 // TODO IPv6 support 217 iptable := iptables.GetIptable(iptables.IPv4) 218 219 if add == iptable.Exists(iptables.Mangle, chain, rule...) { 220 return 221 } 222 223 if !add { 224 a = "-D" 225 action = "remove" 226 } 227 228 if err = iptable.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil { 229 logrus.Warnf("could not %s mangle rule: %v", action, err) 230 } 231 232 return 233 } 234 235 func programInput(vni uint32, add bool) (err error) { 236 var ( 237 port = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10) 238 vniMatch = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) 239 plainVxlan = []string{"-p", "udp", "--dport", port, "-m", "u32", "--u32", vniMatch, "-j"} 240 ipsecVxlan = append([]string{"-m", "policy", "--dir", "in", "--pol", "ipsec"}, plainVxlan...) 241 block = append(plainVxlan, "DROP") 242 accept = append(ipsecVxlan, "ACCEPT") 243 chain = "INPUT" 244 action = iptables.Append 245 msg = "add" 246 ) 247 248 // TODO IPv6 support 249 iptable := iptables.GetIptable(iptables.IPv4) 250 251 if !add { 252 action = iptables.Delete 253 msg = "remove" 254 } 255 256 if err := iptable.ProgramRule(iptables.Filter, chain, action, accept); err != nil { 257 logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err) 258 } 259 260 if err := iptable.ProgramRule(iptables.Filter, chain, action, block); err != nil { 261 logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err) 262 } 263 264 return 265 } 266 267 func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) { 268 var ( 269 action = "Removing" 270 xfrmProgram = ns.NlHandle().XfrmStateDel 271 ) 272 273 if add { 274 action = "Adding" 275 xfrmProgram = ns.NlHandle().XfrmStateAdd 276 } 277 278 if dir&reverse > 0 { 279 rSA = &netlink.XfrmState{ 280 Src: remoteIP, 281 Dst: localIP, 282 Proto: netlink.XFRM_PROTO_ESP, 283 Spi: spi.reverse, 284 Mode: netlink.XFRM_MODE_TRANSPORT, 285 Reqid: r, 286 } 287 if add { 288 rSA.Aead = buildAeadAlgo(k, spi.reverse) 289 } 290 291 exists, err := saExists(rSA) 292 if err != nil { 293 exists = !add 294 } 295 296 if add != exists { 297 logrus.Debugf("%s: rSA{%s}", action, rSA) 298 if err := xfrmProgram(rSA); err != nil { 299 logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err) 300 } 301 } 302 } 303 304 if dir&forward > 0 { 305 fSA = &netlink.XfrmState{ 306 Src: localIP, 307 Dst: remoteIP, 308 Proto: netlink.XFRM_PROTO_ESP, 309 Spi: spi.forward, 310 Mode: netlink.XFRM_MODE_TRANSPORT, 311 Reqid: r, 312 } 313 if add { 314 fSA.Aead = buildAeadAlgo(k, spi.forward) 315 } 316 317 exists, err := saExists(fSA) 318 if err != nil { 319 exists = !add 320 } 321 322 if add != exists { 323 logrus.Debugf("%s fSA{%s}", action, fSA) 324 if err := xfrmProgram(fSA); err != nil { 325 logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err) 326 } 327 } 328 } 329 330 return 331 } 332 333 func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error { 334 action := "Removing" 335 xfrmProgram := ns.NlHandle().XfrmPolicyDel 336 if add { 337 action = "Adding" 338 xfrmProgram = ns.NlHandle().XfrmPolicyAdd 339 } 340 341 // Create a congruent cidr 342 s := types.GetMinimalIP(fSA.Src) 343 d := types.GetMinimalIP(fSA.Dst) 344 fullMask := net.CIDRMask(8*len(s), 8*len(s)) 345 346 fPol := &netlink.XfrmPolicy{ 347 Src: &net.IPNet{IP: s, Mask: fullMask}, 348 Dst: &net.IPNet{IP: d, Mask: fullMask}, 349 Dir: netlink.XFRM_DIR_OUT, 350 Proto: 17, 351 DstPort: 4789, 352 Mark: &spMark, 353 Tmpls: []netlink.XfrmPolicyTmpl{ 354 { 355 Src: fSA.Src, 356 Dst: fSA.Dst, 357 Proto: netlink.XFRM_PROTO_ESP, 358 Mode: netlink.XFRM_MODE_TRANSPORT, 359 Spi: fSA.Spi, 360 Reqid: r, 361 }, 362 }, 363 } 364 365 exists, err := spExists(fPol) 366 if err != nil { 367 exists = !add 368 } 369 370 if add != exists { 371 logrus.Debugf("%s fSP{%s}", action, fPol) 372 if err := xfrmProgram(fPol); err != nil { 373 logrus.Warnf("%s fSP{%s}: %v", action, fPol, err) 374 } 375 } 376 377 return nil 378 } 379 380 func saExists(sa *netlink.XfrmState) (bool, error) { 381 _, err := ns.NlHandle().XfrmStateGet(sa) 382 switch err { 383 case nil: 384 return true, nil 385 case syscall.ESRCH: 386 return false, nil 387 default: 388 err = fmt.Errorf("Error while checking for SA existence: %v", err) 389 logrus.Warn(err) 390 return false, err 391 } 392 } 393 394 func spExists(sp *netlink.XfrmPolicy) (bool, error) { 395 _, err := ns.NlHandle().XfrmPolicyGet(sp) 396 switch err { 397 case nil: 398 return true, nil 399 case syscall.ENOENT: 400 return false, nil 401 default: 402 err = fmt.Errorf("Error while checking for SP existence: %v", err) 403 logrus.Warn(err) 404 return false, err 405 } 406 } 407 408 func buildSPI(src, dst net.IP, st uint32) int { 409 b := make([]byte, 4) 410 binary.BigEndian.PutUint32(b, st) 411 h := fnv.New32a() 412 h.Write(src) 413 h.Write(b) 414 h.Write(dst) 415 return int(binary.BigEndian.Uint32(h.Sum(nil))) 416 } 417 418 func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo { 419 salt := make([]byte, 4) 420 binary.BigEndian.PutUint32(salt, uint32(s)) 421 return &netlink.XfrmStateAlgo{ 422 Name: "rfc4106(gcm(aes))", 423 Key: append(k.value, salt...), 424 ICVLen: 64, 425 } 426 } 427 428 func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error { 429 d.secMap.Lock() 430 for node, indices := range d.secMap.nodes { 431 idxs, stop := f(node, indices) 432 if idxs != nil { 433 d.secMap.nodes[node] = idxs 434 } 435 if stop { 436 break 437 } 438 } 439 d.secMap.Unlock() 440 return nil 441 } 442 443 func (d *driver) setKeys(keys []*key) error { 444 // Remove any stale policy, state 445 clearEncryptionStates() 446 // Accept the encryption keys and clear any stale encryption map 447 d.Lock() 448 d.keys = keys 449 d.secMap = &encrMap{nodes: map[string][]*spi{}} 450 d.Unlock() 451 logrus.Debugf("Initial encryption keys: %v", keys) 452 return nil 453 } 454 455 // updateKeys allows to add a new key and/or change the primary key and/or prune an existing key 456 // The primary key is the key used in transmission and will go in first position in the list. 457 func (d *driver) updateKeys(newKey, primary, pruneKey *key) error { 458 logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey) 459 460 logrus.Debugf("Current: %v", d.keys) 461 462 var ( 463 newIdx = -1 464 priIdx = -1 465 delIdx = -1 466 lIP = net.ParseIP(d.bindAddress) 467 aIP = net.ParseIP(d.advertiseAddress) 468 ) 469 470 d.Lock() 471 defer d.Unlock() 472 473 // add new 474 if newKey != nil { 475 d.keys = append(d.keys, newKey) 476 newIdx += len(d.keys) 477 } 478 for i, k := range d.keys { 479 if primary != nil && k.tag == primary.tag { 480 priIdx = i 481 } 482 if pruneKey != nil && k.tag == pruneKey.tag { 483 delIdx = i 484 } 485 } 486 487 if (newKey != nil && newIdx == -1) || 488 (primary != nil && priIdx == -1) || 489 (pruneKey != nil && delIdx == -1) { 490 return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+ 491 "(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx) 492 } 493 494 if priIdx != -1 && priIdx == delIdx { 495 return types.BadRequestErrorf("attempting to both make a key (index %d) primary and delete it", priIdx) 496 } 497 498 d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) { 499 rIP := net.ParseIP(rIPs) 500 return updateNodeKey(lIP, aIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false 501 }) 502 503 // swap primary 504 if priIdx != -1 { 505 d.keys[0], d.keys[priIdx] = d.keys[priIdx], d.keys[0] 506 } 507 // prune 508 if delIdx != -1 { 509 if delIdx == 0 { 510 delIdx = priIdx 511 } 512 d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...) 513 } 514 515 logrus.Debugf("Updated: %v", d.keys) 516 517 return nil 518 } 519 520 /******************************************************** 521 * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1 522 * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1 523 * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2 524 *********************************************************/ 525 526 // Spis and keys are sorted in such away the one in position 0 is the primary 527 func updateNodeKey(lIP, aIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi { 528 logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx) 529 530 spis := idxs 531 logrus.Debugf("Current: %v", spis) 532 533 // add new 534 if newIdx != -1 { 535 spis = append(spis, &spi{ 536 forward: buildSPI(aIP, rIP, curKeys[newIdx].tag), 537 reverse: buildSPI(rIP, aIP, curKeys[newIdx].tag), 538 }) 539 } 540 541 if delIdx != -1 { 542 // -rSA0 543 programSA(lIP, rIP, spis[delIdx], nil, reverse, false) 544 } 545 546 if newIdx > -1 { 547 // +rSA2 548 programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true) 549 } 550 551 if priIdx > 0 { 552 // +fSA2 553 fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true) 554 555 // +fSP2, -fSP1 556 s := types.GetMinimalIP(fSA2.Src) 557 d := types.GetMinimalIP(fSA2.Dst) 558 fullMask := net.CIDRMask(8*len(s), 8*len(s)) 559 560 fSP1 := &netlink.XfrmPolicy{ 561 Src: &net.IPNet{IP: s, Mask: fullMask}, 562 Dst: &net.IPNet{IP: d, Mask: fullMask}, 563 Dir: netlink.XFRM_DIR_OUT, 564 Proto: 17, 565 DstPort: 4789, 566 Mark: &spMark, 567 Tmpls: []netlink.XfrmPolicyTmpl{ 568 { 569 Src: fSA2.Src, 570 Dst: fSA2.Dst, 571 Proto: netlink.XFRM_PROTO_ESP, 572 Mode: netlink.XFRM_MODE_TRANSPORT, 573 Spi: fSA2.Spi, 574 Reqid: r, 575 }, 576 }, 577 } 578 logrus.Debugf("Updating fSP{%s}", fSP1) 579 if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil { 580 logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err) 581 } 582 583 // -fSA1 584 programSA(lIP, rIP, spis[0], nil, forward, false) 585 } 586 587 // swap 588 if priIdx > 0 { 589 swp := spis[0] 590 spis[0] = spis[priIdx] 591 spis[priIdx] = swp 592 } 593 // prune 594 if delIdx != -1 { 595 if delIdx == 0 { 596 delIdx = priIdx 597 } 598 spis = append(spis[:delIdx], spis[delIdx+1:]...) 599 } 600 601 logrus.Debugf("Updated: %v", spis) 602 603 return spis 604 } 605 606 func (n *network) maxMTU() int { 607 mtu := 1500 608 if n.mtu != 0 { 609 mtu = n.mtu 610 } 611 mtu -= vxlanEncap 612 if n.secure { 613 // In case of encryption account for the 614 // esp packet expansion and padding 615 mtu -= pktExpansion 616 mtu -= (mtu % 4) 617 } 618 return mtu 619 } 620 621 func clearEncryptionStates() { 622 nlh := ns.NlHandle() 623 spList, err := nlh.XfrmPolicyList(netlink.FAMILY_ALL) 624 if err != nil { 625 logrus.Warnf("Failed to retrieve SP list for cleanup: %v", err) 626 } 627 saList, err := nlh.XfrmStateList(netlink.FAMILY_ALL) 628 if err != nil { 629 logrus.Warnf("Failed to retrieve SA list for cleanup: %v", err) 630 } 631 for _, sp := range spList { 632 sp := sp 633 if sp.Mark != nil && sp.Mark.Value == spMark.Value { 634 if err := nlh.XfrmPolicyDel(&sp); err != nil { 635 logrus.Warnf("Failed to delete stale SP %s: %v", sp, err) 636 continue 637 } 638 logrus.Debugf("Removed stale SP: %s", sp) 639 } 640 } 641 for _, sa := range saList { 642 sa := sa 643 if sa.Reqid == r { 644 if err := nlh.XfrmStateDel(&sa); err != nil { 645 logrus.Warnf("Failed to delete stale SA %s: %v", sa, err) 646 continue 647 } 648 logrus.Debugf("Removed stale SA: %s", sa) 649 } 650 } 651 }