github.com/zhuohuang-hust/src-cbuild@v0.0.0-20230105071821-c7aab3e7c840/mergeCode/libnetwork/drivers/overlay/encryption.go (about)

     1  package overlay
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"encoding/hex"
     7  	"fmt"
     8  	"hash/fnv"
     9  	"net"
    10  	"sync"
    11  	"syscall"
    12  
    13  	"strconv"
    14  
    15  	log "github.com/Sirupsen/logrus"
    16  	"github.com/docker/libnetwork/iptables"
    17  	"github.com/docker/libnetwork/ns"
    18  	"github.com/docker/libnetwork/types"
    19  	"github.com/vishvananda/netlink"
    20  )
    21  
    22  const (
    23  	mark         = uint32(0xD0C4E3)
    24  	timeout      = 30
    25  	pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8)
    26  )
    27  
    28  const (
    29  	forward = iota + 1
    30  	reverse
    31  	bidir
    32  )
    33  
    34  type key struct {
    35  	value []byte
    36  	tag   uint32
    37  }
    38  
    39  func (k *key) String() string {
    40  	if k != nil {
    41  		return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag)
    42  	}
    43  	return ""
    44  }
    45  
    46  type spi struct {
    47  	forward int
    48  	reverse int
    49  }
    50  
    51  func (s *spi) String() string {
    52  	return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse))
    53  }
    54  
    55  type encrMap struct {
    56  	nodes map[string][]*spi
    57  	sync.Mutex
    58  }
    59  
    60  func (e *encrMap) String() string {
    61  	e.Lock()
    62  	defer e.Unlock()
    63  	b := new(bytes.Buffer)
    64  	for k, v := range e.nodes {
    65  		b.WriteString("\n")
    66  		b.WriteString(k)
    67  		b.WriteString(":")
    68  		b.WriteString("[")
    69  		for _, s := range v {
    70  			b.WriteString(s.String())
    71  			b.WriteString(",")
    72  		}
    73  		b.WriteString("]")
    74  
    75  	}
    76  	return b.String()
    77  }
    78  
    79  func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error {
    80  	log.Debugf("checkEncryption(%s, %v, %d, %t)", nid[0:7], rIP, vxlanID, isLocal)
    81  
    82  	n := d.network(nid)
    83  	if n == nil || !n.secure {
    84  		return nil
    85  	}
    86  
    87  	if len(d.keys) == 0 {
    88  		return types.ForbiddenErrorf("encryption key is not present")
    89  	}
    90  
    91  	lIP := net.ParseIP(d.bindAddress)
    92  	aIP := net.ParseIP(d.advertiseAddress)
    93  	nodes := map[string]net.IP{}
    94  
    95  	switch {
    96  	case isLocal:
    97  		if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
    98  			if !aIP.Equal(pEntry.vtep) {
    99  				nodes[pEntry.vtep.String()] = pEntry.vtep
   100  			}
   101  			return false
   102  		}); err != nil {
   103  			log.Warnf("Failed to retrieve list of participating nodes in overlay network %s: %v", nid[0:5], err)
   104  		}
   105  	default:
   106  		if len(d.network(nid).endpoints) > 0 {
   107  			nodes[rIP.String()] = rIP
   108  		}
   109  	}
   110  
   111  	log.Debugf("List of nodes: %s", nodes)
   112  
   113  	if add {
   114  		for _, rIP := range nodes {
   115  			if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil {
   116  				log.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err)
   117  			}
   118  		}
   119  	} else {
   120  		if len(nodes) == 0 {
   121  			if err := removeEncryption(lIP, rIP, d.secMap); err != nil {
   122  				log.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err)
   123  			}
   124  		}
   125  	}
   126  
   127  	return nil
   128  }
   129  
   130  func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error {
   131  	log.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP)
   132  	rIPs := remoteIP.String()
   133  
   134  	indices := make([]*spi, 0, len(keys))
   135  
   136  	err := programMangle(vni, true)
   137  	if err != nil {
   138  		log.Warn(err)
   139  	}
   140  
   141  	for i, k := range keys {
   142  		spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)}
   143  		dir := reverse
   144  		if i == 0 {
   145  			dir = bidir
   146  		}
   147  		fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true)
   148  		if err != nil {
   149  			log.Warn(err)
   150  		}
   151  		indices = append(indices, spis)
   152  		if i != 0 {
   153  			continue
   154  		}
   155  		err = programSP(fSA, rSA, true)
   156  		if err != nil {
   157  			log.Warn(err)
   158  		}
   159  	}
   160  
   161  	em.Lock()
   162  	em.nodes[rIPs] = indices
   163  	em.Unlock()
   164  
   165  	return nil
   166  }
   167  
   168  func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error {
   169  	em.Lock()
   170  	indices, ok := em.nodes[remoteIP.String()]
   171  	em.Unlock()
   172  	if !ok {
   173  		return nil
   174  	}
   175  	for i, idxs := range indices {
   176  		dir := reverse
   177  		if i == 0 {
   178  			dir = bidir
   179  		}
   180  		fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false)
   181  		if err != nil {
   182  			log.Warn(err)
   183  		}
   184  		if i != 0 {
   185  			continue
   186  		}
   187  		err = programSP(fSA, rSA, false)
   188  		if err != nil {
   189  			log.Warn(err)
   190  		}
   191  	}
   192  	return nil
   193  }
   194  
   195  func programMangle(vni uint32, add bool) (err error) {
   196  	var (
   197  		p      = strconv.FormatUint(uint64(vxlanPort), 10)
   198  		c      = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
   199  		m      = strconv.FormatUint(uint64(mark), 10)
   200  		chain  = "OUTPUT"
   201  		rule   = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m}
   202  		a      = "-A"
   203  		action = "install"
   204  	)
   205  
   206  	if add == iptables.Exists(iptables.Mangle, chain, rule...) {
   207  		return
   208  	}
   209  
   210  	if !add {
   211  		a = "-D"
   212  		action = "remove"
   213  	}
   214  
   215  	if err = iptables.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil {
   216  		log.Warnf("could not %s mangle rule: %v", action, err)
   217  	}
   218  
   219  	return
   220  }
   221  
   222  func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) {
   223  	var (
   224  		action      = "Removing"
   225  		xfrmProgram = ns.NlHandle().XfrmStateDel
   226  	)
   227  
   228  	if add {
   229  		action = "Adding"
   230  		xfrmProgram = ns.NlHandle().XfrmStateAdd
   231  	}
   232  
   233  	if dir&reverse > 0 {
   234  		rSA = &netlink.XfrmState{
   235  			Src:   remoteIP,
   236  			Dst:   localIP,
   237  			Proto: netlink.XFRM_PROTO_ESP,
   238  			Spi:   spi.reverse,
   239  			Mode:  netlink.XFRM_MODE_TRANSPORT,
   240  		}
   241  		if add {
   242  			rSA.Aead = buildAeadAlgo(k, spi.reverse)
   243  		}
   244  
   245  		exists, err := saExists(rSA)
   246  		if err != nil {
   247  			exists = !add
   248  		}
   249  
   250  		if add != exists {
   251  			log.Debugf("%s: rSA{%s}", action, rSA)
   252  			if err := xfrmProgram(rSA); err != nil {
   253  				log.Warnf("Failed %s rSA{%s}: %v", action, rSA, err)
   254  			}
   255  		}
   256  	}
   257  
   258  	if dir&forward > 0 {
   259  		fSA = &netlink.XfrmState{
   260  			Src:   localIP,
   261  			Dst:   remoteIP,
   262  			Proto: netlink.XFRM_PROTO_ESP,
   263  			Spi:   spi.forward,
   264  			Mode:  netlink.XFRM_MODE_TRANSPORT,
   265  		}
   266  		if add {
   267  			fSA.Aead = buildAeadAlgo(k, spi.forward)
   268  		}
   269  
   270  		exists, err := saExists(fSA)
   271  		if err != nil {
   272  			exists = !add
   273  		}
   274  
   275  		if add != exists {
   276  			log.Debugf("%s fSA{%s}", action, fSA)
   277  			if err := xfrmProgram(fSA); err != nil {
   278  				log.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err)
   279  			}
   280  		}
   281  	}
   282  
   283  	return
   284  }
   285  
   286  func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error {
   287  	action := "Removing"
   288  	xfrmProgram := ns.NlHandle().XfrmPolicyDel
   289  	if add {
   290  		action = "Adding"
   291  		xfrmProgram = ns.NlHandle().XfrmPolicyAdd
   292  	}
   293  
   294  	fullMask := net.CIDRMask(8*len(fSA.Src), 8*len(fSA.Src))
   295  
   296  	fPol := &netlink.XfrmPolicy{
   297  		Src:     &net.IPNet{IP: fSA.Src, Mask: fullMask},
   298  		Dst:     &net.IPNet{IP: fSA.Dst, Mask: fullMask},
   299  		Dir:     netlink.XFRM_DIR_OUT,
   300  		Proto:   17,
   301  		DstPort: 4789,
   302  		Mark: &netlink.XfrmMark{
   303  			Value: mark,
   304  		},
   305  		Tmpls: []netlink.XfrmPolicyTmpl{
   306  			{
   307  				Src:   fSA.Src,
   308  				Dst:   fSA.Dst,
   309  				Proto: netlink.XFRM_PROTO_ESP,
   310  				Mode:  netlink.XFRM_MODE_TRANSPORT,
   311  				Spi:   fSA.Spi,
   312  			},
   313  		},
   314  	}
   315  
   316  	exists, err := spExists(fPol)
   317  	if err != nil {
   318  		exists = !add
   319  	}
   320  
   321  	if add != exists {
   322  		log.Debugf("%s fSP{%s}", action, fPol)
   323  		if err := xfrmProgram(fPol); err != nil {
   324  			log.Warnf("%s fSP{%s}: %v", action, fPol, err)
   325  		}
   326  	}
   327  
   328  	return nil
   329  }
   330  
   331  func saExists(sa *netlink.XfrmState) (bool, error) {
   332  	_, err := ns.NlHandle().XfrmStateGet(sa)
   333  	switch err {
   334  	case nil:
   335  		return true, nil
   336  	case syscall.ESRCH:
   337  		return false, nil
   338  	default:
   339  		err = fmt.Errorf("Error while checking for SA existence: %v", err)
   340  		log.Warn(err)
   341  		return false, err
   342  	}
   343  }
   344  
   345  func spExists(sp *netlink.XfrmPolicy) (bool, error) {
   346  	_, err := ns.NlHandle().XfrmPolicyGet(sp)
   347  	switch err {
   348  	case nil:
   349  		return true, nil
   350  	case syscall.ENOENT:
   351  		return false, nil
   352  	default:
   353  		err = fmt.Errorf("Error while checking for SP existence: %v", err)
   354  		log.Warn(err)
   355  		return false, err
   356  	}
   357  }
   358  
   359  func buildSPI(src, dst net.IP, st uint32) int {
   360  	b := make([]byte, 4)
   361  	binary.BigEndian.PutUint32(b, st)
   362  	h := fnv.New32a()
   363  	h.Write(src)
   364  	h.Write(b)
   365  	h.Write(dst)
   366  	return int(binary.BigEndian.Uint32(h.Sum(nil)))
   367  }
   368  
   369  func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo {
   370  	salt := make([]byte, 4)
   371  	binary.BigEndian.PutUint32(salt, uint32(s))
   372  	return &netlink.XfrmStateAlgo{
   373  		Name:   "rfc4106(gcm(aes))",
   374  		Key:    append(k.value, salt...),
   375  		ICVLen: 64,
   376  	}
   377  }
   378  
   379  func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error {
   380  	d.secMap.Lock()
   381  	for node, indices := range d.secMap.nodes {
   382  		idxs, stop := f(node, indices)
   383  		if idxs != nil {
   384  			d.secMap.nodes[node] = idxs
   385  		}
   386  		if stop {
   387  			break
   388  		}
   389  	}
   390  	d.secMap.Unlock()
   391  	return nil
   392  }
   393  
   394  func (d *driver) setKeys(keys []*key) error {
   395  	// Accept the encryption keys and clear any stale encryption map
   396  	d.Lock()
   397  	d.keys = keys
   398  	d.secMap = &encrMap{nodes: map[string][]*spi{}}
   399  	d.Unlock()
   400  	log.Debugf("Initial encryption keys: %v", d.keys)
   401  	return nil
   402  }
   403  
   404  // updateKeys allows to add a new key and/or change the primary key and/or prune an existing key
   405  // The primary key is the key used in transmission and will go in first position in the list.
   406  func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
   407  	log.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey)
   408  
   409  	log.Debugf("Current: %v", d.keys)
   410  
   411  	var (
   412  		newIdx = -1
   413  		priIdx = -1
   414  		delIdx = -1
   415  		lIP    = net.ParseIP(d.bindAddress)
   416  	)
   417  
   418  	d.Lock()
   419  	// add new
   420  	if newKey != nil {
   421  		d.keys = append(d.keys, newKey)
   422  		newIdx += len(d.keys)
   423  	}
   424  	for i, k := range d.keys {
   425  		if primary != nil && k.tag == primary.tag {
   426  			priIdx = i
   427  		}
   428  		if pruneKey != nil && k.tag == pruneKey.tag {
   429  			delIdx = i
   430  		}
   431  	}
   432  	d.Unlock()
   433  
   434  	if (newKey != nil && newIdx == -1) ||
   435  		(primary != nil && priIdx == -1) ||
   436  		(pruneKey != nil && delIdx == -1) {
   437  		return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+
   438  			"(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx)
   439  	}
   440  
   441  	d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) {
   442  		rIP := net.ParseIP(rIPs)
   443  		return updateNodeKey(lIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false
   444  	})
   445  
   446  	d.Lock()
   447  	// swap primary
   448  	if priIdx != -1 {
   449  		swp := d.keys[0]
   450  		d.keys[0] = d.keys[priIdx]
   451  		d.keys[priIdx] = swp
   452  	}
   453  	// prune
   454  	if delIdx != -1 {
   455  		if delIdx == 0 {
   456  			delIdx = priIdx
   457  		}
   458  		d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...)
   459  	}
   460  	d.Unlock()
   461  
   462  	log.Debugf("Updated: %v", d.keys)
   463  
   464  	return nil
   465  }
   466  
   467  /********************************************************
   468   * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1
   469   * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1
   470   * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2
   471   *********************************************************/
   472  
   473  // Spis and keys are sorted in such away the one in position 0 is the primary
   474  func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi {
   475  	log.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx)
   476  
   477  	spis := idxs
   478  	log.Debugf("Current: %v", spis)
   479  
   480  	// add new
   481  	if newIdx != -1 {
   482  		spis = append(spis, &spi{
   483  			forward: buildSPI(lIP, rIP, curKeys[newIdx].tag),
   484  			reverse: buildSPI(rIP, lIP, curKeys[newIdx].tag),
   485  		})
   486  	}
   487  
   488  	if delIdx != -1 {
   489  		// -rSA0
   490  		programSA(lIP, rIP, spis[delIdx], nil, reverse, false)
   491  	}
   492  
   493  	if newIdx > -1 {
   494  		// +RSA2
   495  		programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true)
   496  	}
   497  
   498  	if priIdx > 0 {
   499  		// +fSA2
   500  		fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
   501  
   502  		// +fSP2, -fSP1
   503  		fullMask := net.CIDRMask(8*len(fSA2.Src), 8*len(fSA2.Src))
   504  		fSP1 := &netlink.XfrmPolicy{
   505  			Src:     &net.IPNet{IP: fSA2.Src, Mask: fullMask},
   506  			Dst:     &net.IPNet{IP: fSA2.Dst, Mask: fullMask},
   507  			Dir:     netlink.XFRM_DIR_OUT,
   508  			Proto:   17,
   509  			DstPort: 4789,
   510  			Mark: &netlink.XfrmMark{
   511  				Value: mark,
   512  			},
   513  			Tmpls: []netlink.XfrmPolicyTmpl{
   514  				{
   515  					Src:   fSA2.Src,
   516  					Dst:   fSA2.Dst,
   517  					Proto: netlink.XFRM_PROTO_ESP,
   518  					Mode:  netlink.XFRM_MODE_TRANSPORT,
   519  					Spi:   fSA2.Spi,
   520  				},
   521  			},
   522  		}
   523  		log.Debugf("Updating fSP{%s}", fSP1)
   524  		if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil {
   525  			log.Warnf("Failed to update fSP{%s}: %v", fSP1, err)
   526  		}
   527  
   528  		// -fSA1
   529  		programSA(lIP, rIP, spis[0], nil, forward, false)
   530  	}
   531  
   532  	// swap
   533  	if priIdx > 0 {
   534  		swp := spis[0]
   535  		spis[0] = spis[priIdx]
   536  		spis[priIdx] = swp
   537  	}
   538  	// prune
   539  	if delIdx != -1 {
   540  		if delIdx == 0 {
   541  			delIdx = priIdx
   542  		}
   543  		spis = append(spis[:delIdx], spis[delIdx+1:]...)
   544  	}
   545  
   546  	log.Debugf("Updated: %v", spis)
   547  
   548  	return spis
   549  }
   550  
   551  func (n *network) maxMTU() int {
   552  	mtu := 1500
   553  	if n.mtu != 0 {
   554  		mtu = n.mtu
   555  	}
   556  	mtu -= vxlanEncap
   557  	if n.secure {
   558  		// In case of encryption account for the
   559  		// esp packet espansion and padding
   560  		mtu -= pktExpansion
   561  		mtu -= (mtu % 4)
   562  	}
   563  	return mtu
   564  }