github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/ipset_linux_test.go (about)

     1  package netlink
     2  
     3  import (
     4  	"bytes"
     5  	"io/ioutil"
     6  	"net"
     7  	"testing"
     8  
     9  	"github.com/sagernet/netlink/nl"
    10  	"golang.org/x/sys/unix"
    11  )
    12  
    13  func TestParseIpsetProtocolResult(t *testing.T) {
    14  	msgBytes, err := ioutil.ReadFile("testdata/ipset_protocol_result")
    15  	if err != nil {
    16  		t.Fatalf("reading test fixture failed: %v", err)
    17  	}
    18  
    19  	msg := ipsetUnserialize([][]byte{msgBytes})
    20  	if msg.Protocol != 6 {
    21  		t.Errorf("expected msg.Protocol to equal 6, got %d", msg.Protocol)
    22  	}
    23  }
    24  
    25  func TestParseIpsetListResult(t *testing.T) {
    26  	msgBytes, err := ioutil.ReadFile("testdata/ipset_list_result")
    27  	if err != nil {
    28  		t.Fatalf("reading test fixture failed: %v", err)
    29  	}
    30  
    31  	msg := ipsetUnserialize([][]byte{msgBytes})
    32  	if msg.SetName != "clients" {
    33  		t.Errorf(`expected SetName to equal "clients", got %q`, msg.SetName)
    34  	}
    35  	if msg.TypeName != "hash:mac" {
    36  		t.Errorf(`expected TypeName to equal "hash:mac", got %q`, msg.TypeName)
    37  	}
    38  	if msg.Protocol != 6 {
    39  		t.Errorf("expected Protocol to equal 6, got %d", msg.Protocol)
    40  	}
    41  	if msg.References != 0 {
    42  		t.Errorf("expected References to equal 0, got %d", msg.References)
    43  	}
    44  	if msg.NumEntries != 2 {
    45  		t.Errorf("expected NumEntries to equal 2, got %d", msg.NumEntries)
    46  	}
    47  	if msg.HashSize != 1024 {
    48  		t.Errorf("expected HashSize to equal 1024, got %d", msg.HashSize)
    49  	}
    50  	if *msg.Timeout != 3600 {
    51  		t.Errorf("expected Timeout to equal 3600, got %d", *msg.Timeout)
    52  	}
    53  	if msg.MaxElements != 65536 {
    54  		t.Errorf("expected MaxElements to equal 65536, got %d", msg.MaxElements)
    55  	}
    56  	if msg.CadtFlags != nl.IPSET_FLAG_WITH_COMMENT|nl.IPSET_FLAG_WITH_COUNTERS {
    57  		t.Error("expected CadtFlags to be IPSET_FLAG_WITH_COMMENT and IPSET_FLAG_WITH_COUNTERS")
    58  	}
    59  	if len(msg.Entries) != 2 {
    60  		t.Fatalf("expected 2 Entries, got %d", len(msg.Entries))
    61  	}
    62  
    63  	// first entry
    64  	ent := msg.Entries[0]
    65  	if int(*ent.Timeout) != 3577 {
    66  		t.Errorf("expected Timeout for first entry to equal 3577, got %d", *ent.Timeout)
    67  	}
    68  	if int(*ent.Bytes) != 4121 {
    69  		t.Errorf("expected Bytes for first entry to equal 4121, got %d", *ent.Bytes)
    70  	}
    71  	if int(*ent.Packets) != 42 {
    72  		t.Errorf("expected Packets for first entry to equal 42, got %d", *ent.Packets)
    73  	}
    74  	if ent.Comment != "foo bar" {
    75  		t.Errorf("unexpected Comment for first entry: %q", ent.Comment)
    76  	}
    77  	expectedMAC := net.HardwareAddr{0xde, 0xad, 0x0, 0x0, 0xbe, 0xef}
    78  	if !bytes.Equal(ent.MAC, expectedMAC) {
    79  		t.Errorf("expected MAC for first entry to be %s, got %s", expectedMAC.String(), ent.MAC.String())
    80  	}
    81  
    82  	// second entry
    83  	ent = msg.Entries[1]
    84  	expectedMAC = net.HardwareAddr{0x1, 0x2, 0x3, 0x0, 0x1, 0x2}
    85  	if !bytes.Equal(ent.MAC, expectedMAC) {
    86  		t.Errorf("expected MAC for second entry to be %s, got %s", expectedMAC.String(), ent.MAC.String())
    87  	}
    88  }
    89  
    90  func TestIpsetCreateListAddDelDestroy(t *testing.T) {
    91  	tearDown := setUpNetlinkTest(t)
    92  	defer tearDown()
    93  	timeout := uint32(3)
    94  	err := IpsetCreate("my-test-ipset-1", "hash:ip", IpsetCreateOptions{
    95  		Replace:  true,
    96  		Timeout:  &timeout,
    97  		Counters: true,
    98  		Comments: false,
    99  		Skbinfo:  false,
   100  	})
   101  	if err != nil {
   102  		t.Fatal(err)
   103  	}
   104  
   105  	err = IpsetCreate("my-test-ipset-2", "hash:net", IpsetCreateOptions{
   106  		Replace:  true,
   107  		Timeout:  &timeout,
   108  		Counters: false,
   109  		Comments: true,
   110  		Skbinfo:  true,
   111  	})
   112  	if err != nil {
   113  		t.Fatal(err)
   114  	}
   115  
   116  	results, err := IpsetListAll()
   117  
   118  	if err != nil {
   119  		t.Fatal(err)
   120  	}
   121  
   122  	if len(results) != 2 {
   123  		t.Fatalf("expected 2 IPSets to be created, got %d", len(results))
   124  	}
   125  
   126  	if results[0].SetName != "my-test-ipset-1" {
   127  		t.Errorf("expected name to be 'my-test-ipset-1', but got '%s'", results[0].SetName)
   128  	}
   129  
   130  	if results[1].SetName != "my-test-ipset-2" {
   131  		t.Errorf("expected name to be 'my-test-ipset-2', but got '%s'", results[1].SetName)
   132  	}
   133  
   134  	if results[0].TypeName != "hash:ip" {
   135  		t.Errorf("expected type to be 'hash:ip', but got '%s'", results[0].TypeName)
   136  	}
   137  
   138  	if results[1].TypeName != "hash:net" {
   139  		t.Errorf("expected type to be 'hash:net', but got '%s'", results[1].TypeName)
   140  	}
   141  
   142  	if *results[0].Timeout != 3 {
   143  		t.Errorf("expected timeout to be 3, but got '%d'", *results[0].Timeout)
   144  	}
   145  
   146  	err = IpsetAdd("my-test-ipset-1", &IPSetEntry{
   147  		Comment: "test comment",
   148  		IP:      net.ParseIP("10.99.99.99").To4(),
   149  		Replace: false,
   150  	})
   151  
   152  	if err != nil {
   153  		t.Fatal(err)
   154  	}
   155  
   156  	result, err := IpsetList("my-test-ipset-1")
   157  
   158  	if err != nil {
   159  		t.Fatal(err)
   160  	}
   161  
   162  	if len(result.Entries) != 1 {
   163  		t.Fatalf("expected 1 entry be created, got '%d'", len(result.Entries))
   164  	}
   165  	if result.Entries[0].IP.String() != "10.99.99.99" {
   166  		t.Fatalf("expected entry to be '10.99.99.99', got '%s'", result.Entries[0].IP.String())
   167  	}
   168  
   169  	if result.Entries[0].Comment != "test comment" {
   170  		// This is only supported in the kernel module from revision 2 or 4, so comments may be ignored.
   171  		t.Logf("expected comment to be 'test comment', got '%s'", result.Entries[0].Comment)
   172  	}
   173  
   174  	err = IpsetDel("my-test-ipset-1", &IPSetEntry{
   175  		Comment: "test comment",
   176  		IP:      net.ParseIP("10.99.99.99").To4(),
   177  	})
   178  	if err != nil {
   179  		t.Fatal(err)
   180  	}
   181  
   182  	result, err = IpsetList("my-test-ipset-1")
   183  	if err != nil {
   184  		t.Fatal(err)
   185  	}
   186  
   187  	if len(result.Entries) != 0 {
   188  		t.Fatalf("expected 0 entries to exist, got %d", len(result.Entries))
   189  	}
   190  
   191  	err = IpsetDestroy("my-test-ipset-1")
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   195  
   196  	err = IpsetDestroy("my-test-ipset-2")
   197  	if err != nil {
   198  		t.Fatal(err)
   199  	}
   200  }
   201  
   202  func TestIpsetCreateListAddDelDestroyWithTestCases(t *testing.T) {
   203  	timeout := uint32(3)
   204  	protocalTCP := uint8(unix.IPPROTO_TCP)
   205  	port := uint16(80)
   206  
   207  	testCases := []struct {
   208  		desc     string
   209  		setname  string
   210  		typename string
   211  		options  IpsetCreateOptions
   212  		entry    *IPSetEntry
   213  	}{
   214  		{
   215  			desc:     "Type-hash:ip",
   216  			setname:  "my-test-ipset-1",
   217  			typename: "hash:ip",
   218  			options: IpsetCreateOptions{
   219  				Replace:  true,
   220  				Timeout:  &timeout,
   221  				Counters: true,
   222  				Comments: false,
   223  				Skbinfo:  false,
   224  			},
   225  			entry: &IPSetEntry{
   226  				Comment: "test comment",
   227  				IP:      net.ParseIP("10.99.99.99").To4(),
   228  				Replace: false,
   229  			},
   230  		},
   231  		{
   232  			desc:     "Type-hash:net",
   233  			setname:  "my-test-ipset-2",
   234  			typename: "hash:net",
   235  			options: IpsetCreateOptions{
   236  				Replace:  true,
   237  				Timeout:  &timeout,
   238  				Counters: false,
   239  				Comments: true,
   240  				Skbinfo:  true,
   241  			},
   242  			entry: &IPSetEntry{
   243  				Comment: "test comment",
   244  				IP:      net.ParseIP("10.99.99.0").To4(),
   245  				CIDR:    24,
   246  				Replace: false,
   247  			},
   248  		},
   249  		{
   250  			desc:     "Type-hash:net,net",
   251  			setname:  "my-test-ipset-4",
   252  			typename: "hash:net,net",
   253  			options: IpsetCreateOptions{
   254  				Replace:  true,
   255  				Timeout:  &timeout,
   256  				Counters: false,
   257  				Comments: true,
   258  				Skbinfo:  true,
   259  			},
   260  			entry: &IPSetEntry{
   261  				Comment: "test comment",
   262  				IP:      net.ParseIP("10.99.99.0").To4(),
   263  				CIDR:    24,
   264  				IP2:     net.ParseIP("10.99.0.0").To4(),
   265  				CIDR2:   24,
   266  				Replace: false,
   267  			},
   268  		},
   269  		{
   270  			desc:     "Type-hash:ip,ip",
   271  			setname:  "my-test-ipset-5",
   272  			typename: "hash:net,net",
   273  			options: IpsetCreateOptions{
   274  				Replace:  true,
   275  				Timeout:  &timeout,
   276  				Counters: false,
   277  				Comments: true,
   278  				Skbinfo:  true,
   279  			},
   280  			entry: &IPSetEntry{
   281  				Comment: "test comment",
   282  				IP:      net.ParseIP("10.99.99.0").To4(),
   283  				IP2:     net.ParseIP("10.99.0.0").To4(),
   284  				Replace: false,
   285  			},
   286  		},
   287  		{
   288  			desc:     "Type-hash:ip,port",
   289  			setname:  "my-test-ipset-6",
   290  			typename: "hash:ip,port",
   291  			options: IpsetCreateOptions{
   292  				Replace:  true,
   293  				Timeout:  &timeout,
   294  				Counters: false,
   295  				Comments: true,
   296  				Skbinfo:  true,
   297  			},
   298  			entry: &IPSetEntry{
   299  				Comment:  "test comment",
   300  				IP:       net.ParseIP("10.99.99.1").To4(),
   301  				Protocol: &protocalTCP,
   302  				Port:     &port,
   303  				Replace:  false,
   304  			},
   305  		},
   306  		{
   307  			desc:     "Type-hash:net,port,net",
   308  			setname:  "my-test-ipset-7",
   309  			typename: "hash:net,port,net",
   310  			options: IpsetCreateOptions{
   311  				Replace:  true,
   312  				Timeout:  &timeout,
   313  				Counters: false,
   314  				Comments: true,
   315  				Skbinfo:  true,
   316  			},
   317  			entry: &IPSetEntry{
   318  				Comment:  "test comment",
   319  				IP:       net.ParseIP("10.99.99.0").To4(),
   320  				CIDR:     24,
   321  				IP2:      net.ParseIP("10.99.0.0").To4(),
   322  				CIDR2:    24,
   323  				Protocol: &protocalTCP,
   324  				Port:     &port,
   325  				Replace:  false,
   326  			},
   327  		},
   328  		{
   329  			desc:     "Type-hash:mac",
   330  			setname:  "my-test-ipset-8",
   331  			typename: "hash:mac",
   332  			options: IpsetCreateOptions{
   333  				Replace:  true,
   334  				Timeout:  &timeout,
   335  				Counters: true,
   336  				Comments: false,
   337  				Skbinfo:  false,
   338  			},
   339  			entry: &IPSetEntry{
   340  				Comment: "test comment",
   341  				MAC:     net.HardwareAddr{0x26, 0x6f, 0x0d, 0x5b, 0xc1, 0x9d},
   342  				Replace: false,
   343  			},
   344  		},
   345  		{
   346  			desc:     "Type-hash:net,iface",
   347  			setname:  "my-test-ipset-9",
   348  			typename: "hash:net,iface",
   349  			options: IpsetCreateOptions{
   350  				Replace:  true,
   351  				Timeout:  &timeout,
   352  				Counters: true,
   353  				Comments: false,
   354  				Skbinfo:  false,
   355  			},
   356  			entry: &IPSetEntry{
   357  				Comment: "test comment",
   358  				IP:      net.ParseIP("10.99.99.0").To4(),
   359  				CIDR:    24,
   360  				IFace:   "eth0",
   361  				Replace: false,
   362  			},
   363  		},
   364  		{
   365  			desc:     "Type-hash:ip,mark",
   366  			setname:  "my-test-ipset-10",
   367  			typename: "hash:ip,mark",
   368  			options: IpsetCreateOptions{
   369  				Replace:  true,
   370  				Timeout:  &timeout,
   371  				Counters: true,
   372  				Comments: false,
   373  				Skbinfo:  false,
   374  			},
   375  			entry: &IPSetEntry{
   376  				Comment: "test comment",
   377  				IP:      net.ParseIP("10.99.99.0").To4(),
   378  				Mark:    &timeout,
   379  				Replace: false,
   380  			},
   381  		},
   382  	}
   383  
   384  	for _, tC := range testCases {
   385  		t.Run(tC.desc, func(t *testing.T) {
   386  			tearDown := setUpNetlinkTest(t)
   387  			defer tearDown()
   388  
   389  			err := IpsetCreate(tC.setname, tC.typename, tC.options)
   390  			if err != nil {
   391  				t.Fatal(err)
   392  			}
   393  
   394  			result, err := IpsetList(tC.setname)
   395  			if err != nil {
   396  				t.Fatal(err)
   397  			}
   398  
   399  			if result.SetName != tC.setname {
   400  				t.Errorf("expected name to be '%s', but got '%s'", tC.setname, result.SetName)
   401  			}
   402  
   403  			if result.TypeName != tC.typename {
   404  				t.Errorf("expected type to be '%s', but got '%s'", tC.typename, result.TypeName)
   405  			}
   406  
   407  			if *result.Timeout != timeout {
   408  				t.Errorf("expected timeout to be %d, but got '%d'", timeout, *result.Timeout)
   409  			}
   410  
   411  			err = IpsetAdd(tC.setname, tC.entry)
   412  
   413  			if err != nil {
   414  				t.Error(result.Protocol, result.Family)
   415  				t.Fatal(err)
   416  			}
   417  
   418  			result, err = IpsetList(tC.setname)
   419  
   420  			if err != nil {
   421  				t.Fatal(err)
   422  			}
   423  
   424  			if len(result.Entries) != 1 {
   425  				t.Fatalf("expected 1 entry be created, got '%d'", len(result.Entries))
   426  			}
   427  
   428  			if tC.entry.IP != nil {
   429  				if !tC.entry.IP.Equal(result.Entries[0].IP) {
   430  					t.Fatalf("expected entry to be '%v', got '%v'", tC.entry.IP, result.Entries[0].IP)
   431  				}
   432  			}
   433  
   434  			if tC.entry.CIDR > 0 {
   435  				if result.Entries[0].CIDR != tC.entry.CIDR {
   436  					t.Fatalf("expected cidr to be '%d', got '%d'", tC.entry.CIDR, result.Entries[0].CIDR)
   437  				}
   438  			}
   439  
   440  			if tC.entry.IP2 != nil {
   441  				if !tC.entry.IP2.Equal(result.Entries[0].IP2) {
   442  					t.Fatalf("expected entry.ip2 to be '%v', got '%v'", tC.entry.IP2, result.Entries[0].IP2)
   443  				}
   444  			}
   445  
   446  			if tC.entry.CIDR2 > 0 {
   447  				if result.Entries[0].CIDR2 != tC.entry.CIDR2 {
   448  					t.Fatalf("expected cidr2 to be '%d', got '%d'", tC.entry.CIDR2, result.Entries[0].CIDR2)
   449  				}
   450  			}
   451  
   452  			if tC.entry.Port != nil {
   453  				if *result.Entries[0].Protocol != *tC.entry.Protocol {
   454  					t.Fatalf("expected protocol to be '%d', got '%d'", *tC.entry.Protocol, *result.Entries[0].Protocol)
   455  				}
   456  				if *result.Entries[0].Port != *tC.entry.Port {
   457  					t.Fatalf("expected port to be '%d', got '%d'", *tC.entry.Port, *result.Entries[0].Port)
   458  				}
   459  			}
   460  
   461  			if tC.entry.MAC != nil {
   462  				if result.Entries[0].MAC.String() != tC.entry.MAC.String() {
   463  					t.Fatalf("expected mac to be '%v', got '%v'", tC.entry.MAC, result.Entries[0].MAC)
   464  				}
   465  			}
   466  
   467  			if tC.entry.IFace != "" {
   468  				if result.Entries[0].IFace != tC.entry.IFace {
   469  					t.Fatalf("expected iface to be '%v', got '%v'", tC.entry.IFace, result.Entries[0].IFace)
   470  				}
   471  			}
   472  
   473  			if tC.entry.Mark != nil {
   474  				if *result.Entries[0].Mark != *tC.entry.Mark {
   475  					t.Fatalf("expected mark to be '%v', got '%v'", *tC.entry.Mark, *result.Entries[0].Mark)
   476  				}
   477  			}
   478  
   479  			if result.Entries[0].Comment != tC.entry.Comment {
   480  				// This is only supported in the kernel module from revision 2 or 4, so comments may be ignored.
   481  				t.Logf("expected comment to be '%s', got '%s'", tC.entry.Comment, result.Entries[0].Comment)
   482  			}
   483  
   484  			err = IpsetDel(tC.setname, tC.entry)
   485  			if err != nil {
   486  				t.Fatal(err)
   487  			}
   488  
   489  			result, err = IpsetList(tC.setname)
   490  			if err != nil {
   491  				t.Fatal(err)
   492  			}
   493  
   494  			if len(result.Entries) != 0 {
   495  				t.Fatalf("expected 0 entries to exist, got %d", len(result.Entries))
   496  			}
   497  
   498  			err = IpsetDestroy(tC.setname)
   499  			if err != nil {
   500  				t.Fatal(err)
   501  			}
   502  		})
   503  	}
   504  }
   505  
   506  func TestIpsetBitmapCreateListWithTestCases(t *testing.T) {
   507  	timeout := uint32(3)
   508  
   509  	testCases := []struct {
   510  		desc     string
   511  		setname  string
   512  		typename string
   513  		options  IpsetCreateOptions
   514  		entry    *IPSetEntry
   515  	}{
   516  		{
   517  			desc:     "Type-bitmap:port",
   518  			setname:  "my-test-ipset-11",
   519  			typename: "bitmap:port",
   520  			options: IpsetCreateOptions{
   521  				Replace:  true,
   522  				Timeout:  &timeout,
   523  				Counters: true,
   524  				Comments: false,
   525  				Skbinfo:  false,
   526  				PortFrom: 100,
   527  				PortTo:   600,
   528  			},
   529  			entry: &IPSetEntry{
   530  				Comment: "test comment",
   531  				IP:      net.ParseIP("10.99.99.0").To4(),
   532  				CIDR:    26,
   533  				Mark:    &timeout,
   534  				Replace: false,
   535  			},
   536  		},
   537  	}
   538  
   539  	for _, tC := range testCases {
   540  		t.Run(tC.desc, func(t *testing.T) {
   541  			tearDown := setUpNetlinkTest(t)
   542  			defer tearDown()
   543  
   544  			err := IpsetCreate(tC.setname, tC.typename, tC.options)
   545  			if err != nil {
   546  				t.Fatal(err)
   547  			}
   548  
   549  			result, err := IpsetList(tC.setname)
   550  			if err != nil {
   551  				t.Fatal(err)
   552  			}
   553  
   554  			if tC.typename == "bitmap:port" {
   555  				if result.PortFrom != tC.options.PortFrom || result.PortTo != tC.options.PortTo {
   556  					t.Fatalf("expected port range %d-%d, got %d-%d", tC.options.PortFrom, tC.options.PortTo, result.PortFrom, result.PortTo)
   557  				}
   558  			} else if tC.typename == "bitmap:ip" {
   559  				if result.IPFrom == nil || result.IPTo == nil || result.IPFrom.Equal(tC.options.IPFrom) || result.IPTo.Equal(tC.options.IPTo) {
   560  					t.Fatalf("expected ip range %v-%v, got %v-%v", tC.options.IPFrom, tC.options.IPTo, result.IPFrom, result.IPTo)
   561  				}
   562  			}
   563  
   564  		})
   565  	}
   566  }
   567  
   568  func TestIpsetSwap(t *testing.T) {
   569  	tearDown := setUpNetlinkTest(t)
   570  	defer tearDown()
   571  
   572  	ipset1 := "my-test-ipset-swap-1"
   573  	ipset2 := "my-test-ipset-swap-2"
   574  
   575  	err := IpsetCreate(ipset1, "hash:ip", IpsetCreateOptions{
   576  		Replace: true,
   577  	})
   578  	if err != nil {
   579  		t.Fatal(err)
   580  	}
   581  	defer func() {
   582  		_ = IpsetDestroy(ipset1)
   583  	}()
   584  
   585  	err = IpsetCreate(ipset2, "hash:ip", IpsetCreateOptions{
   586  		Replace: true,
   587  	})
   588  	if err != nil {
   589  		t.Fatal(err)
   590  	}
   591  	defer func() {
   592  		_ = IpsetDestroy(ipset2)
   593  	}()
   594  
   595  	err = IpsetAdd(ipset1, &IPSetEntry{
   596  		IP: net.ParseIP("10.99.99.99").To4(),
   597  	})
   598  	if err != nil {
   599  		t.Fatal(err)
   600  	}
   601  
   602  	assertHasOneEntry := func(name string) {
   603  		result, err := IpsetList(name)
   604  		if err != nil {
   605  			t.Fatal(err)
   606  		}
   607  		if len(result.Entries) != 1 {
   608  			t.Fatalf("expected 1 entry be created, got '%d'", len(result.Entries))
   609  		}
   610  		if result.Entries[0].IP.String() != "10.99.99.99" {
   611  			t.Fatalf("expected entry to be '10.99.99.99', got '%s'", result.Entries[0].IP.String())
   612  		}
   613  	}
   614  
   615  	assertIsEmpty := func(name string) {
   616  		result, err := IpsetList(name)
   617  		if err != nil {
   618  			t.Fatal(err)
   619  		}
   620  		if len(result.Entries) != 0 {
   621  			t.Fatalf("expected 0 entry be created, got '%d'", len(result.Entries))
   622  		}
   623  	}
   624  
   625  	assertHasOneEntry(ipset1)
   626  	assertIsEmpty(ipset2)
   627  
   628  	err = IpsetSwap(ipset1, ipset2)
   629  	if err != nil {
   630  		t.Fatal(err)
   631  	}
   632  
   633  	assertIsEmpty(ipset1)
   634  	assertHasOneEntry(ipset2)
   635  }