github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/nl/nl_linux.go (about)

     1  // Package nl has low level primitives for making Netlink calls.
     2  package nl
     3  
     4  import (
     5  	"bytes"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"net"
     9  	"os"
    10  	"runtime"
    11  	"sync"
    12  	"sync/atomic"
    13  	"syscall"
    14  	"unsafe"
    15  
    16  	"github.com/vishvananda/netns"
    17  	"golang.org/x/sys/unix"
    18  )
    19  
    20  const (
    21  	// Family type definitions
    22  	FAMILY_ALL  = unix.AF_UNSPEC
    23  	FAMILY_V4   = unix.AF_INET
    24  	FAMILY_V6   = unix.AF_INET6
    25  	FAMILY_MPLS = unix.AF_MPLS
    26  	// Arbitrary set value (greater than default 4k) to allow receiving
    27  	// from kernel more verbose messages e.g. for statistics,
    28  	// tc rules or filters, or other more memory requiring data.
    29  	RECEIVE_BUFFER_SIZE = 65536
    30  	// Kernel netlink pid
    31  	PidKernel     uint32 = 0
    32  	SizeofCnMsgOp        = 0x18
    33  )
    34  
    35  // SupportedNlFamilies contains the list of netlink families this netlink package supports
    36  var SupportedNlFamilies = []int{unix.NETLINK_ROUTE, unix.NETLINK_XFRM, unix.NETLINK_NETFILTER}
    37  
    38  var nextSeqNr uint32
    39  
    40  // Default netlink socket timeout, 60s
    41  var SocketTimeoutTv = unix.Timeval{Sec: 60, Usec: 0}
    42  
    43  // ErrorMessageReporting is the default error message reporting configuration for the new netlink sockets
    44  var EnableErrorMessageReporting bool = false
    45  
    46  // GetIPFamily returns the family type of a net.IP.
    47  func GetIPFamily(ip net.IP) int {
    48  	if len(ip) <= net.IPv4len {
    49  		return FAMILY_V4
    50  	}
    51  	if ip.To4() != nil {
    52  		return FAMILY_V4
    53  	}
    54  	return FAMILY_V6
    55  }
    56  
    57  var nativeEndian binary.ByteOrder
    58  
    59  // NativeEndian gets native endianness for the system
    60  func NativeEndian() binary.ByteOrder {
    61  	if nativeEndian == nil {
    62  		var x uint32 = 0x01020304
    63  		if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
    64  			nativeEndian = binary.BigEndian
    65  		} else {
    66  			nativeEndian = binary.LittleEndian
    67  		}
    68  	}
    69  	return nativeEndian
    70  }
    71  
    72  // Byte swap a 16 bit value if we aren't big endian
    73  func Swap16(i uint16) uint16 {
    74  	if NativeEndian() == binary.BigEndian {
    75  		return i
    76  	}
    77  	return (i&0xff00)>>8 | (i&0xff)<<8
    78  }
    79  
    80  // Byte swap a 32 bit value if aren't big endian
    81  func Swap32(i uint32) uint32 {
    82  	if NativeEndian() == binary.BigEndian {
    83  		return i
    84  	}
    85  	return (i&0xff000000)>>24 | (i&0xff0000)>>8 | (i&0xff00)<<8 | (i&0xff)<<24
    86  }
    87  
    88  const (
    89  	NLMSGERR_ATTR_UNUSED = 0
    90  	NLMSGERR_ATTR_MSG    = 1
    91  	NLMSGERR_ATTR_OFFS   = 2
    92  	NLMSGERR_ATTR_COOKIE = 3
    93  	NLMSGERR_ATTR_POLICY = 4
    94  )
    95  
    96  type NetlinkRequestData interface {
    97  	Len() int
    98  	Serialize() []byte
    99  }
   100  
   101  const (
   102  	PROC_CN_MCAST_LISTEN = 1
   103  	PROC_CN_MCAST_IGNORE
   104  )
   105  
   106  type CbID struct {
   107  	Idx uint32
   108  	Val uint32
   109  }
   110  
   111  type CnMsg struct {
   112  	ID     CbID
   113  	Seq    uint32
   114  	Ack    uint32
   115  	Length uint16
   116  	Flags  uint16
   117  }
   118  
   119  type CnMsgOp struct {
   120  	CnMsg
   121  	// here we differ from the C header
   122  	Op uint32
   123  }
   124  
   125  func NewCnMsg(idx, val, op uint32) *CnMsgOp {
   126  	var cm CnMsgOp
   127  
   128  	cm.ID.Idx = idx
   129  	cm.ID.Val = val
   130  
   131  	cm.Ack = 0
   132  	cm.Seq = 1
   133  	cm.Length = uint16(binary.Size(op))
   134  	cm.Op = op
   135  
   136  	return &cm
   137  }
   138  
   139  func (msg *CnMsgOp) Serialize() []byte {
   140  	return (*(*[SizeofCnMsgOp]byte)(unsafe.Pointer(msg)))[:]
   141  }
   142  
   143  func DeserializeCnMsgOp(b []byte) *CnMsgOp {
   144  	return (*CnMsgOp)(unsafe.Pointer(&b[0:SizeofCnMsgOp][0]))
   145  }
   146  
   147  func (msg *CnMsgOp) Len() int {
   148  	return SizeofCnMsgOp
   149  }
   150  
   151  // IfInfomsg is related to links, but it is used for list requests as well
   152  type IfInfomsg struct {
   153  	unix.IfInfomsg
   154  }
   155  
   156  // Create an IfInfomsg with family specified
   157  func NewIfInfomsg(family int) *IfInfomsg {
   158  	return &IfInfomsg{
   159  		IfInfomsg: unix.IfInfomsg{
   160  			Family: uint8(family),
   161  		},
   162  	}
   163  }
   164  
   165  func DeserializeIfInfomsg(b []byte) *IfInfomsg {
   166  	return (*IfInfomsg)(unsafe.Pointer(&b[0:unix.SizeofIfInfomsg][0]))
   167  }
   168  
   169  func (msg *IfInfomsg) Serialize() []byte {
   170  	return (*(*[unix.SizeofIfInfomsg]byte)(unsafe.Pointer(msg)))[:]
   171  }
   172  
   173  func (msg *IfInfomsg) Len() int {
   174  	return unix.SizeofIfInfomsg
   175  }
   176  
   177  func (msg *IfInfomsg) EncapType() string {
   178  	switch msg.Type {
   179  	case 0:
   180  		return "generic"
   181  	case unix.ARPHRD_ETHER:
   182  		return "ether"
   183  	case unix.ARPHRD_EETHER:
   184  		return "eether"
   185  	case unix.ARPHRD_AX25:
   186  		return "ax25"
   187  	case unix.ARPHRD_PRONET:
   188  		return "pronet"
   189  	case unix.ARPHRD_CHAOS:
   190  		return "chaos"
   191  	case unix.ARPHRD_IEEE802:
   192  		return "ieee802"
   193  	case unix.ARPHRD_ARCNET:
   194  		return "arcnet"
   195  	case unix.ARPHRD_APPLETLK:
   196  		return "atalk"
   197  	case unix.ARPHRD_DLCI:
   198  		return "dlci"
   199  	case unix.ARPHRD_ATM:
   200  		return "atm"
   201  	case unix.ARPHRD_METRICOM:
   202  		return "metricom"
   203  	case unix.ARPHRD_IEEE1394:
   204  		return "ieee1394"
   205  	case unix.ARPHRD_INFINIBAND:
   206  		return "infiniband"
   207  	case unix.ARPHRD_SLIP:
   208  		return "slip"
   209  	case unix.ARPHRD_CSLIP:
   210  		return "cslip"
   211  	case unix.ARPHRD_SLIP6:
   212  		return "slip6"
   213  	case unix.ARPHRD_CSLIP6:
   214  		return "cslip6"
   215  	case unix.ARPHRD_RSRVD:
   216  		return "rsrvd"
   217  	case unix.ARPHRD_ADAPT:
   218  		return "adapt"
   219  	case unix.ARPHRD_ROSE:
   220  		return "rose"
   221  	case unix.ARPHRD_X25:
   222  		return "x25"
   223  	case unix.ARPHRD_HWX25:
   224  		return "hwx25"
   225  	case unix.ARPHRD_PPP:
   226  		return "ppp"
   227  	case unix.ARPHRD_HDLC:
   228  		return "hdlc"
   229  	case unix.ARPHRD_LAPB:
   230  		return "lapb"
   231  	case unix.ARPHRD_DDCMP:
   232  		return "ddcmp"
   233  	case unix.ARPHRD_RAWHDLC:
   234  		return "rawhdlc"
   235  	case unix.ARPHRD_TUNNEL:
   236  		return "ipip"
   237  	case unix.ARPHRD_TUNNEL6:
   238  		return "tunnel6"
   239  	case unix.ARPHRD_FRAD:
   240  		return "frad"
   241  	case unix.ARPHRD_SKIP:
   242  		return "skip"
   243  	case unix.ARPHRD_LOOPBACK:
   244  		return "loopback"
   245  	case unix.ARPHRD_LOCALTLK:
   246  		return "ltalk"
   247  	case unix.ARPHRD_FDDI:
   248  		return "fddi"
   249  	case unix.ARPHRD_BIF:
   250  		return "bif"
   251  	case unix.ARPHRD_SIT:
   252  		return "sit"
   253  	case unix.ARPHRD_IPDDP:
   254  		return "ip/ddp"
   255  	case unix.ARPHRD_IPGRE:
   256  		return "gre"
   257  	case unix.ARPHRD_PIMREG:
   258  		return "pimreg"
   259  	case unix.ARPHRD_HIPPI:
   260  		return "hippi"
   261  	case unix.ARPHRD_ASH:
   262  		return "ash"
   263  	case unix.ARPHRD_ECONET:
   264  		return "econet"
   265  	case unix.ARPHRD_IRDA:
   266  		return "irda"
   267  	case unix.ARPHRD_FCPP:
   268  		return "fcpp"
   269  	case unix.ARPHRD_FCAL:
   270  		return "fcal"
   271  	case unix.ARPHRD_FCPL:
   272  		return "fcpl"
   273  	case unix.ARPHRD_FCFABRIC:
   274  		return "fcfb0"
   275  	case unix.ARPHRD_FCFABRIC + 1:
   276  		return "fcfb1"
   277  	case unix.ARPHRD_FCFABRIC + 2:
   278  		return "fcfb2"
   279  	case unix.ARPHRD_FCFABRIC + 3:
   280  		return "fcfb3"
   281  	case unix.ARPHRD_FCFABRIC + 4:
   282  		return "fcfb4"
   283  	case unix.ARPHRD_FCFABRIC + 5:
   284  		return "fcfb5"
   285  	case unix.ARPHRD_FCFABRIC + 6:
   286  		return "fcfb6"
   287  	case unix.ARPHRD_FCFABRIC + 7:
   288  		return "fcfb7"
   289  	case unix.ARPHRD_FCFABRIC + 8:
   290  		return "fcfb8"
   291  	case unix.ARPHRD_FCFABRIC + 9:
   292  		return "fcfb9"
   293  	case unix.ARPHRD_FCFABRIC + 10:
   294  		return "fcfb10"
   295  	case unix.ARPHRD_FCFABRIC + 11:
   296  		return "fcfb11"
   297  	case unix.ARPHRD_FCFABRIC + 12:
   298  		return "fcfb12"
   299  	case unix.ARPHRD_IEEE802_TR:
   300  		return "tr"
   301  	case unix.ARPHRD_IEEE80211:
   302  		return "ieee802.11"
   303  	case unix.ARPHRD_IEEE80211_PRISM:
   304  		return "ieee802.11/prism"
   305  	case unix.ARPHRD_IEEE80211_RADIOTAP:
   306  		return "ieee802.11/radiotap"
   307  	case unix.ARPHRD_IEEE802154:
   308  		return "ieee802.15.4"
   309  
   310  	case 65534:
   311  		return "none"
   312  	case 65535:
   313  		return "void"
   314  	}
   315  	return fmt.Sprintf("unknown%d", msg.Type)
   316  }
   317  
   318  // Round the length of a netlink message up to align it properly.
   319  // Taken from syscall/netlink_linux.go by The Go Authors under BSD-style license.
   320  func nlmAlignOf(msglen int) int {
   321  	return (msglen + syscall.NLMSG_ALIGNTO - 1) & ^(syscall.NLMSG_ALIGNTO - 1)
   322  }
   323  
   324  func rtaAlignOf(attrlen int) int {
   325  	return (attrlen + unix.RTA_ALIGNTO - 1) & ^(unix.RTA_ALIGNTO - 1)
   326  }
   327  
   328  func NewIfInfomsgChild(parent *RtAttr, family int) *IfInfomsg {
   329  	msg := NewIfInfomsg(family)
   330  	parent.children = append(parent.children, msg)
   331  	return msg
   332  }
   333  
   334  type Uint32Attribute struct {
   335  	Type  uint16
   336  	Value uint32
   337  }
   338  
   339  func (a *Uint32Attribute) Serialize() []byte {
   340  	native := NativeEndian()
   341  	buf := make([]byte, rtaAlignOf(8))
   342  	native.PutUint16(buf[0:2], 8)
   343  	native.PutUint16(buf[2:4], a.Type)
   344  
   345  	if a.Type&NLA_F_NET_BYTEORDER != 0 {
   346  		binary.BigEndian.PutUint32(buf[4:], a.Value)
   347  	} else {
   348  		native.PutUint32(buf[4:], a.Value)
   349  	}
   350  	return buf
   351  }
   352  
   353  func (a *Uint32Attribute) Len() int {
   354  	return 8
   355  }
   356  
   357  // Extend RtAttr to handle data and children
   358  type RtAttr struct {
   359  	unix.RtAttr
   360  	Data     []byte
   361  	children []NetlinkRequestData
   362  }
   363  
   364  // Create a new Extended RtAttr object
   365  func NewRtAttr(attrType int, data []byte) *RtAttr {
   366  	return &RtAttr{
   367  		RtAttr: unix.RtAttr{
   368  			Type: uint16(attrType),
   369  		},
   370  		children: []NetlinkRequestData{},
   371  		Data:     data,
   372  	}
   373  }
   374  
   375  // NewRtAttrChild adds an RtAttr as a child to the parent and returns the new attribute
   376  //
   377  // Deprecated: Use AddRtAttr() on the parent object
   378  func NewRtAttrChild(parent *RtAttr, attrType int, data []byte) *RtAttr {
   379  	return parent.AddRtAttr(attrType, data)
   380  }
   381  
   382  // AddRtAttr adds an RtAttr as a child and returns the new attribute
   383  func (a *RtAttr) AddRtAttr(attrType int, data []byte) *RtAttr {
   384  	attr := NewRtAttr(attrType, data)
   385  	a.children = append(a.children, attr)
   386  	return attr
   387  }
   388  
   389  // AddChild adds an existing NetlinkRequestData as a child.
   390  func (a *RtAttr) AddChild(attr NetlinkRequestData) {
   391  	a.children = append(a.children, attr)
   392  }
   393  
   394  func (a *RtAttr) Len() int {
   395  	if len(a.children) == 0 {
   396  		return (unix.SizeofRtAttr + len(a.Data))
   397  	}
   398  
   399  	l := 0
   400  	for _, child := range a.children {
   401  		l += rtaAlignOf(child.Len())
   402  	}
   403  	l += unix.SizeofRtAttr
   404  	return rtaAlignOf(l + len(a.Data))
   405  }
   406  
   407  // Serialize the RtAttr into a byte array
   408  // This can't just unsafe.cast because it must iterate through children.
   409  func (a *RtAttr) Serialize() []byte {
   410  	native := NativeEndian()
   411  
   412  	length := a.Len()
   413  	buf := make([]byte, rtaAlignOf(length))
   414  
   415  	next := 4
   416  	if a.Data != nil {
   417  		copy(buf[next:], a.Data)
   418  		next += rtaAlignOf(len(a.Data))
   419  	}
   420  	if len(a.children) > 0 {
   421  		for _, child := range a.children {
   422  			childBuf := child.Serialize()
   423  			copy(buf[next:], childBuf)
   424  			next += rtaAlignOf(len(childBuf))
   425  		}
   426  	}
   427  
   428  	if l := uint16(length); l != 0 {
   429  		native.PutUint16(buf[0:2], l)
   430  	}
   431  	native.PutUint16(buf[2:4], a.Type)
   432  	return buf
   433  }
   434  
   435  type NetlinkRequest struct {
   436  	unix.NlMsghdr
   437  	Data    []NetlinkRequestData
   438  	RawData []byte
   439  	Sockets map[int]*SocketHandle
   440  }
   441  
   442  // Serialize the Netlink Request into a byte array
   443  func (req *NetlinkRequest) Serialize() []byte {
   444  	length := unix.SizeofNlMsghdr
   445  	dataBytes := make([][]byte, len(req.Data))
   446  	for i, data := range req.Data {
   447  		dataBytes[i] = data.Serialize()
   448  		length = length + len(dataBytes[i])
   449  	}
   450  	length += len(req.RawData)
   451  
   452  	req.Len = uint32(length)
   453  	b := make([]byte, length)
   454  	hdr := (*(*[unix.SizeofNlMsghdr]byte)(unsafe.Pointer(req)))[:]
   455  	next := unix.SizeofNlMsghdr
   456  	copy(b[0:next], hdr)
   457  	for _, data := range dataBytes {
   458  		for _, dataByte := range data {
   459  			b[next] = dataByte
   460  			next = next + 1
   461  		}
   462  	}
   463  	// Add the raw data if any
   464  	if len(req.RawData) > 0 {
   465  		copy(b[next:length], req.RawData)
   466  	}
   467  	return b
   468  }
   469  
   470  func (req *NetlinkRequest) AddData(data NetlinkRequestData) {
   471  	req.Data = append(req.Data, data)
   472  }
   473  
   474  // AddRawData adds raw bytes to the end of the NetlinkRequest object during serialization
   475  func (req *NetlinkRequest) AddRawData(data []byte) {
   476  	req.RawData = append(req.RawData, data...)
   477  }
   478  
   479  // Execute the request against a the given sockType.
   480  // Returns a list of netlink messages in serialized format, optionally filtered
   481  // by resType.
   482  func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) {
   483  	var (
   484  		s   *NetlinkSocket
   485  		err error
   486  	)
   487  
   488  	if req.Sockets != nil {
   489  		if sh, ok := req.Sockets[sockType]; ok {
   490  			s = sh.Socket
   491  			req.Seq = atomic.AddUint32(&sh.Seq, 1)
   492  		}
   493  	}
   494  	sharedSocket := s != nil
   495  
   496  	if s == nil {
   497  		s, err = getNetlinkSocket(sockType)
   498  		if err != nil {
   499  			return nil, err
   500  		}
   501  
   502  		if err := s.SetSendTimeout(&SocketTimeoutTv); err != nil {
   503  			return nil, err
   504  		}
   505  		if err := s.SetReceiveTimeout(&SocketTimeoutTv); err != nil {
   506  			return nil, err
   507  		}
   508  		if EnableErrorMessageReporting {
   509  			if err := s.SetExtAck(true); err != nil {
   510  				return nil, err
   511  			}
   512  		}
   513  
   514  		defer s.Close()
   515  	} else {
   516  		s.Lock()
   517  		defer s.Unlock()
   518  	}
   519  
   520  	if err := s.Send(req); err != nil {
   521  		return nil, err
   522  	}
   523  
   524  	pid, err := s.GetPid()
   525  	if err != nil {
   526  		return nil, err
   527  	}
   528  
   529  	var res [][]byte
   530  
   531  done:
   532  	for {
   533  		msgs, from, err := s.Receive()
   534  		if err != nil {
   535  			return nil, err
   536  		}
   537  		if from.Pid != PidKernel {
   538  			return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel)
   539  		}
   540  		for _, m := range msgs {
   541  			if m.Header.Seq != req.Seq {
   542  				if sharedSocket {
   543  					continue
   544  				}
   545  				return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
   546  			}
   547  			if m.Header.Pid != pid {
   548  				continue
   549  			}
   550  			if m.Header.Type == unix.NLMSG_DONE || m.Header.Type == unix.NLMSG_ERROR {
   551  				native := NativeEndian()
   552  				errno := int32(native.Uint32(m.Data[0:4]))
   553  				if errno == 0 {
   554  					break done
   555  				}
   556  				var err error
   557  				err = syscall.Errno(-errno)
   558  
   559  				unreadData := m.Data[4:]
   560  				if m.Header.Flags|unix.NLM_F_ACK_TLVS != 0 && len(unreadData) > syscall.SizeofNlMsghdr {
   561  					// Skip the echoed request message.
   562  					echoReqH := (*syscall.NlMsghdr)(unsafe.Pointer(&unreadData[0]))
   563  					unreadData = unreadData[nlmAlignOf(int(echoReqH.Len)):]
   564  
   565  					// Annotate `err` using nlmsgerr attributes.
   566  					for len(unreadData) >= syscall.SizeofRtAttr {
   567  						attr := (*syscall.RtAttr)(unsafe.Pointer(&unreadData[0]))
   568  						attrData := unreadData[syscall.SizeofRtAttr:attr.Len]
   569  
   570  						switch attr.Type {
   571  						case NLMSGERR_ATTR_MSG:
   572  							err = fmt.Errorf("%w: %s", err, string(attrData))
   573  
   574  						default:
   575  							// TODO: handle other NLMSGERR_ATTR types
   576  						}
   577  
   578  						unreadData = unreadData[rtaAlignOf(int(attr.Len)):]
   579  					}
   580  				}
   581  
   582  				return nil, err
   583  			}
   584  			if resType != 0 && m.Header.Type != resType {
   585  				continue
   586  			}
   587  			res = append(res, m.Data)
   588  			if m.Header.Flags&unix.NLM_F_MULTI == 0 {
   589  				break done
   590  			}
   591  		}
   592  	}
   593  	return res, nil
   594  }
   595  
   596  // Create a new netlink request from proto and flags
   597  // Note the Len value will be inaccurate once data is added until
   598  // the message is serialized
   599  func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
   600  	return &NetlinkRequest{
   601  		NlMsghdr: unix.NlMsghdr{
   602  			Len:   uint32(unix.SizeofNlMsghdr),
   603  			Type:  uint16(proto),
   604  			Flags: unix.NLM_F_REQUEST | uint16(flags),
   605  			Seq:   atomic.AddUint32(&nextSeqNr, 1),
   606  		},
   607  	}
   608  }
   609  
   610  type NetlinkSocket struct {
   611  	fd   int32
   612  	file *os.File
   613  	lsa  unix.SockaddrNetlink
   614  	sync.Mutex
   615  }
   616  
   617  func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
   618  	fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, protocol)
   619  	if err != nil {
   620  		return nil, err
   621  	}
   622  	err = unix.SetNonblock(fd, true)
   623  	if err != nil {
   624  		return nil, err
   625  	}
   626  	s := &NetlinkSocket{
   627  		fd:   int32(fd),
   628  		file: os.NewFile(uintptr(fd), "netlink"),
   629  	}
   630  	s.lsa.Family = unix.AF_NETLINK
   631  	if err := unix.Bind(fd, &s.lsa); err != nil {
   632  		unix.Close(fd)
   633  		return nil, err
   634  	}
   635  
   636  	return s, nil
   637  }
   638  
   639  // GetNetlinkSocketAt opens a netlink socket in the network namespace newNs
   640  // and positions the thread back into the network namespace specified by curNs,
   641  // when done. If curNs is close, the function derives the current namespace and
   642  // moves back into it when done. If newNs is close, the socket will be opened
   643  // in the current network namespace.
   644  func GetNetlinkSocketAt(newNs, curNs netns.NsHandle, protocol int) (*NetlinkSocket, error) {
   645  	c, err := executeInNetns(newNs, curNs)
   646  	if err != nil {
   647  		return nil, err
   648  	}
   649  	defer c()
   650  	return getNetlinkSocket(protocol)
   651  }
   652  
   653  // executeInNetns sets execution of the code following this call to the
   654  // network namespace newNs, then moves the thread back to curNs if open,
   655  // otherwise to the current netns at the time the function was invoked
   656  // In case of success, the caller is expected to execute the returned function
   657  // at the end of the code that needs to be executed in the network namespace.
   658  // Example:
   659  //
   660  //	func jobAt(...) error {
   661  //	     d, err := executeInNetns(...)
   662  //	     if err != nil { return err}
   663  //	     defer d()
   664  //	     < code which needs to be executed in specific netns>
   665  //	 }
   666  //
   667  // TODO: his function probably belongs to netns pkg.
   668  func executeInNetns(newNs, curNs netns.NsHandle) (func(), error) {
   669  	var (
   670  		err       error
   671  		moveBack  func(netns.NsHandle) error
   672  		closeNs   func() error
   673  		unlockThd func()
   674  	)
   675  	restore := func() {
   676  		// order matters
   677  		if moveBack != nil {
   678  			moveBack(curNs)
   679  		}
   680  		if closeNs != nil {
   681  			closeNs()
   682  		}
   683  		if unlockThd != nil {
   684  			unlockThd()
   685  		}
   686  	}
   687  	if newNs.IsOpen() {
   688  		runtime.LockOSThread()
   689  		unlockThd = runtime.UnlockOSThread
   690  		if !curNs.IsOpen() {
   691  			if curNs, err = netns.Get(); err != nil {
   692  				restore()
   693  				return nil, fmt.Errorf("could not get current namespace while creating netlink socket: %v", err)
   694  			}
   695  			closeNs = curNs.Close
   696  		}
   697  		if err := netns.Set(newNs); err != nil {
   698  			restore()
   699  			return nil, fmt.Errorf("failed to set into network namespace %d while creating netlink socket: %v", newNs, err)
   700  		}
   701  		moveBack = netns.Set
   702  	}
   703  	return restore, nil
   704  }
   705  
   706  // Create a netlink socket with a given protocol (e.g. NETLINK_ROUTE)
   707  // and subscribe it to multicast groups passed in variable argument list.
   708  // Returns the netlink socket on which Receive() method can be called
   709  // to retrieve the messages from the kernel.
   710  func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
   711  	fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, protocol)
   712  	if err != nil {
   713  		return nil, err
   714  	}
   715  	err = unix.SetNonblock(fd, true)
   716  	if err != nil {
   717  		return nil, err
   718  	}
   719  	s := &NetlinkSocket{
   720  		fd:   int32(fd),
   721  		file: os.NewFile(uintptr(fd), "netlink"),
   722  	}
   723  	s.lsa.Family = unix.AF_NETLINK
   724  
   725  	for _, g := range groups {
   726  		s.lsa.Groups |= (1 << (g - 1))
   727  	}
   728  
   729  	if err := unix.Bind(fd, &s.lsa); err != nil {
   730  		unix.Close(fd)
   731  		return nil, err
   732  	}
   733  
   734  	return s, nil
   735  }
   736  
   737  // SubscribeAt works like Subscribe plus let's the caller choose the network
   738  // namespace in which the socket would be opened (newNs). Then control goes back
   739  // to curNs if open, otherwise to the netns at the time this function was called.
   740  func SubscribeAt(newNs, curNs netns.NsHandle, protocol int, groups ...uint) (*NetlinkSocket, error) {
   741  	c, err := executeInNetns(newNs, curNs)
   742  	if err != nil {
   743  		return nil, err
   744  	}
   745  	defer c()
   746  	return Subscribe(protocol, groups...)
   747  }
   748  
   749  func (s *NetlinkSocket) Close() {
   750  	s.file.Close()
   751  }
   752  
   753  func (s *NetlinkSocket) GetFd() int {
   754  	return int(s.fd)
   755  }
   756  
   757  func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
   758  	return unix.Sendto(int(s.fd), request.Serialize(), 0, &s.lsa)
   759  }
   760  
   761  func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) {
   762  	rawConn, err := s.file.SyscallConn()
   763  	if err != nil {
   764  		return nil, nil, err
   765  	}
   766  	var (
   767  		fromAddr *unix.SockaddrNetlink
   768  		rb       [RECEIVE_BUFFER_SIZE]byte
   769  		nr       int
   770  		from     unix.Sockaddr
   771  		innerErr error
   772  	)
   773  	err = rawConn.Read(func(fd uintptr) (done bool) {
   774  		nr, from, innerErr = unix.Recvfrom(int(fd), rb[:], 0)
   775  		return innerErr != unix.EWOULDBLOCK
   776  	})
   777  	if innerErr != nil {
   778  		err = innerErr
   779  	}
   780  	if err != nil {
   781  		return nil, nil, err
   782  	}
   783  	fromAddr, ok := from.(*unix.SockaddrNetlink)
   784  	if !ok {
   785  		return nil, nil, fmt.Errorf("Error converting to netlink sockaddr")
   786  	}
   787  	if nr < unix.NLMSG_HDRLEN {
   788  		return nil, nil, fmt.Errorf("Got short response from netlink")
   789  	}
   790  	rb2 := make([]byte, nr)
   791  	copy(rb2, rb[:nr])
   792  	nl, err := syscall.ParseNetlinkMessage(rb2)
   793  	if err != nil {
   794  		return nil, nil, err
   795  	}
   796  	return nl, fromAddr, nil
   797  }
   798  
   799  // SetSendTimeout allows to set a send timeout on the socket
   800  func (s *NetlinkSocket) SetSendTimeout(timeout *unix.Timeval) error {
   801  	// Set a send timeout of SOCKET_SEND_TIMEOUT, this will allow the Send to periodically unblock and avoid that a routine
   802  	// remains stuck on a send on a closed fd
   803  	return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, timeout)
   804  }
   805  
   806  // SetReceiveTimeout allows to set a receive timeout on the socket
   807  func (s *NetlinkSocket) SetReceiveTimeout(timeout *unix.Timeval) error {
   808  	// Set a read timeout of SOCKET_READ_TIMEOUT, this will allow the Read to periodically unblock and avoid that a routine
   809  	// remains stuck on a recvmsg on a closed fd
   810  	return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, timeout)
   811  }
   812  
   813  // SetExtAck requests error messages to be reported on the socket
   814  func (s *NetlinkSocket) SetExtAck(enable bool) error {
   815  	var enableN int
   816  	if enable {
   817  		enableN = 1
   818  	}
   819  
   820  	return unix.SetsockoptInt(int(s.fd), unix.SOL_NETLINK, unix.NETLINK_EXT_ACK, enableN)
   821  }
   822  
   823  func (s *NetlinkSocket) GetPid() (uint32, error) {
   824  	lsa, err := unix.Getsockname(int(s.fd))
   825  	if err != nil {
   826  		return 0, err
   827  	}
   828  	switch v := lsa.(type) {
   829  	case *unix.SockaddrNetlink:
   830  		return v.Pid, nil
   831  	}
   832  	return 0, fmt.Errorf("Wrong socket type")
   833  }
   834  
   835  func ZeroTerminated(s string) []byte {
   836  	bytes := make([]byte, len(s)+1)
   837  	for i := 0; i < len(s); i++ {
   838  		bytes[i] = s[i]
   839  	}
   840  	bytes[len(s)] = 0
   841  	return bytes
   842  }
   843  
   844  func NonZeroTerminated(s string) []byte {
   845  	bytes := make([]byte, len(s))
   846  	for i := 0; i < len(s); i++ {
   847  		bytes[i] = s[i]
   848  	}
   849  	return bytes
   850  }
   851  
   852  func BytesToString(b []byte) string {
   853  	n := bytes.Index(b, []byte{0})
   854  	return string(b[:n])
   855  }
   856  
   857  func Uint8Attr(v uint8) []byte {
   858  	return []byte{byte(v)}
   859  }
   860  
   861  func Uint16Attr(v uint16) []byte {
   862  	native := NativeEndian()
   863  	bytes := make([]byte, 2)
   864  	native.PutUint16(bytes, v)
   865  	return bytes
   866  }
   867  
   868  func Uint32Attr(v uint32) []byte {
   869  	native := NativeEndian()
   870  	bytes := make([]byte, 4)
   871  	native.PutUint32(bytes, v)
   872  	return bytes
   873  }
   874  
   875  func Uint64Attr(v uint64) []byte {
   876  	native := NativeEndian()
   877  	bytes := make([]byte, 8)
   878  	native.PutUint64(bytes, v)
   879  	return bytes
   880  }
   881  
   882  func ParseRouteAttr(b []byte) ([]syscall.NetlinkRouteAttr, error) {
   883  	var attrs []syscall.NetlinkRouteAttr
   884  	for len(b) >= unix.SizeofRtAttr {
   885  		a, vbuf, alen, err := netlinkRouteAttrAndValue(b)
   886  		if err != nil {
   887  			return nil, err
   888  		}
   889  		ra := syscall.NetlinkRouteAttr{Attr: syscall.RtAttr(*a), Value: vbuf[:int(a.Len)-unix.SizeofRtAttr]}
   890  		attrs = append(attrs, ra)
   891  		b = b[alen:]
   892  	}
   893  	return attrs, nil
   894  }
   895  
   896  func netlinkRouteAttrAndValue(b []byte) (*unix.RtAttr, []byte, int, error) {
   897  	a := (*unix.RtAttr)(unsafe.Pointer(&b[0]))
   898  	if int(a.Len) < unix.SizeofRtAttr || int(a.Len) > len(b) {
   899  		return nil, nil, 0, unix.EINVAL
   900  	}
   901  	return a, b[unix.SizeofRtAttr:], rtaAlignOf(int(a.Len)), nil
   902  }
   903  
   904  // SocketHandle contains the netlink socket and the associated
   905  // sequence counter for a specific netlink family
   906  type SocketHandle struct {
   907  	Seq    uint32
   908  	Socket *NetlinkSocket
   909  }
   910  
   911  // Close closes the netlink socket
   912  func (sh *SocketHandle) Close() {
   913  	if sh.Socket != nil {
   914  		sh.Socket.Close()
   915  	}
   916  }