github.com/pwn-term/docker@v0.0.0-20210616085119-6e977cce2565/libnetwork/drivers/overlay/encryption.go (about) 1 package overlay 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "encoding/hex" 7 "fmt" 8 "hash/fnv" 9 "net" 10 "sync" 11 "syscall" 12 13 "strconv" 14 15 "github.com/docker/libnetwork/drivers/overlay/overlayutils" 16 "github.com/docker/libnetwork/iptables" 17 "github.com/docker/libnetwork/ns" 18 "github.com/docker/libnetwork/types" 19 "github.com/sirupsen/logrus" 20 "github.com/vishvananda/netlink" 21 ) 22 23 const ( 24 r = 0xD0C4E3 25 pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8) 26 ) 27 28 const ( 29 forward = iota + 1 30 reverse 31 bidir 32 ) 33 34 var spMark = netlink.XfrmMark{Value: uint32(r), Mask: 0xffffffff} 35 36 type key struct { 37 value []byte 38 tag uint32 39 } 40 41 func (k *key) String() string { 42 if k != nil { 43 return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag) 44 } 45 return "" 46 } 47 48 type spi struct { 49 forward int 50 reverse int 51 } 52 53 func (s *spi) String() string { 54 return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse)) 55 } 56 57 type encrMap struct { 58 nodes map[string][]*spi 59 sync.Mutex 60 } 61 62 func (e *encrMap) String() string { 63 e.Lock() 64 defer e.Unlock() 65 b := new(bytes.Buffer) 66 for k, v := range e.nodes { 67 b.WriteString("\n") 68 b.WriteString(k) 69 b.WriteString(":") 70 b.WriteString("[") 71 for _, s := range v { 72 b.WriteString(s.String()) 73 b.WriteString(",") 74 } 75 b.WriteString("]") 76 77 } 78 return b.String() 79 } 80 81 func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error { 82 logrus.Debugf("checkEncryption(%.7s, %v, %d, %t)", nid, rIP, vxlanID, isLocal) 83 84 n := d.network(nid) 85 if n == nil || !n.secure { 86 return nil 87 } 88 89 if len(d.keys) == 0 { 90 return types.ForbiddenErrorf("encryption key is not present") 91 } 92 93 lIP := net.ParseIP(d.bindAddress) 94 aIP := net.ParseIP(d.advertiseAddress) 95 nodes := map[string]net.IP{} 96 97 switch { 98 case isLocal: 99 if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool { 100 if !aIP.Equal(pEntry.vtep) { 101 nodes[pEntry.vtep.String()] = pEntry.vtep 102 } 103 return false 104 }); err != nil { 105 logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %.5s: %v", nid, err) 106 } 107 default: 108 if len(d.network(nid).endpoints) > 0 { 109 nodes[rIP.String()] = rIP 110 } 111 } 112 113 logrus.Debugf("List of nodes: %s", nodes) 114 115 if add { 116 for _, rIP := range nodes { 117 if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil { 118 logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err) 119 } 120 } 121 } else { 122 if len(nodes) == 0 { 123 if err := removeEncryption(lIP, rIP, d.secMap); err != nil { 124 logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err) 125 } 126 } 127 } 128 129 return nil 130 } 131 132 func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error { 133 logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP) 134 rIPs := remoteIP.String() 135 136 indices := make([]*spi, 0, len(keys)) 137 138 err := programMangle(vni, true) 139 if err != nil { 140 logrus.Warn(err) 141 } 142 143 err = programInput(vni, true) 144 if err != nil { 145 logrus.Warn(err) 146 } 147 148 for i, k := range keys { 149 spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)} 150 dir := reverse 151 if i == 0 { 152 dir = bidir 153 } 154 fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true) 155 if err != nil { 156 logrus.Warn(err) 157 } 158 indices = append(indices, spis) 159 if i != 0 { 160 continue 161 } 162 err = programSP(fSA, rSA, true) 163 if err != nil { 164 logrus.Warn(err) 165 } 166 } 167 168 em.Lock() 169 em.nodes[rIPs] = indices 170 em.Unlock() 171 172 return nil 173 } 174 175 func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error { 176 em.Lock() 177 indices, ok := em.nodes[remoteIP.String()] 178 em.Unlock() 179 if !ok { 180 return nil 181 } 182 for i, idxs := range indices { 183 dir := reverse 184 if i == 0 { 185 dir = bidir 186 } 187 fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false) 188 if err != nil { 189 logrus.Warn(err) 190 } 191 if i != 0 { 192 continue 193 } 194 err = programSP(fSA, rSA, false) 195 if err != nil { 196 logrus.Warn(err) 197 } 198 } 199 return nil 200 } 201 202 func programMangle(vni uint32, add bool) (err error) { 203 var ( 204 p = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10) 205 c = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) 206 m = strconv.FormatUint(uint64(r), 10) 207 chain = "OUTPUT" 208 rule = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m} 209 a = "-A" 210 action = "install" 211 ) 212 213 // TODO IPv6 support 214 iptable := iptables.GetIptable(iptables.IPv4) 215 216 if add == iptable.Exists(iptables.Mangle, chain, rule...) { 217 return 218 } 219 220 if !add { 221 a = "-D" 222 action = "remove" 223 } 224 225 if err = iptable.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil { 226 logrus.Warnf("could not %s mangle rule: %v", action, err) 227 } 228 229 return 230 } 231 232 func programInput(vni uint32, add bool) (err error) { 233 var ( 234 port = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10) 235 vniMatch = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) 236 plainVxlan = []string{"-p", "udp", "--dport", port, "-m", "u32", "--u32", vniMatch, "-j"} 237 ipsecVxlan = append([]string{"-m", "policy", "--dir", "in", "--pol", "ipsec"}, plainVxlan...) 238 block = append(plainVxlan, "DROP") 239 accept = append(ipsecVxlan, "ACCEPT") 240 chain = "INPUT" 241 action = iptables.Append 242 msg = "add" 243 ) 244 245 // TODO IPv6 support 246 iptable := iptables.GetIptable(iptables.IPv4) 247 248 if !add { 249 action = iptables.Delete 250 msg = "remove" 251 } 252 253 if err := iptable.ProgramRule(iptables.Filter, chain, action, accept); err != nil { 254 logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err) 255 } 256 257 if err := iptable.ProgramRule(iptables.Filter, chain, action, block); err != nil { 258 logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err) 259 } 260 261 return 262 } 263 264 func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) { 265 var ( 266 action = "Removing" 267 xfrmProgram = ns.NlHandle().XfrmStateDel 268 ) 269 270 if add { 271 action = "Adding" 272 xfrmProgram = ns.NlHandle().XfrmStateAdd 273 } 274 275 if dir&reverse > 0 { 276 rSA = &netlink.XfrmState{ 277 Src: remoteIP, 278 Dst: localIP, 279 Proto: netlink.XFRM_PROTO_ESP, 280 Spi: spi.reverse, 281 Mode: netlink.XFRM_MODE_TRANSPORT, 282 Reqid: r, 283 } 284 if add { 285 rSA.Aead = buildAeadAlgo(k, spi.reverse) 286 } 287 288 exists, err := saExists(rSA) 289 if err != nil { 290 exists = !add 291 } 292 293 if add != exists { 294 logrus.Debugf("%s: rSA{%s}", action, rSA) 295 if err := xfrmProgram(rSA); err != nil { 296 logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err) 297 } 298 } 299 } 300 301 if dir&forward > 0 { 302 fSA = &netlink.XfrmState{ 303 Src: localIP, 304 Dst: remoteIP, 305 Proto: netlink.XFRM_PROTO_ESP, 306 Spi: spi.forward, 307 Mode: netlink.XFRM_MODE_TRANSPORT, 308 Reqid: r, 309 } 310 if add { 311 fSA.Aead = buildAeadAlgo(k, spi.forward) 312 } 313 314 exists, err := saExists(fSA) 315 if err != nil { 316 exists = !add 317 } 318 319 if add != exists { 320 logrus.Debugf("%s fSA{%s}", action, fSA) 321 if err := xfrmProgram(fSA); err != nil { 322 logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err) 323 } 324 } 325 } 326 327 return 328 } 329 330 func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error { 331 action := "Removing" 332 xfrmProgram := ns.NlHandle().XfrmPolicyDel 333 if add { 334 action = "Adding" 335 xfrmProgram = ns.NlHandle().XfrmPolicyAdd 336 } 337 338 // Create a congruent cidr 339 s := types.GetMinimalIP(fSA.Src) 340 d := types.GetMinimalIP(fSA.Dst) 341 fullMask := net.CIDRMask(8*len(s), 8*len(s)) 342 343 fPol := &netlink.XfrmPolicy{ 344 Src: &net.IPNet{IP: s, Mask: fullMask}, 345 Dst: &net.IPNet{IP: d, Mask: fullMask}, 346 Dir: netlink.XFRM_DIR_OUT, 347 Proto: 17, 348 DstPort: 4789, 349 Mark: &spMark, 350 Tmpls: []netlink.XfrmPolicyTmpl{ 351 { 352 Src: fSA.Src, 353 Dst: fSA.Dst, 354 Proto: netlink.XFRM_PROTO_ESP, 355 Mode: netlink.XFRM_MODE_TRANSPORT, 356 Spi: fSA.Spi, 357 Reqid: r, 358 }, 359 }, 360 } 361 362 exists, err := spExists(fPol) 363 if err != nil { 364 exists = !add 365 } 366 367 if add != exists { 368 logrus.Debugf("%s fSP{%s}", action, fPol) 369 if err := xfrmProgram(fPol); err != nil { 370 logrus.Warnf("%s fSP{%s}: %v", action, fPol, err) 371 } 372 } 373 374 return nil 375 } 376 377 func saExists(sa *netlink.XfrmState) (bool, error) { 378 _, err := ns.NlHandle().XfrmStateGet(sa) 379 switch err { 380 case nil: 381 return true, nil 382 case syscall.ESRCH: 383 return false, nil 384 default: 385 err = fmt.Errorf("Error while checking for SA existence: %v", err) 386 logrus.Warn(err) 387 return false, err 388 } 389 } 390 391 func spExists(sp *netlink.XfrmPolicy) (bool, error) { 392 _, err := ns.NlHandle().XfrmPolicyGet(sp) 393 switch err { 394 case nil: 395 return true, nil 396 case syscall.ENOENT: 397 return false, nil 398 default: 399 err = fmt.Errorf("Error while checking for SP existence: %v", err) 400 logrus.Warn(err) 401 return false, err 402 } 403 } 404 405 func buildSPI(src, dst net.IP, st uint32) int { 406 b := make([]byte, 4) 407 binary.BigEndian.PutUint32(b, st) 408 h := fnv.New32a() 409 h.Write(src) 410 h.Write(b) 411 h.Write(dst) 412 return int(binary.BigEndian.Uint32(h.Sum(nil))) 413 } 414 415 func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo { 416 salt := make([]byte, 4) 417 binary.BigEndian.PutUint32(salt, uint32(s)) 418 return &netlink.XfrmStateAlgo{ 419 Name: "rfc4106(gcm(aes))", 420 Key: append(k.value, salt...), 421 ICVLen: 64, 422 } 423 } 424 425 func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error { 426 d.secMap.Lock() 427 for node, indices := range d.secMap.nodes { 428 idxs, stop := f(node, indices) 429 if idxs != nil { 430 d.secMap.nodes[node] = idxs 431 } 432 if stop { 433 break 434 } 435 } 436 d.secMap.Unlock() 437 return nil 438 } 439 440 func (d *driver) setKeys(keys []*key) error { 441 // Remove any stale policy, state 442 clearEncryptionStates() 443 // Accept the encryption keys and clear any stale encryption map 444 d.Lock() 445 d.keys = keys 446 d.secMap = &encrMap{nodes: map[string][]*spi{}} 447 d.Unlock() 448 logrus.Debugf("Initial encryption keys: %v", keys) 449 return nil 450 } 451 452 // updateKeys allows to add a new key and/or change the primary key and/or prune an existing key 453 // The primary key is the key used in transmission and will go in first position in the list. 454 func (d *driver) updateKeys(newKey, primary, pruneKey *key) error { 455 logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey) 456 457 logrus.Debugf("Current: %v", d.keys) 458 459 var ( 460 newIdx = -1 461 priIdx = -1 462 delIdx = -1 463 lIP = net.ParseIP(d.bindAddress) 464 aIP = net.ParseIP(d.advertiseAddress) 465 ) 466 467 d.Lock() 468 defer d.Unlock() 469 470 // add new 471 if newKey != nil { 472 d.keys = append(d.keys, newKey) 473 newIdx += len(d.keys) 474 } 475 for i, k := range d.keys { 476 if primary != nil && k.tag == primary.tag { 477 priIdx = i 478 } 479 if pruneKey != nil && k.tag == pruneKey.tag { 480 delIdx = i 481 } 482 } 483 484 if (newKey != nil && newIdx == -1) || 485 (primary != nil && priIdx == -1) || 486 (pruneKey != nil && delIdx == -1) { 487 return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+ 488 "(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx) 489 } 490 491 if priIdx != -1 && priIdx == delIdx { 492 return types.BadRequestErrorf("attempting to both make a key (index %d) primary and delete it", priIdx) 493 } 494 495 d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) { 496 rIP := net.ParseIP(rIPs) 497 return updateNodeKey(lIP, aIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false 498 }) 499 500 // swap primary 501 if priIdx != -1 { 502 d.keys[0], d.keys[priIdx] = d.keys[priIdx], d.keys[0] 503 } 504 // prune 505 if delIdx != -1 { 506 if delIdx == 0 { 507 delIdx = priIdx 508 } 509 d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...) 510 } 511 512 logrus.Debugf("Updated: %v", d.keys) 513 514 return nil 515 } 516 517 /******************************************************** 518 * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1 519 * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1 520 * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2 521 *********************************************************/ 522 523 // Spis and keys are sorted in such away the one in position 0 is the primary 524 func updateNodeKey(lIP, aIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi { 525 logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx) 526 527 spis := idxs 528 logrus.Debugf("Current: %v", spis) 529 530 // add new 531 if newIdx != -1 { 532 spis = append(spis, &spi{ 533 forward: buildSPI(aIP, rIP, curKeys[newIdx].tag), 534 reverse: buildSPI(rIP, aIP, curKeys[newIdx].tag), 535 }) 536 } 537 538 if delIdx != -1 { 539 // -rSA0 540 programSA(lIP, rIP, spis[delIdx], nil, reverse, false) 541 } 542 543 if newIdx > -1 { 544 // +rSA2 545 programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true) 546 } 547 548 if priIdx > 0 { 549 // +fSA2 550 fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true) 551 552 // +fSP2, -fSP1 553 s := types.GetMinimalIP(fSA2.Src) 554 d := types.GetMinimalIP(fSA2.Dst) 555 fullMask := net.CIDRMask(8*len(s), 8*len(s)) 556 557 fSP1 := &netlink.XfrmPolicy{ 558 Src: &net.IPNet{IP: s, Mask: fullMask}, 559 Dst: &net.IPNet{IP: d, Mask: fullMask}, 560 Dir: netlink.XFRM_DIR_OUT, 561 Proto: 17, 562 DstPort: 4789, 563 Mark: &spMark, 564 Tmpls: []netlink.XfrmPolicyTmpl{ 565 { 566 Src: fSA2.Src, 567 Dst: fSA2.Dst, 568 Proto: netlink.XFRM_PROTO_ESP, 569 Mode: netlink.XFRM_MODE_TRANSPORT, 570 Spi: fSA2.Spi, 571 Reqid: r, 572 }, 573 }, 574 } 575 logrus.Debugf("Updating fSP{%s}", fSP1) 576 if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil { 577 logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err) 578 } 579 580 // -fSA1 581 programSA(lIP, rIP, spis[0], nil, forward, false) 582 } 583 584 // swap 585 if priIdx > 0 { 586 swp := spis[0] 587 spis[0] = spis[priIdx] 588 spis[priIdx] = swp 589 } 590 // prune 591 if delIdx != -1 { 592 if delIdx == 0 { 593 delIdx = priIdx 594 } 595 spis = append(spis[:delIdx], spis[delIdx+1:]...) 596 } 597 598 logrus.Debugf("Updated: %v", spis) 599 600 return spis 601 } 602 603 func (n *network) maxMTU() int { 604 mtu := 1500 605 if n.mtu != 0 { 606 mtu = n.mtu 607 } 608 mtu -= vxlanEncap 609 if n.secure { 610 // In case of encryption account for the 611 // esp packet expansion and padding 612 mtu -= pktExpansion 613 mtu -= (mtu % 4) 614 } 615 return mtu 616 } 617 618 func clearEncryptionStates() { 619 nlh := ns.NlHandle() 620 spList, err := nlh.XfrmPolicyList(netlink.FAMILY_ALL) 621 if err != nil { 622 logrus.Warnf("Failed to retrieve SP list for cleanup: %v", err) 623 } 624 saList, err := nlh.XfrmStateList(netlink.FAMILY_ALL) 625 if err != nil { 626 logrus.Warnf("Failed to retrieve SA list for cleanup: %v", err) 627 } 628 for _, sp := range spList { 629 if sp.Mark != nil && sp.Mark.Value == spMark.Value { 630 if err := nlh.XfrmPolicyDel(&sp); err != nil { 631 logrus.Warnf("Failed to delete stale SP %s: %v", sp, err) 632 continue 633 } 634 logrus.Debugf("Removed stale SP: %s", sp) 635 } 636 } 637 for _, sa := range saList { 638 if sa.Reqid == r { 639 if err := nlh.XfrmStateDel(&sa); err != nil { 640 logrus.Warnf("Failed to delete stale SA %s: %v", sa, err) 641 continue 642 } 643 logrus.Debugf("Removed stale SA: %s", sa) 644 } 645 } 646 }