k8s.io/kubernetes@v1.29.3/pkg/proxy/nftables/helpers_test.go (about)

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package nftables
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"regexp"
    23  	"runtime"
    24  	"sort"
    25  	"strings"
    26  	"testing"
    27  
    28  	"github.com/danwinship/knftables"
    29  	"github.com/google/go-cmp/cmp"
    30  	"github.com/lithammer/dedent"
    31  
    32  	"k8s.io/api/core/v1"
    33  	"k8s.io/apimachinery/pkg/util/sets"
    34  	netutils "k8s.io/utils/net"
    35  )
    36  
    37  // getLine returns a string containing the file and line number of the caller, if
    38  // possible. This is useful in tests with a large number of cases - when something goes
    39  // wrong you can find which case more easily.
    40  func getLine() string {
    41  	_, file, line, ok := runtime.Caller(1)
    42  	if !ok {
    43  		return ""
    44  	}
    45  	return fmt.Sprintf(" (from %s:%d)", file, line)
    46  }
    47  
    48  // objectOrder defines the order we sort different types into (higher = earlier); while
    49  // not necessary just for comparison purposes, it's more intuitive in the Diff output to
    50  // see rules/sets/maps before chains/elements.
    51  var objectOrder = map[string]int{
    52  	"table":   10,
    53  	"chain":   9,
    54  	"rule":    8,
    55  	"set":     7,
    56  	"map":     6,
    57  	"element": 5,
    58  	// anything else: 0
    59  }
    60  
    61  // sortNFTablesTransaction sorts an nftables transaction into a standard order for comparison
    62  func sortNFTablesTransaction(tx string) string {
    63  	lines := strings.Split(tx, "\n")
    64  
    65  	// strip blank lines and comments
    66  	for i := 0; i < len(lines); {
    67  		if lines[i] == "" || lines[i][0] == '#' {
    68  			lines = append(lines[:i], lines[i+1:]...)
    69  		} else {
    70  			i++
    71  		}
    72  	}
    73  
    74  	// sort remaining lines
    75  	sort.SliceStable(lines, func(i, j int) bool {
    76  		li := lines[i]
    77  		wi := strings.Split(li, " ")
    78  		lj := lines[j]
    79  		wj := strings.Split(lj, " ")
    80  
    81  		// All lines will start with "add OBJECTTYPE ip kube-proxy". Everything
    82  		// except "add table" will have an object name after the table name, and
    83  		// "add table" will have a comment after the table name. So every line
    84  		// should have at least 5 words.
    85  		if len(wi) < 5 || len(wj) < 5 {
    86  			return false
    87  		}
    88  
    89  		// Sort by object type first.
    90  		if wi[1] != wj[1] {
    91  			return objectOrder[wi[1]] >= objectOrder[wj[1]]
    92  		}
    93  
    94  		// Sort by object name when object type is identical.
    95  		if wi[4] != wj[4] {
    96  			return wi[4] < wj[4]
    97  		}
    98  
    99  		// Leave rules in the order they were originally added.
   100  		if wi[1] == "rule" {
   101  			return false
   102  		}
   103  
   104  		// Sort by the whole line when object type and name is identical. (e.g.,
   105  		// individual "add rule" and "add element" lines in a chain/set/map.)
   106  		return li < lj
   107  	})
   108  	return strings.Join(lines, "\n")
   109  }
   110  
   111  // diffNFTablesTransaction is a (testable) helper function for assertNFTablesTransactionEqual
   112  func diffNFTablesTransaction(expected, result string) string {
   113  	expected = sortNFTablesTransaction(expected)
   114  	result = sortNFTablesTransaction(result)
   115  
   116  	return cmp.Diff(expected, result)
   117  }
   118  
   119  // assertNFTablesTransactionEqual asserts that expected and result are equal, ignoring
   120  // irrelevant differences.
   121  func assertNFTablesTransactionEqual(t *testing.T, line string, expected, result string) {
   122  	diff := diffNFTablesTransaction(expected, result)
   123  	if diff != "" {
   124  		t.Errorf("tables do not match%s:\ndiff:\n%s\nfull result: %+v", line, diff, result)
   125  	}
   126  }
   127  
   128  // diffNFTablesChain is a (testable) helper function for assertNFTablesChainEqual
   129  func diffNFTablesChain(nft *knftables.Fake, chain, expected string) string {
   130  	expected = strings.TrimSpace(expected)
   131  	result := ""
   132  	if ch := nft.Table.Chains[chain]; ch != nil {
   133  		for i, rule := range ch.Rules {
   134  			if i > 0 {
   135  				result += "\n"
   136  			}
   137  			result += rule.Rule
   138  		}
   139  	}
   140  
   141  	return cmp.Diff(expected, result)
   142  }
   143  
   144  // assertNFTablesChainEqual asserts that the indicated chain in nft's table contains
   145  // exactly the rules in expected (in that order).
   146  func assertNFTablesChainEqual(t *testing.T, line string, nft *knftables.Fake, chain, expected string) {
   147  	if diff := diffNFTablesChain(nft, chain, expected); diff != "" {
   148  		t.Errorf("rules do not match%s:\ndiff:\n%s", line, diff)
   149  	}
   150  }
   151  
   152  // nftablesTracer holds data used while virtually tracing a packet through a set of
   153  // iptables rules
   154  type nftablesTracer struct {
   155  	nft     *knftables.Fake
   156  	nodeIPs sets.Set[string]
   157  	t       *testing.T
   158  
   159  	// matches accumulates the list of rules that were matched, for debugging purposes.
   160  	matches []string
   161  
   162  	// outputs accumulates the list of matched terminal rule targets (endpoint
   163  	// IP:ports, or a special target like "REJECT") and is eventually used to generate
   164  	// the return value of tracePacket.
   165  	outputs []string
   166  
   167  	// markMasq tracks whether the packet has been marked for masquerading
   168  	markMasq bool
   169  }
   170  
   171  // newNFTablesTracer creates an nftablesTracer. nodeIPs are the IP to treat as local node
   172  // IPs (for determining whether rules with "fib saddr type local" or "fib daddr type
   173  // local" match).
   174  func newNFTablesTracer(t *testing.T, nft *knftables.Fake, nodeIPs []string) *nftablesTracer {
   175  	return &nftablesTracer{
   176  		nft:     nft,
   177  		nodeIPs: sets.New(nodeIPs...),
   178  		t:       t,
   179  	}
   180  }
   181  
   182  func (tracer *nftablesTracer) addressMatches(ipStr, not, ruleAddress string) bool {
   183  	ip := netutils.ParseIPSloppy(ipStr)
   184  	if ip == nil {
   185  		tracer.t.Fatalf("Bad IP in test case: %s", ipStr)
   186  	}
   187  
   188  	var match bool
   189  	if strings.Contains(ruleAddress, "/") {
   190  		_, cidr, err := netutils.ParseCIDRSloppy(ruleAddress)
   191  		if err != nil {
   192  			tracer.t.Errorf("Bad CIDR in kube-proxy output: %v", err)
   193  		}
   194  		match = cidr.Contains(ip)
   195  	} else {
   196  		ip2 := netutils.ParseIPSloppy(ruleAddress)
   197  		if ip2 == nil {
   198  			tracer.t.Errorf("Bad IP/CIDR in kube-proxy output: %s", ruleAddress)
   199  		}
   200  		match = ip.Equal(ip2)
   201  	}
   202  
   203  	if not == "!= " {
   204  		return !match
   205  	} else {
   206  		return match
   207  	}
   208  }
   209  
   210  // matchDestIPOnly checks an "ip daddr" against a set/map, and returns the matching
   211  // Element, if found.
   212  func (tracer *nftablesTracer) matchDestIPOnly(elements []*knftables.Element, destIP string) *knftables.Element {
   213  	for _, element := range elements {
   214  		if element.Key[0] == destIP {
   215  			return element
   216  		}
   217  	}
   218  	return nil
   219  }
   220  
   221  // matchDest checks an "ip daddr . meta l4proto . th dport" against a set/map, and returns
   222  // the matching Element, if found.
   223  func (tracer *nftablesTracer) matchDest(elements []*knftables.Element, destIP, protocol, destPort string) *knftables.Element {
   224  	for _, element := range elements {
   225  		if element.Key[0] == destIP && element.Key[1] == protocol && element.Key[2] == destPort {
   226  			return element
   227  		}
   228  	}
   229  	return nil
   230  }
   231  
   232  // matchDestAndSource checks an "ip daddr . meta l4proto . th dport . ip saddr" against a
   233  // set/map, where the source is allowed to be a CIDR, and returns the matching Element, if
   234  // found.
   235  func (tracer *nftablesTracer) matchDestAndSource(elements []*knftables.Element, destIP, protocol, destPort, sourceIP string) *knftables.Element {
   236  	for _, element := range elements {
   237  		if element.Key[0] == destIP && element.Key[1] == protocol && element.Key[2] == destPort && tracer.addressMatches(sourceIP, "", element.Key[3]) {
   238  			return element
   239  		}
   240  	}
   241  	return nil
   242  }
   243  
   244  // matchDestPort checks an "meta l4proto . th dport" against a set/map, and returns the
   245  // matching Element, if found.
   246  func (tracer *nftablesTracer) matchDestPort(elements []*knftables.Element, protocol, destPort string) *knftables.Element {
   247  	for _, element := range elements {
   248  		if element.Key[0] == protocol && element.Key[1] == destPort {
   249  			return element
   250  		}
   251  	}
   252  	return nil
   253  }
   254  
   255  // We intentionally don't try to parse arbitrary nftables rules, as the syntax is quite
   256  // complicated and context sensitive. (E.g., "ip daddr" could be the start of an address
   257  // comparison, or it could be the start of a set/map lookup.) Instead, we just have
   258  // regexps to recognize the specific pieces of rules that we create in proxier.go.
   259  // Anything matching ignoredRegexp gets stripped out of the rule, and then what's left
   260  // *must* match one of the cases in runChain or an error will be logged. In cases where
   261  // the regexp doesn't end with `$`, and the matched rule succeeds against the input data,
   262  // runChain will continue trying to match the rest of the rule. E.g., "ip daddr 10.0.0.1
   263  // drop" would first match destAddrRegexp, and then (assuming destIP was "10.0.0.1") would
   264  // match verdictRegexp.
   265  
   266  var destAddrRegexp = regexp.MustCompile(`^ip6* daddr (!= )?(\S+)`)
   267  var destAddrLocalRegexp = regexp.MustCompile(`^fib daddr type local`)
   268  var destPortRegexp = regexp.MustCompile(`^(tcp|udp|sctp) dport (\d+)`)
   269  var destIPOnlyLookupRegexp = regexp.MustCompile(`^ip6* daddr @(\S+)`)
   270  var destLookupRegexp = regexp.MustCompile(`^ip6* daddr \. meta l4proto \. th dport @(\S+)`)
   271  var destSourceLookupRegexp = regexp.MustCompile(`^ip6* daddr \. meta l4proto \. th dport \. ip6* saddr @(\S+)`)
   272  var destPortLookupRegexp = regexp.MustCompile(`^meta l4proto \. th dport @(\S+)`)
   273  
   274  var destDispatchRegexp = regexp.MustCompile(`^ip6* daddr \. meta l4proto \. th dport vmap @(\S+)$`)
   275  var destPortDispatchRegexp = regexp.MustCompile(`^meta l4proto \. th dport vmap @(\S+)$`)
   276  
   277  var sourceAddrRegexp = regexp.MustCompile(`^ip6* saddr (!= )?(\S+)`)
   278  var sourceAddrLocalRegexp = regexp.MustCompile(`^fib saddr type local`)
   279  
   280  var endpointVMAPRegexp = regexp.MustCompile(`^numgen random mod \d+ vmap \{(.*)\}$`)
   281  var endpointVMapEntryRegexp = regexp.MustCompile(`\d+ : goto (\S+)`)
   282  
   283  var masqueradeRegexp = regexp.MustCompile(`^jump ` + kubeMarkMasqChain + `$`)
   284  var jumpRegexp = regexp.MustCompile(`^(jump|goto) (\S+)$`)
   285  var returnRegexp = regexp.MustCompile(`^return$`)
   286  var verdictRegexp = regexp.MustCompile(`^(drop|reject)$`)
   287  var dnatRegexp = regexp.MustCompile(`^meta l4proto (tcp|udp|sctp) dnat to (\S+)$`)
   288  
   289  var ignoredRegexp = regexp.MustCompile(strings.Join(
   290  	[]string{
   291  		// Ignore comments (which can only appear at the end of a rule).
   292  		` *comment "[^"]*"$`,
   293  
   294  		// The trace tests only check new connections, so for our purposes, this
   295  		// check always succeeds (and thus can be ignored).
   296  		`^ct state new`,
   297  
   298  		// Likewise, this rule never matches and thus never drops anything, and so
   299  		// can be ignored.
   300  		`^ct state invalid drop$`,
   301  	},
   302  	"|",
   303  ))
   304  
   305  // runChain runs the given packet through the rules in the given table and chain, updating
   306  // tracer's internal state accordingly. It returns true if it hits a terminal action.
   307  func (tracer *nftablesTracer) runChain(chname, sourceIP, protocol, destIP, destPort string) bool {
   308  	ch := tracer.nft.Table.Chains[chname]
   309  	if ch == nil {
   310  		tracer.t.Errorf("unknown chain %q", chname)
   311  		return true
   312  	}
   313  
   314  	for _, ruleObj := range ch.Rules {
   315  		rule := ignoredRegexp.ReplaceAllLiteralString(ruleObj.Rule, "")
   316  		for rule != "" {
   317  			rule = strings.TrimLeft(rule, " ")
   318  
   319  			// Note that the order of (some of) the cases is important. e.g.,
   320  			// masqueradeRegexp must be checked before jumpRegexp, since
   321  			// jumpRegexp would also match masqueradeRegexp but do the wrong
   322  			// thing with it.
   323  
   324  			switch {
   325  			case destIPOnlyLookupRegexp.MatchString(rule):
   326  				// `^ip6* daddr @(\S+)`
   327  				// Tests whether destIP is a member of the indicated set.
   328  				match := destIPOnlyLookupRegexp.FindStringSubmatch(rule)
   329  				rule = strings.TrimPrefix(rule, match[0])
   330  				set := match[1]
   331  				if tracer.matchDestIPOnly(tracer.nft.Table.Sets[set].Elements, destIP) == nil {
   332  					rule = ""
   333  					break
   334  				}
   335  
   336  			case destSourceLookupRegexp.MatchString(rule):
   337  				// `^ip6* daddr . meta l4proto . th dport . ip6* saddr @(\S+)`
   338  				// Tests whether "destIP . protocol . destPort . sourceIP" is
   339  				// a member of the indicated set.
   340  				match := destSourceLookupRegexp.FindStringSubmatch(rule)
   341  				rule = strings.TrimPrefix(rule, match[0])
   342  				set := match[1]
   343  				if tracer.matchDestAndSource(tracer.nft.Table.Sets[set].Elements, destIP, protocol, destPort, sourceIP) == nil {
   344  					rule = ""
   345  					break
   346  				}
   347  
   348  			case destLookupRegexp.MatchString(rule):
   349  				// `^ip6* daddr . meta l4proto . th dport @(\S+)`
   350  				// Tests whether "destIP . protocol . destPort" is a member
   351  				// of the indicated set.
   352  				match := destLookupRegexp.FindStringSubmatch(rule)
   353  				rule = strings.TrimPrefix(rule, match[0])
   354  				set := match[1]
   355  				if tracer.matchDest(tracer.nft.Table.Sets[set].Elements, destIP, protocol, destPort) == nil {
   356  					rule = ""
   357  					break
   358  				}
   359  
   360  			case destPortLookupRegexp.MatchString(rule):
   361  				// `^meta l4proto . th dport @(\S+)`
   362  				// Tests whether "protocol . destPort" is a member of the
   363  				// indicated set.
   364  				match := destPortLookupRegexp.FindStringSubmatch(rule)
   365  				rule = strings.TrimPrefix(rule, match[0])
   366  				set := match[1]
   367  				if tracer.matchDestPort(tracer.nft.Table.Sets[set].Elements, protocol, destPort) == nil {
   368  					rule = ""
   369  					break
   370  				}
   371  
   372  			case destDispatchRegexp.MatchString(rule):
   373  				// `^ip6* daddr \. meta l4proto \. th dport vmap @(\S+)$`
   374  				// Looks up "destIP . protocol . destPort" in the indicated
   375  				// verdict map, and if found, runs the assocated verdict.
   376  				match := destDispatchRegexp.FindStringSubmatch(rule)
   377  				mapName := match[1]
   378  				element := tracer.matchDest(tracer.nft.Table.Maps[mapName].Elements, destIP, protocol, destPort)
   379  				if element == nil {
   380  					rule = ""
   381  					break
   382  				} else {
   383  					rule = element.Value[0]
   384  				}
   385  
   386  			case destPortDispatchRegexp.MatchString(rule):
   387  				// `^meta l4proto \. th dport vmap @(\S+)$`
   388  				// Looks up "protocol . destPort" in the indicated verdict map,
   389  				// and if found, runs the assocated verdict.
   390  				match := destPortDispatchRegexp.FindStringSubmatch(rule)
   391  				mapName := match[1]
   392  				element := tracer.matchDestPort(tracer.nft.Table.Maps[mapName].Elements, protocol, destPort)
   393  				if element == nil {
   394  					rule = ""
   395  					break
   396  				} else {
   397  					rule = element.Value[0]
   398  				}
   399  
   400  			case destAddrRegexp.MatchString(rule):
   401  				// `^ip6* daddr (!= )?(\S+)`
   402  				// Tests whether destIP does/doesn't match a literal.
   403  				match := destAddrRegexp.FindStringSubmatch(rule)
   404  				rule = strings.TrimPrefix(rule, match[0])
   405  				not, ip := match[1], match[2]
   406  				if !tracer.addressMatches(destIP, not, ip) {
   407  					rule = ""
   408  					break
   409  				}
   410  
   411  			case destAddrLocalRegexp.MatchString(rule):
   412  				// `^fib daddr type local`
   413  				// Tests whether destIP is a local IP.
   414  				match := destAddrLocalRegexp.FindStringSubmatch(rule)
   415  				rule = strings.TrimPrefix(rule, match[0])
   416  				if !tracer.nodeIPs.Has(destIP) {
   417  					rule = ""
   418  					break
   419  				}
   420  
   421  			case destPortRegexp.MatchString(rule):
   422  				// `^(tcp|udp|sctp) dport (\d+)`
   423  				// Tests whether destPort matches a literal.
   424  				match := destPortRegexp.FindStringSubmatch(rule)
   425  				rule = strings.TrimPrefix(rule, match[0])
   426  				proto, port := match[1], match[2]
   427  				if protocol != proto || destPort != port {
   428  					rule = ""
   429  					break
   430  				}
   431  
   432  			case sourceAddrRegexp.MatchString(rule):
   433  				// `^ip6* saddr (!= )?(\S+)`
   434  				// Tests whether sourceIP does/doesn't match a literal.
   435  				match := sourceAddrRegexp.FindStringSubmatch(rule)
   436  				rule = strings.TrimPrefix(rule, match[0])
   437  				not, ip := match[1], match[2]
   438  				if !tracer.addressMatches(sourceIP, not, ip) {
   439  					rule = ""
   440  					break
   441  				}
   442  
   443  			case sourceAddrLocalRegexp.MatchString(rule):
   444  				// `^fib saddr type local`
   445  				// Tests whether sourceIP is a local IP.
   446  				match := sourceAddrLocalRegexp.FindStringSubmatch(rule)
   447  				rule = strings.TrimPrefix(rule, match[0])
   448  				if !tracer.nodeIPs.Has(sourceIP) {
   449  					rule = ""
   450  					break
   451  				}
   452  
   453  			case masqueradeRegexp.MatchString(rule):
   454  				// `^jump mark-for-masquerade$`
   455  				// Mark for masquerade: we just treat the jump rule itself as
   456  				// being what creates the mark, rather than trying to handle
   457  				// the rules inside that chain and the "masquerading" chain.
   458  				match := jumpRegexp.FindStringSubmatch(rule)
   459  				rule = strings.TrimPrefix(rule, match[0])
   460  
   461  				tracer.matches = append(tracer.matches, ruleObj.Rule)
   462  				tracer.markMasq = true
   463  
   464  			case jumpRegexp.MatchString(rule):
   465  				// `^(jump|goto) (\S+)$`
   466  				// Jumps to another chain.
   467  				match := jumpRegexp.FindStringSubmatch(rule)
   468  				rule = strings.TrimPrefix(rule, match[0])
   469  				action, destChain := match[1], match[2]
   470  
   471  				tracer.matches = append(tracer.matches, ruleObj.Rule)
   472  				terminated := tracer.runChain(destChain, sourceIP, protocol, destIP, destPort)
   473  				if terminated {
   474  					// destChain reached a terminal statement, so we
   475  					// terminate too.
   476  					return true
   477  				} else if action == "goto" {
   478  					// After a goto, return to our calling chain
   479  					// (without terminating) rather than continuing
   480  					// with this chain.
   481  					return false
   482  				}
   483  
   484  			case verdictRegexp.MatchString(rule):
   485  				// `^(drop|reject)$`
   486  				// Drop/reject the packet and terminate processing.
   487  				match := verdictRegexp.FindStringSubmatch(rule)
   488  				verdict := match[1]
   489  
   490  				tracer.matches = append(tracer.matches, ruleObj.Rule)
   491  				tracer.outputs = append(tracer.outputs, strings.ToUpper(verdict))
   492  				return true
   493  
   494  			case returnRegexp.MatchString(rule):
   495  				// `^return$`
   496  				// Returns to the calling chain.
   497  				tracer.matches = append(tracer.matches, ruleObj.Rule)
   498  				return false
   499  
   500  			case dnatRegexp.MatchString(rule):
   501  				// `meta l4proto (tcp|udp|sctp) dnat to (\S+)`
   502  				// DNAT to an endpoint IP and terminate processing.
   503  				match := dnatRegexp.FindStringSubmatch(rule)
   504  				destEndpoint := match[2]
   505  
   506  				tracer.matches = append(tracer.matches, ruleObj.Rule)
   507  				tracer.outputs = append(tracer.outputs, destEndpoint)
   508  				return true
   509  
   510  			case endpointVMAPRegexp.MatchString(rule):
   511  				// `^numgen random mod \d+ vmap \{(.*)\}$`
   512  				// Selects a random endpoint and jumps to it. For tracePacket's
   513  				// purposes, we jump to *all* of the endpoints.
   514  				match := endpointVMAPRegexp.FindStringSubmatch(rule)
   515  				elements := match[1]
   516  
   517  				for _, match = range endpointVMapEntryRegexp.FindAllStringSubmatch(elements, -1) {
   518  					// `\d+ : goto (\S+)`
   519  					destChain := match[1]
   520  
   521  					tracer.matches = append(tracer.matches, ruleObj.Rule)
   522  					// Ignore return value; we know each endpoint has a
   523  					// terminating dnat verdict, but we want to gather all
   524  					// of the endpoints into tracer.output.
   525  					_ = tracer.runChain(destChain, sourceIP, protocol, destIP, destPort)
   526  				}
   527  				return true
   528  
   529  			default:
   530  				tracer.t.Errorf("unmatched rule: %s", ruleObj.Rule)
   531  				rule = ""
   532  			}
   533  		}
   534  	}
   535  
   536  	return false
   537  }
   538  
   539  // tracePacket determines what would happen to a packet with the given sourceIP, destIP,
   540  // and destPort, given the indicated iptables ruleData. nodeIPs are the local node IPs (for
   541  // rules matching "local"). (The protocol value should be lowercase as in nftables
   542  // rules, not uppercase as in corev1.)
   543  //
   544  // The return values are: an array of matched rules (for debugging), the final packet
   545  // destinations (a comma-separated list of IPs, or one of the special targets "ACCEPT",
   546  // "DROP", or "REJECT"), and whether the packet would be masqueraded.
   547  func tracePacket(t *testing.T, nft *knftables.Fake, sourceIP, protocol, destIP, destPort string, nodeIPs []string) ([]string, string, bool) {
   548  	tracer := newNFTablesTracer(t, nft, nodeIPs)
   549  
   550  	// Collect "base chains" (ie, the chains that are run by netfilter directly rather
   551  	// than only being run when they are jumped to). Skip postrouting because it only
   552  	// does masquerading and we handle that separately.
   553  	var baseChains []string
   554  	for chname, ch := range nft.Table.Chains {
   555  		if ch.Priority != nil && chname != "nat-postrouting" {
   556  			baseChains = append(baseChains, chname)
   557  		}
   558  	}
   559  
   560  	// Sort by priority
   561  	sort.Slice(baseChains, func(i, j int) bool {
   562  		// FIXME: IPv4 vs IPv6 doesn't actually matter here
   563  		iprio, _ := knftables.ParsePriority(knftables.IPv4Family, string(*nft.Table.Chains[baseChains[i]].Priority))
   564  		jprio, _ := knftables.ParsePriority(knftables.IPv4Family, string(*nft.Table.Chains[baseChains[j]].Priority))
   565  		return iprio < jprio
   566  	})
   567  
   568  	for _, chname := range baseChains {
   569  		terminated := tracer.runChain(chname, sourceIP, protocol, destIP, destPort)
   570  		if terminated {
   571  			break
   572  		}
   573  	}
   574  
   575  	return tracer.matches, strings.Join(tracer.outputs, ", "), tracer.markMasq
   576  }
   577  
   578  type packetFlowTest struct {
   579  	name     string
   580  	sourceIP string
   581  	protocol v1.Protocol
   582  	destIP   string
   583  	destPort int
   584  	output   string
   585  	masq     bool
   586  }
   587  
   588  func runPacketFlowTests(t *testing.T, line string, nft *knftables.Fake, nodeIPs []string, testCases []packetFlowTest) {
   589  	for _, tc := range testCases {
   590  		t.Run(tc.name, func(t *testing.T) {
   591  			protocol := strings.ToLower(string(tc.protocol))
   592  			if protocol == "" {
   593  				protocol = "tcp"
   594  			}
   595  			matches, output, masq := tracePacket(t, nft, tc.sourceIP, protocol, tc.destIP, fmt.Sprintf("%d", tc.destPort), nodeIPs)
   596  			var errors []string
   597  			if output != tc.output {
   598  				errors = append(errors, fmt.Sprintf("wrong output: expected %q got %q", tc.output, output))
   599  			}
   600  			if masq != tc.masq {
   601  				errors = append(errors, fmt.Sprintf("wrong masq: expected %v got %v", tc.masq, masq))
   602  			}
   603  			if errors != nil {
   604  				t.Errorf("Test %q of a packet from %s to %s:%d%s got result:\n%s\n\nBy matching:\n%s\n\n",
   605  					tc.name, tc.sourceIP, tc.destIP, tc.destPort, line, strings.Join(errors, "\n"), strings.Join(matches, "\n"))
   606  			}
   607  		})
   608  	}
   609  }
   610  
   611  // helpers_test unit tests
   612  
   613  var testInput = dedent.Dedent(`
   614  	add table ip testing { comment "rules for kube-proxy" ; }
   615  
   616  	add chain ip testing forward
   617  	add rule ip testing forward ct state invalid drop
   618  	add chain ip testing mark-for-masquerade
   619  	add rule ip testing mark-for-masquerade mark set mark or 0x4000
   620  	add chain ip testing masquerading
   621  	add rule ip testing masquerading mark and 0x4000 == 0 return
   622  	add rule ip testing masquerading mark set mark xor 0x4000
   623  	add rule ip testing masquerading masquerade fully-random
   624  
   625  	add set ip testing firewall { type ipv4_addr . inet_proto . inet_service ; comment "destinations that are subject to LoadBalancerSourceRanges" ; }
   626  	add set ip testing firewall-allow { type ipv4_addr . inet_proto . inet_service . ipv4_addr ; flags interval ; comment "destinations+sources that are allowed by LoadBalancerSourceRanges" ; }
   627  	add chain ip testing firewall-check
   628  	add chain ip testing firewall-allow-check
   629  	add rule ip testing firewall-allow-check ip daddr . meta l4proto . th dport . ip saddr @firewall-allow return
   630  	add rule ip testing firewall-allow-check drop
   631  	add rule ip testing firewall-check ip daddr . meta l4proto . th dport @firewall jump firewall-allow-check
   632  
   633  	# svc1
   634  	add chain ip testing service-ULMVA6XW-ns1/svc1/tcp/p80
   635  	add rule ip testing service-ULMVA6XW-ns1/svc1/tcp/p80 ip daddr 172.30.0.41 tcp dport 80 ip saddr != 10.0.0.0/8 jump mark-for-masquerade
   636  	add rule ip testing service-ULMVA6XW-ns1/svc1/tcp/p80 numgen random mod 1 vmap { 0 : goto endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 }
   637  
   638  	add chain ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80
   639  	add rule ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 ip saddr 10.180.0.1 jump mark-for-masquerade
   640  	add rule ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 meta l4proto tcp dnat to 10.180.0.1:80
   641  
   642  	add element ip testing service-ips { 172.30.0.41 . tcp . 80 : goto service-ULMVA6XW-ns1/svc1/tcp/p80 }
   643  
   644  	# svc2
   645  	add chain ip testing service-42NFTM6N-ns2/svc2/tcp/p80
   646  	add rule ip testing service-42NFTM6N-ns2/svc2/tcp/p80 ip daddr 172.30.0.42 tcp dport 80 ip saddr != 10.0.0.0/8 jump mark-for-masquerade
   647  	add rule ip testing service-42NFTM6N-ns2/svc2/tcp/p80 numgen random mod 1 vmap { 0 : goto endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 }
   648  	add chain ip testing external-42NFTM6N-ns2/svc2/tcp/p80
   649  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 ip saddr 10.0.0.0/8 goto service-42NFTM6N-ns2/svc2/tcp/p80 comment "short-circuit pod traffic"
   650  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 fib saddr type local jump mark-for-masquerade comment "masquerade local traffic"
   651  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 fib saddr type local goto service-42NFTM6N-ns2/svc2/tcp/p80 comment "short-circuit local traffic"
   652  	add chain ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80
   653  	add rule ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 ip saddr 10.180.0.2 jump mark-for-masquerade
   654  	add rule ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 meta l4proto tcp dnat to 10.180.0.2:80
   655  
   656  	add element ip testing service-ips { 172.30.0.42 . tcp . 80 : goto service-42NFTM6N-ns2/svc2/tcp/p80 }
   657  	add element ip testing service-ips { 192.168.99.22 . tcp . 80 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   658  	add element ip testing service-ips { 1.2.3.4 . tcp . 80 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   659  	add element ip testing service-nodeports { tcp . 3001 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   660  
   661  	add element ip testing no-endpoint-nodeports { tcp . 3001 comment "ns2/svc2:p80" : drop }
   662  	add element ip testing no-endpoint-services { 1.2.3.4 . tcp . 80 comment "ns2/svc2:p80" : drop }
   663  	add element ip testing no-endpoint-services { 192.168.99.22 . tcp . 80 comment "ns2/svc2:p80" : drop }
   664  	`)
   665  
   666  var testExpected = dedent.Dedent(`
   667  	add table ip testing { comment "rules for kube-proxy" ; }
   668  	add chain ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80
   669  	add chain ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80
   670  	add chain ip testing external-42NFTM6N-ns2/svc2/tcp/p80
   671  	add chain ip testing firewall-allow-check
   672  	add chain ip testing firewall-check
   673  	add chain ip testing forward
   674  	add chain ip testing mark-for-masquerade
   675  	add chain ip testing masquerading
   676  	add chain ip testing service-42NFTM6N-ns2/svc2/tcp/p80
   677  	add chain ip testing service-ULMVA6XW-ns1/svc1/tcp/p80
   678  	add rule ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 ip saddr 10.180.0.1 jump mark-for-masquerade
   679  	add rule ip testing endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 meta l4proto tcp dnat to 10.180.0.1:80
   680  	add rule ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 ip saddr 10.180.0.2 jump mark-for-masquerade
   681  	add rule ip testing endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 meta l4proto tcp dnat to 10.180.0.2:80
   682  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 ip saddr 10.0.0.0/8 goto service-42NFTM6N-ns2/svc2/tcp/p80 comment "short-circuit pod traffic"
   683  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 fib saddr type local jump mark-for-masquerade comment "masquerade local traffic"
   684  	add rule ip testing external-42NFTM6N-ns2/svc2/tcp/p80 fib saddr type local goto service-42NFTM6N-ns2/svc2/tcp/p80 comment "short-circuit local traffic"
   685  	add rule ip testing firewall-allow-check ip daddr . meta l4proto . th dport . ip saddr @firewall-allow return
   686  	add rule ip testing firewall-allow-check drop
   687  	add rule ip testing firewall-check ip daddr . meta l4proto . th dport @firewall jump firewall-allow-check
   688  	add rule ip testing forward ct state invalid drop
   689  	add rule ip testing mark-for-masquerade mark set mark or 0x4000
   690  	add rule ip testing masquerading mark and 0x4000 == 0 return
   691  	add rule ip testing masquerading mark set mark xor 0x4000
   692  	add rule ip testing masquerading masquerade fully-random
   693  	add rule ip testing service-42NFTM6N-ns2/svc2/tcp/p80 ip daddr 172.30.0.42 tcp dport 80 ip saddr != 10.0.0.0/8 jump mark-for-masquerade
   694  	add rule ip testing service-42NFTM6N-ns2/svc2/tcp/p80 numgen random mod 1 vmap { 0 : goto endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 }
   695  	add rule ip testing service-ULMVA6XW-ns1/svc1/tcp/p80 ip daddr 172.30.0.41 tcp dport 80 ip saddr != 10.0.0.0/8 jump mark-for-masquerade
   696  	add rule ip testing service-ULMVA6XW-ns1/svc1/tcp/p80 numgen random mod 1 vmap { 0 : goto endpoint-5OJB2KTY-ns1/svc1/tcp/p80__10.180.0.1/80 }
   697  	add set ip testing firewall { type ipv4_addr . inet_proto . inet_service ; comment "destinations that are subject to LoadBalancerSourceRanges" ; }
   698  	add set ip testing firewall-allow { type ipv4_addr . inet_proto . inet_service . ipv4_addr ; flags interval ; comment "destinations+sources that are allowed by LoadBalancerSourceRanges" ; }
   699  	add element ip testing no-endpoint-nodeports { tcp . 3001 comment "ns2/svc2:p80" : drop }
   700  	add element ip testing no-endpoint-services { 1.2.3.4 . tcp . 80 comment "ns2/svc2:p80" : drop }
   701  	add element ip testing no-endpoint-services { 192.168.99.22 . tcp . 80 comment "ns2/svc2:p80" : drop }
   702  	add element ip testing service-ips { 1.2.3.4 . tcp . 80 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   703  	add element ip testing service-ips { 172.30.0.41 . tcp . 80 : goto service-ULMVA6XW-ns1/svc1/tcp/p80 }
   704  	add element ip testing service-ips { 172.30.0.42 . tcp . 80 : goto service-42NFTM6N-ns2/svc2/tcp/p80 }
   705  	add element ip testing service-ips { 192.168.99.22 . tcp . 80 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   706  	add element ip testing service-nodeports { tcp . 3001 : goto external-42NFTM6N-ns2/svc2/tcp/p80 }
   707  	`)
   708  
   709  func Test_sortNFTablesTransaction(t *testing.T) {
   710  	output := sortNFTablesTransaction(testInput)
   711  	expected := strings.TrimSpace(testExpected)
   712  
   713  	diff := cmp.Diff(expected, output)
   714  	if diff != "" {
   715  		t.Errorf("output does not match expected:\n%s", diff)
   716  	}
   717  }
   718  
   719  func Test_diffNFTablesTransaction(t *testing.T) {
   720  	diff := diffNFTablesTransaction(testInput, testExpected)
   721  	if diff != "" {
   722  		t.Errorf("found diff in inputs that should have been equal:\n%s", diff)
   723  	}
   724  
   725  	notExpected := strings.Join(strings.Split(testExpected, "\n")[2:], "\n")
   726  	diff = diffNFTablesTransaction(testInput, notExpected)
   727  	if diff == "" {
   728  		t.Errorf("found no diff in inputs that should have been different")
   729  	}
   730  }
   731  
   732  func Test_diffNFTablesChain(t *testing.T) {
   733  	fake := knftables.NewFake(knftables.IPv4Family, "testing")
   734  	tx := fake.NewTransaction()
   735  
   736  	tx.Add(&knftables.Table{})
   737  	tx.Add(&knftables.Chain{
   738  		Name: "mark-masq-chain",
   739  	})
   740  	tx.Add(&knftables.Chain{
   741  		Name: "masquerade-chain",
   742  	})
   743  	tx.Add(&knftables.Chain{
   744  		Name: "empty-chain",
   745  	})
   746  
   747  	tx.Add(&knftables.Rule{
   748  		Chain: "mark-masq-chain",
   749  		Rule:  "mark set mark or 0x4000",
   750  	})
   751  
   752  	tx.Add(&knftables.Rule{
   753  		Chain: "masquerade-chain",
   754  		Rule:  "mark and 0x4000 == 0 return",
   755  	})
   756  	tx.Add(&knftables.Rule{
   757  		Chain: "masquerade-chain",
   758  		Rule:  "mark set mark xor 0x4000",
   759  	})
   760  	tx.Add(&knftables.Rule{
   761  		Chain: "masquerade-chain",
   762  		Rule:  "masquerade fully-random",
   763  	})
   764  
   765  	err := fake.Run(context.Background(), tx)
   766  	if err != nil {
   767  		t.Fatalf("Unexpected error running transaction: %v", err)
   768  	}
   769  
   770  	diff := diffNFTablesChain(fake, "mark-masq-chain", "mark set mark or 0x4000")
   771  	if diff != "" {
   772  		t.Errorf("unexpected difference in mark-masq-chain:\n%s", diff)
   773  	}
   774  	diff = diffNFTablesChain(fake, "mark-masq-chain", "mark set mark or 0x4000\n")
   775  	if diff != "" {
   776  		t.Errorf("unexpected difference in mark-masq-chain with trailing newline:\n%s", diff)
   777  	}
   778  
   779  	diff = diffNFTablesChain(fake, "masquerade-chain", "mark and 0x4000 == 0 return\nmark set mark xor 0x4000\nmasquerade fully-random")
   780  	if diff != "" {
   781  		t.Errorf("unexpected difference in masquerade-chain:\n%s", diff)
   782  	}
   783  	diff = diffNFTablesChain(fake, "masquerade-chain", "mark set mark xor 0x4000\nmasquerade fully-random")
   784  	if diff == "" {
   785  		t.Errorf("unexpected lack of difference in wrong masquerade-chain")
   786  	}
   787  
   788  	diff = diffNFTablesChain(fake, "empty-chain", "")
   789  	if diff != "" {
   790  		t.Errorf("unexpected difference in empty-chain:\n%s", diff)
   791  	}
   792  	diff = diffNFTablesChain(fake, "empty-chain", "\n")
   793  	if diff != "" {
   794  		t.Errorf("unexpected difference in empty-chain with trailing newline:\n%s", diff)
   795  	}
   796  }