github.com/adityamillind98/moby@v23.0.0-rc.4+incompatible/libnetwork/drivers/overlay/encryption.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package overlay
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/binary"
     9  	"encoding/hex"
    10  	"fmt"
    11  	"hash/fnv"
    12  	"net"
    13  	"strconv"
    14  	"sync"
    15  	"syscall"
    16  
    17  	"github.com/docker/docker/libnetwork/drivers/overlay/overlayutils"
    18  	"github.com/docker/docker/libnetwork/iptables"
    19  	"github.com/docker/docker/libnetwork/ns"
    20  	"github.com/docker/docker/libnetwork/types"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/vishvananda/netlink"
    23  )
    24  
    25  const (
    26  	r            = 0xD0C4E3
    27  	pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8)
    28  )
    29  
    30  const (
    31  	forward = iota + 1
    32  	reverse
    33  	bidir
    34  )
    35  
    36  var spMark = netlink.XfrmMark{Value: uint32(r), Mask: 0xffffffff}
    37  
    38  type key struct {
    39  	value []byte
    40  	tag   uint32
    41  }
    42  
    43  func (k *key) String() string {
    44  	if k != nil {
    45  		return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag)
    46  	}
    47  	return ""
    48  }
    49  
    50  type spi struct {
    51  	forward int
    52  	reverse int
    53  }
    54  
    55  func (s *spi) String() string {
    56  	return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse))
    57  }
    58  
    59  type encrMap struct {
    60  	nodes map[string][]*spi
    61  	sync.Mutex
    62  }
    63  
    64  func (e *encrMap) String() string {
    65  	e.Lock()
    66  	defer e.Unlock()
    67  	b := new(bytes.Buffer)
    68  	for k, v := range e.nodes {
    69  		b.WriteString("\n")
    70  		b.WriteString(k)
    71  		b.WriteString(":")
    72  		b.WriteString("[")
    73  		for _, s := range v {
    74  			b.WriteString(s.String())
    75  			b.WriteString(",")
    76  		}
    77  		b.WriteString("]")
    78  	}
    79  	return b.String()
    80  }
    81  
    82  func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error {
    83  	logrus.Debugf("checkEncryption(%.7s, %v, %d, %t)", nid, rIP, vxlanID, isLocal)
    84  
    85  	n := d.network(nid)
    86  	if n == nil || !n.secure {
    87  		return nil
    88  	}
    89  
    90  	if len(d.keys) == 0 {
    91  		return types.ForbiddenErrorf("encryption key is not present")
    92  	}
    93  
    94  	lIP := net.ParseIP(d.bindAddress)
    95  	aIP := net.ParseIP(d.advertiseAddress)
    96  	nodes := map[string]net.IP{}
    97  
    98  	switch {
    99  	case isLocal:
   100  		if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
   101  			if !aIP.Equal(pEntry.vtep) {
   102  				nodes[pEntry.vtep.String()] = pEntry.vtep
   103  			}
   104  			return false
   105  		}); err != nil {
   106  			logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %.5s: %v", nid, err)
   107  		}
   108  	default:
   109  		if len(d.network(nid).endpoints) > 0 {
   110  			nodes[rIP.String()] = rIP
   111  		}
   112  	}
   113  
   114  	logrus.Debugf("List of nodes: %s", nodes)
   115  
   116  	if add {
   117  		for _, rIP := range nodes {
   118  			if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil {
   119  				logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err)
   120  			}
   121  		}
   122  	} else {
   123  		if len(nodes) == 0 {
   124  			if err := removeEncryption(lIP, rIP, d.secMap); err != nil {
   125  				logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err)
   126  			}
   127  		}
   128  	}
   129  
   130  	return nil
   131  }
   132  
   133  func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error {
   134  	logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP)
   135  	rIPs := remoteIP.String()
   136  
   137  	indices := make([]*spi, 0, len(keys))
   138  
   139  	err := programMangle(vni, true)
   140  	if err != nil {
   141  		logrus.Warn(err)
   142  	}
   143  
   144  	err = programInput(vni, true)
   145  	if err != nil {
   146  		logrus.Warn(err)
   147  	}
   148  
   149  	for i, k := range keys {
   150  		spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)}
   151  		dir := reverse
   152  		if i == 0 {
   153  			dir = bidir
   154  		}
   155  		fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true)
   156  		if err != nil {
   157  			logrus.Warn(err)
   158  		}
   159  		indices = append(indices, spis)
   160  		if i != 0 {
   161  			continue
   162  		}
   163  		err = programSP(fSA, rSA, true)
   164  		if err != nil {
   165  			logrus.Warn(err)
   166  		}
   167  	}
   168  
   169  	em.Lock()
   170  	em.nodes[rIPs] = indices
   171  	em.Unlock()
   172  
   173  	return nil
   174  }
   175  
   176  func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error {
   177  	em.Lock()
   178  	indices, ok := em.nodes[remoteIP.String()]
   179  	em.Unlock()
   180  	if !ok {
   181  		return nil
   182  	}
   183  	for i, idxs := range indices {
   184  		dir := reverse
   185  		if i == 0 {
   186  			dir = bidir
   187  		}
   188  		fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false)
   189  		if err != nil {
   190  			logrus.Warn(err)
   191  		}
   192  		if i != 0 {
   193  			continue
   194  		}
   195  		err = programSP(fSA, rSA, false)
   196  		if err != nil {
   197  			logrus.Warn(err)
   198  		}
   199  	}
   200  	return nil
   201  }
   202  
   203  func programMangle(vni uint32, add bool) (err error) {
   204  	var (
   205  		p      = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10)
   206  		c      = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
   207  		m      = strconv.FormatUint(uint64(r), 10)
   208  		chain  = "OUTPUT"
   209  		rule   = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m}
   210  		a      = "-A"
   211  		action = "install"
   212  	)
   213  
   214  	// TODO IPv6 support
   215  	iptable := iptables.GetIptable(iptables.IPv4)
   216  
   217  	if add == iptable.Exists(iptables.Mangle, chain, rule...) {
   218  		return
   219  	}
   220  
   221  	if !add {
   222  		a = "-D"
   223  		action = "remove"
   224  	}
   225  
   226  	if err = iptable.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil {
   227  		logrus.Warnf("could not %s mangle rule: %v", action, err)
   228  	}
   229  
   230  	return
   231  }
   232  
   233  func programInput(vni uint32, add bool) (err error) {
   234  	var (
   235  		port       = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10)
   236  		vniMatch   = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
   237  		plainVxlan = []string{"-p", "udp", "--dport", port, "-m", "u32", "--u32", vniMatch, "-j"}
   238  		ipsecVxlan = append([]string{"-m", "policy", "--dir", "in", "--pol", "ipsec"}, plainVxlan...)
   239  		block      = append(plainVxlan, "DROP")
   240  		accept     = append(ipsecVxlan, "ACCEPT")
   241  		chain      = "INPUT"
   242  		action     = iptables.Append
   243  		msg        = "add"
   244  	)
   245  
   246  	// TODO IPv6 support
   247  	iptable := iptables.GetIptable(iptables.IPv4)
   248  
   249  	if !add {
   250  		action = iptables.Delete
   251  		msg = "remove"
   252  	}
   253  
   254  	if err := iptable.ProgramRule(iptables.Filter, chain, action, accept); err != nil {
   255  		logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err)
   256  	}
   257  
   258  	if err := iptable.ProgramRule(iptables.Filter, chain, action, block); err != nil {
   259  		logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err)
   260  	}
   261  
   262  	return
   263  }
   264  
   265  func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) {
   266  	var (
   267  		action      = "Removing"
   268  		xfrmProgram = ns.NlHandle().XfrmStateDel
   269  	)
   270  
   271  	if add {
   272  		action = "Adding"
   273  		xfrmProgram = ns.NlHandle().XfrmStateAdd
   274  	}
   275  
   276  	if dir&reverse > 0 {
   277  		rSA = &netlink.XfrmState{
   278  			Src:   remoteIP,
   279  			Dst:   localIP,
   280  			Proto: netlink.XFRM_PROTO_ESP,
   281  			Spi:   spi.reverse,
   282  			Mode:  netlink.XFRM_MODE_TRANSPORT,
   283  			Reqid: r,
   284  		}
   285  		if add {
   286  			rSA.Aead = buildAeadAlgo(k, spi.reverse)
   287  		}
   288  
   289  		exists, err := saExists(rSA)
   290  		if err != nil {
   291  			exists = !add
   292  		}
   293  
   294  		if add != exists {
   295  			logrus.Debugf("%s: rSA{%s}", action, rSA)
   296  			if err := xfrmProgram(rSA); err != nil {
   297  				logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err)
   298  			}
   299  		}
   300  	}
   301  
   302  	if dir&forward > 0 {
   303  		fSA = &netlink.XfrmState{
   304  			Src:   localIP,
   305  			Dst:   remoteIP,
   306  			Proto: netlink.XFRM_PROTO_ESP,
   307  			Spi:   spi.forward,
   308  			Mode:  netlink.XFRM_MODE_TRANSPORT,
   309  			Reqid: r,
   310  		}
   311  		if add {
   312  			fSA.Aead = buildAeadAlgo(k, spi.forward)
   313  		}
   314  
   315  		exists, err := saExists(fSA)
   316  		if err != nil {
   317  			exists = !add
   318  		}
   319  
   320  		if add != exists {
   321  			logrus.Debugf("%s fSA{%s}", action, fSA)
   322  			if err := xfrmProgram(fSA); err != nil {
   323  				logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err)
   324  			}
   325  		}
   326  	}
   327  
   328  	return
   329  }
   330  
   331  func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error {
   332  	action := "Removing"
   333  	xfrmProgram := ns.NlHandle().XfrmPolicyDel
   334  	if add {
   335  		action = "Adding"
   336  		xfrmProgram = ns.NlHandle().XfrmPolicyAdd
   337  	}
   338  
   339  	// Create a congruent cidr
   340  	s := types.GetMinimalIP(fSA.Src)
   341  	d := types.GetMinimalIP(fSA.Dst)
   342  	fullMask := net.CIDRMask(8*len(s), 8*len(s))
   343  
   344  	fPol := &netlink.XfrmPolicy{
   345  		Src:     &net.IPNet{IP: s, Mask: fullMask},
   346  		Dst:     &net.IPNet{IP: d, Mask: fullMask},
   347  		Dir:     netlink.XFRM_DIR_OUT,
   348  		Proto:   17,
   349  		DstPort: 4789,
   350  		Mark:    &spMark,
   351  		Tmpls: []netlink.XfrmPolicyTmpl{
   352  			{
   353  				Src:   fSA.Src,
   354  				Dst:   fSA.Dst,
   355  				Proto: netlink.XFRM_PROTO_ESP,
   356  				Mode:  netlink.XFRM_MODE_TRANSPORT,
   357  				Spi:   fSA.Spi,
   358  				Reqid: r,
   359  			},
   360  		},
   361  	}
   362  
   363  	exists, err := spExists(fPol)
   364  	if err != nil {
   365  		exists = !add
   366  	}
   367  
   368  	if add != exists {
   369  		logrus.Debugf("%s fSP{%s}", action, fPol)
   370  		if err := xfrmProgram(fPol); err != nil {
   371  			logrus.Warnf("%s fSP{%s}: %v", action, fPol, err)
   372  		}
   373  	}
   374  
   375  	return nil
   376  }
   377  
   378  func saExists(sa *netlink.XfrmState) (bool, error) {
   379  	_, err := ns.NlHandle().XfrmStateGet(sa)
   380  	switch err {
   381  	case nil:
   382  		return true, nil
   383  	case syscall.ESRCH:
   384  		return false, nil
   385  	default:
   386  		err = fmt.Errorf("Error while checking for SA existence: %v", err)
   387  		logrus.Warn(err)
   388  		return false, err
   389  	}
   390  }
   391  
   392  func spExists(sp *netlink.XfrmPolicy) (bool, error) {
   393  	_, err := ns.NlHandle().XfrmPolicyGet(sp)
   394  	switch err {
   395  	case nil:
   396  		return true, nil
   397  	case syscall.ENOENT:
   398  		return false, nil
   399  	default:
   400  		err = fmt.Errorf("Error while checking for SP existence: %v", err)
   401  		logrus.Warn(err)
   402  		return false, err
   403  	}
   404  }
   405  
   406  func buildSPI(src, dst net.IP, st uint32) int {
   407  	b := make([]byte, 4)
   408  	binary.BigEndian.PutUint32(b, st)
   409  	h := fnv.New32a()
   410  	h.Write(src)
   411  	h.Write(b)
   412  	h.Write(dst)
   413  	return int(binary.BigEndian.Uint32(h.Sum(nil)))
   414  }
   415  
   416  func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo {
   417  	salt := make([]byte, 4)
   418  	binary.BigEndian.PutUint32(salt, uint32(s))
   419  	return &netlink.XfrmStateAlgo{
   420  		Name:   "rfc4106(gcm(aes))",
   421  		Key:    append(k.value, salt...),
   422  		ICVLen: 64,
   423  	}
   424  }
   425  
   426  func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error {
   427  	d.secMap.Lock()
   428  	for node, indices := range d.secMap.nodes {
   429  		idxs, stop := f(node, indices)
   430  		if idxs != nil {
   431  			d.secMap.nodes[node] = idxs
   432  		}
   433  		if stop {
   434  			break
   435  		}
   436  	}
   437  	d.secMap.Unlock()
   438  	return nil
   439  }
   440  
   441  func (d *driver) setKeys(keys []*key) error {
   442  	// Remove any stale policy, state
   443  	clearEncryptionStates()
   444  	// Accept the encryption keys and clear any stale encryption map
   445  	d.Lock()
   446  	d.keys = keys
   447  	d.secMap = &encrMap{nodes: map[string][]*spi{}}
   448  	d.Unlock()
   449  	logrus.Debugf("Initial encryption keys: %v", keys)
   450  	return nil
   451  }
   452  
   453  // updateKeys allows to add a new key and/or change the primary key and/or prune an existing key
   454  // The primary key is the key used in transmission and will go in first position in the list.
   455  func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
   456  	logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey)
   457  
   458  	logrus.Debugf("Current: %v", d.keys)
   459  
   460  	var (
   461  		newIdx = -1
   462  		priIdx = -1
   463  		delIdx = -1
   464  		lIP    = net.ParseIP(d.bindAddress)
   465  		aIP    = net.ParseIP(d.advertiseAddress)
   466  	)
   467  
   468  	d.Lock()
   469  	defer d.Unlock()
   470  
   471  	// add new
   472  	if newKey != nil {
   473  		d.keys = append(d.keys, newKey)
   474  		newIdx += len(d.keys)
   475  	}
   476  	for i, k := range d.keys {
   477  		if primary != nil && k.tag == primary.tag {
   478  			priIdx = i
   479  		}
   480  		if pruneKey != nil && k.tag == pruneKey.tag {
   481  			delIdx = i
   482  		}
   483  	}
   484  
   485  	if (newKey != nil && newIdx == -1) ||
   486  		(primary != nil && priIdx == -1) ||
   487  		(pruneKey != nil && delIdx == -1) {
   488  		return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+
   489  			"(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx)
   490  	}
   491  
   492  	if priIdx != -1 && priIdx == delIdx {
   493  		return types.BadRequestErrorf("attempting to both make a key (index %d) primary and delete it", priIdx)
   494  	}
   495  
   496  	d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) {
   497  		rIP := net.ParseIP(rIPs)
   498  		return updateNodeKey(lIP, aIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false
   499  	})
   500  
   501  	// swap primary
   502  	if priIdx != -1 {
   503  		d.keys[0], d.keys[priIdx] = d.keys[priIdx], d.keys[0]
   504  	}
   505  	// prune
   506  	if delIdx != -1 {
   507  		if delIdx == 0 {
   508  			delIdx = priIdx
   509  		}
   510  		d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...)
   511  	}
   512  
   513  	logrus.Debugf("Updated: %v", d.keys)
   514  
   515  	return nil
   516  }
   517  
   518  /********************************************************
   519   * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1
   520   * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1
   521   * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2
   522   *********************************************************/
   523  
   524  // Spis and keys are sorted in such away the one in position 0 is the primary
   525  func updateNodeKey(lIP, aIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi {
   526  	logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx)
   527  
   528  	spis := idxs
   529  	logrus.Debugf("Current: %v", spis)
   530  
   531  	// add new
   532  	if newIdx != -1 {
   533  		spis = append(spis, &spi{
   534  			forward: buildSPI(aIP, rIP, curKeys[newIdx].tag),
   535  			reverse: buildSPI(rIP, aIP, curKeys[newIdx].tag),
   536  		})
   537  	}
   538  
   539  	if delIdx != -1 {
   540  		// -rSA0
   541  		programSA(lIP, rIP, spis[delIdx], nil, reverse, false)
   542  	}
   543  
   544  	if newIdx > -1 {
   545  		// +rSA2
   546  		programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true)
   547  	}
   548  
   549  	if priIdx > 0 {
   550  		// +fSA2
   551  		fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
   552  
   553  		// +fSP2, -fSP1
   554  		s := types.GetMinimalIP(fSA2.Src)
   555  		d := types.GetMinimalIP(fSA2.Dst)
   556  		fullMask := net.CIDRMask(8*len(s), 8*len(s))
   557  
   558  		fSP1 := &netlink.XfrmPolicy{
   559  			Src:     &net.IPNet{IP: s, Mask: fullMask},
   560  			Dst:     &net.IPNet{IP: d, Mask: fullMask},
   561  			Dir:     netlink.XFRM_DIR_OUT,
   562  			Proto:   17,
   563  			DstPort: 4789,
   564  			Mark:    &spMark,
   565  			Tmpls: []netlink.XfrmPolicyTmpl{
   566  				{
   567  					Src:   fSA2.Src,
   568  					Dst:   fSA2.Dst,
   569  					Proto: netlink.XFRM_PROTO_ESP,
   570  					Mode:  netlink.XFRM_MODE_TRANSPORT,
   571  					Spi:   fSA2.Spi,
   572  					Reqid: r,
   573  				},
   574  			},
   575  		}
   576  		logrus.Debugf("Updating fSP{%s}", fSP1)
   577  		if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil {
   578  			logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err)
   579  		}
   580  
   581  		// -fSA1
   582  		programSA(lIP, rIP, spis[0], nil, forward, false)
   583  	}
   584  
   585  	// swap
   586  	if priIdx > 0 {
   587  		swp := spis[0]
   588  		spis[0] = spis[priIdx]
   589  		spis[priIdx] = swp
   590  	}
   591  	// prune
   592  	if delIdx != -1 {
   593  		if delIdx == 0 {
   594  			delIdx = priIdx
   595  		}
   596  		spis = append(spis[:delIdx], spis[delIdx+1:]...)
   597  	}
   598  
   599  	logrus.Debugf("Updated: %v", spis)
   600  
   601  	return spis
   602  }
   603  
   604  func (n *network) maxMTU() int {
   605  	mtu := 1500
   606  	if n.mtu != 0 {
   607  		mtu = n.mtu
   608  	}
   609  	mtu -= vxlanEncap
   610  	if n.secure {
   611  		// In case of encryption account for the
   612  		// esp packet expansion and padding
   613  		mtu -= pktExpansion
   614  		mtu -= (mtu % 4)
   615  	}
   616  	return mtu
   617  }
   618  
   619  func clearEncryptionStates() {
   620  	nlh := ns.NlHandle()
   621  	spList, err := nlh.XfrmPolicyList(netlink.FAMILY_ALL)
   622  	if err != nil {
   623  		logrus.Warnf("Failed to retrieve SP list for cleanup: %v", err)
   624  	}
   625  	saList, err := nlh.XfrmStateList(netlink.FAMILY_ALL)
   626  	if err != nil {
   627  		logrus.Warnf("Failed to retrieve SA list for cleanup: %v", err)
   628  	}
   629  	for _, sp := range spList {
   630  		sp := sp
   631  		if sp.Mark != nil && sp.Mark.Value == spMark.Value {
   632  			if err := nlh.XfrmPolicyDel(&sp); err != nil {
   633  				logrus.Warnf("Failed to delete stale SP %s: %v", sp, err)
   634  				continue
   635  			}
   636  			logrus.Debugf("Removed stale SP: %s", sp)
   637  		}
   638  	}
   639  	for _, sa := range saList {
   640  		sa := sa
   641  		if sa.Reqid == r {
   642  			if err := nlh.XfrmStateDel(&sa); err != nil {
   643  				logrus.Warnf("Failed to delete stale SA %s: %v", sa, err)
   644  				continue
   645  			}
   646  			logrus.Debugf("Removed stale SA: %s", sa)
   647  		}
   648  	}
   649  }