gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/testing/context/context.go (about)

     1  // Copyright 2022 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package context provides a context used by datagram-based network endpoints
    16  // tests. It also defines the TestFlow type to facilitate IP configurations.
    17  package context
    18  
    19  import (
    20  	"bytes"
    21  	"reflect"
    22  	"testing"
    23  
    24  	"github.com/google/go-cmp/cmp"
    25  	"golang.org/x/time/rate"
    26  	"gvisor.dev/gvisor/pkg/buffer"
    27  	"gvisor.dev/gvisor/pkg/refs"
    28  	"gvisor.dev/gvisor/pkg/tcpip"
    29  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    30  	"gvisor.dev/gvisor/pkg/tcpip/faketime"
    31  	"gvisor.dev/gvisor/pkg/tcpip/header"
    32  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    33  	"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
    34  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    35  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    36  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    37  	"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
    38  	"gvisor.dev/gvisor/pkg/waiter"
    39  )
    40  
    41  const (
    42  	// NICID is the id of the nic created by the Context.
    43  	NICID = 1
    44  
    45  	// DefaultMTU is the MTU used by the Context, except where another value is
    46  	// explicitly specified during initialization. It is chosen to match the MTU
    47  	// of loopback interfaces on linux systems.
    48  	DefaultMTU = 65536
    49  )
    50  
    51  // Context is a testing context for datagram-based network endpoints.
    52  type Context struct {
    53  	// T is the testing context.
    54  	T *testing.T
    55  
    56  	// LinkEP is the link endpoint that is attached to the stack's NIC.
    57  	LinkEP *channel.Endpoint
    58  
    59  	// Stack is the networking stack owned by the context.
    60  	Stack *stack.Stack
    61  
    62  	// EP is the transport endpoint owned by the context.
    63  	EP tcpip.Endpoint
    64  
    65  	// WQ is the wait queue associated with EP and is used to block for events on
    66  	// EP.
    67  	WQ waiter.Queue
    68  }
    69  
    70  // Options contains options for creating a new test context.
    71  type Options struct {
    72  	// MTU is the mtu that the link endpoint will be initialized with.
    73  	MTU uint32
    74  
    75  	// HandleLocal specifies if non-loopback interfaces are allowed to loop
    76  	// packets.
    77  	HandleLocal bool
    78  }
    79  
    80  // New allocates and initializes a test context containing a configured stack.
    81  func New(t *testing.T, transportProtocols []stack.TransportProtocolFactory) *Context {
    82  	t.Helper()
    83  
    84  	options := Options{
    85  		MTU:         DefaultMTU,
    86  		HandleLocal: true,
    87  	}
    88  
    89  	return NewWithOptions(t, transportProtocols, options)
    90  }
    91  
    92  // NewWithOptions allocates and initializes a test context containing a
    93  // configured stack with the provided options.
    94  func NewWithOptions(t *testing.T, transportProtocols []stack.TransportProtocolFactory, options Options) *Context {
    95  	t.Helper()
    96  
    97  	stackOptions := stack.Options{
    98  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    99  		TransportProtocols: transportProtocols,
   100  		HandleLocal:        options.HandleLocal,
   101  		Clock:              &faketime.NullClock{},
   102  		RawFactory:         &raw.EndpointFactory{},
   103  	}
   104  
   105  	s := stack.New(stackOptions)
   106  	// Disable ICMP rate limiter since we're using Null clock, which never
   107  	// advances time and thus never allows ICMP messages.
   108  	s.SetICMPLimit(rate.Inf)
   109  	ep := channel.New(256, options.MTU, "")
   110  	wep := stack.LinkEndpoint(ep)
   111  
   112  	if testing.Verbose() {
   113  		wep = sniffer.New(ep)
   114  	}
   115  	if err := s.CreateNIC(NICID, wep); err != nil {
   116  		t.Fatalf("CreateNIC(%d, _): %s", NICID, err)
   117  	}
   118  
   119  	protocolAddrV4 := tcpip.ProtocolAddress{
   120  		Protocol:          ipv4.ProtocolNumber,
   121  		AddressWithPrefix: tcpip.Address(StackAddr).WithPrefix(),
   122  	}
   123  	if err := s.AddProtocolAddress(NICID, protocolAddrV4, stack.AddressProperties{}); err != nil {
   124  		t.Fatalf("AddProtocolAddress(%d, %#v, {}): %s", NICID, protocolAddrV4, err)
   125  	}
   126  
   127  	protocolAddrV6 := tcpip.ProtocolAddress{
   128  		Protocol:          ipv6.ProtocolNumber,
   129  		AddressWithPrefix: tcpip.Address(StackV6Addr).WithPrefix(),
   130  	}
   131  	if err := s.AddProtocolAddress(NICID, protocolAddrV6, stack.AddressProperties{}); err != nil {
   132  		t.Fatalf("AddProtocolAddress(%d, %#v, {}): %s", NICID, protocolAddrV6, err)
   133  	}
   134  
   135  	s.SetRouteTable([]tcpip.Route{
   136  		{
   137  			Destination: header.IPv4EmptySubnet,
   138  			NIC:         NICID,
   139  		},
   140  		{
   141  			Destination: header.IPv6EmptySubnet,
   142  			NIC:         NICID,
   143  		},
   144  	})
   145  
   146  	return &Context{
   147  		T:      t,
   148  		Stack:  s,
   149  		LinkEP: ep,
   150  	}
   151  }
   152  
   153  // Cleanup closes the context endpoint if required.
   154  func (c *Context) Cleanup() {
   155  	_ = c.LinkEP.Drain()
   156  	if c.EP != nil {
   157  		c.EP.Close()
   158  	}
   159  	c.Stack.Destroy()
   160  	c.Stack = nil
   161  	refs.DoRepeatedLeakCheck()
   162  }
   163  
   164  // CreateEndpoint creates the Context's Endpoint.
   165  func (c *Context) CreateEndpoint(network tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber) {
   166  	c.T.Helper()
   167  
   168  	var err tcpip.Error
   169  	c.EP, err = c.Stack.NewEndpoint(transport, network, &c.WQ)
   170  	if err != nil {
   171  		c.T.Fatalf("c.Stack.NewEndpoint(%d, %d, _) failed: %s", transport, network, err)
   172  	}
   173  }
   174  
   175  // CreateEndpointForFlow creates the Context's Endpoint and configured it
   176  // according to the given TestFlow.
   177  func (c *Context) CreateEndpointForFlow(flow TestFlow, transport tcpip.TransportProtocolNumber) {
   178  	c.T.Helper()
   179  
   180  	c.CreateEndpoint(flow.SockProto(), transport)
   181  	if flow.isV6Only() {
   182  		c.EP.SocketOptions().SetV6Only(true)
   183  	} else if flow.isBroadcast() {
   184  		c.EP.SocketOptions().SetBroadcast(true)
   185  	}
   186  }
   187  
   188  // CreateRawEndpoint creates the Context's Endpoint.
   189  func (c *Context) CreateRawEndpoint(network tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber) {
   190  	c.T.Helper()
   191  
   192  	var err tcpip.Error
   193  	c.EP, err = c.Stack.NewRawEndpoint(transport, network, &c.WQ, true /* associated */)
   194  	if err != nil {
   195  		c.T.Fatal("c.Stack.NewRawEndpoint failed: ", err)
   196  	}
   197  }
   198  
   199  // CreateRawEndpointForFlow creates the Context's Endpoint and configured it
   200  // according to the given TestFlow.
   201  func (c *Context) CreateRawEndpointForFlow(flow TestFlow, transport tcpip.TransportProtocolNumber) {
   202  	c.T.Helper()
   203  
   204  	c.CreateRawEndpoint(flow.SockProto(), transport)
   205  	if flow.isV6Only() {
   206  		c.EP.SocketOptions().SetV6Only(true)
   207  	} else if flow.isBroadcast() {
   208  		c.EP.SocketOptions().SetBroadcast(true)
   209  	}
   210  }
   211  
   212  // CheckEndpointWriteStats checks that the write statistic related to the given
   213  // error has been incremented as expected.
   214  func (c *Context) CheckEndpointWriteStats(incr uint64, want *tcpip.TransportEndpointStats, err tcpip.Error) {
   215  	var got tcpip.TransportEndpointStats
   216  	c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&got)
   217  	switch err.(type) {
   218  	case nil:
   219  		want.PacketsSent.IncrementBy(incr)
   220  	case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
   221  		want.WriteErrors.InvalidArgs.IncrementBy(incr)
   222  	case *tcpip.ErrClosedForSend:
   223  		want.WriteErrors.WriteClosed.IncrementBy(incr)
   224  	case *tcpip.ErrInvalidEndpointState:
   225  		want.WriteErrors.InvalidEndpointState.IncrementBy(incr)
   226  	case *tcpip.ErrHostUnreachable, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
   227  		want.SendErrors.NoRoute.IncrementBy(incr)
   228  	default:
   229  		want.SendErrors.SendToNetworkFailed.IncrementBy(incr)
   230  	}
   231  	if !reflect.DeepEqual(&got, want) {
   232  		c.T.Errorf("Endpoint stats not matching for error %s: got %#v, want %#v", err, &got, want)
   233  	}
   234  }
   235  
   236  // CheckEndpointReadStats checks that the read statistic related to the given
   237  // error has been incremented as expected.
   238  func (c *Context) CheckEndpointReadStats(incr uint64, want *tcpip.TransportEndpointStats, err tcpip.Error) {
   239  	c.T.Helper()
   240  
   241  	var got tcpip.TransportEndpointStats
   242  	c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&got)
   243  	switch err.(type) {
   244  	case nil, *tcpip.ErrWouldBlock:
   245  	case *tcpip.ErrClosedForReceive:
   246  		want.ReadErrors.ReadClosed.IncrementBy(incr)
   247  	default:
   248  		c.T.Errorf("Endpoint error missing stats update for err %s", err)
   249  	}
   250  	if !reflect.DeepEqual(&got, want) {
   251  		c.T.Errorf("Endpoint stats not matching for error %s: got %#v, want %#v", err, &got, want)
   252  	}
   253  }
   254  
   255  // InjectPacket injects a packet into the context's link endpoint.
   256  func (c *Context) InjectPacket(netProto tcpip.NetworkProtocolNumber, buf []byte) {
   257  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   258  		Payload: buffer.MakeWithData(buf),
   259  	})
   260  	defer pkt.DecRef()
   261  	c.LinkEP.InjectInbound(netProto, pkt)
   262  }
   263  
   264  // readExpectations holds information about the expected outcome when reading
   265  // from the context's endpoint.
   266  type readExpectations struct {
   267  	nothingToRead  bool
   268  	payload        []byte
   269  	addresses      Header4Tuple
   270  	readShouldFail bool
   271  }
   272  
   273  // readFromEndpoint attempts to read a packet from the endpoint and compares the
   274  // outcome with the given expectations.
   275  func (c *Context) readFromEndpoint(expectations readExpectations, checkers ...checker.ControlMessagesChecker) {
   276  	c.T.Helper()
   277  
   278  	// Try to receive the data.
   279  	we, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
   280  	c.WQ.EventRegister(&we)
   281  	defer c.WQ.EventUnregister(&we)
   282  
   283  	// Take a snapshot of the stats to validate them at the end of the test.
   284  	var epstats tcpip.TransportEndpointStats
   285  	c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&epstats)
   286  
   287  	var buf bytes.Buffer
   288  	res, err := c.EP.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
   289  	if _, ok := err.(*tcpip.ErrWouldBlock); ok {
   290  		select {
   291  		case <-ch:
   292  			res, err = c.EP.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
   293  		default:
   294  			if expectations.nothingToRead {
   295  				return
   296  			}
   297  			c.T.Fatal("timed out waiting for data")
   298  		}
   299  	}
   300  
   301  	if expectations.readShouldFail && err != nil {
   302  		c.CheckEndpointReadStats(1, &epstats, err)
   303  		return
   304  	}
   305  
   306  	if err != nil {
   307  		c.T.Fatal("Read failed:", err)
   308  	}
   309  
   310  	if expectations.nothingToRead {
   311  		c.T.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr)
   312  	}
   313  
   314  	// Check the read result.
   315  	if diff := cmp.Diff(tcpip.ReadResult{
   316  		Count:      buf.Len(),
   317  		Total:      buf.Len(),
   318  		RemoteAddr: tcpip.FullAddress{Addr: expectations.addresses.Src.Addr},
   319  	}, res, checker.IgnoreCmpPath(
   320  		"ControlMessages", // ControlMessages are checked below.
   321  		"RemoteAddr.NIC",
   322  		"RemoteAddr.Port",
   323  	)); diff != "" {
   324  		c.T.Fatalf("Read: unexpected result (-want +got):\n%s", diff)
   325  	}
   326  
   327  	// Check the payload.
   328  	v := buf.Bytes()
   329  	if !bytes.Equal(expectations.payload, v) {
   330  		c.T.Fatalf("got payload = %x, want = %x", v, expectations.payload)
   331  	}
   332  
   333  	// Run any checkers against the ControlMessages.
   334  	for _, f := range checkers {
   335  		f(c.T, res.ControlMessages)
   336  	}
   337  
   338  	c.CheckEndpointReadStats(1, &epstats, err)
   339  }
   340  
   341  // ReadFromEndpointExpectSuccess attempts to reads from the endpoint and
   342  // performs checks on the received packet, according to the given flow and
   343  // checkers.
   344  func (c *Context) ReadFromEndpointExpectSuccess(payload []byte, flow TestFlow, checkers ...checker.ControlMessagesChecker) {
   345  	c.T.Helper()
   346  
   347  	c.readFromEndpoint(readExpectations{
   348  		payload:   payload,
   349  		addresses: flow.MakeHeader4Tuple(Incoming),
   350  	}, checkers...)
   351  }
   352  
   353  // ReadFromEndpointExpectNoPacket reads from the endpoint and checks that no
   354  // packets was received.
   355  func (c *Context) ReadFromEndpointExpectNoPacket() {
   356  	c.T.Helper()
   357  
   358  	c.readFromEndpoint(readExpectations{
   359  		nothingToRead: true,
   360  	})
   361  }
   362  
   363  // ReadFromEndpointExpectError reads from the endpoint and checks that an
   364  // error was returned.
   365  func (c *Context) ReadFromEndpointExpectError() {
   366  	c.T.Helper()
   367  
   368  	c.readFromEndpoint(readExpectations{
   369  		readShouldFail: true,
   370  	})
   371  }