github.com/vishvananda/netlink@v1.3.1/nl/nl_linux.go (about)

     1  // Package nl has low level primitives for making Netlink calls.
     2  package nl
     3  
     4  import (
     5  	"bytes"
     6  	"encoding/binary"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"os"
    11  	"runtime"
    12  	"sync"
    13  	"sync/atomic"
    14  	"syscall"
    15  	"time"
    16  	"unsafe"
    17  
    18  	"github.com/vishvananda/netns"
    19  	"golang.org/x/sys/unix"
    20  )
    21  
    22  const (
    23  	// Family type definitions
    24  	FAMILY_ALL  = unix.AF_UNSPEC
    25  	FAMILY_V4   = unix.AF_INET
    26  	FAMILY_V6   = unix.AF_INET6
    27  	FAMILY_MPLS = unix.AF_MPLS
    28  	// Arbitrary set value (greater than default 4k) to allow receiving
    29  	// from kernel more verbose messages e.g. for statistics,
    30  	// tc rules or filters, or other more memory requiring data.
    31  	RECEIVE_BUFFER_SIZE = 65536
    32  	// Kernel netlink pid
    33  	PidKernel     uint32 = 0
    34  	SizeofCnMsgOp        = 0x18
    35  )
    36  
    37  // SupportedNlFamilies contains the list of netlink families this netlink package supports
    38  var SupportedNlFamilies = []int{unix.NETLINK_ROUTE, unix.NETLINK_XFRM, unix.NETLINK_NETFILTER}
    39  
    40  var nextSeqNr uint32
    41  
    42  // Default netlink socket timeout, 60s
    43  var SocketTimeoutTv = unix.Timeval{Sec: 60, Usec: 0}
    44  
    45  // ErrorMessageReporting is the default error message reporting configuration for the new netlink sockets
    46  var EnableErrorMessageReporting bool = false
    47  
    48  // ErrDumpInterrupted is an instance of errDumpInterrupted, used to report that
    49  // a netlink function has set the NLM_F_DUMP_INTR flag in a response message,
    50  // indicating that the results may be incomplete or inconsistent.
    51  var ErrDumpInterrupted = errDumpInterrupted{}
    52  
    53  // errDumpInterrupted is an error type, used to report that NLM_F_DUMP_INTR was
    54  // set in a netlink response.
    55  type errDumpInterrupted struct{}
    56  
    57  func (errDumpInterrupted) Error() string {
    58  	return "results may be incomplete or inconsistent"
    59  }
    60  
    61  // Before errDumpInterrupted was introduced, EINTR was returned when a netlink
    62  // response had NLM_F_DUMP_INTR. Retain backward compatibility with code that
    63  // may be checking for EINTR using Is.
    64  func (e errDumpInterrupted) Is(target error) bool {
    65  	return target == unix.EINTR
    66  }
    67  
    68  // GetIPFamily returns the family type of a net.IP.
    69  func GetIPFamily(ip net.IP) int {
    70  	if len(ip) <= net.IPv4len {
    71  		return FAMILY_V4
    72  	}
    73  	if ip.To4() != nil {
    74  		return FAMILY_V4
    75  	}
    76  	return FAMILY_V6
    77  }
    78  
    79  var nativeEndian binary.ByteOrder
    80  
    81  // NativeEndian gets native endianness for the system
    82  func NativeEndian() binary.ByteOrder {
    83  	if nativeEndian == nil {
    84  		var x uint32 = 0x01020304
    85  		if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
    86  			nativeEndian = binary.BigEndian
    87  		} else {
    88  			nativeEndian = binary.LittleEndian
    89  		}
    90  	}
    91  	return nativeEndian
    92  }
    93  
    94  // Byte swap a 16 bit value if we aren't big endian
    95  func Swap16(i uint16) uint16 {
    96  	if NativeEndian() == binary.BigEndian {
    97  		return i
    98  	}
    99  	return (i&0xff00)>>8 | (i&0xff)<<8
   100  }
   101  
   102  // Byte swap a 32 bit value if aren't big endian
   103  func Swap32(i uint32) uint32 {
   104  	if NativeEndian() == binary.BigEndian {
   105  		return i
   106  	}
   107  	return (i&0xff000000)>>24 | (i&0xff0000)>>8 | (i&0xff00)<<8 | (i&0xff)<<24
   108  }
   109  
   110  const (
   111  	NLMSGERR_ATTR_UNUSED = 0
   112  	NLMSGERR_ATTR_MSG    = 1
   113  	NLMSGERR_ATTR_OFFS   = 2
   114  	NLMSGERR_ATTR_COOKIE = 3
   115  	NLMSGERR_ATTR_POLICY = 4
   116  )
   117  
   118  type NetlinkRequestData interface {
   119  	Len() int
   120  	Serialize() []byte
   121  }
   122  
   123  const (
   124  	PROC_CN_MCAST_LISTEN = 1
   125  	PROC_CN_MCAST_IGNORE
   126  )
   127  
   128  type CbID struct {
   129  	Idx uint32
   130  	Val uint32
   131  }
   132  
   133  type CnMsg struct {
   134  	ID     CbID
   135  	Seq    uint32
   136  	Ack    uint32
   137  	Length uint16
   138  	Flags  uint16
   139  }
   140  
   141  type CnMsgOp struct {
   142  	CnMsg
   143  	// here we differ from the C header
   144  	Op uint32
   145  }
   146  
   147  func NewCnMsg(idx, val, op uint32) *CnMsgOp {
   148  	var cm CnMsgOp
   149  
   150  	cm.ID.Idx = idx
   151  	cm.ID.Val = val
   152  
   153  	cm.Ack = 0
   154  	cm.Seq = 1
   155  	cm.Length = uint16(binary.Size(op))
   156  	cm.Op = op
   157  
   158  	return &cm
   159  }
   160  
   161  func (msg *CnMsgOp) Serialize() []byte {
   162  	return (*(*[SizeofCnMsgOp]byte)(unsafe.Pointer(msg)))[:]
   163  }
   164  
   165  func DeserializeCnMsgOp(b []byte) *CnMsgOp {
   166  	return (*CnMsgOp)(unsafe.Pointer(&b[0:SizeofCnMsgOp][0]))
   167  }
   168  
   169  func (msg *CnMsgOp) Len() int {
   170  	return SizeofCnMsgOp
   171  }
   172  
   173  // IfInfomsg is related to links, but it is used for list requests as well
   174  type IfInfomsg struct {
   175  	unix.IfInfomsg
   176  }
   177  
   178  // Create an IfInfomsg with family specified
   179  func NewIfInfomsg(family int) *IfInfomsg {
   180  	return &IfInfomsg{
   181  		IfInfomsg: unix.IfInfomsg{
   182  			Family: uint8(family),
   183  		},
   184  	}
   185  }
   186  
   187  func DeserializeIfInfomsg(b []byte) *IfInfomsg {
   188  	return (*IfInfomsg)(unsafe.Pointer(&b[0:unix.SizeofIfInfomsg][0]))
   189  }
   190  
   191  func (msg *IfInfomsg) Serialize() []byte {
   192  	return (*(*[unix.SizeofIfInfomsg]byte)(unsafe.Pointer(msg)))[:]
   193  }
   194  
   195  func (msg *IfInfomsg) Len() int {
   196  	return unix.SizeofIfInfomsg
   197  }
   198  
   199  func (msg *IfInfomsg) EncapType() string {
   200  	switch msg.Type {
   201  	case 0:
   202  		return "generic"
   203  	case unix.ARPHRD_ETHER:
   204  		return "ether"
   205  	case unix.ARPHRD_EETHER:
   206  		return "eether"
   207  	case unix.ARPHRD_AX25:
   208  		return "ax25"
   209  	case unix.ARPHRD_PRONET:
   210  		return "pronet"
   211  	case unix.ARPHRD_CHAOS:
   212  		return "chaos"
   213  	case unix.ARPHRD_IEEE802:
   214  		return "ieee802"
   215  	case unix.ARPHRD_ARCNET:
   216  		return "arcnet"
   217  	case unix.ARPHRD_APPLETLK:
   218  		return "atalk"
   219  	case unix.ARPHRD_DLCI:
   220  		return "dlci"
   221  	case unix.ARPHRD_ATM:
   222  		return "atm"
   223  	case unix.ARPHRD_METRICOM:
   224  		return "metricom"
   225  	case unix.ARPHRD_IEEE1394:
   226  		return "ieee1394"
   227  	case unix.ARPHRD_INFINIBAND:
   228  		return "infiniband"
   229  	case unix.ARPHRD_SLIP:
   230  		return "slip"
   231  	case unix.ARPHRD_CSLIP:
   232  		return "cslip"
   233  	case unix.ARPHRD_SLIP6:
   234  		return "slip6"
   235  	case unix.ARPHRD_CSLIP6:
   236  		return "cslip6"
   237  	case unix.ARPHRD_RSRVD:
   238  		return "rsrvd"
   239  	case unix.ARPHRD_ADAPT:
   240  		return "adapt"
   241  	case unix.ARPHRD_ROSE:
   242  		return "rose"
   243  	case unix.ARPHRD_X25:
   244  		return "x25"
   245  	case unix.ARPHRD_HWX25:
   246  		return "hwx25"
   247  	case unix.ARPHRD_PPP:
   248  		return "ppp"
   249  	case unix.ARPHRD_HDLC:
   250  		return "hdlc"
   251  	case unix.ARPHRD_LAPB:
   252  		return "lapb"
   253  	case unix.ARPHRD_DDCMP:
   254  		return "ddcmp"
   255  	case unix.ARPHRD_RAWHDLC:
   256  		return "rawhdlc"
   257  	case unix.ARPHRD_TUNNEL:
   258  		return "ipip"
   259  	case unix.ARPHRD_TUNNEL6:
   260  		return "tunnel6"
   261  	case unix.ARPHRD_FRAD:
   262  		return "frad"
   263  	case unix.ARPHRD_SKIP:
   264  		return "skip"
   265  	case unix.ARPHRD_LOOPBACK:
   266  		return "loopback"
   267  	case unix.ARPHRD_LOCALTLK:
   268  		return "ltalk"
   269  	case unix.ARPHRD_FDDI:
   270  		return "fddi"
   271  	case unix.ARPHRD_BIF:
   272  		return "bif"
   273  	case unix.ARPHRD_SIT:
   274  		return "sit"
   275  	case unix.ARPHRD_IPDDP:
   276  		return "ip/ddp"
   277  	case unix.ARPHRD_IPGRE:
   278  		return "gre"
   279  	case unix.ARPHRD_PIMREG:
   280  		return "pimreg"
   281  	case unix.ARPHRD_HIPPI:
   282  		return "hippi"
   283  	case unix.ARPHRD_ASH:
   284  		return "ash"
   285  	case unix.ARPHRD_ECONET:
   286  		return "econet"
   287  	case unix.ARPHRD_IRDA:
   288  		return "irda"
   289  	case unix.ARPHRD_FCPP:
   290  		return "fcpp"
   291  	case unix.ARPHRD_FCAL:
   292  		return "fcal"
   293  	case unix.ARPHRD_FCPL:
   294  		return "fcpl"
   295  	case unix.ARPHRD_FCFABRIC:
   296  		return "fcfb0"
   297  	case unix.ARPHRD_FCFABRIC + 1:
   298  		return "fcfb1"
   299  	case unix.ARPHRD_FCFABRIC + 2:
   300  		return "fcfb2"
   301  	case unix.ARPHRD_FCFABRIC + 3:
   302  		return "fcfb3"
   303  	case unix.ARPHRD_FCFABRIC + 4:
   304  		return "fcfb4"
   305  	case unix.ARPHRD_FCFABRIC + 5:
   306  		return "fcfb5"
   307  	case unix.ARPHRD_FCFABRIC + 6:
   308  		return "fcfb6"
   309  	case unix.ARPHRD_FCFABRIC + 7:
   310  		return "fcfb7"
   311  	case unix.ARPHRD_FCFABRIC + 8:
   312  		return "fcfb8"
   313  	case unix.ARPHRD_FCFABRIC + 9:
   314  		return "fcfb9"
   315  	case unix.ARPHRD_FCFABRIC + 10:
   316  		return "fcfb10"
   317  	case unix.ARPHRD_FCFABRIC + 11:
   318  		return "fcfb11"
   319  	case unix.ARPHRD_FCFABRIC + 12:
   320  		return "fcfb12"
   321  	case unix.ARPHRD_IEEE802_TR:
   322  		return "tr"
   323  	case unix.ARPHRD_IEEE80211:
   324  		return "ieee802.11"
   325  	case unix.ARPHRD_IEEE80211_PRISM:
   326  		return "ieee802.11/prism"
   327  	case unix.ARPHRD_IEEE80211_RADIOTAP:
   328  		return "ieee802.11/radiotap"
   329  	case unix.ARPHRD_IEEE802154:
   330  		return "ieee802.15.4"
   331  
   332  	case 65534:
   333  		return "none"
   334  	case 65535:
   335  		return "void"
   336  	}
   337  	return fmt.Sprintf("unknown%d", msg.Type)
   338  }
   339  
   340  // Round the length of a netlink message up to align it properly.
   341  // Taken from syscall/netlink_linux.go by The Go Authors under BSD-style license.
   342  func nlmAlignOf(msglen int) int {
   343  	return (msglen + syscall.NLMSG_ALIGNTO - 1) & ^(syscall.NLMSG_ALIGNTO - 1)
   344  }
   345  
   346  func rtaAlignOf(attrlen int) int {
   347  	return (attrlen + unix.RTA_ALIGNTO - 1) & ^(unix.RTA_ALIGNTO - 1)
   348  }
   349  
   350  func NewIfInfomsgChild(parent *RtAttr, family int) *IfInfomsg {
   351  	msg := NewIfInfomsg(family)
   352  	parent.children = append(parent.children, msg)
   353  	return msg
   354  }
   355  
   356  type Uint32Bitfield struct {
   357  	Value    uint32
   358  	Selector uint32
   359  }
   360  
   361  func (a *Uint32Bitfield) Serialize() []byte {
   362  	return (*(*[SizeofUint32Bitfield]byte)(unsafe.Pointer(a)))[:]
   363  }
   364  
   365  func DeserializeUint32Bitfield(data []byte) *Uint32Bitfield {
   366  	return (*Uint32Bitfield)(unsafe.Pointer(&data[0:SizeofUint32Bitfield][0]))
   367  }
   368  
   369  type Uint32Attribute struct {
   370  	Type  uint16
   371  	Value uint32
   372  }
   373  
   374  func (a *Uint32Attribute) Serialize() []byte {
   375  	native := NativeEndian()
   376  	buf := make([]byte, rtaAlignOf(8))
   377  	native.PutUint16(buf[0:2], 8)
   378  	native.PutUint16(buf[2:4], a.Type)
   379  
   380  	if a.Type&NLA_F_NET_BYTEORDER != 0 {
   381  		binary.BigEndian.PutUint32(buf[4:], a.Value)
   382  	} else {
   383  		native.PutUint32(buf[4:], a.Value)
   384  	}
   385  	return buf
   386  }
   387  
   388  func (a *Uint32Attribute) Len() int {
   389  	return 8
   390  }
   391  
   392  // Extend RtAttr to handle data and children
   393  type RtAttr struct {
   394  	unix.RtAttr
   395  	Data     []byte
   396  	children []NetlinkRequestData
   397  }
   398  
   399  // Create a new Extended RtAttr object
   400  func NewRtAttr(attrType int, data []byte) *RtAttr {
   401  	return &RtAttr{
   402  		RtAttr: unix.RtAttr{
   403  			Type: uint16(attrType),
   404  		},
   405  		children: []NetlinkRequestData{},
   406  		Data:     data,
   407  	}
   408  }
   409  
   410  // NewRtAttrChild adds an RtAttr as a child to the parent and returns the new attribute
   411  //
   412  // Deprecated: Use AddRtAttr() on the parent object
   413  func NewRtAttrChild(parent *RtAttr, attrType int, data []byte) *RtAttr {
   414  	return parent.AddRtAttr(attrType, data)
   415  }
   416  
   417  // AddRtAttr adds an RtAttr as a child and returns the new attribute
   418  func (a *RtAttr) AddRtAttr(attrType int, data []byte) *RtAttr {
   419  	attr := NewRtAttr(attrType, data)
   420  	a.children = append(a.children, attr)
   421  	return attr
   422  }
   423  
   424  // AddChild adds an existing NetlinkRequestData as a child.
   425  func (a *RtAttr) AddChild(attr NetlinkRequestData) {
   426  	a.children = append(a.children, attr)
   427  }
   428  
   429  func (a *RtAttr) Len() int {
   430  	if len(a.children) == 0 {
   431  		return (unix.SizeofRtAttr + len(a.Data))
   432  	}
   433  
   434  	l := 0
   435  	for _, child := range a.children {
   436  		l += rtaAlignOf(child.Len())
   437  	}
   438  	l += unix.SizeofRtAttr
   439  	return rtaAlignOf(l + len(a.Data))
   440  }
   441  
   442  // Serialize the RtAttr into a byte array
   443  // This can't just unsafe.cast because it must iterate through children.
   444  func (a *RtAttr) Serialize() []byte {
   445  	native := NativeEndian()
   446  
   447  	length := a.Len()
   448  	buf := make([]byte, rtaAlignOf(length))
   449  
   450  	next := 4
   451  	if a.Data != nil {
   452  		copy(buf[next:], a.Data)
   453  		next += rtaAlignOf(len(a.Data))
   454  	}
   455  	if len(a.children) > 0 {
   456  		for _, child := range a.children {
   457  			childBuf := child.Serialize()
   458  			copy(buf[next:], childBuf)
   459  			next += rtaAlignOf(len(childBuf))
   460  		}
   461  	}
   462  
   463  	if l := uint16(length); l != 0 {
   464  		native.PutUint16(buf[0:2], l)
   465  	}
   466  	native.PutUint16(buf[2:4], a.Type)
   467  	return buf
   468  }
   469  
   470  type NetlinkRequest struct {
   471  	unix.NlMsghdr
   472  	Data    []NetlinkRequestData
   473  	RawData []byte
   474  	Sockets map[int]*SocketHandle
   475  }
   476  
   477  // Serialize the Netlink Request into a byte array
   478  func (req *NetlinkRequest) Serialize() []byte {
   479  	length := unix.SizeofNlMsghdr
   480  	dataBytes := make([][]byte, len(req.Data))
   481  	for i, data := range req.Data {
   482  		dataBytes[i] = data.Serialize()
   483  		length = length + len(dataBytes[i])
   484  	}
   485  	length += len(req.RawData)
   486  
   487  	req.Len = uint32(length)
   488  	b := make([]byte, length)
   489  	hdr := (*(*[unix.SizeofNlMsghdr]byte)(unsafe.Pointer(req)))[:]
   490  	next := unix.SizeofNlMsghdr
   491  	copy(b[0:next], hdr)
   492  	for _, data := range dataBytes {
   493  		for _, dataByte := range data {
   494  			b[next] = dataByte
   495  			next = next + 1
   496  		}
   497  	}
   498  	// Add the raw data if any
   499  	if len(req.RawData) > 0 {
   500  		copy(b[next:length], req.RawData)
   501  	}
   502  	return b
   503  }
   504  
   505  func (req *NetlinkRequest) AddData(data NetlinkRequestData) {
   506  	req.Data = append(req.Data, data)
   507  }
   508  
   509  // AddRawData adds raw bytes to the end of the NetlinkRequest object during serialization
   510  func (req *NetlinkRequest) AddRawData(data []byte) {
   511  	req.RawData = append(req.RawData, data...)
   512  }
   513  
   514  // Execute the request against the given sockType.
   515  // Returns a list of netlink messages in serialized format, optionally filtered
   516  // by resType.
   517  // If the returned error is [ErrDumpInterrupted], results may be inconsistent
   518  // or incomplete.
   519  func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) {
   520  	var res [][]byte
   521  	err := req.ExecuteIter(sockType, resType, func(msg []byte) bool {
   522  		res = append(res, msg)
   523  		return true
   524  	})
   525  	if err != nil && !errors.Is(err, ErrDumpInterrupted) {
   526  		return nil, err
   527  	}
   528  	return res, err
   529  }
   530  
   531  // ExecuteIter executes the request against the given sockType.
   532  // Calls the provided callback func once for each netlink message.
   533  // If the callback returns false, it is not called again, but
   534  // the remaining messages are consumed/discarded.
   535  // If the returned error is [ErrDumpInterrupted], results may be inconsistent
   536  // or incomplete.
   537  //
   538  // Thread safety: ExecuteIter holds a lock on the socket until
   539  // it finishes iteration so the callback must not call back into
   540  // the netlink API.
   541  func (req *NetlinkRequest) ExecuteIter(sockType int, resType uint16, f func(msg []byte) bool) error {
   542  	var (
   543  		s   *NetlinkSocket
   544  		err error
   545  	)
   546  
   547  	if req.Sockets != nil {
   548  		if sh, ok := req.Sockets[sockType]; ok {
   549  			s = sh.Socket
   550  			req.Seq = atomic.AddUint32(&sh.Seq, 1)
   551  		}
   552  	}
   553  	sharedSocket := s != nil
   554  
   555  	if s == nil {
   556  		s, err = getNetlinkSocket(sockType)
   557  		if err != nil {
   558  			return err
   559  		}
   560  
   561  		if err := s.SetSendTimeout(&SocketTimeoutTv); err != nil {
   562  			return err
   563  		}
   564  		if err := s.SetReceiveTimeout(&SocketTimeoutTv); err != nil {
   565  			return err
   566  		}
   567  		if EnableErrorMessageReporting {
   568  			if err := s.SetExtAck(true); err != nil {
   569  				return err
   570  			}
   571  		}
   572  
   573  		defer s.Close()
   574  	} else {
   575  		s.Lock()
   576  		defer s.Unlock()
   577  	}
   578  
   579  	if err := s.Send(req); err != nil {
   580  		return err
   581  	}
   582  
   583  	pid, err := s.GetPid()
   584  	if err != nil {
   585  		return err
   586  	}
   587  
   588  	dumpIntr := false
   589  
   590  done:
   591  	for {
   592  		msgs, from, err := s.Receive()
   593  		if err != nil {
   594  			return err
   595  		}
   596  		if from.Pid != PidKernel {
   597  			return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel)
   598  		}
   599  		for _, m := range msgs {
   600  			if m.Header.Seq != req.Seq {
   601  				if sharedSocket {
   602  					continue
   603  				}
   604  				return fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
   605  			}
   606  			if m.Header.Pid != pid {
   607  				continue
   608  			}
   609  
   610  			if m.Header.Flags&unix.NLM_F_DUMP_INTR != 0 {
   611  				dumpIntr = true
   612  			}
   613  
   614  			if m.Header.Type == unix.NLMSG_DONE || m.Header.Type == unix.NLMSG_ERROR {
   615  				// NLMSG_DONE might have no payload, if so assume no error.
   616  				if m.Header.Type == unix.NLMSG_DONE && len(m.Data) == 0 {
   617  					break done
   618  				}
   619  
   620  				native := NativeEndian()
   621  				errno := int32(native.Uint32(m.Data[0:4]))
   622  				if errno == 0 {
   623  					break done
   624  				}
   625  				var err error
   626  				err = syscall.Errno(-errno)
   627  
   628  				unreadData := m.Data[4:]
   629  				if m.Header.Flags&unix.NLM_F_ACK_TLVS != 0 && len(unreadData) > syscall.SizeofNlMsghdr {
   630  					// Skip the echoed request message.
   631  					echoReqH := (*syscall.NlMsghdr)(unsafe.Pointer(&unreadData[0]))
   632  					unreadData = unreadData[nlmAlignOf(int(echoReqH.Len)):]
   633  
   634  					// Annotate `err` using nlmsgerr attributes.
   635  					for len(unreadData) >= syscall.SizeofRtAttr {
   636  						attr := (*syscall.RtAttr)(unsafe.Pointer(&unreadData[0]))
   637  						attrData := unreadData[syscall.SizeofRtAttr:attr.Len]
   638  
   639  						switch attr.Type {
   640  						case NLMSGERR_ATTR_MSG:
   641  							err = fmt.Errorf("%w: %s", err, unix.ByteSliceToString(attrData))
   642  						default:
   643  							// TODO: handle other NLMSGERR_ATTR types
   644  						}
   645  
   646  						unreadData = unreadData[rtaAlignOf(int(attr.Len)):]
   647  					}
   648  				}
   649  
   650  				return err
   651  			}
   652  			if resType != 0 && m.Header.Type != resType {
   653  				continue
   654  			}
   655  			if cont := f(m.Data); !cont {
   656  				// Drain the rest of the messages from the kernel but don't
   657  				// pass them to the iterator func.
   658  				f = dummyMsgIterFunc
   659  			}
   660  			if m.Header.Flags&unix.NLM_F_MULTI == 0 {
   661  				break done
   662  			}
   663  		}
   664  	}
   665  	if dumpIntr {
   666  		return ErrDumpInterrupted
   667  	}
   668  	return nil
   669  }
   670  
   671  func dummyMsgIterFunc(msg []byte) bool {
   672  	return true
   673  }
   674  
   675  // Create a new netlink request from proto and flags
   676  // Note the Len value will be inaccurate once data is added until
   677  // the message is serialized
   678  func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
   679  	return &NetlinkRequest{
   680  		NlMsghdr: unix.NlMsghdr{
   681  			Len:   uint32(unix.SizeofNlMsghdr),
   682  			Type:  uint16(proto),
   683  			Flags: unix.NLM_F_REQUEST | uint16(flags),
   684  			Seq:   atomic.AddUint32(&nextSeqNr, 1),
   685  		},
   686  	}
   687  }
   688  
   689  type NetlinkSocket struct {
   690  	fd             int32
   691  	file           *os.File
   692  	lsa            unix.SockaddrNetlink
   693  	sendTimeout    int64 // Access using atomic.Load/StoreInt64
   694  	receiveTimeout int64 // Access using atomic.Load/StoreInt64
   695  	sync.Mutex
   696  }
   697  
   698  func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
   699  	fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, protocol)
   700  	if err != nil {
   701  		return nil, err
   702  	}
   703  	err = unix.SetNonblock(fd, true)
   704  	if err != nil {
   705  		return nil, err
   706  	}
   707  	s := &NetlinkSocket{
   708  		fd:   int32(fd),
   709  		file: os.NewFile(uintptr(fd), "netlink"),
   710  	}
   711  	s.lsa.Family = unix.AF_NETLINK
   712  	if err := unix.Bind(fd, &s.lsa); err != nil {
   713  		unix.Close(fd)
   714  		return nil, err
   715  	}
   716  
   717  	return s, nil
   718  }
   719  
   720  // GetNetlinkSocketAt opens a netlink socket in the network namespace newNs
   721  // and positions the thread back into the network namespace specified by curNs,
   722  // when done. If curNs is close, the function derives the current namespace and
   723  // moves back into it when done. If newNs is close, the socket will be opened
   724  // in the current network namespace.
   725  func GetNetlinkSocketAt(newNs, curNs netns.NsHandle, protocol int) (*NetlinkSocket, error) {
   726  	c, err := executeInNetns(newNs, curNs)
   727  	if err != nil {
   728  		return nil, err
   729  	}
   730  	defer c()
   731  	return getNetlinkSocket(protocol)
   732  }
   733  
   734  // executeInNetns sets execution of the code following this call to the
   735  // network namespace newNs, then moves the thread back to curNs if open,
   736  // otherwise to the current netns at the time the function was invoked
   737  // In case of success, the caller is expected to execute the returned function
   738  // at the end of the code that needs to be executed in the network namespace.
   739  // Example:
   740  //
   741  //	func jobAt(...) error {
   742  //	     d, err := executeInNetns(...)
   743  //	     if err != nil { return err}
   744  //	     defer d()
   745  //	     < code which needs to be executed in specific netns>
   746  //	 }
   747  //
   748  // TODO: his function probably belongs to netns pkg.
   749  func executeInNetns(newNs, curNs netns.NsHandle) (func(), error) {
   750  	var (
   751  		err       error
   752  		moveBack  func(netns.NsHandle) error
   753  		closeNs   func() error
   754  		unlockThd func()
   755  	)
   756  	restore := func() {
   757  		// order matters
   758  		if moveBack != nil {
   759  			moveBack(curNs)
   760  		}
   761  		if closeNs != nil {
   762  			closeNs()
   763  		}
   764  		if unlockThd != nil {
   765  			unlockThd()
   766  		}
   767  	}
   768  	if newNs.IsOpen() {
   769  		runtime.LockOSThread()
   770  		unlockThd = runtime.UnlockOSThread
   771  		if !curNs.IsOpen() {
   772  			if curNs, err = netns.Get(); err != nil {
   773  				restore()
   774  				return nil, fmt.Errorf("could not get current namespace while creating netlink socket: %v", err)
   775  			}
   776  			closeNs = curNs.Close
   777  		}
   778  		if err := netns.Set(newNs); err != nil {
   779  			restore()
   780  			return nil, fmt.Errorf("failed to set into network namespace %d while creating netlink socket: %v", newNs, err)
   781  		}
   782  		moveBack = netns.Set
   783  	}
   784  	return restore, nil
   785  }
   786  
   787  // Create a netlink socket with a given protocol (e.g. NETLINK_ROUTE)
   788  // and subscribe it to multicast groups passed in variable argument list.
   789  // Returns the netlink socket on which Receive() method can be called
   790  // to retrieve the messages from the kernel.
   791  func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
   792  	fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, protocol)
   793  	if err != nil {
   794  		return nil, err
   795  	}
   796  	err = unix.SetNonblock(fd, true)
   797  	if err != nil {
   798  		return nil, err
   799  	}
   800  	s := &NetlinkSocket{
   801  		fd:   int32(fd),
   802  		file: os.NewFile(uintptr(fd), "netlink"),
   803  	}
   804  	s.lsa.Family = unix.AF_NETLINK
   805  
   806  	for _, g := range groups {
   807  		s.lsa.Groups |= (1 << (g - 1))
   808  	}
   809  
   810  	if err := unix.Bind(fd, &s.lsa); err != nil {
   811  		unix.Close(fd)
   812  		return nil, err
   813  	}
   814  
   815  	return s, nil
   816  }
   817  
   818  // SubscribeAt works like Subscribe plus let's the caller choose the network
   819  // namespace in which the socket would be opened (newNs). Then control goes back
   820  // to curNs if open, otherwise to the netns at the time this function was called.
   821  func SubscribeAt(newNs, curNs netns.NsHandle, protocol int, groups ...uint) (*NetlinkSocket, error) {
   822  	c, err := executeInNetns(newNs, curNs)
   823  	if err != nil {
   824  		return nil, err
   825  	}
   826  	defer c()
   827  	return Subscribe(protocol, groups...)
   828  }
   829  
   830  func (s *NetlinkSocket) Close() {
   831  	s.file.Close()
   832  }
   833  
   834  func (s *NetlinkSocket) GetFd() int {
   835  	return int(s.fd)
   836  }
   837  
   838  func (s *NetlinkSocket) GetTimeouts() (send, receive time.Duration) {
   839  	return time.Duration(atomic.LoadInt64(&s.sendTimeout)),
   840  		time.Duration(atomic.LoadInt64(&s.receiveTimeout))
   841  }
   842  
   843  func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
   844  	rawConn, err := s.file.SyscallConn()
   845  	if err != nil {
   846  		return err
   847  	}
   848  	var (
   849  		deadline time.Time
   850  		innerErr error
   851  	)
   852  	sendTimeout := atomic.LoadInt64(&s.sendTimeout)
   853  	if sendTimeout != 0 {
   854  		deadline = time.Now().Add(time.Duration(sendTimeout))
   855  	}
   856  	if err := s.file.SetWriteDeadline(deadline); err != nil {
   857  		return err
   858  	}
   859  	serializedReq := request.Serialize()
   860  	err = rawConn.Write(func(fd uintptr) (done bool) {
   861  		innerErr = unix.Sendto(int(s.fd), serializedReq, 0, &s.lsa)
   862  		return innerErr != unix.EWOULDBLOCK
   863  	})
   864  	if innerErr != nil {
   865  		return innerErr
   866  	}
   867  	if err != nil {
   868  		// The timeout was previously implemented using SO_SNDTIMEO on a blocking
   869  		// socket. So, continue to return EAGAIN when the timeout is reached.
   870  		if errors.Is(err, os.ErrDeadlineExceeded) {
   871  			return unix.EAGAIN
   872  		}
   873  		return err
   874  	}
   875  	return nil
   876  }
   877  
   878  func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) {
   879  	rawConn, err := s.file.SyscallConn()
   880  	if err != nil {
   881  		return nil, nil, err
   882  	}
   883  	var (
   884  		deadline time.Time
   885  		fromAddr *unix.SockaddrNetlink
   886  		rb       [RECEIVE_BUFFER_SIZE]byte
   887  		nr       int
   888  		from     unix.Sockaddr
   889  		innerErr error
   890  	)
   891  	receiveTimeout := atomic.LoadInt64(&s.receiveTimeout)
   892  	if receiveTimeout != 0 {
   893  		deadline = time.Now().Add(time.Duration(receiveTimeout))
   894  	}
   895  	if err := s.file.SetReadDeadline(deadline); err != nil {
   896  		return nil, nil, err
   897  	}
   898  	err = rawConn.Read(func(fd uintptr) (done bool) {
   899  		nr, from, innerErr = unix.Recvfrom(int(fd), rb[:], 0)
   900  		return innerErr != unix.EWOULDBLOCK
   901  	})
   902  	if innerErr != nil {
   903  		return nil, nil, innerErr
   904  	}
   905  	if err != nil {
   906  		// The timeout was previously implemented using SO_RCVTIMEO on a blocking
   907  		// socket. So, continue to return EAGAIN when the timeout is reached.
   908  		if errors.Is(err, os.ErrDeadlineExceeded) {
   909  			return nil, nil, unix.EAGAIN
   910  		}
   911  		return nil, nil, err
   912  	}
   913  	fromAddr, ok := from.(*unix.SockaddrNetlink)
   914  	if !ok {
   915  		return nil, nil, fmt.Errorf("Error converting to netlink sockaddr")
   916  	}
   917  	if nr < unix.NLMSG_HDRLEN {
   918  		return nil, nil, fmt.Errorf("Got short response from netlink")
   919  	}
   920  	msgLen := nlmAlignOf(nr)
   921  	rb2 := make([]byte, msgLen)
   922  	copy(rb2, rb[:msgLen])
   923  	nl, err := syscall.ParseNetlinkMessage(rb2)
   924  	if err != nil {
   925  		return nil, nil, err
   926  	}
   927  	return nl, fromAddr, nil
   928  }
   929  
   930  // SetSendTimeout allows to set a send timeout on the socket
   931  func (s *NetlinkSocket) SetSendTimeout(timeout *unix.Timeval) error {
   932  	atomic.StoreInt64(&s.sendTimeout, timeout.Nano())
   933  	return nil
   934  }
   935  
   936  // SetReceiveTimeout allows to set a receive timeout on the socket
   937  func (s *NetlinkSocket) SetReceiveTimeout(timeout *unix.Timeval) error {
   938  	atomic.StoreInt64(&s.receiveTimeout, timeout.Nano())
   939  	return nil
   940  }
   941  
   942  // SetReceiveBufferSize allows to set a receive buffer size on the socket
   943  func (s *NetlinkSocket) SetReceiveBufferSize(size int, force bool) error {
   944  	opt := unix.SO_RCVBUF
   945  	if force {
   946  		opt = unix.SO_RCVBUFFORCE
   947  	}
   948  	return unix.SetsockoptInt(int(s.fd), unix.SOL_SOCKET, opt, size)
   949  }
   950  
   951  // SetExtAck requests error messages to be reported on the socket
   952  func (s *NetlinkSocket) SetExtAck(enable bool) error {
   953  	var enableN int
   954  	if enable {
   955  		enableN = 1
   956  	}
   957  
   958  	return unix.SetsockoptInt(int(s.fd), unix.SOL_NETLINK, unix.NETLINK_EXT_ACK, enableN)
   959  }
   960  
   961  func (s *NetlinkSocket) GetPid() (uint32, error) {
   962  	lsa, err := unix.Getsockname(int(s.fd))
   963  	if err != nil {
   964  		return 0, err
   965  	}
   966  	switch v := lsa.(type) {
   967  	case *unix.SockaddrNetlink:
   968  		return v.Pid, nil
   969  	}
   970  	return 0, fmt.Errorf("Wrong socket type")
   971  }
   972  
   973  func ZeroTerminated(s string) []byte {
   974  	bytes := make([]byte, len(s)+1)
   975  	for i := 0; i < len(s); i++ {
   976  		bytes[i] = s[i]
   977  	}
   978  	bytes[len(s)] = 0
   979  	return bytes
   980  }
   981  
   982  func NonZeroTerminated(s string) []byte {
   983  	bytes := make([]byte, len(s))
   984  	for i := 0; i < len(s); i++ {
   985  		bytes[i] = s[i]
   986  	}
   987  	return bytes
   988  }
   989  
   990  func BytesToString(b []byte) string {
   991  	n := bytes.Index(b, []byte{0})
   992  	return string(b[:n])
   993  }
   994  
   995  func Uint8Attr(v uint8) []byte {
   996  	return []byte{byte(v)}
   997  }
   998  
   999  func Uint16Attr(v uint16) []byte {
  1000  	native := NativeEndian()
  1001  	bytes := make([]byte, 2)
  1002  	native.PutUint16(bytes, v)
  1003  	return bytes
  1004  }
  1005  
  1006  func BEUint16Attr(v uint16) []byte {
  1007  	bytes := make([]byte, 2)
  1008  	binary.BigEndian.PutUint16(bytes, v)
  1009  	return bytes
  1010  }
  1011  
  1012  func Uint32Attr(v uint32) []byte {
  1013  	native := NativeEndian()
  1014  	bytes := make([]byte, 4)
  1015  	native.PutUint32(bytes, v)
  1016  	return bytes
  1017  }
  1018  
  1019  func BEUint32Attr(v uint32) []byte {
  1020  	bytes := make([]byte, 4)
  1021  	binary.BigEndian.PutUint32(bytes, v)
  1022  	return bytes
  1023  }
  1024  
  1025  func Uint64Attr(v uint64) []byte {
  1026  	native := NativeEndian()
  1027  	bytes := make([]byte, 8)
  1028  	native.PutUint64(bytes, v)
  1029  	return bytes
  1030  }
  1031  
  1032  func BEUint64Attr(v uint64) []byte {
  1033  	bytes := make([]byte, 8)
  1034  	binary.BigEndian.PutUint64(bytes, v)
  1035  	return bytes
  1036  }
  1037  
  1038  func ParseRouteAttr(b []byte) ([]syscall.NetlinkRouteAttr, error) {
  1039  	var attrs []syscall.NetlinkRouteAttr
  1040  	for len(b) >= unix.SizeofRtAttr {
  1041  		a, vbuf, alen, err := netlinkRouteAttrAndValue(b)
  1042  		if err != nil {
  1043  			return nil, err
  1044  		}
  1045  		ra := syscall.NetlinkRouteAttr{Attr: syscall.RtAttr(*a), Value: vbuf[:int(a.Len)-unix.SizeofRtAttr]}
  1046  		attrs = append(attrs, ra)
  1047  		b = b[alen:]
  1048  	}
  1049  	return attrs, nil
  1050  }
  1051  
  1052  // ParseRouteAttrAsMap parses provided buffer that contains raw RtAttrs and returns a map of parsed
  1053  // atttributes indexed by attribute type or error if occured.
  1054  func ParseRouteAttrAsMap(b []byte) (map[uint16]syscall.NetlinkRouteAttr, error) {
  1055  	attrMap := make(map[uint16]syscall.NetlinkRouteAttr)
  1056  
  1057  	attrs, err := ParseRouteAttr(b)
  1058  	if err != nil {
  1059  		return nil, err
  1060  	}
  1061  
  1062  	for _, attr := range attrs {
  1063  		attrMap[attr.Attr.Type] = attr
  1064  	}
  1065  	return attrMap, nil
  1066  }
  1067  
  1068  func netlinkRouteAttrAndValue(b []byte) (*unix.RtAttr, []byte, int, error) {
  1069  	a := (*unix.RtAttr)(unsafe.Pointer(&b[0]))
  1070  	if int(a.Len) < unix.SizeofRtAttr || int(a.Len) > len(b) {
  1071  		return nil, nil, 0, unix.EINVAL
  1072  	}
  1073  	return a, b[unix.SizeofRtAttr:], rtaAlignOf(int(a.Len)), nil
  1074  }
  1075  
  1076  // SocketHandle contains the netlink socket and the associated
  1077  // sequence counter for a specific netlink family
  1078  type SocketHandle struct {
  1079  	Seq    uint32
  1080  	Socket *NetlinkSocket
  1081  }
  1082  
  1083  // Close closes the netlink socket
  1084  func (sh *SocketHandle) Close() {
  1085  	if sh.Socket != nil {
  1086  		sh.Socket.Close()
  1087  	}
  1088  }