github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/conn/bind_linux.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"errors"
    10  	"net"
    11  	"strconv"
    12  	"sync"
    13  	"syscall"
    14  	"unsafe"
    15  
    16  	"golang.org/x/sys/unix"
    17  )
    18  
    19  type ipv4Source struct {
    20  	Src     [4]byte
    21  	Ifindex int32
    22  }
    23  
    24  type ipv6Source struct {
    25  	src [16]byte
    26  	// ifindex belongs in dst.ZoneId
    27  }
    28  
    29  type LinuxSocketEndpoint struct {
    30  	mu   sync.Mutex
    31  	dst  [unsafe.Sizeof(unix.SockaddrInet6{})]byte
    32  	src  [unsafe.Sizeof(ipv6Source{})]byte
    33  	isV6 bool
    34  }
    35  
    36  func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source         { return endpoint.src4() }
    37  func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
    38  func (endpoint *LinuxSocketEndpoint) IsV6() bool                { return endpoint.isV6 }
    39  
    40  func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
    41  	return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
    42  }
    43  
    44  func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
    45  	return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
    46  }
    47  
    48  func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
    49  	return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
    50  }
    51  
    52  func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
    53  	return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
    54  }
    55  
    56  // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
    57  type LinuxSocketBind struct {
    58  	// mu guards sock4 and sock6 and the associated fds.
    59  	// As long as someone holds mu (read or write), the associated fds are valid.
    60  	mu    sync.RWMutex
    61  	sock4 int
    62  	sock6 int
    63  }
    64  
    65  func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
    66  func NewDefaultBind() Bind     { return NewLinuxSocketBind() }
    67  
    68  var _ Endpoint = (*LinuxSocketEndpoint)(nil)
    69  var _ Bind = (*LinuxSocketBind)(nil)
    70  
    71  func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
    72  	var end LinuxSocketEndpoint
    73  	addr, err := parseEndpoint(s)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	ipv4 := addr.IP.To4()
    79  	if ipv4 != nil {
    80  		dst := end.dst4()
    81  		end.isV6 = false
    82  		dst.Port = addr.Port
    83  		copy(dst.Addr[:], ipv4)
    84  		end.ClearSrc()
    85  		return &end, nil
    86  	}
    87  
    88  	ipv6 := addr.IP.To16()
    89  	if ipv6 != nil {
    90  		zone, err := zoneToUint32(addr.Zone)
    91  		if err != nil {
    92  			return nil, err
    93  		}
    94  		dst := end.dst6()
    95  		end.isV6 = true
    96  		dst.Port = addr.Port
    97  		dst.ZoneId = zone
    98  		copy(dst.Addr[:], ipv6[:])
    99  		end.ClearSrc()
   100  		return &end, nil
   101  	}
   102  
   103  	return nil, errors.New("invalid IP address")
   104  }
   105  
   106  func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
   107  	bind.mu.Lock()
   108  	defer bind.mu.Unlock()
   109  
   110  	var err error
   111  	var newPort uint16
   112  	var tries int
   113  
   114  	if bind.sock4 != -1 || bind.sock6 != -1 {
   115  		return nil, 0, ErrBindAlreadyOpen
   116  	}
   117  
   118  	originalPort := port
   119  
   120  again:
   121  	port = originalPort
   122  	var sock4, sock6 int
   123  	// Attempt ipv6 bind, update port if successful.
   124  	sock6, newPort, err = create6(port)
   125  	if err != nil {
   126  		if !errors.Is(err, syscall.EAFNOSUPPORT) {
   127  			return nil, 0, err
   128  		}
   129  	} else {
   130  		port = newPort
   131  	}
   132  
   133  	// Attempt ipv4 bind, update port if successful.
   134  	sock4, newPort, err = create4(port)
   135  	if err != nil {
   136  		if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
   137  			unix.Close(sock6)
   138  			tries++
   139  			goto again
   140  		}
   141  		if !errors.Is(err, syscall.EAFNOSUPPORT) {
   142  			unix.Close(sock6)
   143  			return nil, 0, err
   144  		}
   145  	} else {
   146  		port = newPort
   147  	}
   148  
   149  	var fns []ReceiveFunc
   150  	if sock4 != -1 {
   151  		bind.sock4 = sock4
   152  		fns = append(fns, bind.receiveIPv4)
   153  	}
   154  	if sock6 != -1 {
   155  		bind.sock6 = sock6
   156  		fns = append(fns, bind.receiveIPv6)
   157  	}
   158  	if len(fns) == 0 {
   159  		return nil, 0, syscall.EAFNOSUPPORT
   160  	}
   161  	return fns, port, nil
   162  }
   163  
   164  func (bind *LinuxSocketBind) SetMark(value uint32) error {
   165  	bind.mu.RLock()
   166  	defer bind.mu.RUnlock()
   167  
   168  	if bind.sock6 != -1 {
   169  		err := unix.SetsockoptInt(
   170  			bind.sock6,
   171  			unix.SOL_SOCKET,
   172  			unix.SO_MARK,
   173  			int(value),
   174  		)
   175  
   176  		if err != nil {
   177  			return err
   178  		}
   179  	}
   180  
   181  	if bind.sock4 != -1 {
   182  		err := unix.SetsockoptInt(
   183  			bind.sock4,
   184  			unix.SOL_SOCKET,
   185  			unix.SO_MARK,
   186  			int(value),
   187  		)
   188  
   189  		if err != nil {
   190  			return err
   191  		}
   192  	}
   193  
   194  	return nil
   195  }
   196  
   197  func (bind *LinuxSocketBind) Close() error {
   198  	// Take a readlock to shut down the sockets...
   199  	bind.mu.RLock()
   200  	if bind.sock6 != -1 {
   201  		unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
   202  	}
   203  	if bind.sock4 != -1 {
   204  		unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
   205  	}
   206  	bind.mu.RUnlock()
   207  	// ...and a write lock to close the fd.
   208  	// This ensures that no one else is using the fd.
   209  	bind.mu.Lock()
   210  	defer bind.mu.Unlock()
   211  	var err1, err2 error
   212  	if bind.sock6 != -1 {
   213  		err1 = unix.Close(bind.sock6)
   214  		bind.sock6 = -1
   215  	}
   216  	if bind.sock4 != -1 {
   217  		err2 = unix.Close(bind.sock4)
   218  		bind.sock4 = -1
   219  	}
   220  
   221  	if err1 != nil {
   222  		return err1
   223  	}
   224  	return err2
   225  }
   226  
   227  func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
   228  	bind.mu.RLock()
   229  	defer bind.mu.RUnlock()
   230  	if bind.sock4 == -1 {
   231  		return 0, nil, net.ErrClosed
   232  	}
   233  	var end LinuxSocketEndpoint
   234  	n, err := receive4(bind.sock4, buf, &end)
   235  	return n, &end, err
   236  }
   237  
   238  func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
   239  	bind.mu.RLock()
   240  	defer bind.mu.RUnlock()
   241  	if bind.sock6 == -1 {
   242  		return 0, nil, net.ErrClosed
   243  	}
   244  	var end LinuxSocketEndpoint
   245  	n, err := receive6(bind.sock6, buf, &end)
   246  	return n, &end, err
   247  }
   248  
   249  func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
   250  	nend, ok := end.(*LinuxSocketEndpoint)
   251  	if !ok {
   252  		return ErrWrongEndpointType
   253  	}
   254  	bind.mu.RLock()
   255  	defer bind.mu.RUnlock()
   256  	if !nend.isV6 {
   257  		if bind.sock4 == -1 {
   258  			return net.ErrClosed
   259  		}
   260  		return send4(bind.sock4, nend, buff)
   261  	} else {
   262  		if bind.sock6 == -1 {
   263  			return net.ErrClosed
   264  		}
   265  		return send6(bind.sock6, nend, buff)
   266  	}
   267  }
   268  
   269  func (end *LinuxSocketEndpoint) SrcIP() net.IP {
   270  	if !end.isV6 {
   271  		return net.IPv4(
   272  			end.src4().Src[0],
   273  			end.src4().Src[1],
   274  			end.src4().Src[2],
   275  			end.src4().Src[3],
   276  		)
   277  	} else {
   278  		return end.src6().src[:]
   279  	}
   280  }
   281  
   282  func (end *LinuxSocketEndpoint) DstIP() net.IP {
   283  	if !end.isV6 {
   284  		return net.IPv4(
   285  			end.dst4().Addr[0],
   286  			end.dst4().Addr[1],
   287  			end.dst4().Addr[2],
   288  			end.dst4().Addr[3],
   289  		)
   290  	} else {
   291  		return end.dst6().Addr[:]
   292  	}
   293  }
   294  
   295  func (end *LinuxSocketEndpoint) DstToBytes() []byte {
   296  	if !end.isV6 {
   297  		return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
   298  	} else {
   299  		return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
   300  	}
   301  }
   302  
   303  func (end *LinuxSocketEndpoint) SrcToString() string {
   304  	return end.SrcIP().String()
   305  }
   306  
   307  func (end *LinuxSocketEndpoint) DstToString() string {
   308  	var udpAddr net.UDPAddr
   309  	udpAddr.IP = end.DstIP()
   310  	if !end.isV6 {
   311  		udpAddr.Port = end.dst4().Port
   312  	} else {
   313  		udpAddr.Port = end.dst6().Port
   314  	}
   315  	return udpAddr.String()
   316  }
   317  
   318  func (end *LinuxSocketEndpoint) ClearDst() {
   319  	for i := range end.dst {
   320  		end.dst[i] = 0
   321  	}
   322  }
   323  
   324  func (end *LinuxSocketEndpoint) ClearSrc() {
   325  	for i := range end.src {
   326  		end.src[i] = 0
   327  	}
   328  }
   329  
   330  func zoneToUint32(zone string) (uint32, error) {
   331  	if zone == "" {
   332  		return 0, nil
   333  	}
   334  	if intr, err := net.InterfaceByName(zone); err == nil {
   335  		return uint32(intr.Index), nil
   336  	}
   337  	n, err := strconv.ParseUint(zone, 10, 32)
   338  	return uint32(n), err
   339  }
   340  
   341  func create4(port uint16) (int, uint16, error) {
   342  
   343  	// create socket
   344  
   345  	fd, err := unix.Socket(
   346  		unix.AF_INET,
   347  		unix.SOCK_DGRAM,
   348  		0,
   349  	)
   350  
   351  	if err != nil {
   352  		return -1, 0, err
   353  	}
   354  
   355  	addr := unix.SockaddrInet4{
   356  		Port: int(port),
   357  	}
   358  
   359  	// set sockopts and bind
   360  
   361  	if err := func() error {
   362  		if err := unix.SetsockoptInt(
   363  			fd,
   364  			unix.IPPROTO_IP,
   365  			unix.IP_PKTINFO,
   366  			1,
   367  		); err != nil {
   368  			return err
   369  		}
   370  
   371  		return unix.Bind(fd, &addr)
   372  	}(); err != nil {
   373  		unix.Close(fd)
   374  		return -1, 0, err
   375  	}
   376  
   377  	sa, err := unix.Getsockname(fd)
   378  	if err == nil {
   379  		addr.Port = sa.(*unix.SockaddrInet4).Port
   380  	}
   381  
   382  	return fd, uint16(addr.Port), err
   383  }
   384  
   385  func create6(port uint16) (int, uint16, error) {
   386  
   387  	// create socket
   388  
   389  	fd, err := unix.Socket(
   390  		unix.AF_INET6,
   391  		unix.SOCK_DGRAM,
   392  		0,
   393  	)
   394  
   395  	if err != nil {
   396  		return -1, 0, err
   397  	}
   398  
   399  	// set sockopts and bind
   400  
   401  	addr := unix.SockaddrInet6{
   402  		Port: int(port),
   403  	}
   404  
   405  	if err := func() error {
   406  		if err := unix.SetsockoptInt(
   407  			fd,
   408  			unix.IPPROTO_IPV6,
   409  			unix.IPV6_RECVPKTINFO,
   410  			1,
   411  		); err != nil {
   412  			return err
   413  		}
   414  
   415  		if err := unix.SetsockoptInt(
   416  			fd,
   417  			unix.IPPROTO_IPV6,
   418  			unix.IPV6_V6ONLY,
   419  			1,
   420  		); err != nil {
   421  			return err
   422  		}
   423  
   424  		return unix.Bind(fd, &addr)
   425  
   426  	}(); err != nil {
   427  		unix.Close(fd)
   428  		return -1, 0, err
   429  	}
   430  
   431  	sa, err := unix.Getsockname(fd)
   432  	if err == nil {
   433  		addr.Port = sa.(*unix.SockaddrInet6).Port
   434  	}
   435  
   436  	return fd, uint16(addr.Port), err
   437  }
   438  
   439  func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
   440  
   441  	// construct message header
   442  
   443  	cmsg := struct {
   444  		cmsghdr unix.Cmsghdr
   445  		pktinfo unix.Inet4Pktinfo
   446  	}{
   447  		unix.Cmsghdr{
   448  			Level: unix.IPPROTO_IP,
   449  			Type:  unix.IP_PKTINFO,
   450  			Len:   unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
   451  		},
   452  		unix.Inet4Pktinfo{
   453  			Spec_dst: end.src4().Src,
   454  			Ifindex:  end.src4().Ifindex,
   455  		},
   456  	}
   457  
   458  	end.mu.Lock()
   459  	_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
   460  	end.mu.Unlock()
   461  
   462  	if err == nil {
   463  		return nil
   464  	}
   465  
   466  	// clear src and retry
   467  
   468  	if err == unix.EINVAL {
   469  		end.ClearSrc()
   470  		cmsg.pktinfo = unix.Inet4Pktinfo{}
   471  		end.mu.Lock()
   472  		_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
   473  		end.mu.Unlock()
   474  	}
   475  
   476  	return err
   477  }
   478  
   479  func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
   480  
   481  	// construct message header
   482  
   483  	cmsg := struct {
   484  		cmsghdr unix.Cmsghdr
   485  		pktinfo unix.Inet6Pktinfo
   486  	}{
   487  		unix.Cmsghdr{
   488  			Level: unix.IPPROTO_IPV6,
   489  			Type:  unix.IPV6_PKTINFO,
   490  			Len:   unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
   491  		},
   492  		unix.Inet6Pktinfo{
   493  			Addr:    end.src6().src,
   494  			Ifindex: end.dst6().ZoneId,
   495  		},
   496  	}
   497  
   498  	if cmsg.pktinfo.Addr == [16]byte{} {
   499  		cmsg.pktinfo.Ifindex = 0
   500  	}
   501  
   502  	end.mu.Lock()
   503  	_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
   504  	end.mu.Unlock()
   505  
   506  	if err == nil {
   507  		return nil
   508  	}
   509  
   510  	// clear src and retry
   511  
   512  	if err == unix.EINVAL {
   513  		end.ClearSrc()
   514  		cmsg.pktinfo = unix.Inet6Pktinfo{}
   515  		end.mu.Lock()
   516  		_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
   517  		end.mu.Unlock()
   518  	}
   519  
   520  	return err
   521  }
   522  
   523  func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
   524  
   525  	// construct message header
   526  
   527  	var cmsg struct {
   528  		cmsghdr unix.Cmsghdr
   529  		pktinfo unix.Inet4Pktinfo
   530  	}
   531  
   532  	size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
   533  
   534  	if err != nil {
   535  		return 0, err
   536  	}
   537  	end.isV6 = false
   538  
   539  	if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
   540  		*end.dst4() = *newDst4
   541  	}
   542  
   543  	// update source cache
   544  
   545  	if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
   546  		cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
   547  		cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
   548  		end.src4().Src = cmsg.pktinfo.Spec_dst
   549  		end.src4().Ifindex = cmsg.pktinfo.Ifindex
   550  	}
   551  
   552  	return size, nil
   553  }
   554  
   555  func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
   556  
   557  	// construct message header
   558  
   559  	var cmsg struct {
   560  		cmsghdr unix.Cmsghdr
   561  		pktinfo unix.Inet6Pktinfo
   562  	}
   563  
   564  	size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
   565  
   566  	if err != nil {
   567  		return 0, err
   568  	}
   569  	end.isV6 = true
   570  
   571  	if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
   572  		*end.dst6() = *newDst6
   573  	}
   574  
   575  	// update source cache
   576  
   577  	if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
   578  		cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
   579  		cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
   580  		end.src6().src = cmsg.pktinfo.Addr
   581  		end.dst6().ZoneId = cmsg.pktinfo.Ifindex
   582  	}
   583  
   584  	return size, nil
   585  }