github.com/vishvananda/netlink@v1.3.0/nl/nl_linux.go (about)

     1  // Package nl has low level primitives for making Netlink calls.
     2  package nl
     3  
     4  import (
     5  	"bytes"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"net"
     9  	"os"
    10  	"runtime"
    11  	"sync"
    12  	"sync/atomic"
    13  	"syscall"
    14  	"unsafe"
    15  
    16  	"github.com/vishvananda/netns"
    17  	"golang.org/x/sys/unix"
    18  )
    19  
    20  const (
    21  	// Family type definitions
    22  	FAMILY_ALL  = unix.AF_UNSPEC
    23  	FAMILY_V4   = unix.AF_INET
    24  	FAMILY_V6   = unix.AF_INET6
    25  	FAMILY_MPLS = unix.AF_MPLS
    26  	// Arbitrary set value (greater than default 4k) to allow receiving
    27  	// from kernel more verbose messages e.g. for statistics,
    28  	// tc rules or filters, or other more memory requiring data.
    29  	RECEIVE_BUFFER_SIZE = 65536
    30  	// Kernel netlink pid
    31  	PidKernel     uint32 = 0
    32  	SizeofCnMsgOp        = 0x18
    33  )
    34  
    35  // SupportedNlFamilies contains the list of netlink families this netlink package supports
    36  var SupportedNlFamilies = []int{unix.NETLINK_ROUTE, unix.NETLINK_XFRM, unix.NETLINK_NETFILTER}
    37  
    38  var nextSeqNr uint32
    39  
    40  // Default netlink socket timeout, 60s
    41  var SocketTimeoutTv = unix.Timeval{Sec: 60, Usec: 0}
    42  
    43  // ErrorMessageReporting is the default error message reporting configuration for the new netlink sockets
    44  var EnableErrorMessageReporting bool = false
    45  
    46  // GetIPFamily returns the family type of a net.IP.
    47  func GetIPFamily(ip net.IP) int {
    48  	if len(ip) <= net.IPv4len {
    49  		return FAMILY_V4
    50  	}
    51  	if ip.To4() != nil {
    52  		return FAMILY_V4
    53  	}
    54  	return FAMILY_V6
    55  }
    56  
    57  var nativeEndian binary.ByteOrder
    58  
    59  // NativeEndian gets native endianness for the system
    60  func NativeEndian() binary.ByteOrder {
    61  	if nativeEndian == nil {
    62  		var x uint32 = 0x01020304
    63  		if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
    64  			nativeEndian = binary.BigEndian
    65  		} else {
    66  			nativeEndian = binary.LittleEndian
    67  		}
    68  	}
    69  	return nativeEndian
    70  }
    71  
    72  // Byte swap a 16 bit value if we aren't big endian
    73  func Swap16(i uint16) uint16 {
    74  	if NativeEndian() == binary.BigEndian {
    75  		return i
    76  	}
    77  	return (i&0xff00)>>8 | (i&0xff)<<8
    78  }
    79  
    80  // Byte swap a 32 bit value if aren't big endian
    81  func Swap32(i uint32) uint32 {
    82  	if NativeEndian() == binary.BigEndian {
    83  		return i
    84  	}
    85  	return (i&0xff000000)>>24 | (i&0xff0000)>>8 | (i&0xff00)<<8 | (i&0xff)<<24
    86  }
    87  
    88  const (
    89  	NLMSGERR_ATTR_UNUSED = 0
    90  	NLMSGERR_ATTR_MSG    = 1
    91  	NLMSGERR_ATTR_OFFS   = 2
    92  	NLMSGERR_ATTR_COOKIE = 3
    93  	NLMSGERR_ATTR_POLICY = 4
    94  )
    95  
    96  type NetlinkRequestData interface {
    97  	Len() int
    98  	Serialize() []byte
    99  }
   100  
   101  const (
   102  	PROC_CN_MCAST_LISTEN = 1
   103  	PROC_CN_MCAST_IGNORE
   104  )
   105  
   106  type CbID struct {
   107  	Idx uint32
   108  	Val uint32
   109  }
   110  
   111  type CnMsg struct {
   112  	ID     CbID
   113  	Seq    uint32
   114  	Ack    uint32
   115  	Length uint16
   116  	Flags  uint16
   117  }
   118  
   119  type CnMsgOp struct {
   120  	CnMsg
   121  	// here we differ from the C header
   122  	Op uint32
   123  }
   124  
   125  func NewCnMsg(idx, val, op uint32) *CnMsgOp {
   126  	var cm CnMsgOp
   127  
   128  	cm.ID.Idx = idx
   129  	cm.ID.Val = val
   130  
   131  	cm.Ack = 0
   132  	cm.Seq = 1
   133  	cm.Length = uint16(binary.Size(op))
   134  	cm.Op = op
   135  
   136  	return &cm
   137  }
   138  
   139  func (msg *CnMsgOp) Serialize() []byte {
   140  	return (*(*[SizeofCnMsgOp]byte)(unsafe.Pointer(msg)))[:]
   141  }
   142  
   143  func DeserializeCnMsgOp(b []byte) *CnMsgOp {
   144  	return (*CnMsgOp)(unsafe.Pointer(&b[0:SizeofCnMsgOp][0]))
   145  }
   146  
   147  func (msg *CnMsgOp) Len() int {
   148  	return SizeofCnMsgOp
   149  }
   150  
   151  // IfInfomsg is related to links, but it is used for list requests as well
   152  type IfInfomsg struct {
   153  	unix.IfInfomsg
   154  }
   155  
   156  // Create an IfInfomsg with family specified
   157  func NewIfInfomsg(family int) *IfInfomsg {
   158  	return &IfInfomsg{
   159  		IfInfomsg: unix.IfInfomsg{
   160  			Family: uint8(family),
   161  		},
   162  	}
   163  }
   164  
   165  func DeserializeIfInfomsg(b []byte) *IfInfomsg {
   166  	return (*IfInfomsg)(unsafe.Pointer(&b[0:unix.SizeofIfInfomsg][0]))
   167  }
   168  
   169  func (msg *IfInfomsg) Serialize() []byte {
   170  	return (*(*[unix.SizeofIfInfomsg]byte)(unsafe.Pointer(msg)))[:]
   171  }
   172  
   173  func (msg *IfInfomsg) Len() int {
   174  	return unix.SizeofIfInfomsg
   175  }
   176  
   177  func (msg *IfInfomsg) EncapType() string {
   178  	switch msg.Type {
   179  	case 0:
   180  		return "generic"
   181  	case unix.ARPHRD_ETHER:
   182  		return "ether"
   183  	case unix.ARPHRD_EETHER:
   184  		return "eether"
   185  	case unix.ARPHRD_AX25:
   186  		return "ax25"
   187  	case unix.ARPHRD_PRONET:
   188  		return "pronet"
   189  	case unix.ARPHRD_CHAOS:
   190  		return "chaos"
   191  	case unix.ARPHRD_IEEE802:
   192  		return "ieee802"
   193  	case unix.ARPHRD_ARCNET:
   194  		return "arcnet"
   195  	case unix.ARPHRD_APPLETLK:
   196  		return "atalk"
   197  	case unix.ARPHRD_DLCI:
   198  		return "dlci"
   199  	case unix.ARPHRD_ATM:
   200  		return "atm"
   201  	case unix.ARPHRD_METRICOM:
   202  		return "metricom"
   203  	case unix.ARPHRD_IEEE1394:
   204  		return "ieee1394"
   205  	case unix.ARPHRD_INFINIBAND:
   206  		return "infiniband"
   207  	case unix.ARPHRD_SLIP:
   208  		return "slip"
   209  	case unix.ARPHRD_CSLIP:
   210  		return "cslip"
   211  	case unix.ARPHRD_SLIP6:
   212  		return "slip6"
   213  	case unix.ARPHRD_CSLIP6:
   214  		return "cslip6"
   215  	case unix.ARPHRD_RSRVD:
   216  		return "rsrvd"
   217  	case unix.ARPHRD_ADAPT:
   218  		return "adapt"
   219  	case unix.ARPHRD_ROSE:
   220  		return "rose"
   221  	case unix.ARPHRD_X25:
   222  		return "x25"
   223  	case unix.ARPHRD_HWX25:
   224  		return "hwx25"
   225  	case unix.ARPHRD_PPP:
   226  		return "ppp"
   227  	case unix.ARPHRD_HDLC:
   228  		return "hdlc"
   229  	case unix.ARPHRD_LAPB:
   230  		return "lapb"
   231  	case unix.ARPHRD_DDCMP:
   232  		return "ddcmp"
   233  	case unix.ARPHRD_RAWHDLC:
   234  		return "rawhdlc"
   235  	case unix.ARPHRD_TUNNEL:
   236  		return "ipip"
   237  	case unix.ARPHRD_TUNNEL6:
   238  		return "tunnel6"
   239  	case unix.ARPHRD_FRAD:
   240  		return "frad"
   241  	case unix.ARPHRD_SKIP:
   242  		return "skip"
   243  	case unix.ARPHRD_LOOPBACK:
   244  		return "loopback"
   245  	case unix.ARPHRD_LOCALTLK:
   246  		return "ltalk"
   247  	case unix.ARPHRD_FDDI:
   248  		return "fddi"
   249  	case unix.ARPHRD_BIF:
   250  		return "bif"
   251  	case unix.ARPHRD_SIT:
   252  		return "sit"
   253  	case unix.ARPHRD_IPDDP:
   254  		return "ip/ddp"
   255  	case unix.ARPHRD_IPGRE:
   256  		return "gre"
   257  	case unix.ARPHRD_PIMREG:
   258  		return "pimreg"
   259  	case unix.ARPHRD_HIPPI:
   260  		return "hippi"
   261  	case unix.ARPHRD_ASH:
   262  		return "ash"
   263  	case unix.ARPHRD_ECONET:
   264  		return "econet"
   265  	case unix.ARPHRD_IRDA:
   266  		return "irda"
   267  	case unix.ARPHRD_FCPP:
   268  		return "fcpp"
   269  	case unix.ARPHRD_FCAL:
   270  		return "fcal"
   271  	case unix.ARPHRD_FCPL:
   272  		return "fcpl"
   273  	case unix.ARPHRD_FCFABRIC:
   274  		return "fcfb0"
   275  	case unix.ARPHRD_FCFABRIC + 1:
   276  		return "fcfb1"
   277  	case unix.ARPHRD_FCFABRIC + 2:
   278  		return "fcfb2"
   279  	case unix.ARPHRD_FCFABRIC + 3:
   280  		return "fcfb3"
   281  	case unix.ARPHRD_FCFABRIC + 4:
   282  		return "fcfb4"
   283  	case unix.ARPHRD_FCFABRIC + 5:
   284  		return "fcfb5"
   285  	case unix.ARPHRD_FCFABRIC + 6:
   286  		return "fcfb6"
   287  	case unix.ARPHRD_FCFABRIC + 7:
   288  		return "fcfb7"
   289  	case unix.ARPHRD_FCFABRIC + 8:
   290  		return "fcfb8"
   291  	case unix.ARPHRD_FCFABRIC + 9:
   292  		return "fcfb9"
   293  	case unix.ARPHRD_FCFABRIC + 10:
   294  		return "fcfb10"
   295  	case unix.ARPHRD_FCFABRIC + 11:
   296  		return "fcfb11"
   297  	case unix.ARPHRD_FCFABRIC + 12:
   298  		return "fcfb12"
   299  	case unix.ARPHRD_IEEE802_TR:
   300  		return "tr"
   301  	case unix.ARPHRD_IEEE80211:
   302  		return "ieee802.11"
   303  	case unix.ARPHRD_IEEE80211_PRISM:
   304  		return "ieee802.11/prism"
   305  	case unix.ARPHRD_IEEE80211_RADIOTAP:
   306  		return "ieee802.11/radiotap"
   307  	case unix.ARPHRD_IEEE802154:
   308  		return "ieee802.15.4"
   309  
   310  	case 65534:
   311  		return "none"
   312  	case 65535:
   313  		return "void"
   314  	}
   315  	return fmt.Sprintf("unknown%d", msg.Type)
   316  }
   317  
   318  // Round the length of a netlink message up to align it properly.
   319  // Taken from syscall/netlink_linux.go by The Go Authors under BSD-style license.
   320  func nlmAlignOf(msglen int) int {
   321  	return (msglen + syscall.NLMSG_ALIGNTO - 1) & ^(syscall.NLMSG_ALIGNTO - 1)
   322  }
   323  
   324  func rtaAlignOf(attrlen int) int {
   325  	return (attrlen + unix.RTA_ALIGNTO - 1) & ^(unix.RTA_ALIGNTO - 1)
   326  }
   327  
   328  func NewIfInfomsgChild(parent *RtAttr, family int) *IfInfomsg {
   329  	msg := NewIfInfomsg(family)
   330  	parent.children = append(parent.children, msg)
   331  	return msg
   332  }
   333  
   334  type Uint32Bitfield struct {
   335  	Value    uint32
   336  	Selector uint32
   337  }
   338  
   339  func (a *Uint32Bitfield) Serialize() []byte {
   340  	return (*(*[SizeofUint32Bitfield]byte)(unsafe.Pointer(a)))[:]
   341  }
   342  
   343  func DeserializeUint32Bitfield(data []byte) *Uint32Bitfield {
   344  	return (*Uint32Bitfield)(unsafe.Pointer(&data[0:SizeofUint32Bitfield][0]))
   345  }
   346  
   347  type Uint32Attribute struct {
   348  	Type  uint16
   349  	Value uint32
   350  }
   351  
   352  func (a *Uint32Attribute) Serialize() []byte {
   353  	native := NativeEndian()
   354  	buf := make([]byte, rtaAlignOf(8))
   355  	native.PutUint16(buf[0:2], 8)
   356  	native.PutUint16(buf[2:4], a.Type)
   357  
   358  	if a.Type&NLA_F_NET_BYTEORDER != 0 {
   359  		binary.BigEndian.PutUint32(buf[4:], a.Value)
   360  	} else {
   361  		native.PutUint32(buf[4:], a.Value)
   362  	}
   363  	return buf
   364  }
   365  
   366  func (a *Uint32Attribute) Len() int {
   367  	return 8
   368  }
   369  
   370  // Extend RtAttr to handle data and children
   371  type RtAttr struct {
   372  	unix.RtAttr
   373  	Data     []byte
   374  	children []NetlinkRequestData
   375  }
   376  
   377  // Create a new Extended RtAttr object
   378  func NewRtAttr(attrType int, data []byte) *RtAttr {
   379  	return &RtAttr{
   380  		RtAttr: unix.RtAttr{
   381  			Type: uint16(attrType),
   382  		},
   383  		children: []NetlinkRequestData{},
   384  		Data:     data,
   385  	}
   386  }
   387  
   388  // NewRtAttrChild adds an RtAttr as a child to the parent and returns the new attribute
   389  //
   390  // Deprecated: Use AddRtAttr() on the parent object
   391  func NewRtAttrChild(parent *RtAttr, attrType int, data []byte) *RtAttr {
   392  	return parent.AddRtAttr(attrType, data)
   393  }
   394  
   395  // AddRtAttr adds an RtAttr as a child and returns the new attribute
   396  func (a *RtAttr) AddRtAttr(attrType int, data []byte) *RtAttr {
   397  	attr := NewRtAttr(attrType, data)
   398  	a.children = append(a.children, attr)
   399  	return attr
   400  }
   401  
   402  // AddChild adds an existing NetlinkRequestData as a child.
   403  func (a *RtAttr) AddChild(attr NetlinkRequestData) {
   404  	a.children = append(a.children, attr)
   405  }
   406  
   407  func (a *RtAttr) Len() int {
   408  	if len(a.children) == 0 {
   409  		return (unix.SizeofRtAttr + len(a.Data))
   410  	}
   411  
   412  	l := 0
   413  	for _, child := range a.children {
   414  		l += rtaAlignOf(child.Len())
   415  	}
   416  	l += unix.SizeofRtAttr
   417  	return rtaAlignOf(l + len(a.Data))
   418  }
   419  
   420  // Serialize the RtAttr into a byte array
   421  // This can't just unsafe.cast because it must iterate through children.
   422  func (a *RtAttr) Serialize() []byte {
   423  	native := NativeEndian()
   424  
   425  	length := a.Len()
   426  	buf := make([]byte, rtaAlignOf(length))
   427  
   428  	next := 4
   429  	if a.Data != nil {
   430  		copy(buf[next:], a.Data)
   431  		next += rtaAlignOf(len(a.Data))
   432  	}
   433  	if len(a.children) > 0 {
   434  		for _, child := range a.children {
   435  			childBuf := child.Serialize()
   436  			copy(buf[next:], childBuf)
   437  			next += rtaAlignOf(len(childBuf))
   438  		}
   439  	}
   440  
   441  	if l := uint16(length); l != 0 {
   442  		native.PutUint16(buf[0:2], l)
   443  	}
   444  	native.PutUint16(buf[2:4], a.Type)
   445  	return buf
   446  }
   447  
   448  type NetlinkRequest struct {
   449  	unix.NlMsghdr
   450  	Data    []NetlinkRequestData
   451  	RawData []byte
   452  	Sockets map[int]*SocketHandle
   453  }
   454  
   455  // Serialize the Netlink Request into a byte array
   456  func (req *NetlinkRequest) Serialize() []byte {
   457  	length := unix.SizeofNlMsghdr
   458  	dataBytes := make([][]byte, len(req.Data))
   459  	for i, data := range req.Data {
   460  		dataBytes[i] = data.Serialize()
   461  		length = length + len(dataBytes[i])
   462  	}
   463  	length += len(req.RawData)
   464  
   465  	req.Len = uint32(length)
   466  	b := make([]byte, length)
   467  	hdr := (*(*[unix.SizeofNlMsghdr]byte)(unsafe.Pointer(req)))[:]
   468  	next := unix.SizeofNlMsghdr
   469  	copy(b[0:next], hdr)
   470  	for _, data := range dataBytes {
   471  		for _, dataByte := range data {
   472  			b[next] = dataByte
   473  			next = next + 1
   474  		}
   475  	}
   476  	// Add the raw data if any
   477  	if len(req.RawData) > 0 {
   478  		copy(b[next:length], req.RawData)
   479  	}
   480  	return b
   481  }
   482  
   483  func (req *NetlinkRequest) AddData(data NetlinkRequestData) {
   484  	req.Data = append(req.Data, data)
   485  }
   486  
   487  // AddRawData adds raw bytes to the end of the NetlinkRequest object during serialization
   488  func (req *NetlinkRequest) AddRawData(data []byte) {
   489  	req.RawData = append(req.RawData, data...)
   490  }
   491  
   492  // Execute the request against the given sockType.
   493  // Returns a list of netlink messages in serialized format, optionally filtered
   494  // by resType.
   495  func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) {
   496  	var res [][]byte
   497  	err := req.ExecuteIter(sockType, resType, func(msg []byte) bool {
   498  		res = append(res, msg)
   499  		return true
   500  	})
   501  	if err != nil {
   502  		return nil, err
   503  	}
   504  	return res, nil
   505  }
   506  
   507  // ExecuteIter executes the request against the given sockType.
   508  // Calls the provided callback func once for each netlink message.
   509  // If the callback returns false, it is not called again, but
   510  // the remaining messages are consumed/discarded.
   511  //
   512  // Thread safety: ExecuteIter holds a lock on the socket until
   513  // it finishes iteration so the callback must not call back into
   514  // the netlink API.
   515  func (req *NetlinkRequest) ExecuteIter(sockType int, resType uint16, f func(msg []byte) bool) error {
   516  	var (
   517  		s   *NetlinkSocket
   518  		err error
   519  	)
   520  
   521  	if req.Sockets != nil {
   522  		if sh, ok := req.Sockets[sockType]; ok {
   523  			s = sh.Socket
   524  			req.Seq = atomic.AddUint32(&sh.Seq, 1)
   525  		}
   526  	}
   527  	sharedSocket := s != nil
   528  
   529  	if s == nil {
   530  		s, err = getNetlinkSocket(sockType)
   531  		if err != nil {
   532  			return err
   533  		}
   534  
   535  		if err := s.SetSendTimeout(&SocketTimeoutTv); err != nil {
   536  			return err
   537  		}
   538  		if err := s.SetReceiveTimeout(&SocketTimeoutTv); err != nil {
   539  			return err
   540  		}
   541  		if EnableErrorMessageReporting {
   542  			if err := s.SetExtAck(true); err != nil {
   543  				return err
   544  			}
   545  		}
   546  
   547  		defer s.Close()
   548  	} else {
   549  		s.Lock()
   550  		defer s.Unlock()
   551  	}
   552  
   553  	if err := s.Send(req); err != nil {
   554  		return err
   555  	}
   556  
   557  	pid, err := s.GetPid()
   558  	if err != nil {
   559  		return err
   560  	}
   561  
   562  done:
   563  	for {
   564  		msgs, from, err := s.Receive()
   565  		if err != nil {
   566  			return err
   567  		}
   568  		if from.Pid != PidKernel {
   569  			return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel)
   570  		}
   571  		for _, m := range msgs {
   572  			if m.Header.Seq != req.Seq {
   573  				if sharedSocket {
   574  					continue
   575  				}
   576  				return fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
   577  			}
   578  			if m.Header.Pid != pid {
   579  				continue
   580  			}
   581  
   582  			if m.Header.Flags&unix.NLM_F_DUMP_INTR != 0 {
   583  				return syscall.Errno(unix.EINTR)
   584  			}
   585  
   586  			if m.Header.Type == unix.NLMSG_DONE || m.Header.Type == unix.NLMSG_ERROR {
   587  				// NLMSG_DONE might have no payload, if so assume no error.
   588  				if m.Header.Type == unix.NLMSG_DONE && len(m.Data) == 0 {
   589  					break done
   590  				}
   591  
   592  				native := NativeEndian()
   593  				errno := int32(native.Uint32(m.Data[0:4]))
   594  				if errno == 0 {
   595  					break done
   596  				}
   597  				var err error
   598  				err = syscall.Errno(-errno)
   599  
   600  				unreadData := m.Data[4:]
   601  				if m.Header.Flags&unix.NLM_F_ACK_TLVS != 0 && len(unreadData) > syscall.SizeofNlMsghdr {
   602  					// Skip the echoed request message.
   603  					echoReqH := (*syscall.NlMsghdr)(unsafe.Pointer(&unreadData[0]))
   604  					unreadData = unreadData[nlmAlignOf(int(echoReqH.Len)):]
   605  
   606  					// Annotate `err` using nlmsgerr attributes.
   607  					for len(unreadData) >= syscall.SizeofRtAttr {
   608  						attr := (*syscall.RtAttr)(unsafe.Pointer(&unreadData[0]))
   609  						attrData := unreadData[syscall.SizeofRtAttr:attr.Len]
   610  
   611  						switch attr.Type {
   612  						case NLMSGERR_ATTR_MSG:
   613  							err = fmt.Errorf("%w: %s", err, unix.ByteSliceToString(attrData))
   614  						default:
   615  							// TODO: handle other NLMSGERR_ATTR types
   616  						}
   617  
   618  						unreadData = unreadData[rtaAlignOf(int(attr.Len)):]
   619  					}
   620  				}
   621  
   622  				return err
   623  			}
   624  			if resType != 0 && m.Header.Type != resType {
   625  				continue
   626  			}
   627  			if cont := f(m.Data); !cont {
   628  				// Drain the rest of the messages from the kernel but don't
   629  				// pass them to the iterator func.
   630  				f = dummyMsgIterFunc
   631  			}
   632  			if m.Header.Flags&unix.NLM_F_MULTI == 0 {
   633  				break done
   634  			}
   635  		}
   636  	}
   637  	return nil
   638  }
   639  
   640  func dummyMsgIterFunc(msg []byte) bool {
   641  	return true
   642  }
   643  
   644  // Create a new netlink request from proto and flags
   645  // Note the Len value will be inaccurate once data is added until
   646  // the message is serialized
   647  func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
   648  	return &NetlinkRequest{
   649  		NlMsghdr: unix.NlMsghdr{
   650  			Len:   uint32(unix.SizeofNlMsghdr),
   651  			Type:  uint16(proto),
   652  			Flags: unix.NLM_F_REQUEST | uint16(flags),
   653  			Seq:   atomic.AddUint32(&nextSeqNr, 1),
   654  		},
   655  	}
   656  }
   657  
   658  type NetlinkSocket struct {
   659  	fd   int32
   660  	file *os.File
   661  	lsa  unix.SockaddrNetlink
   662  	sync.Mutex
   663  }
   664  
   665  func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
   666  	fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, protocol)
   667  	if err != nil {
   668  		return nil, err
   669  	}
   670  	err = unix.SetNonblock(fd, true)
   671  	if err != nil {
   672  		return nil, err
   673  	}
   674  	s := &NetlinkSocket{
   675  		fd:   int32(fd),
   676  		file: os.NewFile(uintptr(fd), "netlink"),
   677  	}
   678  	s.lsa.Family = unix.AF_NETLINK
   679  	if err := unix.Bind(fd, &s.lsa); err != nil {
   680  		unix.Close(fd)
   681  		return nil, err
   682  	}
   683  
   684  	return s, nil
   685  }
   686  
   687  // GetNetlinkSocketAt opens a netlink socket in the network namespace newNs
   688  // and positions the thread back into the network namespace specified by curNs,
   689  // when done. If curNs is close, the function derives the current namespace and
   690  // moves back into it when done. If newNs is close, the socket will be opened
   691  // in the current network namespace.
   692  func GetNetlinkSocketAt(newNs, curNs netns.NsHandle, protocol int) (*NetlinkSocket, error) {
   693  	c, err := executeInNetns(newNs, curNs)
   694  	if err != nil {
   695  		return nil, err
   696  	}
   697  	defer c()
   698  	return getNetlinkSocket(protocol)
   699  }
   700  
   701  // executeInNetns sets execution of the code following this call to the
   702  // network namespace newNs, then moves the thread back to curNs if open,
   703  // otherwise to the current netns at the time the function was invoked
   704  // In case of success, the caller is expected to execute the returned function
   705  // at the end of the code that needs to be executed in the network namespace.
   706  // Example:
   707  //
   708  //	func jobAt(...) error {
   709  //	     d, err := executeInNetns(...)
   710  //	     if err != nil { return err}
   711  //	     defer d()
   712  //	     < code which needs to be executed in specific netns>
   713  //	 }
   714  //
   715  // TODO: his function probably belongs to netns pkg.
   716  func executeInNetns(newNs, curNs netns.NsHandle) (func(), error) {
   717  	var (
   718  		err       error
   719  		moveBack  func(netns.NsHandle) error
   720  		closeNs   func() error
   721  		unlockThd func()
   722  	)
   723  	restore := func() {
   724  		// order matters
   725  		if moveBack != nil {
   726  			moveBack(curNs)
   727  		}
   728  		if closeNs != nil {
   729  			closeNs()
   730  		}
   731  		if unlockThd != nil {
   732  			unlockThd()
   733  		}
   734  	}
   735  	if newNs.IsOpen() {
   736  		runtime.LockOSThread()
   737  		unlockThd = runtime.UnlockOSThread
   738  		if !curNs.IsOpen() {
   739  			if curNs, err = netns.Get(); err != nil {
   740  				restore()
   741  				return nil, fmt.Errorf("could not get current namespace while creating netlink socket: %v", err)
   742  			}
   743  			closeNs = curNs.Close
   744  		}
   745  		if err := netns.Set(newNs); err != nil {
   746  			restore()
   747  			return nil, fmt.Errorf("failed to set into network namespace %d while creating netlink socket: %v", newNs, err)
   748  		}
   749  		moveBack = netns.Set
   750  	}
   751  	return restore, nil
   752  }
   753  
   754  // Create a netlink socket with a given protocol (e.g. NETLINK_ROUTE)
   755  // and subscribe it to multicast groups passed in variable argument list.
   756  // Returns the netlink socket on which Receive() method can be called
   757  // to retrieve the messages from the kernel.
   758  func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
   759  	fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, protocol)
   760  	if err != nil {
   761  		return nil, err
   762  	}
   763  	err = unix.SetNonblock(fd, true)
   764  	if err != nil {
   765  		return nil, err
   766  	}
   767  	s := &NetlinkSocket{
   768  		fd:   int32(fd),
   769  		file: os.NewFile(uintptr(fd), "netlink"),
   770  	}
   771  	s.lsa.Family = unix.AF_NETLINK
   772  
   773  	for _, g := range groups {
   774  		s.lsa.Groups |= (1 << (g - 1))
   775  	}
   776  
   777  	if err := unix.Bind(fd, &s.lsa); err != nil {
   778  		unix.Close(fd)
   779  		return nil, err
   780  	}
   781  
   782  	return s, nil
   783  }
   784  
   785  // SubscribeAt works like Subscribe plus let's the caller choose the network
   786  // namespace in which the socket would be opened (newNs). Then control goes back
   787  // to curNs if open, otherwise to the netns at the time this function was called.
   788  func SubscribeAt(newNs, curNs netns.NsHandle, protocol int, groups ...uint) (*NetlinkSocket, error) {
   789  	c, err := executeInNetns(newNs, curNs)
   790  	if err != nil {
   791  		return nil, err
   792  	}
   793  	defer c()
   794  	return Subscribe(protocol, groups...)
   795  }
   796  
   797  func (s *NetlinkSocket) Close() {
   798  	s.file.Close()
   799  }
   800  
   801  func (s *NetlinkSocket) GetFd() int {
   802  	return int(s.fd)
   803  }
   804  
   805  func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
   806  	return unix.Sendto(int(s.fd), request.Serialize(), 0, &s.lsa)
   807  }
   808  
   809  func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) {
   810  	rawConn, err := s.file.SyscallConn()
   811  	if err != nil {
   812  		return nil, nil, err
   813  	}
   814  	var (
   815  		fromAddr *unix.SockaddrNetlink
   816  		rb       [RECEIVE_BUFFER_SIZE]byte
   817  		nr       int
   818  		from     unix.Sockaddr
   819  		innerErr error
   820  	)
   821  	err = rawConn.Read(func(fd uintptr) (done bool) {
   822  		nr, from, innerErr = unix.Recvfrom(int(fd), rb[:], 0)
   823  		return innerErr != unix.EWOULDBLOCK
   824  	})
   825  	if innerErr != nil {
   826  		err = innerErr
   827  	}
   828  	if err != nil {
   829  		return nil, nil, err
   830  	}
   831  	fromAddr, ok := from.(*unix.SockaddrNetlink)
   832  	if !ok {
   833  		return nil, nil, fmt.Errorf("Error converting to netlink sockaddr")
   834  	}
   835  	if nr < unix.NLMSG_HDRLEN {
   836  		return nil, nil, fmt.Errorf("Got short response from netlink")
   837  	}
   838  	msgLen := nlmAlignOf(nr)
   839  	rb2 := make([]byte, msgLen)
   840  	copy(rb2, rb[:msgLen])
   841  	nl, err := syscall.ParseNetlinkMessage(rb2)
   842  	if err != nil {
   843  		return nil, nil, err
   844  	}
   845  	return nl, fromAddr, nil
   846  }
   847  
   848  // SetSendTimeout allows to set a send timeout on the socket
   849  func (s *NetlinkSocket) SetSendTimeout(timeout *unix.Timeval) error {
   850  	// Set a send timeout of SOCKET_SEND_TIMEOUT, this will allow the Send to periodically unblock and avoid that a routine
   851  	// remains stuck on a send on a closed fd
   852  	return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, timeout)
   853  }
   854  
   855  // SetReceiveTimeout allows to set a receive timeout on the socket
   856  func (s *NetlinkSocket) SetReceiveTimeout(timeout *unix.Timeval) error {
   857  	// Set a read timeout of SOCKET_READ_TIMEOUT, this will allow the Read to periodically unblock and avoid that a routine
   858  	// remains stuck on a recvmsg on a closed fd
   859  	return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, timeout)
   860  }
   861  
   862  // SetReceiveBufferSize allows to set a receive buffer size on the socket
   863  func (s *NetlinkSocket) SetReceiveBufferSize(size int, force bool) error {
   864  	opt := unix.SO_RCVBUF
   865  	if force {
   866  		opt = unix.SO_RCVBUFFORCE
   867  	}
   868  	return unix.SetsockoptInt(int(s.fd), unix.SOL_SOCKET, opt, size)
   869  }
   870  
   871  // SetExtAck requests error messages to be reported on the socket
   872  func (s *NetlinkSocket) SetExtAck(enable bool) error {
   873  	var enableN int
   874  	if enable {
   875  		enableN = 1
   876  	}
   877  
   878  	return unix.SetsockoptInt(int(s.fd), unix.SOL_NETLINK, unix.NETLINK_EXT_ACK, enableN)
   879  }
   880  
   881  func (s *NetlinkSocket) GetPid() (uint32, error) {
   882  	lsa, err := unix.Getsockname(int(s.fd))
   883  	if err != nil {
   884  		return 0, err
   885  	}
   886  	switch v := lsa.(type) {
   887  	case *unix.SockaddrNetlink:
   888  		return v.Pid, nil
   889  	}
   890  	return 0, fmt.Errorf("Wrong socket type")
   891  }
   892  
   893  func ZeroTerminated(s string) []byte {
   894  	bytes := make([]byte, len(s)+1)
   895  	for i := 0; i < len(s); i++ {
   896  		bytes[i] = s[i]
   897  	}
   898  	bytes[len(s)] = 0
   899  	return bytes
   900  }
   901  
   902  func NonZeroTerminated(s string) []byte {
   903  	bytes := make([]byte, len(s))
   904  	for i := 0; i < len(s); i++ {
   905  		bytes[i] = s[i]
   906  	}
   907  	return bytes
   908  }
   909  
   910  func BytesToString(b []byte) string {
   911  	n := bytes.Index(b, []byte{0})
   912  	return string(b[:n])
   913  }
   914  
   915  func Uint8Attr(v uint8) []byte {
   916  	return []byte{byte(v)}
   917  }
   918  
   919  func Uint16Attr(v uint16) []byte {
   920  	native := NativeEndian()
   921  	bytes := make([]byte, 2)
   922  	native.PutUint16(bytes, v)
   923  	return bytes
   924  }
   925  
   926  func BEUint16Attr(v uint16) []byte {
   927  	bytes := make([]byte, 2)
   928  	binary.BigEndian.PutUint16(bytes, v)
   929  	return bytes
   930  }
   931  
   932  func Uint32Attr(v uint32) []byte {
   933  	native := NativeEndian()
   934  	bytes := make([]byte, 4)
   935  	native.PutUint32(bytes, v)
   936  	return bytes
   937  }
   938  
   939  func BEUint32Attr(v uint32) []byte {
   940  	bytes := make([]byte, 4)
   941  	binary.BigEndian.PutUint32(bytes, v)
   942  	return bytes
   943  }
   944  
   945  func Uint64Attr(v uint64) []byte {
   946  	native := NativeEndian()
   947  	bytes := make([]byte, 8)
   948  	native.PutUint64(bytes, v)
   949  	return bytes
   950  }
   951  
   952  func BEUint64Attr(v uint64) []byte {
   953  	bytes := make([]byte, 8)
   954  	binary.BigEndian.PutUint64(bytes, v)
   955  	return bytes
   956  }
   957  
   958  func ParseRouteAttr(b []byte) ([]syscall.NetlinkRouteAttr, error) {
   959  	var attrs []syscall.NetlinkRouteAttr
   960  	for len(b) >= unix.SizeofRtAttr {
   961  		a, vbuf, alen, err := netlinkRouteAttrAndValue(b)
   962  		if err != nil {
   963  			return nil, err
   964  		}
   965  		ra := syscall.NetlinkRouteAttr{Attr: syscall.RtAttr(*a), Value: vbuf[:int(a.Len)-unix.SizeofRtAttr]}
   966  		attrs = append(attrs, ra)
   967  		b = b[alen:]
   968  	}
   969  	return attrs, nil
   970  }
   971  
   972  // ParseRouteAttrAsMap parses provided buffer that contains raw RtAttrs and returns a map of parsed
   973  // atttributes indexed by attribute type or error if occured.
   974  func ParseRouteAttrAsMap(b []byte) (map[uint16]syscall.NetlinkRouteAttr, error) {
   975  	attrMap := make(map[uint16]syscall.NetlinkRouteAttr)
   976  
   977  	attrs, err := ParseRouteAttr(b)
   978  	if err != nil {
   979  		return nil, err
   980  	}
   981  
   982  	for _, attr := range attrs {
   983  		attrMap[attr.Attr.Type] = attr
   984  	}
   985  	return attrMap, nil
   986  }
   987  
   988  func netlinkRouteAttrAndValue(b []byte) (*unix.RtAttr, []byte, int, error) {
   989  	a := (*unix.RtAttr)(unsafe.Pointer(&b[0]))
   990  	if int(a.Len) < unix.SizeofRtAttr || int(a.Len) > len(b) {
   991  		return nil, nil, 0, unix.EINVAL
   992  	}
   993  	return a, b[unix.SizeofRtAttr:], rtaAlignOf(int(a.Len)), nil
   994  }
   995  
   996  // SocketHandle contains the netlink socket and the associated
   997  // sequence counter for a specific netlink family
   998  type SocketHandle struct {
   999  	Seq    uint32
  1000  	Socket *NetlinkSocket
  1001  }
  1002  
  1003  // Close closes the netlink socket
  1004  func (sh *SocketHandle) Close() {
  1005  	if sh.Socket != nil {
  1006  		sh.Socket.Close()
  1007  	}
  1008  }