github.com/rawahars/moby@v24.0.4+incompatible/libnetwork/testutils/context_unix.go (about)

     1  //go:build linux || freebsd
     2  // +build linux freebsd
     3  
     4  package testutils
     5  
     6  import (
     7  	"fmt"
     8  	"runtime"
     9  	"strconv"
    10  	"testing"
    11  
    12  	"github.com/docker/docker/libnetwork/ns"
    13  	"github.com/pkg/errors"
    14  	"github.com/vishvananda/netns"
    15  	"golang.org/x/sys/unix"
    16  )
    17  
    18  // OSContext is a handle to a test OS context.
    19  type OSContext struct {
    20  	origNS, newNS netns.NsHandle
    21  
    22  	tid    int
    23  	caller string // The file:line where SetupTestOSContextEx was called, for interpolating into error messages.
    24  }
    25  
    26  // SetupTestOSContextEx joins the current goroutine to a new network namespace.
    27  //
    28  // Compared to [SetupTestOSContext], this function allows goroutines to be
    29  // spawned which are associated with the same OS context via the returned
    30  // OSContext value.
    31  //
    32  // Example usage:
    33  //
    34  //	c := SetupTestOSContext(t)
    35  //	defer c.Cleanup(t)
    36  func SetupTestOSContextEx(t *testing.T) *OSContext {
    37  	runtime.LockOSThread()
    38  	origNS, err := netns.Get()
    39  	if err != nil {
    40  		runtime.UnlockOSThread()
    41  		t.Fatalf("Failed to open initial netns: %v", err)
    42  	}
    43  
    44  	c := OSContext{
    45  		tid:    unix.Gettid(),
    46  		origNS: origNS,
    47  	}
    48  	c.newNS, err = netns.New()
    49  	if err != nil {
    50  		// netns.New() is not atomic: it could have encountered an error
    51  		// after unsharing the current thread's network namespace.
    52  		c.restore(t)
    53  		t.Fatalf("Failed to enter netns: %v", err)
    54  	}
    55  
    56  	// Since we are switching to a new test namespace make
    57  	// sure to re-initialize initNs context
    58  	ns.Init()
    59  
    60  	nl := ns.NlHandle()
    61  	lo, err := nl.LinkByName("lo")
    62  	if err != nil {
    63  		c.restore(t)
    64  		t.Fatalf("Failed to get handle to loopback interface 'lo' in new netns: %v", err)
    65  	}
    66  	if err := nl.LinkSetUp(lo); err != nil {
    67  		c.restore(t)
    68  		t.Fatalf("Failed to enable loopback interface in new netns: %v", err)
    69  	}
    70  
    71  	_, file, line, ok := runtime.Caller(0)
    72  	if ok {
    73  		c.caller = file + ":" + strconv.Itoa(line)
    74  	}
    75  
    76  	return &c
    77  }
    78  
    79  // Cleanup tears down the OS context. It must be called from the same goroutine
    80  // as the [SetupTestOSContextEx] call which returned c.
    81  //
    82  // Explicit cleanup is required as (*testing.T).Cleanup() makes no guarantees
    83  // about which goroutine the cleanup functions are invoked on.
    84  func (c *OSContext) Cleanup(t *testing.T) {
    85  	t.Helper()
    86  	if unix.Gettid() != c.tid {
    87  		t.Fatalf("c.Cleanup() must be called from the same goroutine as SetupTestOSContextEx() (%s)", c.caller)
    88  	}
    89  	if err := c.newNS.Close(); err != nil {
    90  		t.Logf("Warning: netns closing failed (%v)", err)
    91  	}
    92  	c.restore(t)
    93  	ns.Init()
    94  }
    95  
    96  func (c *OSContext) restore(t *testing.T) {
    97  	t.Helper()
    98  	if err := netns.Set(c.origNS); err != nil {
    99  		t.Logf("Warning: failed to restore thread netns (%v)", err)
   100  	} else {
   101  		runtime.UnlockOSThread()
   102  	}
   103  
   104  	if err := c.origNS.Close(); err != nil {
   105  		t.Logf("Warning: netns closing failed (%v)", err)
   106  	}
   107  }
   108  
   109  // Set sets the OS context of the calling goroutine to c and returns a teardown
   110  // function to restore the calling goroutine's OS context and release resources.
   111  // The teardown function accepts an optional Logger argument.
   112  //
   113  // This is a lower-level interface which is less ergonomic than c.Go() but more
   114  // composable with other goroutine-spawning utilities such as [sync.WaitGroup]
   115  // or [golang.org/x/sync/errgroup.Group].
   116  //
   117  // Example usage:
   118  //
   119  //	func TestFoo(t *testing.T) {
   120  //		osctx := testutils.SetupTestOSContextEx(t)
   121  //		defer osctx.Cleanup(t)
   122  //		var eg errgroup.Group
   123  //		eg.Go(func() error {
   124  //			teardown, err := osctx.Set()
   125  //			if err != nil {
   126  //				return err
   127  //			}
   128  //			defer teardown(t)
   129  //			// ...
   130  //		})
   131  //		if err := eg.Wait(); err != nil {
   132  //			t.Fatalf("%+v", err)
   133  //		}
   134  //	}
   135  func (c *OSContext) Set() (func(Logger), error) {
   136  	runtime.LockOSThread()
   137  	orig, err := netns.Get()
   138  	if err != nil {
   139  		runtime.UnlockOSThread()
   140  		return nil, errors.Wrap(err, "failed to open initial netns for goroutine")
   141  	}
   142  	if err := errors.WithStack(netns.Set(c.newNS)); err != nil {
   143  		runtime.UnlockOSThread()
   144  		return nil, errors.Wrap(err, "failed to set goroutine network namespace")
   145  	}
   146  
   147  	tid := unix.Gettid()
   148  	_, file, line, callerOK := runtime.Caller(0)
   149  
   150  	return func(log Logger) {
   151  		if unix.Gettid() != tid {
   152  			msg := "teardown function must be called from the same goroutine as c.Set()"
   153  			if callerOK {
   154  				msg += fmt.Sprintf(" (%s:%d)", file, line)
   155  			}
   156  			panic(msg)
   157  		}
   158  
   159  		if err := netns.Set(orig); err != nil && log != nil {
   160  			log.Logf("Warning: failed to restore goroutine thread netns (%v)", err)
   161  		} else {
   162  			runtime.UnlockOSThread()
   163  		}
   164  
   165  		if err := orig.Close(); err != nil && log != nil {
   166  			log.Logf("Warning: netns closing failed (%v)", err)
   167  		}
   168  	}, nil
   169  }