github.com/TeaOSLab/EdgeNode@v1.3.8/internal/firewalls/nftables/chain_test.go (about)

     1  // Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
     2  //go:build linux
     3  
     4  package nftables_test
     5  
     6  import (
     7  	"github.com/TeaOSLab/EdgeNode/internal/firewalls/nftables"
     8  	"net"
     9  	"testing"
    10  )
    11  
    12  func getIPv4Chain(t *testing.T) *nftables.Chain {
    13  	conn, err := nftables.NewConn()
    14  	if err != nil {
    15  		t.Fatal(err)
    16  	}
    17  	table, err := conn.GetTable("test_ipv4", nftables.TableFamilyIPv4)
    18  	if err != nil {
    19  		if err == nftables.ErrTableNotFound {
    20  			table, err = conn.AddIPv4Table("test_ipv4")
    21  			if err != nil {
    22  				t.Fatal(err)
    23  			}
    24  		} else {
    25  			t.Fatal(err)
    26  		}
    27  	}
    28  
    29  	chain, err := table.GetChain("test_chain")
    30  	if err != nil {
    31  		if err == nftables.ErrChainNotFound {
    32  			chain, err = table.AddAcceptChain("test_chain")
    33  		}
    34  	}
    35  
    36  	if err != nil {
    37  		t.Fatal(err)
    38  	}
    39  
    40  	return chain
    41  }
    42  
    43  func TestChain_AddAcceptIPRule(t *testing.T) {
    44  	var chain = getIPv4Chain(t)
    45  	_, err := chain.AddAcceptIPv4Rule(net.ParseIP("192.168.2.40").To4(), nil)
    46  	if err != nil {
    47  		t.Fatal(err)
    48  	}
    49  }
    50  
    51  func TestChain_AddDropIPRule(t *testing.T) {
    52  	var chain = getIPv4Chain(t)
    53  	_, err := chain.AddDropIPv4Rule(net.ParseIP("192.168.2.31").To4(), nil)
    54  	if err != nil {
    55  		t.Fatal(err)
    56  	}
    57  }
    58  
    59  func TestChain_AddAcceptSetRule(t *testing.T) {
    60  	var chain = getIPv4Chain(t)
    61  	_, err := chain.AddAcceptIPv4SetRule("ipv4_black_set", nil)
    62  	if err != nil {
    63  		t.Fatal(err)
    64  	}
    65  }
    66  
    67  func TestChain_AddDropSetRule(t *testing.T) {
    68  	var chain = getIPv4Chain(t)
    69  	_, err := chain.AddDropIPv4SetRule("ipv4_black_set", nil)
    70  	if err != nil {
    71  		t.Fatal(err)
    72  	}
    73  }
    74  
    75  func TestChain_AddRejectSetRule(t *testing.T) {
    76  	var chain = getIPv4Chain(t)
    77  	_, err := chain.AddRejectIPv4SetRule("ipv4_black_set", nil)
    78  	if err != nil {
    79  		t.Fatal(err)
    80  	}
    81  }
    82  
    83  func TestChain_GetRuleWithUserData(t *testing.T) {
    84  	var chain = getIPv4Chain(t)
    85  	rule, err := chain.GetRuleWithUserData([]byte("test"))
    86  	if err != nil {
    87  		if err == nftables.ErrRuleNotFound {
    88  			t.Log("rule not found")
    89  			return
    90  		} else {
    91  			t.Fatal(err)
    92  		}
    93  	}
    94  	t.Log("rule:", rule)
    95  }
    96  
    97  func TestChain_GetRules(t *testing.T) {
    98  	var chain = getIPv4Chain(t)
    99  	rules, err := chain.GetRules()
   100  	if err != nil {
   101  		t.Fatal(err)
   102  	}
   103  	for _, rule := range rules {
   104  		t.Log("handle:", rule.Handle(), "set name:", rule.LookupSetName(),
   105  			"verdict:", rule.VerDict(), "user data:", string(rule.UserData()))
   106  	}
   107  }
   108  
   109  func TestChain_DeleteRule(t *testing.T) {
   110  	var chain = getIPv4Chain(t)
   111  	rule, err := chain.GetRuleWithUserData([]byte("test"))
   112  	if err != nil {
   113  		if err == nftables.ErrRuleNotFound {
   114  			t.Log("rule not found")
   115  			return
   116  		}
   117  		t.Fatal(err)
   118  	}
   119  	err = chain.DeleteRule(rule)
   120  	if err != nil {
   121  		t.Fatal(err)
   122  	}
   123  }
   124  
   125  func TestChain_Flush(t *testing.T) {
   126  	var chain = getIPv4Chain(t)
   127  	err := chain.Flush()
   128  	if err != nil {
   129  		t.Fatal(err)
   130  	}
   131  	t.Log("ok")
   132  }