github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/socket/netlink/message_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package message_test
    16  
    17  import (
    18  	"bytes"
    19  	"reflect"
    20  	"testing"
    21  
    22  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    23  	"github.com/SagerNet/gvisor/pkg/marshal"
    24  	"github.com/SagerNet/gvisor/pkg/marshal/primitive"
    25  	"github.com/SagerNet/gvisor/pkg/sentry/socket/netlink"
    26  )
    27  
    28  type dummyNetlinkMsg struct {
    29  	marshal.StubMarshallable
    30  	Foo uint16
    31  }
    32  
    33  func (*dummyNetlinkMsg) SizeBytes() int {
    34  	return 2
    35  }
    36  
    37  func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) {
    38  	p := primitive.Uint16(m.Foo)
    39  	p.MarshalUnsafe(dst)
    40  }
    41  
    42  func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) {
    43  	var p primitive.Uint16
    44  	p.UnmarshalUnsafe(src)
    45  	m.Foo = uint16(p)
    46  }
    47  
    48  func TestParseMessage(t *testing.T) {
    49  	tests := []struct {
    50  		desc  string
    51  		input []byte
    52  
    53  		header  linux.NetlinkMessageHeader
    54  		dataMsg *dummyNetlinkMsg
    55  		restLen int
    56  		ok      bool
    57  	}{
    58  		{
    59  			desc: "valid",
    60  			input: []byte{
    61  				0x14, 0x00, 0x00, 0x00, // Length
    62  				0x01, 0x00, // Type
    63  				0x02, 0x00, // Flags
    64  				0x03, 0x00, 0x00, 0x00, // Seq
    65  				0x04, 0x00, 0x00, 0x00, // PortID
    66  				0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
    67  			},
    68  			header: linux.NetlinkMessageHeader{
    69  				Length: 20,
    70  				Type:   1,
    71  				Flags:  2,
    72  				Seq:    3,
    73  				PortID: 4,
    74  			},
    75  			dataMsg: &dummyNetlinkMsg{
    76  				Foo: 0x3130,
    77  			},
    78  			restLen: 0,
    79  			ok:      true,
    80  		},
    81  		{
    82  			desc: "valid with next message",
    83  			input: []byte{
    84  				0x14, 0x00, 0x00, 0x00, // Length
    85  				0x01, 0x00, // Type
    86  				0x02, 0x00, // Flags
    87  				0x03, 0x00, 0x00, 0x00, // Seq
    88  				0x04, 0x00, 0x00, 0x00, // PortID
    89  				0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
    90  				0xFF, // Next message (rest)
    91  			},
    92  			header: linux.NetlinkMessageHeader{
    93  				Length: 20,
    94  				Type:   1,
    95  				Flags:  2,
    96  				Seq:    3,
    97  				PortID: 4,
    98  			},
    99  			dataMsg: &dummyNetlinkMsg{
   100  				Foo: 0x3130,
   101  			},
   102  			restLen: 1,
   103  			ok:      true,
   104  		},
   105  		{
   106  			desc: "valid for last message without padding",
   107  			input: []byte{
   108  				0x12, 0x00, 0x00, 0x00, // Length
   109  				0x01, 0x00, // Type
   110  				0x02, 0x00, // Flags
   111  				0x03, 0x00, 0x00, 0x00, // Seq
   112  				0x04, 0x00, 0x00, 0x00, // PortID
   113  				0x30, 0x31, // Data message
   114  			},
   115  			header: linux.NetlinkMessageHeader{
   116  				Length: 18,
   117  				Type:   1,
   118  				Flags:  2,
   119  				Seq:    3,
   120  				PortID: 4,
   121  			},
   122  			dataMsg: &dummyNetlinkMsg{
   123  				Foo: 0x3130,
   124  			},
   125  			restLen: 0,
   126  			ok:      true,
   127  		},
   128  		{
   129  			desc: "valid for last message not to be aligned",
   130  			input: []byte{
   131  				0x13, 0x00, 0x00, 0x00, // Length
   132  				0x01, 0x00, // Type
   133  				0x02, 0x00, // Flags
   134  				0x03, 0x00, 0x00, 0x00, // Seq
   135  				0x04, 0x00, 0x00, 0x00, // PortID
   136  				0x30, 0x31, // Data message
   137  				0x00, // Excessive 1 byte permitted at end
   138  			},
   139  			header: linux.NetlinkMessageHeader{
   140  				Length: 19,
   141  				Type:   1,
   142  				Flags:  2,
   143  				Seq:    3,
   144  				PortID: 4,
   145  			},
   146  			dataMsg: &dummyNetlinkMsg{
   147  				Foo: 0x3130,
   148  			},
   149  			restLen: 0,
   150  			ok:      true,
   151  		},
   152  		{
   153  			desc: "header.Length too short",
   154  			input: []byte{
   155  				0x04, 0x00, 0x00, 0x00, // Length
   156  				0x01, 0x00, // Type
   157  				0x02, 0x00, // Flags
   158  				0x03, 0x00, 0x00, 0x00, // Seq
   159  				0x04, 0x00, 0x00, 0x00, // PortID
   160  				0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
   161  			},
   162  			ok: false,
   163  		},
   164  		{
   165  			desc: "header.Length too long",
   166  			input: []byte{
   167  				0xFF, 0xFF, 0x00, 0x00, // Length
   168  				0x01, 0x00, // Type
   169  				0x02, 0x00, // Flags
   170  				0x03, 0x00, 0x00, 0x00, // Seq
   171  				0x04, 0x00, 0x00, 0x00, // PortID
   172  				0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
   173  			},
   174  			ok: false,
   175  		},
   176  		{
   177  			desc: "header incomplete",
   178  			input: []byte{
   179  				0x04, 0x00, 0x00, 0x00, // Length
   180  			},
   181  			ok: false,
   182  		},
   183  		{
   184  			desc:  "empty message",
   185  			input: []byte{},
   186  			ok:    false,
   187  		},
   188  	}
   189  	for _, test := range tests {
   190  		msg, rest, ok := netlink.ParseMessage(test.input)
   191  		if ok != test.ok {
   192  			t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok)
   193  			continue
   194  		}
   195  		if !test.ok {
   196  			continue
   197  		}
   198  		if !reflect.DeepEqual(msg.Header(), test.header) {
   199  			t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header)
   200  		}
   201  
   202  		dataMsg := &dummyNetlinkMsg{}
   203  		_, dataOk := msg.GetData(dataMsg)
   204  		if !dataOk {
   205  			t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk)
   206  		} else if !reflect.DeepEqual(dataMsg, test.dataMsg) {
   207  			t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg)
   208  		}
   209  
   210  		if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) {
   211  			t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want)
   212  		}
   213  	}
   214  }
   215  
   216  func TestAttrView(t *testing.T) {
   217  	tests := []struct {
   218  		desc  string
   219  		input []byte
   220  
   221  		// Outputs for ParseFirst.
   222  		hdr     linux.NetlinkAttrHeader
   223  		value   []byte
   224  		restLen int
   225  		ok      bool
   226  
   227  		// Outputs for Empty.
   228  		isEmpty bool
   229  	}{
   230  		{
   231  			desc: "valid",
   232  			input: []byte{
   233  				0x06, 0x00, // Length
   234  				0x01, 0x00, // Type
   235  				0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding
   236  			},
   237  			hdr: linux.NetlinkAttrHeader{
   238  				Length: 6,
   239  				Type:   1,
   240  			},
   241  			value:   []byte{0x30, 0x31},
   242  			restLen: 0,
   243  			ok:      true,
   244  			isEmpty: false,
   245  		},
   246  		{
   247  			desc: "at alignment",
   248  			input: []byte{
   249  				0x08, 0x00, // Length
   250  				0x01, 0x00, // Type
   251  				0x30, 0x31, 0x32, 0x33, // Data
   252  			},
   253  			hdr: linux.NetlinkAttrHeader{
   254  				Length: 8,
   255  				Type:   1,
   256  			},
   257  			value:   []byte{0x30, 0x31, 0x32, 0x33},
   258  			restLen: 0,
   259  			ok:      true,
   260  			isEmpty: false,
   261  		},
   262  		{
   263  			desc: "at alignment with rest data",
   264  			input: []byte{
   265  				0x08, 0x00, // Length
   266  				0x01, 0x00, // Type
   267  				0x30, 0x31, 0x32, 0x33, // Data
   268  				0xFF, 0xFE, // Rest data
   269  			},
   270  			hdr: linux.NetlinkAttrHeader{
   271  				Length: 8,
   272  				Type:   1,
   273  			},
   274  			value:   []byte{0x30, 0x31, 0x32, 0x33},
   275  			restLen: 2,
   276  			ok:      true,
   277  			isEmpty: false,
   278  		},
   279  		{
   280  			desc: "hdr.Length too long",
   281  			input: []byte{
   282  				0xFF, 0x00, // Length
   283  				0x01, 0x00, // Type
   284  				0x30, 0x31, 0x32, 0x33, // Data
   285  			},
   286  			ok:      false,
   287  			isEmpty: false,
   288  		},
   289  		{
   290  			desc: "hdr.Length too short",
   291  			input: []byte{
   292  				0x01, 0x00, // Length
   293  				0x01, 0x00, // Type
   294  				0x30, 0x31, 0x32, 0x33, // Data
   295  			},
   296  			ok:      false,
   297  			isEmpty: false,
   298  		},
   299  		{
   300  			desc:    "empty",
   301  			input:   []byte{},
   302  			ok:      false,
   303  			isEmpty: true,
   304  		},
   305  	}
   306  	for _, test := range tests {
   307  		attrs := netlink.AttrsView(test.input)
   308  
   309  		// Test ParseFirst().
   310  		hdr, value, rest, ok := attrs.ParseFirst()
   311  		if ok != test.ok {
   312  			t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok)
   313  		} else if test.ok {
   314  			if !reflect.DeepEqual(hdr, test.hdr) {
   315  				t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr)
   316  			}
   317  			if !bytes.Equal(value, test.value) {
   318  				t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value)
   319  			}
   320  			if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) {
   321  				t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest)
   322  			}
   323  		}
   324  
   325  		// Test Empty().
   326  		if got, want := attrs.Empty(), test.isEmpty; got != want {
   327  			t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want)
   328  		}
   329  	}
   330  }