github.com/jfrazelle/docker@v1.1.2-0.20210712172922-bf78e25fe508/libnetwork/drivers/overlay/encryption.go (about) 1 // +build linux 2 3 package overlay 4 5 import ( 6 "bytes" 7 "encoding/binary" 8 "encoding/hex" 9 "fmt" 10 "hash/fnv" 11 "net" 12 "sync" 13 "syscall" 14 15 "strconv" 16 17 "github.com/docker/docker/libnetwork/drivers/overlay/overlayutils" 18 "github.com/docker/docker/libnetwork/iptables" 19 "github.com/docker/docker/libnetwork/ns" 20 "github.com/docker/docker/libnetwork/types" 21 "github.com/sirupsen/logrus" 22 "github.com/vishvananda/netlink" 23 ) 24 25 const ( 26 r = 0xD0C4E3 27 pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8) 28 ) 29 30 const ( 31 forward = iota + 1 32 reverse 33 bidir 34 ) 35 36 var spMark = netlink.XfrmMark{Value: uint32(r), Mask: 0xffffffff} 37 38 type key struct { 39 value []byte 40 tag uint32 41 } 42 43 func (k *key) String() string { 44 if k != nil { 45 return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag) 46 } 47 return "" 48 } 49 50 type spi struct { 51 forward int 52 reverse int 53 } 54 55 func (s *spi) String() string { 56 return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse)) 57 } 58 59 type encrMap struct { 60 nodes map[string][]*spi 61 sync.Mutex 62 } 63 64 func (e *encrMap) String() string { 65 e.Lock() 66 defer e.Unlock() 67 b := new(bytes.Buffer) 68 for k, v := range e.nodes { 69 b.WriteString("\n") 70 b.WriteString(k) 71 b.WriteString(":") 72 b.WriteString("[") 73 for _, s := range v { 74 b.WriteString(s.String()) 75 b.WriteString(",") 76 } 77 b.WriteString("]") 78 79 } 80 return b.String() 81 } 82 83 func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error { 84 logrus.Debugf("checkEncryption(%.7s, %v, %d, %t)", nid, rIP, vxlanID, isLocal) 85 86 n := d.network(nid) 87 if n == nil || !n.secure { 88 return nil 89 } 90 91 if len(d.keys) == 0 { 92 return types.ForbiddenErrorf("encryption key is not present") 93 } 94 95 lIP := net.ParseIP(d.bindAddress) 96 aIP := net.ParseIP(d.advertiseAddress) 97 nodes := map[string]net.IP{} 98 99 switch { 100 case isLocal: 101 if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool { 102 if !aIP.Equal(pEntry.vtep) { 103 nodes[pEntry.vtep.String()] = pEntry.vtep 104 } 105 return false 106 }); err != nil { 107 logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %.5s: %v", nid, err) 108 } 109 default: 110 if len(d.network(nid).endpoints) > 0 { 111 nodes[rIP.String()] = rIP 112 } 113 } 114 115 logrus.Debugf("List of nodes: %s", nodes) 116 117 if add { 118 for _, rIP := range nodes { 119 if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil { 120 logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err) 121 } 122 } 123 } else { 124 if len(nodes) == 0 { 125 if err := removeEncryption(lIP, rIP, d.secMap); err != nil { 126 logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err) 127 } 128 } 129 } 130 131 return nil 132 } 133 134 func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error { 135 logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP) 136 rIPs := remoteIP.String() 137 138 indices := make([]*spi, 0, len(keys)) 139 140 err := programMangle(vni, true) 141 if err != nil { 142 logrus.Warn(err) 143 } 144 145 err = programInput(vni, true) 146 if err != nil { 147 logrus.Warn(err) 148 } 149 150 for i, k := range keys { 151 spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)} 152 dir := reverse 153 if i == 0 { 154 dir = bidir 155 } 156 fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true) 157 if err != nil { 158 logrus.Warn(err) 159 } 160 indices = append(indices, spis) 161 if i != 0 { 162 continue 163 } 164 err = programSP(fSA, rSA, true) 165 if err != nil { 166 logrus.Warn(err) 167 } 168 } 169 170 em.Lock() 171 em.nodes[rIPs] = indices 172 em.Unlock() 173 174 return nil 175 } 176 177 func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error { 178 em.Lock() 179 indices, ok := em.nodes[remoteIP.String()] 180 em.Unlock() 181 if !ok { 182 return nil 183 } 184 for i, idxs := range indices { 185 dir := reverse 186 if i == 0 { 187 dir = bidir 188 } 189 fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false) 190 if err != nil { 191 logrus.Warn(err) 192 } 193 if i != 0 { 194 continue 195 } 196 err = programSP(fSA, rSA, false) 197 if err != nil { 198 logrus.Warn(err) 199 } 200 } 201 return nil 202 } 203 204 func programMangle(vni uint32, add bool) (err error) { 205 var ( 206 p = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10) 207 c = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) 208 m = strconv.FormatUint(uint64(r), 10) 209 chain = "OUTPUT" 210 rule = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m} 211 a = "-A" 212 action = "install" 213 ) 214 215 // TODO IPv6 support 216 iptable := iptables.GetIptable(iptables.IPv4) 217 218 if add == iptable.Exists(iptables.Mangle, chain, rule...) { 219 return 220 } 221 222 if !add { 223 a = "-D" 224 action = "remove" 225 } 226 227 if err = iptable.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil { 228 logrus.Warnf("could not %s mangle rule: %v", action, err) 229 } 230 231 return 232 } 233 234 func programInput(vni uint32, add bool) (err error) { 235 var ( 236 port = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10) 237 vniMatch = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) 238 plainVxlan = []string{"-p", "udp", "--dport", port, "-m", "u32", "--u32", vniMatch, "-j"} 239 ipsecVxlan = append([]string{"-m", "policy", "--dir", "in", "--pol", "ipsec"}, plainVxlan...) 240 block = append(plainVxlan, "DROP") 241 accept = append(ipsecVxlan, "ACCEPT") 242 chain = "INPUT" 243 action = iptables.Append 244 msg = "add" 245 ) 246 247 // TODO IPv6 support 248 iptable := iptables.GetIptable(iptables.IPv4) 249 250 if !add { 251 action = iptables.Delete 252 msg = "remove" 253 } 254 255 if err := iptable.ProgramRule(iptables.Filter, chain, action, accept); err != nil { 256 logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err) 257 } 258 259 if err := iptable.ProgramRule(iptables.Filter, chain, action, block); err != nil { 260 logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err) 261 } 262 263 return 264 } 265 266 func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) { 267 var ( 268 action = "Removing" 269 xfrmProgram = ns.NlHandle().XfrmStateDel 270 ) 271 272 if add { 273 action = "Adding" 274 xfrmProgram = ns.NlHandle().XfrmStateAdd 275 } 276 277 if dir&reverse > 0 { 278 rSA = &netlink.XfrmState{ 279 Src: remoteIP, 280 Dst: localIP, 281 Proto: netlink.XFRM_PROTO_ESP, 282 Spi: spi.reverse, 283 Mode: netlink.XFRM_MODE_TRANSPORT, 284 Reqid: r, 285 } 286 if add { 287 rSA.Aead = buildAeadAlgo(k, spi.reverse) 288 } 289 290 exists, err := saExists(rSA) 291 if err != nil { 292 exists = !add 293 } 294 295 if add != exists { 296 logrus.Debugf("%s: rSA{%s}", action, rSA) 297 if err := xfrmProgram(rSA); err != nil { 298 logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err) 299 } 300 } 301 } 302 303 if dir&forward > 0 { 304 fSA = &netlink.XfrmState{ 305 Src: localIP, 306 Dst: remoteIP, 307 Proto: netlink.XFRM_PROTO_ESP, 308 Spi: spi.forward, 309 Mode: netlink.XFRM_MODE_TRANSPORT, 310 Reqid: r, 311 } 312 if add { 313 fSA.Aead = buildAeadAlgo(k, spi.forward) 314 } 315 316 exists, err := saExists(fSA) 317 if err != nil { 318 exists = !add 319 } 320 321 if add != exists { 322 logrus.Debugf("%s fSA{%s}", action, fSA) 323 if err := xfrmProgram(fSA); err != nil { 324 logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err) 325 } 326 } 327 } 328 329 return 330 } 331 332 func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error { 333 action := "Removing" 334 xfrmProgram := ns.NlHandle().XfrmPolicyDel 335 if add { 336 action = "Adding" 337 xfrmProgram = ns.NlHandle().XfrmPolicyAdd 338 } 339 340 // Create a congruent cidr 341 s := types.GetMinimalIP(fSA.Src) 342 d := types.GetMinimalIP(fSA.Dst) 343 fullMask := net.CIDRMask(8*len(s), 8*len(s)) 344 345 fPol := &netlink.XfrmPolicy{ 346 Src: &net.IPNet{IP: s, Mask: fullMask}, 347 Dst: &net.IPNet{IP: d, Mask: fullMask}, 348 Dir: netlink.XFRM_DIR_OUT, 349 Proto: 17, 350 DstPort: 4789, 351 Mark: &spMark, 352 Tmpls: []netlink.XfrmPolicyTmpl{ 353 { 354 Src: fSA.Src, 355 Dst: fSA.Dst, 356 Proto: netlink.XFRM_PROTO_ESP, 357 Mode: netlink.XFRM_MODE_TRANSPORT, 358 Spi: fSA.Spi, 359 Reqid: r, 360 }, 361 }, 362 } 363 364 exists, err := spExists(fPol) 365 if err != nil { 366 exists = !add 367 } 368 369 if add != exists { 370 logrus.Debugf("%s fSP{%s}", action, fPol) 371 if err := xfrmProgram(fPol); err != nil { 372 logrus.Warnf("%s fSP{%s}: %v", action, fPol, err) 373 } 374 } 375 376 return nil 377 } 378 379 func saExists(sa *netlink.XfrmState) (bool, error) { 380 _, err := ns.NlHandle().XfrmStateGet(sa) 381 switch err { 382 case nil: 383 return true, nil 384 case syscall.ESRCH: 385 return false, nil 386 default: 387 err = fmt.Errorf("Error while checking for SA existence: %v", err) 388 logrus.Warn(err) 389 return false, err 390 } 391 } 392 393 func spExists(sp *netlink.XfrmPolicy) (bool, error) { 394 _, err := ns.NlHandle().XfrmPolicyGet(sp) 395 switch err { 396 case nil: 397 return true, nil 398 case syscall.ENOENT: 399 return false, nil 400 default: 401 err = fmt.Errorf("Error while checking for SP existence: %v", err) 402 logrus.Warn(err) 403 return false, err 404 } 405 } 406 407 func buildSPI(src, dst net.IP, st uint32) int { 408 b := make([]byte, 4) 409 binary.BigEndian.PutUint32(b, st) 410 h := fnv.New32a() 411 h.Write(src) 412 h.Write(b) 413 h.Write(dst) 414 return int(binary.BigEndian.Uint32(h.Sum(nil))) 415 } 416 417 func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo { 418 salt := make([]byte, 4) 419 binary.BigEndian.PutUint32(salt, uint32(s)) 420 return &netlink.XfrmStateAlgo{ 421 Name: "rfc4106(gcm(aes))", 422 Key: append(k.value, salt...), 423 ICVLen: 64, 424 } 425 } 426 427 func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error { 428 d.secMap.Lock() 429 for node, indices := range d.secMap.nodes { 430 idxs, stop := f(node, indices) 431 if idxs != nil { 432 d.secMap.nodes[node] = idxs 433 } 434 if stop { 435 break 436 } 437 } 438 d.secMap.Unlock() 439 return nil 440 } 441 442 func (d *driver) setKeys(keys []*key) error { 443 // Remove any stale policy, state 444 clearEncryptionStates() 445 // Accept the encryption keys and clear any stale encryption map 446 d.Lock() 447 d.keys = keys 448 d.secMap = &encrMap{nodes: map[string][]*spi{}} 449 d.Unlock() 450 logrus.Debugf("Initial encryption keys: %v", keys) 451 return nil 452 } 453 454 // updateKeys allows to add a new key and/or change the primary key and/or prune an existing key 455 // The primary key is the key used in transmission and will go in first position in the list. 456 func (d *driver) updateKeys(newKey, primary, pruneKey *key) error { 457 logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey) 458 459 logrus.Debugf("Current: %v", d.keys) 460 461 var ( 462 newIdx = -1 463 priIdx = -1 464 delIdx = -1 465 lIP = net.ParseIP(d.bindAddress) 466 aIP = net.ParseIP(d.advertiseAddress) 467 ) 468 469 d.Lock() 470 defer d.Unlock() 471 472 // add new 473 if newKey != nil { 474 d.keys = append(d.keys, newKey) 475 newIdx += len(d.keys) 476 } 477 for i, k := range d.keys { 478 if primary != nil && k.tag == primary.tag { 479 priIdx = i 480 } 481 if pruneKey != nil && k.tag == pruneKey.tag { 482 delIdx = i 483 } 484 } 485 486 if (newKey != nil && newIdx == -1) || 487 (primary != nil && priIdx == -1) || 488 (pruneKey != nil && delIdx == -1) { 489 return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+ 490 "(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx) 491 } 492 493 if priIdx != -1 && priIdx == delIdx { 494 return types.BadRequestErrorf("attempting to both make a key (index %d) primary and delete it", priIdx) 495 } 496 497 d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) { 498 rIP := net.ParseIP(rIPs) 499 return updateNodeKey(lIP, aIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false 500 }) 501 502 // swap primary 503 if priIdx != -1 { 504 d.keys[0], d.keys[priIdx] = d.keys[priIdx], d.keys[0] 505 } 506 // prune 507 if delIdx != -1 { 508 if delIdx == 0 { 509 delIdx = priIdx 510 } 511 d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...) 512 } 513 514 logrus.Debugf("Updated: %v", d.keys) 515 516 return nil 517 } 518 519 /******************************************************** 520 * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1 521 * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1 522 * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2 523 *********************************************************/ 524 525 // Spis and keys are sorted in such away the one in position 0 is the primary 526 func updateNodeKey(lIP, aIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi { 527 logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx) 528 529 spis := idxs 530 logrus.Debugf("Current: %v", spis) 531 532 // add new 533 if newIdx != -1 { 534 spis = append(spis, &spi{ 535 forward: buildSPI(aIP, rIP, curKeys[newIdx].tag), 536 reverse: buildSPI(rIP, aIP, curKeys[newIdx].tag), 537 }) 538 } 539 540 if delIdx != -1 { 541 // -rSA0 542 programSA(lIP, rIP, spis[delIdx], nil, reverse, false) 543 } 544 545 if newIdx > -1 { 546 // +rSA2 547 programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true) 548 } 549 550 if priIdx > 0 { 551 // +fSA2 552 fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true) 553 554 // +fSP2, -fSP1 555 s := types.GetMinimalIP(fSA2.Src) 556 d := types.GetMinimalIP(fSA2.Dst) 557 fullMask := net.CIDRMask(8*len(s), 8*len(s)) 558 559 fSP1 := &netlink.XfrmPolicy{ 560 Src: &net.IPNet{IP: s, Mask: fullMask}, 561 Dst: &net.IPNet{IP: d, Mask: fullMask}, 562 Dir: netlink.XFRM_DIR_OUT, 563 Proto: 17, 564 DstPort: 4789, 565 Mark: &spMark, 566 Tmpls: []netlink.XfrmPolicyTmpl{ 567 { 568 Src: fSA2.Src, 569 Dst: fSA2.Dst, 570 Proto: netlink.XFRM_PROTO_ESP, 571 Mode: netlink.XFRM_MODE_TRANSPORT, 572 Spi: fSA2.Spi, 573 Reqid: r, 574 }, 575 }, 576 } 577 logrus.Debugf("Updating fSP{%s}", fSP1) 578 if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil { 579 logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err) 580 } 581 582 // -fSA1 583 programSA(lIP, rIP, spis[0], nil, forward, false) 584 } 585 586 // swap 587 if priIdx > 0 { 588 swp := spis[0] 589 spis[0] = spis[priIdx] 590 spis[priIdx] = swp 591 } 592 // prune 593 if delIdx != -1 { 594 if delIdx == 0 { 595 delIdx = priIdx 596 } 597 spis = append(spis[:delIdx], spis[delIdx+1:]...) 598 } 599 600 logrus.Debugf("Updated: %v", spis) 601 602 return spis 603 } 604 605 func (n *network) maxMTU() int { 606 mtu := 1500 607 if n.mtu != 0 { 608 mtu = n.mtu 609 } 610 mtu -= vxlanEncap 611 if n.secure { 612 // In case of encryption account for the 613 // esp packet expansion and padding 614 mtu -= pktExpansion 615 mtu -= (mtu % 4) 616 } 617 return mtu 618 } 619 620 func clearEncryptionStates() { 621 nlh := ns.NlHandle() 622 spList, err := nlh.XfrmPolicyList(netlink.FAMILY_ALL) 623 if err != nil { 624 logrus.Warnf("Failed to retrieve SP list for cleanup: %v", err) 625 } 626 saList, err := nlh.XfrmStateList(netlink.FAMILY_ALL) 627 if err != nil { 628 logrus.Warnf("Failed to retrieve SA list for cleanup: %v", err) 629 } 630 for _, sp := range spList { 631 sp := sp 632 if sp.Mark != nil && sp.Mark.Value == spMark.Value { 633 if err := nlh.XfrmPolicyDel(&sp); err != nil { 634 logrus.Warnf("Failed to delete stale SP %s: %v", sp, err) 635 continue 636 } 637 logrus.Debugf("Removed stale SP: %s", sp) 638 } 639 } 640 for _, sa := range saList { 641 sa := sa 642 if sa.Reqid == r { 643 if err := nlh.XfrmStateDel(&sa); err != nil { 644 logrus.Warnf("Failed to delete stale SA %s: %v", sa, err) 645 continue 646 } 647 logrus.Debugf("Removed stale SA: %s", sa) 648 } 649 } 650 }