gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/network/internal/testutil/testutil.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package testutil defines types and functions used to test Network Layer
    16  // functionality such as IP fragmentation.
    17  package testutil
    18  
    19  import (
    20  	"fmt"
    21  	"math/rand"
    22  	"testing"
    23  
    24  	"github.com/google/go-cmp/cmp"
    25  	"gvisor.dev/gvisor/pkg/buffer"
    26  	"gvisor.dev/gvisor/pkg/tcpip"
    27  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    28  	"gvisor.dev/gvisor/pkg/tcpip/header"
    29  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    30  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    31  )
    32  
    33  // MockLinkEndpoint is an endpoint used for testing, it stores packets written
    34  // to it and can mock errors.
    35  type MockLinkEndpoint struct {
    36  	// WrittenPackets is where packets written to the endpoint are stored.
    37  	WrittenPackets []*stack.PacketBuffer
    38  
    39  	mtu          uint32
    40  	err          tcpip.Error
    41  	allowPackets int
    42  }
    43  
    44  // NewMockLinkEndpoint creates a new MockLinkEndpoint.
    45  //
    46  // err is the error that will be returned once allowPackets packets are written
    47  // to the endpoint.
    48  func NewMockLinkEndpoint(mtu uint32, err tcpip.Error, allowPackets int) *MockLinkEndpoint {
    49  	return &MockLinkEndpoint{
    50  		mtu:          mtu,
    51  		err:          err,
    52  		allowPackets: allowPackets,
    53  	}
    54  }
    55  
    56  // MTU implements LinkEndpoint.MTU.
    57  func (ep *MockLinkEndpoint) MTU() uint32 { return ep.mtu }
    58  
    59  // Capabilities implements LinkEndpoint.Capabilities.
    60  func (*MockLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return 0 }
    61  
    62  // MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
    63  func (*MockLinkEndpoint) MaxHeaderLength() uint16 { return 0 }
    64  
    65  // LinkAddress implements LinkEndpoint.LinkAddress.
    66  func (*MockLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
    67  
    68  // WritePackets implements LinkEndpoint.WritePackets.
    69  func (ep *MockLinkEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
    70  	var n int
    71  	for _, pkt := range pkts.AsSlice() {
    72  		if ep.allowPackets == 0 {
    73  			return n, ep.err
    74  		}
    75  		ep.allowPackets--
    76  		ep.WrittenPackets = append(ep.WrittenPackets, pkt.IncRef())
    77  		n++
    78  	}
    79  	return n, nil
    80  }
    81  
    82  // Attach implements LinkEndpoint.Attach.
    83  func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {}
    84  
    85  // IsAttached implements LinkEndpoint.IsAttached.
    86  func (*MockLinkEndpoint) IsAttached() bool { return false }
    87  
    88  // Wait implements LinkEndpoint.Wait.
    89  func (*MockLinkEndpoint) Wait() {}
    90  
    91  // ARPHardwareType implements LinkEndpoint.ARPHardwareType.
    92  func (*MockLinkEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
    93  
    94  // AddHeader implements LinkEndpoint.AddHeader.
    95  func (*MockLinkEndpoint) AddHeader(*stack.PacketBuffer) {}
    96  
    97  // ParseHeader implements LinkEndpoint.ParseHeader.
    98  func (*MockLinkEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
    99  
   100  // Close releases all resources.
   101  func (ep *MockLinkEndpoint) Close() {
   102  	for _, pkt := range ep.WrittenPackets {
   103  		pkt.DecRef()
   104  	}
   105  	ep.WrittenPackets = nil
   106  }
   107  
   108  // MakeRandPkt generates a randomized packet. transportHeaderLength indicates
   109  // how many random bytes will be copied in the Transport Header.
   110  // extraHeaderReserveLength indicates how much extra space will be reserved for
   111  // the other headers. The payload is made from Views of the sizes listed in
   112  // viewSizes.
   113  func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer {
   114  	var buf buffer.Buffer
   115  
   116  	for _, s := range viewSizes {
   117  		newView := buffer.NewViewSize(s)
   118  		if _, err := rand.Read(newView.AsSlice()); err != nil {
   119  			panic(fmt.Sprintf("rand.Read: %s", err))
   120  		}
   121  		buf.Append(newView)
   122  	}
   123  
   124  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   125  		ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength,
   126  		Payload:            buf,
   127  	})
   128  	pkt.NetworkProtocolNumber = proto
   129  	if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil {
   130  		panic(fmt.Sprintf("rand.Read: %s", err))
   131  	}
   132  	return pkt
   133  }
   134  
   135  func checkIGMPStats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) {
   136  	t.Helper()
   137  
   138  	if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != reports {
   139  		t.Errorf("got s.Stats().IGMP.PacketsSent.V2MembershipReport.Value() = %d, want = %d", got, reports)
   140  	}
   141  	if got := s.Stats().IGMP.PacketsSent.V3MembershipReport.Value(); got != reportsV2 {
   142  		t.Errorf("got s.Stats().IGMP.PacketsSent.V3MembershipReport.Value() = %d, want = %d", got, reportsV2)
   143  	}
   144  	if got := s.Stats().IGMP.PacketsSent.LeaveGroup.Value(); got != leaves {
   145  		t.Errorf("got s.Stats().IGMP.PacketsSent.LeaveGroup.Value() = %d, want = %d", got, leaves)
   146  	}
   147  }
   148  
   149  // CheckIGMPv2Stats checks IGMPv2 stats.
   150  func CheckIGMPv2Stats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) {
   151  	t.Helper()
   152  	// We still check V3 stats in V2 compatibility tests because the test may send
   153  	// V3 reports before we drop into compatibility mode.
   154  	checkIGMPStats(t, s, reports, leaves, reportsV2)
   155  }
   156  
   157  // CheckIGMPv3Stats checks IGMPv3 stats.
   158  func CheckIGMPv3Stats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) {
   159  	t.Helper()
   160  	// In IGMPv3 tests, reports/leaves are just IGMPv3 reports.
   161  	checkIGMPStats(t, s, 0 /* reports */, 0 /* leaves */, reports+leaves+reportsV2)
   162  }
   163  
   164  func checkMLDStats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) {
   165  	t.Helper()
   166  
   167  	if got := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport.Value(); got != reports {
   168  		t.Errorf("got s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport.Value() = %d, want = %d", got, reports)
   169  	}
   170  	if got := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReportV2.Value(); got != reportsV2 {
   171  		t.Errorf("got s.Stats().ICMP.V6.PacketsSent.MulticastListenerReportV2.Value() = %d, want = %d", got, reportsV2)
   172  	}
   173  	if got := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone.Value(); got != leaves {
   174  		t.Errorf("got s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone.Value() = %d, want = %d", got, leaves)
   175  	}
   176  }
   177  
   178  // CheckMLDv1Stats checks MLDv1 stats.
   179  func CheckMLDv1Stats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) {
   180  	t.Helper()
   181  	// We still check V2 stats in V1 compatibility tests because the test may send
   182  	// V2 reports before we drop into compatibility mode.
   183  	checkMLDStats(t, s, reports, leaves, reportsV2)
   184  }
   185  
   186  // CheckMLDv2Stats checks MLDv2 stats.
   187  func CheckMLDv2Stats(t *testing.T, s *stack.Stack, reports, leaves, reportsV2 uint64) {
   188  	t.Helper()
   189  	// In MLDv2 tests, reports/leaves are just MLDv2 reports.
   190  	checkMLDStats(t, s, 0 /* reports */, 0 /* leaves */, reports+leaves+reportsV2)
   191  }
   192  
   193  // ValidateIGMPv3ReportWithRecords validates an IGMPv3 report.
   194  //
   195  // Note that observed records are removed from expectedRecords. No error is
   196  // logged if the report does not have all the records expected.
   197  func ValidateIGMPv3ReportWithRecords(t *testing.T, v *buffer.View, srcAddr tcpip.Address, expectedRecords map[tcpip.Address]header.IGMPv3ReportRecordType) {
   198  	t.Helper()
   199  
   200  	checker.IPv4(t, v,
   201  		checker.SrcAddr(srcAddr),
   202  		checker.DstAddr(header.IGMPv3RoutersAddress),
   203  		checker.TTL(header.IGMPTTL),
   204  		checker.IPv4RouterAlert(),
   205  		checker.IGMPv3Report(expectedRecords),
   206  	)
   207  }
   208  
   209  // ValidateIGMPv3Report validates an IGMPv3 report.
   210  func ValidateIGMPv3Report(t *testing.T, v *buffer.View, srcAddr tcpip.Address, addrs []tcpip.Address, recordType header.IGMPv3ReportRecordType) {
   211  	t.Helper()
   212  
   213  	records := make(map[tcpip.Address]header.IGMPv3ReportRecordType)
   214  	for _, addr := range addrs {
   215  		records[addr] = recordType
   216  	}
   217  
   218  	ValidateIGMPv3ReportWithRecords(t, v, srcAddr, records)
   219  
   220  	if diff := cmp.Diff(map[tcpip.Address]header.IGMPv3ReportRecordType{}, records); diff != "" {
   221  		t.Errorf("post-validation records map mismatch (-want +got):\n%s", diff)
   222  	}
   223  }
   224  
   225  // ValidateIGMPv3RecordsAcrossReports validates IGMPv3 records across one or
   226  // more reports.
   227  func ValidateIGMPv3RecordsAcrossReports(t *testing.T, e *channel.Endpoint, srcAddr tcpip.Address, addrs []tcpip.Address, recordType header.IGMPv3ReportRecordType) {
   228  	t.Helper()
   229  
   230  	expectedRecords := make(map[tcpip.Address]header.IGMPv3ReportRecordType)
   231  	for _, addr := range addrs {
   232  		expectedRecords[addr] = recordType
   233  	}
   234  
   235  	for len(expectedRecords) != 0 {
   236  		p := e.Read()
   237  		if p == nil {
   238  			t.Fatalf("expected IGMP message with expectedRecords = %#v", expectedRecords)
   239  		}
   240  		v := stack.PayloadSince(p.NetworkHeader())
   241  		ValidateIGMPv3ReportWithRecords(t, v, srcAddr, expectedRecords)
   242  		v.Release()
   243  		p.DecRef()
   244  	}
   245  
   246  	if diff := cmp.Diff(map[tcpip.Address]header.IGMPv3ReportRecordType{}, expectedRecords); diff != "" {
   247  		t.Errorf("post-validation records map mismatch (-want +got):\n%s", diff)
   248  	}
   249  }
   250  
   251  // ValidMultipleIGMPv2ReportLeaves validates the reception of multiple IGMPv2
   252  // report/leave messages.
   253  func ValidMultipleIGMPv2ReportLeaves(t *testing.T, e *channel.Endpoint, srcAddr tcpip.Address, addrs []tcpip.Address, leave bool) {
   254  	t.Helper()
   255  
   256  	expectedGroups := make(map[tcpip.Address]struct{})
   257  	for _, addr := range addrs {
   258  		expectedGroups[addr] = struct{}{}
   259  	}
   260  
   261  	igmpType := header.IGMPv2MembershipReport
   262  	if leave {
   263  		igmpType = header.IGMPLeaveGroup
   264  	}
   265  
   266  	for len(expectedGroups) != 0 {
   267  		p := e.Read()
   268  		if p == nil {
   269  			t.Fatalf("expected IGMP message with expectedGroups = %#v", expectedGroups)
   270  		}
   271  		v := stack.PayloadSince(p.NetworkHeader())
   272  		checker.IPv4(t, v,
   273  			checker.SrcAddr(srcAddr),
   274  			checker.TTL(header.IGMPTTL),
   275  			checker.IPv4RouterAlert(),
   276  			checker.IGMP(
   277  				checker.IGMPType(igmpType),
   278  				checker.IGMPMaxRespTime(0),
   279  				checker.IGMPGroupAddressUnordered(expectedGroups),
   280  			),
   281  		)
   282  		v.Release()
   283  		p.DecRef()
   284  	}
   285  
   286  	if diff := cmp.Diff(map[tcpip.Address]struct{}{}, expectedGroups); diff != "" {
   287  		t.Errorf("post-validation groups map mismatch (-want +got):\n%s", diff)
   288  	}
   289  }
   290  
   291  // ValidateMLDv2ReportWithRecords validates an MLDv2 report.
   292  //
   293  // Note that observed records are removed from expectedRecords. No error is
   294  // logged if the report does not have all the records expected.
   295  func ValidateMLDv2ReportWithRecords(t *testing.T, v *buffer.View, srcAddr tcpip.Address, expectedRecords map[tcpip.Address]header.MLDv2ReportRecordType) {
   296  	t.Helper()
   297  
   298  	checker.IPv6WithExtHdr(t, v,
   299  		checker.IPv6ExtHdr(
   300  			checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
   301  		),
   302  		checker.SrcAddr(srcAddr),
   303  		checker.DstAddr(header.MLDv2RoutersAddress),
   304  		checker.TTL(header.MLDHopLimit),
   305  		checker.MLDv2Report(expectedRecords),
   306  	)
   307  }
   308  
   309  // ValidateMLDv2Report validates an MLDv2 report.
   310  func ValidateMLDv2Report(t *testing.T, v *buffer.View, srcAddr tcpip.Address, addrs []tcpip.Address, recordType header.MLDv2ReportRecordType) {
   311  	t.Helper()
   312  
   313  	records := make(map[tcpip.Address]header.MLDv2ReportRecordType)
   314  	for _, addr := range addrs {
   315  		records[addr] = recordType
   316  	}
   317  
   318  	ValidateMLDv2ReportWithRecords(t, v, srcAddr, records)
   319  
   320  	if diff := cmp.Diff(map[tcpip.Address]header.MLDv2ReportRecordType{}, records); diff != "" {
   321  		t.Errorf("post-validation records map mismatch (-want +got):\n%s", diff)
   322  	}
   323  }
   324  
   325  // ValidateMLDv2RecordsAcrossReports validates MLDv2 records across one or more
   326  // reports.
   327  func ValidateMLDv2RecordsAcrossReports(t *testing.T, e *channel.Endpoint, srcAddr tcpip.Address, addrs []tcpip.Address, recordType header.MLDv2ReportRecordType) {
   328  	t.Helper()
   329  
   330  	expectedRecords := make(map[tcpip.Address]header.MLDv2ReportRecordType)
   331  	for _, addr := range addrs {
   332  		expectedRecords[addr] = recordType
   333  	}
   334  
   335  	for len(expectedRecords) != 0 {
   336  		p := e.Read()
   337  		if p == nil {
   338  			t.Fatalf("expected MLD Message with expectedRecords = %#v", expectedRecords)
   339  		}
   340  		v := stack.PayloadSince(p.NetworkHeader())
   341  		ValidateMLDv2ReportWithRecords(t, v, srcAddr, expectedRecords)
   342  		v.Release()
   343  		p.DecRef()
   344  	}
   345  
   346  	if diff := cmp.Diff(map[tcpip.Address]header.MLDv2ReportRecordType{}, expectedRecords); diff != "" {
   347  		t.Errorf("post-validation records map mismatch (-want +got):\n%s", diff)
   348  	}
   349  }
   350  
   351  // ValidMultipleMLDv1ReportLeaves validates the reception of multiple MLDv1
   352  // report/leave messages.
   353  func ValidMultipleMLDv1ReportLeaves(t *testing.T, e *channel.Endpoint, srcAddr tcpip.Address, addrs []tcpip.Address, leave bool) {
   354  	t.Helper()
   355  
   356  	expectedGroups := make(map[tcpip.Address]struct{})
   357  	for _, addr := range addrs {
   358  		expectedGroups[addr] = struct{}{}
   359  	}
   360  
   361  	mldType := header.ICMPv6MulticastListenerReport
   362  	if leave {
   363  		mldType = header.ICMPv6MulticastListenerDone
   364  	}
   365  
   366  	for len(expectedGroups) != 0 {
   367  		p := e.Read()
   368  		if p == nil {
   369  			t.Fatalf("expected MLD Message with expectedGroups = %#v", expectedGroups)
   370  		}
   371  		v := stack.PayloadSince(p.NetworkHeader())
   372  		checker.IPv6WithExtHdr(t, v,
   373  			checker.IPv6ExtHdr(
   374  				checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)),
   375  			),
   376  			checker.SrcAddr(srcAddr),
   377  			checker.TTL(header.MLDHopLimit),
   378  			checker.MLD(mldType, header.MLDMinimumSize,
   379  				checker.MLDMaxRespDelay(0),
   380  				checker.MLDMulticastAddressUnordered(expectedGroups),
   381  			),
   382  		)
   383  		v.Release()
   384  		p.DecRef()
   385  	}
   386  
   387  	if diff := cmp.Diff(map[tcpip.Address]struct{}{}, expectedGroups); diff != "" {
   388  		t.Errorf("post-validation groups map mismatch (-want +got):\n%s", diff)
   389  	}
   390  }