github.com/slackhq/nebula@v1.9.0/udp/udp_linux.go (about)

     1  //go:build !android && !e2e_testing
     2  // +build !android,!e2e_testing
     3  
     4  package udp
     5  
     6  import (
     7  	"encoding/binary"
     8  	"fmt"
     9  	"net"
    10  	"syscall"
    11  	"unsafe"
    12  
    13  	"github.com/rcrowley/go-metrics"
    14  	"github.com/sirupsen/logrus"
    15  	"github.com/slackhq/nebula/config"
    16  	"github.com/slackhq/nebula/firewall"
    17  	"github.com/slackhq/nebula/header"
    18  	"golang.org/x/sys/unix"
    19  )
    20  
    21  //TODO: make it support reload as best you can!
    22  
    23  type StdConn struct {
    24  	sysFd int
    25  	isV4  bool
    26  	l     *logrus.Logger
    27  	batch int
    28  }
    29  
    30  var x int
    31  
    32  // From linux/sock_diag.h
    33  const (
    34  	_SK_MEMINFO_RMEM_ALLOC = iota
    35  	_SK_MEMINFO_RCVBUF
    36  	_SK_MEMINFO_WMEM_ALLOC
    37  	_SK_MEMINFO_SNDBUF
    38  	_SK_MEMINFO_FWD_ALLOC
    39  	_SK_MEMINFO_WMEM_QUEUED
    40  	_SK_MEMINFO_OPTMEM
    41  	_SK_MEMINFO_BACKLOG
    42  	_SK_MEMINFO_DROPS
    43  
    44  	_SK_MEMINFO_VARS
    45  )
    46  
    47  type _SK_MEMINFO [_SK_MEMINFO_VARS]uint32
    48  
    49  func maybeIPV4(ip net.IP) (net.IP, bool) {
    50  	ip4 := ip.To4()
    51  	if ip4 != nil {
    52  		return ip4, true
    53  	}
    54  	return ip, false
    55  }
    56  
    57  func NewListener(l *logrus.Logger, ip net.IP, port int, multi bool, batch int) (Conn, error) {
    58  	ipV4, isV4 := maybeIPV4(ip)
    59  	af := unix.AF_INET6
    60  	if isV4 {
    61  		af = unix.AF_INET
    62  	}
    63  	syscall.ForkLock.RLock()
    64  	fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
    65  	if err == nil {
    66  		unix.CloseOnExec(fd)
    67  	}
    68  	syscall.ForkLock.RUnlock()
    69  
    70  	if err != nil {
    71  		unix.Close(fd)
    72  		return nil, fmt.Errorf("unable to open socket: %s", err)
    73  	}
    74  
    75  	if multi {
    76  		if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
    77  			return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err)
    78  		}
    79  	}
    80  
    81  	//TODO: support multiple listening IPs (for limiting ipv6)
    82  	var sa unix.Sockaddr
    83  	if isV4 {
    84  		sa4 := &unix.SockaddrInet4{Port: port}
    85  		copy(sa4.Addr[:], ipV4)
    86  		sa = sa4
    87  	} else {
    88  		sa6 := &unix.SockaddrInet6{Port: port}
    89  		copy(sa6.Addr[:], ip.To16())
    90  		sa = sa6
    91  	}
    92  	if err = unix.Bind(fd, sa); err != nil {
    93  		return nil, fmt.Errorf("unable to bind to socket: %s", err)
    94  	}
    95  
    96  	//TODO: this may be useful for forcing threads into specific cores
    97  	//unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU, x)
    98  	//v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU)
    99  	//l.Println(v, err)
   100  
   101  	return &StdConn{sysFd: fd, isV4: isV4, l: l, batch: batch}, err
   102  }
   103  
   104  func (u *StdConn) Rebind() error {
   105  	return nil
   106  }
   107  
   108  func (u *StdConn) SetRecvBuffer(n int) error {
   109  	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n)
   110  }
   111  
   112  func (u *StdConn) SetSendBuffer(n int) error {
   113  	return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n)
   114  }
   115  
   116  func (u *StdConn) GetRecvBuffer() (int, error) {
   117  	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF)
   118  }
   119  
   120  func (u *StdConn) GetSendBuffer() (int, error) {
   121  	return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF)
   122  }
   123  
   124  func (u *StdConn) LocalAddr() (*Addr, error) {
   125  	sa, err := unix.Getsockname(u.sysFd)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	addr := &Addr{}
   131  	switch sa := sa.(type) {
   132  	case *unix.SockaddrInet4:
   133  		addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
   134  		addr.Port = uint16(sa.Port)
   135  	case *unix.SockaddrInet6:
   136  		addr.IP = sa.Addr[0:]
   137  		addr.Port = uint16(sa.Port)
   138  	}
   139  
   140  	return addr, nil
   141  }
   142  
   143  func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) {
   144  	plaintext := make([]byte, MTU)
   145  	h := &header.H{}
   146  	fwPacket := &firewall.Packet{}
   147  	udpAddr := &Addr{}
   148  	nb := make([]byte, 12, 12)
   149  
   150  	//TODO: should we track this?
   151  	//metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015))
   152  	msgs, buffers, names := u.PrepareRawMessages(u.batch)
   153  	read := u.ReadMulti
   154  	if u.batch == 1 {
   155  		read = u.ReadSingle
   156  	}
   157  
   158  	for {
   159  		n, err := read(msgs)
   160  		if err != nil {
   161  			u.l.WithError(err).Debug("udp socket is closed, exiting read loop")
   162  			return
   163  		}
   164  
   165  		//metric.Update(int64(n))
   166  		for i := 0; i < n; i++ {
   167  			if u.isV4 {
   168  				udpAddr.IP = names[i][4:8]
   169  			} else {
   170  				udpAddr.IP = names[i][8:24]
   171  			}
   172  			udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
   173  			r(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], h, fwPacket, lhf, nb, q, cache.Get(u.l))
   174  		}
   175  	}
   176  }
   177  
   178  func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) {
   179  	for {
   180  		n, _, err := unix.Syscall6(
   181  			unix.SYS_RECVMSG,
   182  			uintptr(u.sysFd),
   183  			uintptr(unsafe.Pointer(&(msgs[0].Hdr))),
   184  			0,
   185  			0,
   186  			0,
   187  			0,
   188  		)
   189  
   190  		if err != 0 {
   191  			return 0, &net.OpError{Op: "recvmsg", Err: err}
   192  		}
   193  
   194  		msgs[0].Len = uint32(n)
   195  		return 1, nil
   196  	}
   197  }
   198  
   199  func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) {
   200  	for {
   201  		n, _, err := unix.Syscall6(
   202  			unix.SYS_RECVMMSG,
   203  			uintptr(u.sysFd),
   204  			uintptr(unsafe.Pointer(&msgs[0])),
   205  			uintptr(len(msgs)),
   206  			unix.MSG_WAITFORONE,
   207  			0,
   208  			0,
   209  		)
   210  
   211  		if err != 0 {
   212  			return 0, &net.OpError{Op: "recvmmsg", Err: err}
   213  		}
   214  
   215  		return int(n), nil
   216  	}
   217  }
   218  
   219  func (u *StdConn) WriteTo(b []byte, addr *Addr) error {
   220  	if u.isV4 {
   221  		return u.writeTo4(b, addr)
   222  	}
   223  	return u.writeTo6(b, addr)
   224  }
   225  
   226  func (u *StdConn) writeTo6(b []byte, addr *Addr) error {
   227  	var rsa unix.RawSockaddrInet6
   228  	rsa.Family = unix.AF_INET6
   229  	// Little Endian -> Network Endian
   230  	rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
   231  	copy(rsa.Addr[:], addr.IP.To16())
   232  
   233  	for {
   234  		_, _, err := unix.Syscall6(
   235  			unix.SYS_SENDTO,
   236  			uintptr(u.sysFd),
   237  			uintptr(unsafe.Pointer(&b[0])),
   238  			uintptr(len(b)),
   239  			uintptr(0),
   240  			uintptr(unsafe.Pointer(&rsa)),
   241  			uintptr(unix.SizeofSockaddrInet6),
   242  		)
   243  
   244  		if err != 0 {
   245  			return &net.OpError{Op: "sendto", Err: err}
   246  		}
   247  
   248  		//TODO: handle incomplete writes
   249  
   250  		return nil
   251  	}
   252  }
   253  
   254  func (u *StdConn) writeTo4(b []byte, addr *Addr) error {
   255  	addrV4, isAddrV4 := maybeIPV4(addr.IP)
   256  	if !isAddrV4 {
   257  		return fmt.Errorf("Listener is IPv4, but writing to IPv6 remote")
   258  	}
   259  
   260  	var rsa unix.RawSockaddrInet4
   261  	rsa.Family = unix.AF_INET
   262  	// Little Endian -> Network Endian
   263  	rsa.Port = (addr.Port >> 8) | ((addr.Port & 0xff) << 8)
   264  	copy(rsa.Addr[:], addrV4)
   265  
   266  	for {
   267  		_, _, err := unix.Syscall6(
   268  			unix.SYS_SENDTO,
   269  			uintptr(u.sysFd),
   270  			uintptr(unsafe.Pointer(&b[0])),
   271  			uintptr(len(b)),
   272  			uintptr(0),
   273  			uintptr(unsafe.Pointer(&rsa)),
   274  			uintptr(unix.SizeofSockaddrInet4),
   275  		)
   276  
   277  		if err != 0 {
   278  			return &net.OpError{Op: "sendto", Err: err}
   279  		}
   280  
   281  		//TODO: handle incomplete writes
   282  
   283  		return nil
   284  	}
   285  }
   286  
   287  func (u *StdConn) ReloadConfig(c *config.C) {
   288  	b := c.GetInt("listen.read_buffer", 0)
   289  	if b > 0 {
   290  		err := u.SetRecvBuffer(b)
   291  		if err == nil {
   292  			s, err := u.GetRecvBuffer()
   293  			if err == nil {
   294  				u.l.WithField("size", s).Info("listen.read_buffer was set")
   295  			} else {
   296  				u.l.WithError(err).Warn("Failed to get listen.read_buffer")
   297  			}
   298  		} else {
   299  			u.l.WithError(err).Error("Failed to set listen.read_buffer")
   300  		}
   301  	}
   302  
   303  	b = c.GetInt("listen.write_buffer", 0)
   304  	if b > 0 {
   305  		err := u.SetSendBuffer(b)
   306  		if err == nil {
   307  			s, err := u.GetSendBuffer()
   308  			if err == nil {
   309  				u.l.WithField("size", s).Info("listen.write_buffer was set")
   310  			} else {
   311  				u.l.WithError(err).Warn("Failed to get listen.write_buffer")
   312  			}
   313  		} else {
   314  			u.l.WithError(err).Error("Failed to set listen.write_buffer")
   315  		}
   316  	}
   317  }
   318  
   319  func (u *StdConn) getMemInfo(meminfo *_SK_MEMINFO) error {
   320  	var vallen uint32 = 4 * _SK_MEMINFO_VARS
   321  	_, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0)
   322  	if err != 0 {
   323  		return err
   324  	}
   325  	return nil
   326  }
   327  
   328  func (u *StdConn) Close() error {
   329  	//TODO: this will not interrupt the read loop
   330  	return syscall.Close(u.sysFd)
   331  }
   332  
   333  func NewUDPStatsEmitter(udpConns []Conn) func() {
   334  	// Check if our kernel supports SO_MEMINFO before registering the gauges
   335  	var udpGauges [][_SK_MEMINFO_VARS]metrics.Gauge
   336  	var meminfo _SK_MEMINFO
   337  	if err := udpConns[0].(*StdConn).getMemInfo(&meminfo); err == nil {
   338  		udpGauges = make([][_SK_MEMINFO_VARS]metrics.Gauge, len(udpConns))
   339  		for i := range udpConns {
   340  			udpGauges[i] = [_SK_MEMINFO_VARS]metrics.Gauge{
   341  				metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rmem_alloc", i), nil),
   342  				metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.rcvbuf", i), nil),
   343  				metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_alloc", i), nil),
   344  				metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.sndbuf", i), nil),
   345  				metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.fwd_alloc", i), nil),
   346  				metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.wmem_queued", i), nil),
   347  				metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.optmem", i), nil),
   348  				metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.backlog", i), nil),
   349  				metrics.GetOrRegisterGauge(fmt.Sprintf("udp.%d.drops", i), nil),
   350  			}
   351  		}
   352  	}
   353  
   354  	return func() {
   355  		for i, gauges := range udpGauges {
   356  			if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil {
   357  				for j := 0; j < _SK_MEMINFO_VARS; j++ {
   358  					gauges[j].Update(int64(meminfo[j]))
   359  				}
   360  			}
   361  		}
   362  	}
   363  }