github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/device/sticky_linux.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   *
     5   * This implements userspace semantics of "sticky sockets", modeled after
     6   * WireGuard's kernelspace implementation. This is more or less a straight port
     7   * of the sticky-sockets.c example code:
     8   * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
     9   *
    10   * Currently there is no way to achieve this within the net package:
    11   * See e.g. https://github.com/golang/go/issues/17930
    12   * So this code is remains platform dependent.
    13   */
    14  
    15  package device
    16  
    17  import (
    18  	"sync"
    19  	"unsafe"
    20  
    21  	"github.com/tailscale/wireguard-go/conn"
    22  	"github.com/tailscale/wireguard-go/rwcancel"
    23  	"golang.org/x/sys/unix"
    24  )
    25  
    26  func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
    27  	if _, ok := bind.(*conn.LinuxSocketBind); !ok {
    28  		return nil, nil
    29  	}
    30  
    31  	netlinkSock, err := createNetlinkRouteSocket()
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  	netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
    36  	if err != nil {
    37  		unix.Close(netlinkSock)
    38  		return nil, err
    39  	}
    40  
    41  	go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
    42  
    43  	return netlinkCancel, nil
    44  }
    45  
    46  func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
    47  	type peerEndpointPtr struct {
    48  		peer     *Peer
    49  		endpoint *conn.Endpoint
    50  	}
    51  	var reqPeer map[uint32]peerEndpointPtr
    52  	var reqPeerLock sync.Mutex
    53  
    54  	defer netlinkCancel.Close()
    55  	defer unix.Close(netlinkSock)
    56  
    57  	for msg := make([]byte, 1<<16); ; {
    58  		var err error
    59  		var msgn int
    60  		for {
    61  			msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
    62  			if err == nil || !rwcancel.RetryAfterError(err) {
    63  				break
    64  			}
    65  			if !netlinkCancel.ReadyRead() {
    66  				return
    67  			}
    68  		}
    69  		if err != nil {
    70  			return
    71  		}
    72  
    73  		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
    74  
    75  			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
    76  
    77  			if uint(hdr.Len) > uint(len(remain)) {
    78  				break
    79  			}
    80  
    81  			switch hdr.Type {
    82  			case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
    83  				if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
    84  					if uint(len(remain)) < uint(hdr.Len) {
    85  						break
    86  					}
    87  					if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
    88  						attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
    89  						for {
    90  							if uint(len(attr)) < uint(unix.SizeofRtAttr) {
    91  								break
    92  							}
    93  							attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
    94  							if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
    95  								break
    96  							}
    97  							if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
    98  								ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
    99  								reqPeerLock.Lock()
   100  								if reqPeer == nil {
   101  									reqPeerLock.Unlock()
   102  									break
   103  								}
   104  								pePtr, ok := reqPeer[hdr.Seq]
   105  								reqPeerLock.Unlock()
   106  								if !ok {
   107  									break
   108  								}
   109  								pePtr.peer.Lock()
   110  								if &pePtr.peer.endpoint != pePtr.endpoint {
   111  									pePtr.peer.Unlock()
   112  									break
   113  								}
   114  								if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
   115  									pePtr.peer.Unlock()
   116  									break
   117  								}
   118  								pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
   119  								pePtr.peer.Unlock()
   120  							}
   121  							attr = attr[attrhdr.Len:]
   122  						}
   123  					}
   124  					break
   125  				}
   126  				reqPeerLock.Lock()
   127  				reqPeer = make(map[uint32]peerEndpointPtr)
   128  				reqPeerLock.Unlock()
   129  				go func() {
   130  					device.peers.RLock()
   131  					i := uint32(1)
   132  					for _, peer := range device.peers.keyMap {
   133  						peer.RLock()
   134  						if peer.endpoint == nil {
   135  							peer.RUnlock()
   136  							continue
   137  						}
   138  						nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
   139  						if nativeEP == nil {
   140  							peer.RUnlock()
   141  							continue
   142  						}
   143  						if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
   144  							peer.RUnlock()
   145  							break
   146  						}
   147  						nlmsg := struct {
   148  							hdr     unix.NlMsghdr
   149  							msg     unix.RtMsg
   150  							dsthdr  unix.RtAttr
   151  							dst     [4]byte
   152  							srchdr  unix.RtAttr
   153  							src     [4]byte
   154  							markhdr unix.RtAttr
   155  							mark    uint32
   156  						}{
   157  							unix.NlMsghdr{
   158  								Type:  uint16(unix.RTM_GETROUTE),
   159  								Flags: unix.NLM_F_REQUEST,
   160  								Seq:   i,
   161  							},
   162  							unix.RtMsg{
   163  								Family:  unix.AF_INET,
   164  								Dst_len: 32,
   165  								Src_len: 32,
   166  							},
   167  							unix.RtAttr{
   168  								Len:  8,
   169  								Type: unix.RTA_DST,
   170  							},
   171  							nativeEP.Dst4().Addr,
   172  							unix.RtAttr{
   173  								Len:  8,
   174  								Type: unix.RTA_SRC,
   175  							},
   176  							nativeEP.Src4().Src,
   177  							unix.RtAttr{
   178  								Len:  8,
   179  								Type: unix.RTA_MARK,
   180  							},
   181  							device.net.fwmark,
   182  						}
   183  						nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
   184  						reqPeerLock.Lock()
   185  						reqPeer[i] = peerEndpointPtr{
   186  							peer:     peer,
   187  							endpoint: &peer.endpoint,
   188  						}
   189  						reqPeerLock.Unlock()
   190  						peer.RUnlock()
   191  						i++
   192  						_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
   193  						if err != nil {
   194  							break
   195  						}
   196  					}
   197  					device.peers.RUnlock()
   198  				}()
   199  			}
   200  			remain = remain[hdr.Len:]
   201  		}
   202  	}
   203  }
   204  
   205  func createNetlinkRouteSocket() (int, error) {
   206  	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
   207  	if err != nil {
   208  		return -1, err
   209  	}
   210  	saddr := &unix.SockaddrNetlink{
   211  		Family: unix.AF_NETLINK,
   212  		Groups: unix.RTMGRP_IPV4_ROUTE,
   213  	}
   214  	err = unix.Bind(sock, saddr)
   215  	if err != nil {
   216  		unix.Close(sock)
   217  		return -1, err
   218  	}
   219  	return sock, nil
   220  }