github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/device/sticky_linux.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 * 5 * This implements userspace semantics of "sticky sockets", modeled after 6 * WireGuard's kernelspace implementation. This is more or less a straight port 7 * of the sticky-sockets.c example code: 8 * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c 9 * 10 * Currently there is no way to achieve this within the net package: 11 * See e.g. https://github.com/golang/go/issues/17930 12 * So this code is remains platform dependent. 13 */ 14 15 package device 16 17 import ( 18 "sync" 19 "unsafe" 20 21 "github.com/tailscale/wireguard-go/conn" 22 "github.com/tailscale/wireguard-go/rwcancel" 23 "golang.org/x/sys/unix" 24 ) 25 26 func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { 27 if _, ok := bind.(*conn.LinuxSocketBind); !ok { 28 return nil, nil 29 } 30 31 netlinkSock, err := createNetlinkRouteSocket() 32 if err != nil { 33 return nil, err 34 } 35 netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) 36 if err != nil { 37 unix.Close(netlinkSock) 38 return nil, err 39 } 40 41 go device.routineRouteListener(bind, netlinkSock, netlinkCancel) 42 43 return netlinkCancel, nil 44 } 45 46 func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { 47 type peerEndpointPtr struct { 48 peer *Peer 49 endpoint *conn.Endpoint 50 } 51 var reqPeer map[uint32]peerEndpointPtr 52 var reqPeerLock sync.Mutex 53 54 defer netlinkCancel.Close() 55 defer unix.Close(netlinkSock) 56 57 for msg := make([]byte, 1<<16); ; { 58 var err error 59 var msgn int 60 for { 61 msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) 62 if err == nil || !rwcancel.RetryAfterError(err) { 63 break 64 } 65 if !netlinkCancel.ReadyRead() { 66 return 67 } 68 } 69 if err != nil { 70 return 71 } 72 73 for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { 74 75 hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) 76 77 if uint(hdr.Len) > uint(len(remain)) { 78 break 79 } 80 81 switch hdr.Type { 82 case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: 83 if hdr.Seq <= MaxPeers && hdr.Seq > 0 { 84 if uint(len(remain)) < uint(hdr.Len) { 85 break 86 } 87 if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { 88 attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] 89 for { 90 if uint(len(attr)) < uint(unix.SizeofRtAttr) { 91 break 92 } 93 attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) 94 if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { 95 break 96 } 97 if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { 98 ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) 99 reqPeerLock.Lock() 100 if reqPeer == nil { 101 reqPeerLock.Unlock() 102 break 103 } 104 pePtr, ok := reqPeer[hdr.Seq] 105 reqPeerLock.Unlock() 106 if !ok { 107 break 108 } 109 pePtr.peer.Lock() 110 if &pePtr.peer.endpoint != pePtr.endpoint { 111 pePtr.peer.Unlock() 112 break 113 } 114 if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx { 115 pePtr.peer.Unlock() 116 break 117 } 118 pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc() 119 pePtr.peer.Unlock() 120 } 121 attr = attr[attrhdr.Len:] 122 } 123 } 124 break 125 } 126 reqPeerLock.Lock() 127 reqPeer = make(map[uint32]peerEndpointPtr) 128 reqPeerLock.Unlock() 129 go func() { 130 device.peers.RLock() 131 i := uint32(1) 132 for _, peer := range device.peers.keyMap { 133 peer.RLock() 134 if peer.endpoint == nil { 135 peer.RUnlock() 136 continue 137 } 138 nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint) 139 if nativeEP == nil { 140 peer.RUnlock() 141 continue 142 } 143 if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { 144 peer.RUnlock() 145 break 146 } 147 nlmsg := struct { 148 hdr unix.NlMsghdr 149 msg unix.RtMsg 150 dsthdr unix.RtAttr 151 dst [4]byte 152 srchdr unix.RtAttr 153 src [4]byte 154 markhdr unix.RtAttr 155 mark uint32 156 }{ 157 unix.NlMsghdr{ 158 Type: uint16(unix.RTM_GETROUTE), 159 Flags: unix.NLM_F_REQUEST, 160 Seq: i, 161 }, 162 unix.RtMsg{ 163 Family: unix.AF_INET, 164 Dst_len: 32, 165 Src_len: 32, 166 }, 167 unix.RtAttr{ 168 Len: 8, 169 Type: unix.RTA_DST, 170 }, 171 nativeEP.Dst4().Addr, 172 unix.RtAttr{ 173 Len: 8, 174 Type: unix.RTA_SRC, 175 }, 176 nativeEP.Src4().Src, 177 unix.RtAttr{ 178 Len: 8, 179 Type: unix.RTA_MARK, 180 }, 181 device.net.fwmark, 182 } 183 nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) 184 reqPeerLock.Lock() 185 reqPeer[i] = peerEndpointPtr{ 186 peer: peer, 187 endpoint: &peer.endpoint, 188 } 189 reqPeerLock.Unlock() 190 peer.RUnlock() 191 i++ 192 _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) 193 if err != nil { 194 break 195 } 196 } 197 device.peers.RUnlock() 198 }() 199 } 200 remain = remain[hdr.Len:] 201 } 202 } 203 } 204 205 func createNetlinkRouteSocket() (int, error) { 206 sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) 207 if err != nil { 208 return -1, err 209 } 210 saddr := &unix.SockaddrNetlink{ 211 Family: unix.AF_NETLINK, 212 Groups: unix.RTMGRP_IPV4_ROUTE, 213 } 214 err = unix.Bind(sock, saddr) 215 if err != nil { 216 unix.Close(sock) 217 return -1, err 218 } 219 return sock, nil 220 }