github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/sticky_linux.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 * 5 * This implements userspace semantics of "sticky sockets", modeled after 6 * WireGuard's kernelspace implementation. This is more or less a straight port 7 * of the sticky-sockets.c example code: 8 * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c 9 * 10 * Currently there is no way to achieve this within the net package: 11 * See e.g. https://github.com/golang/go/issues/17930 12 * So this code is remains platform dependent. 13 */ 14 15 package device 16 17 import ( 18 "sync" 19 "unsafe" 20 21 "golang.org/x/sys/unix" 22 23 "github.com/liloew/wireguard-go/conn" 24 "github.com/liloew/wireguard-go/rwcancel" 25 ) 26 27 func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 28 if _, ok := bind.(*conn.LinuxSocketBind); !ok { 29 return nil, nil 30 } 31 32 netlinkSock, err := createNetlinkRouteSocket() 33 if err != nil { 34 return nil, err 35 } 36 netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) 37 if err != nil { 38 unix.Close(netlinkSock) 39 return nil, err 40 } 41 42 go device.routineRouteListener(bind, netlinkSock, netlinkCancel) 43 44 return netlinkCancel, nil 45 } 46 47 func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { 48 type peerEndpointPtr struct { 49 peer *Peer 50 endpoint *conn.Endpoint 51 } 52 var reqPeer map[uint32]peerEndpointPtr 53 var reqPeerLock sync.Mutex 54 55 defer netlinkCancel.Close() 56 defer unix.Close(netlinkSock) 57 58 for msg := make([]byte, 1<<16); ; { 59 var err error 60 var msgn int 61 for { 62 msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) 63 if err == nil || !rwcancel.RetryAfterError(err) { 64 break 65 } 66 if !netlinkCancel.ReadyRead() { 67 return 68 } 69 } 70 if err != nil { 71 return 72 } 73 74 for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { 75 76 hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) 77 78 if uint(hdr.Len) > uint(len(remain)) { 79 break 80 } 81 82 switch hdr.Type { 83 case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: 84 if hdr.Seq <= MaxPeers && hdr.Seq > 0 { 85 if uint(len(remain)) < uint(hdr.Len) { 86 break 87 } 88 if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { 89 attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] 90 for { 91 if uint(len(attr)) < uint(unix.SizeofRtAttr) { 92 break 93 } 94 attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) 95 if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { 96 break 97 } 98 if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { 99 ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) 100 reqPeerLock.Lock() 101 if reqPeer == nil { 102 reqPeerLock.Unlock() 103 break 104 } 105 pePtr, ok := reqPeer[hdr.Seq] 106 reqPeerLock.Unlock() 107 if !ok { 108 break 109 } 110 pePtr.peer.Lock() 111 if &pePtr.peer.endpoint != pePtr.endpoint { 112 pePtr.peer.Unlock() 113 break 114 } 115 if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx { 116 pePtr.peer.Unlock() 117 break 118 } 119 pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc() 120 pePtr.peer.Unlock() 121 } 122 attr = attr[attrhdr.Len:] 123 } 124 } 125 break 126 } 127 reqPeerLock.Lock() 128 reqPeer = make(map[uint32]peerEndpointPtr) 129 reqPeerLock.Unlock() 130 go func() { 131 device.peers.RLock() 132 i := uint32(1) 133 for _, peer := range device.peers.keyMap { 134 peer.RLock() 135 if peer.endpoint == nil { 136 peer.RUnlock() 137 continue 138 } 139 nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint) 140 if nativeEP == nil { 141 peer.RUnlock() 142 continue 143 } 144 if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { 145 peer.RUnlock() 146 break 147 } 148 nlmsg := struct { 149 hdr unix.NlMsghdr 150 msg unix.RtMsg 151 dsthdr unix.RtAttr 152 dst [4]byte 153 srchdr unix.RtAttr 154 src [4]byte 155 markhdr unix.RtAttr 156 mark uint32 157 }{ 158 unix.NlMsghdr{ 159 Type: uint16(unix.RTM_GETROUTE), 160 Flags: unix.NLM_F_REQUEST, 161 Seq: i, 162 }, 163 unix.RtMsg{ 164 Family: unix.AF_INET, 165 Dst_len: 32, 166 Src_len: 32, 167 }, 168 unix.RtAttr{ 169 Len: 8, 170 Type: unix.RTA_DST, 171 }, 172 nativeEP.Dst4().Addr, 173 unix.RtAttr{ 174 Len: 8, 175 Type: unix.RTA_SRC, 176 }, 177 nativeEP.Src4().Src, 178 unix.RtAttr{ 179 Len: 8, 180 Type: unix.RTA_MARK, 181 }, 182 device.net.fwmark, 183 } 184 nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) 185 reqPeerLock.Lock() 186 reqPeer[i] = peerEndpointPtr{ 187 peer: peer, 188 endpoint: &peer.endpoint, 189 } 190 reqPeerLock.Unlock() 191 peer.RUnlock() 192 i++ 193 _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) 194 if err != nil { 195 break 196 } 197 } 198 device.peers.RUnlock() 199 }() 200 } 201 remain = remain[hdr.Len:] 202 } 203 } 204 } 205 206 func createNetlinkRouteSocket() (int, error) { 207 sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) 208 if err != nil { 209 return -1, err 210 } 211 saddr := &unix.SockaddrNetlink{ 212 Family: unix.AF_NETLINK, 213 Groups: unix.RTMGRP_IPV4_ROUTE, 214 } 215 err = unix.Bind(sock, saddr) 216 if err != nil { 217 unix.Close(sock) 218 return -1, err 219 } 220 return sock, nil 221 }