github.com/bepass-org/wireguard-go@v1.0.4-rc2.0.20240304192354-ebce6572bc24/device/sticky_linux.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   *
     5   * This implements userspace semantics of "sticky sockets", modeled after
     6   * WireGuard's kernelspace implementation. This is more or less a straight port
     7   * of the sticky-sockets.c example code:
     8   * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
     9   *
    10   * Currently there is no way to achieve this within the net package:
    11   * See e.g. https://github.com/golang/go/issues/17930
    12   * So this code is remains platform dependent.
    13   */
    14  
    15  package device
    16  
    17  import (
    18  	"sync"
    19  	"unsafe"
    20  
    21  	"golang.org/x/sys/unix"
    22  
    23  	"github.com/bepass-org/wireguard-go/conn"
    24  	"github.com/bepass-org/wireguard-go/rwcancel"
    25  )
    26  
    27  func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
    28  	if !conn.StdNetSupportsStickySockets {
    29  		return nil, nil
    30  	}
    31  	if _, ok := bind.(*conn.StdNetBind); !ok {
    32  		return nil, nil
    33  	}
    34  
    35  	netlinkSock, err := createNetlinkRouteSocket()
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  	netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
    40  	if err != nil {
    41  		unix.Close(netlinkSock)
    42  		return nil, err
    43  	}
    44  
    45  	go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
    46  
    47  	return netlinkCancel, nil
    48  }
    49  
    50  func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
    51  	type peerEndpointPtr struct {
    52  		peer     *Peer
    53  		endpoint *conn.Endpoint
    54  	}
    55  	var reqPeer map[uint32]peerEndpointPtr
    56  	var reqPeerLock sync.Mutex
    57  
    58  	defer netlinkCancel.Close()
    59  	defer unix.Close(netlinkSock)
    60  
    61  	for msg := make([]byte, 1<<16); ; {
    62  		var err error
    63  		var msgn int
    64  		for {
    65  			msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
    66  			if err == nil || !rwcancel.RetryAfterError(err) {
    67  				break
    68  			}
    69  			if !netlinkCancel.ReadyRead() {
    70  				return
    71  			}
    72  		}
    73  		if err != nil {
    74  			return
    75  		}
    76  
    77  		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
    78  
    79  			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
    80  
    81  			if uint(hdr.Len) > uint(len(remain)) {
    82  				break
    83  			}
    84  
    85  			switch hdr.Type {
    86  			case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
    87  				if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
    88  					if uint(len(remain)) < uint(hdr.Len) {
    89  						break
    90  					}
    91  					if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
    92  						attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
    93  						for {
    94  							if uint(len(attr)) < uint(unix.SizeofRtAttr) {
    95  								break
    96  							}
    97  							attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
    98  							if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
    99  								break
   100  							}
   101  							if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
   102  								ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
   103  								reqPeerLock.Lock()
   104  								if reqPeer == nil {
   105  									reqPeerLock.Unlock()
   106  									break
   107  								}
   108  								pePtr, ok := reqPeer[hdr.Seq]
   109  								reqPeerLock.Unlock()
   110  								if !ok {
   111  									break
   112  								}
   113  								pePtr.peer.endpoint.Lock()
   114  								if &pePtr.peer.endpoint.val != pePtr.endpoint {
   115  									pePtr.peer.endpoint.Unlock()
   116  									break
   117  								}
   118  								if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
   119  									pePtr.peer.endpoint.Unlock()
   120  									break
   121  								}
   122  								pePtr.peer.endpoint.clearSrcOnTx = true
   123  								pePtr.peer.endpoint.Unlock()
   124  							}
   125  							attr = attr[attrhdr.Len:]
   126  						}
   127  					}
   128  					break
   129  				}
   130  				reqPeerLock.Lock()
   131  				reqPeer = make(map[uint32]peerEndpointPtr)
   132  				reqPeerLock.Unlock()
   133  				go func() {
   134  					device.peers.RLock()
   135  					i := uint32(1)
   136  					for _, peer := range device.peers.keyMap {
   137  						peer.endpoint.Lock()
   138  						if peer.endpoint.val == nil {
   139  							peer.endpoint.Unlock()
   140  							continue
   141  						}
   142  						nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
   143  						if nativeEP == nil {
   144  							peer.endpoint.Unlock()
   145  							continue
   146  						}
   147  						if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
   148  							peer.endpoint.Unlock()
   149  							break
   150  						}
   151  						nlmsg := struct {
   152  							hdr     unix.NlMsghdr
   153  							msg     unix.RtMsg
   154  							dsthdr  unix.RtAttr
   155  							dst     [4]byte
   156  							srchdr  unix.RtAttr
   157  							src     [4]byte
   158  							markhdr unix.RtAttr
   159  							mark    uint32
   160  						}{
   161  							unix.NlMsghdr{
   162  								Type:  uint16(unix.RTM_GETROUTE),
   163  								Flags: unix.NLM_F_REQUEST,
   164  								Seq:   i,
   165  							},
   166  							unix.RtMsg{
   167  								Family:  unix.AF_INET,
   168  								Dst_len: 32,
   169  								Src_len: 32,
   170  							},
   171  							unix.RtAttr{
   172  								Len:  8,
   173  								Type: unix.RTA_DST,
   174  							},
   175  							nativeEP.DstIP().As4(),
   176  							unix.RtAttr{
   177  								Len:  8,
   178  								Type: unix.RTA_SRC,
   179  							},
   180  							nativeEP.SrcIP().As4(),
   181  							unix.RtAttr{
   182  								Len:  8,
   183  								Type: unix.RTA_MARK,
   184  							},
   185  							device.net.fwmark,
   186  						}
   187  						nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
   188  						reqPeerLock.Lock()
   189  						reqPeer[i] = peerEndpointPtr{
   190  							peer:     peer,
   191  							endpoint: &peer.endpoint.val,
   192  						}
   193  						reqPeerLock.Unlock()
   194  						peer.endpoint.Unlock()
   195  						i++
   196  						_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
   197  						if err != nil {
   198  							break
   199  						}
   200  					}
   201  					device.peers.RUnlock()
   202  				}()
   203  			}
   204  			remain = remain[hdr.Len:]
   205  		}
   206  	}
   207  }
   208  
   209  func createNetlinkRouteSocket() (int, error) {
   210  	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
   211  	if err != nil {
   212  		return -1, err
   213  	}
   214  	saddr := &unix.SockaddrNetlink{
   215  		Family: unix.AF_NETLINK,
   216  		Groups: unix.RTMGRP_IPV4_ROUTE,
   217  	}
   218  	err = unix.Bind(sock, saddr)
   219  	if err != nil {
   220  		unix.Close(sock)
   221  		return -1, err
   222  	}
   223  	return sock, nil
   224  }