github.com/GFW-knocker/wireguard@v1.0.1/device/sticky_linux.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 * 5 * This implements userspace semantics of "sticky sockets", modeled after 6 * WireGuard's kernelspace implementation. This is more or less a straight port 7 * of the sticky-sockets.c example code: 8 * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c 9 * 10 * Currently there is no way to achieve this within the net package: 11 * See e.g. https://github.com/golang/go/issues/17930 12 * So this code is remains platform dependent. 13 */ 14 15 package device 16 17 import ( 18 "sync" 19 "unsafe" 20 21 "golang.org/x/sys/unix" 22 23 "github.com/GFW-knocker/wireguard/conn" 24 "github.com/GFW-knocker/wireguard/rwcancel" 25 ) 26 27 func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 28 if !conn.StdNetSupportsStickySockets { 29 return nil, nil 30 } 31 if _, ok := bind.(*conn.StdNetBind); !ok { 32 return nil, nil 33 } 34 35 netlinkSock, err := createNetlinkRouteSocket() 36 if err != nil { 37 return nil, err 38 } 39 netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) 40 if err != nil { 41 unix.Close(netlinkSock) 42 return nil, err 43 } 44 45 go device.routineRouteListener(bind, netlinkSock, netlinkCancel) 46 47 return netlinkCancel, nil 48 } 49 50 func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { 51 type peerEndpointPtr struct { 52 peer *Peer 53 endpoint *conn.Endpoint 54 } 55 var reqPeer map[uint32]peerEndpointPtr 56 var reqPeerLock sync.Mutex 57 58 defer netlinkCancel.Close() 59 defer unix.Close(netlinkSock) 60 61 for msg := make([]byte, 1<<16); ; { 62 var err error 63 var msgn int 64 for { 65 msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) 66 if err == nil || !rwcancel.RetryAfterError(err) { 67 break 68 } 69 if !netlinkCancel.ReadyRead() { 70 return 71 } 72 } 73 if err != nil { 74 return 75 } 76 77 for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { 78 79 hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) 80 81 if uint(hdr.Len) > uint(len(remain)) { 82 break 83 } 84 85 switch hdr.Type { 86 case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: 87 if hdr.Seq <= MaxPeers && hdr.Seq > 0 { 88 if uint(len(remain)) < uint(hdr.Len) { 89 break 90 } 91 if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { 92 attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] 93 for { 94 if uint(len(attr)) < uint(unix.SizeofRtAttr) { 95 break 96 } 97 attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) 98 if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { 99 break 100 } 101 if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { 102 ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) 103 reqPeerLock.Lock() 104 if reqPeer == nil { 105 reqPeerLock.Unlock() 106 break 107 } 108 pePtr, ok := reqPeer[hdr.Seq] 109 reqPeerLock.Unlock() 110 if !ok { 111 break 112 } 113 pePtr.peer.endpoint.Lock() 114 if &pePtr.peer.endpoint.val != pePtr.endpoint { 115 pePtr.peer.endpoint.Unlock() 116 break 117 } 118 if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { 119 pePtr.peer.endpoint.Unlock() 120 break 121 } 122 pePtr.peer.endpoint.clearSrcOnTx = true 123 pePtr.peer.endpoint.Unlock() 124 } 125 attr = attr[attrhdr.Len:] 126 } 127 } 128 break 129 } 130 reqPeerLock.Lock() 131 reqPeer = make(map[uint32]peerEndpointPtr) 132 reqPeerLock.Unlock() 133 go func() { 134 device.peers.RLock() 135 i := uint32(1) 136 for _, peer := range device.peers.keyMap { 137 peer.endpoint.Lock() 138 if peer.endpoint.val == nil { 139 peer.endpoint.Unlock() 140 continue 141 } 142 nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint) 143 if nativeEP == nil { 144 peer.endpoint.Unlock() 145 continue 146 } 147 if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { 148 peer.endpoint.Unlock() 149 break 150 } 151 nlmsg := struct { 152 hdr unix.NlMsghdr 153 msg unix.RtMsg 154 dsthdr unix.RtAttr 155 dst [4]byte 156 srchdr unix.RtAttr 157 src [4]byte 158 markhdr unix.RtAttr 159 mark uint32 160 }{ 161 unix.NlMsghdr{ 162 Type: uint16(unix.RTM_GETROUTE), 163 Flags: unix.NLM_F_REQUEST, 164 Seq: i, 165 }, 166 unix.RtMsg{ 167 Family: unix.AF_INET, 168 Dst_len: 32, 169 Src_len: 32, 170 }, 171 unix.RtAttr{ 172 Len: 8, 173 Type: unix.RTA_DST, 174 }, 175 nativeEP.DstIP().As4(), 176 unix.RtAttr{ 177 Len: 8, 178 Type: unix.RTA_SRC, 179 }, 180 nativeEP.SrcIP().As4(), 181 unix.RtAttr{ 182 Len: 8, 183 Type: unix.RTA_MARK, 184 }, 185 device.net.fwmark, 186 } 187 nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) 188 reqPeerLock.Lock() 189 reqPeer[i] = peerEndpointPtr{ 190 peer: peer, 191 endpoint: &peer.endpoint.val, 192 } 193 reqPeerLock.Unlock() 194 peer.endpoint.Unlock() 195 i++ 196 _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) 197 if err != nil { 198 break 199 } 200 } 201 device.peers.RUnlock() 202 }() 203 } 204 remain = remain[hdr.Len:] 205 } 206 } 207 } 208 209 func createNetlinkRouteSocket() (int, error) { 210 sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) 211 if err != nil { 212 return -1, err 213 } 214 saddr := &unix.SockaddrNetlink{ 215 Family: unix.AF_NETLINK, 216 Groups: unix.RTMGRP_IPV4_ROUTE, 217 } 218 err = unix.Bind(sock, saddr) 219 if err != nil { 220 unix.Close(sock) 221 return -1, err 222 } 223 return sock, nil 224 }