github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/device/sticky_linux.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   *
     5   * This implements userspace semantics of "sticky sockets", modeled after
     6   * WireGuard's kernelspace implementation. This is more or less a straight port
     7   * of the sticky-sockets.c example code:
     8   * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
     9   *
    10   * Currently there is no way to achieve this within the net package:
    11   * See e.g. https://github.com/golang/go/issues/17930
    12   * So this code is remains platform dependent.
    13   */
    14  
    15  package device
    16  
    17  import (
    18  	"sync"
    19  	"unsafe"
    20  
    21  	"github.com/sagernet/wireguard-go/conn"
    22  	"github.com/sagernet/wireguard-go/rwcancel"
    23  	"golang.org/x/sys/unix"
    24  )
    25  
    26  func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
    27  	if !conn.StdNetSupportsStickySockets {
    28  		return nil, nil
    29  	}
    30  	if _, ok := bind.(*conn.StdNetBind); !ok {
    31  		return nil, nil
    32  	}
    33  
    34  	netlinkSock, err := createNetlinkRouteSocket()
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  	netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
    39  	if err != nil {
    40  		unix.Close(netlinkSock)
    41  		return nil, err
    42  	}
    43  
    44  	go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
    45  
    46  	return netlinkCancel, nil
    47  }
    48  
    49  func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
    50  	type peerEndpointPtr struct {
    51  		peer     *Peer
    52  		endpoint *conn.Endpoint
    53  	}
    54  	var reqPeer map[uint32]peerEndpointPtr
    55  	var reqPeerLock sync.Mutex
    56  
    57  	defer netlinkCancel.Close()
    58  	defer unix.Close(netlinkSock)
    59  
    60  	for msg := make([]byte, 1<<16); ; {
    61  		var err error
    62  		var msgn int
    63  		for {
    64  			msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
    65  			if err == nil || !rwcancel.RetryAfterError(err) {
    66  				break
    67  			}
    68  			if !netlinkCancel.ReadyRead() {
    69  				return
    70  			}
    71  		}
    72  		if err != nil {
    73  			return
    74  		}
    75  
    76  		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
    77  
    78  			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
    79  
    80  			if uint(hdr.Len) > uint(len(remain)) {
    81  				break
    82  			}
    83  
    84  			switch hdr.Type {
    85  			case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
    86  				if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
    87  					if uint(len(remain)) < uint(hdr.Len) {
    88  						break
    89  					}
    90  					if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
    91  						attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
    92  						for {
    93  							if uint(len(attr)) < uint(unix.SizeofRtAttr) {
    94  								break
    95  							}
    96  							attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
    97  							if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
    98  								break
    99  							}
   100  							if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
   101  								ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
   102  								reqPeerLock.Lock()
   103  								if reqPeer == nil {
   104  									reqPeerLock.Unlock()
   105  									break
   106  								}
   107  								pePtr, ok := reqPeer[hdr.Seq]
   108  								reqPeerLock.Unlock()
   109  								if !ok {
   110  									break
   111  								}
   112  								pePtr.peer.Lock()
   113  								if &pePtr.peer.endpoint != pePtr.endpoint {
   114  									pePtr.peer.Unlock()
   115  									break
   116  								}
   117  								if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
   118  									pePtr.peer.Unlock()
   119  									break
   120  								}
   121  								pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc()
   122  								pePtr.peer.Unlock()
   123  							}
   124  							attr = attr[attrhdr.Len:]
   125  						}
   126  					}
   127  					break
   128  				}
   129  				reqPeerLock.Lock()
   130  				reqPeer = make(map[uint32]peerEndpointPtr)
   131  				reqPeerLock.Unlock()
   132  				go func() {
   133  					device.peers.RLock()
   134  					i := uint32(1)
   135  					for _, peer := range device.peers.keyMap {
   136  						peer.RLock()
   137  						if peer.endpoint == nil {
   138  							peer.RUnlock()
   139  							continue
   140  						}
   141  						nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint)
   142  						if nativeEP == nil {
   143  							peer.RUnlock()
   144  							continue
   145  						}
   146  						if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
   147  							peer.RUnlock()
   148  							break
   149  						}
   150  						nlmsg := struct {
   151  							hdr     unix.NlMsghdr
   152  							msg     unix.RtMsg
   153  							dsthdr  unix.RtAttr
   154  							dst     [4]byte
   155  							srchdr  unix.RtAttr
   156  							src     [4]byte
   157  							markhdr unix.RtAttr
   158  							mark    uint32
   159  						}{
   160  							unix.NlMsghdr{
   161  								Type:  uint16(unix.RTM_GETROUTE),
   162  								Flags: unix.NLM_F_REQUEST,
   163  								Seq:   i,
   164  							},
   165  							unix.RtMsg{
   166  								Family:  unix.AF_INET,
   167  								Dst_len: 32,
   168  								Src_len: 32,
   169  							},
   170  							unix.RtAttr{
   171  								Len:  8,
   172  								Type: unix.RTA_DST,
   173  							},
   174  							nativeEP.DstIP().As4(),
   175  							unix.RtAttr{
   176  								Len:  8,
   177  								Type: unix.RTA_SRC,
   178  							},
   179  							nativeEP.SrcIP().As4(),
   180  							unix.RtAttr{
   181  								Len:  8,
   182  								Type: unix.RTA_MARK,
   183  							},
   184  							device.net.fwmark,
   185  						}
   186  						nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
   187  						reqPeerLock.Lock()
   188  						reqPeer[i] = peerEndpointPtr{
   189  							peer:     peer,
   190  							endpoint: &peer.endpoint,
   191  						}
   192  						reqPeerLock.Unlock()
   193  						peer.RUnlock()
   194  						i++
   195  						_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
   196  						if err != nil {
   197  							break
   198  						}
   199  					}
   200  					device.peers.RUnlock()
   201  				}()
   202  			}
   203  			remain = remain[hdr.Len:]
   204  		}
   205  	}
   206  }
   207  
   208  func createNetlinkRouteSocket() (int, error) {
   209  	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
   210  	if err != nil {
   211  		return -1, err
   212  	}
   213  	saddr := &unix.SockaddrNetlink{
   214  		Family: unix.AF_NETLINK,
   215  		Groups: unix.RTMGRP_IPV4_ROUTE,
   216  	}
   217  	err = unix.Bind(sock, saddr)
   218  	if err != nil {
   219  		unix.Close(sock)
   220  		return -1, err
   221  	}
   222  	return sock, nil
   223  }