github.com/vishvananda/netlink@v1.3.0/netlink_test.go (about) 1 //go:build linux 2 // +build linux 3 4 package netlink 5 6 import ( 7 "bytes" 8 "crypto/rand" 9 "encoding/hex" 10 "fmt" 11 "io/ioutil" 12 "log" 13 "os" 14 "os/exec" 15 "runtime" 16 "strings" 17 "testing" 18 19 "github.com/vishvananda/netlink/nl" 20 "github.com/vishvananda/netns" 21 "golang.org/x/sys/unix" 22 ) 23 24 type tearDownNetlinkTest func() 25 26 func skipUnlessRoot(t testing.TB) { 27 t.Helper() 28 29 if os.Getuid() != 0 { 30 t.Skip("Test requires root privileges.") 31 } 32 } 33 34 func skipUnlessKModuleLoaded(t *testing.T, moduleNames ...string) { 35 t.Helper() 36 file, err := ioutil.ReadFile("/proc/modules") 37 if err != nil { 38 t.Fatal("Failed to open /proc/modules", err) 39 } 40 41 foundRequiredMods := make(map[string]bool) 42 lines := strings.Split(string(file), "\n") 43 44 for _, name := range moduleNames { 45 foundRequiredMods[name] = false 46 for _, line := range lines { 47 n := strings.Split(line, " ")[0] 48 if n == name { 49 foundRequiredMods[name] = true 50 break 51 } 52 } 53 } 54 55 failed := false 56 for _, name := range moduleNames { 57 if found, _ := foundRequiredMods[name]; !found { 58 t.Logf("Test requires missing kmodule %q.", name) 59 failed = true 60 } 61 } 62 if failed { 63 t.SkipNow() 64 } 65 } 66 67 func setUpNetlinkTest(t testing.TB) tearDownNetlinkTest { 68 skipUnlessRoot(t) 69 70 // new temporary namespace so we don't pollute the host 71 // lock thread since the namespace is thread local 72 runtime.LockOSThread() 73 var err error 74 ns, err := netns.New() 75 if err != nil { 76 t.Fatal("Failed to create newns", ns) 77 } 78 79 return func() { 80 ns.Close() 81 runtime.UnlockOSThread() 82 } 83 } 84 85 // setUpNamedNetlinkTest create a temporary named names space with a random name 86 func setUpNamedNetlinkTest(t *testing.T) (string, tearDownNetlinkTest) { 87 skipUnlessRoot(t) 88 89 origNS, err := netns.Get() 90 if err != nil { 91 t.Fatal("Failed saving orig namespace") 92 } 93 94 // create a random name 95 rnd := make([]byte, 4) 96 if _, err := rand.Read(rnd); err != nil { 97 t.Fatal("failed creating random ns name") 98 } 99 name := "netlinktest-" + hex.EncodeToString(rnd) 100 101 ns, err := netns.NewNamed(name) 102 if err != nil { 103 t.Fatal("Failed to create new ns", err) 104 } 105 106 runtime.LockOSThread() 107 cleanup := func() { 108 ns.Close() 109 netns.DeleteNamed(name) 110 netns.Set(origNS) 111 runtime.UnlockOSThread() 112 } 113 114 if err := netns.Set(ns); err != nil { 115 cleanup() 116 t.Fatal("Failed entering new namespace", err) 117 } 118 119 return name, cleanup 120 } 121 122 func setUpNetlinkTestWithLoopback(t *testing.T) tearDownNetlinkTest { 123 skipUnlessRoot(t) 124 125 runtime.LockOSThread() 126 ns, err := netns.New() 127 if err != nil { 128 t.Fatal("Failed to create new netns", ns) 129 } 130 131 link, err := LinkByName("lo") 132 if err != nil { 133 t.Fatalf("Failed to find \"lo\" in new netns: %v", err) 134 } 135 if err := LinkSetUp(link); err != nil { 136 t.Fatalf("Failed to bring up \"lo\" in new netns: %v", err) 137 } 138 139 return func() { 140 ns.Close() 141 runtime.UnlockOSThread() 142 } 143 } 144 145 func setUpF(t *testing.T, path, value string) { 146 file, err := os.Create(path) 147 if err != nil { 148 t.Fatalf("Failed to open %s: %s", path, err) 149 } 150 defer file.Close() 151 file.WriteString(value) 152 } 153 154 func setUpMPLSNetlinkTest(t *testing.T) tearDownNetlinkTest { 155 if _, err := os.Stat("/proc/sys/net/mpls/platform_labels"); err != nil { 156 t.Skip("Test requires MPLS support.") 157 } 158 f := setUpNetlinkTest(t) 159 setUpF(t, "/proc/sys/net/mpls/platform_labels", "1024") 160 setUpF(t, "/proc/sys/net/mpls/conf/lo/input", "1") 161 return f 162 } 163 164 func setUpSEG6NetlinkTest(t *testing.T) tearDownNetlinkTest { 165 // check if SEG6 options are enabled in Kernel Config 166 cmd := exec.Command("uname", "-r") 167 var out bytes.Buffer 168 cmd.Stdout = &out 169 if err := cmd.Run(); err != nil { 170 t.Fatal("Failed to run: uname -r") 171 } 172 s := []string{"/boot/config-", strings.TrimRight(out.String(), "\n")} 173 filename := strings.Join(s, "") 174 175 grepKey := func(key, fname string) (string, error) { 176 cmd := exec.Command("grep", key, filename) 177 var out bytes.Buffer 178 cmd.Stdout = &out 179 err := cmd.Run() // "err != nil" if no line matched with grep 180 return strings.TrimRight(out.String(), "\n"), err 181 } 182 key := string("CONFIG_IPV6_SEG6_LWTUNNEL=y") 183 if _, err := grepKey(key, filename); err != nil { 184 msg := "Skipped test because it requires SEG6_LWTUNNEL support." 185 log.Println(msg) 186 t.Skip(msg) 187 } 188 // Add CONFIG_IPV6_SEG6_HMAC to support seg6_hamc 189 // key := string("CONFIG_IPV6_SEG6_HMAC=y") 190 191 return setUpNetlinkTest(t) 192 } 193 194 func setUpNetlinkTestWithKModule(t *testing.T, moduleNames ...string) tearDownNetlinkTest { 195 skipUnlessKModuleLoaded(t, moduleNames...) 196 return setUpNetlinkTest(t) 197 } 198 func setUpNamedNetlinkTestWithKModule(t *testing.T, moduleNames ...string) (string, tearDownNetlinkTest) { 199 file, err := ioutil.ReadFile("/proc/modules") 200 if err != nil { 201 t.Fatal("Failed to open /proc/modules", err) 202 } 203 204 foundRequiredMods := make(map[string]bool) 205 lines := strings.Split(string(file), "\n") 206 207 for _, name := range moduleNames { 208 foundRequiredMods[name] = false 209 for _, line := range lines { 210 n := strings.Split(line, " ")[0] 211 if n == name { 212 foundRequiredMods[name] = true 213 break 214 } 215 } 216 } 217 218 failed := false 219 for _, name := range moduleNames { 220 if found, _ := foundRequiredMods[name]; !found { 221 t.Logf("Test requires missing kmodule %q.", name) 222 failed = true 223 } 224 } 225 if failed { 226 t.SkipNow() 227 } 228 229 return setUpNamedNetlinkTest(t) 230 } 231 232 func remountSysfs() error { 233 if err := unix.Mount("", "/", "none", unix.MS_SLAVE|unix.MS_REC, ""); err != nil { 234 return err 235 } 236 if err := unix.Unmount("/sys", unix.MNT_DETACH); err != nil { 237 return err 238 } 239 return unix.Mount("", "/sys", "sysfs", 0, "") 240 } 241 242 func minKernelRequired(t *testing.T, kernel, major int) { 243 t.Helper() 244 245 k, m, err := KernelVersion() 246 if err != nil { 247 t.Fatal(err) 248 } 249 if k < kernel || k == kernel && m < major { 250 t.Skipf("Host Kernel (%d.%d) does not meet test's minimum required version: (%d.%d)", 251 k, m, kernel, major) 252 } 253 } 254 255 func KernelVersion() (kernel, major int, err error) { 256 uts := unix.Utsname{} 257 if err = unix.Uname(&uts); err != nil { 258 return 259 } 260 261 ba := make([]byte, 0, len(uts.Release)) 262 for _, b := range uts.Release { 263 if b == 0 { 264 break 265 } 266 ba = append(ba, byte(b)) 267 } 268 var rest string 269 if n, _ := fmt.Sscanf(string(ba), "%d.%d%s", &kernel, &major, &rest); n < 2 { 270 err = fmt.Errorf("can't parse kernel version in %q", string(ba)) 271 } 272 return 273 } 274 275 func TestMain(m *testing.M) { 276 nl.EnableErrorMessageReporting = true 277 os.Exit(m.Run()) 278 }