github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/network/ipv4/igmp_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ipv4_test
    16  
    17  import (
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/SagerNet/gvisor/pkg/tcpip"
    22  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    23  	"github.com/SagerNet/gvisor/pkg/tcpip/checker"
    24  	"github.com/SagerNet/gvisor/pkg/tcpip/faketime"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/link/channel"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/testutil"
    30  )
    31  
    32  const (
    33  	linkAddr            = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
    34  	nicID               = 1
    35  	defaultTTL          = 1
    36  	defaultPrefixLength = 24
    37  )
    38  
    39  var (
    40  	stackAddr     = testutil.MustParse4("10.0.0.1")
    41  	remoteAddr    = testutil.MustParse4("10.0.0.2")
    42  	multicastAddr = testutil.MustParse4("224.0.0.3")
    43  )
    44  
    45  // validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet
    46  // sent to the provided address with the passed fields set. Raises a t.Error if
    47  // any field does not match.
    48  func validateIgmpPacket(t *testing.T, p channel.PacketInfo, igmpType header.IGMPType, maxRespTime byte, srcAddr, dstAddr, groupAddress tcpip.Address) {
    49  	t.Helper()
    50  
    51  	payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader()))
    52  	checker.IPv4(t, payload,
    53  		checker.SrcAddr(srcAddr),
    54  		checker.DstAddr(dstAddr),
    55  		// TTL for an IGMP message must be 1 as per RFC 2236 section 2.
    56  		checker.TTL(1),
    57  		checker.IPv4RouterAlert(),
    58  		checker.IGMP(
    59  			checker.IGMPType(igmpType),
    60  			checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)),
    61  			checker.IGMPGroupAddress(groupAddress),
    62  		),
    63  	)
    64  }
    65  
    66  func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) {
    67  	t.Helper()
    68  
    69  	// Create an endpoint of queue size 1, since no more than 1 packets are ever
    70  	// queued in the tests in this file.
    71  	e := channel.New(1, 1280, linkAddr)
    72  	clock := faketime.NewManualClock()
    73  	s := stack.New(stack.Options{
    74  		NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{
    75  			IGMP: ipv4.IGMPOptions{
    76  				Enabled: igmpEnabled,
    77  			},
    78  		})},
    79  		Clock: clock,
    80  	})
    81  	if err := s.CreateNIC(nicID, e); err != nil {
    82  		t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
    83  	}
    84  	return e, s, clock
    85  }
    86  
    87  func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, ttl uint8, srcAddr, dstAddr, groupAddress tcpip.Address, hasRouterAlertOption bool) {
    88  	var options header.IPv4OptionsSerializer
    89  	if hasRouterAlertOption {
    90  		options = header.IPv4OptionsSerializer{
    91  			&header.IPv4SerializableRouterAlertOption{},
    92  		}
    93  	}
    94  	buf := buffer.NewView(header.IPv4MinimumSize + int(options.Length()) + header.IGMPQueryMinimumSize)
    95  
    96  	ip := header.IPv4(buf)
    97  	ip.Encode(&header.IPv4Fields{
    98  		TotalLength: uint16(len(buf)),
    99  		TTL:         ttl,
   100  		Protocol:    uint8(header.IGMPProtocolNumber),
   101  		SrcAddr:     srcAddr,
   102  		DstAddr:     dstAddr,
   103  		Options:     options,
   104  	})
   105  	ip.SetChecksum(^ip.CalculateChecksum())
   106  
   107  	igmp := header.IGMP(ip.Payload())
   108  	igmp.SetType(igmpType)
   109  	igmp.SetMaxRespTime(maxRespTime)
   110  	igmp.SetGroupAddress(groupAddress)
   111  	igmp.SetChecksum(header.IGMPCalculateChecksum(igmp))
   112  
   113  	e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
   114  		Data: buf.ToVectorisedView(),
   115  	}))
   116  }
   117  
   118  // TestIGMPV1Present tests the node's ability to fallback to V1 when a V1
   119  // router is detected. V1 present status is expected to be reset when the NIC
   120  // cycles.
   121  func TestIGMPV1Present(t *testing.T) {
   122  	e, s, clock := createStack(t, true)
   123  	addr := tcpip.AddressWithPrefix{Address: stackAddr, PrefixLen: defaultPrefixLength}
   124  	if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
   125  		t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err)
   126  	}
   127  
   128  	if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
   129  		t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err)
   130  	}
   131  
   132  	// This NIC will send an IGMPv2 report immediately, before this test can get
   133  	// the IGMPv1 General Membership Query in.
   134  	{
   135  		p, ok := e.Read()
   136  		if !ok {
   137  			t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
   138  		}
   139  		if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 {
   140  			t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got)
   141  		}
   142  		validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
   143  	}
   144  	if t.Failed() {
   145  		t.FailNow()
   146  	}
   147  
   148  	// Inject an IGMPv1 General Membership Query which is identical to a standard
   149  	// membership query except the Max Response Time is set to 0, which will tell
   150  	// the stack that this is a router using IGMPv1. Send it to the all systems
   151  	// group which is the only group this host belongs to.
   152  	createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, defaultTTL, remoteAddr, stackAddr, header.IPv4AllSystems, true /* hasRouterAlertOption */)
   153  	if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 {
   154  		t.Fatalf("got Membership Queries received = %d, want = 1", got)
   155  	}
   156  
   157  	// Before advancing the clock, verify that this host has not sent a
   158  	// V1MembershipReport yet.
   159  	if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 {
   160  		t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got)
   161  	}
   162  
   163  	// Verify the solicited Membership Report is sent. Now that this NIC has seen
   164  	// an IGMPv1 query, it should send an IGMPv1 Membership Report.
   165  	if p, ok := e.Read(); ok {
   166  		t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt)
   167  	}
   168  	clock.Advance(ipv4.UnsolicitedReportIntervalMax)
   169  	{
   170  		p, ok := e.Read()
   171  		if !ok {
   172  			t.Fatal("unable to Read IGMP packet, expected V1MembershipReport")
   173  		}
   174  		if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 {
   175  			t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got)
   176  		}
   177  		validateIgmpPacket(t, p, header.IGMPv1MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
   178  	}
   179  
   180  	// Cycling the interface should reset the V1 present flag.
   181  	if err := s.DisableNIC(nicID); err != nil {
   182  		t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
   183  	}
   184  	if err := s.EnableNIC(nicID); err != nil {
   185  		t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
   186  	}
   187  	{
   188  		p, ok := e.Read()
   189  		if !ok {
   190  			t.Fatal("unable to Read IGMP packet, expected V2MembershipReport")
   191  		}
   192  		if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 2 {
   193  			t.Fatalf("got V2MembershipReport messages sent = %d, want = 2", got)
   194  		}
   195  		validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
   196  	}
   197  }
   198  
   199  func TestSendQueuedIGMPReports(t *testing.T) {
   200  	e, s, clock := createStack(t, true)
   201  
   202  	// Joining a group without an assigned address should queue IGMP packets; none
   203  	// should be sent without an assigned address.
   204  	if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil {
   205  		t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err)
   206  	}
   207  	reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport
   208  	if got := reportStat.Value(); got != 0 {
   209  		t.Errorf("got reportStat.Value() = %d, want = 0", got)
   210  	}
   211  	clock.Advance(time.Hour)
   212  	if p, ok := e.Read(); ok {
   213  		t.Fatalf("got unexpected packet = %#v", p)
   214  	}
   215  
   216  	// The initial set of IGMP reports that were queued should be sent once an
   217  	// address is assigned.
   218  	if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
   219  		t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
   220  	}
   221  	if got := reportStat.Value(); got != 1 {
   222  		t.Errorf("got reportStat.Value() = %d, want = 1", got)
   223  	}
   224  	if p, ok := e.Read(); !ok {
   225  		t.Error("expected to send an IGMP membership report")
   226  	} else {
   227  		validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
   228  	}
   229  	if t.Failed() {
   230  		t.FailNow()
   231  	}
   232  	clock.Advance(ipv4.UnsolicitedReportIntervalMax)
   233  	if got := reportStat.Value(); got != 2 {
   234  		t.Errorf("got reportStat.Value() = %d, want = 2", got)
   235  	}
   236  	if p, ok := e.Read(); !ok {
   237  		t.Error("expected to send an IGMP membership report")
   238  	} else {
   239  		validateIgmpPacket(t, p, header.IGMPv2MembershipReport, 0, stackAddr, multicastAddr, multicastAddr)
   240  	}
   241  	if t.Failed() {
   242  		t.FailNow()
   243  	}
   244  
   245  	// Should have no more packets to send after the initial set of unsolicited
   246  	// reports.
   247  	clock.Advance(time.Hour)
   248  	if p, ok := e.Read(); ok {
   249  		t.Fatalf("got unexpected packet = %#v", p)
   250  	}
   251  }
   252  
   253  func TestIGMPPacketValidation(t *testing.T) {
   254  	tests := []struct {
   255  		name                     string
   256  		messageType              header.IGMPType
   257  		stackAddresses           []tcpip.AddressWithPrefix
   258  		srcAddr                  tcpip.Address
   259  		includeRouterAlertOption bool
   260  		ttl                      uint8
   261  		expectValidIGMP          bool
   262  		getMessageTypeStatValue  func(tcpip.Stats) uint64
   263  	}{
   264  		{
   265  			name:                     "valid",
   266  			messageType:              header.IGMPLeaveGroup,
   267  			includeRouterAlertOption: true,
   268  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   269  			srcAddr:                  remoteAddr,
   270  			ttl:                      1,
   271  			expectValidIGMP:          true,
   272  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() },
   273  		},
   274  		{
   275  			name:                     "bad ttl",
   276  			messageType:              header.IGMPv1MembershipReport,
   277  			includeRouterAlertOption: true,
   278  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   279  			srcAddr:                  remoteAddr,
   280  			ttl:                      2,
   281  			expectValidIGMP:          false,
   282  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() },
   283  		},
   284  		{
   285  			name:                     "missing router alert ip option",
   286  			messageType:              header.IGMPv2MembershipReport,
   287  			includeRouterAlertOption: false,
   288  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   289  			srcAddr:                  remoteAddr,
   290  			ttl:                      1,
   291  			expectValidIGMP:          false,
   292  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
   293  		},
   294  		{
   295  			name:                     "igmp leave group and src ip does not belong to nic subnet",
   296  			messageType:              header.IGMPLeaveGroup,
   297  			includeRouterAlertOption: true,
   298  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   299  			srcAddr:                  testutil.MustParse4("10.0.1.2"),
   300  			ttl:                      1,
   301  			expectValidIGMP:          false,
   302  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.LeaveGroup.Value() },
   303  		},
   304  		{
   305  			name:                     "igmp query and src ip does not belong to nic subnet",
   306  			messageType:              header.IGMPMembershipQuery,
   307  			includeRouterAlertOption: true,
   308  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   309  			srcAddr:                  testutil.MustParse4("10.0.1.2"),
   310  			ttl:                      1,
   311  			expectValidIGMP:          true,
   312  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.MembershipQuery.Value() },
   313  		},
   314  		{
   315  			name:                     "igmp report v1 and src ip does not belong to nic subnet",
   316  			messageType:              header.IGMPv1MembershipReport,
   317  			includeRouterAlertOption: true,
   318  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   319  			srcAddr:                  testutil.MustParse4("10.0.1.2"),
   320  			ttl:                      1,
   321  			expectValidIGMP:          false,
   322  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V1MembershipReport.Value() },
   323  		},
   324  		{
   325  			name:                     "igmp report v2 and src ip does not belong to nic subnet",
   326  			messageType:              header.IGMPv2MembershipReport,
   327  			includeRouterAlertOption: true,
   328  			stackAddresses:           []tcpip.AddressWithPrefix{{Address: stackAddr, PrefixLen: 24}},
   329  			srcAddr:                  testutil.MustParse4("10.0.1.2"),
   330  			ttl:                      1,
   331  			expectValidIGMP:          false,
   332  			getMessageTypeStatValue:  func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
   333  		},
   334  		{
   335  			name:                     "src ip belongs to the subnet of the nic's second address",
   336  			messageType:              header.IGMPv2MembershipReport,
   337  			includeRouterAlertOption: true,
   338  			stackAddresses: []tcpip.AddressWithPrefix{
   339  				{Address: testutil.MustParse4("10.0.15.1"), PrefixLen: 24},
   340  				{Address: stackAddr, PrefixLen: 24},
   341  			},
   342  			srcAddr:                 remoteAddr,
   343  			ttl:                     1,
   344  			expectValidIGMP:         true,
   345  			getMessageTypeStatValue: func(stats tcpip.Stats) uint64 { return stats.IGMP.PacketsReceived.V2MembershipReport.Value() },
   346  		},
   347  	}
   348  
   349  	for _, test := range tests {
   350  		t.Run(test.name, func(t *testing.T) {
   351  			e, s, _ := createStack(t, true)
   352  			for _, address := range test.stackAddresses {
   353  				if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, address); err != nil {
   354  					t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, address, err)
   355  				}
   356  			}
   357  			stats := s.Stats()
   358  			// Verify that every relevant stats is zero'd before we send a packet.
   359  			if got := test.getMessageTypeStatValue(s.Stats()); got != 0 {
   360  				t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 0", got)
   361  			}
   362  			if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != 0 {
   363  				t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = 0", got)
   364  			}
   365  			if got := stats.IP.PacketsDelivered.Value(); got != 0 {
   366  				t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 0", got)
   367  			}
   368  			createAndInjectIGMPPacket(e, test.messageType, 0, test.ttl, test.srcAddr, header.IPv4AllSystems, header.IPv4AllSystems, test.includeRouterAlertOption)
   369  			// We always expect the packet to pass IP validation.
   370  			if got := stats.IP.PacketsDelivered.Value(); got != 1 {
   371  				t.Fatalf("got stats.IP.PacketsDelivered.Value() = %d, want = 1", got)
   372  			}
   373  			// Even when the IGMP-specific validation checks fail, we expect the
   374  			// corresponding IGMP counter to be incremented.
   375  			if got := test.getMessageTypeStatValue(s.Stats()); got != 1 {
   376  				t.Errorf("got test.getMessageTypeStatValue(s.Stats()) = %d, want = 1", got)
   377  			}
   378  			var expectedInvalidCount uint64
   379  			if !test.expectValidIGMP {
   380  				expectedInvalidCount = 1
   381  			}
   382  			if got := stats.IGMP.PacketsReceived.Invalid.Value(); got != expectedInvalidCount {
   383  				t.Errorf("got stats.IGMP.PacketsReceived.Invalid.Value() = %d, want = %d", got, expectedInvalidCount)
   384  			}
   385  		})
   386  	}
   387  }