github.com/pwn-term/docker@v0.0.0-20210616085119-6e977cce2565/libnetwork/drivers/overlay/encryption.go (about)

     1  package overlay
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"encoding/hex"
     7  	"fmt"
     8  	"hash/fnv"
     9  	"net"
    10  	"sync"
    11  	"syscall"
    12  
    13  	"strconv"
    14  
    15  	"github.com/docker/libnetwork/drivers/overlay/overlayutils"
    16  	"github.com/docker/libnetwork/iptables"
    17  	"github.com/docker/libnetwork/ns"
    18  	"github.com/docker/libnetwork/types"
    19  	"github.com/sirupsen/logrus"
    20  	"github.com/vishvananda/netlink"
    21  )
    22  
    23  const (
    24  	r            = 0xD0C4E3
    25  	pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8)
    26  )
    27  
    28  const (
    29  	forward = iota + 1
    30  	reverse
    31  	bidir
    32  )
    33  
    34  var spMark = netlink.XfrmMark{Value: uint32(r), Mask: 0xffffffff}
    35  
    36  type key struct {
    37  	value []byte
    38  	tag   uint32
    39  }
    40  
    41  func (k *key) String() string {
    42  	if k != nil {
    43  		return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag)
    44  	}
    45  	return ""
    46  }
    47  
    48  type spi struct {
    49  	forward int
    50  	reverse int
    51  }
    52  
    53  func (s *spi) String() string {
    54  	return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse))
    55  }
    56  
    57  type encrMap struct {
    58  	nodes map[string][]*spi
    59  	sync.Mutex
    60  }
    61  
    62  func (e *encrMap) String() string {
    63  	e.Lock()
    64  	defer e.Unlock()
    65  	b := new(bytes.Buffer)
    66  	for k, v := range e.nodes {
    67  		b.WriteString("\n")
    68  		b.WriteString(k)
    69  		b.WriteString(":")
    70  		b.WriteString("[")
    71  		for _, s := range v {
    72  			b.WriteString(s.String())
    73  			b.WriteString(",")
    74  		}
    75  		b.WriteString("]")
    76  
    77  	}
    78  	return b.String()
    79  }
    80  
    81  func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error {
    82  	logrus.Debugf("checkEncryption(%.7s, %v, %d, %t)", nid, rIP, vxlanID, isLocal)
    83  
    84  	n := d.network(nid)
    85  	if n == nil || !n.secure {
    86  		return nil
    87  	}
    88  
    89  	if len(d.keys) == 0 {
    90  		return types.ForbiddenErrorf("encryption key is not present")
    91  	}
    92  
    93  	lIP := net.ParseIP(d.bindAddress)
    94  	aIP := net.ParseIP(d.advertiseAddress)
    95  	nodes := map[string]net.IP{}
    96  
    97  	switch {
    98  	case isLocal:
    99  		if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
   100  			if !aIP.Equal(pEntry.vtep) {
   101  				nodes[pEntry.vtep.String()] = pEntry.vtep
   102  			}
   103  			return false
   104  		}); err != nil {
   105  			logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %.5s: %v", nid, err)
   106  		}
   107  	default:
   108  		if len(d.network(nid).endpoints) > 0 {
   109  			nodes[rIP.String()] = rIP
   110  		}
   111  	}
   112  
   113  	logrus.Debugf("List of nodes: %s", nodes)
   114  
   115  	if add {
   116  		for _, rIP := range nodes {
   117  			if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil {
   118  				logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err)
   119  			}
   120  		}
   121  	} else {
   122  		if len(nodes) == 0 {
   123  			if err := removeEncryption(lIP, rIP, d.secMap); err != nil {
   124  				logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err)
   125  			}
   126  		}
   127  	}
   128  
   129  	return nil
   130  }
   131  
   132  func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error {
   133  	logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP)
   134  	rIPs := remoteIP.String()
   135  
   136  	indices := make([]*spi, 0, len(keys))
   137  
   138  	err := programMangle(vni, true)
   139  	if err != nil {
   140  		logrus.Warn(err)
   141  	}
   142  
   143  	err = programInput(vni, true)
   144  	if err != nil {
   145  		logrus.Warn(err)
   146  	}
   147  
   148  	for i, k := range keys {
   149  		spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)}
   150  		dir := reverse
   151  		if i == 0 {
   152  			dir = bidir
   153  		}
   154  		fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true)
   155  		if err != nil {
   156  			logrus.Warn(err)
   157  		}
   158  		indices = append(indices, spis)
   159  		if i != 0 {
   160  			continue
   161  		}
   162  		err = programSP(fSA, rSA, true)
   163  		if err != nil {
   164  			logrus.Warn(err)
   165  		}
   166  	}
   167  
   168  	em.Lock()
   169  	em.nodes[rIPs] = indices
   170  	em.Unlock()
   171  
   172  	return nil
   173  }
   174  
   175  func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error {
   176  	em.Lock()
   177  	indices, ok := em.nodes[remoteIP.String()]
   178  	em.Unlock()
   179  	if !ok {
   180  		return nil
   181  	}
   182  	for i, idxs := range indices {
   183  		dir := reverse
   184  		if i == 0 {
   185  			dir = bidir
   186  		}
   187  		fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false)
   188  		if err != nil {
   189  			logrus.Warn(err)
   190  		}
   191  		if i != 0 {
   192  			continue
   193  		}
   194  		err = programSP(fSA, rSA, false)
   195  		if err != nil {
   196  			logrus.Warn(err)
   197  		}
   198  	}
   199  	return nil
   200  }
   201  
   202  func programMangle(vni uint32, add bool) (err error) {
   203  	var (
   204  		p      = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10)
   205  		c      = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
   206  		m      = strconv.FormatUint(uint64(r), 10)
   207  		chain  = "OUTPUT"
   208  		rule   = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m}
   209  		a      = "-A"
   210  		action = "install"
   211  	)
   212  
   213  	// TODO IPv6 support
   214  	iptable := iptables.GetIptable(iptables.IPv4)
   215  
   216  	if add == iptable.Exists(iptables.Mangle, chain, rule...) {
   217  		return
   218  	}
   219  
   220  	if !add {
   221  		a = "-D"
   222  		action = "remove"
   223  	}
   224  
   225  	if err = iptable.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil {
   226  		logrus.Warnf("could not %s mangle rule: %v", action, err)
   227  	}
   228  
   229  	return
   230  }
   231  
   232  func programInput(vni uint32, add bool) (err error) {
   233  	var (
   234  		port       = strconv.FormatUint(uint64(overlayutils.VXLANUDPPort()), 10)
   235  		vniMatch   = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
   236  		plainVxlan = []string{"-p", "udp", "--dport", port, "-m", "u32", "--u32", vniMatch, "-j"}
   237  		ipsecVxlan = append([]string{"-m", "policy", "--dir", "in", "--pol", "ipsec"}, plainVxlan...)
   238  		block      = append(plainVxlan, "DROP")
   239  		accept     = append(ipsecVxlan, "ACCEPT")
   240  		chain      = "INPUT"
   241  		action     = iptables.Append
   242  		msg        = "add"
   243  	)
   244  
   245  	// TODO IPv6 support
   246  	iptable := iptables.GetIptable(iptables.IPv4)
   247  
   248  	if !add {
   249  		action = iptables.Delete
   250  		msg = "remove"
   251  	}
   252  
   253  	if err := iptable.ProgramRule(iptables.Filter, chain, action, accept); err != nil {
   254  		logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err)
   255  	}
   256  
   257  	if err := iptable.ProgramRule(iptables.Filter, chain, action, block); err != nil {
   258  		logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err)
   259  	}
   260  
   261  	return
   262  }
   263  
   264  func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) {
   265  	var (
   266  		action      = "Removing"
   267  		xfrmProgram = ns.NlHandle().XfrmStateDel
   268  	)
   269  
   270  	if add {
   271  		action = "Adding"
   272  		xfrmProgram = ns.NlHandle().XfrmStateAdd
   273  	}
   274  
   275  	if dir&reverse > 0 {
   276  		rSA = &netlink.XfrmState{
   277  			Src:   remoteIP,
   278  			Dst:   localIP,
   279  			Proto: netlink.XFRM_PROTO_ESP,
   280  			Spi:   spi.reverse,
   281  			Mode:  netlink.XFRM_MODE_TRANSPORT,
   282  			Reqid: r,
   283  		}
   284  		if add {
   285  			rSA.Aead = buildAeadAlgo(k, spi.reverse)
   286  		}
   287  
   288  		exists, err := saExists(rSA)
   289  		if err != nil {
   290  			exists = !add
   291  		}
   292  
   293  		if add != exists {
   294  			logrus.Debugf("%s: rSA{%s}", action, rSA)
   295  			if err := xfrmProgram(rSA); err != nil {
   296  				logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err)
   297  			}
   298  		}
   299  	}
   300  
   301  	if dir&forward > 0 {
   302  		fSA = &netlink.XfrmState{
   303  			Src:   localIP,
   304  			Dst:   remoteIP,
   305  			Proto: netlink.XFRM_PROTO_ESP,
   306  			Spi:   spi.forward,
   307  			Mode:  netlink.XFRM_MODE_TRANSPORT,
   308  			Reqid: r,
   309  		}
   310  		if add {
   311  			fSA.Aead = buildAeadAlgo(k, spi.forward)
   312  		}
   313  
   314  		exists, err := saExists(fSA)
   315  		if err != nil {
   316  			exists = !add
   317  		}
   318  
   319  		if add != exists {
   320  			logrus.Debugf("%s fSA{%s}", action, fSA)
   321  			if err := xfrmProgram(fSA); err != nil {
   322  				logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err)
   323  			}
   324  		}
   325  	}
   326  
   327  	return
   328  }
   329  
   330  func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error {
   331  	action := "Removing"
   332  	xfrmProgram := ns.NlHandle().XfrmPolicyDel
   333  	if add {
   334  		action = "Adding"
   335  		xfrmProgram = ns.NlHandle().XfrmPolicyAdd
   336  	}
   337  
   338  	// Create a congruent cidr
   339  	s := types.GetMinimalIP(fSA.Src)
   340  	d := types.GetMinimalIP(fSA.Dst)
   341  	fullMask := net.CIDRMask(8*len(s), 8*len(s))
   342  
   343  	fPol := &netlink.XfrmPolicy{
   344  		Src:     &net.IPNet{IP: s, Mask: fullMask},
   345  		Dst:     &net.IPNet{IP: d, Mask: fullMask},
   346  		Dir:     netlink.XFRM_DIR_OUT,
   347  		Proto:   17,
   348  		DstPort: 4789,
   349  		Mark:    &spMark,
   350  		Tmpls: []netlink.XfrmPolicyTmpl{
   351  			{
   352  				Src:   fSA.Src,
   353  				Dst:   fSA.Dst,
   354  				Proto: netlink.XFRM_PROTO_ESP,
   355  				Mode:  netlink.XFRM_MODE_TRANSPORT,
   356  				Spi:   fSA.Spi,
   357  				Reqid: r,
   358  			},
   359  		},
   360  	}
   361  
   362  	exists, err := spExists(fPol)
   363  	if err != nil {
   364  		exists = !add
   365  	}
   366  
   367  	if add != exists {
   368  		logrus.Debugf("%s fSP{%s}", action, fPol)
   369  		if err := xfrmProgram(fPol); err != nil {
   370  			logrus.Warnf("%s fSP{%s}: %v", action, fPol, err)
   371  		}
   372  	}
   373  
   374  	return nil
   375  }
   376  
   377  func saExists(sa *netlink.XfrmState) (bool, error) {
   378  	_, err := ns.NlHandle().XfrmStateGet(sa)
   379  	switch err {
   380  	case nil:
   381  		return true, nil
   382  	case syscall.ESRCH:
   383  		return false, nil
   384  	default:
   385  		err = fmt.Errorf("Error while checking for SA existence: %v", err)
   386  		logrus.Warn(err)
   387  		return false, err
   388  	}
   389  }
   390  
   391  func spExists(sp *netlink.XfrmPolicy) (bool, error) {
   392  	_, err := ns.NlHandle().XfrmPolicyGet(sp)
   393  	switch err {
   394  	case nil:
   395  		return true, nil
   396  	case syscall.ENOENT:
   397  		return false, nil
   398  	default:
   399  		err = fmt.Errorf("Error while checking for SP existence: %v", err)
   400  		logrus.Warn(err)
   401  		return false, err
   402  	}
   403  }
   404  
   405  func buildSPI(src, dst net.IP, st uint32) int {
   406  	b := make([]byte, 4)
   407  	binary.BigEndian.PutUint32(b, st)
   408  	h := fnv.New32a()
   409  	h.Write(src)
   410  	h.Write(b)
   411  	h.Write(dst)
   412  	return int(binary.BigEndian.Uint32(h.Sum(nil)))
   413  }
   414  
   415  func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo {
   416  	salt := make([]byte, 4)
   417  	binary.BigEndian.PutUint32(salt, uint32(s))
   418  	return &netlink.XfrmStateAlgo{
   419  		Name:   "rfc4106(gcm(aes))",
   420  		Key:    append(k.value, salt...),
   421  		ICVLen: 64,
   422  	}
   423  }
   424  
   425  func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error {
   426  	d.secMap.Lock()
   427  	for node, indices := range d.secMap.nodes {
   428  		idxs, stop := f(node, indices)
   429  		if idxs != nil {
   430  			d.secMap.nodes[node] = idxs
   431  		}
   432  		if stop {
   433  			break
   434  		}
   435  	}
   436  	d.secMap.Unlock()
   437  	return nil
   438  }
   439  
   440  func (d *driver) setKeys(keys []*key) error {
   441  	// Remove any stale policy, state
   442  	clearEncryptionStates()
   443  	// Accept the encryption keys and clear any stale encryption map
   444  	d.Lock()
   445  	d.keys = keys
   446  	d.secMap = &encrMap{nodes: map[string][]*spi{}}
   447  	d.Unlock()
   448  	logrus.Debugf("Initial encryption keys: %v", keys)
   449  	return nil
   450  }
   451  
   452  // updateKeys allows to add a new key and/or change the primary key and/or prune an existing key
   453  // The primary key is the key used in transmission and will go in first position in the list.
   454  func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
   455  	logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey)
   456  
   457  	logrus.Debugf("Current: %v", d.keys)
   458  
   459  	var (
   460  		newIdx = -1
   461  		priIdx = -1
   462  		delIdx = -1
   463  		lIP    = net.ParseIP(d.bindAddress)
   464  		aIP    = net.ParseIP(d.advertiseAddress)
   465  	)
   466  
   467  	d.Lock()
   468  	defer d.Unlock()
   469  
   470  	// add new
   471  	if newKey != nil {
   472  		d.keys = append(d.keys, newKey)
   473  		newIdx += len(d.keys)
   474  	}
   475  	for i, k := range d.keys {
   476  		if primary != nil && k.tag == primary.tag {
   477  			priIdx = i
   478  		}
   479  		if pruneKey != nil && k.tag == pruneKey.tag {
   480  			delIdx = i
   481  		}
   482  	}
   483  
   484  	if (newKey != nil && newIdx == -1) ||
   485  		(primary != nil && priIdx == -1) ||
   486  		(pruneKey != nil && delIdx == -1) {
   487  		return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+
   488  			"(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx)
   489  	}
   490  
   491  	if priIdx != -1 && priIdx == delIdx {
   492  		return types.BadRequestErrorf("attempting to both make a key (index %d) primary and delete it", priIdx)
   493  	}
   494  
   495  	d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) {
   496  		rIP := net.ParseIP(rIPs)
   497  		return updateNodeKey(lIP, aIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false
   498  	})
   499  
   500  	// swap primary
   501  	if priIdx != -1 {
   502  		d.keys[0], d.keys[priIdx] = d.keys[priIdx], d.keys[0]
   503  	}
   504  	// prune
   505  	if delIdx != -1 {
   506  		if delIdx == 0 {
   507  			delIdx = priIdx
   508  		}
   509  		d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...)
   510  	}
   511  
   512  	logrus.Debugf("Updated: %v", d.keys)
   513  
   514  	return nil
   515  }
   516  
   517  /********************************************************
   518   * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1
   519   * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1
   520   * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2
   521   *********************************************************/
   522  
   523  // Spis and keys are sorted in such away the one in position 0 is the primary
   524  func updateNodeKey(lIP, aIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi {
   525  	logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx)
   526  
   527  	spis := idxs
   528  	logrus.Debugf("Current: %v", spis)
   529  
   530  	// add new
   531  	if newIdx != -1 {
   532  		spis = append(spis, &spi{
   533  			forward: buildSPI(aIP, rIP, curKeys[newIdx].tag),
   534  			reverse: buildSPI(rIP, aIP, curKeys[newIdx].tag),
   535  		})
   536  	}
   537  
   538  	if delIdx != -1 {
   539  		// -rSA0
   540  		programSA(lIP, rIP, spis[delIdx], nil, reverse, false)
   541  	}
   542  
   543  	if newIdx > -1 {
   544  		// +rSA2
   545  		programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true)
   546  	}
   547  
   548  	if priIdx > 0 {
   549  		// +fSA2
   550  		fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
   551  
   552  		// +fSP2, -fSP1
   553  		s := types.GetMinimalIP(fSA2.Src)
   554  		d := types.GetMinimalIP(fSA2.Dst)
   555  		fullMask := net.CIDRMask(8*len(s), 8*len(s))
   556  
   557  		fSP1 := &netlink.XfrmPolicy{
   558  			Src:     &net.IPNet{IP: s, Mask: fullMask},
   559  			Dst:     &net.IPNet{IP: d, Mask: fullMask},
   560  			Dir:     netlink.XFRM_DIR_OUT,
   561  			Proto:   17,
   562  			DstPort: 4789,
   563  			Mark:    &spMark,
   564  			Tmpls: []netlink.XfrmPolicyTmpl{
   565  				{
   566  					Src:   fSA2.Src,
   567  					Dst:   fSA2.Dst,
   568  					Proto: netlink.XFRM_PROTO_ESP,
   569  					Mode:  netlink.XFRM_MODE_TRANSPORT,
   570  					Spi:   fSA2.Spi,
   571  					Reqid: r,
   572  				},
   573  			},
   574  		}
   575  		logrus.Debugf("Updating fSP{%s}", fSP1)
   576  		if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil {
   577  			logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err)
   578  		}
   579  
   580  		// -fSA1
   581  		programSA(lIP, rIP, spis[0], nil, forward, false)
   582  	}
   583  
   584  	// swap
   585  	if priIdx > 0 {
   586  		swp := spis[0]
   587  		spis[0] = spis[priIdx]
   588  		spis[priIdx] = swp
   589  	}
   590  	// prune
   591  	if delIdx != -1 {
   592  		if delIdx == 0 {
   593  			delIdx = priIdx
   594  		}
   595  		spis = append(spis[:delIdx], spis[delIdx+1:]...)
   596  	}
   597  
   598  	logrus.Debugf("Updated: %v", spis)
   599  
   600  	return spis
   601  }
   602  
   603  func (n *network) maxMTU() int {
   604  	mtu := 1500
   605  	if n.mtu != 0 {
   606  		mtu = n.mtu
   607  	}
   608  	mtu -= vxlanEncap
   609  	if n.secure {
   610  		// In case of encryption account for the
   611  		// esp packet expansion and padding
   612  		mtu -= pktExpansion
   613  		mtu -= (mtu % 4)
   614  	}
   615  	return mtu
   616  }
   617  
   618  func clearEncryptionStates() {
   619  	nlh := ns.NlHandle()
   620  	spList, err := nlh.XfrmPolicyList(netlink.FAMILY_ALL)
   621  	if err != nil {
   622  		logrus.Warnf("Failed to retrieve SP list for cleanup: %v", err)
   623  	}
   624  	saList, err := nlh.XfrmStateList(netlink.FAMILY_ALL)
   625  	if err != nil {
   626  		logrus.Warnf("Failed to retrieve SA list for cleanup: %v", err)
   627  	}
   628  	for _, sp := range spList {
   629  		if sp.Mark != nil && sp.Mark.Value == spMark.Value {
   630  			if err := nlh.XfrmPolicyDel(&sp); err != nil {
   631  				logrus.Warnf("Failed to delete stale SP %s: %v", sp, err)
   632  				continue
   633  			}
   634  			logrus.Debugf("Removed stale SP: %s", sp)
   635  		}
   636  	}
   637  	for _, sa := range saList {
   638  		if sa.Reqid == r {
   639  			if err := nlh.XfrmStateDel(&sa); err != nil {
   640  				logrus.Warnf("Failed to delete stale SA %s: %v", sa, err)
   641  				continue
   642  			}
   643  			logrus.Debugf("Removed stale SA: %s", sa)
   644  		}
   645  	}
   646  }