github.com/jfrazelle/docker@v1.1.2-0.20210712172922-bf78e25fe508/libnetwork/drivers/overlay/encryption.go (about)

     1  // +build linux
     2  
     3  package overlay
     4  
     5  import (
     6  	"bytes"
     7  	"encoding/binary"
     8  	"encoding/hex"
     9  	"fmt"
    10  	"hash/fnv"
    11  	"net"
    12  	"sync"
    13  	"syscall"
    14  
    15  	"strconv"
    16  
    17  	"github.com/docker/docker/libnetwork/drivers/overlay/overlayutils"
    18  	"github.com/docker/docker/libnetwork/iptables"
    19  	"github.com/docker/docker/libnetwork/ns"
    20  	"github.com/docker/docker/libnetwork/types"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/vishvananda/netlink"
    23  )
    24  
    25  const (
    26  	r            = 0xD0C4E3
    27  	pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8)
    28  )
    29  
    30  const (
    31  	forward = iota + 1
    32  	reverse
    33  	bidir
    34  )
    35  
    36  var spMark = netlink.XfrmMark{Value: uint32(r), Mask: 0xffffffff}
    37  
    38  type key struct {
    39  	value []byte
    40  	tag   uint32
    41  }
    42  
    43  func (k *key) String() string {
    44  	if k != nil {
    45  		return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag)
    46  	}
    47  	return ""
    48  }
    49  
    50  type spi struct {
    51  	forward int
    52  	reverse int
    53  }
    54  
    55  func (s *spi) String() string {
    56  	return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse))
    57  }
    58  
    59  type encrMap struct {
    60  	nodes map[string][]*spi
    61  	sync.Mutex
    62  }
    63  
    64  func (e *encrMap) String() string {
    65  	e.Lock()
    66  	defer e.Unlock()
    67  	b := new(bytes.Buffer)
    68  	for k, v := range e.nodes {
    69  		b.WriteString("\n")
    70  		b.WriteString(k)
    71  		b.WriteString(":")
    72  		b.WriteString("[")
    73  		for _, s := range v {
    74  			b.WriteString(s.String())
    75  			b.WriteString(",")
    76  		}
    77  		b.WriteString("]")
    78  
    79  	}
    80  	return b.String()
    81  }
    82  
    83  func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error {
    84  	logrus.Debugf("checkEncryption(%.7s, %v, %d, %t)", nid, rIP, vxlanID, isLocal)
    85  
    86  	n := d.network(nid)
    87  	if n == nil || !n.secure {
    88  		return nil
    89  	}
    90  
    91  	if len(d.keys) == 0 {
    92  		return types.ForbiddenErrorf("encryption key is not present")
    93  	}
    94  
    95  	lIP := net.ParseIP(d.bindAddress)
    96  	aIP := net.ParseIP(d.advertiseAddress)
    97  	nodes := map[string]net.IP{}
    98  
    99  	switch {
   100  	case isLocal:
   101  		if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
   102  			if !aIP.Equal(pEntry.vtep) {
   103  				nodes[pEntry.vtep.String()] = pEntry.vtep
   104  			}
   105  			return false
   106  		}); err != nil {
   107  			logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %.5s: %v", nid, err)
   108  		}
   109  	default:
   110  		if len(d.network(nid).endpoints) > 0 {
   111  			nodes[rIP.String()] = rIP
   112  		}
   113  	}
   114  
   115  	logrus.Debugf("List of nodes: %s", nodes)
   116  
   117  	if add {
   118  		for _, rIP := range nodes {
   119  			if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil {
   120  				logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err)
   121  			}
   122  		}
   123  	} else {
   124  		if len(nodes) == 0 {
   125  			if err := removeEncryption(lIP, rIP, d.secMap); err != nil {
   126  				logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err)
   127  			}
   128  		}
   129  	}
   130  
   131  	return nil
   132  }
   133  
   134  func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error {
   135  	logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP)
   136  	rIPs := remoteIP.String()
   137  
   138  	indices := make([]*spi, 0, len(keys))
   139  
   140  	err := programMangle(vni, true)
   141  	if err != nil {
   142  		logrus.Warn(err)
   143  	}
   144  
   145  	err = programInput(vni, true)
   146  	if err != nil {
   147  		logrus.Warn(err)
   148  	}
   149  
   150  	for i, k := range keys {
   151  		spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)}
   152  		dir := reverse
   153  		if i == 0 {
   154  			dir = bidir
   155  		}
   156  		fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true)
   157  		if err != nil {
   158  			logrus.Warn(err)
   159  		}
   160  		indices = append(indices, spis)
   161  		if i != 0 {
   162  			continue
   163  		}
   164  		err = programSP(fSA, rSA, true)
   165  		if err != nil {
   166  			logrus.Warn(err)
   167  		}
   168  	}
   169  
   170  	em.Lock()
   171  	em.nodes[rIPs] = indices
   172  	em.Unlock()
   173  
   174  	return nil
   175  }
   176  
   177  func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error {
   178  	em.Lock()
   179  	indices, ok := em.nodes[remoteIP.String()]
   180  	em.Unlock()
   181  	if !ok {
   182  		return nil
   183  	}
   184  	for i, idxs := range indices {
   185  		dir := reverse
   186  		if i == 0 {
   187  			dir = bidir
   188  		}
   189  		fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false)
   190  		if err != nil {
   191  			logrus.Warn(err)
   192  		}
   193  		if i != 0 {
   194  			continue
   195  		}
   196  		err = programSP(fSA, rSA, false)
   197  		if err != nil {
   198  			logrus.Warn(err)
   199  		}
   200  	}
   201  	return nil
   202  }
   203  
   204  func programMangle(vni uint32, add bool) (err error) {
   205  	var (
   206  		p      = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10)
   207  		c      = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
   208  		m      = strconv.FormatUint(uint64(r), 10)
   209  		chain  = "OUTPUT"
   210  		rule   = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m}
   211  		a      = "-A"
   212  		action = "install"
   213  	)
   214  
   215  	// TODO IPv6 support
   216  	iptable := iptables.GetIptable(iptables.IPv4)
   217  
   218  	if add == iptable.Exists(iptables.Mangle, chain, rule...) {
   219  		return
   220  	}
   221  
   222  	if !add {
   223  		a = "-D"
   224  		action = "remove"
   225  	}
   226  
   227  	if err = iptable.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil {
   228  		logrus.Warnf("could not %s mangle rule: %v", action, err)
   229  	}
   230  
   231  	return
   232  }
   233  
   234  func programInput(vni uint32, add bool) (err error) {
   235  	var (
   236  		port       = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10)
   237  		vniMatch   = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
   238  		plainVxlan = []string{"-p", "udp", "--dport", port, "-m", "u32", "--u32", vniMatch, "-j"}
   239  		ipsecVxlan = append([]string{"-m", "policy", "--dir", "in", "--pol", "ipsec"}, plainVxlan...)
   240  		block      = append(plainVxlan, "DROP")
   241  		accept     = append(ipsecVxlan, "ACCEPT")
   242  		chain      = "INPUT"
   243  		action     = iptables.Append
   244  		msg        = "add"
   245  	)
   246  
   247  	// TODO IPv6 support
   248  	iptable := iptables.GetIptable(iptables.IPv4)
   249  
   250  	if !add {
   251  		action = iptables.Delete
   252  		msg = "remove"
   253  	}
   254  
   255  	if err := iptable.ProgramRule(iptables.Filter, chain, action, accept); err != nil {
   256  		logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err)
   257  	}
   258  
   259  	if err := iptable.ProgramRule(iptables.Filter, chain, action, block); err != nil {
   260  		logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err)
   261  	}
   262  
   263  	return
   264  }
   265  
   266  func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) {
   267  	var (
   268  		action      = "Removing"
   269  		xfrmProgram = ns.NlHandle().XfrmStateDel
   270  	)
   271  
   272  	if add {
   273  		action = "Adding"
   274  		xfrmProgram = ns.NlHandle().XfrmStateAdd
   275  	}
   276  
   277  	if dir&reverse > 0 {
   278  		rSA = &netlink.XfrmState{
   279  			Src:   remoteIP,
   280  			Dst:   localIP,
   281  			Proto: netlink.XFRM_PROTO_ESP,
   282  			Spi:   spi.reverse,
   283  			Mode:  netlink.XFRM_MODE_TRANSPORT,
   284  			Reqid: r,
   285  		}
   286  		if add {
   287  			rSA.Aead = buildAeadAlgo(k, spi.reverse)
   288  		}
   289  
   290  		exists, err := saExists(rSA)
   291  		if err != nil {
   292  			exists = !add
   293  		}
   294  
   295  		if add != exists {
   296  			logrus.Debugf("%s: rSA{%s}", action, rSA)
   297  			if err := xfrmProgram(rSA); err != nil {
   298  				logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err)
   299  			}
   300  		}
   301  	}
   302  
   303  	if dir&forward > 0 {
   304  		fSA = &netlink.XfrmState{
   305  			Src:   localIP,
   306  			Dst:   remoteIP,
   307  			Proto: netlink.XFRM_PROTO_ESP,
   308  			Spi:   spi.forward,
   309  			Mode:  netlink.XFRM_MODE_TRANSPORT,
   310  			Reqid: r,
   311  		}
   312  		if add {
   313  			fSA.Aead = buildAeadAlgo(k, spi.forward)
   314  		}
   315  
   316  		exists, err := saExists(fSA)
   317  		if err != nil {
   318  			exists = !add
   319  		}
   320  
   321  		if add != exists {
   322  			logrus.Debugf("%s fSA{%s}", action, fSA)
   323  			if err := xfrmProgram(fSA); err != nil {
   324  				logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err)
   325  			}
   326  		}
   327  	}
   328  
   329  	return
   330  }
   331  
   332  func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error {
   333  	action := "Removing"
   334  	xfrmProgram := ns.NlHandle().XfrmPolicyDel
   335  	if add {
   336  		action = "Adding"
   337  		xfrmProgram = ns.NlHandle().XfrmPolicyAdd
   338  	}
   339  
   340  	// Create a congruent cidr
   341  	s := types.GetMinimalIP(fSA.Src)
   342  	d := types.GetMinimalIP(fSA.Dst)
   343  	fullMask := net.CIDRMask(8*len(s), 8*len(s))
   344  
   345  	fPol := &netlink.XfrmPolicy{
   346  		Src:     &net.IPNet{IP: s, Mask: fullMask},
   347  		Dst:     &net.IPNet{IP: d, Mask: fullMask},
   348  		Dir:     netlink.XFRM_DIR_OUT,
   349  		Proto:   17,
   350  		DstPort: 4789,
   351  		Mark:    &spMark,
   352  		Tmpls: []netlink.XfrmPolicyTmpl{
   353  			{
   354  				Src:   fSA.Src,
   355  				Dst:   fSA.Dst,
   356  				Proto: netlink.XFRM_PROTO_ESP,
   357  				Mode:  netlink.XFRM_MODE_TRANSPORT,
   358  				Spi:   fSA.Spi,
   359  				Reqid: r,
   360  			},
   361  		},
   362  	}
   363  
   364  	exists, err := spExists(fPol)
   365  	if err != nil {
   366  		exists = !add
   367  	}
   368  
   369  	if add != exists {
   370  		logrus.Debugf("%s fSP{%s}", action, fPol)
   371  		if err := xfrmProgram(fPol); err != nil {
   372  			logrus.Warnf("%s fSP{%s}: %v", action, fPol, err)
   373  		}
   374  	}
   375  
   376  	return nil
   377  }
   378  
   379  func saExists(sa *netlink.XfrmState) (bool, error) {
   380  	_, err := ns.NlHandle().XfrmStateGet(sa)
   381  	switch err {
   382  	case nil:
   383  		return true, nil
   384  	case syscall.ESRCH:
   385  		return false, nil
   386  	default:
   387  		err = fmt.Errorf("Error while checking for SA existence: %v", err)
   388  		logrus.Warn(err)
   389  		return false, err
   390  	}
   391  }
   392  
   393  func spExists(sp *netlink.XfrmPolicy) (bool, error) {
   394  	_, err := ns.NlHandle().XfrmPolicyGet(sp)
   395  	switch err {
   396  	case nil:
   397  		return true, nil
   398  	case syscall.ENOENT:
   399  		return false, nil
   400  	default:
   401  		err = fmt.Errorf("Error while checking for SP existence: %v", err)
   402  		logrus.Warn(err)
   403  		return false, err
   404  	}
   405  }
   406  
   407  func buildSPI(src, dst net.IP, st uint32) int {
   408  	b := make([]byte, 4)
   409  	binary.BigEndian.PutUint32(b, st)
   410  	h := fnv.New32a()
   411  	h.Write(src)
   412  	h.Write(b)
   413  	h.Write(dst)
   414  	return int(binary.BigEndian.Uint32(h.Sum(nil)))
   415  }
   416  
   417  func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo {
   418  	salt := make([]byte, 4)
   419  	binary.BigEndian.PutUint32(salt, uint32(s))
   420  	return &netlink.XfrmStateAlgo{
   421  		Name:   "rfc4106(gcm(aes))",
   422  		Key:    append(k.value, salt...),
   423  		ICVLen: 64,
   424  	}
   425  }
   426  
   427  func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error {
   428  	d.secMap.Lock()
   429  	for node, indices := range d.secMap.nodes {
   430  		idxs, stop := f(node, indices)
   431  		if idxs != nil {
   432  			d.secMap.nodes[node] = idxs
   433  		}
   434  		if stop {
   435  			break
   436  		}
   437  	}
   438  	d.secMap.Unlock()
   439  	return nil
   440  }
   441  
   442  func (d *driver) setKeys(keys []*key) error {
   443  	// Remove any stale policy, state
   444  	clearEncryptionStates()
   445  	// Accept the encryption keys and clear any stale encryption map
   446  	d.Lock()
   447  	d.keys = keys
   448  	d.secMap = &encrMap{nodes: map[string][]*spi{}}
   449  	d.Unlock()
   450  	logrus.Debugf("Initial encryption keys: %v", keys)
   451  	return nil
   452  }
   453  
   454  // updateKeys allows to add a new key and/or change the primary key and/or prune an existing key
   455  // The primary key is the key used in transmission and will go in first position in the list.
   456  func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
   457  	logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey)
   458  
   459  	logrus.Debugf("Current: %v", d.keys)
   460  
   461  	var (
   462  		newIdx = -1
   463  		priIdx = -1
   464  		delIdx = -1
   465  		lIP    = net.ParseIP(d.bindAddress)
   466  		aIP    = net.ParseIP(d.advertiseAddress)
   467  	)
   468  
   469  	d.Lock()
   470  	defer d.Unlock()
   471  
   472  	// add new
   473  	if newKey != nil {
   474  		d.keys = append(d.keys, newKey)
   475  		newIdx += len(d.keys)
   476  	}
   477  	for i, k := range d.keys {
   478  		if primary != nil && k.tag == primary.tag {
   479  			priIdx = i
   480  		}
   481  		if pruneKey != nil && k.tag == pruneKey.tag {
   482  			delIdx = i
   483  		}
   484  	}
   485  
   486  	if (newKey != nil && newIdx == -1) ||
   487  		(primary != nil && priIdx == -1) ||
   488  		(pruneKey != nil && delIdx == -1) {
   489  		return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+
   490  			"(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx)
   491  	}
   492  
   493  	if priIdx != -1 && priIdx == delIdx {
   494  		return types.BadRequestErrorf("attempting to both make a key (index %d) primary and delete it", priIdx)
   495  	}
   496  
   497  	d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) {
   498  		rIP := net.ParseIP(rIPs)
   499  		return updateNodeKey(lIP, aIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false
   500  	})
   501  
   502  	// swap primary
   503  	if priIdx != -1 {
   504  		d.keys[0], d.keys[priIdx] = d.keys[priIdx], d.keys[0]
   505  	}
   506  	// prune
   507  	if delIdx != -1 {
   508  		if delIdx == 0 {
   509  			delIdx = priIdx
   510  		}
   511  		d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...)
   512  	}
   513  
   514  	logrus.Debugf("Updated: %v", d.keys)
   515  
   516  	return nil
   517  }
   518  
   519  /********************************************************
   520   * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1
   521   * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1
   522   * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2
   523   *********************************************************/
   524  
   525  // Spis and keys are sorted in such away the one in position 0 is the primary
   526  func updateNodeKey(lIP, aIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi {
   527  	logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx)
   528  
   529  	spis := idxs
   530  	logrus.Debugf("Current: %v", spis)
   531  
   532  	// add new
   533  	if newIdx != -1 {
   534  		spis = append(spis, &spi{
   535  			forward: buildSPI(aIP, rIP, curKeys[newIdx].tag),
   536  			reverse: buildSPI(rIP, aIP, curKeys[newIdx].tag),
   537  		})
   538  	}
   539  
   540  	if delIdx != -1 {
   541  		// -rSA0
   542  		programSA(lIP, rIP, spis[delIdx], nil, reverse, false)
   543  	}
   544  
   545  	if newIdx > -1 {
   546  		// +rSA2
   547  		programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true)
   548  	}
   549  
   550  	if priIdx > 0 {
   551  		// +fSA2
   552  		fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
   553  
   554  		// +fSP2, -fSP1
   555  		s := types.GetMinimalIP(fSA2.Src)
   556  		d := types.GetMinimalIP(fSA2.Dst)
   557  		fullMask := net.CIDRMask(8*len(s), 8*len(s))
   558  
   559  		fSP1 := &netlink.XfrmPolicy{
   560  			Src:     &net.IPNet{IP: s, Mask: fullMask},
   561  			Dst:     &net.IPNet{IP: d, Mask: fullMask},
   562  			Dir:     netlink.XFRM_DIR_OUT,
   563  			Proto:   17,
   564  			DstPort: 4789,
   565  			Mark:    &spMark,
   566  			Tmpls: []netlink.XfrmPolicyTmpl{
   567  				{
   568  					Src:   fSA2.Src,
   569  					Dst:   fSA2.Dst,
   570  					Proto: netlink.XFRM_PROTO_ESP,
   571  					Mode:  netlink.XFRM_MODE_TRANSPORT,
   572  					Spi:   fSA2.Spi,
   573  					Reqid: r,
   574  				},
   575  			},
   576  		}
   577  		logrus.Debugf("Updating fSP{%s}", fSP1)
   578  		if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil {
   579  			logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err)
   580  		}
   581  
   582  		// -fSA1
   583  		programSA(lIP, rIP, spis[0], nil, forward, false)
   584  	}
   585  
   586  	// swap
   587  	if priIdx > 0 {
   588  		swp := spis[0]
   589  		spis[0] = spis[priIdx]
   590  		spis[priIdx] = swp
   591  	}
   592  	// prune
   593  	if delIdx != -1 {
   594  		if delIdx == 0 {
   595  			delIdx = priIdx
   596  		}
   597  		spis = append(spis[:delIdx], spis[delIdx+1:]...)
   598  	}
   599  
   600  	logrus.Debugf("Updated: %v", spis)
   601  
   602  	return spis
   603  }
   604  
   605  func (n *network) maxMTU() int {
   606  	mtu := 1500
   607  	if n.mtu != 0 {
   608  		mtu = n.mtu
   609  	}
   610  	mtu -= vxlanEncap
   611  	if n.secure {
   612  		// In case of encryption account for the
   613  		// esp packet expansion and padding
   614  		mtu -= pktExpansion
   615  		mtu -= (mtu % 4)
   616  	}
   617  	return mtu
   618  }
   619  
   620  func clearEncryptionStates() {
   621  	nlh := ns.NlHandle()
   622  	spList, err := nlh.XfrmPolicyList(netlink.FAMILY_ALL)
   623  	if err != nil {
   624  		logrus.Warnf("Failed to retrieve SP list for cleanup: %v", err)
   625  	}
   626  	saList, err := nlh.XfrmStateList(netlink.FAMILY_ALL)
   627  	if err != nil {
   628  		logrus.Warnf("Failed to retrieve SA list for cleanup: %v", err)
   629  	}
   630  	for _, sp := range spList {
   631  		sp := sp
   632  		if sp.Mark != nil && sp.Mark.Value == spMark.Value {
   633  			if err := nlh.XfrmPolicyDel(&sp); err != nil {
   634  				logrus.Warnf("Failed to delete stale SP %s: %v", sp, err)
   635  				continue
   636  			}
   637  			logrus.Debugf("Removed stale SP: %s", sp)
   638  		}
   639  	}
   640  	for _, sa := range saList {
   641  		sa := sa
   642  		if sa.Reqid == r {
   643  			if err := nlh.XfrmStateDel(&sa); err != nil {
   644  				logrus.Warnf("Failed to delete stale SA %s: %v", sa, err)
   645  				continue
   646  			}
   647  			logrus.Debugf("Removed stale SA: %s", sa)
   648  		}
   649  	}
   650  }