github.com/rish1988/moby@v25.0.2+incompatible/internal/testutils/netnsutils/context_unix.go (about)

     1  //go:build linux || freebsd
     2  
     3  package netnsutils
     4  
     5  import (
     6  	"fmt"
     7  	"runtime"
     8  	"strconv"
     9  	"testing"
    10  
    11  	"github.com/docker/docker/internal/testutils"
    12  	"github.com/docker/docker/libnetwork/ns"
    13  	"github.com/pkg/errors"
    14  	"github.com/vishvananda/netns"
    15  	"golang.org/x/sys/unix"
    16  )
    17  
    18  // OSContext is a handle to a test OS context.
    19  type OSContext struct {
    20  	origNS, newNS netns.NsHandle
    21  
    22  	tid    int
    23  	caller string // The file:line where SetupTestOSContextEx was called, for interpolating into error messages.
    24  }
    25  
    26  // SetupTestOSContext joins the current goroutine to a new network namespace,
    27  // and returns its associated teardown function.
    28  //
    29  // Example usage:
    30  //
    31  //	defer SetupTestOSContext(t)()
    32  func SetupTestOSContext(t *testing.T) func() {
    33  	c := SetupTestOSContextEx(t)
    34  	return func() { c.Cleanup(t) }
    35  }
    36  
    37  // SetupTestOSContextEx joins the current goroutine to a new network namespace.
    38  //
    39  // Compared to [SetupTestOSContext], this function allows goroutines to be
    40  // spawned which are associated with the same OS context via the returned
    41  // OSContext value.
    42  //
    43  // Example usage:
    44  //
    45  //	c := SetupTestOSContext(t)
    46  //	defer c.Cleanup(t)
    47  func SetupTestOSContextEx(t *testing.T) *OSContext {
    48  	runtime.LockOSThread()
    49  	origNS, err := netns.Get()
    50  	if err != nil {
    51  		runtime.UnlockOSThread()
    52  		t.Fatalf("Failed to open initial netns: %v", err)
    53  	}
    54  
    55  	c := OSContext{
    56  		tid:    unix.Gettid(),
    57  		origNS: origNS,
    58  	}
    59  	c.newNS, err = netns.New()
    60  	if err != nil {
    61  		// netns.New() is not atomic: it could have encountered an error
    62  		// after unsharing the current thread's network namespace.
    63  		c.restore(t)
    64  		t.Fatalf("Failed to enter netns: %v", err)
    65  	}
    66  
    67  	// Since we are switching to a new test namespace make
    68  	// sure to re-initialize initNs context
    69  	ns.Init()
    70  
    71  	nl := ns.NlHandle()
    72  	lo, err := nl.LinkByName("lo")
    73  	if err != nil {
    74  		c.restore(t)
    75  		t.Fatalf("Failed to get handle to loopback interface 'lo' in new netns: %v", err)
    76  	}
    77  	if err := nl.LinkSetUp(lo); err != nil {
    78  		c.restore(t)
    79  		t.Fatalf("Failed to enable loopback interface in new netns: %v", err)
    80  	}
    81  
    82  	_, file, line, ok := runtime.Caller(0)
    83  	if ok {
    84  		c.caller = file + ":" + strconv.Itoa(line)
    85  	}
    86  
    87  	return &c
    88  }
    89  
    90  // Cleanup tears down the OS context. It must be called from the same goroutine
    91  // as the [SetupTestOSContextEx] call which returned c.
    92  //
    93  // Explicit cleanup is required as (*testing.T).Cleanup() makes no guarantees
    94  // about which goroutine the cleanup functions are invoked on.
    95  func (c *OSContext) Cleanup(t *testing.T) {
    96  	t.Helper()
    97  	if unix.Gettid() != c.tid {
    98  		t.Fatalf("c.Cleanup() must be called from the same goroutine as SetupTestOSContextEx() (%s)", c.caller)
    99  	}
   100  	if err := c.newNS.Close(); err != nil {
   101  		t.Logf("Warning: netns closing failed (%v)", err)
   102  	}
   103  	c.restore(t)
   104  	ns.Init()
   105  }
   106  
   107  func (c *OSContext) restore(t *testing.T) {
   108  	t.Helper()
   109  	if err := netns.Set(c.origNS); err != nil {
   110  		t.Logf("Warning: failed to restore thread netns (%v)", err)
   111  	} else {
   112  		runtime.UnlockOSThread()
   113  	}
   114  
   115  	if err := c.origNS.Close(); err != nil {
   116  		t.Logf("Warning: netns closing failed (%v)", err)
   117  	}
   118  }
   119  
   120  // Set sets the OS context of the calling goroutine to c and returns a teardown
   121  // function to restore the calling goroutine's OS context and release resources.
   122  // The teardown function accepts an optional Logger argument.
   123  //
   124  // This is a lower-level interface which is less ergonomic than c.Go() but more
   125  // composable with other goroutine-spawning utilities such as [sync.WaitGroup]
   126  // or [golang.org/x/sync/errgroup.Group].
   127  //
   128  // Example usage:
   129  //
   130  //	func TestFoo(t *testing.T) {
   131  //		osctx := testutils.SetupTestOSContextEx(t)
   132  //		defer osctx.Cleanup(t)
   133  //		var eg errgroup.Group
   134  //		eg.Go(func() error {
   135  //			teardown, err := osctx.Set()
   136  //			if err != nil {
   137  //				return err
   138  //			}
   139  //			defer teardown(t)
   140  //			// ...
   141  //		})
   142  //		if err := eg.Wait(); err != nil {
   143  //			t.Fatalf("%+v", err)
   144  //		}
   145  //	}
   146  func (c *OSContext) Set() (func(testutils.Logger), error) {
   147  	runtime.LockOSThread()
   148  	orig, err := netns.Get()
   149  	if err != nil {
   150  		runtime.UnlockOSThread()
   151  		return nil, errors.Wrap(err, "failed to open initial netns for goroutine")
   152  	}
   153  	if err := errors.WithStack(netns.Set(c.newNS)); err != nil {
   154  		runtime.UnlockOSThread()
   155  		return nil, errors.Wrap(err, "failed to set goroutine network namespace")
   156  	}
   157  
   158  	tid := unix.Gettid()
   159  	_, file, line, callerOK := runtime.Caller(0)
   160  
   161  	return func(log testutils.Logger) {
   162  		if unix.Gettid() != tid {
   163  			msg := "teardown function must be called from the same goroutine as c.Set()"
   164  			if callerOK {
   165  				msg += fmt.Sprintf(" (%s:%d)", file, line)
   166  			}
   167  			panic(msg)
   168  		}
   169  
   170  		if err := netns.Set(orig); err != nil && log != nil {
   171  			log.Logf("Warning: failed to restore goroutine thread netns (%v)", err)
   172  		} else {
   173  			runtime.UnlockOSThread()
   174  		}
   175  
   176  		if err := orig.Close(); err != nil && log != nil {
   177  			log.Logf("Warning: netns closing failed (%v)", err)
   178  		}
   179  	}, nil
   180  }
   181  
   182  // Go starts running fn in a new goroutine inside the test OS context.
   183  func (c *OSContext) Go(t *testing.T, fn func()) {
   184  	t.Helper()
   185  	errCh := make(chan error, 1)
   186  	go func() {
   187  		teardown, err := c.Set()
   188  		if err != nil {
   189  			errCh <- err
   190  			return
   191  		}
   192  		defer teardown(t)
   193  		close(errCh)
   194  		fn()
   195  	}()
   196  
   197  	if err := <-errCh; err != nil {
   198  		t.Fatalf("%+v", err)
   199  	}
   200  }