github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/netlink_test.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package netlink
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/rand"
     9  	"encoding/hex"
    10  	"fmt"
    11  	"io/ioutil"
    12  	"log"
    13  	"os"
    14  	"os/exec"
    15  	"runtime"
    16  	"strings"
    17  	"testing"
    18  
    19  	"github.com/vishvananda/netns"
    20  	"golang.org/x/sys/unix"
    21  )
    22  
    23  type tearDownNetlinkTest func()
    24  
    25  func skipUnlessRoot(t *testing.T) {
    26  	if os.Getuid() != 0 {
    27  		t.Skip("Test requires root privileges.")
    28  	}
    29  }
    30  
    31  func setUpNetlinkTest(t *testing.T) tearDownNetlinkTest {
    32  	skipUnlessRoot(t)
    33  
    34  	// new temporary namespace so we don't pollute the host
    35  	// lock thread since the namespace is thread local
    36  	runtime.LockOSThread()
    37  	var err error
    38  	ns, err := netns.New()
    39  	if err != nil {
    40  		t.Fatal("Failed to create newns", ns)
    41  	}
    42  
    43  	return func() {
    44  		ns.Close()
    45  		runtime.UnlockOSThread()
    46  	}
    47  }
    48  
    49  // setUpNamedNetlinkTest create a temporary named names space with a random name
    50  func setUpNamedNetlinkTest(t *testing.T) (string, tearDownNetlinkTest) {
    51  	skipUnlessRoot(t)
    52  
    53  	origNS, err := netns.Get()
    54  	if err != nil {
    55  		t.Fatal("Failed saving orig namespace")
    56  	}
    57  
    58  	// create a random name
    59  	rnd := make([]byte, 4)
    60  	if _, err := rand.Read(rnd); err != nil {
    61  		t.Fatal("failed creating random ns name")
    62  	}
    63  	name := "netlinktest-" + hex.EncodeToString(rnd)
    64  
    65  	ns, err := netns.NewNamed(name)
    66  	if err != nil {
    67  		t.Fatal("Failed to create new ns", err)
    68  	}
    69  
    70  	runtime.LockOSThread()
    71  	cleanup := func() {
    72  		ns.Close()
    73  		netns.DeleteNamed(name)
    74  		netns.Set(origNS)
    75  		runtime.UnlockOSThread()
    76  	}
    77  
    78  	if err := netns.Set(ns); err != nil {
    79  		cleanup()
    80  		t.Fatal("Failed entering new namespace", err)
    81  	}
    82  
    83  	return name, cleanup
    84  }
    85  
    86  func setUpNetlinkTestWithLoopback(t *testing.T) tearDownNetlinkTest {
    87  	skipUnlessRoot(t)
    88  
    89  	runtime.LockOSThread()
    90  	ns, err := netns.New()
    91  	if err != nil {
    92  		t.Fatal("Failed to create new netns", ns)
    93  	}
    94  
    95  	link, err := LinkByName("lo")
    96  	if err != nil {
    97  		t.Fatalf("Failed to find \"lo\" in new netns: %v", err)
    98  	}
    99  	if err := LinkSetUp(link); err != nil {
   100  		t.Fatalf("Failed to bring up \"lo\" in new netns: %v", err)
   101  	}
   102  
   103  	return func() {
   104  		ns.Close()
   105  		runtime.UnlockOSThread()
   106  	}
   107  }
   108  
   109  func setUpF(t *testing.T, path, value string) {
   110  	file, err := os.Create(path)
   111  	if err != nil {
   112  		t.Fatalf("Failed to open %s: %s", path, err)
   113  	}
   114  	defer file.Close()
   115  	file.WriteString(value)
   116  }
   117  
   118  func setUpMPLSNetlinkTest(t *testing.T) tearDownNetlinkTest {
   119  	if _, err := os.Stat("/proc/sys/net/mpls/platform_labels"); err != nil {
   120  		t.Skip("Test requires MPLS support.")
   121  	}
   122  	f := setUpNetlinkTest(t)
   123  	setUpF(t, "/proc/sys/net/mpls/platform_labels", "1024")
   124  	setUpF(t, "/proc/sys/net/mpls/conf/lo/input", "1")
   125  	return f
   126  }
   127  
   128  func setUpSEG6NetlinkTest(t *testing.T) tearDownNetlinkTest {
   129  	// check if SEG6 options are enabled in Kernel Config
   130  	cmd := exec.Command("uname", "-r")
   131  	var out bytes.Buffer
   132  	cmd.Stdout = &out
   133  	if err := cmd.Run(); err != nil {
   134  		t.Fatal("Failed to run: uname -r")
   135  	}
   136  	s := []string{"/boot/config-", strings.TrimRight(out.String(), "\n")}
   137  	filename := strings.Join(s, "")
   138  
   139  	grepKey := func(key, fname string) (string, error) {
   140  		cmd := exec.Command("grep", key, filename)
   141  		var out bytes.Buffer
   142  		cmd.Stdout = &out
   143  		err := cmd.Run() // "err != nil" if no line matched with grep
   144  		return strings.TrimRight(out.String(), "\n"), err
   145  	}
   146  	key := string("CONFIG_IPV6_SEG6_LWTUNNEL=y")
   147  	if _, err := grepKey(key, filename); err != nil {
   148  		msg := "Skipped test because it requires SEG6_LWTUNNEL support."
   149  		log.Println(msg)
   150  		t.Skip(msg)
   151  	}
   152  	// Add CONFIG_IPV6_SEG6_HMAC to support seg6_hamc
   153  	// key := string("CONFIG_IPV6_SEG6_HMAC=y")
   154  
   155  	return setUpNetlinkTest(t)
   156  }
   157  
   158  func setUpNetlinkTestWithKModule(t *testing.T, name string) tearDownNetlinkTest {
   159  	file, err := ioutil.ReadFile("/proc/modules")
   160  	if err != nil {
   161  		t.Fatal("Failed to open /proc/modules", err)
   162  	}
   163  	found := false
   164  	for _, line := range strings.Split(string(file), "\n") {
   165  		n := strings.Split(line, " ")[0]
   166  		if n == name {
   167  			found = true
   168  			break
   169  		}
   170  
   171  	}
   172  	if !found {
   173  		t.Skipf("Test requires kmodule %q.", name)
   174  	}
   175  	return setUpNetlinkTest(t)
   176  }
   177  
   178  func remountSysfs() error {
   179  	if err := unix.Mount("", "/", "none", unix.MS_SLAVE|unix.MS_REC, ""); err != nil {
   180  		return err
   181  	}
   182  	if err := unix.Unmount("/sys", unix.MNT_DETACH); err != nil {
   183  		return err
   184  	}
   185  	return unix.Mount("", "/sys", "sysfs", 0, "")
   186  }
   187  
   188  func minKernelRequired(t *testing.T, kernel, major int) {
   189  	k, m, err := KernelVersion()
   190  	if err != nil {
   191  		t.Fatal(err)
   192  	}
   193  	if k < kernel || k == kernel && m < major {
   194  		t.Skipf("Host Kernel (%d.%d) does not meet test's minimum required version: (%d.%d)",
   195  			k, m, kernel, major)
   196  	}
   197  }
   198  
   199  func KernelVersion() (kernel, major int, err error) {
   200  	uts := unix.Utsname{}
   201  	if err = unix.Uname(&uts); err != nil {
   202  		return
   203  	}
   204  
   205  	ba := make([]byte, 0, len(uts.Release))
   206  	for _, b := range uts.Release {
   207  		if b == 0 {
   208  			break
   209  		}
   210  		ba = append(ba, byte(b))
   211  	}
   212  	var rest string
   213  	if n, _ := fmt.Sscanf(string(ba), "%d.%d%s", &kernel, &major, &rest); n < 2 {
   214  		err = fmt.Errorf("can't parse kernel version in %q", string(ba))
   215  	}
   216  	return
   217  }