github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/rule_test.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package netlink
     5  
     6  import (
     7  	"net/netip"
     8  	"testing"
     9  
    10  	"golang.org/x/sys/unix"
    11  )
    12  
    13  func TestRuleAddDel(t *testing.T) {
    14  	skipUnlessRoot(t)
    15  	defer setUpNetlinkTest(t)()
    16  
    17  	srcNet := netip.MustParsePrefix("172.16.0.1/16")
    18  	dstNet := netip.MustParsePrefix("172.16.1.1/24")
    19  
    20  	rulesBegin, err := RuleList(FAMILY_V4)
    21  	if err != nil {
    22  		t.Fatal(err)
    23  	}
    24  
    25  	rule := NewRule()
    26  	rule.Table = unix.RT_TABLE_MAIN
    27  	rule.Src = srcNet
    28  	rule.Dst = dstNet
    29  	rule.Priority = 5
    30  	rule.OifName = "lo"
    31  	rule.IifName = "lo"
    32  	rule.Invert = true
    33  	rule.Tos = 0x10
    34  	rule.Dport = NewRulePortRange(80, 80)
    35  	rule.Sport = NewRulePortRange(1000, 1024)
    36  	rule.IPProto = unix.IPPROTO_UDP
    37  	rule.UIDRange = NewRuleUIDRange(100, 100)
    38  	if err := RuleAdd(rule); err != nil {
    39  		t.Fatal(err)
    40  	}
    41  
    42  	rules, err := RuleList(FAMILY_V4)
    43  	if err != nil {
    44  		t.Fatal(err)
    45  	}
    46  
    47  	if len(rules) != len(rulesBegin)+1 {
    48  		t.Fatal("Rule not added properly")
    49  	}
    50  
    51  	// find this rule
    52  	found := ruleExists(rules, *rule)
    53  	if !found {
    54  		t.Fatal("Rule has diffrent options than one added")
    55  	}
    56  
    57  	if err := RuleDel(rule); err != nil {
    58  		t.Fatal(err)
    59  	}
    60  
    61  	rulesEnd, err := RuleList(FAMILY_V4)
    62  	if err != nil {
    63  		t.Fatal(err)
    64  	}
    65  
    66  	if len(rulesEnd) != len(rulesBegin) {
    67  		t.Fatal("Rule not removed properly")
    68  	}
    69  }
    70  
    71  func TestRuleListFiltered(t *testing.T) {
    72  	skipUnlessRoot(t)
    73  	defer setUpNetlinkTest(t)()
    74  
    75  	t.Run("IPv4", testRuleListFilteredIPv4)
    76  	t.Run("IPv6", testRuleListFilteredIPv6)
    77  }
    78  
    79  func testRuleListFilteredIPv4(t *testing.T) {
    80  	srcNet := netip.MustParsePrefix("172.16.0.1/16")
    81  	dstNet := netip.MustParsePrefix("172.16.1.1/16")
    82  	runRuleListFiltered(t, FAMILY_V4, srcNet, dstNet)
    83  }
    84  
    85  func testRuleListFilteredIPv6(t *testing.T) {
    86  	srcNet := netip.MustParsePrefix("fd56:6b58:db28:2913::/64")
    87  	dstNet := netip.MustParsePrefix("fde9:379f:3b35:6635::/96")
    88  	runRuleListFiltered(t, FAMILY_V6, srcNet, dstNet)
    89  }
    90  
    91  func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet netip.Prefix) {
    92  	defaultRules, _ := RuleList(family)
    93  
    94  	tests := []struct {
    95  		name       string
    96  		ruleFilter *Rule
    97  		filterMask uint64
    98  		preRun     func() *Rule // Creates sample rule harness
    99  		postRun    func(*Rule)  // Deletes sample rule harness
   100  		setupWant  func(*Rule) ([]Rule, bool)
   101  	}{
   102  		{
   103  			name:       "returns all rules",
   104  			ruleFilter: nil,
   105  			filterMask: 0,
   106  			preRun:     func() *Rule { return nil },
   107  			postRun:    func(r *Rule) {},
   108  			setupWant: func(_ *Rule) ([]Rule, bool) {
   109  				return defaultRules, false
   110  			},
   111  		},
   112  		{
   113  			name:       "returns one rule filtered by Src",
   114  			ruleFilter: &Rule{Src: srcNet},
   115  			filterMask: RT_FILTER_SRC,
   116  			preRun: func() *Rule {
   117  				r := NewRule()
   118  				r.Src = srcNet
   119  				r.Priority = 1 // Must add priority and table otherwise it's auto-assigned
   120  				r.Table = 1
   121  				err := RuleAdd(r)
   122  				if err != nil {
   123  					t.Fatal(err)
   124  				}
   125  				return r
   126  			},
   127  			postRun: func(r *Rule) { RuleDel(r) },
   128  			setupWant: func(r *Rule) ([]Rule, bool) {
   129  				return []Rule{*r}, false
   130  			},
   131  		},
   132  		{
   133  			name:       "returns one rule filtered by Dst",
   134  			ruleFilter: &Rule{Dst: dstNet},
   135  			filterMask: RT_FILTER_DST,
   136  			preRun: func() *Rule {
   137  				r := NewRule()
   138  				r.Dst = dstNet
   139  				r.Priority = 1 // Must add priority and table otherwise it's auto-assigned
   140  				r.Table = 1
   141  				err := RuleAdd(r)
   142  				if err != nil {
   143  					t.Fatal(err)
   144  				}
   145  				return r
   146  			},
   147  			postRun: func(r *Rule) { RuleDel(r) },
   148  			setupWant: func(r *Rule) ([]Rule, bool) {
   149  				return []Rule{*r}, false
   150  			},
   151  		},
   152  		{
   153  			name:       "returns two rules filtered by Dst",
   154  			ruleFilter: &Rule{Dst: dstNet},
   155  			filterMask: RT_FILTER_DST,
   156  			preRun: func() *Rule {
   157  				r := NewRule()
   158  				r.Dst = dstNet
   159  				r.Priority = 1 // Must add priority and table otherwise it's auto-assigned
   160  				r.Table = 1
   161  				RuleAdd(r)
   162  
   163  				rc := *r // Create almost identical copy
   164  				rc.Src = srcNet
   165  				RuleAdd(&rc)
   166  
   167  				return r
   168  			},
   169  			postRun: func(r *Rule) {
   170  				RuleDel(r)
   171  
   172  				rc := *r // Delete the almost identical copy
   173  				rc.Src = srcNet
   174  				RuleDel(&rc)
   175  			},
   176  			setupWant: func(r *Rule) ([]Rule, bool) {
   177  				rs := []Rule{}
   178  				rs = append(rs, *r)
   179  
   180  				rc := *r // Append the almost identical copy
   181  				rc.Src = srcNet
   182  				rs = append(rs, rc)
   183  
   184  				return rs, false
   185  			},
   186  		},
   187  		{
   188  			name:       "returns one rule filtered by Src when two rules exist",
   189  			ruleFilter: &Rule{Src: srcNet},
   190  			filterMask: RT_FILTER_SRC,
   191  			preRun: func() *Rule {
   192  				r := NewRule()
   193  				r.Dst = dstNet
   194  				r.Priority = 1 // Must add priority and table otherwise it's auto-assigned
   195  				r.Table = 1
   196  				RuleAdd(r)
   197  
   198  				rc := *r // Create almost identical copy
   199  				rc.Src = srcNet
   200  				RuleAdd(&rc)
   201  
   202  				return r
   203  			},
   204  			postRun: func(r *Rule) {
   205  				RuleDel(r)
   206  
   207  				rc := *r // Delete the almost identical copy
   208  				rc.Src = srcNet
   209  				RuleDel(&rc)
   210  			},
   211  			setupWant: func(r *Rule) ([]Rule, bool) {
   212  				rs := []Rule{}
   213  				// Do not append `r`
   214  
   215  				rc := *r // Append the almost identical copy
   216  				rc.Src = srcNet
   217  				rs = append(rs, rc)
   218  
   219  				return rs, false
   220  			},
   221  		},
   222  		{
   223  			name:       "returns rules with specific priority",
   224  			ruleFilter: &Rule{Priority: 5},
   225  			filterMask: RT_FILTER_PRIORITY,
   226  			preRun: func() *Rule {
   227  				r := NewRule()
   228  				r.Src = srcNet
   229  				r.Priority = 5
   230  				r.Table = 1
   231  				RuleAdd(r)
   232  
   233  				for i := 2; i < 5; i++ {
   234  					rc := *r // Create almost identical copy
   235  					rc.Table = i
   236  					RuleAdd(&rc)
   237  				}
   238  
   239  				return r
   240  			},
   241  			postRun: func(r *Rule) {
   242  				RuleDel(r)
   243  
   244  				for i := 2; i < 5; i++ {
   245  					rc := *r // Delete the almost identical copy
   246  					rc.Table = -1
   247  					RuleDel(&rc)
   248  				}
   249  			},
   250  			setupWant: func(r *Rule) ([]Rule, bool) {
   251  				rs := []Rule{}
   252  				rs = append(rs, *r)
   253  
   254  				for i := 2; i < 5; i++ {
   255  					rc := *r // Append the almost identical copy
   256  					rc.Table = i
   257  					rs = append(rs, rc)
   258  				}
   259  
   260  				return rs, false
   261  			},
   262  		},
   263  		{
   264  			name:       "returns rules filtered by Table",
   265  			ruleFilter: &Rule{Table: 199},
   266  			filterMask: RT_FILTER_TABLE,
   267  			preRun: func() *Rule {
   268  				r := NewRule()
   269  				r.Src = srcNet
   270  				r.Priority = 1 // Must add priority otherwise it's auto-assigned
   271  				r.Table = 199
   272  				RuleAdd(r)
   273  				return r
   274  			},
   275  			postRun: func(r *Rule) { RuleDel(r) },
   276  			setupWant: func(r *Rule) ([]Rule, bool) {
   277  				return []Rule{*r}, false
   278  			},
   279  		},
   280  		{
   281  			name:       "returns rules filtered by Mask",
   282  			ruleFilter: &Rule{Mask: 0x5},
   283  			filterMask: RT_FILTER_MASK,
   284  			preRun: func() *Rule {
   285  				r := NewRule()
   286  				r.Src = srcNet
   287  				r.Priority = 1 // Must add priority and table otherwise it's auto-assigned
   288  				r.Table = 1
   289  				r.Mask = 0x5
   290  				RuleAdd(r)
   291  				return r
   292  			},
   293  			postRun: func(r *Rule) { RuleDel(r) },
   294  			setupWant: func(r *Rule) ([]Rule, bool) {
   295  				return []Rule{*r}, false
   296  			},
   297  		},
   298  		{
   299  			name:       "returns rules filtered by Mark",
   300  			ruleFilter: &Rule{Mark: 0xbb, MarkSet: true},
   301  			filterMask: RT_FILTER_MARK,
   302  			preRun: func() *Rule {
   303  				r := NewRule()
   304  				r.Src = srcNet
   305  				r.Priority = 1 // Must add priority, table, mask otherwise it's auto-assigned
   306  				r.Table = 1
   307  				r.Mask = 0xff
   308  				r.Mark = 0xbb
   309  				RuleAdd(r)
   310  				return r
   311  			},
   312  			postRun: func(r *Rule) { RuleDel(r) },
   313  			setupWant: func(r *Rule) ([]Rule, bool) {
   314  				return []Rule{*r}, false
   315  			},
   316  		},
   317  		{
   318  			name:       "returns rules filtered by Tos",
   319  			ruleFilter: &Rule{Tos: 12},
   320  			filterMask: RT_FILTER_TOS,
   321  			preRun: func() *Rule {
   322  				r := NewRule()
   323  				r.Src = srcNet
   324  				r.Priority = 1 // Must add priority, table, mask otherwise it's auto-assigned
   325  				r.Table = 12
   326  				r.Tos = 12 // Tos must equal table
   327  				RuleAdd(r)
   328  				return r
   329  			},
   330  			postRun: func(r *Rule) { RuleDel(r) },
   331  			setupWant: func(r *Rule) ([]Rule, bool) {
   332  				return []Rule{*r}, false
   333  			},
   334  		},
   335  	}
   336  	for _, tt := range tests {
   337  		t.Run(tt.name, func(t *testing.T) {
   338  			rule := tt.preRun()
   339  			rules, err := RuleListFiltered(family, tt.ruleFilter, tt.filterMask)
   340  			tt.postRun(rule)
   341  
   342  			wantRules, wantErr := tt.setupWant(rule)
   343  
   344  			if len(wantRules) != len(rules) {
   345  				t.Errorf("Expected len: %d, got: %d", len(wantRules), len(rules))
   346  			} else {
   347  				for i := range wantRules {
   348  					if !ruleEquals(wantRules[i], rules[i]) {
   349  						t.Errorf("Rules mismatch, want %v, got %v", wantRules[i], rules[i])
   350  					}
   351  				}
   352  			}
   353  
   354  			if (err != nil) != wantErr {
   355  				t.Errorf("Error expectation not met, want %v, got %v", (err != nil), wantErr)
   356  			}
   357  		})
   358  	}
   359  }
   360  
   361  func TestRuleString(t *testing.T) {
   362  	t.Parallel()
   363  
   364  	testCases := map[string]struct {
   365  		r Rule
   366  		s string
   367  	}{
   368  		"empty rule": {
   369  			s: "ip rule 0: from all to all table 0",
   370  		},
   371  		"rule with src and dst equivalent to <nil>": {
   372  			r: Rule{
   373  				Priority: 100,
   374  				Table:    99,
   375  			},
   376  			s: "ip rule 100: from all to all table 99",
   377  		},
   378  		"rule with src and dst": {
   379  			r: Rule{
   380  				Priority: 100,
   381  				Src:      netip.MustParsePrefix("10.0.0.0/24"),
   382  				Dst:      netip.MustParsePrefix("20.0.0.0/24"),
   383  				Table:    99,
   384  			},
   385  			s: "ip rule 100: from 10.0.0.0/24 to 20.0.0.0/24 table 99",
   386  		},
   387  	}
   388  
   389  	for name, testCase := range testCases {
   390  		testCase := testCase
   391  		t.Run(name, func(t *testing.T) {
   392  			t.Parallel()
   393  
   394  			s := testCase.r.String()
   395  
   396  			if s != testCase.s {
   397  				t.Errorf("expected %q but got %q", testCase.s, s)
   398  			}
   399  		})
   400  	}
   401  }
   402  
   403  func ruleExists(rules []Rule, rule Rule) bool {
   404  	for i := range rules {
   405  		if ruleEquals(rules[i], rule) {
   406  			return true
   407  		}
   408  	}
   409  
   410  	return false
   411  }
   412  
   413  func ruleEquals(a, b Rule) bool {
   414  	return a.Table == b.Table &&
   415  		((!a.Src.IsValid() && !b.Src.IsValid()) ||
   416  			(a.Src.IsValid() && b.Src.IsValid() && a.Src.String() == b.Src.String())) &&
   417  		((!a.Dst.IsValid() && !b.Dst.IsValid()) ||
   418  			(a.Dst.IsValid() && b.Dst.IsValid() && a.Dst.String() == b.Dst.String())) &&
   419  		a.OifName == b.OifName &&
   420  		a.Priority == b.Priority &&
   421  		a.IifName == b.IifName &&
   422  		a.Invert == b.Invert &&
   423  		a.Tos == b.Tos &&
   424  		a.IPProto == b.IPProto
   425  }