github.com/psiphon-inc/goarista@v0.0.0-20160825065156-d002785f4c67/netns/netns_test.go (about)

     1  // Copyright (C) 2016  Arista Networks, Inc.
     2  // Use of this source code is governed by the Apache License 2.0
     3  // that can be found in the COPYING file.
     4  
     5  package netns
     6  
     7  import (
     8  	"io/ioutil"
     9  	"os"
    10  	"path/filepath"
    11  	"testing"
    12  )
    13  
    14  type mockHandle int
    15  
    16  func (mh mockHandle) close() error {
    17  	return nil
    18  }
    19  
    20  func (mh mockHandle) fd() int {
    21  	return 0
    22  }
    23  
    24  func TestNetNs(t *testing.T) {
    25  	setNsCallCount := 0
    26  
    27  	// Mock getNs
    28  	oldGetNs := getNs
    29  	getNs = func(nsName string) (handle, error) {
    30  		return mockHandle(1), nil
    31  	}
    32  	defer func() {
    33  		getNs = oldGetNs
    34  	}()
    35  
    36  	// Mock setNs
    37  	oldSetNs := setNs
    38  	setNs = func(fd handle) error {
    39  		setNsCallCount++
    40  		return nil
    41  	}
    42  	defer func() {
    43  		setNs = oldSetNs
    44  	}()
    45  
    46  	// Create a tempfile so we can use its name for the network namespace
    47  	tmpfile, err := ioutil.TempFile("", "")
    48  	if err != nil {
    49  		t.Fatalf("Failed to create a temp file: %s", err)
    50  	}
    51  	defer os.Remove(tmpfile.Name())
    52  	nsName := filepath.Base(tmpfile.Name())
    53  
    54  	// Map of network namespace name to the number of times it should call setNs
    55  	cases := map[string]int{"": 0, "default": 2, nsName: 2}
    56  	for name, callCount := range cases {
    57  		var cbResult string
    58  		err = Do(name, func() {
    59  			cbResult = "Hello" + name
    60  		})
    61  		if err != nil {
    62  			t.Fatalf("Error calling function in different network namespace: %s", err)
    63  		}
    64  		if cbResult != "Hello"+name {
    65  			t.Fatalf("Failed to call the callback function")
    66  		}
    67  		if setNsCallCount != callCount {
    68  			t.Fatalf("setNs should have been called %d times for %s, but was called %d times",
    69  				callCount, name, setNsCallCount)
    70  		}
    71  		setNsCallCount = 0
    72  	}
    73  }