github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/netns/netns_test.go (about) 1 // Copyright (c) 2016 Arista Networks, Inc. 2 // Use of this source code is governed by the Apache License 2.0 3 // that can be found in the COPYING file. 4 5 package netns 6 7 import ( 8 "io/ioutil" 9 "os" 10 "path/filepath" 11 "testing" 12 ) 13 14 type mockHandle int 15 16 func (mh mockHandle) close() error { 17 return nil 18 } 19 20 func (mh mockHandle) fd() int { 21 return 0 22 } 23 24 func TestNetNs(t *testing.T) { 25 setNsCallCount := 0 26 27 // Mock getNs 28 oldGetNs := getNs 29 getNs = func(nsName string) (handle, error) { 30 return mockHandle(1), nil 31 } 32 defer func() { 33 getNs = oldGetNs 34 }() 35 36 // Mock setNs 37 oldSetNs := setNs 38 setNs = func(fd handle) error { 39 setNsCallCount++ 40 return nil 41 } 42 defer func() { 43 setNs = oldSetNs 44 }() 45 46 // Create a tempfile so we can use its name for the network namespace 47 tmpfile, err := ioutil.TempFile("", "") 48 if err != nil { 49 t.Fatalf("Failed to create a temp file: %s", err) 50 } 51 defer os.Remove(tmpfile.Name()) 52 nsName := filepath.Base(tmpfile.Name()) 53 54 // Map of network namespace name to the number of times it should call setNs 55 cases := map[string]int{"": 0, "default": 2, nsName: 2} 56 for name, callCount := range cases { 57 var cbResult string 58 err = Do(name, func() error { 59 cbResult = "Hello" + name 60 return nil 61 }) 62 if err != nil { 63 t.Fatalf("Error calling function in different network namespace: %s", err) 64 } 65 if cbResult != "Hello"+name { 66 t.Fatalf("Failed to call the callback function") 67 } 68 if setNsCallCount != callCount { 69 t.Fatalf("setNs should have been called %d times for %s, but was called %d times", 70 callCount, name, setNsCallCount) 71 } 72 setNsCallCount = 0 73 } 74 }