github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/transport/tcp/testing/context/context.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package context provides a test context for use in tcp tests. It also
    16  // provides helper methods to assert/check certain behaviours.
    17  package context
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/SagerNet/gvisor/pkg/tcpip"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/checker"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/link/channel"
    30  	"github.com/SagerNet/gvisor/pkg/tcpip/link/sniffer"
    31  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4"
    32  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6"
    33  	"github.com/SagerNet/gvisor/pkg/tcpip/seqnum"
    34  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    35  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/tcp"
    36  	"github.com/SagerNet/gvisor/pkg/waiter"
    37  )
    38  
    39  const (
    40  	// StackAddr is the IPv4 address assigned to the stack.
    41  	StackAddr = "\x0a\x00\x00\x01"
    42  
    43  	// StackPort is used as the listening port in tests for passive
    44  	// connects.
    45  	StackPort = 1234
    46  
    47  	// TestAddr is the source address for packets sent to the stack via the
    48  	// link layer endpoint.
    49  	TestAddr = "\x0a\x00\x00\x02"
    50  
    51  	// TestPort is the TCP port used for packets sent to the stack
    52  	// via the link layer endpoint.
    53  	TestPort = 4096
    54  
    55  	// StackV6Addr is the IPv6 address assigned to the stack.
    56  	StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
    57  
    58  	// TestV6Addr is the source address for packets sent to the stack via
    59  	// the link layer endpoint.
    60  	TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
    61  
    62  	// StackV4MappedAddr is StackAddr as a mapped v6 address.
    63  	StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
    64  
    65  	// TestV4MappedAddr is TestAddr as a mapped v6 address.
    66  	TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr
    67  
    68  	// V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0.
    69  	V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
    70  
    71  	// TestInitialSequenceNumber is the initial sequence number sent in packets that
    72  	// are sent in response to a SYN or in the initial SYN sent to the stack.
    73  	TestInitialSequenceNumber = 789
    74  )
    75  
    76  // StackAddrWithPrefix is StackAddr with its associated prefix length.
    77  var StackAddrWithPrefix = tcpip.AddressWithPrefix{
    78  	Address:   StackAddr,
    79  	PrefixLen: 24,
    80  }
    81  
    82  // StackV6AddrWithPrefix is StackV6Addr with its associated prefix length.
    83  var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{
    84  	Address:   StackV6Addr,
    85  	PrefixLen: header.IIDOffsetInIPv6Address * 8,
    86  }
    87  
    88  // Headers is used to represent the TCP header fields when building a
    89  // new packet.
    90  type Headers struct {
    91  	// SrcPort holds the src port value to be used in the packet.
    92  	SrcPort uint16
    93  
    94  	// DstPort holds the destination port value to be used in the packet.
    95  	DstPort uint16
    96  
    97  	// SeqNum is the value of the sequence number field in the TCP header.
    98  	SeqNum seqnum.Value
    99  
   100  	// AckNum represents the acknowledgement number field in the TCP header.
   101  	AckNum seqnum.Value
   102  
   103  	// Flags are the TCP flags in the TCP header.
   104  	Flags header.TCPFlags
   105  
   106  	// RcvWnd is the window to be advertised in the ReceiveWindow field of
   107  	// the TCP header.
   108  	RcvWnd seqnum.Size
   109  
   110  	// TCPOpts holds the options to be sent in the option field of the TCP
   111  	// header.
   112  	TCPOpts []byte
   113  }
   114  
   115  // Options contains options for creating a new test context.
   116  type Options struct {
   117  	// EnableV4 indicates whether IPv4 should be enabled.
   118  	EnableV4 bool
   119  
   120  	// EnableV6 indicates whether IPv4 should be enabled.
   121  	EnableV6 bool
   122  
   123  	// MTU indicates the maximum transmission unit on the link layer.
   124  	MTU uint32
   125  }
   126  
   127  // Context provides an initialized Network stack and a link layer endpoint
   128  // for use in TCP tests.
   129  type Context struct {
   130  	t      *testing.T
   131  	linkEP *channel.Endpoint
   132  	s      *stack.Stack
   133  
   134  	// IRS holds the initial sequence number in the SYN sent by endpoint in
   135  	// case of an active connect or the sequence number sent by the endpoint
   136  	// in the SYN-ACK sent in response to a SYN when listening in passive
   137  	// mode.
   138  	IRS seqnum.Value
   139  
   140  	// Port holds the port bound by EP below in case of an active connect or
   141  	// the listening port number in case of a passive connect.
   142  	Port uint16
   143  
   144  	// EP is the test endpoint in the stack owned by this context. This endpoint
   145  	// is used in various tests to either initiate an active connect or is used
   146  	// as a passive listening endpoint to accept inbound connections.
   147  	EP tcpip.Endpoint
   148  
   149  	// Wq is the wait queue associated with EP and is used to block for events
   150  	// on EP.
   151  	WQ waiter.Queue
   152  
   153  	// TimeStampEnabled is true if ep is connected with the timestamp option
   154  	// enabled.
   155  	TimeStampEnabled bool
   156  
   157  	// WindowScale is the expected window scale in SYN packets sent by
   158  	// the stack.
   159  	WindowScale uint8
   160  
   161  	// RcvdWindowScale is the actual window scale sent by the stack in
   162  	// SYN/SYN-ACK.
   163  	RcvdWindowScale uint8
   164  }
   165  
   166  // New allocates and initializes a test context containing a new
   167  // stack and a link-layer endpoint.
   168  func New(t *testing.T, mtu uint32) *Context {
   169  	return NewWithOpts(t, Options{
   170  		EnableV4: true,
   171  		EnableV6: true,
   172  		MTU:      mtu,
   173  	})
   174  }
   175  
   176  // NewWithOpts allocates and initializes a test context containing a new
   177  // stack and a link-layer endpoint with specific options.
   178  func NewWithOpts(t *testing.T, opts Options) *Context {
   179  	if opts.MTU == 0 {
   180  		panic("MTU must be greater than 0")
   181  	}
   182  
   183  	stackOpts := stack.Options{
   184  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
   185  	}
   186  	if opts.EnableV4 {
   187  		stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol)
   188  	}
   189  	if opts.EnableV6 {
   190  		stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol)
   191  	}
   192  	s := stack.New(stackOpts)
   193  
   194  	const sendBufferSize = 1 << 20 // 1 MiB
   195  	const recvBufferSize = 1 << 20 // 1 MiB
   196  	// Allow minimum send/receive buffer sizes to be 1 during tests.
   197  	sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}
   198  	if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil {
   199  		t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err)
   200  	}
   201  
   202  	rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}
   203  	if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil {
   204  		t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err)
   205  	}
   206  
   207  	// Increase minimum RTO in tests to avoid test flakes due to early
   208  	// retransmit in case the test executors are overloaded and cause timers
   209  	// to fire earlier than expected.
   210  	minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second)
   211  	if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil {
   212  		t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err)
   213  	}
   214  
   215  	// Some of the congestion control tests send up to 640 packets, we so
   216  	// set the channel size to 1000.
   217  	ep := channel.New(1000, opts.MTU, "")
   218  	wep := stack.LinkEndpoint(ep)
   219  	if testing.Verbose() {
   220  		wep = sniffer.New(ep)
   221  	}
   222  	nicOpts := stack.NICOptions{Name: "nic1"}
   223  	if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil {
   224  		t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
   225  	}
   226  	wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, ""))
   227  	if testing.Verbose() {
   228  		wep2 = sniffer.New(channel.New(1000, opts.MTU, ""))
   229  	}
   230  	opts2 := stack.NICOptions{Name: "nic2"}
   231  	if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
   232  		t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
   233  	}
   234  
   235  	var routeTable []tcpip.Route
   236  
   237  	if opts.EnableV4 {
   238  		v4ProtocolAddr := tcpip.ProtocolAddress{
   239  			Protocol:          ipv4.ProtocolNumber,
   240  			AddressWithPrefix: StackAddrWithPrefix,
   241  		}
   242  		if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil {
   243  			t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err)
   244  		}
   245  		routeTable = append(routeTable, tcpip.Route{
   246  			Destination: header.IPv4EmptySubnet,
   247  			NIC:         1,
   248  		})
   249  	}
   250  
   251  	if opts.EnableV6 {
   252  		v6ProtocolAddr := tcpip.ProtocolAddress{
   253  			Protocol:          ipv6.ProtocolNumber,
   254  			AddressWithPrefix: StackV6AddrWithPrefix,
   255  		}
   256  		if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil {
   257  			t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err)
   258  		}
   259  		routeTable = append(routeTable, tcpip.Route{
   260  			Destination: header.IPv6EmptySubnet,
   261  			NIC:         1,
   262  		})
   263  	}
   264  
   265  	s.SetRouteTable(routeTable)
   266  
   267  	return &Context{
   268  		t:           t,
   269  		s:           s,
   270  		linkEP:      ep,
   271  		WindowScale: uint8(tcp.FindWndScale(recvBufferSize)),
   272  	}
   273  }
   274  
   275  // Cleanup closes the context endpoint if required.
   276  func (c *Context) Cleanup() {
   277  	if c.EP != nil {
   278  		c.EP.Close()
   279  	}
   280  	c.Stack().Close()
   281  }
   282  
   283  // Stack returns a reference to the stack in the Context.
   284  func (c *Context) Stack() *stack.Stack {
   285  	return c.s
   286  }
   287  
   288  // CheckNoPacketTimeout verifies that no packet is received during the time
   289  // specified by wait.
   290  func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
   291  	c.t.Helper()
   292  
   293  	ctx, cancel := context.WithTimeout(context.Background(), wait)
   294  	defer cancel()
   295  	if _, ok := c.linkEP.ReadContext(ctx); ok {
   296  		c.t.Fatal(errMsg)
   297  	}
   298  }
   299  
   300  // CheckNoPacket verifies that no packet is received for 1 second.
   301  func (c *Context) CheckNoPacket(errMsg string) {
   302  	c.CheckNoPacketTimeout(errMsg, 1*time.Second)
   303  }
   304  
   305  // GetPacketWithTimeout reads a packet from the link layer endpoint and verifies
   306  // that it is an IPv4 packet with the expected source and destination
   307  // addresses. If no packet is received in the specified timeout it will return
   308  // nil.
   309  func (c *Context) GetPacketWithTimeout(timeout time.Duration) []byte {
   310  	c.t.Helper()
   311  
   312  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
   313  	defer cancel()
   314  	p, ok := c.linkEP.ReadContext(ctx)
   315  	if !ok {
   316  		return nil
   317  	}
   318  
   319  	if p.Proto != ipv4.ProtocolNumber {
   320  		c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
   321  	}
   322  
   323  	// Just check that the stack set the transport protocol number for outbound
   324  	// TCP messages.
   325  	// TODO(github.com/SagerNet/issues/3810): Remove when protocol numbers are part
   326  	// of the headerinfo.
   327  	if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
   328  		c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
   329  	}
   330  
   331  	vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
   332  	b := vv.ToView()
   333  
   334  	if p.Pkt.GSOOptions.Type != stack.GSONone && p.Pkt.GSOOptions.L3HdrLen != header.IPv4MinimumSize {
   335  		c.t.Errorf("got L3HdrLen = %d, want = %d", p.Pkt.GSOOptions.L3HdrLen, header.IPv4MinimumSize)
   336  	}
   337  
   338  	checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
   339  	return b
   340  }
   341  
   342  // GetPacket reads a packet from the link layer endpoint and verifies
   343  // that it is an IPv4 packet with the expected source and destination
   344  // addresses.
   345  func (c *Context) GetPacket() []byte {
   346  	c.t.Helper()
   347  
   348  	p := c.GetPacketWithTimeout(5 * time.Second)
   349  	if p == nil {
   350  		c.t.Fatalf("Packet wasn't written out")
   351  		return nil
   352  	}
   353  
   354  	return p
   355  }
   356  
   357  // GetPacketNonBlocking reads a packet from the link layer endpoint
   358  // and verifies that it is an IPv4 packet with the expected source
   359  // and destination address. If no packet is available it will return
   360  // nil immediately.
   361  func (c *Context) GetPacketNonBlocking() []byte {
   362  	c.t.Helper()
   363  
   364  	p, ok := c.linkEP.Read()
   365  	if !ok {
   366  		return nil
   367  	}
   368  
   369  	if p.Proto != ipv4.ProtocolNumber {
   370  		c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
   371  	}
   372  
   373  	// Just check that the stack set the transport protocol number for outbound
   374  	// TCP messages.
   375  	// TODO(github.com/SagerNet/issues/3810): Remove when protocol numbers are part
   376  	// of the headerinfo.
   377  	if p.Pkt.TransportProtocolNumber != tcp.ProtocolNumber {
   378  		c.t.Fatalf("got p.Pkt.TransportProtocolNumber = %d, want = %d", p.Pkt.TransportProtocolNumber, tcp.ProtocolNumber)
   379  	}
   380  
   381  	vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
   382  	b := vv.ToView()
   383  
   384  	checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
   385  	return b
   386  }
   387  
   388  // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
   389  func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) {
   390  	// Allocate a buffer data and headers.
   391  	buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
   392  	if len(buf) > maxTotalSize {
   393  		buf = buf[:maxTotalSize]
   394  	}
   395  
   396  	ip := header.IPv4(buf)
   397  	ip.Encode(&header.IPv4Fields{
   398  		TotalLength: uint16(len(buf)),
   399  		TTL:         65,
   400  		Protocol:    uint8(header.ICMPv4ProtocolNumber),
   401  		SrcAddr:     TestAddr,
   402  		DstAddr:     StackAddr,
   403  	})
   404  	ip.SetChecksum(^ip.CalculateChecksum())
   405  
   406  	icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
   407  	icmp.SetType(typ)
   408  	icmp.SetCode(code)
   409  	const icmpv4VariableHeaderOffset = 4
   410  	copy(icmp[icmpv4VariableHeaderOffset:], p1)
   411  	copy(icmp[header.ICMPv4PayloadOffset:], p2)
   412  	icmp.SetChecksum(0)
   413  	checksum := ^header.Checksum(icmp, 0 /* initial */)
   414  	icmp.SetChecksum(checksum)
   415  
   416  	// Inject packet.
   417  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   418  		Data: buf.ToVectorisedView(),
   419  	})
   420  	c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
   421  }
   422  
   423  // BuildSegment builds a TCP segment based on the given Headers and payload.
   424  func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
   425  	return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
   426  }
   427  
   428  // BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
   429  // payload and source and destination IPv4 addresses.
   430  func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
   431  	// Allocate a buffer for data and headers.
   432  	buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
   433  	copy(buf[len(buf)-len(payload):], payload)
   434  	copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts)
   435  
   436  	// Initialize the IP header.
   437  	ip := header.IPv4(buf)
   438  	ip.Encode(&header.IPv4Fields{
   439  		TotalLength: uint16(len(buf)),
   440  		TTL:         65,
   441  		Protocol:    uint8(tcp.ProtocolNumber),
   442  		SrcAddr:     src,
   443  		DstAddr:     dst,
   444  	})
   445  	ip.SetChecksum(^ip.CalculateChecksum())
   446  
   447  	// Initialize the TCP header.
   448  	t := header.TCP(buf[header.IPv4MinimumSize:])
   449  	t.Encode(&header.TCPFields{
   450  		SrcPort:    h.SrcPort,
   451  		DstPort:    h.DstPort,
   452  		SeqNum:     uint32(h.SeqNum),
   453  		AckNum:     uint32(h.AckNum),
   454  		DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)),
   455  		Flags:      h.Flags,
   456  		WindowSize: uint16(h.RcvWnd),
   457  	})
   458  
   459  	// Calculate the TCP pseudo-header checksum.
   460  	xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
   461  
   462  	// Calculate the TCP checksum and set it.
   463  	xsum = header.Checksum(payload, xsum)
   464  	t.SetChecksum(^t.CalculateChecksum(xsum))
   465  
   466  	// Inject packet.
   467  	return buf.ToVectorisedView()
   468  }
   469  
   470  // SendSegment sends a TCP segment that has already been built and written to a
   471  // buffer.VectorisedView.
   472  func (c *Context) SendSegment(s buffer.VectorisedView) {
   473  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   474  		Data: s,
   475  	})
   476  	c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
   477  }
   478  
   479  // SendPacket builds and sends a TCP segment(with the provided payload & TCP
   480  // headers) in an IPv4 packet via the link layer endpoint.
   481  func (c *Context) SendPacket(payload []byte, h *Headers) {
   482  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   483  		Data: c.BuildSegment(payload, h),
   484  	})
   485  	c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
   486  }
   487  
   488  // SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
   489  // & TCPheaders) in an IPv4 packet via the link layer endpoint using the
   490  // provided source and destination IPv4 addresses.
   491  func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
   492  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   493  		Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
   494  	})
   495  	c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
   496  }
   497  
   498  // SendAck sends an ACK packet.
   499  func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
   500  	c.SendAckWithSACK(seq, bytesReceived, nil)
   501  }
   502  
   503  // SendAckWithSACK sends an ACK packet which includes the sackBlocks specified.
   504  func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) {
   505  	options := make([]byte, 40)
   506  	offset := 0
   507  	if len(sackBlocks) > 0 {
   508  		offset += header.EncodeNOP(options[offset:])
   509  		offset += header.EncodeNOP(options[offset:])
   510  		offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
   511  	}
   512  
   513  	c.SendPacket(nil, &Headers{
   514  		SrcPort: TestPort,
   515  		DstPort: c.Port,
   516  		Flags:   header.TCPFlagAck,
   517  		SeqNum:  seq,
   518  		AckNum:  c.IRS.Add(1 + seqnum.Size(bytesReceived)),
   519  		RcvWnd:  30000,
   520  		TCPOpts: options[:offset],
   521  	})
   522  }
   523  
   524  // ReceiveAndCheckPacket reads a packet from the link layer endpoint and
   525  // verifies that the packet packet payload of packet matches the slice
   526  // of data indicated by offset & size.
   527  func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
   528  	c.t.Helper()
   529  
   530  	c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0)
   531  }
   532  
   533  // ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint
   534  // and verifies that the packet packet payload of packet matches the slice of
   535  // data indicated by offset & size and skips optlen bytes in addition to the IP
   536  // TCP headers when comparing the data.
   537  func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) {
   538  	c.t.Helper()
   539  
   540  	b := c.GetPacket()
   541  	checker.IPv4(c.t, b,
   542  		checker.PayloadLen(size+header.TCPMinimumSize+optlen),
   543  		checker.TCP(
   544  			checker.DstPort(TestPort),
   545  			checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
   546  			checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
   547  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
   548  		),
   549  	)
   550  
   551  	pdata := data[offset:][:size]
   552  	if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 {
   553  		c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
   554  	}
   555  }
   556  
   557  // ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint
   558  // and verifies that the packet packet payload of packet matches the slice of
   559  // data indicated by offset & size. It returns true if a packet was received and
   560  // processed.
   561  func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool {
   562  	c.t.Helper()
   563  
   564  	b := c.GetPacketNonBlocking()
   565  	if b == nil {
   566  		return false
   567  	}
   568  	checker.IPv4(c.t, b,
   569  		checker.PayloadLen(size+header.TCPMinimumSize),
   570  		checker.TCP(
   571  			checker.DstPort(TestPort),
   572  			checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
   573  			checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
   574  			checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
   575  		),
   576  	)
   577  
   578  	pdata := data[offset:][:size]
   579  	if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 {
   580  		c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
   581  	}
   582  	return true
   583  }
   584  
   585  // CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only
   586  // is true then it sets the IP_V6ONLY option on the socket to make it a IPv6
   587  // only endpoint instead of a default dual stack socket.
   588  func (c *Context) CreateV6Endpoint(v6only bool) {
   589  	var err tcpip.Error
   590  	c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ)
   591  	if err != nil {
   592  		c.t.Fatalf("NewEndpoint failed: %v", err)
   593  	}
   594  
   595  	c.EP.SocketOptions().SetV6Only(v6only)
   596  }
   597  
   598  // GetV6Packet reads a single packet from the link layer endpoint of the context
   599  // and asserts that it is an IPv6 Packet with the expected src/dest addresses.
   600  func (c *Context) GetV6Packet() []byte {
   601  	c.t.Helper()
   602  
   603  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   604  	defer cancel()
   605  	p, ok := c.linkEP.ReadContext(ctx)
   606  	if !ok {
   607  		c.t.Fatalf("Packet wasn't written out")
   608  		return nil
   609  	}
   610  
   611  	if p.Proto != ipv6.ProtocolNumber {
   612  		c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
   613  	}
   614  	vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
   615  	b := vv.ToView()
   616  
   617  	checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
   618  	return b
   619  }
   620  
   621  // SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
   622  // the context.
   623  func (c *Context) SendV6Packet(payload []byte, h *Headers) {
   624  	c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
   625  }
   626  
   627  // SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
   628  // endpoint of the context using the provided source and destination IPv6
   629  // addresses.
   630  func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
   631  	// Allocate a buffer for data and headers.
   632  	buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
   633  	copy(buf[len(buf)-len(payload):], payload)
   634  
   635  	// Initialize the IP header.
   636  	ip := header.IPv6(buf)
   637  	ip.Encode(&header.IPv6Fields{
   638  		PayloadLength:     uint16(header.TCPMinimumSize + len(payload)),
   639  		TransportProtocol: tcp.ProtocolNumber,
   640  		HopLimit:          65,
   641  		SrcAddr:           src,
   642  		DstAddr:           dst,
   643  	})
   644  
   645  	// Initialize the TCP header.
   646  	t := header.TCP(buf[header.IPv6MinimumSize:])
   647  	t.Encode(&header.TCPFields{
   648  		SrcPort:    h.SrcPort,
   649  		DstPort:    h.DstPort,
   650  		SeqNum:     uint32(h.SeqNum),
   651  		AckNum:     uint32(h.AckNum),
   652  		DataOffset: header.TCPMinimumSize,
   653  		Flags:      h.Flags,
   654  		WindowSize: uint16(h.RcvWnd),
   655  	})
   656  
   657  	// Calculate the TCP pseudo-header checksum.
   658  	xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
   659  
   660  	// Calculate the TCP checksum and set it.
   661  	xsum = header.Checksum(payload, xsum)
   662  	t.SetChecksum(^t.CalculateChecksum(xsum))
   663  
   664  	// Inject packet.
   665  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   666  		Data: buf.ToVectorisedView(),
   667  	})
   668  	c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt)
   669  }
   670  
   671  // CreateConnected creates a connected TCP endpoint.
   672  func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
   673  	c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
   674  }
   675  
   676  // Connect performs the 3-way handshake for c.EP with the provided Initial
   677  // Sequence Number (iss) and receive window(rcvWnd) and any options if
   678  // specified.
   679  //
   680  // It also sets the receive buffer for the endpoint to the specified
   681  // value in epRcvBuf.
   682  //
   683  // PreCondition: c.EP must already be created.
   684  func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
   685  	c.t.Helper()
   686  
   687  	// Start connection attempt.
   688  	waitEntry, notifyCh := waiter.NewChannelEntry(nil)
   689  	c.WQ.EventRegister(&waitEntry, waiter.WritableEvents)
   690  	defer c.WQ.EventUnregister(&waitEntry)
   691  
   692  	err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
   693  	if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
   694  		c.t.Fatalf("Unexpected return value from Connect: %v", err)
   695  	}
   696  
   697  	// Receive SYN packet.
   698  	b := c.GetPacket()
   699  	checker.IPv4(c.t, b,
   700  		checker.TCP(
   701  			checker.DstPort(TestPort),
   702  			checker.TCPFlags(header.TCPFlagSyn),
   703  		),
   704  	)
   705  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
   706  		c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
   707  	}
   708  
   709  	tcpHdr := header.TCP(header.IPv4(b).Payload())
   710  	synOpts := header.ParseSynOptions(tcpHdr.Options(), false /* isAck */)
   711  	c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
   712  
   713  	c.SendPacket(nil, &Headers{
   714  		SrcPort: tcpHdr.DestinationPort(),
   715  		DstPort: tcpHdr.SourcePort(),
   716  		Flags:   header.TCPFlagSyn | header.TCPFlagAck,
   717  		SeqNum:  iss,
   718  		AckNum:  c.IRS.Add(1),
   719  		RcvWnd:  rcvWnd,
   720  		TCPOpts: options,
   721  	})
   722  
   723  	// Receive ACK packet.
   724  	checker.IPv4(c.t, c.GetPacket(),
   725  		checker.TCP(
   726  			checker.DstPort(TestPort),
   727  			checker.TCPFlags(header.TCPFlagAck),
   728  			checker.TCPSeqNum(uint32(c.IRS)+1),
   729  			checker.TCPAckNum(uint32(iss)+1),
   730  		),
   731  	)
   732  
   733  	// Wait for connection to be established.
   734  	select {
   735  	case <-notifyCh:
   736  		if err := c.EP.LastError(); err != nil {
   737  			c.t.Fatalf("Unexpected error when connecting: %v", err)
   738  		}
   739  	case <-time.After(1 * time.Second):
   740  		c.t.Fatalf("Timed out waiting for connection")
   741  	}
   742  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
   743  		c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
   744  	}
   745  
   746  	c.RcvdWindowScale = uint8(synOpts.WS)
   747  	c.Port = tcpHdr.SourcePort()
   748  }
   749  
   750  // Create creates a TCP endpoint.
   751  func (c *Context) Create(epRcvBuf int) {
   752  	// Create TCP endpoint.
   753  	var err tcpip.Error
   754  	c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
   755  	if err != nil {
   756  		c.t.Fatalf("NewEndpoint failed: %v", err)
   757  	}
   758  
   759  	if epRcvBuf != -1 {
   760  		c.EP.SocketOptions().SetReceiveBufferSize(int64(epRcvBuf), true /* notify */)
   761  	}
   762  }
   763  
   764  // CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
   765  // the specified option bytes as the Option field in the initial SYN packet.
   766  //
   767  // It also sets the receive buffer for the endpoint to the specified
   768  // value in epRcvBuf.
   769  func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
   770  	c.Create(epRcvBuf)
   771  	c.Connect(iss, rcvWnd, options)
   772  }
   773  
   774  // RawEndpoint is just a small wrapper around a TCP endpoint's state to make
   775  // sending data and ACK packets easy while being able to manipulate the sequence
   776  // numbers and timestamp values as needed.
   777  type RawEndpoint struct {
   778  	C          *Context
   779  	SrcPort    uint16
   780  	DstPort    uint16
   781  	Flags      header.TCPFlags
   782  	NextSeqNum seqnum.Value
   783  	AckNum     seqnum.Value
   784  	WndSize    seqnum.Size
   785  	RecentTS   uint32 // Stores the latest timestamp to echo back.
   786  	TSVal      uint32 // TSVal stores the last timestamp sent by this endpoint.
   787  
   788  	// SackPermitted is true if SACKPermitted option was negotiated for this endpoint.
   789  	SACKPermitted bool
   790  }
   791  
   792  // SendPacketWithTS embeds the provided tsVal in the Timestamp option
   793  // for the packet to be sent out.
   794  func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) {
   795  	r.TSVal = tsVal
   796  	tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
   797  	header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:])
   798  	r.SendPacket(payload, tsOpt[:])
   799  }
   800  
   801  // SendPacket is a small wrapper function to build and send packets.
   802  func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) {
   803  	packetHeaders := &Headers{
   804  		SrcPort: r.SrcPort,
   805  		DstPort: r.DstPort,
   806  		Flags:   r.Flags,
   807  		SeqNum:  r.NextSeqNum,
   808  		AckNum:  r.AckNum,
   809  		RcvWnd:  r.WndSize,
   810  		TCPOpts: opts,
   811  	}
   812  	r.C.SendPacket(payload, packetHeaders)
   813  	r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload)))
   814  }
   815  
   816  // VerifyAndReturnACKWithTS verifies that the tsEcr field int he ACK matches
   817  // the provided tsVal as well as returns the original packet.
   818  func (r *RawEndpoint) VerifyAndReturnACKWithTS(tsVal uint32) []byte {
   819  	r.C.t.Helper()
   820  	// Read ACK and verify that tsEcr of ACK packet is [1,2,3,4]
   821  	ackPacket := r.C.GetPacket()
   822  	checker.IPv4(r.C.t, ackPacket,
   823  		checker.TCP(
   824  			checker.DstPort(r.SrcPort),
   825  			checker.TCPFlags(header.TCPFlagAck),
   826  			checker.TCPSeqNum(uint32(r.AckNum)),
   827  			checker.TCPAckNum(uint32(r.NextSeqNum)),
   828  			checker.TCPTimestampChecker(true, 0, tsVal),
   829  		),
   830  	)
   831  	// Store the parsed TSVal from the ack as recentTS.
   832  	tcpSeg := header.TCP(header.IPv4(ackPacket).Payload())
   833  	opts := tcpSeg.ParsedOptions()
   834  	r.RecentTS = opts.TSVal
   835  	return ackPacket
   836  }
   837  
   838  // VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
   839  // tsVal.
   840  func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
   841  	r.C.t.Helper()
   842  	_ = r.VerifyAndReturnACKWithTS(tsVal)
   843  }
   844  
   845  // VerifyACKRcvWnd verifies that the window advertised by the incoming ACK
   846  // matches the provided rcvWnd.
   847  func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) {
   848  	r.C.t.Helper()
   849  	ackPacket := r.C.GetPacket()
   850  	checker.IPv4(r.C.t, ackPacket,
   851  		checker.TCP(
   852  			checker.DstPort(r.SrcPort),
   853  			checker.TCPFlags(header.TCPFlagAck),
   854  			checker.TCPSeqNum(uint32(r.AckNum)),
   855  			checker.TCPAckNum(uint32(r.NextSeqNum)),
   856  			checker.TCPWindow(rcvWnd),
   857  		),
   858  	)
   859  }
   860  
   861  // VerifyACKNoSACK verifies that the ACK does not contain a SACK block.
   862  func (r *RawEndpoint) VerifyACKNoSACK() {
   863  	r.VerifyACKHasSACK(nil)
   864  }
   865  
   866  // VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks.
   867  func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
   868  	// Read ACK and verify that the TCP options in the segment do
   869  	// not contain a SACK block.
   870  	ackPacket := r.C.GetPacket()
   871  	checker.IPv4(r.C.t, ackPacket,
   872  		checker.TCP(
   873  			checker.DstPort(r.SrcPort),
   874  			checker.TCPFlags(header.TCPFlagAck),
   875  			checker.TCPSeqNum(uint32(r.AckNum)),
   876  			checker.TCPAckNum(uint32(r.NextSeqNum)),
   877  			checker.TCPSACKBlockChecker(sackBlocks),
   878  		),
   879  	)
   880  }
   881  
   882  // CreateConnectedWithOptions creates and connects c.ep with the specified TCP
   883  // options enabled and returns a RawEndpoint which represents the other end of
   884  // the connection.
   885  //
   886  // It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
   887  // does not carry an option that was not requested.
   888  func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
   889  	var err tcpip.Error
   890  	c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
   891  	if err != nil {
   892  		c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err)
   893  	}
   894  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want {
   895  		c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
   896  	}
   897  
   898  	// Start connection attempt.
   899  	waitEntry, notifyCh := waiter.NewChannelEntry(nil)
   900  	c.WQ.EventRegister(&waitEntry, waiter.WritableEvents)
   901  	defer c.WQ.EventUnregister(&waitEntry)
   902  
   903  	testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort}
   904  	err = c.EP.Connect(testFullAddr)
   905  	if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
   906  		c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err)
   907  	}
   908  	// Receive SYN packet.
   909  	b := c.GetPacket()
   910  	// Validate that the syn has the timestamp option and a valid
   911  	// TS value.
   912  	mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
   913  
   914  	checker.IPv4(c.t, b,
   915  		checker.TCP(
   916  			checker.DstPort(TestPort),
   917  			checker.TCPFlags(header.TCPFlagSyn),
   918  			checker.TCPSynOptions(header.TCPSynOptions{
   919  				MSS:           mss,
   920  				TS:            true,
   921  				WS:            int(c.WindowScale),
   922  				SACKPermitted: c.SACKEnabled(),
   923  			}),
   924  		),
   925  	)
   926  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
   927  		c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
   928  	}
   929  
   930  	tcpSeg := header.TCP(header.IPv4(b).Payload())
   931  	synOptions := header.ParseSynOptions(tcpSeg.Options(), false)
   932  
   933  	// Build options w/ tsVal to be sent in the SYN-ACK.
   934  	synAckOptions := make([]byte, header.TCPOptionsMaximumSize)
   935  	offset := 0
   936  	if wantOptions.WS != -1 {
   937  		offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:])
   938  	}
   939  	if wantOptions.TS {
   940  		offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:])
   941  	}
   942  	if wantOptions.SACKPermitted {
   943  		offset += header.EncodeSACKPermittedOption(synAckOptions[offset:])
   944  	}
   945  
   946  	offset += header.AddTCPOptionPadding(synAckOptions, offset)
   947  
   948  	// Build SYN-ACK.
   949  	c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
   950  	iss := seqnum.Value(TestInitialSequenceNumber)
   951  	c.SendPacket(nil, &Headers{
   952  		SrcPort: tcpSeg.DestinationPort(),
   953  		DstPort: tcpSeg.SourcePort(),
   954  		Flags:   header.TCPFlagSyn | header.TCPFlagAck,
   955  		SeqNum:  iss,
   956  		AckNum:  c.IRS.Add(1),
   957  		RcvWnd:  30000,
   958  		TCPOpts: synAckOptions[:offset],
   959  	})
   960  
   961  	// Read ACK.
   962  	ackPacket := c.GetPacket()
   963  
   964  	// Verify TCP header fields.
   965  	tcpCheckers := []checker.TransportChecker{
   966  		checker.DstPort(TestPort),
   967  		checker.TCPFlags(header.TCPFlagAck),
   968  		checker.TCPSeqNum(uint32(c.IRS) + 1),
   969  		checker.TCPAckNum(uint32(iss) + 1),
   970  	}
   971  
   972  	// Verify that tsEcr of ACK packet is wantOptions.TSVal if the
   973  	// timestamp option was enabled, if not then we verify that
   974  	// there is no timestamp in the ACK packet.
   975  	if wantOptions.TS {
   976  		tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal))
   977  	} else {
   978  		tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
   979  	}
   980  
   981  	checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...))
   982  
   983  	ackSeg := header.TCP(header.IPv4(ackPacket).Payload())
   984  	ackOptions := ackSeg.ParsedOptions()
   985  
   986  	// Wait for connection to be established.
   987  	select {
   988  	case <-notifyCh:
   989  		if err := c.EP.LastError(); err != nil {
   990  			c.t.Fatalf("Unexpected error when connecting: %v", err)
   991  		}
   992  	case <-time.After(1 * time.Second):
   993  		c.t.Fatalf("Timed out waiting for connection")
   994  	}
   995  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
   996  		c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
   997  	}
   998  
   999  	// Store the source port in use by the endpoint.
  1000  	c.Port = tcpSeg.SourcePort()
  1001  
  1002  	// Mark in context that timestamp option is enabled for this endpoint.
  1003  	c.TimeStampEnabled = true
  1004  	c.RcvdWindowScale = uint8(synOptions.WS)
  1005  	return &RawEndpoint{
  1006  		C:             c,
  1007  		SrcPort:       tcpSeg.DestinationPort(),
  1008  		DstPort:       tcpSeg.SourcePort(),
  1009  		Flags:         header.TCPFlagAck | header.TCPFlagPsh,
  1010  		NextSeqNum:    iss + 1,
  1011  		AckNum:        c.IRS.Add(1),
  1012  		WndSize:       30000,
  1013  		RecentTS:      ackOptions.TSVal,
  1014  		TSVal:         wantOptions.TSVal,
  1015  		SACKPermitted: wantOptions.SACKPermitted,
  1016  	}
  1017  }
  1018  
  1019  // AcceptWithOptions initializes a listening endpoint and connects to it with the
  1020  // provided options enabled. It also verifies that the SYN-ACK has the expected
  1021  // values for the provided options.
  1022  //
  1023  // The function returns a RawEndpoint representing the other end of the accepted
  1024  // endpoint.
  1025  func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
  1026  	// Create EP and start listening.
  1027  	wq := &waiter.Queue{}
  1028  	ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
  1029  	if err != nil {
  1030  		c.t.Fatalf("NewEndpoint failed: %v", err)
  1031  	}
  1032  	defer ep.Close()
  1033  
  1034  	if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil {
  1035  		c.t.Fatalf("Bind failed: %v", err)
  1036  	}
  1037  	if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
  1038  		c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
  1039  	}
  1040  
  1041  	if err := ep.Listen(10); err != nil {
  1042  		c.t.Fatalf("Listen failed: %v", err)
  1043  	}
  1044  	if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
  1045  		c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
  1046  	}
  1047  
  1048  	rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
  1049  
  1050  	// Try to accept the connection.
  1051  	we, ch := waiter.NewChannelEntry(nil)
  1052  	wq.EventRegister(&we, waiter.ReadableEvents)
  1053  	defer wq.EventUnregister(&we)
  1054  
  1055  	c.EP, _, err = ep.Accept(nil)
  1056  	if _, ok := err.(*tcpip.ErrWouldBlock); ok {
  1057  		// Wait for connection to be established.
  1058  		select {
  1059  		case <-ch:
  1060  			c.EP, _, err = ep.Accept(nil)
  1061  			if err != nil {
  1062  				c.t.Fatalf("Accept failed: %v", err)
  1063  			}
  1064  
  1065  		case <-time.After(1 * time.Second):
  1066  			c.t.Fatalf("Timed out waiting for accept")
  1067  		}
  1068  	}
  1069  	if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want {
  1070  		c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
  1071  	}
  1072  
  1073  	return rep
  1074  }
  1075  
  1076  // PassiveConnect just disables WindowScaling and delegates the call to
  1077  // PassiveConnectWithOptions.
  1078  func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
  1079  	synOptions.WS = -1
  1080  	c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions)
  1081  }
  1082  
  1083  // PassiveConnectWithOptions initiates a new connection (with the specified TCP
  1084  // options enabled) to the port on which the Context.ep is listening for new
  1085  // connections. It also validates that the SYN-ACK has the expected values for
  1086  // the enabled options.
  1087  //
  1088  // NOTE: MSS is not a negotiated option and it can be asymmetric
  1089  // in each direction. This function uses the maxPayload to set the MSS to be
  1090  // sent to the peer on a connect and validates that the MSS in the SYN-ACK
  1091  // response is equal to the MTU - (tcphdr len + iphdr len).
  1092  //
  1093  // wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
  1094  // value of the window scaling option to be sent in the SYN. If synOptions.WS >
  1095  // 0 then we send the WindowScale option.
  1096  func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
  1097  	c.t.Helper()
  1098  	opts := make([]byte, header.TCPOptionsMaximumSize)
  1099  	offset := 0
  1100  	offset += header.EncodeMSSOption(uint32(maxPayload), opts)
  1101  
  1102  	if synOptions.WS >= 0 {
  1103  		offset += header.EncodeWSOption(3, opts[offset:])
  1104  	}
  1105  	if synOptions.TS {
  1106  		offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:])
  1107  	}
  1108  
  1109  	if synOptions.SACKPermitted {
  1110  		offset += header.EncodeSACKPermittedOption(opts[offset:])
  1111  	}
  1112  
  1113  	paddingToAdd := 4 - offset%4
  1114  	// Now add any padding bytes that might be required to quad align the
  1115  	// options.
  1116  	for i := offset; i < offset+paddingToAdd; i++ {
  1117  		opts[i] = header.TCPOptionNOP
  1118  	}
  1119  	offset += paddingToAdd
  1120  
  1121  	// Send a SYN request.
  1122  	iss := seqnum.Value(TestInitialSequenceNumber)
  1123  	c.SendPacket(nil, &Headers{
  1124  		SrcPort: TestPort,
  1125  		DstPort: StackPort,
  1126  		Flags:   header.TCPFlagSyn,
  1127  		SeqNum:  iss,
  1128  		RcvWnd:  30000,
  1129  		TCPOpts: opts[:offset],
  1130  	})
  1131  
  1132  	// Receive the SYN-ACK reply. Make sure MSS and other expected options
  1133  	// are present.
  1134  	b := c.GetPacket()
  1135  	tcp := header.TCP(header.IPv4(b).Payload())
  1136  	rcvdSynOptions := header.ParseSynOptions(tcp.Options(), true /* isAck */)
  1137  	c.IRS = seqnum.Value(tcp.SequenceNumber())
  1138  
  1139  	tcpCheckers := []checker.TransportChecker{
  1140  		checker.SrcPort(StackPort),
  1141  		checker.DstPort(TestPort),
  1142  		checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
  1143  		checker.TCPAckNum(uint32(iss) + 1),
  1144  		checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
  1145  	}
  1146  
  1147  	// If TS option was enabled in the original SYN then add a checker to
  1148  	// validate the Timestamp option in the SYN-ACK.
  1149  	if synOptions.TS {
  1150  		tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal))
  1151  	} else {
  1152  		tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
  1153  	}
  1154  
  1155  	checker.IPv4(c.t, b, checker.TCP(tcpCheckers...))
  1156  	rcvWnd := seqnum.Size(30000)
  1157  	ackHeaders := &Headers{
  1158  		SrcPort: TestPort,
  1159  		DstPort: StackPort,
  1160  		Flags:   header.TCPFlagAck,
  1161  		SeqNum:  iss + 1,
  1162  		AckNum:  c.IRS + 1,
  1163  		RcvWnd:  rcvWnd,
  1164  	}
  1165  
  1166  	// If WS was expected to be in effect then scale the advertised window
  1167  	// correspondingly.
  1168  	if synOptions.WS > 0 {
  1169  		ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS)
  1170  	}
  1171  
  1172  	parsedOpts := tcp.ParsedOptions()
  1173  	if synOptions.TS {
  1174  		// Echo the tsVal back to the peer in the tsEcr field of the
  1175  		// timestamp option.
  1176  		// Increment TSVal by 1 from the value sent in the SYN and echo
  1177  		// the TSVal in the SYN-ACK in the TSEcr field.
  1178  		opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
  1179  		header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:])
  1180  		ackHeaders.TCPOpts = opts[:]
  1181  	}
  1182  
  1183  	// Send ACK.
  1184  	c.SendPacket(nil, ackHeaders)
  1185  
  1186  	c.RcvdWindowScale = uint8(rcvdSynOptions.WS)
  1187  	c.Port = StackPort
  1188  
  1189  	return &RawEndpoint{
  1190  		C:             c,
  1191  		SrcPort:       TestPort,
  1192  		DstPort:       StackPort,
  1193  		Flags:         header.TCPFlagPsh | header.TCPFlagAck,
  1194  		NextSeqNum:    iss + 1,
  1195  		AckNum:        c.IRS + 1,
  1196  		WndSize:       rcvWnd,
  1197  		SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(),
  1198  		RecentTS:      parsedOpts.TSVal,
  1199  		TSVal:         synOptions.TSVal + 1,
  1200  	}
  1201  }
  1202  
  1203  // SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
  1204  // for the Stack in the context.
  1205  func (c *Context) SACKEnabled() bool {
  1206  	var v tcpip.TCPSACKEnabled
  1207  	if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
  1208  		// Stack doesn't support SACK. So just return.
  1209  		return false
  1210  	}
  1211  	return bool(v)
  1212  }
  1213  
  1214  // SetGSOEnabled enables or disables generic segmentation offload.
  1215  func (c *Context) SetGSOEnabled(enable bool) {
  1216  	if enable {
  1217  		c.linkEP.SupportedGSOKind = stack.HWGSOSupported
  1218  	} else {
  1219  		c.linkEP.SupportedGSOKind = stack.GSONotSupported
  1220  	}
  1221  }
  1222  
  1223  // MSSWithoutOptions returns the value for the MSS used by the stack when no
  1224  // options are in use.
  1225  func (c *Context) MSSWithoutOptions() uint16 {
  1226  	return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
  1227  }
  1228  
  1229  // MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no
  1230  // options are in use for IPv6 packets.
  1231  func (c *Context) MSSWithoutOptionsV6() uint16 {
  1232  	return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize)
  1233  }