github.com/docker/engine@v22.0.0-20211208180946-d456264580cf+incompatible/libnetwork/drivers/overlay/encryption.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package overlay
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/binary"
     9  	"encoding/hex"
    10  	"fmt"
    11  	"hash/fnv"
    12  	"net"
    13  	"sync"
    14  	"syscall"
    15  
    16  	"strconv"
    17  
    18  	"github.com/docker/docker/libnetwork/drivers/overlay/overlayutils"
    19  	"github.com/docker/docker/libnetwork/iptables"
    20  	"github.com/docker/docker/libnetwork/ns"
    21  	"github.com/docker/docker/libnetwork/types"
    22  	"github.com/sirupsen/logrus"
    23  	"github.com/vishvananda/netlink"
    24  )
    25  
    26  const (
    27  	r            = 0xD0C4E3
    28  	pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8)
    29  )
    30  
    31  const (
    32  	forward = iota + 1
    33  	reverse
    34  	bidir
    35  )
    36  
    37  var spMark = netlink.XfrmMark{Value: uint32(r), Mask: 0xffffffff}
    38  
    39  type key struct {
    40  	value []byte
    41  	tag   uint32
    42  }
    43  
    44  func (k *key) String() string {
    45  	if k != nil {
    46  		return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag)
    47  	}
    48  	return ""
    49  }
    50  
    51  type spi struct {
    52  	forward int
    53  	reverse int
    54  }
    55  
    56  func (s *spi) String() string {
    57  	return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse))
    58  }
    59  
    60  type encrMap struct {
    61  	nodes map[string][]*spi
    62  	sync.Mutex
    63  }
    64  
    65  func (e *encrMap) String() string {
    66  	e.Lock()
    67  	defer e.Unlock()
    68  	b := new(bytes.Buffer)
    69  	for k, v := range e.nodes {
    70  		b.WriteString("\n")
    71  		b.WriteString(k)
    72  		b.WriteString(":")
    73  		b.WriteString("[")
    74  		for _, s := range v {
    75  			b.WriteString(s.String())
    76  			b.WriteString(",")
    77  		}
    78  		b.WriteString("]")
    79  
    80  	}
    81  	return b.String()
    82  }
    83  
    84  func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error {
    85  	logrus.Debugf("checkEncryption(%.7s, %v, %d, %t)", nid, rIP, vxlanID, isLocal)
    86  
    87  	n := d.network(nid)
    88  	if n == nil || !n.secure {
    89  		return nil
    90  	}
    91  
    92  	if len(d.keys) == 0 {
    93  		return types.ForbiddenErrorf("encryption key is not present")
    94  	}
    95  
    96  	lIP := net.ParseIP(d.bindAddress)
    97  	aIP := net.ParseIP(d.advertiseAddress)
    98  	nodes := map[string]net.IP{}
    99  
   100  	switch {
   101  	case isLocal:
   102  		if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
   103  			if !aIP.Equal(pEntry.vtep) {
   104  				nodes[pEntry.vtep.String()] = pEntry.vtep
   105  			}
   106  			return false
   107  		}); err != nil {
   108  			logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %.5s: %v", nid, err)
   109  		}
   110  	default:
   111  		if len(d.network(nid).endpoints) > 0 {
   112  			nodes[rIP.String()] = rIP
   113  		}
   114  	}
   115  
   116  	logrus.Debugf("List of nodes: %s", nodes)
   117  
   118  	if add {
   119  		for _, rIP := range nodes {
   120  			if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil {
   121  				logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err)
   122  			}
   123  		}
   124  	} else {
   125  		if len(nodes) == 0 {
   126  			if err := removeEncryption(lIP, rIP, d.secMap); err != nil {
   127  				logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err)
   128  			}
   129  		}
   130  	}
   131  
   132  	return nil
   133  }
   134  
   135  func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error {
   136  	logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP)
   137  	rIPs := remoteIP.String()
   138  
   139  	indices := make([]*spi, 0, len(keys))
   140  
   141  	err := programMangle(vni, true)
   142  	if err != nil {
   143  		logrus.Warn(err)
   144  	}
   145  
   146  	err = programInput(vni, true)
   147  	if err != nil {
   148  		logrus.Warn(err)
   149  	}
   150  
   151  	for i, k := range keys {
   152  		spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)}
   153  		dir := reverse
   154  		if i == 0 {
   155  			dir = bidir
   156  		}
   157  		fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true)
   158  		if err != nil {
   159  			logrus.Warn(err)
   160  		}
   161  		indices = append(indices, spis)
   162  		if i != 0 {
   163  			continue
   164  		}
   165  		err = programSP(fSA, rSA, true)
   166  		if err != nil {
   167  			logrus.Warn(err)
   168  		}
   169  	}
   170  
   171  	em.Lock()
   172  	em.nodes[rIPs] = indices
   173  	em.Unlock()
   174  
   175  	return nil
   176  }
   177  
   178  func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error {
   179  	em.Lock()
   180  	indices, ok := em.nodes[remoteIP.String()]
   181  	em.Unlock()
   182  	if !ok {
   183  		return nil
   184  	}
   185  	for i, idxs := range indices {
   186  		dir := reverse
   187  		if i == 0 {
   188  			dir = bidir
   189  		}
   190  		fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false)
   191  		if err != nil {
   192  			logrus.Warn(err)
   193  		}
   194  		if i != 0 {
   195  			continue
   196  		}
   197  		err = programSP(fSA, rSA, false)
   198  		if err != nil {
   199  			logrus.Warn(err)
   200  		}
   201  	}
   202  	return nil
   203  }
   204  
   205  func programMangle(vni uint32, add bool) (err error) {
   206  	var (
   207  		p      = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10)
   208  		c      = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
   209  		m      = strconv.FormatUint(uint64(r), 10)
   210  		chain  = "OUTPUT"
   211  		rule   = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m}
   212  		a      = "-A"
   213  		action = "install"
   214  	)
   215  
   216  	// TODO IPv6 support
   217  	iptable := iptables.GetIptable(iptables.IPv4)
   218  
   219  	if add == iptable.Exists(iptables.Mangle, chain, rule...) {
   220  		return
   221  	}
   222  
   223  	if !add {
   224  		a = "-D"
   225  		action = "remove"
   226  	}
   227  
   228  	if err = iptable.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil {
   229  		logrus.Warnf("could not %s mangle rule: %v", action, err)
   230  	}
   231  
   232  	return
   233  }
   234  
   235  func programInput(vni uint32, add bool) (err error) {
   236  	var (
   237  		port       = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10)
   238  		vniMatch   = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
   239  		plainVxlan = []string{"-p", "udp", "--dport", port, "-m", "u32", "--u32", vniMatch, "-j"}
   240  		ipsecVxlan = append([]string{"-m", "policy", "--dir", "in", "--pol", "ipsec"}, plainVxlan...)
   241  		block      = append(plainVxlan, "DROP")
   242  		accept     = append(ipsecVxlan, "ACCEPT")
   243  		chain      = "INPUT"
   244  		action     = iptables.Append
   245  		msg        = "add"
   246  	)
   247  
   248  	// TODO IPv6 support
   249  	iptable := iptables.GetIptable(iptables.IPv4)
   250  
   251  	if !add {
   252  		action = iptables.Delete
   253  		msg = "remove"
   254  	}
   255  
   256  	if err := iptable.ProgramRule(iptables.Filter, chain, action, accept); err != nil {
   257  		logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err)
   258  	}
   259  
   260  	if err := iptable.ProgramRule(iptables.Filter, chain, action, block); err != nil {
   261  		logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err)
   262  	}
   263  
   264  	return
   265  }
   266  
   267  func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) {
   268  	var (
   269  		action      = "Removing"
   270  		xfrmProgram = ns.NlHandle().XfrmStateDel
   271  	)
   272  
   273  	if add {
   274  		action = "Adding"
   275  		xfrmProgram = ns.NlHandle().XfrmStateAdd
   276  	}
   277  
   278  	if dir&reverse > 0 {
   279  		rSA = &netlink.XfrmState{
   280  			Src:   remoteIP,
   281  			Dst:   localIP,
   282  			Proto: netlink.XFRM_PROTO_ESP,
   283  			Spi:   spi.reverse,
   284  			Mode:  netlink.XFRM_MODE_TRANSPORT,
   285  			Reqid: r,
   286  		}
   287  		if add {
   288  			rSA.Aead = buildAeadAlgo(k, spi.reverse)
   289  		}
   290  
   291  		exists, err := saExists(rSA)
   292  		if err != nil {
   293  			exists = !add
   294  		}
   295  
   296  		if add != exists {
   297  			logrus.Debugf("%s: rSA{%s}", action, rSA)
   298  			if err := xfrmProgram(rSA); err != nil {
   299  				logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err)
   300  			}
   301  		}
   302  	}
   303  
   304  	if dir&forward > 0 {
   305  		fSA = &netlink.XfrmState{
   306  			Src:   localIP,
   307  			Dst:   remoteIP,
   308  			Proto: netlink.XFRM_PROTO_ESP,
   309  			Spi:   spi.forward,
   310  			Mode:  netlink.XFRM_MODE_TRANSPORT,
   311  			Reqid: r,
   312  		}
   313  		if add {
   314  			fSA.Aead = buildAeadAlgo(k, spi.forward)
   315  		}
   316  
   317  		exists, err := saExists(fSA)
   318  		if err != nil {
   319  			exists = !add
   320  		}
   321  
   322  		if add != exists {
   323  			logrus.Debugf("%s fSA{%s}", action, fSA)
   324  			if err := xfrmProgram(fSA); err != nil {
   325  				logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err)
   326  			}
   327  		}
   328  	}
   329  
   330  	return
   331  }
   332  
   333  func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error {
   334  	action := "Removing"
   335  	xfrmProgram := ns.NlHandle().XfrmPolicyDel
   336  	if add {
   337  		action = "Adding"
   338  		xfrmProgram = ns.NlHandle().XfrmPolicyAdd
   339  	}
   340  
   341  	// Create a congruent cidr
   342  	s := types.GetMinimalIP(fSA.Src)
   343  	d := types.GetMinimalIP(fSA.Dst)
   344  	fullMask := net.CIDRMask(8*len(s), 8*len(s))
   345  
   346  	fPol := &netlink.XfrmPolicy{
   347  		Src:     &net.IPNet{IP: s, Mask: fullMask},
   348  		Dst:     &net.IPNet{IP: d, Mask: fullMask},
   349  		Dir:     netlink.XFRM_DIR_OUT,
   350  		Proto:   17,
   351  		DstPort: 4789,
   352  		Mark:    &spMark,
   353  		Tmpls: []netlink.XfrmPolicyTmpl{
   354  			{
   355  				Src:   fSA.Src,
   356  				Dst:   fSA.Dst,
   357  				Proto: netlink.XFRM_PROTO_ESP,
   358  				Mode:  netlink.XFRM_MODE_TRANSPORT,
   359  				Spi:   fSA.Spi,
   360  				Reqid: r,
   361  			},
   362  		},
   363  	}
   364  
   365  	exists, err := spExists(fPol)
   366  	if err != nil {
   367  		exists = !add
   368  	}
   369  
   370  	if add != exists {
   371  		logrus.Debugf("%s fSP{%s}", action, fPol)
   372  		if err := xfrmProgram(fPol); err != nil {
   373  			logrus.Warnf("%s fSP{%s}: %v", action, fPol, err)
   374  		}
   375  	}
   376  
   377  	return nil
   378  }
   379  
   380  func saExists(sa *netlink.XfrmState) (bool, error) {
   381  	_, err := ns.NlHandle().XfrmStateGet(sa)
   382  	switch err {
   383  	case nil:
   384  		return true, nil
   385  	case syscall.ESRCH:
   386  		return false, nil
   387  	default:
   388  		err = fmt.Errorf("Error while checking for SA existence: %v", err)
   389  		logrus.Warn(err)
   390  		return false, err
   391  	}
   392  }
   393  
   394  func spExists(sp *netlink.XfrmPolicy) (bool, error) {
   395  	_, err := ns.NlHandle().XfrmPolicyGet(sp)
   396  	switch err {
   397  	case nil:
   398  		return true, nil
   399  	case syscall.ENOENT:
   400  		return false, nil
   401  	default:
   402  		err = fmt.Errorf("Error while checking for SP existence: %v", err)
   403  		logrus.Warn(err)
   404  		return false, err
   405  	}
   406  }
   407  
   408  func buildSPI(src, dst net.IP, st uint32) int {
   409  	b := make([]byte, 4)
   410  	binary.BigEndian.PutUint32(b, st)
   411  	h := fnv.New32a()
   412  	h.Write(src)
   413  	h.Write(b)
   414  	h.Write(dst)
   415  	return int(binary.BigEndian.Uint32(h.Sum(nil)))
   416  }
   417  
   418  func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo {
   419  	salt := make([]byte, 4)
   420  	binary.BigEndian.PutUint32(salt, uint32(s))
   421  	return &netlink.XfrmStateAlgo{
   422  		Name:   "rfc4106(gcm(aes))",
   423  		Key:    append(k.value, salt...),
   424  		ICVLen: 64,
   425  	}
   426  }
   427  
   428  func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error {
   429  	d.secMap.Lock()
   430  	for node, indices := range d.secMap.nodes {
   431  		idxs, stop := f(node, indices)
   432  		if idxs != nil {
   433  			d.secMap.nodes[node] = idxs
   434  		}
   435  		if stop {
   436  			break
   437  		}
   438  	}
   439  	d.secMap.Unlock()
   440  	return nil
   441  }
   442  
   443  func (d *driver) setKeys(keys []*key) error {
   444  	// Remove any stale policy, state
   445  	clearEncryptionStates()
   446  	// Accept the encryption keys and clear any stale encryption map
   447  	d.Lock()
   448  	d.keys = keys
   449  	d.secMap = &encrMap{nodes: map[string][]*spi{}}
   450  	d.Unlock()
   451  	logrus.Debugf("Initial encryption keys: %v", keys)
   452  	return nil
   453  }
   454  
   455  // updateKeys allows to add a new key and/or change the primary key and/or prune an existing key
   456  // The primary key is the key used in transmission and will go in first position in the list.
   457  func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
   458  	logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey)
   459  
   460  	logrus.Debugf("Current: %v", d.keys)
   461  
   462  	var (
   463  		newIdx = -1
   464  		priIdx = -1
   465  		delIdx = -1
   466  		lIP    = net.ParseIP(d.bindAddress)
   467  		aIP    = net.ParseIP(d.advertiseAddress)
   468  	)
   469  
   470  	d.Lock()
   471  	defer d.Unlock()
   472  
   473  	// add new
   474  	if newKey != nil {
   475  		d.keys = append(d.keys, newKey)
   476  		newIdx += len(d.keys)
   477  	}
   478  	for i, k := range d.keys {
   479  		if primary != nil && k.tag == primary.tag {
   480  			priIdx = i
   481  		}
   482  		if pruneKey != nil && k.tag == pruneKey.tag {
   483  			delIdx = i
   484  		}
   485  	}
   486  
   487  	if (newKey != nil && newIdx == -1) ||
   488  		(primary != nil && priIdx == -1) ||
   489  		(pruneKey != nil && delIdx == -1) {
   490  		return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+
   491  			"(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx)
   492  	}
   493  
   494  	if priIdx != -1 && priIdx == delIdx {
   495  		return types.BadRequestErrorf("attempting to both make a key (index %d) primary and delete it", priIdx)
   496  	}
   497  
   498  	d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) {
   499  		rIP := net.ParseIP(rIPs)
   500  		return updateNodeKey(lIP, aIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false
   501  	})
   502  
   503  	// swap primary
   504  	if priIdx != -1 {
   505  		d.keys[0], d.keys[priIdx] = d.keys[priIdx], d.keys[0]
   506  	}
   507  	// prune
   508  	if delIdx != -1 {
   509  		if delIdx == 0 {
   510  			delIdx = priIdx
   511  		}
   512  		d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...)
   513  	}
   514  
   515  	logrus.Debugf("Updated: %v", d.keys)
   516  
   517  	return nil
   518  }
   519  
   520  /********************************************************
   521   * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1
   522   * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1
   523   * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2
   524   *********************************************************/
   525  
   526  // Spis and keys are sorted in such away the one in position 0 is the primary
   527  func updateNodeKey(lIP, aIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi {
   528  	logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx)
   529  
   530  	spis := idxs
   531  	logrus.Debugf("Current: %v", spis)
   532  
   533  	// add new
   534  	if newIdx != -1 {
   535  		spis = append(spis, &spi{
   536  			forward: buildSPI(aIP, rIP, curKeys[newIdx].tag),
   537  			reverse: buildSPI(rIP, aIP, curKeys[newIdx].tag),
   538  		})
   539  	}
   540  
   541  	if delIdx != -1 {
   542  		// -rSA0
   543  		programSA(lIP, rIP, spis[delIdx], nil, reverse, false)
   544  	}
   545  
   546  	if newIdx > -1 {
   547  		// +rSA2
   548  		programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true)
   549  	}
   550  
   551  	if priIdx > 0 {
   552  		// +fSA2
   553  		fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
   554  
   555  		// +fSP2, -fSP1
   556  		s := types.GetMinimalIP(fSA2.Src)
   557  		d := types.GetMinimalIP(fSA2.Dst)
   558  		fullMask := net.CIDRMask(8*len(s), 8*len(s))
   559  
   560  		fSP1 := &netlink.XfrmPolicy{
   561  			Src:     &net.IPNet{IP: s, Mask: fullMask},
   562  			Dst:     &net.IPNet{IP: d, Mask: fullMask},
   563  			Dir:     netlink.XFRM_DIR_OUT,
   564  			Proto:   17,
   565  			DstPort: 4789,
   566  			Mark:    &spMark,
   567  			Tmpls: []netlink.XfrmPolicyTmpl{
   568  				{
   569  					Src:   fSA2.Src,
   570  					Dst:   fSA2.Dst,
   571  					Proto: netlink.XFRM_PROTO_ESP,
   572  					Mode:  netlink.XFRM_MODE_TRANSPORT,
   573  					Spi:   fSA2.Spi,
   574  					Reqid: r,
   575  				},
   576  			},
   577  		}
   578  		logrus.Debugf("Updating fSP{%s}", fSP1)
   579  		if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil {
   580  			logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err)
   581  		}
   582  
   583  		// -fSA1
   584  		programSA(lIP, rIP, spis[0], nil, forward, false)
   585  	}
   586  
   587  	// swap
   588  	if priIdx > 0 {
   589  		swp := spis[0]
   590  		spis[0] = spis[priIdx]
   591  		spis[priIdx] = swp
   592  	}
   593  	// prune
   594  	if delIdx != -1 {
   595  		if delIdx == 0 {
   596  			delIdx = priIdx
   597  		}
   598  		spis = append(spis[:delIdx], spis[delIdx+1:]...)
   599  	}
   600  
   601  	logrus.Debugf("Updated: %v", spis)
   602  
   603  	return spis
   604  }
   605  
   606  func (n *network) maxMTU() int {
   607  	mtu := 1500
   608  	if n.mtu != 0 {
   609  		mtu = n.mtu
   610  	}
   611  	mtu -= vxlanEncap
   612  	if n.secure {
   613  		// In case of encryption account for the
   614  		// esp packet expansion and padding
   615  		mtu -= pktExpansion
   616  		mtu -= (mtu % 4)
   617  	}
   618  	return mtu
   619  }
   620  
   621  func clearEncryptionStates() {
   622  	nlh := ns.NlHandle()
   623  	spList, err := nlh.XfrmPolicyList(netlink.FAMILY_ALL)
   624  	if err != nil {
   625  		logrus.Warnf("Failed to retrieve SP list for cleanup: %v", err)
   626  	}
   627  	saList, err := nlh.XfrmStateList(netlink.FAMILY_ALL)
   628  	if err != nil {
   629  		logrus.Warnf("Failed to retrieve SA list for cleanup: %v", err)
   630  	}
   631  	for _, sp := range spList {
   632  		sp := sp
   633  		if sp.Mark != nil && sp.Mark.Value == spMark.Value {
   634  			if err := nlh.XfrmPolicyDel(&sp); err != nil {
   635  				logrus.Warnf("Failed to delete stale SP %s: %v", sp, err)
   636  				continue
   637  			}
   638  			logrus.Debugf("Removed stale SP: %s", sp)
   639  		}
   640  	}
   641  	for _, sa := range saList {
   642  		sa := sa
   643  		if sa.Reqid == r {
   644  			if err := nlh.XfrmStateDel(&sa); err != nil {
   645  				logrus.Warnf("Failed to delete stale SA %s: %v", sa, err)
   646  				continue
   647  			}
   648  			logrus.Debugf("Removed stale SA: %s", sa)
   649  		}
   650  	}
   651  }