gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/socket/netlink/nlmsg/message.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package nlmsg provides helpers to parse and construct netlink messages.
    16  package nlmsg
    17  
    18  import (
    19  	"fmt"
    20  	"math"
    21  
    22  	"gvisor.dev/gvisor/pkg/abi/linux"
    23  	"gvisor.dev/gvisor/pkg/bits"
    24  	"gvisor.dev/gvisor/pkg/hostarch"
    25  	"gvisor.dev/gvisor/pkg/marshal"
    26  	"gvisor.dev/gvisor/pkg/marshal/primitive"
    27  )
    28  
    29  // alignPad returns the length of padding required for alignment.
    30  //
    31  // Preconditions: align is a power of two.
    32  func alignPad(length int, align uint) int {
    33  	return bits.AlignUp(length, align) - length
    34  }
    35  
    36  // Message contains a complete serialized netlink message.
    37  type Message struct {
    38  	hdr linux.NetlinkMessageHeader
    39  	buf []byte
    40  }
    41  
    42  // NewMessage creates a new Message containing the passed header.
    43  //
    44  // The header length will be updated by Finalize.
    45  func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
    46  	return &Message{
    47  		hdr: hdr,
    48  		buf: marshal.Marshal(&hdr),
    49  	}
    50  }
    51  
    52  // ParseMessage parses the first message seen at buf, returning the rest of the
    53  // buffer. If message is malformed, ok of false is returned. For last message,
    54  // padding check is loose, if there isn't enough padding, whole buf is consumed
    55  // and ok is set to true.
    56  func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) {
    57  	b := BytesView(buf)
    58  
    59  	hdrBytes, ok := b.Extract(linux.NetlinkMessageHeaderSize)
    60  	if !ok {
    61  		return
    62  	}
    63  	var hdr linux.NetlinkMessageHeader
    64  	hdr.UnmarshalUnsafe(hdrBytes)
    65  
    66  	// Msg portion.
    67  	totalMsgLen := int(hdr.Length)
    68  	_, ok = b.Extract(totalMsgLen - linux.NetlinkMessageHeaderSize)
    69  	if !ok {
    70  		return
    71  	}
    72  
    73  	// Padding.
    74  	numPad := alignPad(totalMsgLen, linux.NLMSG_ALIGNTO)
    75  	// Linux permits the last message not being aligned, just consume all of it.
    76  	// Ref: net/netlink/af_netlink.c:netlink_rcv_skb
    77  	if numPad > len(b) {
    78  		numPad = len(b)
    79  	}
    80  	_, ok = b.Extract(numPad)
    81  	if !ok {
    82  		return
    83  	}
    84  
    85  	return &Message{
    86  		hdr: hdr,
    87  		buf: buf[:totalMsgLen],
    88  	}, []byte(b), true
    89  }
    90  
    91  // Header returns the header of this message.
    92  func (m *Message) Header() linux.NetlinkMessageHeader {
    93  	return m.hdr
    94  }
    95  
    96  // GetData unmarshals the payload message header from this netlink message, and
    97  // returns the attributes portion.
    98  func (m *Message) GetData(msg marshal.Marshallable) (AttrsView, bool) {
    99  	b := BytesView(m.buf)
   100  
   101  	_, ok := b.Extract(linux.NetlinkMessageHeaderSize)
   102  	if !ok {
   103  		return nil, false
   104  	}
   105  
   106  	size := msg.SizeBytes()
   107  	msgBytes, ok := b.Extract(size)
   108  	if !ok {
   109  		return nil, false
   110  	}
   111  	msg.UnmarshalUnsafe(msgBytes)
   112  
   113  	numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO)
   114  	// Linux permits the last message not being aligned, just consume all of it.
   115  	// Ref: net/netlink/af_netlink.c:netlink_rcv_skb
   116  	if numPad > len(b) {
   117  		numPad = len(b)
   118  	}
   119  	_, ok = b.Extract(numPad)
   120  	if !ok {
   121  		return nil, false
   122  	}
   123  
   124  	return AttrsView(b), true
   125  }
   126  
   127  // Finalize returns the []byte containing the entire message, with the total
   128  // length set in the message header. The Message must not be modified after
   129  // calling Finalize.
   130  func (m *Message) Finalize() []byte {
   131  	// Update length, which is the first 4 bytes of the header.
   132  	hostarch.ByteOrder.PutUint32(m.buf, uint32(len(m.buf)))
   133  
   134  	// Align the message. Note that the message length in the header (set
   135  	// above) is the useful length of the message, not the total aligned
   136  	// length. See net/netlink/af_netlink.c:__nlmsg_put.
   137  	aligned := bits.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
   138  	m.putZeros(aligned - len(m.buf))
   139  	return m.buf
   140  }
   141  
   142  // putZeros adds n zeros to the message.
   143  func (m *Message) putZeros(n int) {
   144  	for n > 0 {
   145  		m.buf = append(m.buf, 0)
   146  		n--
   147  	}
   148  }
   149  
   150  // Put serializes v into the message.
   151  func (m *Message) Put(v marshal.Marshallable) {
   152  	m.buf = append(m.buf, marshal.Marshal(v)...)
   153  }
   154  
   155  // PutAttr adds v to the message as a netlink attribute.
   156  //
   157  // Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize +
   158  // v.SizeBytes()) fits in math.MaxUint16 bytes.
   159  func (m *Message) PutAttr(atype uint16, v marshal.Marshallable) {
   160  	l := linux.NetlinkAttrHeaderSize + v.SizeBytes()
   161  	if l > math.MaxUint16 {
   162  		panic(fmt.Sprintf("attribute too large: %d", l))
   163  	}
   164  
   165  	m.Put(&linux.NetlinkAttrHeader{
   166  		Type:   atype,
   167  		Length: uint16(l),
   168  	})
   169  	m.Put(v)
   170  
   171  	// Align the attribute.
   172  	aligned := bits.AlignUp(l, linux.NLA_ALIGNTO)
   173  	m.putZeros(aligned - l)
   174  }
   175  
   176  // PutAttrString adds s to the message as a netlink attribute.
   177  func (m *Message) PutAttrString(atype uint16, s string) {
   178  	l := linux.NetlinkAttrHeaderSize + len(s) + 1
   179  	m.Put(&linux.NetlinkAttrHeader{
   180  		Type:   atype,
   181  		Length: uint16(l),
   182  	})
   183  
   184  	// String + NUL-termination.
   185  	m.Put(primitive.AsByteSlice([]byte(s)))
   186  	m.putZeros(1)
   187  
   188  	// Align the attribute.
   189  	aligned := bits.AlignUp(l, linux.NLA_ALIGNTO)
   190  	m.putZeros(aligned - l)
   191  }
   192  
   193  // MessageSet contains a series of netlink messages.
   194  type MessageSet struct {
   195  	// Multi indicates that this a multi-part message, to be terminated by
   196  	// NLMSG_DONE. NLMSG_DONE is sent even if the set contains only one
   197  	// Message.
   198  	//
   199  	// If Multi is set, all added messages will have NLM_F_MULTI set.
   200  	Multi bool
   201  
   202  	// PortID is the destination port for all messages.
   203  	PortID int32
   204  
   205  	// Seq is the sequence counter for all messages in the set.
   206  	Seq uint32
   207  
   208  	// Messages contains the messages in the set.
   209  	Messages []*Message
   210  }
   211  
   212  // NewMessageSet creates a new MessageSet.
   213  //
   214  // portID is the destination port to set as PortID in all messages.
   215  //
   216  // seq is the sequence counter to set as seq in all messages in the set.
   217  func NewMessageSet(portID int32, seq uint32) *MessageSet {
   218  	return &MessageSet{
   219  		PortID: portID,
   220  		Seq:    seq,
   221  	}
   222  }
   223  
   224  // AddMessage adds a new message to the set and returns it for further
   225  // additions.
   226  //
   227  // The passed header will have Seq, PortID and the multi flag set
   228  // automatically.
   229  func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message {
   230  	hdr.Seq = ms.Seq
   231  	hdr.PortID = uint32(ms.PortID)
   232  	if ms.Multi {
   233  		hdr.Flags |= linux.NLM_F_MULTI
   234  	}
   235  
   236  	m := NewMessage(hdr)
   237  	ms.Messages = append(ms.Messages, m)
   238  	return m
   239  }
   240  
   241  // AttrsView is a view into the attributes portion of a netlink message.
   242  type AttrsView []byte
   243  
   244  // Empty returns whether there is no attribute left in v.
   245  func (v AttrsView) Empty() bool {
   246  	return len(v) == 0
   247  }
   248  
   249  // ParseFirst parses first netlink attribute at the beginning of v.
   250  func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest AttrsView, ok bool) {
   251  	b := BytesView(v)
   252  
   253  	hdrBytes, ok := b.Extract(linux.NetlinkAttrHeaderSize)
   254  	if !ok {
   255  		return
   256  	}
   257  	hdr.UnmarshalUnsafe(hdrBytes)
   258  
   259  	value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize)
   260  	if !ok {
   261  		return
   262  	}
   263  
   264  	_, ok = b.Extract(alignPad(int(hdr.Length), linux.NLA_ALIGNTO))
   265  	if !ok {
   266  		return
   267  	}
   268  
   269  	return hdr, value, AttrsView(b), ok
   270  }
   271  
   272  // Parse parses netlink attributes.
   273  func (v AttrsView) Parse() (map[uint16]BytesView, bool) {
   274  	attrs := make(map[uint16]BytesView)
   275  	attrsView := v
   276  	for !attrsView.Empty() {
   277  		// The index is unspecified, search by the interface name.
   278  		ahdr, value, rest, ok := attrsView.ParseFirst()
   279  		if !ok {
   280  			return nil, false
   281  		}
   282  		attrsView = rest
   283  		attrs[ahdr.Type] = BytesView(value)
   284  	}
   285  	return attrs, true
   286  
   287  }
   288  
   289  // BytesView supports extracting data from a byte slice with bounds checking.
   290  type BytesView []byte
   291  
   292  // Extract removes the first n bytes from v and returns it. If n is out of
   293  // bounds, it returns false.
   294  func (v *BytesView) Extract(n int) ([]byte, bool) {
   295  	if n < 0 || n > len(*v) {
   296  		return nil, false
   297  	}
   298  	extracted := (*v)[:n]
   299  	*v = (*v)[n:]
   300  	return extracted, true
   301  }
   302  
   303  // String converts the raw attribute value to string.
   304  func (v *BytesView) String() string {
   305  	b := []byte(*v)
   306  	if len(b) == 0 {
   307  		return ""
   308  	}
   309  	if b[len(b)-1] == 0 {
   310  		b = b[:len(b)-1]
   311  	}
   312  	return string(b)
   313  }
   314  
   315  // Uint32 converts the raw attribute value to uint32.
   316  func (v *BytesView) Uint32() (uint32, bool) {
   317  	attr := []byte(*v)
   318  	val := primitive.Uint32(0)
   319  	if len(attr) != val.SizeBytes() {
   320  		return 0, false
   321  	}
   322  	val.UnmarshalBytes(attr)
   323  	return uint32(val), true
   324  }