github.com/vishvananda/netlink@v1.3.0/netns_test.go (about)

     1  // +build linux
     2  
     3  package netlink
     4  
     5  import (
     6  	"os"
     7  	"runtime"
     8  	"syscall"
     9  	"testing"
    10  
    11  	"github.com/vishvananda/netns"
    12  )
    13  
    14  // TestNetNsIdByFd tests setting and getting the network namespace ID
    15  // by file descriptor. It opens a namespace fd, sets it to a random id,
    16  // then retrieves the ID.
    17  // This does not do any namespace switching.
    18  func TestNetNsIdByFd(t *testing.T) {
    19  	skipUnlessRoot(t)
    20  	// create a network namespace
    21  	ns, err := netns.New()
    22  	CheckErrorFail(t, err)
    23  
    24  	// set its ID
    25  	// In an attempt to avoid namespace id collisions, set this to something
    26  	// insanely high. When the kernel assigns IDs, it does so starting from 0
    27  	// So, just use our pid shifted up 16 bits
    28  	wantID := os.Getpid() << 16
    29  
    30  	h, err := NewHandle()
    31  	CheckErrorFail(t, err)
    32  	err = h.SetNetNsIdByFd(int(ns), wantID)
    33  	CheckErrorFail(t, err)
    34  
    35  	// Get the ID back, make sure it matches
    36  	haveID, _ := h.GetNetNsIdByFd(int(ns))
    37  	if haveID != wantID {
    38  		t.Errorf("GetNetNsIdByFd returned %d, want %d", haveID, wantID)
    39  	}
    40  
    41  	ns.Close()
    42  }
    43  
    44  // TestNetNsIdByPid tests manipulating namespace IDs by pid (really, task / thread id)
    45  // Does the same as TestNetNsIdByFd, but we need to change namespaces so we
    46  // actually have a pid in that namespace
    47  func TestNetNsIdByPid(t *testing.T) {
    48  	skipUnlessRoot(t)
    49  	runtime.LockOSThread() // we need a constant OS thread
    50  	origNs, _ := netns.Get()
    51  
    52  	// create and enter a new netns
    53  	ns, err := netns.New()
    54  	CheckErrorFail(t, err)
    55  	err = netns.Set(ns)
    56  	CheckErrorFail(t, err)
    57  	// make sure we go back to the original namespace when done
    58  	defer func() {
    59  		err := netns.Set(origNs)
    60  		if err != nil {
    61  			panic("failed to restore network ns, bailing")
    62  		}
    63  		runtime.UnlockOSThread()
    64  	}()
    65  
    66  	// As above, we'll pick a crazy large netnsid to avoid collisions
    67  	wantID := syscall.Gettid() << 16
    68  
    69  	h, err := NewHandle()
    70  	CheckErrorFail(t, err)
    71  	err = h.SetNetNsIdByPid(syscall.Gettid(), wantID)
    72  	CheckErrorFail(t, err)
    73  
    74  	//Get the ID and see if it worked
    75  	haveID, _ := h.GetNetNsIdByPid(syscall.Gettid())
    76  	if haveID != wantID {
    77  		t.Errorf("GetNetNsIdByPid returned %d, want %d", haveID, wantID)
    78  	}
    79  }