github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/cmd/ipset-test/main.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package main
     5  
     6  import (
     7  	"flag"
     8  	"fmt"
     9  	"log"
    10  	"net"
    11  	"os"
    12  	"sort"
    13  
    14  	"github.com/sagernet/netlink"
    15  )
    16  
    17  type command struct {
    18  	Function    func([]string)
    19  	Description string
    20  	ArgCount    int
    21  }
    22  
    23  var (
    24  	commands = map[string]command{
    25  		"protocol": {cmdProtocol, "prints the protocol version", 0},
    26  		"create":   {cmdCreate, "creates a new ipset", 2},
    27  		"destroy":  {cmdDestroy, "creates a new ipset", 1},
    28  		"list":     {cmdList, "list specific ipset", 1},
    29  		"listall":  {cmdListAll, "list all ipsets", 0},
    30  		"add":      {cmdAddDel(netlink.IpsetAdd), "add entry", 2},
    31  		"del":      {cmdAddDel(netlink.IpsetDel), "delete entry", 2},
    32  	}
    33  
    34  	timeoutVal   *uint32
    35  	timeout      = flag.Int("timeout", -1, "timeout, negative means omit the argument")
    36  	comment      = flag.String("comment", "", "comment")
    37  	withComments = flag.Bool("with-comments", false, "create set with comment support")
    38  	withCounters = flag.Bool("with-counters", false, "create set with counters support")
    39  	withSkbinfo  = flag.Bool("with-skbinfo", false, "create set with skbinfo support")
    40  	replace      = flag.Bool("replace", false, "replace existing set/entry")
    41  )
    42  
    43  func main() {
    44  	flag.Parse()
    45  	args := flag.Args()
    46  
    47  	if len(args) < 1 {
    48  		printUsage()
    49  		os.Exit(1)
    50  	}
    51  
    52  	if *timeout >= 0 {
    53  		v := uint32(*timeout)
    54  		timeoutVal = &v
    55  	}
    56  
    57  	log.SetFlags(log.Lshortfile)
    58  
    59  	cmdName := args[0]
    60  	args = args[1:]
    61  
    62  	cmd, exist := commands[cmdName]
    63  	if !exist {
    64  		fmt.Printf("Unknown command '%s'\n\n", cmdName)
    65  		printUsage()
    66  		os.Exit(1)
    67  	}
    68  
    69  	if cmd.ArgCount != len(args) {
    70  		fmt.Printf("Invalid number of arguments. expected=%d given=%d\n", cmd.ArgCount, len(args))
    71  		os.Exit(1)
    72  	}
    73  
    74  	cmd.Function(args)
    75  }
    76  
    77  func printUsage() {
    78  	fmt.Printf("Usage: %s COMMAND [args] [-flags]\n\n", os.Args[0])
    79  	names := make([]string, 0, len(commands))
    80  	for name := range commands {
    81  		names = append(names, name)
    82  	}
    83  	sort.Strings(names)
    84  	fmt.Println("Available commands:")
    85  	for _, name := range names {
    86  		fmt.Printf("  %-15v %s\n", name, commands[name].Description)
    87  	}
    88  	fmt.Println("\nAvailable flags:")
    89  	flag.PrintDefaults()
    90  }
    91  
    92  func cmdProtocol(_ []string) {
    93  	protocol, minProto, err := netlink.IpsetProtocol()
    94  	check(err)
    95  	log.Println("Protocol:", protocol, "min:", minProto)
    96  }
    97  
    98  func cmdCreate(args []string) {
    99  	err := netlink.IpsetCreate(args[0], args[1], netlink.IpsetCreateOptions{
   100  		Replace:  *replace,
   101  		Timeout:  timeoutVal,
   102  		Comments: *withComments,
   103  		Counters: *withCounters,
   104  		Skbinfo:  *withSkbinfo,
   105  	})
   106  	check(err)
   107  }
   108  
   109  func cmdDestroy(args []string) {
   110  	check(netlink.IpsetDestroy(args[0]))
   111  }
   112  
   113  func cmdList(args []string) {
   114  	result, err := netlink.IpsetList(args[0])
   115  	check(err)
   116  	log.Printf("%+v", result)
   117  }
   118  
   119  func cmdListAll(args []string) {
   120  	result, err := netlink.IpsetListAll()
   121  	check(err)
   122  	for _, ipset := range result {
   123  		log.Printf("%+v", ipset)
   124  	}
   125  }
   126  
   127  func cmdAddDel(f func(string, *netlink.IPSetEntry) error) func([]string) {
   128  	return func(args []string) {
   129  		setName := args[0]
   130  		element := args[1]
   131  
   132  		mac, _ := net.ParseMAC(element)
   133  		entry := netlink.IPSetEntry{
   134  			Timeout: timeoutVal,
   135  			MAC:     mac,
   136  			Comment: *comment,
   137  			Replace: *replace,
   138  		}
   139  
   140  		check(f(setName, &entry))
   141  	}
   142  }
   143  
   144  // panic on error
   145  func check(err error) {
   146  	if err != nil {
   147  		panic(err)
   148  	}
   149  }