github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/device/sticky_linux.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 * 5 * This implements userspace semantics of "sticky sockets", modeled after 6 * WireGuard's kernelspace implementation. This is more or less a straight port 7 * of the sticky-sockets.c example code: 8 * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c 9 * 10 * Currently there is no way to achieve this within the net package: 11 * See e.g. https://github.com/golang/go/issues/17930 12 * So this code is remains platform dependent. 13 */ 14 15 package device 16 17 import ( 18 "sync" 19 "unsafe" 20 21 "github.com/sagernet/wireguard-go/conn" 22 "github.com/sagernet/wireguard-go/rwcancel" 23 "golang.org/x/sys/unix" 24 ) 25 26 func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 27 if !conn.StdNetSupportsStickySockets { 28 return nil, nil 29 } 30 if _, ok := bind.(*conn.StdNetBind); !ok { 31 return nil, nil 32 } 33 34 netlinkSock, err := createNetlinkRouteSocket() 35 if err != nil { 36 return nil, err 37 } 38 netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) 39 if err != nil { 40 unix.Close(netlinkSock) 41 return nil, err 42 } 43 44 go device.routineRouteListener(bind, netlinkSock, netlinkCancel) 45 46 return netlinkCancel, nil 47 } 48 49 func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { 50 type peerEndpointPtr struct { 51 peer *Peer 52 endpoint *conn.Endpoint 53 } 54 var reqPeer map[uint32]peerEndpointPtr 55 var reqPeerLock sync.Mutex 56 57 defer netlinkCancel.Close() 58 defer unix.Close(netlinkSock) 59 60 for msg := make([]byte, 1<<16); ; { 61 var err error 62 var msgn int 63 for { 64 msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) 65 if err == nil || !rwcancel.RetryAfterError(err) { 66 break 67 } 68 if !netlinkCancel.ReadyRead() { 69 return 70 } 71 } 72 if err != nil { 73 return 74 } 75 76 for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { 77 78 hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) 79 80 if uint(hdr.Len) > uint(len(remain)) { 81 break 82 } 83 84 switch hdr.Type { 85 case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: 86 if hdr.Seq <= MaxPeers && hdr.Seq > 0 { 87 if uint(len(remain)) < uint(hdr.Len) { 88 break 89 } 90 if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { 91 attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] 92 for { 93 if uint(len(attr)) < uint(unix.SizeofRtAttr) { 94 break 95 } 96 attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) 97 if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { 98 break 99 } 100 if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { 101 ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) 102 reqPeerLock.Lock() 103 if reqPeer == nil { 104 reqPeerLock.Unlock() 105 break 106 } 107 pePtr, ok := reqPeer[hdr.Seq] 108 reqPeerLock.Unlock() 109 if !ok { 110 break 111 } 112 pePtr.peer.Lock() 113 if &pePtr.peer.endpoint != pePtr.endpoint { 114 pePtr.peer.Unlock() 115 break 116 } 117 if uint32(pePtr.peer.endpoint.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { 118 pePtr.peer.Unlock() 119 break 120 } 121 pePtr.peer.endpoint.(*conn.StdNetEndpoint).ClearSrc() 122 pePtr.peer.Unlock() 123 } 124 attr = attr[attrhdr.Len:] 125 } 126 } 127 break 128 } 129 reqPeerLock.Lock() 130 reqPeer = make(map[uint32]peerEndpointPtr) 131 reqPeerLock.Unlock() 132 go func() { 133 device.peers.RLock() 134 i := uint32(1) 135 for _, peer := range device.peers.keyMap { 136 peer.RLock() 137 if peer.endpoint == nil { 138 peer.RUnlock() 139 continue 140 } 141 nativeEP, _ := peer.endpoint.(*conn.StdNetEndpoint) 142 if nativeEP == nil { 143 peer.RUnlock() 144 continue 145 } 146 if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { 147 peer.RUnlock() 148 break 149 } 150 nlmsg := struct { 151 hdr unix.NlMsghdr 152 msg unix.RtMsg 153 dsthdr unix.RtAttr 154 dst [4]byte 155 srchdr unix.RtAttr 156 src [4]byte 157 markhdr unix.RtAttr 158 mark uint32 159 }{ 160 unix.NlMsghdr{ 161 Type: uint16(unix.RTM_GETROUTE), 162 Flags: unix.NLM_F_REQUEST, 163 Seq: i, 164 }, 165 unix.RtMsg{ 166 Family: unix.AF_INET, 167 Dst_len: 32, 168 Src_len: 32, 169 }, 170 unix.RtAttr{ 171 Len: 8, 172 Type: unix.RTA_DST, 173 }, 174 nativeEP.DstIP().As4(), 175 unix.RtAttr{ 176 Len: 8, 177 Type: unix.RTA_SRC, 178 }, 179 nativeEP.SrcIP().As4(), 180 unix.RtAttr{ 181 Len: 8, 182 Type: unix.RTA_MARK, 183 }, 184 device.net.fwmark, 185 } 186 nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) 187 reqPeerLock.Lock() 188 reqPeer[i] = peerEndpointPtr{ 189 peer: peer, 190 endpoint: &peer.endpoint, 191 } 192 reqPeerLock.Unlock() 193 peer.RUnlock() 194 i++ 195 _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) 196 if err != nil { 197 break 198 } 199 } 200 device.peers.RUnlock() 201 }() 202 } 203 remain = remain[hdr.Len:] 204 } 205 } 206 } 207 208 func createNetlinkRouteSocket() (int, error) { 209 sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) 210 if err != nil { 211 return -1, err 212 } 213 saddr := &unix.SockaddrNetlink{ 214 Family: unix.AF_NETLINK, 215 Groups: unix.RTMGRP_IPV4_ROUTE, 216 } 217 err = unix.Bind(sock, saddr) 218 if err != nil { 219 unix.Close(sock) 220 return -1, err 221 } 222 return sock, nil 223 }