github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/netns_test.go (about) 1 //go:build linux 2 // +build linux 3 4 package netlink 5 6 import ( 7 "os" 8 "runtime" 9 "syscall" 10 "testing" 11 12 "github.com/vishvananda/netns" 13 ) 14 15 // TestNetNsIdByFd tests setting and getting the network namespace ID 16 // by file descriptor. It opens a namespace fd, sets it to a random id, 17 // then retrieves the ID. 18 // This does not do any namespace switching. 19 func TestNetNsIdByFd(t *testing.T) { 20 skipUnlessRoot(t) 21 // create a network namespace 22 ns, err := netns.New() 23 CheckErrorFail(t, err) 24 25 // set its ID 26 // In an attempt to avoid namespace id collisions, set this to something 27 // insanely high. When the kernel assigns IDs, it does so starting from 0 28 // So, just use our pid shifted up 16 bits 29 wantID := os.Getpid() << 16 30 31 h, err := NewHandle() 32 CheckErrorFail(t, err) 33 err = h.SetNetNsIdByFd(int(ns), wantID) 34 CheckErrorFail(t, err) 35 36 // Get the ID back, make sure it matches 37 haveID, _ := h.GetNetNsIdByFd(int(ns)) 38 if haveID != wantID { 39 t.Errorf("GetNetNsIdByFd returned %d, want %d", haveID, wantID) 40 } 41 42 ns.Close() 43 } 44 45 // TestNetNsIdByPid tests manipulating namespace IDs by pid (really, task / thread id) 46 // Does the same as TestNetNsIdByFd, but we need to change namespaces so we 47 // actually have a pid in that namespace 48 func TestNetNsIdByPid(t *testing.T) { 49 skipUnlessRoot(t) 50 runtime.LockOSThread() // we need a constant OS thread 51 origNs, _ := netns.Get() 52 53 // create and enter a new netns 54 ns, err := netns.New() 55 CheckErrorFail(t, err) 56 err = netns.Set(ns) 57 CheckErrorFail(t, err) 58 // make sure we go back to the original namespace when done 59 defer func() { 60 err := netns.Set(origNs) 61 if err != nil { 62 panic("failed to restore network ns, bailing") 63 } 64 runtime.UnlockOSThread() 65 }() 66 67 // As above, we'll pick a crazy large netnsid to avoid collisions 68 wantID := syscall.Gettid() << 16 69 70 h, err := NewHandle() 71 CheckErrorFail(t, err) 72 err = h.SetNetNsIdByPid(syscall.Gettid(), wantID) 73 CheckErrorFail(t, err) 74 75 //Get the ID and see if it worked 76 haveID, _ := h.GetNetNsIdByPid(syscall.Gettid()) 77 if haveID != wantID { 78 t.Errorf("GetNetNsIdByPid returned %d, want %d", haveID, wantID) 79 } 80 }