github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/cmd/ipset-test/main.go (about) 1 //go:build linux 2 // +build linux 3 4 package main 5 6 import ( 7 "flag" 8 "fmt" 9 "log" 10 "net" 11 "os" 12 "sort" 13 14 "github.com/sagernet/netlink" 15 ) 16 17 type command struct { 18 Function func([]string) 19 Description string 20 ArgCount int 21 } 22 23 var ( 24 commands = map[string]command{ 25 "protocol": {cmdProtocol, "prints the protocol version", 0}, 26 "create": {cmdCreate, "creates a new ipset", 2}, 27 "destroy": {cmdDestroy, "creates a new ipset", 1}, 28 "list": {cmdList, "list specific ipset", 1}, 29 "listall": {cmdListAll, "list all ipsets", 0}, 30 "add": {cmdAddDel(netlink.IpsetAdd), "add entry", 2}, 31 "del": {cmdAddDel(netlink.IpsetDel), "delete entry", 2}, 32 } 33 34 timeoutVal *uint32 35 timeout = flag.Int("timeout", -1, "timeout, negative means omit the argument") 36 comment = flag.String("comment", "", "comment") 37 withComments = flag.Bool("with-comments", false, "create set with comment support") 38 withCounters = flag.Bool("with-counters", false, "create set with counters support") 39 withSkbinfo = flag.Bool("with-skbinfo", false, "create set with skbinfo support") 40 replace = flag.Bool("replace", false, "replace existing set/entry") 41 ) 42 43 func main() { 44 flag.Parse() 45 args := flag.Args() 46 47 if len(args) < 1 { 48 printUsage() 49 os.Exit(1) 50 } 51 52 if *timeout >= 0 { 53 v := uint32(*timeout) 54 timeoutVal = &v 55 } 56 57 log.SetFlags(log.Lshortfile) 58 59 cmdName := args[0] 60 args = args[1:] 61 62 cmd, exist := commands[cmdName] 63 if !exist { 64 fmt.Printf("Unknown command '%s'\n\n", cmdName) 65 printUsage() 66 os.Exit(1) 67 } 68 69 if cmd.ArgCount != len(args) { 70 fmt.Printf("Invalid number of arguments. expected=%d given=%d\n", cmd.ArgCount, len(args)) 71 os.Exit(1) 72 } 73 74 cmd.Function(args) 75 } 76 77 func printUsage() { 78 fmt.Printf("Usage: %s COMMAND [args] [-flags]\n\n", os.Args[0]) 79 names := make([]string, 0, len(commands)) 80 for name := range commands { 81 names = append(names, name) 82 } 83 sort.Strings(names) 84 fmt.Println("Available commands:") 85 for _, name := range names { 86 fmt.Printf(" %-15v %s\n", name, commands[name].Description) 87 } 88 fmt.Println("\nAvailable flags:") 89 flag.PrintDefaults() 90 } 91 92 func cmdProtocol(_ []string) { 93 protocol, minProto, err := netlink.IpsetProtocol() 94 check(err) 95 log.Println("Protocol:", protocol, "min:", minProto) 96 } 97 98 func cmdCreate(args []string) { 99 err := netlink.IpsetCreate(args[0], args[1], netlink.IpsetCreateOptions{ 100 Replace: *replace, 101 Timeout: timeoutVal, 102 Comments: *withComments, 103 Counters: *withCounters, 104 Skbinfo: *withSkbinfo, 105 }) 106 check(err) 107 } 108 109 func cmdDestroy(args []string) { 110 check(netlink.IpsetDestroy(args[0])) 111 } 112 113 func cmdList(args []string) { 114 result, err := netlink.IpsetList(args[0]) 115 check(err) 116 log.Printf("%+v", result) 117 } 118 119 func cmdListAll(args []string) { 120 result, err := netlink.IpsetListAll() 121 check(err) 122 for _, ipset := range result { 123 log.Printf("%+v", ipset) 124 } 125 } 126 127 func cmdAddDel(f func(string, *netlink.IPSetEntry) error) func([]string) { 128 return func(args []string) { 129 setName := args[0] 130 element := args[1] 131 132 mac, _ := net.ParseMAC(element) 133 entry := netlink.IPSetEntry{ 134 Timeout: timeoutVal, 135 MAC: mac, 136 Comment: *comment, 137 Replace: *replace, 138 } 139 140 check(f(setName, &entry)) 141 } 142 } 143 144 // panic on error 145 func check(err error) { 146 if err != nil { 147 panic(err) 148 } 149 }