github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/handle_test.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package netlink
     5  
     6  import (
     7  	"crypto/rand"
     8  	"encoding/hex"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"sync"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  	"unsafe"
    17  
    18  	"github.com/sagernet/netlink/nl"
    19  	"github.com/vishvananda/netns"
    20  	"golang.org/x/sys/unix"
    21  )
    22  
    23  func TestHandleCreateClose(t *testing.T) {
    24  	h, err := NewHandle()
    25  	if err != nil {
    26  		t.Fatal(err)
    27  	}
    28  	for _, f := range nl.SupportedNlFamilies {
    29  		sh, ok := h.sockets[f]
    30  		if !ok {
    31  			t.Fatalf("Handle socket(s) for family %d was not created", f)
    32  		}
    33  		if sh.Socket == nil {
    34  			t.Fatalf("Socket for family %d was not created", f)
    35  		}
    36  	}
    37  
    38  	h.Close()
    39  	if h.sockets != nil {
    40  		t.Fatalf("Handle socket(s) were not closed")
    41  	}
    42  }
    43  
    44  func TestHandleCreateNetns(t *testing.T) {
    45  	skipUnlessRoot(t)
    46  
    47  	id := make([]byte, 4)
    48  	if _, err := io.ReadFull(rand.Reader, id); err != nil {
    49  		t.Fatal(err)
    50  	}
    51  	ifName := "dummy-" + hex.EncodeToString(id)
    52  
    53  	// Create an handle on the current netns
    54  	curNs, err := netns.Get()
    55  	if err != nil {
    56  		t.Fatal(err)
    57  	}
    58  	defer curNs.Close()
    59  
    60  	ch, err := NewHandleAt(curNs)
    61  	if err != nil {
    62  		t.Fatal(err)
    63  	}
    64  	defer ch.Close()
    65  
    66  	// Create an handle on a custom netns
    67  	newNs, err := netns.New()
    68  	if err != nil {
    69  		t.Fatal(err)
    70  	}
    71  	defer newNs.Close()
    72  
    73  	nh, err := NewHandleAt(newNs)
    74  	if err != nil {
    75  		t.Fatal(err)
    76  	}
    77  	defer nh.Close()
    78  
    79  	// Create an interface using the current handle
    80  	err = ch.LinkAdd(&Dummy{LinkAttrs{Name: ifName}})
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	l, err := ch.LinkByName(ifName)
    85  	if err != nil {
    86  		t.Fatal(err)
    87  	}
    88  	if l.Type() != "dummy" {
    89  		t.Fatalf("Unexpected link type: %s", l.Type())
    90  	}
    91  
    92  	// Verify the new handle cannot find the interface
    93  	ll, err := nh.LinkByName(ifName)
    94  	if err == nil {
    95  		t.Fatalf("Unexpected link found on netns %s: %v", newNs, ll)
    96  	}
    97  
    98  	// Move the interface to the new netns
    99  	err = ch.LinkSetNsFd(l, int(newNs))
   100  	if err != nil {
   101  		t.Fatal(err)
   102  	}
   103  
   104  	// Verify new netns handle can find the interface while current cannot
   105  	ll, err = nh.LinkByName(ifName)
   106  	if err != nil {
   107  		t.Fatal(err)
   108  	}
   109  	if ll.Type() != "dummy" {
   110  		t.Fatalf("Unexpected link type: %s", ll.Type())
   111  	}
   112  	ll, err = ch.LinkByName(ifName)
   113  	if err == nil {
   114  		t.Fatalf("Unexpected link found on netns %s: %v", curNs, ll)
   115  	}
   116  }
   117  
   118  func TestHandleTimeout(t *testing.T) {
   119  	h, err := NewHandle()
   120  	if err != nil {
   121  		t.Fatal(err)
   122  	}
   123  	defer h.Close()
   124  
   125  	for _, sh := range h.sockets {
   126  		verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 0, Usec: 0})
   127  	}
   128  
   129  	h.SetSocketTimeout(2*time.Second + 8*time.Millisecond)
   130  
   131  	for _, sh := range h.sockets {
   132  		verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 2, Usec: 8000})
   133  	}
   134  }
   135  
   136  func TestHandleReceiveBuffer(t *testing.T) {
   137  	h, err := NewHandle()
   138  	if err != nil {
   139  		t.Fatal(err)
   140  	}
   141  	defer h.Close()
   142  	if err := h.SetSocketReceiveBufferSize(65536, false); err != nil {
   143  		t.Fatal(err)
   144  	}
   145  	sizes, err := h.GetSocketReceiveBufferSize()
   146  	if err != nil {
   147  		t.Fatal(err)
   148  	}
   149  	if len(sizes) != len(h.sockets) {
   150  		t.Fatalf("Unexpected number of socket buffer sizes: %d (expected %d)",
   151  			len(sizes), len(h.sockets))
   152  	}
   153  	for _, s := range sizes {
   154  		if s < 65536 || s > 2*65536 {
   155  			t.Fatalf("Unexpected socket receive buffer size: %d (expected around %d)",
   156  				s, 65536)
   157  		}
   158  	}
   159  }
   160  
   161  func verifySockTimeVal(t *testing.T, fd int, tv unix.Timeval) {
   162  	var (
   163  		tr unix.Timeval
   164  		v  = uint32(0x10)
   165  	)
   166  	_, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0)
   167  	if errno != 0 {
   168  		t.Fatal(errno)
   169  	}
   170  
   171  	if tr.Sec != tv.Sec || tr.Usec != tv.Usec {
   172  		t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv)
   173  	}
   174  
   175  	_, _, errno = unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0)
   176  	if errno != 0 {
   177  		t.Fatal(errno)
   178  	}
   179  
   180  	if tr.Sec != tv.Sec || tr.Usec != tv.Usec {
   181  		t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv)
   182  	}
   183  }
   184  
   185  var (
   186  	iter      = 10
   187  	numThread = uint32(4)
   188  	prefix    = "iface"
   189  	handle1   *Handle
   190  	handle2   *Handle
   191  	ns1       netns.NsHandle
   192  	ns2       netns.NsHandle
   193  	done      uint32
   194  	initError error
   195  	once      sync.Once
   196  )
   197  
   198  func getXfrmState(thread int) *XfrmState {
   199  	return &XfrmState{
   200  		Src:   net.IPv4(byte(192), byte(168), 1, byte(1+thread)),
   201  		Dst:   net.IPv4(byte(192), byte(168), 2, byte(1+thread)),
   202  		Proto: XFRM_PROTO_AH,
   203  		Mode:  XFRM_MODE_TUNNEL,
   204  		Spi:   thread,
   205  		Auth: &XfrmStateAlgo{
   206  			Name: "hmac(sha256)",
   207  			Key:  []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
   208  		},
   209  	}
   210  }
   211  
   212  func getXfrmPolicy(thread int) *XfrmPolicy {
   213  	return &XfrmPolicy{
   214  		Src:     &net.IPNet{IP: net.IPv4(byte(10), byte(10), byte(thread), 0), Mask: []byte{255, 255, 255, 0}},
   215  		Dst:     &net.IPNet{IP: net.IPv4(byte(10), byte(10), byte(thread), 0), Mask: []byte{255, 255, 255, 0}},
   216  		Proto:   17,
   217  		DstPort: 1234,
   218  		SrcPort: 5678,
   219  		Dir:     XFRM_DIR_OUT,
   220  		Tmpls: []XfrmPolicyTmpl{
   221  			{
   222  				Src:   net.IPv4(byte(192), byte(168), 1, byte(thread)),
   223  				Dst:   net.IPv4(byte(192), byte(168), 2, byte(thread)),
   224  				Proto: XFRM_PROTO_ESP,
   225  				Mode:  XFRM_MODE_TUNNEL,
   226  			},
   227  		},
   228  	}
   229  }
   230  func initParallel() {
   231  	ns1, initError = netns.New()
   232  	if initError != nil {
   233  		return
   234  	}
   235  	handle1, initError = NewHandleAt(ns1)
   236  	if initError != nil {
   237  		return
   238  	}
   239  	ns2, initError = netns.New()
   240  	if initError != nil {
   241  		return
   242  	}
   243  	handle2, initError = NewHandleAt(ns2)
   244  	if initError != nil {
   245  		return
   246  	}
   247  }
   248  
   249  func parallelDone() {
   250  	atomic.AddUint32(&done, 1)
   251  	if done == numThread {
   252  		if ns1.IsOpen() {
   253  			ns1.Close()
   254  		}
   255  		if ns2.IsOpen() {
   256  			ns2.Close()
   257  		}
   258  		if handle1 != nil {
   259  			handle1.Close()
   260  		}
   261  		if handle2 != nil {
   262  			handle2.Close()
   263  		}
   264  	}
   265  }
   266  
   267  // Do few route and xfrm operation on the two handles in parallel
   268  func runParallelTests(t *testing.T, thread int) {
   269  	skipUnlessRoot(t)
   270  	defer parallelDone()
   271  
   272  	t.Parallel()
   273  
   274  	once.Do(initParallel)
   275  	if initError != nil {
   276  		t.Fatal(initError)
   277  	}
   278  
   279  	state := getXfrmState(thread)
   280  	policy := getXfrmPolicy(thread)
   281  	for i := 0; i < iter; i++ {
   282  		ifName := fmt.Sprintf("%s_%d_%d", prefix, thread, i)
   283  		link := &Dummy{LinkAttrs{Name: ifName}}
   284  		err := handle1.LinkAdd(link)
   285  		if err != nil {
   286  			t.Fatal(err)
   287  		}
   288  		l, err := handle1.LinkByName(ifName)
   289  		if err != nil {
   290  			t.Fatal(err)
   291  		}
   292  		err = handle1.LinkSetUp(l)
   293  		if err != nil {
   294  			t.Fatal(err)
   295  		}
   296  		handle1.LinkSetNsFd(l, int(ns2))
   297  		if err != nil {
   298  			t.Fatal(err)
   299  		}
   300  		err = handle1.XfrmStateAdd(state)
   301  		if err != nil {
   302  			t.Fatal(err)
   303  		}
   304  		err = handle1.XfrmPolicyAdd(policy)
   305  		if err != nil {
   306  			t.Fatal(err)
   307  		}
   308  		err = handle2.LinkSetDown(l)
   309  		if err != nil {
   310  			t.Fatal(err)
   311  		}
   312  		err = handle2.XfrmStateAdd(state)
   313  		if err != nil {
   314  			t.Fatal(err)
   315  		}
   316  		err = handle2.XfrmPolicyAdd(policy)
   317  		if err != nil {
   318  			t.Fatal(err)
   319  		}
   320  		_, err = handle2.LinkByName(ifName)
   321  		if err != nil {
   322  			t.Fatal(err)
   323  		}
   324  		handle2.LinkSetNsFd(l, int(ns1))
   325  		if err != nil {
   326  			t.Fatal(err)
   327  		}
   328  		err = handle1.LinkSetUp(l)
   329  		if err != nil {
   330  			t.Fatal(err)
   331  		}
   332  		_, err = handle1.LinkByName(ifName)
   333  		if err != nil {
   334  			t.Fatal(err)
   335  		}
   336  		err = handle1.XfrmPolicyDel(policy)
   337  		if err != nil {
   338  			t.Fatal(err)
   339  		}
   340  		err = handle2.XfrmPolicyDel(policy)
   341  		if err != nil {
   342  			t.Fatal(err)
   343  		}
   344  		err = handle1.XfrmStateDel(state)
   345  		if err != nil {
   346  			t.Fatal(err)
   347  		}
   348  		err = handle2.XfrmStateDel(state)
   349  		if err != nil {
   350  			t.Fatal(err)
   351  		}
   352  	}
   353  }
   354  
   355  func TestHandleParallel1(t *testing.T) {
   356  	runParallelTests(t, 1)
   357  }
   358  
   359  func TestHandleParallel2(t *testing.T) {
   360  	runParallelTests(t, 2)
   361  }
   362  
   363  func TestHandleParallel3(t *testing.T) {
   364  	runParallelTests(t, 3)
   365  }
   366  
   367  func TestHandleParallel4(t *testing.T) {
   368  	runParallelTests(t, 4)
   369  }