github.com/vishvananda/netlink@v1.3.0/netlink_test.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package netlink
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/rand"
     9  	"encoding/hex"
    10  	"fmt"
    11  	"io/ioutil"
    12  	"log"
    13  	"os"
    14  	"os/exec"
    15  	"runtime"
    16  	"strings"
    17  	"testing"
    18  
    19  	"github.com/vishvananda/netlink/nl"
    20  	"github.com/vishvananda/netns"
    21  	"golang.org/x/sys/unix"
    22  )
    23  
    24  type tearDownNetlinkTest func()
    25  
    26  func skipUnlessRoot(t testing.TB) {
    27  	t.Helper()
    28  
    29  	if os.Getuid() != 0 {
    30  		t.Skip("Test requires root privileges.")
    31  	}
    32  }
    33  
    34  func skipUnlessKModuleLoaded(t *testing.T, moduleNames ...string) {
    35  	t.Helper()
    36  	file, err := ioutil.ReadFile("/proc/modules")
    37  	if err != nil {
    38  		t.Fatal("Failed to open /proc/modules", err)
    39  	}
    40  
    41  	foundRequiredMods := make(map[string]bool)
    42  	lines := strings.Split(string(file), "\n")
    43  
    44  	for _, name := range moduleNames {
    45  		foundRequiredMods[name] = false
    46  		for _, line := range lines {
    47  			n := strings.Split(line, " ")[0]
    48  			if n == name {
    49  				foundRequiredMods[name] = true
    50  				break
    51  			}
    52  		}
    53  	}
    54  
    55  	failed := false
    56  	for _, name := range moduleNames {
    57  		if found, _ := foundRequiredMods[name]; !found {
    58  			t.Logf("Test requires missing kmodule %q.", name)
    59  			failed = true
    60  		}
    61  	}
    62  	if failed {
    63  		t.SkipNow()
    64  	}
    65  }
    66  
    67  func setUpNetlinkTest(t testing.TB) tearDownNetlinkTest {
    68  	skipUnlessRoot(t)
    69  
    70  	// new temporary namespace so we don't pollute the host
    71  	// lock thread since the namespace is thread local
    72  	runtime.LockOSThread()
    73  	var err error
    74  	ns, err := netns.New()
    75  	if err != nil {
    76  		t.Fatal("Failed to create newns", ns)
    77  	}
    78  
    79  	return func() {
    80  		ns.Close()
    81  		runtime.UnlockOSThread()
    82  	}
    83  }
    84  
    85  // setUpNamedNetlinkTest create a temporary named names space with a random name
    86  func setUpNamedNetlinkTest(t *testing.T) (string, tearDownNetlinkTest) {
    87  	skipUnlessRoot(t)
    88  
    89  	origNS, err := netns.Get()
    90  	if err != nil {
    91  		t.Fatal("Failed saving orig namespace")
    92  	}
    93  
    94  	// create a random name
    95  	rnd := make([]byte, 4)
    96  	if _, err := rand.Read(rnd); err != nil {
    97  		t.Fatal("failed creating random ns name")
    98  	}
    99  	name := "netlinktest-" + hex.EncodeToString(rnd)
   100  
   101  	ns, err := netns.NewNamed(name)
   102  	if err != nil {
   103  		t.Fatal("Failed to create new ns", err)
   104  	}
   105  
   106  	runtime.LockOSThread()
   107  	cleanup := func() {
   108  		ns.Close()
   109  		netns.DeleteNamed(name)
   110  		netns.Set(origNS)
   111  		runtime.UnlockOSThread()
   112  	}
   113  
   114  	if err := netns.Set(ns); err != nil {
   115  		cleanup()
   116  		t.Fatal("Failed entering new namespace", err)
   117  	}
   118  
   119  	return name, cleanup
   120  }
   121  
   122  func setUpNetlinkTestWithLoopback(t *testing.T) tearDownNetlinkTest {
   123  	skipUnlessRoot(t)
   124  
   125  	runtime.LockOSThread()
   126  	ns, err := netns.New()
   127  	if err != nil {
   128  		t.Fatal("Failed to create new netns", ns)
   129  	}
   130  
   131  	link, err := LinkByName("lo")
   132  	if err != nil {
   133  		t.Fatalf("Failed to find \"lo\" in new netns: %v", err)
   134  	}
   135  	if err := LinkSetUp(link); err != nil {
   136  		t.Fatalf("Failed to bring up \"lo\" in new netns: %v", err)
   137  	}
   138  
   139  	return func() {
   140  		ns.Close()
   141  		runtime.UnlockOSThread()
   142  	}
   143  }
   144  
   145  func setUpF(t *testing.T, path, value string) {
   146  	file, err := os.Create(path)
   147  	if err != nil {
   148  		t.Fatalf("Failed to open %s: %s", path, err)
   149  	}
   150  	defer file.Close()
   151  	file.WriteString(value)
   152  }
   153  
   154  func setUpMPLSNetlinkTest(t *testing.T) tearDownNetlinkTest {
   155  	if _, err := os.Stat("/proc/sys/net/mpls/platform_labels"); err != nil {
   156  		t.Skip("Test requires MPLS support.")
   157  	}
   158  	f := setUpNetlinkTest(t)
   159  	setUpF(t, "/proc/sys/net/mpls/platform_labels", "1024")
   160  	setUpF(t, "/proc/sys/net/mpls/conf/lo/input", "1")
   161  	return f
   162  }
   163  
   164  func setUpSEG6NetlinkTest(t *testing.T) tearDownNetlinkTest {
   165  	// check if SEG6 options are enabled in Kernel Config
   166  	cmd := exec.Command("uname", "-r")
   167  	var out bytes.Buffer
   168  	cmd.Stdout = &out
   169  	if err := cmd.Run(); err != nil {
   170  		t.Fatal("Failed to run: uname -r")
   171  	}
   172  	s := []string{"/boot/config-", strings.TrimRight(out.String(), "\n")}
   173  	filename := strings.Join(s, "")
   174  
   175  	grepKey := func(key, fname string) (string, error) {
   176  		cmd := exec.Command("grep", key, filename)
   177  		var out bytes.Buffer
   178  		cmd.Stdout = &out
   179  		err := cmd.Run() // "err != nil" if no line matched with grep
   180  		return strings.TrimRight(out.String(), "\n"), err
   181  	}
   182  	key := string("CONFIG_IPV6_SEG6_LWTUNNEL=y")
   183  	if _, err := grepKey(key, filename); err != nil {
   184  		msg := "Skipped test because it requires SEG6_LWTUNNEL support."
   185  		log.Println(msg)
   186  		t.Skip(msg)
   187  	}
   188  	// Add CONFIG_IPV6_SEG6_HMAC to support seg6_hamc
   189  	// key := string("CONFIG_IPV6_SEG6_HMAC=y")
   190  
   191  	return setUpNetlinkTest(t)
   192  }
   193  
   194  func setUpNetlinkTestWithKModule(t *testing.T, moduleNames ...string) tearDownNetlinkTest {
   195  	skipUnlessKModuleLoaded(t, moduleNames...)
   196  	return setUpNetlinkTest(t)
   197  }
   198  func setUpNamedNetlinkTestWithKModule(t *testing.T, moduleNames ...string) (string, tearDownNetlinkTest) {
   199  	file, err := ioutil.ReadFile("/proc/modules")
   200  	if err != nil {
   201  		t.Fatal("Failed to open /proc/modules", err)
   202  	}
   203  
   204  	foundRequiredMods := make(map[string]bool)
   205  	lines := strings.Split(string(file), "\n")
   206  
   207  	for _, name := range moduleNames {
   208  		foundRequiredMods[name] = false
   209  		for _, line := range lines {
   210  			n := strings.Split(line, " ")[0]
   211  			if n == name {
   212  				foundRequiredMods[name] = true
   213  				break
   214  			}
   215  		}
   216  	}
   217  
   218  	failed := false
   219  	for _, name := range moduleNames {
   220  		if found, _ := foundRequiredMods[name]; !found {
   221  			t.Logf("Test requires missing kmodule %q.", name)
   222  			failed = true
   223  		}
   224  	}
   225  	if failed {
   226  		t.SkipNow()
   227  	}
   228  
   229  	return setUpNamedNetlinkTest(t)
   230  }
   231  
   232  func remountSysfs() error {
   233  	if err := unix.Mount("", "/", "none", unix.MS_SLAVE|unix.MS_REC, ""); err != nil {
   234  		return err
   235  	}
   236  	if err := unix.Unmount("/sys", unix.MNT_DETACH); err != nil {
   237  		return err
   238  	}
   239  	return unix.Mount("", "/sys", "sysfs", 0, "")
   240  }
   241  
   242  func minKernelRequired(t *testing.T, kernel, major int) {
   243  	t.Helper()
   244  
   245  	k, m, err := KernelVersion()
   246  	if err != nil {
   247  		t.Fatal(err)
   248  	}
   249  	if k < kernel || k == kernel && m < major {
   250  		t.Skipf("Host Kernel (%d.%d) does not meet test's minimum required version: (%d.%d)",
   251  			k, m, kernel, major)
   252  	}
   253  }
   254  
   255  func KernelVersion() (kernel, major int, err error) {
   256  	uts := unix.Utsname{}
   257  	if err = unix.Uname(&uts); err != nil {
   258  		return
   259  	}
   260  
   261  	ba := make([]byte, 0, len(uts.Release))
   262  	for _, b := range uts.Release {
   263  		if b == 0 {
   264  			break
   265  		}
   266  		ba = append(ba, byte(b))
   267  	}
   268  	var rest string
   269  	if n, _ := fmt.Sscanf(string(ba), "%d.%d%s", &kernel, &major, &rest); n < 2 {
   270  		err = fmt.Errorf("can't parse kernel version in %q", string(ba))
   271  	}
   272  	return
   273  }
   274  
   275  func TestMain(m *testing.M) {
   276  	nl.EnableErrorMessageReporting = true
   277  	os.Exit(m.Run())
   278  }