github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/packetimpact/testbench/connections.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testbench
    16  
    17  import (
    18  	"fmt"
    19  	"math/rand"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/mohae/deepcopy"
    24  	"go.uber.org/multierr"
    25  	"golang.org/x/sys/unix"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/seqnum"
    29  )
    30  
    31  func portFromSockaddr(sa unix.Sockaddr) (uint16, error) {
    32  	switch sa := sa.(type) {
    33  	case *unix.SockaddrInet4:
    34  		return uint16(sa.Port), nil
    35  	case *unix.SockaddrInet6:
    36  		return uint16(sa.Port), nil
    37  	}
    38  	return 0, fmt.Errorf("sockaddr type %T does not contain port", sa)
    39  }
    40  
    41  // pickPort makes a new socket and returns the socket FD and port. The domain
    42  // should be AF_INET or AF_INET6. The caller must close the FD when done with
    43  // the port if there is no error.
    44  func (n *DUTTestNet) pickPort(domain, typ int) (fd int, port uint16, err error) {
    45  	fd, err = unix.Socket(domain, typ, 0)
    46  	if err != nil {
    47  		return -1, 0, fmt.Errorf("creating socket: %w", err)
    48  	}
    49  	defer func() {
    50  		if err != nil {
    51  			if cerr := unix.Close(fd); cerr != nil {
    52  				err = multierr.Append(err, fmt.Errorf("failed to close socket %d: %w", fd, cerr))
    53  			}
    54  		}
    55  	}()
    56  	var sa unix.Sockaddr
    57  	switch domain {
    58  	case unix.AF_INET:
    59  		var sa4 unix.SockaddrInet4
    60  		copy(sa4.Addr[:], n.LocalIPv4)
    61  		sa = &sa4
    62  	case unix.AF_INET6:
    63  		sa6 := unix.SockaddrInet6{ZoneId: n.LocalDevID}
    64  		copy(sa6.Addr[:], n.LocalIPv6)
    65  		sa = &sa6
    66  	default:
    67  		return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
    68  	}
    69  	if err = unix.Bind(fd, sa); err != nil {
    70  		return -1, 0, fmt.Errorf("binding to %+v: %w", sa, err)
    71  	}
    72  	sa, err = unix.Getsockname(fd)
    73  	if err != nil {
    74  		return -1, 0, fmt.Errorf("unix.Getsocketname(%d): %w", fd, err)
    75  	}
    76  	port, err = portFromSockaddr(sa)
    77  	if err != nil {
    78  		return -1, 0, fmt.Errorf("extracting port from socket address %+v: %w", sa, err)
    79  	}
    80  	return fd, port, nil
    81  }
    82  
    83  // layerState stores the state of a layer of a connection.
    84  type layerState interface {
    85  	// outgoing returns an outgoing layer to be sent in a frame. It should not
    86  	// update layerState, that is done in layerState.sent.
    87  	outgoing() Layer
    88  
    89  	// incoming creates an expected Layer for comparing against a received Layer.
    90  	// Because the expectation can depend on values in the received Layer, it is
    91  	// an input to incoming. For example, the ACK number needs to be checked in a
    92  	// TCP packet but only if the ACK flag is set in the received packet. It
    93  	// should not update layerState, that is done in layerState.received. The
    94  	// caller takes ownership of the returned Layer.
    95  	incoming(received Layer) Layer
    96  
    97  	// sent updates the layerState based on the Layer that was sent. The input is
    98  	// a Layer with all prev and next pointers populated so that the entire frame
    99  	// as it was sent is available.
   100  	sent(sent Layer) error
   101  
   102  	// received updates the layerState based on a Layer that is received. The
   103  	// input is a Layer with all prev and next pointers populated so that the
   104  	// entire frame as it was received is available.
   105  	received(received Layer) error
   106  
   107  	// close frees associated resources held by the LayerState.
   108  	close() error
   109  }
   110  
   111  // etherState maintains state about an Ethernet connection.
   112  type etherState struct {
   113  	out, in Ether
   114  }
   115  
   116  var _ layerState = (*etherState)(nil)
   117  
   118  // newEtherState creates a new etherState.
   119  func (n *DUTTestNet) newEtherState(out, in Ether) (*etherState, error) {
   120  	lmac := tcpip.LinkAddress(n.LocalMAC)
   121  	rmac := tcpip.LinkAddress(n.RemoteMAC)
   122  	s := etherState{
   123  		out: Ether{SrcAddr: &lmac, DstAddr: &rmac},
   124  		in:  Ether{SrcAddr: &rmac, DstAddr: &lmac},
   125  	}
   126  	if err := s.out.merge(&out); err != nil {
   127  		return nil, err
   128  	}
   129  	if err := s.in.merge(&in); err != nil {
   130  		return nil, err
   131  	}
   132  	return &s, nil
   133  }
   134  
   135  func (s *etherState) outgoing() Layer {
   136  	return deepcopy.Copy(&s.out).(Layer)
   137  }
   138  
   139  // incoming implements layerState.incoming.
   140  func (s *etherState) incoming(Layer) Layer {
   141  	return deepcopy.Copy(&s.in).(Layer)
   142  }
   143  
   144  func (*etherState) sent(Layer) error {
   145  	return nil
   146  }
   147  
   148  func (*etherState) received(Layer) error {
   149  	return nil
   150  }
   151  
   152  func (*etherState) close() error {
   153  	return nil
   154  }
   155  
   156  // ipv4State maintains state about an IPv4 connection.
   157  type ipv4State struct {
   158  	out, in IPv4
   159  }
   160  
   161  var _ layerState = (*ipv4State)(nil)
   162  
   163  // newIPv4State creates a new ipv4State.
   164  func (n *DUTTestNet) newIPv4State(out, in IPv4) (*ipv4State, error) {
   165  	lIP := tcpip.Address(n.LocalIPv4)
   166  	rIP := tcpip.Address(n.RemoteIPv4)
   167  	s := ipv4State{
   168  		out: IPv4{SrcAddr: &lIP, DstAddr: &rIP},
   169  		in:  IPv4{SrcAddr: &rIP, DstAddr: &lIP},
   170  	}
   171  	if err := s.out.merge(&out); err != nil {
   172  		return nil, err
   173  	}
   174  	if err := s.in.merge(&in); err != nil {
   175  		return nil, err
   176  	}
   177  	return &s, nil
   178  }
   179  
   180  func (s *ipv4State) outgoing() Layer {
   181  	return deepcopy.Copy(&s.out).(Layer)
   182  }
   183  
   184  // incoming implements layerState.incoming.
   185  func (s *ipv4State) incoming(Layer) Layer {
   186  	return deepcopy.Copy(&s.in).(Layer)
   187  }
   188  
   189  func (*ipv4State) sent(Layer) error {
   190  	return nil
   191  }
   192  
   193  func (*ipv4State) received(Layer) error {
   194  	return nil
   195  }
   196  
   197  func (*ipv4State) close() error {
   198  	return nil
   199  }
   200  
   201  // ipv6State maintains state about an IPv6 connection.
   202  type ipv6State struct {
   203  	out, in IPv6
   204  }
   205  
   206  var _ layerState = (*ipv6State)(nil)
   207  
   208  // newIPv6State creates a new ipv6State.
   209  func (n *DUTTestNet) newIPv6State(out, in IPv6) (*ipv6State, error) {
   210  	lIP := tcpip.Address(n.LocalIPv6)
   211  	rIP := tcpip.Address(n.RemoteIPv6)
   212  	s := ipv6State{
   213  		out: IPv6{SrcAddr: &lIP, DstAddr: &rIP},
   214  		in:  IPv6{SrcAddr: &rIP, DstAddr: &lIP},
   215  	}
   216  	if err := s.out.merge(&out); err != nil {
   217  		return nil, err
   218  	}
   219  	if err := s.in.merge(&in); err != nil {
   220  		return nil, err
   221  	}
   222  	return &s, nil
   223  }
   224  
   225  // outgoing returns an outgoing layer to be sent in a frame.
   226  func (s *ipv6State) outgoing() Layer {
   227  	return deepcopy.Copy(&s.out).(Layer)
   228  }
   229  
   230  func (s *ipv6State) incoming(Layer) Layer {
   231  	return deepcopy.Copy(&s.in).(Layer)
   232  }
   233  
   234  func (s *ipv6State) sent(Layer) error {
   235  	// Nothing to do.
   236  	return nil
   237  }
   238  
   239  func (s *ipv6State) received(Layer) error {
   240  	// Nothing to do.
   241  	return nil
   242  }
   243  
   244  // close cleans up any resources held.
   245  func (s *ipv6State) close() error {
   246  	return nil
   247  }
   248  
   249  // tcpState maintains state about a TCP connection.
   250  type tcpState struct {
   251  	out, in                   TCP
   252  	localSeqNum, remoteSeqNum *seqnum.Value
   253  	synAck                    *TCP
   254  	portPickerFD              int
   255  	finSent                   bool
   256  }
   257  
   258  var _ layerState = (*tcpState)(nil)
   259  
   260  // SeqNumValue is a helper routine that allocates a new seqnum.Value value to
   261  // store v and returns a pointer to it.
   262  func SeqNumValue(v seqnum.Value) *seqnum.Value {
   263  	return &v
   264  }
   265  
   266  // newTCPState creates a new TCPState.
   267  func (n *DUTTestNet) newTCPState(domain int, out, in TCP) (*tcpState, error) {
   268  	portPickerFD, localPort, err := n.pickPort(domain, unix.SOCK_STREAM)
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  	s := tcpState{
   273  		out:          TCP{SrcPort: &localPort},
   274  		in:           TCP{DstPort: &localPort},
   275  		localSeqNum:  SeqNumValue(seqnum.Value(rand.Uint32())),
   276  		portPickerFD: portPickerFD,
   277  		finSent:      false,
   278  	}
   279  	if err := s.out.merge(&out); err != nil {
   280  		return nil, err
   281  	}
   282  	if err := s.in.merge(&in); err != nil {
   283  		return nil, err
   284  	}
   285  	return &s, nil
   286  }
   287  
   288  func (s *tcpState) outgoing() Layer {
   289  	newOutgoing := deepcopy.Copy(s.out).(TCP)
   290  	if s.localSeqNum != nil {
   291  		newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum))
   292  	}
   293  	if s.remoteSeqNum != nil {
   294  		newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum))
   295  	}
   296  	return &newOutgoing
   297  }
   298  
   299  // incoming implements layerState.incoming.
   300  func (s *tcpState) incoming(received Layer) Layer {
   301  	tcpReceived, ok := received.(*TCP)
   302  	if !ok {
   303  		return nil
   304  	}
   305  	newIn := deepcopy.Copy(s.in).(TCP)
   306  	if s.remoteSeqNum != nil {
   307  		newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum))
   308  	}
   309  	if seq, flags := s.localSeqNum, tcpReceived.Flags; seq != nil && flags != nil && *flags&header.TCPFlagAck != 0 {
   310  		// The caller didn't specify an AckNum so we'll expect the calculated one,
   311  		// but only if the ACK flag is set because the AckNum is not valid in a
   312  		// header if ACK is not set.
   313  		newIn.AckNum = Uint32(uint32(*seq))
   314  	}
   315  	return &newIn
   316  }
   317  
   318  func (s *tcpState) sent(sent Layer) error {
   319  	tcp, ok := sent.(*TCP)
   320  	if !ok {
   321  		return fmt.Errorf("can't update tcpState with %T Layer", sent)
   322  	}
   323  	if !s.finSent {
   324  		// update localSeqNum by the payload only when FIN is not yet sent by us
   325  		for current := tcp.next(); current != nil; current = current.next() {
   326  			s.localSeqNum.UpdateForward(seqnum.Size(current.length()))
   327  		}
   328  	}
   329  	if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
   330  		s.localSeqNum.UpdateForward(1)
   331  	}
   332  	if *tcp.Flags&(header.TCPFlagFin) != 0 {
   333  		s.finSent = true
   334  	}
   335  	return nil
   336  }
   337  
   338  func (s *tcpState) received(l Layer) error {
   339  	tcp, ok := l.(*TCP)
   340  	if !ok {
   341  		return fmt.Errorf("can't update tcpState with %T Layer", l)
   342  	}
   343  	s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum))
   344  	if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
   345  		s.remoteSeqNum.UpdateForward(1)
   346  	}
   347  	for current := tcp.next(); current != nil; current = current.next() {
   348  		s.remoteSeqNum.UpdateForward(seqnum.Size(current.length()))
   349  	}
   350  	return nil
   351  }
   352  
   353  // close frees the port associated with this connection.
   354  func (s *tcpState) close() error {
   355  	if err := unix.Close(s.portPickerFD); err != nil {
   356  		return err
   357  	}
   358  	s.portPickerFD = -1
   359  	return nil
   360  }
   361  
   362  // udpState maintains state about a UDP connection.
   363  type udpState struct {
   364  	out, in      UDP
   365  	portPickerFD int
   366  }
   367  
   368  var _ layerState = (*udpState)(nil)
   369  
   370  // newUDPState creates a new udpState.
   371  func (n *DUTTestNet) newUDPState(domain int, out, in UDP) (*udpState, error) {
   372  	portPickerFD, localPort, err := n.pickPort(domain, unix.SOCK_DGRAM)
   373  	if err != nil {
   374  		return nil, fmt.Errorf("picking port: %w", err)
   375  	}
   376  	s := udpState{
   377  		out:          UDP{SrcPort: &localPort},
   378  		in:           UDP{DstPort: &localPort},
   379  		portPickerFD: portPickerFD,
   380  	}
   381  	if err := s.out.merge(&out); err != nil {
   382  		return nil, err
   383  	}
   384  	if err := s.in.merge(&in); err != nil {
   385  		return nil, err
   386  	}
   387  	return &s, nil
   388  }
   389  
   390  func (s *udpState) outgoing() Layer {
   391  	return deepcopy.Copy(&s.out).(Layer)
   392  }
   393  
   394  // incoming implements layerState.incoming.
   395  func (s *udpState) incoming(Layer) Layer {
   396  	return deepcopy.Copy(&s.in).(Layer)
   397  }
   398  
   399  func (*udpState) sent(l Layer) error {
   400  	return nil
   401  }
   402  
   403  func (*udpState) received(l Layer) error {
   404  	return nil
   405  }
   406  
   407  // close frees the port associated with this connection.
   408  func (s *udpState) close() error {
   409  	if err := unix.Close(s.portPickerFD); err != nil {
   410  		return err
   411  	}
   412  	s.portPickerFD = -1
   413  	return nil
   414  }
   415  
   416  // Connection holds a collection of layer states for maintaining a connection
   417  // along with sockets for sniffer and injecting packets.
   418  type Connection struct {
   419  	layerStates []layerState
   420  	injector    Injector
   421  	sniffer     Sniffer
   422  }
   423  
   424  // Returns the default incoming frame against which to match. If received is
   425  // longer than layerStates then that may still count as a match. The reverse is
   426  // never a match and nil is returned.
   427  func (conn *Connection) incoming(received Layers) Layers {
   428  	if len(received) < len(conn.layerStates) {
   429  		return nil
   430  	}
   431  	in := Layers{}
   432  	for i, s := range conn.layerStates {
   433  		toMatch := s.incoming(received[i])
   434  		if toMatch == nil {
   435  			return nil
   436  		}
   437  		in = append(in, toMatch)
   438  	}
   439  	return in
   440  }
   441  
   442  func (conn *Connection) match(override, received Layers) bool {
   443  	toMatch := conn.incoming(received)
   444  	if toMatch == nil {
   445  		return false // Not enough layers in gotLayers for matching.
   446  	}
   447  	if err := toMatch.merge(override); err != nil {
   448  		return false // Failing to merge is not matching.
   449  	}
   450  	return toMatch.match(received)
   451  }
   452  
   453  // Close frees associated resources held by the Connection.
   454  func (conn *Connection) Close(t *testing.T) {
   455  	t.Helper()
   456  
   457  	errs := multierr.Combine(conn.sniffer.close(), conn.injector.close())
   458  	for _, s := range conn.layerStates {
   459  		if err := s.close(); err != nil {
   460  			errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err))
   461  		}
   462  	}
   463  	if errs != nil {
   464  		t.Fatalf("unable to close %+v: %s", conn, errs)
   465  	}
   466  }
   467  
   468  // CreateFrame builds a frame for the connection with defaults overridden
   469  // from the innermost layer out, and additionalLayers added after it.
   470  //
   471  // Note that overrideLayers can have a length that is less than the number
   472  // of layers in this connection, and in such cases the innermost layers are
   473  // overridden first. As an example, valid values of overrideLayers for a TCP-
   474  // over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and
   475  // [Ethernet, IPv4, TCP].
   476  func (conn *Connection) CreateFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) Layers {
   477  	t.Helper()
   478  
   479  	var layersToSend Layers
   480  	for i, s := range conn.layerStates {
   481  		layer := s.outgoing()
   482  		// overrideLayers and conn.layerStates have their tails aligned, so
   483  		// to find the index we move backwards by the distance i is to the
   484  		// end.
   485  		if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 {
   486  			if err := layer.merge(overrideLayers[j]); err != nil {
   487  				t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err)
   488  			}
   489  		}
   490  		layersToSend = append(layersToSend, layer)
   491  	}
   492  	layersToSend = append(layersToSend, additionalLayers...)
   493  	return layersToSend
   494  }
   495  
   496  // SendFrameStateless sends a frame without updating any of the layer states.
   497  //
   498  // This method is useful for sending out-of-band control messages such as
   499  // ICMP packets, where it would not make sense to update the transport layer's
   500  // state using the ICMP header.
   501  func (conn *Connection) SendFrameStateless(t *testing.T, frame Layers) {
   502  	t.Helper()
   503  
   504  	outBytes, err := frame.ToBytes()
   505  	if err != nil {
   506  		t.Fatalf("can't build outgoing packet: %s", err)
   507  	}
   508  	conn.injector.Send(t, outBytes)
   509  }
   510  
   511  // SendFrame sends a frame on the wire and updates the state of all layers.
   512  func (conn *Connection) SendFrame(t *testing.T, frame Layers) {
   513  	t.Helper()
   514  
   515  	outBytes, err := frame.ToBytes()
   516  	if err != nil {
   517  		t.Fatalf("can't build outgoing packet: %s", err)
   518  	}
   519  	conn.injector.Send(t, outBytes)
   520  
   521  	// frame might have nil values where the caller wanted to use default values.
   522  	// sentFrame will have no nil values in it because it comes from parsing the
   523  	// bytes that were actually sent.
   524  	sentFrame := parse(parseEther, outBytes)
   525  	// Update the state of each layer based on what was sent.
   526  	for i, s := range conn.layerStates {
   527  		if err := s.sent(sentFrame[i]); err != nil {
   528  			t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err)
   529  		}
   530  	}
   531  }
   532  
   533  // send sends a packet, possibly with layers of this connection overridden and
   534  // additional layers added.
   535  //
   536  // Types defined with Connection as the underlying type should expose
   537  // type-safe versions of this method.
   538  func (conn *Connection) send(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) {
   539  	t.Helper()
   540  
   541  	conn.SendFrame(t, conn.CreateFrame(t, overrideLayers, additionalLayers...))
   542  }
   543  
   544  // recvFrame gets the next successfully parsed frame (of type Layers) within the
   545  // timeout provided. If no parsable frame arrives before the timeout, it returns
   546  // nil.
   547  func (conn *Connection) recvFrame(t *testing.T, timeout time.Duration) Layers {
   548  	t.Helper()
   549  
   550  	if timeout <= 0 {
   551  		return nil
   552  	}
   553  	b := conn.sniffer.Recv(t, timeout)
   554  	if b == nil {
   555  		return nil
   556  	}
   557  	return parse(parseEther, b)
   558  }
   559  
   560  // layersError stores the Layers that we got and the Layers that we wanted to
   561  // match.
   562  type layersError struct {
   563  	got, want Layers
   564  }
   565  
   566  func (e *layersError) Error() string {
   567  	return e.got.diff(e.want)
   568  }
   569  
   570  // Expect expects a frame with the final layerStates layer matching the
   571  // provided Layer within the timeout specified. If it doesn't arrive in time,
   572  // an error is returned.
   573  func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration) (Layer, error) {
   574  	t.Helper()
   575  
   576  	// Make a frame that will ignore all but the final layer.
   577  	layers := make([]Layer, len(conn.layerStates))
   578  	layers[len(layers)-1] = layer
   579  
   580  	gotFrame, err := conn.ExpectFrame(t, layers, timeout)
   581  	if err != nil {
   582  		return nil, err
   583  	}
   584  	if len(conn.layerStates)-1 < len(gotFrame) {
   585  		return gotFrame[len(conn.layerStates)-1], nil
   586  	}
   587  	t.Fatalf("the received frame should be at least as long as the expected layers, got %d layers, want at least %d layers, got frame: %#v", len(gotFrame), len(conn.layerStates), gotFrame)
   588  	panic("unreachable")
   589  }
   590  
   591  // ExpectFrame expects a frame that matches the provided Layers within the
   592  // timeout specified. If one arrives in time, the Layers is returned without an
   593  // error. If it doesn't arrive in time, it returns nil and error is non-nil.
   594  func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) {
   595  	t.Helper()
   596  
   597  	frames, ok := conn.ListenForFrame(t, layers, timeout)
   598  	if ok {
   599  		return frames[len(frames)-1], nil
   600  	}
   601  	if len(frames) == 0 {
   602  		return nil, fmt.Errorf("got no frames matching %s during %s", layers, timeout)
   603  	}
   604  
   605  	var errs error
   606  	for _, got := range frames {
   607  		want := conn.incoming(layers)
   608  		if err := want.merge(layers); err != nil {
   609  			errs = multierr.Combine(errs, err)
   610  		} else {
   611  			errs = multierr.Combine(errs, &layersError{got: got, want: want})
   612  		}
   613  	}
   614  	return nil, fmt.Errorf("got frames:\n%w want %s during %s", errs, layers, timeout)
   615  }
   616  
   617  // ListenForFrame captures all frames until a frame matches the provided Layers,
   618  // or until the timeout specified. Returns all captured frames, including the
   619  // matched frame, and true if the desired frame was found.
   620  func (conn *Connection) ListenForFrame(t *testing.T, layers Layers, timeout time.Duration) ([]Layers, bool) {
   621  	t.Helper()
   622  
   623  	deadline := time.Now().Add(timeout)
   624  	var frames []Layers
   625  	for {
   626  		var got Layers
   627  		if timeout := time.Until(deadline); timeout > 0 {
   628  			got = conn.recvFrame(t, timeout)
   629  		}
   630  		if got == nil {
   631  			return frames, false
   632  		}
   633  		frames = append(frames, got)
   634  		if conn.match(layers, got) {
   635  			for i, s := range conn.layerStates {
   636  				if err := s.received(got[i]); err != nil {
   637  					t.Fatalf("failed to update test connection's layer states based on received frame: %s", err)
   638  				}
   639  			}
   640  			return frames, true
   641  		}
   642  	}
   643  }
   644  
   645  // Drain drains the sniffer's receive buffer by receiving packets until there's
   646  // nothing else to receive.
   647  func (conn *Connection) Drain(t *testing.T) {
   648  	t.Helper()
   649  
   650  	conn.sniffer.Drain(t)
   651  }
   652  
   653  // TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection.
   654  type TCPIPv4 struct {
   655  	Connection
   656  }
   657  
   658  // NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
   659  func (n *DUTTestNet) NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
   660  	t.Helper()
   661  
   662  	etherState, err := n.newEtherState(Ether{}, Ether{})
   663  	if err != nil {
   664  		t.Fatalf("can't make etherState: %s", err)
   665  	}
   666  	ipv4State, err := n.newIPv4State(IPv4{}, IPv4{})
   667  	if err != nil {
   668  		t.Fatalf("can't make ipv4State: %s", err)
   669  	}
   670  	tcpState, err := n.newTCPState(unix.AF_INET, outgoingTCP, incomingTCP)
   671  	if err != nil {
   672  		t.Fatalf("can't make tcpState: %s", err)
   673  	}
   674  	injector, err := n.NewInjector(t)
   675  	if err != nil {
   676  		t.Fatalf("can't make injector: %s", err)
   677  	}
   678  	sniffer, err := n.NewSniffer(t)
   679  	if err != nil {
   680  		t.Fatalf("can't make sniffer: %s", err)
   681  	}
   682  
   683  	return TCPIPv4{
   684  		Connection: Connection{
   685  			layerStates: []layerState{etherState, ipv4State, tcpState},
   686  			injector:    injector,
   687  			sniffer:     sniffer,
   688  		},
   689  	}
   690  }
   691  
   692  // Connect performs a TCP 3-way handshake. The input Connection should have a
   693  // final TCP Layer.
   694  func (conn *TCPIPv4) Connect(t *testing.T) {
   695  	t.Helper()
   696  
   697  	// Send the SYN.
   698  	conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagSyn)})
   699  
   700  	// Wait for the SYN-ACK.
   701  	synAck, err := conn.Expect(t, TCP{Flags: TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
   702  	if err != nil {
   703  		t.Fatalf("didn't get synack during handshake: %s", err)
   704  	}
   705  	conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
   706  
   707  	// Send an ACK.
   708  	conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagAck)})
   709  }
   710  
   711  // ConnectWithOptions performs a TCP 3-way handshake with given TCP options.
   712  // The input Connection should have a final TCP Layer.
   713  func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) {
   714  	t.Helper()
   715  
   716  	// Send the SYN.
   717  	conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagSyn), Options: options})
   718  
   719  	// Wait for the SYN-ACK.
   720  	synAck, err := conn.Expect(t, TCP{Flags: TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
   721  	if err != nil {
   722  		t.Fatalf("didn't get synack during handshake: %s", err)
   723  	}
   724  	conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
   725  
   726  	// Send an ACK.
   727  	conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagAck)})
   728  }
   729  
   730  // ExpectData is a convenient method that expects a Layer and the Layer after
   731  // it. If it doesn't arrive in time, it returns nil.
   732  func (conn *TCPIPv4) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
   733  	t.Helper()
   734  
   735  	expected := make([]Layer, len(conn.layerStates))
   736  	expected[len(expected)-1] = tcp
   737  	if payload != nil {
   738  		expected = append(expected, payload)
   739  	}
   740  	return conn.ExpectFrame(t, expected, timeout)
   741  }
   742  
   743  // ExpectNextData attempts to receive the next incoming segment for the
   744  // connection and expects that to match the given layers.
   745  //
   746  // It differs from ExpectData() in that here we are only interested in the next
   747  // received segment, while ExpectData() can receive multiple segments for the
   748  // connection until there is a match with given layers or a timeout.
   749  func (conn *TCPIPv4) ExpectNextData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
   750  	t.Helper()
   751  
   752  	// Receive the first incoming TCP segment for this connection.
   753  	got, err := conn.ExpectData(t, &TCP{}, nil, timeout)
   754  	if err != nil {
   755  		return nil, err
   756  	}
   757  
   758  	expected := make([]Layer, len(conn.layerStates))
   759  	expected[len(expected)-1] = tcp
   760  	if payload != nil {
   761  		expected = append(expected, payload)
   762  		tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum(t)) - uint32(payload.Length()))
   763  	}
   764  	if !conn.match(expected, got) {
   765  		return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got)
   766  	}
   767  	return got, nil
   768  }
   769  
   770  // Send a packet with reasonable defaults. Potentially override the TCP layer in
   771  // the connection with the provided layer and add additionLayers.
   772  func (conn *TCPIPv4) Send(t *testing.T, tcp TCP, additionalLayers ...Layer) {
   773  	t.Helper()
   774  
   775  	conn.send(t, Layers{&tcp}, additionalLayers...)
   776  }
   777  
   778  // Expect expects a frame with the TCP layer matching the provided TCP within
   779  // the timeout specified. If it doesn't arrive in time, an error is returned.
   780  func (conn *TCPIPv4) Expect(t *testing.T, tcp TCP, timeout time.Duration) (*TCP, error) {
   781  	t.Helper()
   782  
   783  	layer, err := conn.Connection.Expect(t, &tcp, timeout)
   784  	if layer == nil {
   785  		return nil, err
   786  	}
   787  	gotTCP, ok := layer.(*TCP)
   788  	if !ok {
   789  		t.Fatalf("expected %s to be TCP", layer)
   790  	}
   791  	return gotTCP, err
   792  }
   793  
   794  func (conn *TCPIPv4) tcpState(t *testing.T) *tcpState {
   795  	t.Helper()
   796  
   797  	state, ok := conn.layerStates[2].(*tcpState)
   798  	if !ok {
   799  		t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2])
   800  	}
   801  	return state
   802  }
   803  
   804  func (conn *TCPIPv4) ipv4State(t *testing.T) *ipv4State {
   805  	t.Helper()
   806  
   807  	state, ok := conn.layerStates[1].(*ipv4State)
   808  	if !ok {
   809  		t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1])
   810  	}
   811  	return state
   812  }
   813  
   814  // RemoteSeqNum returns the next expected sequence number from the DUT.
   815  func (conn *TCPIPv4) RemoteSeqNum(t *testing.T) *seqnum.Value {
   816  	t.Helper()
   817  
   818  	return conn.tcpState(t).remoteSeqNum
   819  }
   820  
   821  // LocalSeqNum returns the next sequence number to send from the testbench.
   822  func (conn *TCPIPv4) LocalSeqNum(t *testing.T) *seqnum.Value {
   823  	t.Helper()
   824  
   825  	return conn.tcpState(t).localSeqNum
   826  }
   827  
   828  // SynAck returns the SynAck that was part of the handshake.
   829  func (conn *TCPIPv4) SynAck(t *testing.T) *TCP {
   830  	t.Helper()
   831  
   832  	return conn.tcpState(t).synAck
   833  }
   834  
   835  // LocalAddr gets the local socket address of this connection.
   836  func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 {
   837  	t.Helper()
   838  
   839  	sa := &unix.SockaddrInet4{Port: int(*conn.tcpState(t).out.SrcPort)}
   840  	copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr)
   841  	return sa
   842  }
   843  
   844  // GenerateOTWSeqSegment generates a segment with
   845  // seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only
   846  // acceptable when seqNumOffset is 0, otherwise an ACK is expected from the
   847  // receiver.
   848  func GenerateOTWSeqSegment(t *testing.T, conn *TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) TCP {
   849  	t.Helper()
   850  	lastAcceptable := conn.LocalSeqNum(t).Add(windowSize)
   851  	otwSeq := uint32(lastAcceptable.Add(seqNumOffset))
   852  	return TCP{SeqNum: Uint32(otwSeq), Flags: TCPFlags(header.TCPFlagAck)}
   853  }
   854  
   855  // GenerateUnaccACKSegment generates a segment with
   856  // acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable
   857  // when seqNumOffset is 0, otherwise an ACK is expected from the receiver.
   858  func GenerateUnaccACKSegment(t *testing.T, conn *TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) TCP {
   859  	t.Helper()
   860  	lastAcceptable := conn.RemoteSeqNum(t)
   861  	unaccAck := uint32(lastAcceptable.Add(seqNumOffset))
   862  	return TCP{AckNum: Uint32(unaccAck), Flags: TCPFlags(header.TCPFlagAck)}
   863  }
   864  
   865  // IPv4Conn maintains the state for all the layers in a IPv4 connection.
   866  type IPv4Conn struct {
   867  	Connection
   868  }
   869  
   870  // NewIPv4Conn creates a new IPv4Conn connection with reasonable defaults.
   871  func (n *DUTTestNet) NewIPv4Conn(t *testing.T, outgoingIPv4, incomingIPv4 IPv4) IPv4Conn {
   872  	t.Helper()
   873  
   874  	etherState, err := n.newEtherState(Ether{}, Ether{})
   875  	if err != nil {
   876  		t.Fatalf("can't make EtherState: %s", err)
   877  	}
   878  	ipv4State, err := n.newIPv4State(outgoingIPv4, incomingIPv4)
   879  	if err != nil {
   880  		t.Fatalf("can't make IPv4State: %s", err)
   881  	}
   882  
   883  	injector, err := n.NewInjector(t)
   884  	if err != nil {
   885  		t.Fatalf("can't make injector: %s", err)
   886  	}
   887  	sniffer, err := n.NewSniffer(t)
   888  	if err != nil {
   889  		t.Fatalf("can't make sniffer: %s", err)
   890  	}
   891  
   892  	return IPv4Conn{
   893  		Connection: Connection{
   894  			layerStates: []layerState{etherState, ipv4State},
   895  			injector:    injector,
   896  			sniffer:     sniffer,
   897  		},
   898  	}
   899  }
   900  
   901  // Send sends a frame with ipv4 overriding the IPv4 layer defaults and
   902  // additionalLayers added after it.
   903  func (c *IPv4Conn) Send(t *testing.T, ipv4 IPv4, additionalLayers ...Layer) {
   904  	t.Helper()
   905  
   906  	c.send(t, Layers{&ipv4}, additionalLayers...)
   907  }
   908  
   909  // IPv6Conn maintains the state for all the layers in a IPv6 connection.
   910  type IPv6Conn struct {
   911  	Connection
   912  }
   913  
   914  // NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults.
   915  func (n *DUTTestNet) NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn {
   916  	t.Helper()
   917  
   918  	etherState, err := n.newEtherState(Ether{}, Ether{})
   919  	if err != nil {
   920  		t.Fatalf("can't make EtherState: %s", err)
   921  	}
   922  	ipv6State, err := n.newIPv6State(outgoingIPv6, incomingIPv6)
   923  	if err != nil {
   924  		t.Fatalf("can't make IPv6State: %s", err)
   925  	}
   926  
   927  	injector, err := n.NewInjector(t)
   928  	if err != nil {
   929  		t.Fatalf("can't make injector: %s", err)
   930  	}
   931  	sniffer, err := n.NewSniffer(t)
   932  	if err != nil {
   933  		t.Fatalf("can't make sniffer: %s", err)
   934  	}
   935  
   936  	return IPv6Conn{
   937  		Connection: Connection{
   938  			layerStates: []layerState{etherState, ipv6State},
   939  			injector:    injector,
   940  			sniffer:     sniffer,
   941  		},
   942  	}
   943  }
   944  
   945  // Send sends a frame with ipv6 overriding the IPv6 layer defaults and
   946  // additionalLayers added after it.
   947  func (conn *IPv6Conn) Send(t *testing.T, ipv6 IPv6, additionalLayers ...Layer) {
   948  	t.Helper()
   949  
   950  	conn.send(t, Layers{&ipv6}, additionalLayers...)
   951  }
   952  
   953  // UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
   954  type UDPIPv4 struct {
   955  	Connection
   956  }
   957  
   958  // NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults.
   959  func (n *DUTTestNet) NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
   960  	t.Helper()
   961  
   962  	etherState, err := n.newEtherState(Ether{}, Ether{})
   963  	if err != nil {
   964  		t.Fatalf("can't make etherState: %s", err)
   965  	}
   966  	ipv4State, err := n.newIPv4State(IPv4{}, IPv4{})
   967  	if err != nil {
   968  		t.Fatalf("can't make ipv4State: %s", err)
   969  	}
   970  	udpState, err := n.newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
   971  	if err != nil {
   972  		t.Fatalf("can't make udpState: %s", err)
   973  	}
   974  	injector, err := n.NewInjector(t)
   975  	if err != nil {
   976  		t.Fatalf("can't make injector: %s", err)
   977  	}
   978  	sniffer, err := n.NewSniffer(t)
   979  	if err != nil {
   980  		t.Fatalf("can't make sniffer: %s", err)
   981  	}
   982  
   983  	return UDPIPv4{
   984  		Connection: Connection{
   985  			layerStates: []layerState{etherState, ipv4State, udpState},
   986  			injector:    injector,
   987  			sniffer:     sniffer,
   988  		},
   989  	}
   990  }
   991  
   992  func (conn *UDPIPv4) udpState(t *testing.T) *udpState {
   993  	t.Helper()
   994  
   995  	state, ok := conn.layerStates[2].(*udpState)
   996  	if !ok {
   997  		t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
   998  	}
   999  	return state
  1000  }
  1001  
  1002  func (conn *UDPIPv4) ipv4State(t *testing.T) *ipv4State {
  1003  	t.Helper()
  1004  
  1005  	state, ok := conn.layerStates[1].(*ipv4State)
  1006  	if !ok {
  1007  		t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1])
  1008  	}
  1009  	return state
  1010  }
  1011  
  1012  // LocalAddr gets the local socket address of this connection.
  1013  func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 {
  1014  	t.Helper()
  1015  
  1016  	sa := &unix.SockaddrInet4{Port: int(*conn.udpState(t).out.SrcPort)}
  1017  	copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr)
  1018  	return sa
  1019  }
  1020  
  1021  // SrcPort returns the source port of this connection.
  1022  func (conn *UDPIPv4) SrcPort(t *testing.T) uint16 {
  1023  	t.Helper()
  1024  
  1025  	return *conn.udpState(t).out.SrcPort
  1026  }
  1027  
  1028  // Send sends a packet with reasonable defaults, potentially overriding the UDP
  1029  // layer and adding additionLayers.
  1030  func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) {
  1031  	t.Helper()
  1032  
  1033  	conn.send(t, Layers{&udp}, additionalLayers...)
  1034  }
  1035  
  1036  // SendIP sends a packet with reasonable defaults, potentially overriding the
  1037  // UDP and IPv4 headers and adding additionLayers.
  1038  func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...Layer) {
  1039  	t.Helper()
  1040  
  1041  	conn.send(t, Layers{&ip, &udp}, additionalLayers...)
  1042  }
  1043  
  1044  // SendFrame sends a frame on the wire and updates the state of all layers.
  1045  func (conn *UDPIPv4) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) {
  1046  	t.Helper()
  1047  
  1048  	conn.send(t, overrideLayers, additionalLayers...)
  1049  }
  1050  
  1051  // Expect expects a frame with the UDP layer matching the provided UDP within
  1052  // the timeout specified. If it doesn't arrive in time, an error is returned.
  1053  func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) {
  1054  	t.Helper()
  1055  
  1056  	layer, err := conn.Connection.Expect(t, &udp, timeout)
  1057  	if err != nil {
  1058  		return nil, err
  1059  	}
  1060  	gotUDP, ok := layer.(*UDP)
  1061  	if !ok {
  1062  		t.Fatalf("expected %s to be UDP", layer)
  1063  	}
  1064  	return gotUDP, nil
  1065  }
  1066  
  1067  // ExpectData is a convenient method that expects a Layer and the Layer after
  1068  // it. If it doesn't arrive in time, it returns nil.
  1069  func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
  1070  	t.Helper()
  1071  
  1072  	expected := make([]Layer, len(conn.layerStates))
  1073  	expected[len(expected)-1] = &udp
  1074  	if payload.length() != 0 {
  1075  		expected = append(expected, &payload)
  1076  	}
  1077  	return conn.ExpectFrame(t, expected, timeout)
  1078  }
  1079  
  1080  // UDPIPv6 maintains the state for all the layers in a UDP/IPv6 connection.
  1081  type UDPIPv6 struct {
  1082  	Connection
  1083  }
  1084  
  1085  // NewUDPIPv6 creates a new UDPIPv6 connection with reasonable defaults.
  1086  func (n *DUTTestNet) NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 {
  1087  	t.Helper()
  1088  
  1089  	etherState, err := n.newEtherState(Ether{}, Ether{})
  1090  	if err != nil {
  1091  		t.Fatalf("can't make etherState: %s", err)
  1092  	}
  1093  	ipv6State, err := n.newIPv6State(IPv6{}, IPv6{})
  1094  	if err != nil {
  1095  		t.Fatalf("can't make IPv6State: %s", err)
  1096  	}
  1097  	udpState, err := n.newUDPState(unix.AF_INET6, outgoingUDP, incomingUDP)
  1098  	if err != nil {
  1099  		t.Fatalf("can't make udpState: %s", err)
  1100  	}
  1101  	injector, err := n.NewInjector(t)
  1102  	if err != nil {
  1103  		t.Fatalf("can't make injector: %s", err)
  1104  	}
  1105  	sniffer, err := n.NewSniffer(t)
  1106  	if err != nil {
  1107  		t.Fatalf("can't make sniffer: %s", err)
  1108  	}
  1109  	return UDPIPv6{
  1110  		Connection: Connection{
  1111  			layerStates: []layerState{etherState, ipv6State, udpState},
  1112  			injector:    injector,
  1113  			sniffer:     sniffer,
  1114  		},
  1115  	}
  1116  }
  1117  
  1118  func (conn *UDPIPv6) udpState(t *testing.T) *udpState {
  1119  	t.Helper()
  1120  
  1121  	state, ok := conn.layerStates[2].(*udpState)
  1122  	if !ok {
  1123  		t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
  1124  	}
  1125  	return state
  1126  }
  1127  
  1128  func (conn *UDPIPv6) ipv6State(t *testing.T) *ipv6State {
  1129  	t.Helper()
  1130  
  1131  	state, ok := conn.layerStates[1].(*ipv6State)
  1132  	if !ok {
  1133  		t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1])
  1134  	}
  1135  	return state
  1136  }
  1137  
  1138  // LocalAddr gets the local socket address of this connection.
  1139  func (conn *UDPIPv6) LocalAddr(t *testing.T, zoneID uint32) *unix.SockaddrInet6 {
  1140  	t.Helper()
  1141  
  1142  	sa := &unix.SockaddrInet6{
  1143  		Port: int(*conn.udpState(t).out.SrcPort),
  1144  		// Local address is in perspective to the remote host, so it's scoped to the
  1145  		// ID of the remote interface.
  1146  		ZoneId: zoneID,
  1147  	}
  1148  	copy(sa.Addr[:], *conn.ipv6State(t).out.SrcAddr)
  1149  	return sa
  1150  }
  1151  
  1152  // SrcPort returns the source port of this connection.
  1153  func (conn *UDPIPv6) SrcPort(t *testing.T) uint16 {
  1154  	t.Helper()
  1155  
  1156  	return *conn.udpState(t).out.SrcPort
  1157  }
  1158  
  1159  // Send sends a packet with reasonable defaults, potentially overriding the UDP
  1160  // layer and adding additionLayers.
  1161  func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) {
  1162  	t.Helper()
  1163  
  1164  	conn.send(t, Layers{&udp}, additionalLayers...)
  1165  }
  1166  
  1167  // SendIPv6 sends a packet with reasonable defaults, potentially overriding the
  1168  // UDP and IPv6 headers and adding additionLayers.
  1169  func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers ...Layer) {
  1170  	t.Helper()
  1171  
  1172  	conn.send(t, Layers{&ip, &udp}, additionalLayers...)
  1173  }
  1174  
  1175  // SendFrame sends a frame on the wire and updates the state of all layers.
  1176  func (conn *UDPIPv6) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) {
  1177  	conn.send(t, overrideLayers, additionalLayers...)
  1178  }
  1179  
  1180  // Expect expects a frame with the UDP layer matching the provided UDP within
  1181  // the timeout specified. If it doesn't arrive in time, an error is returned.
  1182  func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) {
  1183  	t.Helper()
  1184  
  1185  	layer, err := conn.Connection.Expect(t, &udp, timeout)
  1186  	if err != nil {
  1187  		return nil, err
  1188  	}
  1189  	gotUDP, ok := layer.(*UDP)
  1190  	if !ok {
  1191  		t.Fatalf("expected %s to be UDP", layer)
  1192  	}
  1193  	return gotUDP, nil
  1194  }
  1195  
  1196  // ExpectData is a convenient method that expects a Layer and the Layer after
  1197  // it. If it doesn't arrive in time, it returns nil.
  1198  func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
  1199  	t.Helper()
  1200  
  1201  	expected := make([]Layer, len(conn.layerStates))
  1202  	expected[len(expected)-1] = &udp
  1203  	if payload.length() != 0 {
  1204  		expected = append(expected, &payload)
  1205  	}
  1206  	return conn.ExpectFrame(t, expected, timeout)
  1207  }
  1208  
  1209  // TCPIPv6 maintains the state for all the layers in a TCP/IPv6 connection.
  1210  type TCPIPv6 struct {
  1211  	Connection
  1212  }
  1213  
  1214  // NewTCPIPv6 creates a new TCPIPv6 connection with reasonable defaults.
  1215  func (n *DUTTestNet) NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 {
  1216  	etherState, err := n.newEtherState(Ether{}, Ether{})
  1217  	if err != nil {
  1218  		t.Fatalf("can't make etherState: %s", err)
  1219  	}
  1220  	ipv6State, err := n.newIPv6State(IPv6{}, IPv6{})
  1221  	if err != nil {
  1222  		t.Fatalf("can't make ipv6State: %s", err)
  1223  	}
  1224  	tcpState, err := n.newTCPState(unix.AF_INET6, outgoingTCP, incomingTCP)
  1225  	if err != nil {
  1226  		t.Fatalf("can't make tcpState: %s", err)
  1227  	}
  1228  	injector, err := n.NewInjector(t)
  1229  	if err != nil {
  1230  		t.Fatalf("can't make injector: %s", err)
  1231  	}
  1232  	sniffer, err := n.NewSniffer(t)
  1233  	if err != nil {
  1234  		t.Fatalf("can't make sniffer: %s", err)
  1235  	}
  1236  
  1237  	return TCPIPv6{
  1238  		Connection: Connection{
  1239  			layerStates: []layerState{etherState, ipv6State, tcpState},
  1240  			injector:    injector,
  1241  			sniffer:     sniffer,
  1242  		},
  1243  	}
  1244  }
  1245  
  1246  // SrcPort returns the source port from the given Connection.
  1247  func (conn *TCPIPv6) SrcPort() uint16 {
  1248  	state := conn.layerStates[2].(*tcpState)
  1249  	return *state.out.SrcPort
  1250  }
  1251  
  1252  // ExpectData is a convenient method that expects a Layer and the Layer after
  1253  // it. If it doesn't arrive in time, it returns nil.
  1254  func (conn *TCPIPv6) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
  1255  	t.Helper()
  1256  
  1257  	expected := make([]Layer, len(conn.layerStates))
  1258  	expected[len(expected)-1] = tcp
  1259  	if payload != nil {
  1260  		expected = append(expected, payload)
  1261  	}
  1262  	return conn.ExpectFrame(t, expected, timeout)
  1263  }