github.com/slackhq/nebula@v1.9.0/udp/udp_linux.go (about) 1 //go:build !android && !e2e_testing 2 // +build !android,!e2e_testing 3 4 package udp 5 6 import ( 7 "encoding/binary" 8 "fmt" 9 "net" 10 "syscall" 11 "unsafe" 12 13 "github.com/rcrowley/go-metrics" 14 "github.com/sirupsen/logrus" 15 "github.com/slackhq/nebula/config" 16 "github.com/slackhq/nebula/firewall" 17 "github.com/slackhq/nebula/header" 18 "golang.org/x/sys/unix" 19 ) 20 21 //TODO: make it support reload as best you can! 22 23 type StdConn struct { 24 sysFd int 25 isV4 bool 26 l *logrus.Logger 27 batch int 28 } 29 30 var x int 31 32 // From linux/sock_diag.h 33 const ( 34 _SK_MEMINFO_RMEM_ALLOC = iota 35 _SK_MEMINFO_RCVBUF 36 _SK_MEMINFO_WMEM_ALLOC 37 _SK_MEMINFO_SNDBUF 38 _SK_MEMINFO_FWD_ALLOC 39 _SK_MEMINFO_WMEM_QUEUED 40 _SK_MEMINFO_OPTMEM 41 _SK_MEMINFO_BACKLOG 42 _SK_MEMINFO_DROPS 43 44 _SK_MEMINFO_VARS 45 ) 46 47 type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32 48 49 func maybeIPV4(ip net.IP) (net.IP, bool) { 50 ip4 := ip.To4() 51 if ip4 != nil { 52 return ip4, true 53 } 54 return ip, false 55 } 56 57 func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) { 58 ipV4, isV4 := maybeIPV4(ip) 59 af := unix.AF_INET6 60 if isV4 { 61 af = unix.AF_INET 62 } 63 syscall.ForkLock.RLock() 64 fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP) 65 if err == nil { 66 unix.CloseOnExec(fd) 67 } 68 syscall.ForkLock.RUnlock() 69 70 if err != nil { 71 unix.Close(fd) 72 return nil, fmt.Errorf("unable to open socket: %s", err) 73 } 74 75 if multi { 76 if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { 77 return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) 78 } 79 } 80 81 //TODO: support multiple listening IPs (for limiting ipv6) 82 var sa unix.Sockaddr 83 if isV4 { 84 sa4 := &unix.SockaddrInet4{Port: port} 85 copy(sa4.Addr[:], ipV4) 86 sa = sa4 87 } else { 88 sa6 := &unix.SockaddrInet6{Port: port} 89 copy(sa6.Addr[:], ip.To16()) 90 sa = sa6 91 } 92 if err = unix.Bind(fd, sa); err != nil { 93 return nil, fmt.Errorf("unable to bind to socket: %s", err) 94 } 95 96 //TODO: this may be useful for forcing threads into specific cores 97 //unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU, x) 98 //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) 99 //l.Println(v, err) 100 101 return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err 102 } 103 104 func (u *StdConn) Rebind() error { 105 return nil 106 } 107 108 func (u *StdConn) SetRecvBuffer(n int) error { 109 return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) 110 } 111 112 func (u *StdConn) SetSendBuffer(n int) error { 113 return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) 114 } 115 116 func (u *StdConn) GetRecvBuffer() (int, error) { 117 return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) 118 } 119 120 func (u *StdConn) GetSendBuffer() (int, error) { 121 return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) 122 } 123 124 func (u *StdConn) LocalAddr() (*Addr, error) { 125 sa, err := unix.Getsockname(u.sysFd) 126 if err != nil { 127 return nil, err 128 } 129 130 addr := &Addr{} 131 switch sa := sa.(type) { 132 case *unix.SockaddrInet4: 133 addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16() 134 addr.Port = uint16(sa.Port) 135 case *unix.SockaddrInet6: 136 addr.IP = sa.Addr[0:] 137 addr.Port = uint16(sa.Port) 138 } 139 140 return addr, nil 141 } 142 143 func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { 144 plaintext := make([]byte, MTU) 145 h := &header.H{} 146 fwPacket := &firewall.Packet{} 147 udpAddr := &Addr{} 148 nb := make([]byte, 12, 12) 149 150 //TODO: should we track this? 151 //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) 152 msgs, buffers, names := u.PrepareRawMessages(u.batch) 153 read := u.ReadMulti 154 if u.batch == 1 { 155 read = u.ReadSingle 156 } 157 158 for { 159 n, err := read(msgs) 160 if err != nil { 161 u.l.WithError(err).Debug("udp socket is closed, exiting read loop") 162 return 163 } 164 165 //metric.Update(int64(n)) 166 for i := 0; i < n; i++ { 167 if u.isV4 { 168 udpAddr.IP = names[i][4:8] 169 } else { 170 udpAddr.IP = names[i][8:24] 171 } 172 udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4]) 173 r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l)) 174 } 175 } 176 } 177 178 func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { 179 for { 180 n, _, err := unix.Syscall6( 181 unix.SYS_RECVMSG, 182 uintptr(u.sysFd), 183 uintptr(unsafe.Pointer(&(msgs[0].Hdr))), 184 0, 185 0, 186 0, 187 0, 188 ) 189 190 if err != 0 { 191 return 0, &net.OpError{Op: "recvmsg", Err: err} 192 } 193 194 msgs[0].Len = uint32(n) 195 return 1, nil 196 } 197 } 198 199 func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { 200 for { 201 n, _, err := unix.Syscall6( 202 unix.SYS_RECVMMSG, 203 uintptr(u.sysFd), 204 uintptr(unsafe.Pointer(&msgs[0])), 205 uintptr(len(msgs)), 206 unix.MSG_WAITFORONE, 207 0, 208 0, 209 ) 210 211 if err != 0 { 212 return 0, &net.OpError{Op: "recvmmsg", Err: err} 213 } 214 215 return int(n), nil 216 } 217 } 218 219 func (u *StdConn) WriteTo(b []byte, addr *Addr) error { 220 if u.isV4 { 221 return u.writeTo4(b, addr) 222 } 223 return u.writeTo6(b, addr) 224 } 225 226 func (u *StdConn) writeTo6(b []byte, addr *Addr) error { 227 var rsa unix.RawSockaddrInet6 228 rsa.Family = unix.AF_INET6 229 // Little Endian -> Network Endian 230 rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) 231 copy(rsa.Addr[:], addr.IP.To16()) 232 233 for { 234 _, _, err := unix.Syscall6( 235 unix.SYS_SENDTO, 236 uintptr(u.sysFd), 237 uintptr(unsafe.Pointer(&b[0])), 238 uintptr(len(b)), 239 uintptr(0), 240 uintptr(unsafe.Pointer(&rsa)), 241 uintptr(unix.SizeofSockaddrInet6), 242 ) 243 244 if err != 0 { 245 return &net.OpError{Op: "sendto", Err: err} 246 } 247 248 //TODO: handle incomplete writes 249 250 return nil 251 } 252 } 253 254 func (u *StdConn) writeTo4(b []byte, addr *Addr) error { 255 addrV4, isAddrV4 := maybeIPV4(addr.IP) 256 if !isAddrV4 { 257 return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote") 258 } 259 260 var rsa unix.RawSockaddrInet4 261 rsa.Family = unix.AF_INET 262 // Little Endian -> Network Endian 263 rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8) 264 copy(rsa.Addr[:], addrV4) 265 266 for { 267 _, _, err := unix.Syscall6( 268 unix.SYS_SENDTO, 269 uintptr(u.sysFd), 270 uintptr(unsafe.Pointer(&b[0])), 271 uintptr(len(b)), 272 uintptr(0), 273 uintptr(unsafe.Pointer(&rsa)), 274 uintptr(unix.SizeofSockaddrInet4), 275 ) 276 277 if err != 0 { 278 return &net.OpError{Op: "sendto", Err: err} 279 } 280 281 //TODO: handle incomplete writes 282 283 return nil 284 } 285 } 286 287 func (u *StdConn) ReloadConfig(c *config.C) { 288 b := c.GetInt("listen.read_buffer", 0) 289 if b > 0 { 290 err := u.SetRecvBuffer(b) 291 if err == nil { 292 s, err := u.GetRecvBuffer() 293 if err == nil { 294 u.l.WithField("size", s).Info("listen.read_buffer was set") 295 } else { 296 u.l.WithError(err).Warn("Failed to get listen.read_buffer") 297 } 298 } else { 299 u.l.WithError(err).Error("Failed to set listen.read_buffer") 300 } 301 } 302 303 b = c.GetInt("listen.write_buffer", 0) 304 if b > 0 { 305 err := u.SetSendBuffer(b) 306 if err == nil { 307 s, err := u.GetSendBuffer() 308 if err == nil { 309 u.l.WithField("size", s).Info("listen.write_buffer was set") 310 } else { 311 u.l.WithError(err).Warn("Failed to get listen.write_buffer") 312 } 313 } else { 314 u.l.WithError(err).Error("Failed to set listen.write_buffer") 315 } 316 } 317 } 318 319 func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error { 320 var vallen uint32 = 4 * _SK_MEMINFO_VARS 321 _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) 322 if err != 0 { 323 return err 324 } 325 return nil 326 } 327 328 func (u *StdConn) Close() error { 329 //TODO: this will not interrupt the read loop 330 return syscall.Close(u.sysFd) 331 } 332 333 func NewUDPStatsEmitter(udpConns []Conn) func() { 334 // Check if our kernel supports SO_MEMINFO before registering the gauges 335 var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge 336 var meminfo _SK_MEMINFO 337 if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil { 338 udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns)) 339 for i := range udpConns { 340 udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{ 341 metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil), 342 metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil), 343 metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil), 344 metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil), 345 metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil), 346 metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil), 347 metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil), 348 metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil), 349 metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil), 350 } 351 } 352 } 353 354 return func() { 355 for i, gauges := range udpGauges { 356 if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil { 357 for j := 0; j < _SK_MEMINFO_VARS; j++ { 358 gauges[j].Update(int64(meminfo[j])) 359 } 360 } 361 } 362 } 363 }