github.com/cawidtu/notwireguard-go/device@v0.0.0-20230523131112-68e8e5ce9cdf/sticky_linux.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   *
     5   * This implements userspace semantics of "sticky sockets", modeled after
     6   * WireGuard's kernelspace implementation. This is more or less a straight port
     7   * of the sticky-sockets.c example code:
     8   * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
     9   *
    10   * Currently there is no way to achieve this within the net package:
    11   * See e.g. https://github.com/golang/go/issues/17930
    12   * So this code is remains platform dependent.
    13   */
    14  
    15  package device
    16  
    17  import (
    18  	"sync"
    19  	"unsafe"
    20  
    21  	"golang.org/x/sys/unix"
    22  
    23  	"github.com/cawidtu/notwireguard-go/conn"
    24  	"github.com/cawidtu/notwireguard-go/rwcancel"
    25  )
    26  
    27  func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
    28  	if _, ok := bind.(*conn.LinuxSocketBind); !ok {
    29  		return nil, nil
    30  	}
    31  
    32  	netlinkSock, err := createNetlinkRouteSocket()
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  	netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
    37  	if err != nil {
    38  		unix.Close(netlinkSock)
    39  		return nil, err
    40  	}
    41  
    42  	go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
    43  
    44  	return netlinkCancel, nil
    45  }
    46  
    47  func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
    48  	type peerEndpointPtr struct {
    49  		peer     *Peer
    50  		endpoint *conn.Endpoint
    51  	}
    52  	var reqPeer map[uint32]peerEndpointPtr
    53  	var reqPeerLock sync.Mutex
    54  
    55  	defer netlinkCancel.Close()
    56  	defer unix.Close(netlinkSock)
    57  
    58  	for msg := make([]byte, 1<<16); ; {
    59  		var err error
    60  		var msgn int
    61  		for {
    62  			msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
    63  			if err == nil || !rwcancel.RetryAfterError(err) {
    64  				break
    65  			}
    66  			if !netlinkCancel.ReadyRead() {
    67  				return
    68  			}
    69  		}
    70  		if err != nil {
    71  			return
    72  		}
    73  
    74  		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
    75  
    76  			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
    77  
    78  			if uint(hdr.Len) > uint(len(remain)) {
    79  				break
    80  			}
    81  
    82  			switch hdr.Type {
    83  			case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
    84  				if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
    85  					if uint(len(remain)) < uint(hdr.Len) {
    86  						break
    87  					}
    88  					if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
    89  						attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
    90  						for {
    91  							if uint(len(attr)) < uint(unix.SizeofRtAttr) {
    92  								break
    93  							}
    94  							attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
    95  							if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
    96  								break
    97  							}
    98  							if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
    99  								ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
   100  								reqPeerLock.Lock()
   101  								if reqPeer == nil {
   102  									reqPeerLock.Unlock()
   103  									break
   104  								}
   105  								pePtr, ok := reqPeer[hdr.Seq]
   106  								reqPeerLock.Unlock()
   107  								if !ok {
   108  									break
   109  								}
   110  								pePtr.peer.Lock()
   111  								if &pePtr.peer.endpoint != pePtr.endpoint {
   112  									pePtr.peer.Unlock()
   113  									break
   114  								}
   115  								if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
   116  									pePtr.peer.Unlock()
   117  									break
   118  								}
   119  								pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
   120  								pePtr.peer.Unlock()
   121  							}
   122  							attr = attr[attrhdr.Len:]
   123  						}
   124  					}
   125  					break
   126  				}
   127  				reqPeerLock.Lock()
   128  				reqPeer = make(map[uint32]peerEndpointPtr)
   129  				reqPeerLock.Unlock()
   130  				go func() {
   131  					device.peers.RLock()
   132  					i := uint32(1)
   133  					for _, peer := range device.peers.keyMap {
   134  						peer.RLock()
   135  						if peer.endpoint == nil {
   136  							peer.RUnlock()
   137  							continue
   138  						}
   139  						nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
   140  						if nativeEP == nil {
   141  							peer.RUnlock()
   142  							continue
   143  						}
   144  						if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
   145  							peer.RUnlock()
   146  							break
   147  						}
   148  						nlmsg := struct {
   149  							hdr     unix.NlMsghdr
   150  							msg     unix.RtMsg
   151  							dsthdr  unix.RtAttr
   152  							dst     [4]byte
   153  							srchdr  unix.RtAttr
   154  							src     [4]byte
   155  							markhdr unix.RtAttr
   156  							mark    uint32
   157  						}{
   158  							unix.NlMsghdr{
   159  								Type:  uint16(unix.RTM_GETROUTE),
   160  								Flags: unix.NLM_F_REQUEST,
   161  								Seq:   i,
   162  							},
   163  							unix.RtMsg{
   164  								Family:  unix.AF_INET,
   165  								Dst_len: 32,
   166  								Src_len: 32,
   167  							},
   168  							unix.RtAttr{
   169  								Len:  8,
   170  								Type: unix.RTA_DST,
   171  							},
   172  							nativeEP.Dst4().Addr,
   173  							unix.RtAttr{
   174  								Len:  8,
   175  								Type: unix.RTA_SRC,
   176  							},
   177  							nativeEP.Src4().Src,
   178  							unix.RtAttr{
   179  								Len:  8,
   180  								Type: unix.RTA_MARK,
   181  							},
   182  							device.net.fwmark,
   183  						}
   184  						nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
   185  						reqPeerLock.Lock()
   186  						reqPeer[i] = peerEndpointPtr{
   187  							peer:     peer,
   188  							endpoint: &peer.endpoint,
   189  						}
   190  						reqPeerLock.Unlock()
   191  						peer.RUnlock()
   192  						i++
   193  						_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
   194  						if err != nil {
   195  							break
   196  						}
   197  					}
   198  					device.peers.RUnlock()
   199  				}()
   200  			}
   201  			remain = remain[hdr.Len:]
   202  		}
   203  	}
   204  }
   205  
   206  func createNetlinkRouteSocket() (int, error) {
   207  	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
   208  	if err != nil {
   209  		return -1, err
   210  	}
   211  	saddr := &unix.SockaddrNetlink{
   212  		Family: unix.AF_NETLINK,
   213  		Groups: unix.RTMGRP_IPV4_ROUTE,
   214  	}
   215  	err = unix.Bind(sock, saddr)
   216  	if err != nil {
   217  		unix.Close(sock)
   218  		return -1, err
   219  	}
   220  	return sock, nil
   221  }