github.com/cawidtu/notwireguard-go/conn@v0.0.0-20230523131112-68e8e5ce9cdf/bind_linux.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"errors"
    10  	"net"
    11  	"net/netip"
    12  	"strconv"
    13  	"sync"
    14  	"syscall"
    15  	"unsafe"
    16  
    17  	"golang.org/x/sys/unix"
    18  )
    19  
    20  type ipv4Source struct {
    21  	Src     [4]byte
    22  	Ifindex int32
    23  }
    24  
    25  type ipv6Source struct {
    26  	src [16]byte
    27  	// ifindex belongs in dst.ZoneId
    28  }
    29  
    30  type LinuxSocketEndpoint struct {
    31  	mu   sync.Mutex
    32  	dst  [unsafe.Sizeof(unix.SockaddrInet6{})]byte
    33  	src  [unsafe.Sizeof(ipv6Source{})]byte
    34  	isV6 bool
    35  }
    36  
    37  func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source         { return endpoint.src4() }
    38  func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
    39  func (endpoint *LinuxSocketEndpoint) IsV6() bool                { return endpoint.isV6 }
    40  
    41  func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
    42  	return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
    43  }
    44  
    45  func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
    46  	return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
    47  }
    48  
    49  func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
    50  	return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
    51  }
    52  
    53  func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
    54  	return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
    55  }
    56  
    57  // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
    58  type LinuxSocketBind struct {
    59  	// mu guards sock4 and sock6 and the associated fds.
    60  	// As long as someone holds mu (read or write), the associated fds are valid.
    61  	mu    sync.RWMutex
    62  	sock4 int
    63  	sock6 int
    64  }
    65  
    66  func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
    67  func NewDefaultBind() Bind     { return NewLinuxSocketBind() }
    68  
    69  var (
    70  	_ Endpoint = (*LinuxSocketEndpoint)(nil)
    71  	_ Bind     = (*LinuxSocketBind)(nil)
    72  )
    73  
    74  func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
    75  	var end LinuxSocketEndpoint
    76  	e, err := netip.ParseAddrPort(s)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	if e.Addr().Is4() {
    82  		dst := end.dst4()
    83  		end.isV6 = false
    84  		dst.Port = int(e.Port())
    85  		dst.Addr = e.Addr().As4()
    86  		end.ClearSrc()
    87  		return &end, nil
    88  	}
    89  
    90  	if e.Addr().Is6() {
    91  		zone, err := zoneToUint32(e.Addr().Zone())
    92  		if err != nil {
    93  			return nil, err
    94  		}
    95  		dst := end.dst6()
    96  		end.isV6 = true
    97  		dst.Port = int(e.Port())
    98  		dst.ZoneId = zone
    99  		dst.Addr = e.Addr().As16()
   100  		end.ClearSrc()
   101  		return &end, nil
   102  	}
   103  
   104  	return nil, errors.New("invalid IP address")
   105  }
   106  
   107  func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
   108  	bind.mu.Lock()
   109  	defer bind.mu.Unlock()
   110  
   111  	var err error
   112  	var newPort uint16
   113  	var tries int
   114  
   115  	if bind.sock4 != -1 || bind.sock6 != -1 {
   116  		return nil, 0, ErrBindAlreadyOpen
   117  	}
   118  
   119  	originalPort := port
   120  
   121  again:
   122  	port = originalPort
   123  	var sock4, sock6 int
   124  	// Attempt ipv6 bind, update port if successful.
   125  	sock6, newPort, err = create6(port)
   126  	if err != nil {
   127  		if !errors.Is(err, syscall.EAFNOSUPPORT) {
   128  			return nil, 0, err
   129  		}
   130  	} else {
   131  		port = newPort
   132  	}
   133  
   134  	// Attempt ipv4 bind, update port if successful.
   135  	sock4, newPort, err = create4(port)
   136  	if err != nil {
   137  		if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
   138  			unix.Close(sock6)
   139  			tries++
   140  			goto again
   141  		}
   142  		if !errors.Is(err, syscall.EAFNOSUPPORT) {
   143  			unix.Close(sock6)
   144  			return nil, 0, err
   145  		}
   146  	} else {
   147  		port = newPort
   148  	}
   149  
   150  	var fns []ReceiveFunc
   151  	if sock4 != -1 {
   152  		bind.sock4 = sock4
   153  		fns = append(fns, bind.receiveIPv4)
   154  	}
   155  	if sock6 != -1 {
   156  		bind.sock6 = sock6
   157  		fns = append(fns, bind.receiveIPv6)
   158  	}
   159  	if len(fns) == 0 {
   160  		return nil, 0, syscall.EAFNOSUPPORT
   161  	}
   162  	return fns, port, nil
   163  }
   164  
   165  func (bind *LinuxSocketBind) SetMark(value uint32) error {
   166  	bind.mu.RLock()
   167  	defer bind.mu.RUnlock()
   168  
   169  	if bind.sock6 != -1 {
   170  		err := unix.SetsockoptInt(
   171  			bind.sock6,
   172  			unix.SOL_SOCKET,
   173  			unix.SO_MARK,
   174  			int(value),
   175  		)
   176  		if err != nil {
   177  			return err
   178  		}
   179  	}
   180  
   181  	if bind.sock4 != -1 {
   182  		err := unix.SetsockoptInt(
   183  			bind.sock4,
   184  			unix.SOL_SOCKET,
   185  			unix.SO_MARK,
   186  			int(value),
   187  		)
   188  		if err != nil {
   189  			return err
   190  		}
   191  	}
   192  
   193  	return nil
   194  }
   195  
   196  func (bind *LinuxSocketBind) Close() error {
   197  	// Take a readlock to shut down the sockets...
   198  	bind.mu.RLock()
   199  	if bind.sock6 != -1 {
   200  		unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
   201  	}
   202  	if bind.sock4 != -1 {
   203  		unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
   204  	}
   205  	bind.mu.RUnlock()
   206  	// ...and a write lock to close the fd.
   207  	// This ensures that no one else is using the fd.
   208  	bind.mu.Lock()
   209  	defer bind.mu.Unlock()
   210  	var err1, err2 error
   211  	if bind.sock6 != -1 {
   212  		err1 = unix.Close(bind.sock6)
   213  		bind.sock6 = -1
   214  	}
   215  	if bind.sock4 != -1 {
   216  		err2 = unix.Close(bind.sock4)
   217  		bind.sock4 = -1
   218  	}
   219  
   220  	if err1 != nil {
   221  		return err1
   222  	}
   223  	return err2
   224  }
   225  
   226  func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
   227  	bind.mu.RLock()
   228  	defer bind.mu.RUnlock()
   229  	if bind.sock4 == -1 {
   230  		return 0, nil, net.ErrClosed
   231  	}
   232  	var end LinuxSocketEndpoint
   233  	n, err := receive4(bind.sock4, buf, &end)
   234  	return n, &end, err
   235  }
   236  
   237  func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
   238  	bind.mu.RLock()
   239  	defer bind.mu.RUnlock()
   240  	if bind.sock6 == -1 {
   241  		return 0, nil, net.ErrClosed
   242  	}
   243  	var end LinuxSocketEndpoint
   244  	n, err := receive6(bind.sock6, buf, &end)
   245  	return n, &end, err
   246  }
   247  
   248  func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
   249  	nend, ok := end.(*LinuxSocketEndpoint)
   250  	if !ok {
   251  		return ErrWrongEndpointType
   252  	}
   253  	bind.mu.RLock()
   254  	defer bind.mu.RUnlock()
   255  	if !nend.isV6 {
   256  		if bind.sock4 == -1 {
   257  			return net.ErrClosed
   258  		}
   259  		return send4(bind.sock4, nend, buff)
   260  	} else {
   261  		if bind.sock6 == -1 {
   262  			return net.ErrClosed
   263  		}
   264  		return send6(bind.sock6, nend, buff)
   265  	}
   266  }
   267  
   268  func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
   269  	if !end.isV6 {
   270  		return netip.AddrFrom4(end.src4().Src)
   271  	} else {
   272  		return netip.AddrFrom16(end.src6().src)
   273  	}
   274  }
   275  
   276  func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
   277  	if !end.isV6 {
   278  		return netip.AddrFrom4(end.dst4().Addr)
   279  	} else {
   280  		return netip.AddrFrom16(end.dst6().Addr)
   281  	}
   282  }
   283  
   284  func (end *LinuxSocketEndpoint) DstToBytes() []byte {
   285  	if !end.isV6 {
   286  		return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
   287  	} else {
   288  		return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
   289  	}
   290  }
   291  
   292  func (end *LinuxSocketEndpoint) SrcToString() string {
   293  	return end.SrcIP().String()
   294  }
   295  
   296  func (end *LinuxSocketEndpoint) DstToString() string {
   297  	var port int
   298  	if !end.isV6 {
   299  		port = end.dst4().Port
   300  	} else {
   301  		port = end.dst6().Port
   302  	}
   303  	return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
   304  }
   305  
   306  func (end *LinuxSocketEndpoint) ClearDst() {
   307  	for i := range end.dst {
   308  		end.dst[i] = 0
   309  	}
   310  }
   311  
   312  func (end *LinuxSocketEndpoint) ClearSrc() {
   313  	for i := range end.src {
   314  		end.src[i] = 0
   315  	}
   316  }
   317  
   318  func zoneToUint32(zone string) (uint32, error) {
   319  	if zone == "" {
   320  		return 0, nil
   321  	}
   322  	if intr, err := net.InterfaceByName(zone); err == nil {
   323  		return uint32(intr.Index), nil
   324  	}
   325  	n, err := strconv.ParseUint(zone, 10, 32)
   326  	return uint32(n), err
   327  }
   328  
   329  func create4(port uint16) (int, uint16, error) {
   330  	// create socket
   331  
   332  	fd, err := unix.Socket(
   333  		unix.AF_INET,
   334  		unix.SOCK_DGRAM,
   335  		0,
   336  	)
   337  	if err != nil {
   338  		return -1, 0, err
   339  	}
   340  
   341  	addr := unix.SockaddrInet4{
   342  		Port: int(port),
   343  	}
   344  
   345  	// set sockopts and bind
   346  
   347  	if err := func() error {
   348  		if err := unix.SetsockoptInt(
   349  			fd,
   350  			unix.IPPROTO_IP,
   351  			unix.IP_PKTINFO,
   352  			1,
   353  		); err != nil {
   354  			return err
   355  		}
   356  
   357  		return unix.Bind(fd, &addr)
   358  	}(); err != nil {
   359  		unix.Close(fd)
   360  		return -1, 0, err
   361  	}
   362  
   363  	sa, err := unix.Getsockname(fd)
   364  	if err == nil {
   365  		addr.Port = sa.(*unix.SockaddrInet4).Port
   366  	}
   367  
   368  	return fd, uint16(addr.Port), err
   369  }
   370  
   371  func create6(port uint16) (int, uint16, error) {
   372  	// create socket
   373  
   374  	fd, err := unix.Socket(
   375  		unix.AF_INET6,
   376  		unix.SOCK_DGRAM,
   377  		0,
   378  	)
   379  	if err != nil {
   380  		return -1, 0, err
   381  	}
   382  
   383  	// set sockopts and bind
   384  
   385  	addr := unix.SockaddrInet6{
   386  		Port: int(port),
   387  	}
   388  
   389  	if err := func() error {
   390  		if err := unix.SetsockoptInt(
   391  			fd,
   392  			unix.IPPROTO_IPV6,
   393  			unix.IPV6_RECVPKTINFO,
   394  			1,
   395  		); err != nil {
   396  			return err
   397  		}
   398  
   399  		if err := unix.SetsockoptInt(
   400  			fd,
   401  			unix.IPPROTO_IPV6,
   402  			unix.IPV6_V6ONLY,
   403  			1,
   404  		); err != nil {
   405  			return err
   406  		}
   407  
   408  		return unix.Bind(fd, &addr)
   409  	}(); err != nil {
   410  		unix.Close(fd)
   411  		return -1, 0, err
   412  	}
   413  
   414  	sa, err := unix.Getsockname(fd)
   415  	if err == nil {
   416  		addr.Port = sa.(*unix.SockaddrInet6).Port
   417  	}
   418  
   419  	return fd, uint16(addr.Port), err
   420  }
   421  
   422  func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
   423  	// construct message header
   424  
   425  	cmsg := struct {
   426  		cmsghdr unix.Cmsghdr
   427  		pktinfo unix.Inet4Pktinfo
   428  	}{
   429  		unix.Cmsghdr{
   430  			Level: unix.IPPROTO_IP,
   431  			Type:  unix.IP_PKTINFO,
   432  			Len:   unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
   433  		},
   434  		unix.Inet4Pktinfo{
   435  			Spec_dst: end.src4().Src,
   436  			Ifindex:  end.src4().Ifindex,
   437  		},
   438  	}
   439  
   440  	end.mu.Lock()
   441  	_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
   442  	end.mu.Unlock()
   443  
   444  	if err == nil {
   445  		return nil
   446  	}
   447  
   448  	// clear src and retry
   449  
   450  	if err == unix.EINVAL {
   451  		end.ClearSrc()
   452  		cmsg.pktinfo = unix.Inet4Pktinfo{}
   453  		end.mu.Lock()
   454  		_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
   455  		end.mu.Unlock()
   456  	}
   457  
   458  	return err
   459  }
   460  
   461  func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
   462  	// construct message header
   463  
   464  	cmsg := struct {
   465  		cmsghdr unix.Cmsghdr
   466  		pktinfo unix.Inet6Pktinfo
   467  	}{
   468  		unix.Cmsghdr{
   469  			Level: unix.IPPROTO_IPV6,
   470  			Type:  unix.IPV6_PKTINFO,
   471  			Len:   unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
   472  		},
   473  		unix.Inet6Pktinfo{
   474  			Addr:    end.src6().src,
   475  			Ifindex: end.dst6().ZoneId,
   476  		},
   477  	}
   478  
   479  	if cmsg.pktinfo.Addr == [16]byte{} {
   480  		cmsg.pktinfo.Ifindex = 0
   481  	}
   482  
   483  	end.mu.Lock()
   484  	_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
   485  	end.mu.Unlock()
   486  
   487  	if err == nil {
   488  		return nil
   489  	}
   490  
   491  	// clear src and retry
   492  
   493  	if err == unix.EINVAL {
   494  		end.ClearSrc()
   495  		cmsg.pktinfo = unix.Inet6Pktinfo{}
   496  		end.mu.Lock()
   497  		_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
   498  		end.mu.Unlock()
   499  	}
   500  
   501  	return err
   502  }
   503  
   504  func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
   505  	// construct message header
   506  
   507  	var cmsg struct {
   508  		cmsghdr unix.Cmsghdr
   509  		pktinfo unix.Inet4Pktinfo
   510  	}
   511  
   512  	size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
   513  	if err != nil {
   514  		return 0, err
   515  	}
   516  	end.isV6 = false
   517  
   518  	if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
   519  		*end.dst4() = *newDst4
   520  	}
   521  
   522  	// update source cache
   523  
   524  	if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
   525  		cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
   526  		cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
   527  		end.src4().Src = cmsg.pktinfo.Spec_dst
   528  		end.src4().Ifindex = cmsg.pktinfo.Ifindex
   529  	}
   530  
   531  	return size, nil
   532  }
   533  
   534  func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
   535  	// construct message header
   536  
   537  	var cmsg struct {
   538  		cmsghdr unix.Cmsghdr
   539  		pktinfo unix.Inet6Pktinfo
   540  	}
   541  
   542  	size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
   543  	if err != nil {
   544  		return 0, err
   545  	}
   546  	end.isV6 = true
   547  
   548  	if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
   549  		*end.dst6() = *newDst6
   550  	}
   551  
   552  	// update source cache
   553  
   554  	if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
   555  		cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
   556  		cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
   557  		end.src6().src = cmsg.pktinfo.Addr
   558  		end.dst6().ZoneId = cmsg.pktinfo.Ifindex
   559  	}
   560  
   561  	return size, nil
   562  }