github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/conn/sticky_linux_test.go (about)

     1  //go:build linux && !android
     2  
     3  /* SPDX-License-Identifier: MIT
     4   *
     5   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     6   */
     7  
     8  package conn
     9  
    10  import (
    11  	"context"
    12  	"net"
    13  	"net/netip"
    14  	"runtime"
    15  	"testing"
    16  	"unsafe"
    17  
    18  	"golang.org/x/sys/unix"
    19  )
    20  
    21  func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
    22  	var buf []byte
    23  	if addr.Is4() {
    24  		buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
    25  		hdr := unix.Cmsghdr{
    26  			Level: unix.IPPROTO_IP,
    27  			Type:  unix.IP_PKTINFO,
    28  		}
    29  		hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
    30  		copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
    31  
    32  		info := unix.Inet4Pktinfo{
    33  			Ifindex:  ifidx,
    34  			Spec_dst: addr.As4(),
    35  		}
    36  		copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
    37  	} else {
    38  		buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
    39  		hdr := unix.Cmsghdr{
    40  			Level: unix.IPPROTO_IPV6,
    41  			Type:  unix.IPV6_PKTINFO,
    42  		}
    43  		hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
    44  		copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
    45  
    46  		info := unix.Inet6Pktinfo{
    47  			Ifindex: uint32(ifidx),
    48  			Addr:    addr.As16(),
    49  		}
    50  		copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
    51  	}
    52  
    53  	ep.src = buf
    54  }
    55  
    56  func Test_setSrcControl(t *testing.T) {
    57  	t.Run("IPv4", func(t *testing.T) {
    58  		ep := &StdNetEndpoint{
    59  			AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
    60  		}
    61  		setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
    62  
    63  		control := make([]byte, srcControlSize)
    64  
    65  		setSrcControl(&control, ep)
    66  
    67  		hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
    68  		if hdr.Level != unix.IPPROTO_IP {
    69  			t.Errorf("unexpected level: %d", hdr.Level)
    70  		}
    71  		if hdr.Type != unix.IP_PKTINFO {
    72  			t.Errorf("unexpected type: %d", hdr.Type)
    73  		}
    74  		if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
    75  			t.Errorf("unexpected length: %d", hdr.Len)
    76  		}
    77  		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
    78  		if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
    79  			t.Errorf("unexpected address: %v", info.Spec_dst)
    80  		}
    81  		if info.Ifindex != 5 {
    82  			t.Errorf("unexpected ifindex: %d", info.Ifindex)
    83  		}
    84  	})
    85  
    86  	t.Run("IPv6", func(t *testing.T) {
    87  		ep := &StdNetEndpoint{
    88  			AddrPort: netip.MustParseAddrPort("[::1]:1234"),
    89  		}
    90  		setSrc(ep, netip.MustParseAddr("::1"), 5)
    91  
    92  		control := make([]byte, srcControlSize)
    93  
    94  		setSrcControl(&control, ep)
    95  
    96  		hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
    97  		if hdr.Level != unix.IPPROTO_IPV6 {
    98  			t.Errorf("unexpected level: %d", hdr.Level)
    99  		}
   100  		if hdr.Type != unix.IPV6_PKTINFO {
   101  			t.Errorf("unexpected type: %d", hdr.Type)
   102  		}
   103  		if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
   104  			t.Errorf("unexpected length: %d", hdr.Len)
   105  		}
   106  		info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
   107  		if info.Addr != ep.SrcIP().As16() {
   108  			t.Errorf("unexpected address: %v", info.Addr)
   109  		}
   110  		if info.Ifindex != 5 {
   111  			t.Errorf("unexpected ifindex: %d", info.Ifindex)
   112  		}
   113  	})
   114  
   115  	t.Run("ClearOnNoSrc", func(t *testing.T) {
   116  		control := make([]byte, unix.CmsgLen(0))
   117  		hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
   118  		hdr.Level = 1
   119  		hdr.Type = 2
   120  		hdr.Len = 3
   121  
   122  		setSrcControl(&control, &StdNetEndpoint{})
   123  
   124  		if len(control) != 0 {
   125  			t.Errorf("unexpected control: %v", control)
   126  		}
   127  	})
   128  }
   129  
   130  func Test_getSrcFromControl(t *testing.T) {
   131  	t.Run("IPv4", func(t *testing.T) {
   132  		control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
   133  		hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
   134  		hdr.Level = unix.IPPROTO_IP
   135  		hdr.Type = unix.IP_PKTINFO
   136  		hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
   137  		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
   138  		info.Spec_dst = [4]byte{127, 0, 0, 1}
   139  		info.Ifindex = 5
   140  
   141  		ep := &StdNetEndpoint{}
   142  		getSrcFromControl(control, ep)
   143  
   144  		if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
   145  			t.Errorf("unexpected address: %v", ep.SrcIP())
   146  		}
   147  		if ep.SrcIfidx() != 5 {
   148  			t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
   149  		}
   150  	})
   151  	t.Run("IPv6", func(t *testing.T) {
   152  		control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
   153  		hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
   154  		hdr.Level = unix.IPPROTO_IPV6
   155  		hdr.Type = unix.IPV6_PKTINFO
   156  		hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
   157  		info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
   158  		info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
   159  		info.Ifindex = 5
   160  
   161  		ep := &StdNetEndpoint{}
   162  		getSrcFromControl(control, ep)
   163  
   164  		if ep.SrcIP() != netip.MustParseAddr("::1") {
   165  			t.Errorf("unexpected address: %v", ep.SrcIP())
   166  		}
   167  		if ep.SrcIfidx() != 5 {
   168  			t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
   169  		}
   170  	})
   171  	t.Run("ClearOnEmpty", func(t *testing.T) {
   172  		var control []byte
   173  		ep := &StdNetEndpoint{}
   174  		setSrc(ep, netip.MustParseAddr("::1"), 5)
   175  
   176  		getSrcFromControl(control, ep)
   177  		if ep.SrcIP().IsValid() {
   178  			t.Errorf("unexpected address: %v", ep.SrcIP())
   179  		}
   180  		if ep.SrcIfidx() != 0 {
   181  			t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
   182  		}
   183  	})
   184  	t.Run("Multiple", func(t *testing.T) {
   185  		zeroControl := make([]byte, unix.CmsgSpace(0))
   186  		zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
   187  		zeroHdr.SetLen(unix.CmsgLen(0))
   188  
   189  		control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
   190  		hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
   191  		hdr.Level = unix.IPPROTO_IP
   192  		hdr.Type = unix.IP_PKTINFO
   193  		hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
   194  		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
   195  		info.Spec_dst = [4]byte{127, 0, 0, 1}
   196  		info.Ifindex = 5
   197  
   198  		combined := make([]byte, 0)
   199  		combined = append(combined, zeroControl...)
   200  		combined = append(combined, control...)
   201  
   202  		ep := &StdNetEndpoint{}
   203  		getSrcFromControl(combined, ep)
   204  
   205  		if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
   206  			t.Errorf("unexpected address: %v", ep.SrcIP())
   207  		}
   208  		if ep.SrcIfidx() != 5 {
   209  			t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
   210  		}
   211  	})
   212  }
   213  
   214  func Test_listenConfig(t *testing.T) {
   215  	t.Run("IPv4", func(t *testing.T) {
   216  		conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
   217  		if err != nil {
   218  			t.Fatal(err)
   219  		}
   220  		defer conn.Close()
   221  		sc, err := conn.(*net.UDPConn).SyscallConn()
   222  		if err != nil {
   223  			t.Fatal(err)
   224  		}
   225  
   226  		if runtime.GOOS == "linux" {
   227  			var i int
   228  			sc.Control(func(fd uintptr) {
   229  				i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
   230  			})
   231  			if err != nil {
   232  				t.Fatal(err)
   233  			}
   234  			if i != 1 {
   235  				t.Error("IP_PKTINFO not set!")
   236  			}
   237  		} else {
   238  			t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
   239  		}
   240  	})
   241  	t.Run("IPv6", func(t *testing.T) {
   242  		conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
   243  		if err != nil {
   244  			t.Fatal(err)
   245  		}
   246  		sc, err := conn.(*net.UDPConn).SyscallConn()
   247  		if err != nil {
   248  			t.Fatal(err)
   249  		}
   250  
   251  		if runtime.GOOS == "linux" {
   252  			var i int
   253  			sc.Control(func(fd uintptr) {
   254  				i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
   255  			})
   256  			if err != nil {
   257  				t.Fatal(err)
   258  			}
   259  			if i != 1 {
   260  				t.Error("IPV6_PKTINFO not set!")
   261  			}
   262  		} else {
   263  			t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
   264  		}
   265  	})
   266  }