github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/netns_test.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package netlink
     5  
     6  import (
     7  	"os"
     8  	"runtime"
     9  	"syscall"
    10  	"testing"
    11  
    12  	"github.com/vishvananda/netns"
    13  )
    14  
    15  // TestNetNsIdByFd tests setting and getting the network namespace ID
    16  // by file descriptor. It opens a namespace fd, sets it to a random id,
    17  // then retrieves the ID.
    18  // This does not do any namespace switching.
    19  func TestNetNsIdByFd(t *testing.T) {
    20  	skipUnlessRoot(t)
    21  	// create a network namespace
    22  	ns, err := netns.New()
    23  	CheckErrorFail(t, err)
    24  
    25  	// set its ID
    26  	// In an attempt to avoid namespace id collisions, set this to something
    27  	// insanely high. When the kernel assigns IDs, it does so starting from 0
    28  	// So, just use our pid shifted up 16 bits
    29  	wantID := os.Getpid() << 16
    30  
    31  	h, err := NewHandle()
    32  	CheckErrorFail(t, err)
    33  	err = h.SetNetNsIdByFd(int(ns), wantID)
    34  	CheckErrorFail(t, err)
    35  
    36  	// Get the ID back, make sure it matches
    37  	haveID, _ := h.GetNetNsIdByFd(int(ns))
    38  	if haveID != wantID {
    39  		t.Errorf("GetNetNsIdByFd returned %d, want %d", haveID, wantID)
    40  	}
    41  
    42  	ns.Close()
    43  }
    44  
    45  // TestNetNsIdByPid tests manipulating namespace IDs by pid (really, task / thread id)
    46  // Does the same as TestNetNsIdByFd, but we need to change namespaces so we
    47  // actually have a pid in that namespace
    48  func TestNetNsIdByPid(t *testing.T) {
    49  	skipUnlessRoot(t)
    50  	runtime.LockOSThread() // we need a constant OS thread
    51  	origNs, _ := netns.Get()
    52  
    53  	// create and enter a new netns
    54  	ns, err := netns.New()
    55  	CheckErrorFail(t, err)
    56  	err = netns.Set(ns)
    57  	CheckErrorFail(t, err)
    58  	// make sure we go back to the original namespace when done
    59  	defer func() {
    60  		err := netns.Set(origNs)
    61  		if err != nil {
    62  			panic("failed to restore network ns, bailing")
    63  		}
    64  		runtime.UnlockOSThread()
    65  	}()
    66  
    67  	// As above, we'll pick a crazy large netnsid to avoid collisions
    68  	wantID := syscall.Gettid() << 16
    69  
    70  	h, err := NewHandle()
    71  	CheckErrorFail(t, err)
    72  	err = h.SetNetNsIdByPid(syscall.Gettid(), wantID)
    73  	CheckErrorFail(t, err)
    74  
    75  	//Get the ID and see if it worked
    76  	haveID, _ := h.GetNetNsIdByPid(syscall.Gettid())
    77  	if haveID != wantID {
    78  		t.Errorf("GetNetNsIdByPid returned %d, want %d", haveID, wantID)
    79  	}
    80  }