github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/sentry/socket/netlink/message.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package netlink
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  
    21  	"github.com/nicocha30/gvisor-ligolo/pkg/abi/linux"
    22  	"github.com/nicocha30/gvisor-ligolo/pkg/bits"
    23  	"github.com/nicocha30/gvisor-ligolo/pkg/hostarch"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/marshal"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/marshal/primitive"
    26  )
    27  
    28  // alignPad returns the length of padding required for alignment.
    29  //
    30  // Preconditions: align is a power of two.
    31  func alignPad(length int, align uint) int {
    32  	return bits.AlignUp(length, align) - length
    33  }
    34  
    35  // Message contains a complete serialized netlink message.
    36  type Message struct {
    37  	hdr linux.NetlinkMessageHeader
    38  	buf []byte
    39  }
    40  
    41  // NewMessage creates a new Message containing the passed header.
    42  //
    43  // The header length will be updated by Finalize.
    44  func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
    45  	return &Message{
    46  		hdr: hdr,
    47  		buf: marshal.Marshal(&hdr),
    48  	}
    49  }
    50  
    51  // ParseMessage parses the first message seen at buf, returning the rest of the
    52  // buffer. If message is malformed, ok of false is returned. For last message,
    53  // padding check is loose, if there isn't enought padding, whole buf is consumed
    54  // and ok is set to true.
    55  func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) {
    56  	b := BytesView(buf)
    57  
    58  	hdrBytes, ok := b.Extract(linux.NetlinkMessageHeaderSize)
    59  	if !ok {
    60  		return
    61  	}
    62  	var hdr linux.NetlinkMessageHeader
    63  	hdr.UnmarshalUnsafe(hdrBytes)
    64  
    65  	// Msg portion.
    66  	totalMsgLen := int(hdr.Length)
    67  	_, ok = b.Extract(totalMsgLen - linux.NetlinkMessageHeaderSize)
    68  	if !ok {
    69  		return
    70  	}
    71  
    72  	// Padding.
    73  	numPad := alignPad(totalMsgLen, linux.NLMSG_ALIGNTO)
    74  	// Linux permits the last message not being aligned, just consume all of it.
    75  	// Ref: net/netlink/af_netlink.c:netlink_rcv_skb
    76  	if numPad > len(b) {
    77  		numPad = len(b)
    78  	}
    79  	_, ok = b.Extract(numPad)
    80  	if !ok {
    81  		return
    82  	}
    83  
    84  	return &Message{
    85  		hdr: hdr,
    86  		buf: buf[:totalMsgLen],
    87  	}, []byte(b), true
    88  }
    89  
    90  // Header returns the header of this message.
    91  func (m *Message) Header() linux.NetlinkMessageHeader {
    92  	return m.hdr
    93  }
    94  
    95  // GetData unmarshals the payload message header from this netlink message, and
    96  // returns the attributes portion.
    97  func (m *Message) GetData(msg marshal.Marshallable) (AttrsView, bool) {
    98  	b := BytesView(m.buf)
    99  
   100  	_, ok := b.Extract(linux.NetlinkMessageHeaderSize)
   101  	if !ok {
   102  		return nil, false
   103  	}
   104  
   105  	size := msg.SizeBytes()
   106  	msgBytes, ok := b.Extract(size)
   107  	if !ok {
   108  		return nil, false
   109  	}
   110  	msg.UnmarshalUnsafe(msgBytes)
   111  
   112  	numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO)
   113  	// Linux permits the last message not being aligned, just consume all of it.
   114  	// Ref: net/netlink/af_netlink.c:netlink_rcv_skb
   115  	if numPad > len(b) {
   116  		numPad = len(b)
   117  	}
   118  	_, ok = b.Extract(numPad)
   119  	if !ok {
   120  		return nil, false
   121  	}
   122  
   123  	return AttrsView(b), true
   124  }
   125  
   126  // Finalize returns the []byte containing the entire message, with the total
   127  // length set in the message header. The Message must not be modified after
   128  // calling Finalize.
   129  func (m *Message) Finalize() []byte {
   130  	// Update length, which is the first 4 bytes of the header.
   131  	hostarch.ByteOrder.PutUint32(m.buf, uint32(len(m.buf)))
   132  
   133  	// Align the message. Note that the message length in the header (set
   134  	// above) is the useful length of the message, not the total aligned
   135  	// length. See net/netlink/af_netlink.c:__nlmsg_put.
   136  	aligned := bits.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
   137  	m.putZeros(aligned - len(m.buf))
   138  	return m.buf
   139  }
   140  
   141  // putZeros adds n zeros to the message.
   142  func (m *Message) putZeros(n int) {
   143  	for n > 0 {
   144  		m.buf = append(m.buf, 0)
   145  		n--
   146  	}
   147  }
   148  
   149  // Put serializes v into the message.
   150  func (m *Message) Put(v marshal.Marshallable) {
   151  	m.buf = append(m.buf, marshal.Marshal(v)...)
   152  }
   153  
   154  // PutAttr adds v to the message as a netlink attribute.
   155  //
   156  // Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize +
   157  // v.SizeBytes()) fits in math.MaxUint16 bytes.
   158  func (m *Message) PutAttr(atype uint16, v marshal.Marshallable) {
   159  	l := linux.NetlinkAttrHeaderSize + v.SizeBytes()
   160  	if l > math.MaxUint16 {
   161  		panic(fmt.Sprintf("attribute too large: %d", l))
   162  	}
   163  
   164  	m.Put(&linux.NetlinkAttrHeader{
   165  		Type:   atype,
   166  		Length: uint16(l),
   167  	})
   168  	m.Put(v)
   169  
   170  	// Align the attribute.
   171  	aligned := bits.AlignUp(l, linux.NLA_ALIGNTO)
   172  	m.putZeros(aligned - l)
   173  }
   174  
   175  // PutAttrString adds s to the message as a netlink attribute.
   176  func (m *Message) PutAttrString(atype uint16, s string) {
   177  	l := linux.NetlinkAttrHeaderSize + len(s) + 1
   178  	m.Put(&linux.NetlinkAttrHeader{
   179  		Type:   atype,
   180  		Length: uint16(l),
   181  	})
   182  
   183  	// String + NUL-termination.
   184  	m.Put(primitive.AsByteSlice([]byte(s)))
   185  	m.putZeros(1)
   186  
   187  	// Align the attribute.
   188  	aligned := bits.AlignUp(l, linux.NLA_ALIGNTO)
   189  	m.putZeros(aligned - l)
   190  }
   191  
   192  // MessageSet contains a series of netlink messages.
   193  type MessageSet struct {
   194  	// Multi indicates that this a multi-part message, to be terminated by
   195  	// NLMSG_DONE. NLMSG_DONE is sent even if the set contains only one
   196  	// Message.
   197  	//
   198  	// If Multi is set, all added messages will have NLM_F_MULTI set.
   199  	Multi bool
   200  
   201  	// PortID is the destination port for all messages.
   202  	PortID int32
   203  
   204  	// Seq is the sequence counter for all messages in the set.
   205  	Seq uint32
   206  
   207  	// Messages contains the messages in the set.
   208  	Messages []*Message
   209  }
   210  
   211  // NewMessageSet creates a new MessageSet.
   212  //
   213  // portID is the destination port to set as PortID in all messages.
   214  //
   215  // seq is the sequence counter to set as seq in all messages in the set.
   216  func NewMessageSet(portID int32, seq uint32) *MessageSet {
   217  	return &MessageSet{
   218  		PortID: portID,
   219  		Seq:    seq,
   220  	}
   221  }
   222  
   223  // AddMessage adds a new message to the set and returns it for further
   224  // additions.
   225  //
   226  // The passed header will have Seq, PortID and the multi flag set
   227  // automatically.
   228  func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message {
   229  	hdr.Seq = ms.Seq
   230  	hdr.PortID = uint32(ms.PortID)
   231  	if ms.Multi {
   232  		hdr.Flags |= linux.NLM_F_MULTI
   233  	}
   234  
   235  	m := NewMessage(hdr)
   236  	ms.Messages = append(ms.Messages, m)
   237  	return m
   238  }
   239  
   240  // AttrsView is a view into the attributes portion of a netlink message.
   241  type AttrsView []byte
   242  
   243  // Empty returns whether there is no attribute left in v.
   244  func (v AttrsView) Empty() bool {
   245  	return len(v) == 0
   246  }
   247  
   248  // ParseFirst parses first netlink attribute at the beginning of v.
   249  func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest AttrsView, ok bool) {
   250  	b := BytesView(v)
   251  
   252  	hdrBytes, ok := b.Extract(linux.NetlinkAttrHeaderSize)
   253  	if !ok {
   254  		return
   255  	}
   256  	hdr.UnmarshalUnsafe(hdrBytes)
   257  
   258  	value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize)
   259  	if !ok {
   260  		return
   261  	}
   262  
   263  	_, ok = b.Extract(alignPad(int(hdr.Length), linux.NLA_ALIGNTO))
   264  	if !ok {
   265  		return
   266  	}
   267  
   268  	return hdr, value, AttrsView(b), ok
   269  }
   270  
   271  // BytesView supports extracting data from a byte slice with bounds checking.
   272  type BytesView []byte
   273  
   274  // Extract removes the first n bytes from v and returns it. If n is out of
   275  // bounds, it returns false.
   276  func (v *BytesView) Extract(n int) ([]byte, bool) {
   277  	if n < 0 || n > len(*v) {
   278  		return nil, false
   279  	}
   280  	extracted := (*v)[:n]
   281  	*v = (*v)[n:]
   282  	return extracted, true
   283  }