gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/packetimpact/testbench/layers.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testbench
    16  
    17  import (
    18  	"encoding/binary"
    19  	"encoding/hex"
    20  	"fmt"
    21  	"net"
    22  	"reflect"
    23  	"strings"
    24  
    25  	"github.com/google/go-cmp/cmp"
    26  	"github.com/google/go-cmp/cmp/cmpopts"
    27  	"go.uber.org/multierr"
    28  	"gvisor.dev/gvisor/pkg/buffer"
    29  	"gvisor.dev/gvisor/pkg/tcpip"
    30  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    31  	"gvisor.dev/gvisor/pkg/tcpip/header"
    32  )
    33  
    34  // Layer is the interface that all encapsulations must implement.
    35  //
    36  // A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A
    37  // Layer contains all the fields of the encapsulation. Each field is a pointer
    38  // and may be nil.
    39  type Layer interface {
    40  	fmt.Stringer
    41  
    42  	// ToBytes converts the Layer into bytes. In places where the Layer's field
    43  	// isn't nil, the value that is pointed to is used. When the field is nil, a
    44  	// reasonable default for the Layer is used. For example, "64" for IPv4 TTL
    45  	// and a calculated checksum for TCP or IP. Some layers require information
    46  	// from the previous or next layers in order to compute a default, such as
    47  	// TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list
    48  	// to the layer's neighbors.
    49  	ToBytes() ([]byte, error)
    50  
    51  	// match checks if the current Layer matches the provided Layer. If either
    52  	// Layer has a nil in a given field, that field is considered matching.
    53  	// Otherwise, the values pointed to by the fields must match. The LayerBase is
    54  	// ignored.
    55  	match(Layer) bool
    56  
    57  	// length in bytes of the current encapsulation
    58  	length() int
    59  
    60  	// next gets a pointer to the encapsulated Layer.
    61  	next() Layer
    62  
    63  	// prev gets a pointer to the Layer encapsulating this one.
    64  	Prev() Layer
    65  
    66  	// setNext sets the pointer to the encapsulated Layer.
    67  	setNext(Layer)
    68  
    69  	// setPrev sets the pointer to the Layer encapsulating this one.
    70  	setPrev(Layer)
    71  
    72  	// merge overrides the values in the interface with the provided values.
    73  	merge(Layer) error
    74  }
    75  
    76  // LayerBase is the common elements of all layers.
    77  type LayerBase struct {
    78  	nextLayer Layer
    79  	prevLayer Layer
    80  }
    81  
    82  func (lb *LayerBase) next() Layer {
    83  	return lb.nextLayer
    84  }
    85  
    86  // Prev returns the previous layer.
    87  func (lb *LayerBase) Prev() Layer {
    88  	return lb.prevLayer
    89  }
    90  
    91  func (lb *LayerBase) setNext(l Layer) {
    92  	lb.nextLayer = l
    93  }
    94  
    95  func (lb *LayerBase) setPrev(l Layer) {
    96  	lb.prevLayer = l
    97  }
    98  
    99  // equalLayer compares that two Layer structs match while ignoring field in
   100  // which either input has a nil and also ignoring the LayerBase of the inputs.
   101  func equalLayer(x, y Layer) bool {
   102  	if x == nil || y == nil {
   103  		return true
   104  	}
   105  	// opt ignores comparison pairs where either of the inputs is a nil.
   106  	opt := cmp.FilterValues(func(x, y any) bool {
   107  		for _, l := range []any{x, y} {
   108  			v := reflect.ValueOf(l)
   109  			if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() {
   110  				return true
   111  			}
   112  		}
   113  		return false
   114  	}, cmp.Ignore())
   115  	return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{}))
   116  }
   117  
   118  // mergeLayer merges y into x. Any fields for which y has a non-nil value, that
   119  // value overwrite the corresponding fields in x.
   120  func mergeLayer(x, y Layer) error {
   121  	if y == nil {
   122  		return nil
   123  	}
   124  	if reflect.TypeOf(x) != reflect.TypeOf(y) {
   125  		return fmt.Errorf("can't merge %T into %T", y, x)
   126  	}
   127  	vx := reflect.ValueOf(x).Elem()
   128  	vy := reflect.ValueOf(y).Elem()
   129  	t := vy.Type()
   130  	for i := 0; i < vy.NumField(); i++ {
   131  		t := t.Field(i)
   132  		if t.Anonymous {
   133  			// Ignore the LayerBase in the Layer struct.
   134  			continue
   135  		}
   136  		v := vy.Field(i)
   137  		if v.IsNil() {
   138  			continue
   139  		}
   140  		vx.Field(i).Set(v)
   141  	}
   142  	return nil
   143  }
   144  
   145  func stringLayer(l Layer) string {
   146  	v := reflect.ValueOf(l).Elem()
   147  	t := v.Type()
   148  	var ret []string
   149  	for i := 0; i < v.NumField(); i++ {
   150  		t := t.Field(i)
   151  		if t.Anonymous {
   152  			// Ignore the LayerBase in the Layer struct.
   153  			continue
   154  		}
   155  		v := v.Field(i)
   156  		if v.IsNil() {
   157  			continue
   158  		}
   159  		v = reflect.Indirect(v)
   160  		switch {
   161  		// Try to use Stringers appropriately.
   162  		case v.Type().Implements(reflect.TypeOf((*fmt.Stringer)(nil)).Elem()):
   163  			ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v))
   164  		// Print byte slices as hex.
   165  		case v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8:
   166  			ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes())))
   167  		// Otherwise just let Go decide how to print.
   168  		default:
   169  			ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v))
   170  		}
   171  	}
   172  	return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " "))
   173  }
   174  
   175  // Ether can construct and match an ethernet encapsulation.
   176  type Ether struct {
   177  	LayerBase
   178  	SrcAddr *tcpip.LinkAddress
   179  	DstAddr *tcpip.LinkAddress
   180  	Type    *tcpip.NetworkProtocolNumber
   181  }
   182  
   183  func (l *Ether) String() string {
   184  	return stringLayer(l)
   185  }
   186  
   187  // ToBytes implements Layer.ToBytes.
   188  func (l *Ether) ToBytes() ([]byte, error) {
   189  	b := make([]byte, header.EthernetMinimumSize)
   190  	h := header.Ethernet(b)
   191  	fields := &header.EthernetFields{}
   192  	if l.SrcAddr != nil {
   193  		fields.SrcAddr = *l.SrcAddr
   194  	}
   195  	if l.DstAddr != nil {
   196  		fields.DstAddr = *l.DstAddr
   197  	}
   198  	if l.Type != nil {
   199  		fields.Type = *l.Type
   200  	} else {
   201  		switch n := l.next().(type) {
   202  		case *IPv4:
   203  			fields.Type = header.IPv4ProtocolNumber
   204  		case *IPv6:
   205  			fields.Type = header.IPv6ProtocolNumber
   206  		default:
   207  			return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n)
   208  		}
   209  	}
   210  	h.Encode(fields)
   211  	return h, nil
   212  }
   213  
   214  // LinkAddress is a helper routine that allocates a new tcpip.LinkAddress value
   215  // to store v and returns a pointer to it.
   216  func LinkAddress(v tcpip.LinkAddress) *tcpip.LinkAddress {
   217  	return &v
   218  }
   219  
   220  // NetworkProtocolNumber is a helper routine that allocates a new
   221  // tcpip.NetworkProtocolNumber value to store v and returns a pointer to it.
   222  func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocolNumber {
   223  	return &v
   224  }
   225  
   226  // bodySizeHint describes num of bytes left to parse for the rest of layers.
   227  type bodySizeHint int
   228  
   229  const bodySizeUnknown bodySizeHint = -1
   230  
   231  // layerParser parses the input bytes and returns a Layer along with the next
   232  // layerParser to run. If there is no more parsing to do, the returned
   233  // layerParser is nil.
   234  type layerParser func([]byte) (Layer, bodySizeHint, layerParser)
   235  
   236  // parse parses bytes starting with the first layerParser and using successive
   237  // layerParsers until all the bytes are parsed.
   238  func parse(parser layerParser, b []byte) Layers {
   239  	var layers Layers
   240  	for {
   241  		layer, hint, next := parser(b)
   242  		layers = append(layers, layer)
   243  		if parser == nil {
   244  			break
   245  		}
   246  		b = b[layer.length():]
   247  		if hint != bodySizeUnknown {
   248  			b = b[:hint]
   249  		}
   250  		if next == nil {
   251  			break
   252  		}
   253  		parser = next
   254  	}
   255  	layers.linkLayers()
   256  	return layers
   257  }
   258  
   259  // parseEther parses the bytes assuming that they start with an ethernet header
   260  // and continues parsing further encapsulations.
   261  func parseEther(b []byte) (Layer, bodySizeHint, layerParser) {
   262  	h := header.Ethernet(b)
   263  	ether := Ether{
   264  		SrcAddr: LinkAddress(h.SourceAddress()),
   265  		DstAddr: LinkAddress(h.DestinationAddress()),
   266  		Type:    NetworkProtocolNumber(h.Type()),
   267  	}
   268  	var nextParser layerParser
   269  	switch h.Type() {
   270  	case header.IPv4ProtocolNumber:
   271  		nextParser = parseIPv4
   272  	case header.IPv6ProtocolNumber:
   273  		nextParser = parseIPv6
   274  	default:
   275  		// Assume that the rest is a payload.
   276  		nextParser = parsePayload
   277  	}
   278  	return &ether, bodySizeUnknown, nextParser
   279  }
   280  
   281  func (l *Ether) match(other Layer) bool {
   282  	return equalLayer(l, other)
   283  }
   284  
   285  func (l *Ether) length() int {
   286  	return header.EthernetMinimumSize
   287  }
   288  
   289  // merge implements Layer.merge.
   290  func (l *Ether) merge(other Layer) error {
   291  	return mergeLayer(l, other)
   292  }
   293  
   294  // IPv4 can construct and match an IPv4 encapsulation.
   295  type IPv4 struct {
   296  	LayerBase
   297  	IHL            *uint8
   298  	TOS            *uint8
   299  	TotalLength    *uint16
   300  	ID             *uint16
   301  	Flags          *uint8
   302  	FragmentOffset *uint16
   303  	TTL            *uint8
   304  	Protocol       *uint8
   305  	Checksum       *uint16
   306  	SrcAddr        *net.IP
   307  	DstAddr        *net.IP
   308  	Options        *header.IPv4Options
   309  }
   310  
   311  func (l *IPv4) String() string {
   312  	return stringLayer(l)
   313  }
   314  
   315  // ToBytes implements Layer.ToBytes.
   316  func (l *IPv4) ToBytes() ([]byte, error) {
   317  	// An IPv4 header is variable length depending on the size of the Options.
   318  	hdrLen := header.IPv4MinimumSize
   319  	if l.Options != nil {
   320  		if len(*l.Options)%4 != 0 {
   321  			return nil, fmt.Errorf("invalid header options '%x (len=%d)'; must be 32 bit aligned", *l.Options, len(*l.Options))
   322  		}
   323  		hdrLen += len(*l.Options)
   324  		if hdrLen > header.IPv4MaximumHeaderSize {
   325  			return nil, fmt.Errorf("IPv4 Options %d bytes, Max %d", len(*l.Options), header.IPv4MaximumOptionsSize)
   326  		}
   327  	}
   328  	b := make([]byte, hdrLen)
   329  	h := header.IPv4(b)
   330  	fields := &header.IPv4Fields{
   331  		TOS:            0,
   332  		TotalLength:    0,
   333  		ID:             0,
   334  		Flags:          0,
   335  		FragmentOffset: 0,
   336  		TTL:            64,
   337  		Protocol:       0,
   338  		Checksum:       0,
   339  		SrcAddr:        tcpip.Address{},
   340  		DstAddr:        tcpip.Address{},
   341  		Options:        nil,
   342  	}
   343  	if l.TOS != nil {
   344  		fields.TOS = *l.TOS
   345  	}
   346  	if l.TotalLength != nil {
   347  		fields.TotalLength = *l.TotalLength
   348  	} else {
   349  		fields.TotalLength = uint16(l.length())
   350  		current := l.next()
   351  		for current != nil {
   352  			fields.TotalLength += uint16(current.length())
   353  			current = current.next()
   354  		}
   355  	}
   356  	if l.ID != nil {
   357  		fields.ID = *l.ID
   358  	}
   359  	if l.Flags != nil {
   360  		fields.Flags = *l.Flags
   361  	}
   362  	if l.FragmentOffset != nil {
   363  		fields.FragmentOffset = *l.FragmentOffset
   364  	}
   365  	if l.TTL != nil {
   366  		fields.TTL = *l.TTL
   367  	}
   368  	if l.Protocol != nil {
   369  		fields.Protocol = *l.Protocol
   370  	} else {
   371  		switch n := l.next().(type) {
   372  		case *TCP:
   373  			fields.Protocol = uint8(header.TCPProtocolNumber)
   374  		case *UDP:
   375  			fields.Protocol = uint8(header.UDPProtocolNumber)
   376  		case *ICMPv4:
   377  			fields.Protocol = uint8(header.ICMPv4ProtocolNumber)
   378  		default:
   379  			// We can add support for more protocols as needed.
   380  			return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
   381  		}
   382  	}
   383  	if l.SrcAddr != nil && len(*l.SrcAddr) > 0 {
   384  		fields.SrcAddr = tcpip.AddrFrom4Slice(*l.SrcAddr)
   385  	}
   386  	if l.DstAddr != nil && len(*l.DstAddr) > 0 {
   387  		fields.DstAddr = tcpip.AddrFrom4Slice(*l.DstAddr)
   388  	}
   389  
   390  	h.Encode(fields)
   391  
   392  	// Put raw option bytes from test definition in header. Options as raw bytes
   393  	// allows us to serialize malformed options, which is not possible with
   394  	// the provided serialization functions.
   395  	if l.Options != nil {
   396  		h.SetHeaderLength(h.HeaderLength() + uint8(len(*l.Options)))
   397  		if got, want := copy(h.Options(), *l.Options), len(*l.Options); got != want {
   398  			return nil, fmt.Errorf("failed to copy option bytes into header, got %d want %d", got, want)
   399  		}
   400  	}
   401  
   402  	// Encode cannot set this incorrectly so we need to overwrite what it wrote
   403  	// in order to test handling of a bad IHL value.
   404  	if l.IHL != nil {
   405  		h.SetHeaderLength(*l.IHL)
   406  	}
   407  
   408  	if l.Checksum == nil {
   409  		h.SetChecksum(^h.CalculateChecksum())
   410  	} else {
   411  		h.SetChecksum(*l.Checksum)
   412  	}
   413  
   414  	return h, nil
   415  }
   416  
   417  // Uint16 is a helper routine that allocates a new
   418  // uint16 value to store v and returns a pointer to it.
   419  func Uint16(v uint16) *uint16 {
   420  	return &v
   421  }
   422  
   423  // Uint8 is a helper routine that allocates a new
   424  // uint8 value to store v and returns a pointer to it.
   425  func Uint8(v uint8) *uint8 {
   426  	return &v
   427  }
   428  
   429  // TCPFlags is a helper routine that allocates a new
   430  // header.TCPFlags value to store v and returns a pointer to it.
   431  func TCPFlags(v header.TCPFlags) *header.TCPFlags {
   432  	return &v
   433  }
   434  
   435  // Address is a helper routine that allocates a new net.IP value to
   436  // store v and returns a pointer to it.
   437  func Address(v tcpip.Address) *net.IP {
   438  	bs := make([]byte, v.Len())
   439  	copy(bs, v.AsSlice())
   440  	ret := net.IP(bs)
   441  	return &ret
   442  }
   443  
   444  // parseIPv4 parses the bytes assuming that they start with an ipv4 header and
   445  // continues parsing further encapsulations.
   446  func parseIPv4(b []byte) (Layer, bodySizeHint, layerParser) {
   447  	h := header.IPv4(b)
   448  	options := h.Options()
   449  	tos, _ := h.TOS()
   450  	ipv4 := IPv4{
   451  		IHL:            Uint8(h.HeaderLength()),
   452  		TOS:            &tos,
   453  		TotalLength:    Uint16(h.TotalLength()),
   454  		ID:             Uint16(h.ID()),
   455  		Flags:          Uint8(h.Flags()),
   456  		FragmentOffset: Uint16(h.FragmentOffset()),
   457  		TTL:            Uint8(h.TTL()),
   458  		Protocol:       Uint8(h.Protocol()),
   459  		Checksum:       Uint16(h.Checksum()),
   460  		SrcAddr:        Address(h.SourceAddress()),
   461  		DstAddr:        Address(h.DestinationAddress()),
   462  		Options:        &options,
   463  	}
   464  	var nextParser layerParser
   465  	// If it is a fragment, don't treat it as having a transport protocol.
   466  	if h.FragmentOffset() != 0 || h.More() {
   467  		return &ipv4, bodySizeHint(h.PayloadLength()), parsePayload
   468  	}
   469  	switch h.TransportProtocol() {
   470  	case header.TCPProtocolNumber:
   471  		nextParser = parseTCP
   472  	case header.UDPProtocolNumber:
   473  		nextParser = parseUDP
   474  	case header.ICMPv4ProtocolNumber:
   475  		nextParser = parseICMPv4
   476  	default:
   477  		// Assume that the rest is a payload.
   478  		nextParser = parsePayload
   479  	}
   480  	return &ipv4, bodySizeHint(h.PayloadLength()), nextParser
   481  }
   482  
   483  func (l *IPv4) match(other Layer) bool {
   484  	return equalLayer(l, other)
   485  }
   486  
   487  func (l *IPv4) length() int {
   488  	if l.IHL == nil {
   489  		return header.IPv4MinimumSize
   490  	}
   491  	return int(*l.IHL)
   492  }
   493  
   494  // merge implements Layer.merge.
   495  func (l *IPv4) merge(other Layer) error {
   496  	return mergeLayer(l, other)
   497  }
   498  
   499  // IPv6 can construct and match an IPv6 encapsulation.
   500  type IPv6 struct {
   501  	LayerBase
   502  	TrafficClass  *uint8
   503  	FlowLabel     *uint32
   504  	PayloadLength *uint16
   505  	NextHeader    *uint8
   506  	HopLimit      *uint8
   507  	SrcAddr       *net.IP
   508  	DstAddr       *net.IP
   509  }
   510  
   511  func (l *IPv6) String() string {
   512  	return stringLayer(l)
   513  }
   514  
   515  // ToBytes implements Layer.ToBytes.
   516  func (l *IPv6) ToBytes() ([]byte, error) {
   517  	b := make([]byte, header.IPv6MinimumSize)
   518  	h := header.IPv6(b)
   519  	fields := &header.IPv6Fields{
   520  		HopLimit: 64,
   521  	}
   522  	if l.TrafficClass != nil {
   523  		fields.TrafficClass = *l.TrafficClass
   524  	}
   525  	if l.FlowLabel != nil {
   526  		fields.FlowLabel = *l.FlowLabel
   527  	}
   528  	if l.PayloadLength != nil {
   529  		fields.PayloadLength = *l.PayloadLength
   530  	} else {
   531  		for current := l.next(); current != nil; current = current.next() {
   532  			fields.PayloadLength += uint16(current.length())
   533  		}
   534  	}
   535  	if l.NextHeader != nil {
   536  		fields.TransportProtocol = tcpip.TransportProtocolNumber(*l.NextHeader)
   537  	} else {
   538  		nh, err := nextHeaderByLayer(l.next())
   539  		if err != nil {
   540  			return nil, err
   541  		}
   542  		fields.TransportProtocol = tcpip.TransportProtocolNumber(nh)
   543  	}
   544  	if l.HopLimit != nil {
   545  		fields.HopLimit = *l.HopLimit
   546  	}
   547  	if l.SrcAddr != nil && len(*l.SrcAddr) > 0 {
   548  		fields.SrcAddr = tcpip.AddrFrom16Slice(*l.SrcAddr)
   549  	}
   550  	if l.DstAddr != nil && len(*l.DstAddr) > 0 {
   551  		fields.DstAddr = tcpip.AddrFrom16Slice(*l.DstAddr)
   552  	}
   553  	h.Encode(fields)
   554  	return h, nil
   555  }
   556  
   557  // nextIPv6PayloadParser finds the corresponding parser for nextHeader.
   558  func nextIPv6PayloadParser(nextHeader uint8) layerParser {
   559  	switch tcpip.TransportProtocolNumber(nextHeader) {
   560  	case header.TCPProtocolNumber:
   561  		return parseTCP
   562  	case header.UDPProtocolNumber:
   563  		return parseUDP
   564  	case header.ICMPv6ProtocolNumber:
   565  		return parseICMPv6
   566  	}
   567  	switch header.IPv6ExtensionHeaderIdentifier(nextHeader) {
   568  	case header.IPv6HopByHopOptionsExtHdrIdentifier:
   569  		return parseIPv6HopByHopOptionsExtHdr
   570  	case header.IPv6DestinationOptionsExtHdrIdentifier:
   571  		return parseIPv6DestinationOptionsExtHdr
   572  	case header.IPv6FragmentExtHdrIdentifier:
   573  		return parseIPv6FragmentExtHdr
   574  	}
   575  	return parsePayload
   576  }
   577  
   578  // parseIPv6 parses the bytes assuming that they start with an ipv6 header and
   579  // continues parsing further encapsulations.
   580  func parseIPv6(b []byte) (Layer, bodySizeHint, layerParser) {
   581  	h := header.IPv6(b)
   582  	tos, flowLabel := h.TOS()
   583  	ipv6 := IPv6{
   584  		TrafficClass:  &tos,
   585  		FlowLabel:     &flowLabel,
   586  		PayloadLength: Uint16(h.PayloadLength()),
   587  		NextHeader:    Uint8(h.NextHeader()),
   588  		HopLimit:      Uint8(h.HopLimit()),
   589  		SrcAddr:       Address(h.SourceAddress()),
   590  		DstAddr:       Address(h.DestinationAddress()),
   591  	}
   592  	nextParser := nextIPv6PayloadParser(h.NextHeader())
   593  	return &ipv6, bodySizeHint(h.PayloadLength()), nextParser
   594  }
   595  
   596  func (l *IPv6) match(other Layer) bool {
   597  	return equalLayer(l, other)
   598  }
   599  
   600  func (l *IPv6) length() int {
   601  	return header.IPv6MinimumSize
   602  }
   603  
   604  // merge overrides the values in l with the values from other but only in fields
   605  // where the value is not nil.
   606  func (l *IPv6) merge(other Layer) error {
   607  	return mergeLayer(l, other)
   608  }
   609  
   610  // IPv6HopByHopOptionsExtHdr can construct and match an IPv6HopByHopOptions
   611  // Extension Header.
   612  type IPv6HopByHopOptionsExtHdr struct {
   613  	LayerBase
   614  	NextHeader *header.IPv6ExtensionHeaderIdentifier
   615  	Options    []byte
   616  }
   617  
   618  // IPv6DestinationOptionsExtHdr can construct and match an IPv6DestinationOptions
   619  // Extension Header.
   620  type IPv6DestinationOptionsExtHdr struct {
   621  	LayerBase
   622  	NextHeader *header.IPv6ExtensionHeaderIdentifier
   623  	Options    []byte
   624  }
   625  
   626  // IPv6FragmentExtHdr can construct and match an IPv6 Fragment Extension Header.
   627  type IPv6FragmentExtHdr struct {
   628  	LayerBase
   629  	NextHeader     *header.IPv6ExtensionHeaderIdentifier
   630  	FragmentOffset *uint16
   631  	MoreFragments  *bool
   632  	Identification *uint32
   633  }
   634  
   635  // nextHeaderByLayer finds the correct next header protocol value for layer l.
   636  func nextHeaderByLayer(l Layer) (uint8, error) {
   637  	if l == nil {
   638  		return uint8(header.IPv6NoNextHeaderIdentifier), nil
   639  	}
   640  	switch l.(type) {
   641  	case *TCP:
   642  		return uint8(header.TCPProtocolNumber), nil
   643  	case *UDP:
   644  		return uint8(header.UDPProtocolNumber), nil
   645  	case *ICMPv6:
   646  		return uint8(header.ICMPv6ProtocolNumber), nil
   647  	case *Payload:
   648  		return uint8(header.IPv6NoNextHeaderIdentifier), nil
   649  	case *IPv6HopByHopOptionsExtHdr:
   650  		return uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), nil
   651  	case *IPv6DestinationOptionsExtHdr:
   652  		return uint8(header.IPv6DestinationOptionsExtHdrIdentifier), nil
   653  	case *IPv6FragmentExtHdr:
   654  		return uint8(header.IPv6FragmentExtHdrIdentifier), nil
   655  	default:
   656  		// TODO(b/161005083): Support more protocols as needed.
   657  		return 0, fmt.Errorf("failed to deduce the IPv6 header's next protocol: %T", l)
   658  	}
   659  }
   660  
   661  // ipv6OptionsExtHdrToBytes serializes an options extension header into bytes.
   662  func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, nextLayer Layer, options []byte) ([]byte, error) {
   663  	length := len(options) + 2
   664  	if length%8 != 0 {
   665  		return nil, fmt.Errorf("IPv6 extension headers must be a multiple of 8 octets long, but the length given: %d, options: %s", length, hex.Dump(options))
   666  	}
   667  	bytes := make([]byte, length)
   668  	if nextHeader != nil {
   669  		bytes[0] = byte(*nextHeader)
   670  	} else {
   671  		nh, err := nextHeaderByLayer(nextLayer)
   672  		if err != nil {
   673  			return nil, err
   674  		}
   675  		bytes[0] = nh
   676  	}
   677  	// ExtHdrLen field is the length of the extension header
   678  	// in 8-octet unit, ignoring the first 8 octets.
   679  	// https://tools.ietf.org/html/rfc2460#section-4.3
   680  	// https://tools.ietf.org/html/rfc2460#section-4.6
   681  	bytes[1] = uint8((length - 8) / 8)
   682  	copy(bytes[2:], options)
   683  	return bytes, nil
   684  }
   685  
   686  // IPv6ExtHdrIdent is a helper routine that allocates a new
   687  // header.IPv6ExtensionHeaderIdentifier value to store v and returns a pointer
   688  // to it.
   689  func IPv6ExtHdrIdent(id header.IPv6ExtensionHeaderIdentifier) *header.IPv6ExtensionHeaderIdentifier {
   690  	return &id
   691  }
   692  
   693  // ToBytes implements Layer.ToBytes.
   694  func (l *IPv6HopByHopOptionsExtHdr) ToBytes() ([]byte, error) {
   695  	return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options)
   696  }
   697  
   698  // ToBytes implements Layer.ToBytes.
   699  func (l *IPv6DestinationOptionsExtHdr) ToBytes() ([]byte, error) {
   700  	return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options)
   701  }
   702  
   703  // ToBytes implements Layer.ToBytes.
   704  func (l *IPv6FragmentExtHdr) ToBytes() ([]byte, error) {
   705  	var offset, mflag uint16
   706  	var ident uint32
   707  	bytes := make([]byte, header.IPv6FragmentExtHdrLength)
   708  	if l.NextHeader != nil {
   709  		bytes[0] = byte(*l.NextHeader)
   710  	} else {
   711  		nh, err := nextHeaderByLayer(l.next())
   712  		if err != nil {
   713  			return nil, err
   714  		}
   715  		bytes[0] = nh
   716  	}
   717  	bytes[1] = 0 // reserved
   718  	if l.MoreFragments != nil && *l.MoreFragments {
   719  		mflag = 1
   720  	}
   721  	if l.FragmentOffset != nil {
   722  		offset = *l.FragmentOffset
   723  	}
   724  	if l.Identification != nil {
   725  		ident = *l.Identification
   726  	}
   727  	offsetAndMflag := offset<<3 | mflag
   728  	binary.BigEndian.PutUint16(bytes[2:], offsetAndMflag)
   729  	binary.BigEndian.PutUint32(bytes[4:], ident)
   730  
   731  	return bytes, nil
   732  }
   733  
   734  // parseIPv6ExtHdr parses an IPv6 extension header and returns the NextHeader
   735  // field, the rest of the payload and a parser function for the corresponding
   736  // next extension header.
   737  func parseIPv6ExtHdr(b []byte) (header.IPv6ExtensionHeaderIdentifier, []byte, layerParser) {
   738  	nextHeader := b[0]
   739  	// For HopByHop and Destination options extension headers,
   740  	// This field is the length of the extension header in
   741  	// 8-octet units, not including the first 8 octets.
   742  	// https://tools.ietf.org/html/rfc2460#section-4.3
   743  	// https://tools.ietf.org/html/rfc2460#section-4.6
   744  	length := b[1]*8 + 8
   745  	data := b[2:length]
   746  	nextParser := nextIPv6PayloadParser(nextHeader)
   747  	return header.IPv6ExtensionHeaderIdentifier(nextHeader), data, nextParser
   748  }
   749  
   750  // parseIPv6HopByHopOptionsExtHdr parses the bytes assuming that they start
   751  // with an IPv6 HopByHop Options Extension Header.
   752  func parseIPv6HopByHopOptionsExtHdr(b []byte) (Layer, bodySizeHint, layerParser) {
   753  	nextHeader, options, nextParser := parseIPv6ExtHdr(b)
   754  	return &IPv6HopByHopOptionsExtHdr{NextHeader: &nextHeader, Options: options}, bodySizeUnknown, nextParser
   755  }
   756  
   757  // parseIPv6DestinationOptionsExtHdr parses the bytes assuming that they start
   758  // with an IPv6 Destination Options Extension Header.
   759  func parseIPv6DestinationOptionsExtHdr(b []byte) (Layer, bodySizeHint, layerParser) {
   760  	nextHeader, options, nextParser := parseIPv6ExtHdr(b)
   761  	return &IPv6DestinationOptionsExtHdr{NextHeader: &nextHeader, Options: options}, bodySizeUnknown, nextParser
   762  }
   763  
   764  // Bool is a helper routine that allocates a new
   765  // bool value to store v and returns a pointer to it.
   766  func Bool(v bool) *bool {
   767  	return &v
   768  }
   769  
   770  // parseIPv6FragmentExtHdr parses the bytes assuming that they start
   771  // with an IPv6 Fragment Extension Header.
   772  func parseIPv6FragmentExtHdr(b []byte) (Layer, bodySizeHint, layerParser) {
   773  	nextHeader := b[0]
   774  	var extHdr header.IPv6FragmentExtHdr
   775  	copy(extHdr[:], b[2:])
   776  	fragLayer := IPv6FragmentExtHdr{
   777  		NextHeader:     IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(nextHeader)),
   778  		FragmentOffset: Uint16(extHdr.FragmentOffset()),
   779  		MoreFragments:  Bool(extHdr.More()),
   780  		Identification: Uint32(extHdr.ID()),
   781  	}
   782  	// If it is a fragment, we can't interpret it.
   783  	if extHdr.FragmentOffset() != 0 || extHdr.More() {
   784  		return &fragLayer, bodySizeUnknown, parsePayload
   785  	}
   786  	return &fragLayer, bodySizeUnknown, nextIPv6PayloadParser(nextHeader)
   787  }
   788  
   789  func (l *IPv6HopByHopOptionsExtHdr) length() int {
   790  	return len(l.Options) + 2
   791  }
   792  
   793  func (l *IPv6HopByHopOptionsExtHdr) match(other Layer) bool {
   794  	return equalLayer(l, other)
   795  }
   796  
   797  // merge overrides the values in l with the values from other but only in fields
   798  // where the value is not nil.
   799  func (l *IPv6HopByHopOptionsExtHdr) merge(other Layer) error {
   800  	return mergeLayer(l, other)
   801  }
   802  
   803  func (l *IPv6HopByHopOptionsExtHdr) String() string {
   804  	return stringLayer(l)
   805  }
   806  
   807  func (l *IPv6DestinationOptionsExtHdr) length() int {
   808  	return len(l.Options) + 2
   809  }
   810  
   811  func (l *IPv6DestinationOptionsExtHdr) match(other Layer) bool {
   812  	return equalLayer(l, other)
   813  }
   814  
   815  // merge overrides the values in l with the values from other but only in fields
   816  // where the value is not nil.
   817  func (l *IPv6DestinationOptionsExtHdr) merge(other Layer) error {
   818  	return mergeLayer(l, other)
   819  }
   820  
   821  func (l *IPv6DestinationOptionsExtHdr) String() string {
   822  	return stringLayer(l)
   823  }
   824  
   825  func (*IPv6FragmentExtHdr) length() int {
   826  	return header.IPv6FragmentExtHdrLength
   827  }
   828  
   829  func (l *IPv6FragmentExtHdr) match(other Layer) bool {
   830  	return equalLayer(l, other)
   831  }
   832  
   833  // merge overrides the values in l with the values from other but only in fields
   834  // where the value is not nil.
   835  func (l *IPv6FragmentExtHdr) merge(other Layer) error {
   836  	return mergeLayer(l, other)
   837  }
   838  
   839  func (l *IPv6FragmentExtHdr) String() string {
   840  	return stringLayer(l)
   841  }
   842  
   843  // ICMPv6 can construct and match an ICMPv6 encapsulation.
   844  type ICMPv6 struct {
   845  	LayerBase
   846  	Type     *header.ICMPv6Type
   847  	Code     *header.ICMPv6Code
   848  	Checksum *uint16
   849  	Ident    *uint16 // Only in Echo Request/Reply.
   850  	Pointer  *uint32 // Only in Parameter Problem.
   851  	Payload  []byte
   852  }
   853  
   854  func (l *ICMPv6) String() string {
   855  	// TODO(eyalsoha): Do something smarter here when *l.Type is ParameterProblem?
   856  	// We could parse the contents of the Payload as if it were an IPv6 packet.
   857  	return stringLayer(l)
   858  }
   859  
   860  // ToBytes implements Layer.ToBytes.
   861  func (l *ICMPv6) ToBytes() ([]byte, error) {
   862  	b := make([]byte, header.ICMPv6MinimumSize+len(l.Payload))
   863  	h := header.ICMPv6(b)
   864  	if l.Type != nil {
   865  		h.SetType(*l.Type)
   866  	}
   867  	if l.Code != nil {
   868  		h.SetCode(*l.Code)
   869  	}
   870  	if n := copy(h.Payload(), l.Payload); n != len(l.Payload) {
   871  		panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, len(l.Payload)))
   872  	}
   873  	typ := h.Type()
   874  	switch typ {
   875  	case header.ICMPv6EchoRequest, header.ICMPv6EchoReply:
   876  		if l.Ident != nil {
   877  			h.SetIdent(*l.Ident)
   878  		}
   879  	case header.ICMPv6ParamProblem:
   880  		if l.Pointer != nil {
   881  			h.SetTypeSpecific(*l.Pointer)
   882  		}
   883  	}
   884  	if l.Checksum != nil {
   885  		h.SetChecksum(*l.Checksum)
   886  	} else {
   887  		// It is possible that the ICMPv6 header does not follow the IPv6 header
   888  		// immediately, there could be one or more extension headers in between.
   889  		// We need to search backwards to find the IPv6 header.
   890  		for layer := l.Prev(); layer != nil; layer = layer.Prev() {
   891  			if ipv6, ok := layer.(*IPv6); ok {
   892  				h.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
   893  					Header:      h[:header.ICMPv6PayloadOffset],
   894  					Src:         tcpip.AddrFrom16Slice(*ipv6.SrcAddr),
   895  					Dst:         tcpip.AddrFrom16Slice(*ipv6.DstAddr),
   896  					PayloadCsum: checksum.Checksum(l.Payload, 0 /* initial */),
   897  					PayloadLen:  len(l.Payload),
   898  				}))
   899  				break
   900  			}
   901  		}
   902  	}
   903  	return h, nil
   904  }
   905  
   906  // ICMPv6Type is a helper routine that allocates a new ICMPv6Type value to store
   907  // v and returns a pointer to it.
   908  func ICMPv6Type(v header.ICMPv6Type) *header.ICMPv6Type {
   909  	return &v
   910  }
   911  
   912  // ICMPv6Code is a helper routine that allocates a new ICMPv6Type value to store
   913  // v and returns a pointer to it.
   914  func ICMPv6Code(v header.ICMPv6Code) *header.ICMPv6Code {
   915  	return &v
   916  }
   917  
   918  // parseICMPv6 parses the bytes assuming that they start with an ICMPv6 header.
   919  func parseICMPv6(b []byte) (Layer, bodySizeHint, layerParser) {
   920  	h := header.ICMPv6(b)
   921  	msgType := h.Type()
   922  	icmpv6 := ICMPv6{
   923  		Type:     ICMPv6Type(msgType),
   924  		Code:     ICMPv6Code(h.Code()),
   925  		Checksum: Uint16(h.Checksum()),
   926  		Payload:  h.Payload(),
   927  	}
   928  	switch msgType {
   929  	case header.ICMPv6EchoRequest, header.ICMPv6EchoReply:
   930  		icmpv6.Ident = Uint16(h.Ident())
   931  	case header.ICMPv6ParamProblem:
   932  		icmpv6.Pointer = Uint32(h.TypeSpecific())
   933  	}
   934  	return &icmpv6, bodySizeUnknown, nil
   935  }
   936  
   937  func (l *ICMPv6) match(other Layer) bool {
   938  	return equalLayer(l, other)
   939  }
   940  
   941  func (l *ICMPv6) length() int {
   942  	return header.ICMPv6MinimumSize + len(l.Payload)
   943  }
   944  
   945  // merge overrides the values in l with the values from other but only in fields
   946  // where the value is not nil.
   947  func (l *ICMPv6) merge(other Layer) error {
   948  	return mergeLayer(l, other)
   949  }
   950  
   951  // ICMPv4 can construct and match an ICMPv4 encapsulation.
   952  type ICMPv4 struct {
   953  	LayerBase
   954  	Type     *header.ICMPv4Type
   955  	Code     *header.ICMPv4Code
   956  	Checksum *uint16
   957  	Ident    *uint16 // Only in Echo Request/Reply.
   958  	Sequence *uint16 // Only in Echo Request/Reply.
   959  	Pointer  *uint8  // Only in Parameter Problem.
   960  	Payload  []byte
   961  }
   962  
   963  func (l *ICMPv4) String() string {
   964  	return stringLayer(l)
   965  }
   966  
   967  // ICMPv4Type is a helper routine that allocates a new header.ICMPv4Type value
   968  // to store t and returns a pointer to it.
   969  func ICMPv4Type(t header.ICMPv4Type) *header.ICMPv4Type {
   970  	return &t
   971  }
   972  
   973  // ICMPv4Code is a helper routine that allocates a new header.ICMPv4Code value
   974  // to store t and returns a pointer to it.
   975  func ICMPv4Code(t header.ICMPv4Code) *header.ICMPv4Code {
   976  	return &t
   977  }
   978  
   979  // ToBytes implements Layer.ToBytes.
   980  func (l *ICMPv4) ToBytes() ([]byte, error) {
   981  	b := make([]byte, header.ICMPv4MinimumSize+len(l.Payload))
   982  	h := header.ICMPv4(b)
   983  	if l.Type != nil {
   984  		h.SetType(*l.Type)
   985  	}
   986  	if l.Code != nil {
   987  		h.SetCode(*l.Code)
   988  	}
   989  	if n := copy(h.Payload(), l.Payload); n != len(l.Payload) {
   990  		panic(fmt.Sprintf("wrong number of bytes copied into h.Payload(): got = %d, want = %d", n, len(l.Payload)))
   991  	}
   992  	typ := h.Type()
   993  	switch typ {
   994  	case header.ICMPv4EchoReply, header.ICMPv4Echo:
   995  		if l.Ident != nil {
   996  			h.SetIdent(*l.Ident)
   997  		}
   998  		if l.Sequence != nil {
   999  			h.SetSequence(*l.Sequence)
  1000  		}
  1001  	case header.ICMPv4ParamProblem:
  1002  		if l.Pointer != nil {
  1003  			h.SetPointer(*l.Pointer)
  1004  		}
  1005  	}
  1006  
  1007  	// The checksum must be handled last because the ICMPv4 header fields are
  1008  	// included in the computation.
  1009  	if l.Checksum != nil {
  1010  		h.SetChecksum(*l.Checksum)
  1011  	} else {
  1012  		h.SetChecksum(^checksum.Checksum(h, 0))
  1013  	}
  1014  
  1015  	return h, nil
  1016  }
  1017  
  1018  // parseICMPv4 parses the bytes as an ICMPv4 header, returning a Layer and a
  1019  // parser for the encapsulated payload.
  1020  func parseICMPv4(b []byte) (Layer, bodySizeHint, layerParser) {
  1021  	h := header.ICMPv4(b)
  1022  
  1023  	msgType := h.Type()
  1024  	icmpv4 := ICMPv4{
  1025  		Type:     ICMPv4Type(msgType),
  1026  		Code:     ICMPv4Code(h.Code()),
  1027  		Checksum: Uint16(h.Checksum()),
  1028  		Payload:  h.Payload(),
  1029  	}
  1030  	switch msgType {
  1031  	case header.ICMPv4EchoReply, header.ICMPv4Echo:
  1032  		icmpv4.Ident = Uint16(h.Ident())
  1033  		icmpv4.Sequence = Uint16(h.Sequence())
  1034  	case header.ICMPv4ParamProblem:
  1035  		icmpv4.Pointer = Uint8(h.Pointer())
  1036  	}
  1037  	return &icmpv4, bodySizeUnknown, nil
  1038  }
  1039  
  1040  func (l *ICMPv4) match(other Layer) bool {
  1041  	return equalLayer(l, other)
  1042  }
  1043  
  1044  func (l *ICMPv4) length() int {
  1045  	return header.ICMPv4MinimumSize + len(l.Payload)
  1046  }
  1047  
  1048  // merge overrides the values in l with the values from other but only in fields
  1049  // where the value is not nil.
  1050  func (l *ICMPv4) merge(other Layer) error {
  1051  	return mergeLayer(l, other)
  1052  }
  1053  
  1054  // TCP can construct and match a TCP encapsulation.
  1055  type TCP struct {
  1056  	LayerBase
  1057  	SrcPort       *uint16
  1058  	DstPort       *uint16
  1059  	SeqNum        *uint32
  1060  	AckNum        *uint32
  1061  	DataOffset    *uint8
  1062  	Flags         *header.TCPFlags
  1063  	WindowSize    *uint16
  1064  	Checksum      *uint16
  1065  	UrgentPointer *uint16
  1066  	Options       []byte
  1067  }
  1068  
  1069  func (l *TCP) String() string {
  1070  	return stringLayer(l)
  1071  }
  1072  
  1073  // ToBytes implements Layer.ToBytes.
  1074  func (l *TCP) ToBytes() ([]byte, error) {
  1075  	b := make([]byte, l.length())
  1076  	h := header.TCP(b)
  1077  	if l.SrcPort != nil {
  1078  		h.SetSourcePort(*l.SrcPort)
  1079  	}
  1080  	if l.DstPort != nil {
  1081  		h.SetDestinationPort(*l.DstPort)
  1082  	}
  1083  	if l.SeqNum != nil {
  1084  		h.SetSequenceNumber(*l.SeqNum)
  1085  	}
  1086  	if l.AckNum != nil {
  1087  		h.SetAckNumber(*l.AckNum)
  1088  	}
  1089  	if l.DataOffset != nil {
  1090  		h.SetDataOffset(*l.DataOffset)
  1091  	} else {
  1092  		h.SetDataOffset(uint8(l.length()))
  1093  	}
  1094  	if l.Flags != nil {
  1095  		h.SetFlags(uint8(*l.Flags))
  1096  	}
  1097  	if l.WindowSize != nil {
  1098  		h.SetWindowSize(*l.WindowSize)
  1099  	} else {
  1100  		h.SetWindowSize(32768)
  1101  	}
  1102  	if l.UrgentPointer != nil {
  1103  		h.SetUrgentPointer(*l.UrgentPointer)
  1104  	}
  1105  	copy(b[header.TCPMinimumSize:], l.Options)
  1106  	header.AddTCPOptionPadding(b[header.TCPMinimumSize:], len(l.Options))
  1107  	if l.Checksum != nil {
  1108  		h.SetChecksum(*l.Checksum)
  1109  		return h, nil
  1110  	}
  1111  	if err := setTCPChecksum(&h, l); err != nil {
  1112  		return nil, err
  1113  	}
  1114  	return h, nil
  1115  }
  1116  
  1117  // totalLength returns the length of the provided layer and all following
  1118  // layers.
  1119  func totalLength(l Layer) int {
  1120  	var totalLength int
  1121  	for ; l != nil; l = l.next() {
  1122  		totalLength += l.length()
  1123  	}
  1124  	return totalLength
  1125  }
  1126  
  1127  // payload returns a buffer.Buffer of l's payload.
  1128  func payload(l Layer) (buffer.Buffer, error) {
  1129  	var payloadBytes buffer.Buffer
  1130  	for current := l.next(); current != nil; current = current.next() {
  1131  		payload, err := current.ToBytes()
  1132  		if err != nil {
  1133  			return buffer.Buffer{}, fmt.Errorf("can't get bytes for next header: %s", payload)
  1134  		}
  1135  		payloadBytes.Append(buffer.NewViewWithData(payload))
  1136  	}
  1137  	return payloadBytes, nil
  1138  }
  1139  
  1140  // layerChecksum calculates the checksum of the Layer header, including the
  1141  // peusdeochecksum of the layer before it and all the bytes after it.
  1142  func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) {
  1143  	totalLength := uint16(totalLength(l))
  1144  	var xsum uint16
  1145  	switch p := l.Prev().(type) {
  1146  	case *IPv4:
  1147  		xsum = header.PseudoHeaderChecksum(protoNumber, tcpip.AddrFrom4Slice(*p.SrcAddr), tcpip.AddrFrom4Slice(*p.DstAddr), totalLength)
  1148  	case *IPv6:
  1149  		xsum = header.PseudoHeaderChecksum(protoNumber, tcpip.AddrFrom16Slice(*p.SrcAddr), tcpip.AddrFrom16Slice(*p.DstAddr), totalLength)
  1150  	default:
  1151  		// TODO(b/161246171): Support more protocols.
  1152  		return 0, fmt.Errorf("checksum for protocol %d is not supported when previous layer is %T", protoNumber, p)
  1153  	}
  1154  	payloadBytes, err := payload(l)
  1155  	if err != nil {
  1156  		return 0, err
  1157  	}
  1158  	xsum = checksum.Checksum(payloadBytes.Flatten(), xsum)
  1159  	return xsum, nil
  1160  }
  1161  
  1162  // setTCPChecksum calculates the checksum of the TCP header and sets it in h.
  1163  func setTCPChecksum(h *header.TCP, tcp *TCP) error {
  1164  	h.SetChecksum(0)
  1165  	xsum, err := layerChecksum(tcp, header.TCPProtocolNumber)
  1166  	if err != nil {
  1167  		return err
  1168  	}
  1169  	h.SetChecksum(^h.CalculateChecksum(xsum))
  1170  	return nil
  1171  }
  1172  
  1173  // Uint32 is a helper routine that allocates a new
  1174  // uint32 value to store v and returns a pointer to it.
  1175  func Uint32(v uint32) *uint32 {
  1176  	return &v
  1177  }
  1178  
  1179  // parseTCP parses the bytes assuming that they start with a tcp header and
  1180  // continues parsing further encapsulations.
  1181  func parseTCP(b []byte) (Layer, bodySizeHint, layerParser) {
  1182  	h := header.TCP(b)
  1183  	tcp := TCP{
  1184  		SrcPort:       Uint16(h.SourcePort()),
  1185  		DstPort:       Uint16(h.DestinationPort()),
  1186  		SeqNum:        Uint32(h.SequenceNumber()),
  1187  		AckNum:        Uint32(h.AckNumber()),
  1188  		DataOffset:    Uint8(h.DataOffset()),
  1189  		Flags:         TCPFlags(h.Flags()),
  1190  		WindowSize:    Uint16(h.WindowSize()),
  1191  		Checksum:      Uint16(h.Checksum()),
  1192  		UrgentPointer: Uint16(h.UrgentPointer()),
  1193  		Options:       b[header.TCPMinimumSize:h.DataOffset()],
  1194  	}
  1195  	return &tcp, bodySizeUnknown, parsePayload
  1196  }
  1197  
  1198  func (l *TCP) match(other Layer) bool {
  1199  	return equalLayer(l, other)
  1200  }
  1201  
  1202  func (l *TCP) length() int {
  1203  	if l.DataOffset == nil {
  1204  		// TCP header including the options must end on a 32-bit
  1205  		// boundary; the user could potentially give us a slice
  1206  		// whose length is not a multiple of 4 bytes, so we have
  1207  		// to do the alignment here.
  1208  		optlen := (len(l.Options) + 3) & ^3
  1209  		return header.TCPMinimumSize + optlen
  1210  	}
  1211  	return int(*l.DataOffset)
  1212  }
  1213  
  1214  // merge implements Layer.merge.
  1215  func (l *TCP) merge(other Layer) error {
  1216  	return mergeLayer(l, other)
  1217  }
  1218  
  1219  // UDP can construct and match a UDP encapsulation.
  1220  type UDP struct {
  1221  	LayerBase
  1222  	SrcPort  *uint16
  1223  	DstPort  *uint16
  1224  	Length   *uint16
  1225  	Checksum *uint16
  1226  }
  1227  
  1228  func (l *UDP) String() string {
  1229  	return stringLayer(l)
  1230  }
  1231  
  1232  // ToBytes implements Layer.ToBytes.
  1233  func (l *UDP) ToBytes() ([]byte, error) {
  1234  	b := make([]byte, header.UDPMinimumSize)
  1235  	h := header.UDP(b)
  1236  	if l.SrcPort != nil {
  1237  		h.SetSourcePort(*l.SrcPort)
  1238  	}
  1239  	if l.DstPort != nil {
  1240  		h.SetDestinationPort(*l.DstPort)
  1241  	}
  1242  	if l.Length != nil {
  1243  		h.SetLength(*l.Length)
  1244  	} else {
  1245  		h.SetLength(uint16(totalLength(l)))
  1246  	}
  1247  	if l.Checksum != nil {
  1248  		h.SetChecksum(*l.Checksum)
  1249  		return h, nil
  1250  	}
  1251  	if err := setUDPChecksum(&h, l); err != nil {
  1252  		return nil, err
  1253  	}
  1254  	return h, nil
  1255  }
  1256  
  1257  // setUDPChecksum calculates the checksum of the UDP header and sets it in h.
  1258  func setUDPChecksum(h *header.UDP, udp *UDP) error {
  1259  	h.SetChecksum(0)
  1260  	xsum, err := layerChecksum(udp, header.UDPProtocolNumber)
  1261  	if err != nil {
  1262  		return err
  1263  	}
  1264  	h.SetChecksum(^h.CalculateChecksum(xsum))
  1265  	return nil
  1266  }
  1267  
  1268  // parseUDP parses the bytes assuming that they start with a udp header and
  1269  // returns the parsed layer and the next parser to use.
  1270  func parseUDP(b []byte) (Layer, bodySizeHint, layerParser) {
  1271  	h := header.UDP(b)
  1272  	udp := UDP{
  1273  		SrcPort:  Uint16(h.SourcePort()),
  1274  		DstPort:  Uint16(h.DestinationPort()),
  1275  		Length:   Uint16(h.Length()),
  1276  		Checksum: Uint16(h.Checksum()),
  1277  	}
  1278  	return &udp, bodySizeUnknown, parsePayload
  1279  }
  1280  
  1281  func (l *UDP) match(other Layer) bool {
  1282  	return equalLayer(l, other)
  1283  }
  1284  
  1285  func (l *UDP) length() int {
  1286  	return header.UDPMinimumSize
  1287  }
  1288  
  1289  // merge implements Layer.merge.
  1290  func (l *UDP) merge(other Layer) error {
  1291  	return mergeLayer(l, other)
  1292  }
  1293  
  1294  // Payload has bytes beyond OSI layer 4.
  1295  type Payload struct {
  1296  	LayerBase
  1297  	Bytes []byte
  1298  }
  1299  
  1300  func (l *Payload) String() string {
  1301  	return stringLayer(l)
  1302  }
  1303  
  1304  // parsePayload parses the bytes assuming that they start with a payload and
  1305  // continue to the end. There can be no further encapsulations.
  1306  func parsePayload(b []byte) (Layer, bodySizeHint, layerParser) {
  1307  	payload := Payload{
  1308  		Bytes: b,
  1309  	}
  1310  	return &payload, bodySizeUnknown, nil
  1311  }
  1312  
  1313  // ToBytes implements Layer.ToBytes.
  1314  func (l *Payload) ToBytes() ([]byte, error) {
  1315  	return l.Bytes, nil
  1316  }
  1317  
  1318  // Length returns payload byte length.
  1319  func (l *Payload) Length() int {
  1320  	return l.length()
  1321  }
  1322  
  1323  func (l *Payload) match(other Layer) bool {
  1324  	return equalLayer(l, other)
  1325  }
  1326  
  1327  func (l *Payload) length() int {
  1328  	return len(l.Bytes)
  1329  }
  1330  
  1331  // merge implements Layer.merge.
  1332  func (l *Payload) merge(other Layer) error {
  1333  	return mergeLayer(l, other)
  1334  }
  1335  
  1336  // Layers is an array of Layer and supports similar functions to Layer.
  1337  type Layers []Layer
  1338  
  1339  // linkLayers sets the linked-list ponters in ls.
  1340  func (ls *Layers) linkLayers() {
  1341  	for i, l := range *ls {
  1342  		if i > 0 {
  1343  			l.setPrev((*ls)[i-1])
  1344  		} else {
  1345  			l.setPrev(nil)
  1346  		}
  1347  		if i+1 < len(*ls) {
  1348  			l.setNext((*ls)[i+1])
  1349  		} else {
  1350  			l.setNext(nil)
  1351  		}
  1352  	}
  1353  }
  1354  
  1355  // ToBytes converts the Layers into bytes. It creates a linked list of the Layer
  1356  // structs and then concatentates the output of ToBytes on each Layer.
  1357  func (ls *Layers) ToBytes() ([]byte, error) {
  1358  	ls.linkLayers()
  1359  	outBytes := []byte{}
  1360  	for _, l := range *ls {
  1361  		layerBytes, err := l.ToBytes()
  1362  		if err != nil {
  1363  			return nil, err
  1364  		}
  1365  		outBytes = append(outBytes, layerBytes...)
  1366  	}
  1367  	return outBytes, nil
  1368  }
  1369  
  1370  func (ls *Layers) match(other Layers) bool {
  1371  	if len(*ls) > len(other) {
  1372  		return false
  1373  	}
  1374  	for i, l := range *ls {
  1375  		if !equalLayer(l, other[i]) {
  1376  			return false
  1377  		}
  1378  	}
  1379  	return true
  1380  }
  1381  
  1382  // layerDiff stores the diffs for each field along with the label for the Layer.
  1383  // If rows is nil, that means that there was no diff.
  1384  type layerDiff struct {
  1385  	label string
  1386  	rows  []layerDiffRow
  1387  }
  1388  
  1389  // layerDiffRow stores the fields and corresponding values for two got and want
  1390  // layers. If the value was nil then the string stored is the empty string.
  1391  type layerDiffRow struct {
  1392  	field, got, want string
  1393  }
  1394  
  1395  // diffLayer extracts all differing fields between two layers.
  1396  func diffLayer(got, want Layer) []layerDiffRow {
  1397  	vGot := reflect.ValueOf(got).Elem()
  1398  	vWant := reflect.ValueOf(want).Elem()
  1399  	if vGot.Type() != vWant.Type() {
  1400  		return nil
  1401  	}
  1402  	t := vGot.Type()
  1403  	var result []layerDiffRow
  1404  	for i := 0; i < t.NumField(); i++ {
  1405  		t := t.Field(i)
  1406  		if t.Anonymous {
  1407  			// Ignore the LayerBase in the Layer struct.
  1408  			continue
  1409  		}
  1410  		vGot := vGot.Field(i)
  1411  		vWant := vWant.Field(i)
  1412  		gotString := ""
  1413  		if !vGot.IsNil() {
  1414  			gotString = fmt.Sprint(reflect.Indirect(vGot))
  1415  		}
  1416  		wantString := ""
  1417  		if !vWant.IsNil() {
  1418  			wantString = fmt.Sprint(reflect.Indirect(vWant))
  1419  		}
  1420  		result = append(result, layerDiffRow{t.Name, gotString, wantString})
  1421  	}
  1422  	return result
  1423  }
  1424  
  1425  // layerType returns a concise string describing the type of the Layer, like
  1426  // "TCP", or "IPv6".
  1427  func layerType(l Layer) string {
  1428  	return reflect.TypeOf(l).Elem().Name()
  1429  }
  1430  
  1431  // diff compares Layers and returns a representation of the difference. Each
  1432  // Layer in the Layers is pairwise compared. If an element in either is nil, it
  1433  // is considered a match with the other Layer. If two Layers have differing
  1434  // types, they don't match regardless of the contents. If two Layers have the
  1435  // same type then the fields in the Layer are pairwise compared. Fields that are
  1436  // nil always match. Two non-nil fields only match if they point to equal
  1437  // values. diff returns an empty string if and only if *ls and other match.
  1438  func (ls *Layers) diff(other Layers) string {
  1439  	var allDiffs []layerDiff
  1440  	// Check the cases where one list is longer than the other, where one or both
  1441  	// elements are nil, where the sides have different types, and where the sides
  1442  	// have the same type.
  1443  	for i := 0; i < len(*ls) || i < len(other); i++ {
  1444  		if i >= len(*ls) {
  1445  			// Matching ls against other where other is longer than ls. missing
  1446  			// matches everything so we just include a label without any rows. Having
  1447  			// no rows is a sign that there was no diff.
  1448  			allDiffs = append(allDiffs, layerDiff{
  1449  				label: "missing matches " + layerType(other[i]),
  1450  			})
  1451  			continue
  1452  		}
  1453  
  1454  		if i >= len(other) {
  1455  			// Matching ls against other where ls is longer than other. missing
  1456  			// matches everything so we just include a label without any rows. Having
  1457  			// no rows is a sign that there was no diff.
  1458  			allDiffs = append(allDiffs, layerDiff{
  1459  				label: layerType((*ls)[i]) + " matches missing",
  1460  			})
  1461  			continue
  1462  		}
  1463  
  1464  		if (*ls)[i] == nil && other[i] == nil {
  1465  			// Matching ls against other where both elements are nil. nil matches
  1466  			// everything so we just include a label without any rows. Having no rows
  1467  			// is a sign that there was no diff.
  1468  			allDiffs = append(allDiffs, layerDiff{
  1469  				label: "nil matches nil",
  1470  			})
  1471  			continue
  1472  		}
  1473  
  1474  		if (*ls)[i] == nil {
  1475  			// Matching ls against other where the element in ls is nil. nil matches
  1476  			// everything so we just include a label without any rows. Having no rows
  1477  			// is a sign that there was no diff.
  1478  			allDiffs = append(allDiffs, layerDiff{
  1479  				label: "nil matches " + layerType(other[i]),
  1480  			})
  1481  			continue
  1482  		}
  1483  
  1484  		if other[i] == nil {
  1485  			// Matching ls against other where the element in other is nil. nil
  1486  			// matches everything so we just include a label without any rows. Having
  1487  			// no rows is a sign that there was no diff.
  1488  			allDiffs = append(allDiffs, layerDiff{
  1489  				label: layerType((*ls)[i]) + " matches nil",
  1490  			})
  1491  			continue
  1492  		}
  1493  
  1494  		if reflect.TypeOf((*ls)[i]) == reflect.TypeOf(other[i]) {
  1495  			// Matching ls against other where both elements have the same type. Match
  1496  			// each field pairwise and only report a diff if there is a mismatch,
  1497  			// which is only when both sides are non-nil and have differring values.
  1498  			diff := diffLayer((*ls)[i], other[i])
  1499  			var layerDiffRows []layerDiffRow
  1500  			for _, d := range diff {
  1501  				if d.got == "" || d.want == "" || d.got == d.want {
  1502  					continue
  1503  				}
  1504  				layerDiffRows = append(layerDiffRows, layerDiffRow{
  1505  					d.field,
  1506  					d.got,
  1507  					d.want,
  1508  				})
  1509  			}
  1510  			if len(layerDiffRows) > 0 {
  1511  				allDiffs = append(allDiffs, layerDiff{
  1512  					label: layerType((*ls)[i]),
  1513  					rows:  layerDiffRows,
  1514  				})
  1515  			} else {
  1516  				allDiffs = append(allDiffs, layerDiff{
  1517  					label: layerType((*ls)[i]) + " matches " + layerType(other[i]),
  1518  					// Having no rows is a sign that there was no diff.
  1519  				})
  1520  			}
  1521  			continue
  1522  		}
  1523  		// Neither side is nil and the types are different, so we'll display one
  1524  		// side then the other.
  1525  		allDiffs = append(allDiffs, layerDiff{
  1526  			label: layerType((*ls)[i]) + " doesn't match " + layerType(other[i]),
  1527  		})
  1528  		diff := diffLayer((*ls)[i], (*ls)[i])
  1529  		layerDiffRows := []layerDiffRow{}
  1530  		for _, d := range diff {
  1531  			if len(d.got) == 0 {
  1532  				continue
  1533  			}
  1534  			layerDiffRows = append(layerDiffRows, layerDiffRow{
  1535  				d.field,
  1536  				d.got,
  1537  				"",
  1538  			})
  1539  		}
  1540  		allDiffs = append(allDiffs, layerDiff{
  1541  			label: layerType((*ls)[i]),
  1542  			rows:  layerDiffRows,
  1543  		})
  1544  
  1545  		layerDiffRows = []layerDiffRow{}
  1546  		diff = diffLayer(other[i], other[i])
  1547  		for _, d := range diff {
  1548  			if len(d.want) == 0 {
  1549  				continue
  1550  			}
  1551  			layerDiffRows = append(layerDiffRows, layerDiffRow{
  1552  				d.field,
  1553  				"",
  1554  				d.want,
  1555  			})
  1556  		}
  1557  		allDiffs = append(allDiffs, layerDiff{
  1558  			label: layerType(other[i]),
  1559  			rows:  layerDiffRows,
  1560  		})
  1561  	}
  1562  
  1563  	output := ""
  1564  	// These are for output formatting.
  1565  	maxLabelLen, maxFieldLen, maxGotLen, maxWantLen := 0, 0, 0, 0
  1566  	foundOne := false
  1567  	for _, l := range allDiffs {
  1568  		if len(l.label) > maxLabelLen && len(l.rows) > 0 {
  1569  			maxLabelLen = len(l.label)
  1570  		}
  1571  		if l.rows != nil {
  1572  			foundOne = true
  1573  		}
  1574  		for _, r := range l.rows {
  1575  			if len(r.field) > maxFieldLen {
  1576  				maxFieldLen = len(r.field)
  1577  			}
  1578  			if l := len(fmt.Sprint(r.got)); l > maxGotLen {
  1579  				maxGotLen = l
  1580  			}
  1581  			if l := len(fmt.Sprint(r.want)); l > maxWantLen {
  1582  				maxWantLen = l
  1583  			}
  1584  		}
  1585  	}
  1586  	if !foundOne {
  1587  		return ""
  1588  	}
  1589  	for _, l := range allDiffs {
  1590  		if len(l.rows) == 0 {
  1591  			output += "(" + l.label + ")\n"
  1592  			continue
  1593  		}
  1594  		for i, r := range l.rows {
  1595  			var label string
  1596  			if i == 0 {
  1597  				label = l.label + ":"
  1598  			}
  1599  			output += fmt.Sprintf(
  1600  				"%*s %*s %*v %*v\n",
  1601  				maxLabelLen+1, label,
  1602  				maxFieldLen+1, r.field+":",
  1603  				maxGotLen, r.got,
  1604  				maxWantLen, r.want,
  1605  			)
  1606  		}
  1607  	}
  1608  	return output
  1609  }
  1610  
  1611  // merge merges the other Layers into ls. If the other Layers is longer, those
  1612  // additional Layer structs are added to ls. The errors from merging are
  1613  // collected and returned.
  1614  func (ls *Layers) merge(other Layers) error {
  1615  	var errs error
  1616  	for i, o := range other {
  1617  		if i < len(*ls) {
  1618  			errs = multierr.Combine(errs, (*ls)[i].merge(o))
  1619  		} else {
  1620  			*ls = append(*ls, o)
  1621  		}
  1622  	}
  1623  	return errs
  1624  }