github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/xfrm_policy_test.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package netlink
     5  
     6  import (
     7  	"bytes"
     8  	"net"
     9  	"testing"
    10  )
    11  
    12  const zeroCIDR = "0.0.0.0/0"
    13  
    14  func TestXfrmPolicyAddUpdateDel(t *testing.T) {
    15  	tearDown := setUpNetlinkTest(t)
    16  	defer tearDown()
    17  
    18  	policy := getPolicy()
    19  	if err := XfrmPolicyAdd(policy); err != nil {
    20  		t.Fatal(err)
    21  	}
    22  	policies, err := XfrmPolicyList(FAMILY_ALL)
    23  	if err != nil {
    24  		t.Fatal(err)
    25  	}
    26  
    27  	if len(policies) != 1 {
    28  		t.Fatal("Policy not added properly")
    29  	}
    30  
    31  	if !comparePolicies(policy, &policies[0]) {
    32  		t.Fatalf("unexpected policy returned.\nExpected: %v.\nGot %v", policy, policies[0])
    33  	}
    34  
    35  	if policies[0].Ifindex != 0 {
    36  		t.Fatalf("default policy has a non-zero interface index.\nGot %d", policies[0].Ifindex)
    37  	}
    38  
    39  	if policies[0].Ifid != 0 {
    40  		t.Fatalf("default policy has non-zero if_id.\nGot %d", policies[0].Ifid)
    41  	}
    42  
    43  	if policies[0].Action != XFRM_POLICY_ALLOW {
    44  		t.Fatalf("default policy has non-allow action.\nGot %s", policies[0].Action)
    45  	}
    46  
    47  	// Look for a specific policy
    48  	sp, err := XfrmPolicyGet(policy)
    49  	if err != nil {
    50  		t.Fatal(err)
    51  	}
    52  
    53  	if !comparePolicies(policy, sp) {
    54  		t.Fatalf("unexpected policy returned")
    55  	}
    56  
    57  	// Modify the policy
    58  	policy.Priority = 100
    59  	if err := XfrmPolicyUpdate(policy); err != nil {
    60  		t.Fatal(err)
    61  	}
    62  	sp, err = XfrmPolicyGet(policy)
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  	if sp.Priority != 100 {
    67  		t.Fatalf("failed to modify the policy")
    68  	}
    69  
    70  	if err = XfrmPolicyDel(policy); err != nil {
    71  		t.Fatal(err)
    72  	}
    73  
    74  	policies, err = XfrmPolicyList(FAMILY_ALL)
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  	if len(policies) != 0 {
    79  		t.Fatal("Policy not removed properly")
    80  	}
    81  
    82  	// Src and dst are not mandatory field. Creation should succeed
    83  	policy.Src = nil
    84  	policy.Dst = nil
    85  	if err = XfrmPolicyAdd(policy); err != nil {
    86  		t.Fatal(err)
    87  	}
    88  
    89  	sp, err = XfrmPolicyGet(policy)
    90  	if err != nil {
    91  		t.Fatal(err)
    92  	}
    93  
    94  	if !comparePolicies(policy, sp) {
    95  		t.Fatalf("unexpected policy returned")
    96  	}
    97  
    98  	if err = XfrmPolicyDel(policy); err != nil {
    99  		t.Fatal(err)
   100  	}
   101  
   102  	if _, err := XfrmPolicyGet(policy); err == nil {
   103  		t.Fatalf("Unexpected success")
   104  	}
   105  }
   106  
   107  func TestXfrmPolicyFlush(t *testing.T) {
   108  	defer setUpNetlinkTest(t)()
   109  
   110  	p1 := getPolicy()
   111  	if err := XfrmPolicyAdd(p1); err != nil {
   112  		t.Fatal(err)
   113  	}
   114  
   115  	p1.Dir = XFRM_DIR_IN
   116  	s := p1.Src
   117  	p1.Src = p1.Dst
   118  	p1.Dst = s
   119  	if err := XfrmPolicyAdd(p1); err != nil {
   120  		t.Fatal(err)
   121  	}
   122  
   123  	policies, err := XfrmPolicyList(FAMILY_ALL)
   124  	if err != nil {
   125  		t.Fatal(err)
   126  	}
   127  	if len(policies) != 2 {
   128  		t.Fatalf("unexpected number of policies: %d", len(policies))
   129  	}
   130  
   131  	if err := XfrmPolicyFlush(); err != nil {
   132  		t.Fatal(err)
   133  	}
   134  
   135  	policies, err = XfrmPolicyList(FAMILY_ALL)
   136  	if err != nil {
   137  		t.Fatal(err)
   138  	}
   139  	if len(policies) != 0 {
   140  		t.Fatalf("unexpected number of policies: %d", len(policies))
   141  	}
   142  
   143  }
   144  
   145  func TestXfrmPolicyBlockWithIfindex(t *testing.T) {
   146  	defer setUpNetlinkTest(t)()
   147  
   148  	pBlock := getPolicy()
   149  	pBlock.Action = XFRM_POLICY_BLOCK
   150  	pBlock.Ifindex = 1 // loopback interface
   151  	if err := XfrmPolicyAdd(pBlock); err != nil {
   152  		t.Fatal(err)
   153  	}
   154  	policies, err := XfrmPolicyList(FAMILY_ALL)
   155  	if err != nil {
   156  		t.Fatal(err)
   157  	}
   158  	if len(policies) != 1 {
   159  		t.Fatalf("unexpected number of policies: %d", len(policies))
   160  	}
   161  	if !comparePolicies(pBlock, &policies[0]) {
   162  		t.Fatalf("unexpected policy returned.\nExpected: %v.\nGot %v", pBlock, policies[0])
   163  	}
   164  	if err = XfrmPolicyDel(pBlock); err != nil {
   165  		t.Fatal(err)
   166  	}
   167  }
   168  
   169  func TestXfrmPolicyWithIfid(t *testing.T) {
   170  	minKernelRequired(t, 4, 19)
   171  	defer setUpNetlinkTest(t)()
   172  
   173  	pol := getPolicy()
   174  	pol.Ifid = 54321
   175  
   176  	if err := XfrmPolicyAdd(pol); err != nil {
   177  		t.Fatal(err)
   178  	}
   179  	policies, err := XfrmPolicyList(FAMILY_ALL)
   180  	if err != nil {
   181  		t.Fatal(err)
   182  	}
   183  	if len(policies) != 1 {
   184  		t.Fatalf("unexpected number of policies: %d", len(policies))
   185  	}
   186  	if !comparePolicies(pol, &policies[0]) {
   187  		t.Fatalf("unexpected policy returned.\nExpected: %v.\nGot %v", pol, policies[0])
   188  	}
   189  	if err = XfrmPolicyDel(&policies[0]); err != nil {
   190  		t.Fatal(err)
   191  	}
   192  }
   193  
   194  func TestXfrmPolicyWithOptional(t *testing.T) {
   195  	minKernelRequired(t, 4, 19)
   196  	defer setUpNetlinkTest(t)()
   197  
   198  	pol := getPolicy()
   199  	pol.Tmpls[0].Optional = 1
   200  
   201  	if err := XfrmPolicyAdd(pol); err != nil {
   202  		t.Fatal(err)
   203  	}
   204  	policies, err := XfrmPolicyList(FAMILY_ALL)
   205  	if err != nil {
   206  		t.Fatal(err)
   207  	}
   208  	if len(policies) != 1 {
   209  		t.Fatalf("unexpected number of policies: %d", len(policies))
   210  	}
   211  	if !comparePolicies(pol, &policies[0]) {
   212  		t.Fatalf("unexpected policy returned.\nExpected: %v.\nGot %v", pol, policies[0])
   213  	}
   214  	if err = XfrmPolicyDel(&policies[0]); err != nil {
   215  		t.Fatal(err)
   216  	}
   217  }
   218  
   219  func comparePolicies(a, b *XfrmPolicy) bool {
   220  	if a == b {
   221  		return true
   222  	}
   223  	if a == nil || b == nil {
   224  		return false
   225  	}
   226  	// Do not check Index which is assigned by kernel
   227  	return a.Dir == b.Dir && a.Priority == b.Priority &&
   228  		compareIPNet(a.Src, b.Src) && compareIPNet(a.Dst, b.Dst) &&
   229  		a.Action == b.Action && a.Ifindex == b.Ifindex &&
   230  		a.Mark.Value == b.Mark.Value && a.Mark.Mask == b.Mark.Mask &&
   231  		a.Ifid == b.Ifid && compareTemplates(a.Tmpls, b.Tmpls)
   232  }
   233  
   234  func compareTemplates(a, b []XfrmPolicyTmpl) bool {
   235  	if len(a) != len(b) {
   236  		return false
   237  	}
   238  	for i, ta := range a {
   239  		tb := b[i]
   240  		if !ta.Dst.Equal(tb.Dst) || !ta.Src.Equal(tb.Src) || ta.Spi != tb.Spi ||
   241  			ta.Mode != tb.Mode || ta.Reqid != tb.Reqid || ta.Proto != tb.Proto ||
   242  			ta.Optional != tb.Optional {
   243  			return false
   244  		}
   245  	}
   246  	return true
   247  }
   248  
   249  func compareIPNet(a, b *net.IPNet) bool {
   250  	if a == b {
   251  		return true
   252  	}
   253  	// For unspecified src/dst parseXfrmPolicy would set the zero address cidr
   254  	if (a == nil && b.String() == zeroCIDR) || (b == nil && a.String() == zeroCIDR) {
   255  		return true
   256  	}
   257  	if a == nil || b == nil {
   258  		return false
   259  	}
   260  	return a.IP.Equal(b.IP) && bytes.Equal(a.Mask, b.Mask)
   261  }
   262  
   263  func getPolicy() *XfrmPolicy {
   264  	src, _ := ParseIPNet("127.1.1.1/32")
   265  	dst, _ := ParseIPNet("127.1.1.2/32")
   266  	policy := &XfrmPolicy{
   267  		Src:     src,
   268  		Dst:     dst,
   269  		Proto:   17,
   270  		DstPort: 1234,
   271  		SrcPort: 5678,
   272  		Dir:     XFRM_DIR_OUT,
   273  		Mark: &XfrmMark{
   274  			Value: 0xabff22,
   275  			Mask:  0xffffffff,
   276  		},
   277  		Priority: 10,
   278  	}
   279  	tmpl := XfrmPolicyTmpl{
   280  		Src:   net.ParseIP("127.0.0.1"),
   281  		Dst:   net.ParseIP("127.0.0.2"),
   282  		Proto: XFRM_PROTO_ESP,
   283  		Mode:  XFRM_MODE_TUNNEL,
   284  		Spi:   0x1bcdef99,
   285  	}
   286  	policy.Tmpls = append(policy.Tmpls, tmpl)
   287  	return policy
   288  }