github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/packetimpact/testbench/layers.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testbench
    16  
    17  import (
    18  	"encoding/binary"
    19  	"encoding/hex"
    20  	"fmt"
    21  	"reflect"
    22  	"strings"
    23  
    24  	"github.com/google/go-cmp/cmp"
    25  	"github.com/google/go-cmp/cmp/cmpopts"
    26  	"go.uber.org/multierr"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    30  )
    31  
    32  // Layer is the interface that all encapsulations must implement.
    33  //
    34  // A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A
    35  // Layer contains all the fields of the encapsulation. Each field is a pointer
    36  // and may be nil.
    37  type Layer interface {
    38  	fmt.Stringer
    39  
    40  	// ToBytes converts the Layer into bytes. In places where the Layer's field
    41  	// isn't nil, the value that is pointed to is used. When the field is nil, a
    42  	// reasonable default for the Layer is used. For example, "64" for IPv4 TTL
    43  	// and a calculated checksum for TCP or IP. Some layers require information
    44  	// from the previous or next layers in order to compute a default, such as
    45  	// TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list
    46  	// to the layer's neighbors.
    47  	ToBytes() ([]byte, error)
    48  
    49  	// match checks if the current Layer matches the provided Layer. If either
    50  	// Layer has a nil in a given field, that field is considered matching.
    51  	// Otherwise, the values pointed to by the fields must match. The LayerBase is
    52  	// ignored.
    53  	match(Layer) bool
    54  
    55  	// length in bytes of the current encapsulation
    56  	length() int
    57  
    58  	// next gets a pointer to the encapsulated Layer.
    59  	next() Layer
    60  
    61  	// prev gets a pointer to the Layer encapsulating this one.
    62  	Prev() Layer
    63  
    64  	// setNext sets the pointer to the encapsulated Layer.
    65  	setNext(Layer)
    66  
    67  	// setPrev sets the pointer to the Layer encapsulating this one.
    68  	setPrev(Layer)
    69  
    70  	// merge overrides the values in the interface with the provided values.
    71  	merge(Layer) error
    72  }
    73  
    74  // LayerBase is the common elements of all layers.
    75  type LayerBase struct {
    76  	nextLayer Layer
    77  	prevLayer Layer
    78  }
    79  
    80  func (lb *LayerBase) next() Layer {
    81  	return lb.nextLayer
    82  }
    83  
    84  // Prev returns the previous layer.
    85  func (lb *LayerBase) Prev() Layer {
    86  	return lb.prevLayer
    87  }
    88  
    89  func (lb *LayerBase) setNext(l Layer) {
    90  	lb.nextLayer = l
    91  }
    92  
    93  func (lb *LayerBase) setPrev(l Layer) {
    94  	lb.prevLayer = l
    95  }
    96  
    97  // equalLayer compares that two Layer structs match while ignoring field in
    98  // which either input has a nil and also ignoring the LayerBase of the inputs.
    99  func equalLayer(x, y Layer) bool {
   100  	if x == nil || y == nil {
   101  		return true
   102  	}
   103  	// opt ignores comparison pairs where either of the inputs is a nil.
   104  	opt := cmp.FilterValues(func(x, y interface{}) bool {
   105  		for _, l := range []interface{}{x, y} {
   106  			v := reflect.ValueOf(l)
   107  			if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() {
   108  				return true
   109  			}
   110  		}
   111  		return false
   112  	}, cmp.Ignore())
   113  	return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{}))
   114  }
   115  
   116  // mergeLayer merges y into x. Any fields for which y has a non-nil value, that
   117  // value overwrite the corresponding fields in x.
   118  func mergeLayer(x, y Layer) error {
   119  	if y == nil {
   120  		return nil
   121  	}
   122  	if reflect.TypeOf(x) != reflect.TypeOf(y) {
   123  		return fmt.Errorf("can't merge %T into %T", y, x)
   124  	}
   125  	vx := reflect.ValueOf(x).Elem()
   126  	vy := reflect.ValueOf(y).Elem()
   127  	t := vy.Type()
   128  	for i := 0; i < vy.NumField(); i++ {
   129  		t := t.Field(i)
   130  		if t.Anonymous {
   131  			// Ignore the LayerBase in the Layer struct.
   132  			continue
   133  		}
   134  		v := vy.Field(i)
   135  		if v.IsNil() {
   136  			continue
   137  		}
   138  		vx.Field(i).Set(v)
   139  	}
   140  	return nil
   141  }
   142  
   143  func stringLayer(l Layer) string {
   144  	v := reflect.ValueOf(l).Elem()
   145  	t := v.Type()
   146  	var ret []string
   147  	for i := 0; i < v.NumField(); i++ {
   148  		t := t.Field(i)
   149  		if t.Anonymous {
   150  			// Ignore the LayerBase in the Layer struct.
   151  			continue
   152  		}
   153  		v := v.Field(i)
   154  		if v.IsNil() {
   155  			continue
   156  		}
   157  		v = reflect.Indirect(v)
   158  		if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
   159  			ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes())))
   160  		} else {
   161  			ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v))
   162  		}
   163  	}
   164  	return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " "))
   165  }
   166  
   167  // Ether can construct and match an ethernet encapsulation.
   168  type Ether struct {
   169  	LayerBase
   170  	SrcAddr *tcpip.LinkAddress
   171  	DstAddr *tcpip.LinkAddress
   172  	Type    *tcpip.NetworkProtocolNumber
   173  }
   174  
   175  func (l *Ether) String() string {
   176  	return stringLayer(l)
   177  }
   178  
   179  // ToBytes implements Layer.ToBytes.
   180  func (l *Ether) ToBytes() ([]byte, error) {
   181  	b := make([]byte, header.EthernetMinimumSize)
   182  	h := header.Ethernet(b)
   183  	fields := &header.EthernetFields{}
   184  	if l.SrcAddr != nil {
   185  		fields.SrcAddr = *l.SrcAddr
   186  	}
   187  	if l.DstAddr != nil {
   188  		fields.DstAddr = *l.DstAddr
   189  	}
   190  	if l.Type != nil {
   191  		fields.Type = *l.Type
   192  	} else {
   193  		switch n := l.next().(type) {
   194  		case *IPv4:
   195  			fields.Type = header.IPv4ProtocolNumber
   196  		case *IPv6:
   197  			fields.Type = header.IPv6ProtocolNumber
   198  		default:
   199  			return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n)
   200  		}
   201  	}
   202  	h.Encode(fields)
   203  	return h, nil
   204  }
   205  
   206  // LinkAddress is a helper routine that allocates a new tcpip.LinkAddress value
   207  // to store v and returns a pointer to it.
   208  func LinkAddress(v tcpip.LinkAddress) *tcpip.LinkAddress {
   209  	return &v
   210  }
   211  
   212  // NetworkProtocolNumber is a helper routine that allocates a new
   213  // tcpip.NetworkProtocolNumber value to store v and returns a pointer to it.
   214  func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocolNumber {
   215  	return &v
   216  }
   217  
   218  // layerParser parses the input bytes and returns a Layer along with the next
   219  // layerParser to run. If there is no more parsing to do, the returned
   220  // layerParser is nil.
   221  type layerParser func([]byte) (Layer, layerParser)
   222  
   223  // parse parses bytes starting with the first layerParser and using successive
   224  // layerParsers until all the bytes are parsed.
   225  func parse(parser layerParser, b []byte) Layers {
   226  	var layers Layers
   227  	for {
   228  		var layer Layer
   229  		layer, parser = parser(b)
   230  		layers = append(layers, layer)
   231  		if parser == nil {
   232  			break
   233  		}
   234  		b = b[layer.length():]
   235  	}
   236  	layers.linkLayers()
   237  	return layers
   238  }
   239  
   240  // parseEther parses the bytes assuming that they start with an ethernet header
   241  // and continues parsing further encapsulations.
   242  func parseEther(b []byte) (Layer, layerParser) {
   243  	h := header.Ethernet(b)
   244  	ether := Ether{
   245  		SrcAddr: LinkAddress(h.SourceAddress()),
   246  		DstAddr: LinkAddress(h.DestinationAddress()),
   247  		Type:    NetworkProtocolNumber(h.Type()),
   248  	}
   249  	var nextParser layerParser
   250  	switch h.Type() {
   251  	case header.IPv4ProtocolNumber:
   252  		nextParser = parseIPv4
   253  	case header.IPv6ProtocolNumber:
   254  		nextParser = parseIPv6
   255  	default:
   256  		// Assume that the rest is a payload.
   257  		nextParser = parsePayload
   258  	}
   259  	return &ether, nextParser
   260  }
   261  
   262  func (l *Ether) match(other Layer) bool {
   263  	return equalLayer(l, other)
   264  }
   265  
   266  func (l *Ether) length() int {
   267  	return header.EthernetMinimumSize
   268  }
   269  
   270  // merge implements Layer.merge.
   271  func (l *Ether) merge(other Layer) error {
   272  	return mergeLayer(l, other)
   273  }
   274  
   275  // IPv4 can construct and match an IPv4 encapsulation.
   276  type IPv4 struct {
   277  	LayerBase
   278  	IHL            *uint8
   279  	TOS            *uint8
   280  	TotalLength    *uint16
   281  	ID             *uint16
   282  	Flags          *uint8
   283  	FragmentOffset *uint16
   284  	TTL            *uint8
   285  	Protocol       *uint8
   286  	Checksum       *uint16
   287  	SrcAddr        *tcpip.Address
   288  	DstAddr        *tcpip.Address
   289  	Options        *header.IPv4Options
   290  }
   291  
   292  func (l *IPv4) String() string {
   293  	return stringLayer(l)
   294  }
   295  
   296  // ToBytes implements Layer.ToBytes.
   297  func (l *IPv4) ToBytes() ([]byte, error) {
   298  	// An IPv4 header is variable length depending on the size of the Options.
   299  	hdrLen := header.IPv4MinimumSize
   300  	if l.Options != nil {
   301  		if len(*l.Options)%4 != 0 {
   302  			return nil, fmt.Errorf("invalid header options '%x (len=%d)'; must be 32 bit aligned", *l.Options, len(*l.Options))
   303  		}
   304  		hdrLen += len(*l.Options)
   305  		if hdrLen > header.IPv4MaximumHeaderSize {
   306  			return nil, fmt.Errorf("IPv4 Options %d bytes, Max %d", len(*l.Options), header.IPv4MaximumOptionsSize)
   307  		}
   308  	}
   309  	b := make([]byte, hdrLen)
   310  	h := header.IPv4(b)
   311  	fields := &header.IPv4Fields{
   312  		TOS:            0,
   313  		TotalLength:    0,
   314  		ID:             0,
   315  		Flags:          0,
   316  		FragmentOffset: 0,
   317  		TTL:            64,
   318  		Protocol:       0,
   319  		Checksum:       0,
   320  		SrcAddr:        tcpip.Address(""),
   321  		DstAddr:        tcpip.Address(""),
   322  		Options:        nil,
   323  	}
   324  	if l.TOS != nil {
   325  		fields.TOS = *l.TOS
   326  	}
   327  	if l.TotalLength != nil {
   328  		fields.TotalLength = *l.TotalLength
   329  	} else {
   330  		fields.TotalLength = uint16(l.length())
   331  		current := l.next()
   332  		for current != nil {
   333  			fields.TotalLength += uint16(current.length())
   334  			current = current.next()
   335  		}
   336  	}
   337  	if l.ID != nil {
   338  		fields.ID = *l.ID
   339  	}
   340  	if l.Flags != nil {
   341  		fields.Flags = *l.Flags
   342  	}
   343  	if l.FragmentOffset != nil {
   344  		fields.FragmentOffset = *l.FragmentOffset
   345  	}
   346  	if l.TTL != nil {
   347  		fields.TTL = *l.TTL
   348  	}
   349  	if l.Protocol != nil {
   350  		fields.Protocol = *l.Protocol
   351  	} else {
   352  		switch n := l.next().(type) {
   353  		case *TCP:
   354  			fields.Protocol = uint8(header.TCPProtocolNumber)
   355  		case *UDP:
   356  			fields.Protocol = uint8(header.UDPProtocolNumber)
   357  		case *ICMPv4:
   358  			fields.Protocol = uint8(header.ICMPv4ProtocolNumber)
   359  		default:
   360  			// We can add support for more protocols as needed.
   361  			return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
   362  		}
   363  	}
   364  	if l.SrcAddr != nil {
   365  		fields.SrcAddr = *l.SrcAddr
   366  	}
   367  	if l.DstAddr != nil {
   368  		fields.DstAddr = *l.DstAddr
   369  	}
   370  
   371  	h.Encode(fields)
   372  
   373  	// Put raw option bytes from test definition in header. Options as raw bytes
   374  	// allows us to serialize malformed options, which is not possible with
   375  	// the provided serialization functions.
   376  	if l.Options != nil {
   377  		h.SetHeaderLength(h.HeaderLength() + uint8(len(*l.Options)))
   378  		if got, want := copy(h.Options(), *l.Options), len(*l.Options); got != want {
   379  			return nil, fmt.Errorf("failed to copy option bytes into header, got %d want %d", got, want)
   380  		}
   381  	}
   382  
   383  	// Encode cannot set this incorrectly so we need to overwrite what it wrote
   384  	// in order to test handling of a bad IHL value.
   385  	if l.IHL != nil {
   386  		h.SetHeaderLength(*l.IHL)
   387  	}
   388  
   389  	if l.Checksum == nil {
   390  		h.SetChecksum(^h.CalculateChecksum())
   391  	} else {
   392  		h.SetChecksum(*l.Checksum)
   393  	}
   394  
   395  	return h, nil
   396  }
   397  
   398  // Uint16 is a helper routine that allocates a new
   399  // uint16 value to store v and returns a pointer to it.
   400  func Uint16(v uint16) *uint16 {
   401  	return &v
   402  }
   403  
   404  // Uint8 is a helper routine that allocates a new
   405  // uint8 value to store v and returns a pointer to it.
   406  func Uint8(v uint8) *uint8 {
   407  	return &v
   408  }
   409  
   410  // TCPFlags is a helper routine that allocates a new
   411  // header.TCPFlags value to store v and returns a pointer to it.
   412  func TCPFlags(v header.TCPFlags) *header.TCPFlags {
   413  	return &v
   414  }
   415  
   416  // Address is a helper routine that allocates a new tcpip.Address value to
   417  // store v and returns a pointer to it.
   418  func Address(v tcpip.Address) *tcpip.Address {
   419  	return &v
   420  }
   421  
   422  // parseIPv4 parses the bytes assuming that they start with an ipv4 header and
   423  // continues parsing further encapsulations.
   424  func parseIPv4(b []byte) (Layer, layerParser) {
   425  	h := header.IPv4(b)
   426  	options := h.Options()
   427  	tos, _ := h.TOS()
   428  	ipv4 := IPv4{
   429  		IHL:            Uint8(h.HeaderLength()),
   430  		TOS:            &tos,
   431  		TotalLength:    Uint16(h.TotalLength()),
   432  		ID:             Uint16(h.ID()),
   433  		Flags:          Uint8(h.Flags()),
   434  		FragmentOffset: Uint16(h.FragmentOffset()),
   435  		TTL:            Uint8(h.TTL()),
   436  		Protocol:       Uint8(h.Protocol()),
   437  		Checksum:       Uint16(h.Checksum()),
   438  		SrcAddr:        Address(h.SourceAddress()),
   439  		DstAddr:        Address(h.DestinationAddress()),
   440  		Options:        &options,
   441  	}
   442  	var nextParser layerParser
   443  	// If it is a fragment, don't treat it as having a transport protocol.
   444  	if h.FragmentOffset() != 0 || h.More() {
   445  		return &ipv4, parsePayload
   446  	}
   447  	switch h.TransportProtocol() {
   448  	case header.TCPProtocolNumber:
   449  		nextParser = parseTCP
   450  	case header.UDPProtocolNumber:
   451  		nextParser = parseUDP
   452  	case header.ICMPv4ProtocolNumber:
   453  		nextParser = parseICMPv4
   454  	default:
   455  		// Assume that the rest is a payload.
   456  		nextParser = parsePayload
   457  	}
   458  	return &ipv4, nextParser
   459  }
   460  
   461  func (l *IPv4) match(other Layer) bool {
   462  	return equalLayer(l, other)
   463  }
   464  
   465  func (l *IPv4) length() int {
   466  	if l.IHL == nil {
   467  		return header.IPv4MinimumSize
   468  	}
   469  	return int(*l.IHL)
   470  }
   471  
   472  // merge implements Layer.merge.
   473  func (l *IPv4) merge(other Layer) error {
   474  	return mergeLayer(l, other)
   475  }
   476  
   477  // IPv6 can construct and match an IPv6 encapsulation.
   478  type IPv6 struct {
   479  	LayerBase
   480  	TrafficClass  *uint8
   481  	FlowLabel     *uint32
   482  	PayloadLength *uint16
   483  	NextHeader    *uint8
   484  	HopLimit      *uint8
   485  	SrcAddr       *tcpip.Address
   486  	DstAddr       *tcpip.Address
   487  }
   488  
   489  func (l *IPv6) String() string {
   490  	return stringLayer(l)
   491  }
   492  
   493  // ToBytes implements Layer.ToBytes.
   494  func (l *IPv6) ToBytes() ([]byte, error) {
   495  	b := make([]byte, header.IPv6MinimumSize)
   496  	h := header.IPv6(b)
   497  	fields := &header.IPv6Fields{
   498  		HopLimit: 64,
   499  	}
   500  	if l.TrafficClass != nil {
   501  		fields.TrafficClass = *l.TrafficClass
   502  	}
   503  	if l.FlowLabel != nil {
   504  		fields.FlowLabel = *l.FlowLabel
   505  	}
   506  	if l.PayloadLength != nil {
   507  		fields.PayloadLength = *l.PayloadLength
   508  	} else {
   509  		for current := l.next(); current != nil; current = current.next() {
   510  			fields.PayloadLength += uint16(current.length())
   511  		}
   512  	}
   513  	if l.NextHeader != nil {
   514  		fields.TransportProtocol = tcpip.TransportProtocolNumber(*l.NextHeader)
   515  	} else {
   516  		nh, err := nextHeaderByLayer(l.next())
   517  		if err != nil {
   518  			return nil, err
   519  		}
   520  		fields.TransportProtocol = tcpip.TransportProtocolNumber(nh)
   521  	}
   522  	if l.HopLimit != nil {
   523  		fields.HopLimit = *l.HopLimit
   524  	}
   525  	if l.SrcAddr != nil {
   526  		fields.SrcAddr = *l.SrcAddr
   527  	}
   528  	if l.DstAddr != nil {
   529  		fields.DstAddr = *l.DstAddr
   530  	}
   531  	h.Encode(fields)
   532  	return h, nil
   533  }
   534  
   535  // nextIPv6PayloadParser finds the corresponding parser for nextHeader.
   536  func nextIPv6PayloadParser(nextHeader uint8) layerParser {
   537  	switch tcpip.TransportProtocolNumber(nextHeader) {
   538  	case header.TCPProtocolNumber:
   539  		return parseTCP
   540  	case header.UDPProtocolNumber:
   541  		return parseUDP
   542  	case header.ICMPv6ProtocolNumber:
   543  		return parseICMPv6
   544  	}
   545  	switch header.IPv6ExtensionHeaderIdentifier(nextHeader) {
   546  	case header.IPv6HopByHopOptionsExtHdrIdentifier:
   547  		return parseIPv6HopByHopOptionsExtHdr
   548  	case header.IPv6DestinationOptionsExtHdrIdentifier:
   549  		return parseIPv6DestinationOptionsExtHdr
   550  	case header.IPv6FragmentExtHdrIdentifier:
   551  		return parseIPv6FragmentExtHdr
   552  	}
   553  	return parsePayload
   554  }
   555  
   556  // parseIPv6 parses the bytes assuming that they start with an ipv6 header and
   557  // continues parsing further encapsulations.
   558  func parseIPv6(b []byte) (Layer, layerParser) {
   559  	h := header.IPv6(b)
   560  	tos, flowLabel := h.TOS()
   561  	ipv6 := IPv6{
   562  		TrafficClass:  &tos,
   563  		FlowLabel:     &flowLabel,
   564  		PayloadLength: Uint16(h.PayloadLength()),
   565  		NextHeader:    Uint8(h.NextHeader()),
   566  		HopLimit:      Uint8(h.HopLimit()),
   567  		SrcAddr:       Address(h.SourceAddress()),
   568  		DstAddr:       Address(h.DestinationAddress()),
   569  	}
   570  	nextParser := nextIPv6PayloadParser(h.NextHeader())
   571  	return &ipv6, nextParser
   572  }
   573  
   574  func (l *IPv6) match(other Layer) bool {
   575  	return equalLayer(l, other)
   576  }
   577  
   578  func (l *IPv6) length() int {
   579  	return header.IPv6MinimumSize
   580  }
   581  
   582  // merge overrides the values in l with the values from other but only in fields
   583  // where the value is not nil.
   584  func (l *IPv6) merge(other Layer) error {
   585  	return mergeLayer(l, other)
   586  }
   587  
   588  // IPv6HopByHopOptionsExtHdr can construct and match an IPv6HopByHopOptions
   589  // Extension Header.
   590  type IPv6HopByHopOptionsExtHdr struct {
   591  	LayerBase
   592  	NextHeader *header.IPv6ExtensionHeaderIdentifier
   593  	Options    []byte
   594  }
   595  
   596  // IPv6DestinationOptionsExtHdr can construct and match an IPv6DestinationOptions
   597  // Extension Header.
   598  type IPv6DestinationOptionsExtHdr struct {
   599  	LayerBase
   600  	NextHeader *header.IPv6ExtensionHeaderIdentifier
   601  	Options    []byte
   602  }
   603  
   604  // IPv6FragmentExtHdr can construct and match an IPv6 Fragment Extension Header.
   605  type IPv6FragmentExtHdr struct {
   606  	LayerBase
   607  	NextHeader     *header.IPv6ExtensionHeaderIdentifier
   608  	FragmentOffset *uint16
   609  	MoreFragments  *bool
   610  	Identification *uint32
   611  }
   612  
   613  // nextHeaderByLayer finds the correct next header protocol value for layer l.
   614  func nextHeaderByLayer(l Layer) (uint8, error) {
   615  	if l == nil {
   616  		return uint8(header.IPv6NoNextHeaderIdentifier), nil
   617  	}
   618  	switch l.(type) {
   619  	case *TCP:
   620  		return uint8(header.TCPProtocolNumber), nil
   621  	case *UDP:
   622  		return uint8(header.UDPProtocolNumber), nil
   623  	case *ICMPv6:
   624  		return uint8(header.ICMPv6ProtocolNumber), nil
   625  	case *Payload:
   626  		return uint8(header.IPv6NoNextHeaderIdentifier), nil
   627  	case *IPv6HopByHopOptionsExtHdr:
   628  		return uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), nil
   629  	case *IPv6DestinationOptionsExtHdr:
   630  		return uint8(header.IPv6DestinationOptionsExtHdrIdentifier), nil
   631  	case *IPv6FragmentExtHdr:
   632  		return uint8(header.IPv6FragmentExtHdrIdentifier), nil
   633  	default:
   634  		// TODO(b/161005083): Support more protocols as needed.
   635  		return 0, fmt.Errorf("failed to deduce the IPv6 header's next protocol: %T", l)
   636  	}
   637  }
   638  
   639  // ipv6OptionsExtHdrToBytes serializes an options extension header into bytes.
   640  func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, nextLayer Layer, options []byte) ([]byte, error) {
   641  	length := len(options) + 2
   642  	if length%8 != 0 {
   643  		return nil, fmt.Errorf("IPv6 extension headers must be a multiple of 8 octets long, but the length given: %d, options: %s", length, hex.Dump(options))
   644  	}
   645  	bytes := make([]byte, length)
   646  	if nextHeader != nil {
   647  		bytes[0] = byte(*nextHeader)
   648  	} else {
   649  		nh, err := nextHeaderByLayer(nextLayer)
   650  		if err != nil {
   651  			return nil, err
   652  		}
   653  		bytes[0] = nh
   654  	}
   655  	// ExtHdrLen field is the length of the extension header
   656  	// in 8-octet unit, ignoring the first 8 octets.
   657  	// https://tools.ietf.org/html/rfc2460#section-4.3
   658  	// https://tools.ietf.org/html/rfc2460#section-4.6
   659  	bytes[1] = uint8((length - 8) / 8)
   660  	copy(bytes[2:], options)
   661  	return bytes, nil
   662  }
   663  
   664  // IPv6ExtHdrIdent is a helper routine that allocates a new
   665  // header.IPv6ExtensionHeaderIdentifier value to store v and returns a pointer
   666  // to it.
   667  func IPv6ExtHdrIdent(id header.IPv6ExtensionHeaderIdentifier) *header.IPv6ExtensionHeaderIdentifier {
   668  	return &id
   669  }
   670  
   671  // ToBytes implements Layer.ToBytes.
   672  func (l *IPv6HopByHopOptionsExtHdr) ToBytes() ([]byte, error) {
   673  	return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options)
   674  }
   675  
   676  // ToBytes implements Layer.ToBytes.
   677  func (l *IPv6DestinationOptionsExtHdr) ToBytes() ([]byte, error) {
   678  	return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options)
   679  }
   680  
   681  // ToBytes implements Layer.ToBytes.
   682  func (l *IPv6FragmentExtHdr) ToBytes() ([]byte, error) {
   683  	var offset, mflag uint16
   684  	var ident uint32
   685  	bytes := make([]byte, header.IPv6FragmentExtHdrLength)
   686  	if l.NextHeader != nil {
   687  		bytes[0] = byte(*l.NextHeader)
   688  	} else {
   689  		nh, err := nextHeaderByLayer(l.next())
   690  		if err != nil {
   691  			return nil, err
   692  		}
   693  		bytes[0] = nh
   694  	}
   695  	bytes[1] = 0 // reserved
   696  	if l.MoreFragments != nil && *l.MoreFragments {
   697  		mflag = 1
   698  	}
   699  	if l.FragmentOffset != nil {
   700  		offset = *l.FragmentOffset
   701  	}
   702  	if l.Identification != nil {
   703  		ident = *l.Identification
   704  	}
   705  	offsetAndMflag := offset<<3 | mflag
   706  	binary.BigEndian.PutUint16(bytes[2:], offsetAndMflag)
   707  	binary.BigEndian.PutUint32(bytes[4:], ident)
   708  
   709  	return bytes, nil
   710  }
   711  
   712  // parseIPv6ExtHdr parses an IPv6 extension header and returns the NextHeader
   713  // field, the rest of the payload and a parser function for the corresponding
   714  // next extension header.
   715  func parseIPv6ExtHdr(b []byte) (header.IPv6ExtensionHeaderIdentifier, []byte, layerParser) {
   716  	nextHeader := b[0]
   717  	// For HopByHop and Destination options extension headers,
   718  	// This field is the length of the extension header in
   719  	// 8-octet units, not including the first 8 octets.
   720  	// https://tools.ietf.org/html/rfc2460#section-4.3
   721  	// https://tools.ietf.org/html/rfc2460#section-4.6
   722  	length := b[1]*8 + 8
   723  	data := b[2:length]
   724  	nextParser := nextIPv6PayloadParser(nextHeader)
   725  	return header.IPv6ExtensionHeaderIdentifier(nextHeader), data, nextParser
   726  }
   727  
   728  // parseIPv6HopByHopOptionsExtHdr parses the bytes assuming that they start
   729  // with an IPv6 HopByHop Options Extension Header.
   730  func parseIPv6HopByHopOptionsExtHdr(b []byte) (Layer, layerParser) {
   731  	nextHeader, options, nextParser := parseIPv6ExtHdr(b)
   732  	return &IPv6HopByHopOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser
   733  }
   734  
   735  // parseIPv6DestinationOptionsExtHdr parses the bytes assuming that they start
   736  // with an IPv6 Destination Options Extension Header.
   737  func parseIPv6DestinationOptionsExtHdr(b []byte) (Layer, layerParser) {
   738  	nextHeader, options, nextParser := parseIPv6ExtHdr(b)
   739  	return &IPv6DestinationOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser
   740  }
   741  
   742  // Bool is a helper routine that allocates a new
   743  // bool value to store v and returns a pointer to it.
   744  func Bool(v bool) *bool {
   745  	return &v
   746  }
   747  
   748  // parseIPv6FragmentExtHdr parses the bytes assuming that they start
   749  // with an IPv6 Fragment Extension Header.
   750  func parseIPv6FragmentExtHdr(b []byte) (Layer, layerParser) {
   751  	nextHeader := b[0]
   752  	var extHdr header.IPv6FragmentExtHdr
   753  	copy(extHdr[:], b[2:])
   754  	fragLayer := IPv6FragmentExtHdr{
   755  		NextHeader:     IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(nextHeader)),
   756  		FragmentOffset: Uint16(extHdr.FragmentOffset()),
   757  		MoreFragments:  Bool(extHdr.More()),
   758  		Identification: Uint32(extHdr.ID()),
   759  	}
   760  	// If it is a fragment, we can't interpret it.
   761  	if extHdr.FragmentOffset() != 0 || extHdr.More() {
   762  		return &fragLayer, parsePayload
   763  	}
   764  	return &fragLayer, nextIPv6PayloadParser(nextHeader)
   765  }
   766  
   767  func (l *IPv6HopByHopOptionsExtHdr) length() int {
   768  	return len(l.Options) + 2
   769  }
   770  
   771  func (l *IPv6HopByHopOptionsExtHdr) match(other Layer) bool {
   772  	return equalLayer(l, other)
   773  }
   774  
   775  // merge overrides the values in l with the values from other but only in fields
   776  // where the value is not nil.
   777  func (l *IPv6HopByHopOptionsExtHdr) merge(other Layer) error {
   778  	return mergeLayer(l, other)
   779  }
   780  
   781  func (l *IPv6HopByHopOptionsExtHdr) String() string {
   782  	return stringLayer(l)
   783  }
   784  
   785  func (l *IPv6DestinationOptionsExtHdr) length() int {
   786  	return len(l.Options) + 2
   787  }
   788  
   789  func (l *IPv6DestinationOptionsExtHdr) match(other Layer) bool {
   790  	return equalLayer(l, other)
   791  }
   792  
   793  // merge overrides the values in l with the values from other but only in fields
   794  // where the value is not nil.
   795  func (l *IPv6DestinationOptionsExtHdr) merge(other Layer) error {
   796  	return mergeLayer(l, other)
   797  }
   798  
   799  func (l *IPv6DestinationOptionsExtHdr) String() string {
   800  	return stringLayer(l)
   801  }
   802  
   803  func (*IPv6FragmentExtHdr) length() int {
   804  	return header.IPv6FragmentExtHdrLength
   805  }
   806  
   807  func (l *IPv6FragmentExtHdr) match(other Layer) bool {
   808  	return equalLayer(l, other)
   809  }
   810  
   811  // merge overrides the values in l with the values from other but only in fields
   812  // where the value is not nil.
   813  func (l *IPv6FragmentExtHdr) merge(other Layer) error {
   814  	return mergeLayer(l, other)
   815  }
   816  
   817  func (l *IPv6FragmentExtHdr) String() string {
   818  	return stringLayer(l)
   819  }
   820  
   821  // ICMPv6 can construct and match an ICMPv6 encapsulation.
   822  type ICMPv6 struct {
   823  	LayerBase
   824  	Type     *header.ICMPv6Type
   825  	Code     *header.ICMPv6Code
   826  	Checksum *uint16
   827  	Ident    *uint16 // Only in Echo Request/Reply.
   828  	Pointer  *uint32 // Only in Parameter Problem.
   829  	Payload  []byte
   830  }
   831  
   832  func (l *ICMPv6) String() string {
   833  	// TODO(eyalsoha): Do something smarter here when *l.Type is ParameterProblem?
   834  	// We could parse the contents of the Payload as if it were an IPv6 packet.
   835  	return stringLayer(l)
   836  }
   837  
   838  // ToBytes implements Layer.ToBytes.
   839  func (l *ICMPv6) ToBytes() ([]byte, error) {
   840  	b := make([]byte, header.ICMPv6MinimumSize+len(l.Payload))
   841  	h := header.ICMPv6(b)
   842  	if l.Type != nil {
   843  		h.SetType(*l.Type)
   844  	}
   845  	if l.Code != nil {
   846  		h.SetCode(*l.Code)
   847  	}
   848  	if n := copy(h.Payload(), l.Payload); n != len(l.Payload) {
   849  		panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, len(l.Payload)))
   850  	}
   851  	typ := h.Type()
   852  	switch typ {
   853  	case header.ICMPv6EchoRequest, header.ICMPv6EchoReply:
   854  		if l.Ident != nil {
   855  			h.SetIdent(*l.Ident)
   856  		}
   857  	case header.ICMPv6ParamProblem:
   858  		if l.Pointer != nil {
   859  			h.SetTypeSpecific(*l.Pointer)
   860  		}
   861  	}
   862  	if l.Checksum != nil {
   863  		h.SetChecksum(*l.Checksum)
   864  	} else {
   865  		// It is possible that the ICMPv6 header does not follow the IPv6 header
   866  		// immediately, there could be one or more extension headers in between.
   867  		// We need to search backwards to find the IPv6 header.
   868  		for layer := l.Prev(); layer != nil; layer = layer.Prev() {
   869  			if ipv6, ok := layer.(*IPv6); ok {
   870  				h.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
   871  					Header:      h[:header.ICMPv6PayloadOffset],
   872  					Src:         *ipv6.SrcAddr,
   873  					Dst:         *ipv6.DstAddr,
   874  					PayloadCsum: header.Checksum(l.Payload, 0 /* initial */),
   875  					PayloadLen:  len(l.Payload),
   876  				}))
   877  				break
   878  			}
   879  		}
   880  	}
   881  	return h, nil
   882  }
   883  
   884  // ICMPv6Type is a helper routine that allocates a new ICMPv6Type value to store
   885  // v and returns a pointer to it.
   886  func ICMPv6Type(v header.ICMPv6Type) *header.ICMPv6Type {
   887  	return &v
   888  }
   889  
   890  // ICMPv6Code is a helper routine that allocates a new ICMPv6Type value to store
   891  // v and returns a pointer to it.
   892  func ICMPv6Code(v header.ICMPv6Code) *header.ICMPv6Code {
   893  	return &v
   894  }
   895  
   896  // parseICMPv6 parses the bytes assuming that they start with an ICMPv6 header.
   897  func parseICMPv6(b []byte) (Layer, layerParser) {
   898  	h := header.ICMPv6(b)
   899  	msgType := h.Type()
   900  	icmpv6 := ICMPv6{
   901  		Type:     ICMPv6Type(msgType),
   902  		Code:     ICMPv6Code(h.Code()),
   903  		Checksum: Uint16(h.Checksum()),
   904  		Payload:  h.Payload(),
   905  	}
   906  	switch msgType {
   907  	case header.ICMPv6EchoRequest, header.ICMPv6EchoReply:
   908  		icmpv6.Ident = Uint16(h.Ident())
   909  	case header.ICMPv6ParamProblem:
   910  		icmpv6.Pointer = Uint32(h.TypeSpecific())
   911  	}
   912  	return &icmpv6, nil
   913  }
   914  
   915  func (l *ICMPv6) match(other Layer) bool {
   916  	return equalLayer(l, other)
   917  }
   918  
   919  func (l *ICMPv6) length() int {
   920  	return header.ICMPv6MinimumSize + len(l.Payload)
   921  }
   922  
   923  // merge overrides the values in l with the values from other but only in fields
   924  // where the value is not nil.
   925  func (l *ICMPv6) merge(other Layer) error {
   926  	return mergeLayer(l, other)
   927  }
   928  
   929  // ICMPv4 can construct and match an ICMPv4 encapsulation.
   930  type ICMPv4 struct {
   931  	LayerBase
   932  	Type     *header.ICMPv4Type
   933  	Code     *header.ICMPv4Code
   934  	Checksum *uint16
   935  	Ident    *uint16 // Only in Echo Request/Reply.
   936  	Sequence *uint16 // Only in Echo Request/Reply.
   937  	Pointer  *uint8  // Only in Parameter Problem.
   938  	Payload  []byte
   939  }
   940  
   941  func (l *ICMPv4) String() string {
   942  	return stringLayer(l)
   943  }
   944  
   945  // ICMPv4Type is a helper routine that allocates a new header.ICMPv4Type value
   946  // to store t and returns a pointer to it.
   947  func ICMPv4Type(t header.ICMPv4Type) *header.ICMPv4Type {
   948  	return &t
   949  }
   950  
   951  // ICMPv4Code is a helper routine that allocates a new header.ICMPv4Code value
   952  // to store t and returns a pointer to it.
   953  func ICMPv4Code(t header.ICMPv4Code) *header.ICMPv4Code {
   954  	return &t
   955  }
   956  
   957  // ToBytes implements Layer.ToBytes.
   958  func (l *ICMPv4) ToBytes() ([]byte, error) {
   959  	b := make([]byte, header.ICMPv4MinimumSize+len(l.Payload))
   960  	h := header.ICMPv4(b)
   961  	if l.Type != nil {
   962  		h.SetType(*l.Type)
   963  	}
   964  	if l.Code != nil {
   965  		h.SetCode(*l.Code)
   966  	}
   967  	if n := copy(h.Payload(), l.Payload); n != len(l.Payload) {
   968  		panic(fmt.Sprintf("wrong number of bytes copied into h.Payload(): got = %d, want = %d", n, len(l.Payload)))
   969  	}
   970  	typ := h.Type()
   971  	switch typ {
   972  	case header.ICMPv4EchoReply, header.ICMPv4Echo:
   973  		if l.Ident != nil {
   974  			h.SetIdent(*l.Ident)
   975  		}
   976  		if l.Sequence != nil {
   977  			h.SetSequence(*l.Sequence)
   978  		}
   979  	case header.ICMPv4ParamProblem:
   980  		if l.Pointer != nil {
   981  			h.SetPointer(*l.Pointer)
   982  		}
   983  	}
   984  
   985  	// The checksum must be handled last because the ICMPv4 header fields are
   986  	// included in the computation.
   987  	if l.Checksum != nil {
   988  		h.SetChecksum(*l.Checksum)
   989  	} else {
   990  		h.SetChecksum(^header.Checksum(h, 0))
   991  	}
   992  
   993  	return h, nil
   994  }
   995  
   996  // parseICMPv4 parses the bytes as an ICMPv4 header, returning a Layer and a
   997  // parser for the encapsulated payload.
   998  func parseICMPv4(b []byte) (Layer, layerParser) {
   999  	h := header.ICMPv4(b)
  1000  
  1001  	msgType := h.Type()
  1002  	icmpv4 := ICMPv4{
  1003  		Type:     ICMPv4Type(msgType),
  1004  		Code:     ICMPv4Code(h.Code()),
  1005  		Checksum: Uint16(h.Checksum()),
  1006  		Payload:  h.Payload(),
  1007  	}
  1008  	switch msgType {
  1009  	case header.ICMPv4EchoReply, header.ICMPv4Echo:
  1010  		icmpv4.Ident = Uint16(h.Ident())
  1011  		icmpv4.Sequence = Uint16(h.Sequence())
  1012  	case header.ICMPv4ParamProblem:
  1013  		icmpv4.Pointer = Uint8(h.Pointer())
  1014  	}
  1015  	return &icmpv4, nil
  1016  }
  1017  
  1018  func (l *ICMPv4) match(other Layer) bool {
  1019  	return equalLayer(l, other)
  1020  }
  1021  
  1022  func (l *ICMPv4) length() int {
  1023  	return header.ICMPv4MinimumSize + len(l.Payload)
  1024  }
  1025  
  1026  // merge overrides the values in l with the values from other but only in fields
  1027  // where the value is not nil.
  1028  func (l *ICMPv4) merge(other Layer) error {
  1029  	return mergeLayer(l, other)
  1030  }
  1031  
  1032  // TCP can construct and match a TCP encapsulation.
  1033  type TCP struct {
  1034  	LayerBase
  1035  	SrcPort       *uint16
  1036  	DstPort       *uint16
  1037  	SeqNum        *uint32
  1038  	AckNum        *uint32
  1039  	DataOffset    *uint8
  1040  	Flags         *header.TCPFlags
  1041  	WindowSize    *uint16
  1042  	Checksum      *uint16
  1043  	UrgentPointer *uint16
  1044  	Options       []byte
  1045  }
  1046  
  1047  func (l *TCP) String() string {
  1048  	return stringLayer(l)
  1049  }
  1050  
  1051  // ToBytes implements Layer.ToBytes.
  1052  func (l *TCP) ToBytes() ([]byte, error) {
  1053  	b := make([]byte, l.length())
  1054  	h := header.TCP(b)
  1055  	if l.SrcPort != nil {
  1056  		h.SetSourcePort(*l.SrcPort)
  1057  	}
  1058  	if l.DstPort != nil {
  1059  		h.SetDestinationPort(*l.DstPort)
  1060  	}
  1061  	if l.SeqNum != nil {
  1062  		h.SetSequenceNumber(*l.SeqNum)
  1063  	}
  1064  	if l.AckNum != nil {
  1065  		h.SetAckNumber(*l.AckNum)
  1066  	}
  1067  	if l.DataOffset != nil {
  1068  		h.SetDataOffset(*l.DataOffset)
  1069  	} else {
  1070  		h.SetDataOffset(uint8(l.length()))
  1071  	}
  1072  	if l.Flags != nil {
  1073  		h.SetFlags(uint8(*l.Flags))
  1074  	}
  1075  	if l.WindowSize != nil {
  1076  		h.SetWindowSize(*l.WindowSize)
  1077  	} else {
  1078  		h.SetWindowSize(32768)
  1079  	}
  1080  	if l.UrgentPointer != nil {
  1081  		h.SetUrgentPoiner(*l.UrgentPointer)
  1082  	}
  1083  	copy(b[header.TCPMinimumSize:], l.Options)
  1084  	header.AddTCPOptionPadding(b[header.TCPMinimumSize:], len(l.Options))
  1085  	if l.Checksum != nil {
  1086  		h.SetChecksum(*l.Checksum)
  1087  		return h, nil
  1088  	}
  1089  	if err := setTCPChecksum(&h, l); err != nil {
  1090  		return nil, err
  1091  	}
  1092  	return h, nil
  1093  }
  1094  
  1095  // totalLength returns the length of the provided layer and all following
  1096  // layers.
  1097  func totalLength(l Layer) int {
  1098  	var totalLength int
  1099  	for ; l != nil; l = l.next() {
  1100  		totalLength += l.length()
  1101  	}
  1102  	return totalLength
  1103  }
  1104  
  1105  // payload returns a buffer.VectorisedView of l's payload.
  1106  func payload(l Layer) (buffer.VectorisedView, error) {
  1107  	var payloadBytes buffer.VectorisedView
  1108  	for current := l.next(); current != nil; current = current.next() {
  1109  		payload, err := current.ToBytes()
  1110  		if err != nil {
  1111  			return buffer.VectorisedView{}, fmt.Errorf("can't get bytes for next header: %s", payload)
  1112  		}
  1113  		payloadBytes.AppendView(payload)
  1114  	}
  1115  	return payloadBytes, nil
  1116  }
  1117  
  1118  // layerChecksum calculates the checksum of the Layer header, including the
  1119  // peusdeochecksum of the layer before it and all the bytes after it.
  1120  func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) {
  1121  	totalLength := uint16(totalLength(l))
  1122  	var xsum uint16
  1123  	switch p := l.Prev().(type) {
  1124  	case *IPv4:
  1125  		xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength)
  1126  	case *IPv6:
  1127  		xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength)
  1128  	default:
  1129  		// TODO(b/161246171): Support more protocols.
  1130  		return 0, fmt.Errorf("checksum for protocol %d is not supported when previous layer is %T", protoNumber, p)
  1131  	}
  1132  	payloadBytes, err := payload(l)
  1133  	if err != nil {
  1134  		return 0, err
  1135  	}
  1136  	xsum = header.ChecksumVV(payloadBytes, xsum)
  1137  	return xsum, nil
  1138  }
  1139  
  1140  // setTCPChecksum calculates the checksum of the TCP header and sets it in h.
  1141  func setTCPChecksum(h *header.TCP, tcp *TCP) error {
  1142  	h.SetChecksum(0)
  1143  	xsum, err := layerChecksum(tcp, header.TCPProtocolNumber)
  1144  	if err != nil {
  1145  		return err
  1146  	}
  1147  	h.SetChecksum(^h.CalculateChecksum(xsum))
  1148  	return nil
  1149  }
  1150  
  1151  // Uint32 is a helper routine that allocates a new
  1152  // uint32 value to store v and returns a pointer to it.
  1153  func Uint32(v uint32) *uint32 {
  1154  	return &v
  1155  }
  1156  
  1157  // parseTCP parses the bytes assuming that they start with a tcp header and
  1158  // continues parsing further encapsulations.
  1159  func parseTCP(b []byte) (Layer, layerParser) {
  1160  	h := header.TCP(b)
  1161  	tcp := TCP{
  1162  		SrcPort:       Uint16(h.SourcePort()),
  1163  		DstPort:       Uint16(h.DestinationPort()),
  1164  		SeqNum:        Uint32(h.SequenceNumber()),
  1165  		AckNum:        Uint32(h.AckNumber()),
  1166  		DataOffset:    Uint8(h.DataOffset()),
  1167  		Flags:         TCPFlags(h.Flags()),
  1168  		WindowSize:    Uint16(h.WindowSize()),
  1169  		Checksum:      Uint16(h.Checksum()),
  1170  		UrgentPointer: Uint16(h.UrgentPointer()),
  1171  		Options:       b[header.TCPMinimumSize:h.DataOffset()],
  1172  	}
  1173  	return &tcp, parsePayload
  1174  }
  1175  
  1176  func (l *TCP) match(other Layer) bool {
  1177  	return equalLayer(l, other)
  1178  }
  1179  
  1180  func (l *TCP) length() int {
  1181  	if l.DataOffset == nil {
  1182  		// TCP header including the options must end on a 32-bit
  1183  		// boundary; the user could potentially give us a slice
  1184  		// whose length is not a multiple of 4 bytes, so we have
  1185  		// to do the alignment here.
  1186  		optlen := (len(l.Options) + 3) & ^3
  1187  		return header.TCPMinimumSize + optlen
  1188  	}
  1189  	return int(*l.DataOffset)
  1190  }
  1191  
  1192  // merge implements Layer.merge.
  1193  func (l *TCP) merge(other Layer) error {
  1194  	return mergeLayer(l, other)
  1195  }
  1196  
  1197  // UDP can construct and match a UDP encapsulation.
  1198  type UDP struct {
  1199  	LayerBase
  1200  	SrcPort  *uint16
  1201  	DstPort  *uint16
  1202  	Length   *uint16
  1203  	Checksum *uint16
  1204  }
  1205  
  1206  func (l *UDP) String() string {
  1207  	return stringLayer(l)
  1208  }
  1209  
  1210  // ToBytes implements Layer.ToBytes.
  1211  func (l *UDP) ToBytes() ([]byte, error) {
  1212  	b := make([]byte, header.UDPMinimumSize)
  1213  	h := header.UDP(b)
  1214  	if l.SrcPort != nil {
  1215  		h.SetSourcePort(*l.SrcPort)
  1216  	}
  1217  	if l.DstPort != nil {
  1218  		h.SetDestinationPort(*l.DstPort)
  1219  	}
  1220  	if l.Length != nil {
  1221  		h.SetLength(*l.Length)
  1222  	} else {
  1223  		h.SetLength(uint16(totalLength(l)))
  1224  	}
  1225  	if l.Checksum != nil {
  1226  		h.SetChecksum(*l.Checksum)
  1227  		return h, nil
  1228  	}
  1229  	if err := setUDPChecksum(&h, l); err != nil {
  1230  		return nil, err
  1231  	}
  1232  	return h, nil
  1233  }
  1234  
  1235  // setUDPChecksum calculates the checksum of the UDP header and sets it in h.
  1236  func setUDPChecksum(h *header.UDP, udp *UDP) error {
  1237  	h.SetChecksum(0)
  1238  	xsum, err := layerChecksum(udp, header.UDPProtocolNumber)
  1239  	if err != nil {
  1240  		return err
  1241  	}
  1242  	h.SetChecksum(^h.CalculateChecksum(xsum))
  1243  	return nil
  1244  }
  1245  
  1246  // parseUDP parses the bytes assuming that they start with a udp header and
  1247  // returns the parsed layer and the next parser to use.
  1248  func parseUDP(b []byte) (Layer, layerParser) {
  1249  	h := header.UDP(b)
  1250  	udp := UDP{
  1251  		SrcPort:  Uint16(h.SourcePort()),
  1252  		DstPort:  Uint16(h.DestinationPort()),
  1253  		Length:   Uint16(h.Length()),
  1254  		Checksum: Uint16(h.Checksum()),
  1255  	}
  1256  	return &udp, parsePayload
  1257  }
  1258  
  1259  func (l *UDP) match(other Layer) bool {
  1260  	return equalLayer(l, other)
  1261  }
  1262  
  1263  func (l *UDP) length() int {
  1264  	return header.UDPMinimumSize
  1265  }
  1266  
  1267  // merge implements Layer.merge.
  1268  func (l *UDP) merge(other Layer) error {
  1269  	return mergeLayer(l, other)
  1270  }
  1271  
  1272  // Payload has bytes beyond OSI layer 4.
  1273  type Payload struct {
  1274  	LayerBase
  1275  	Bytes []byte
  1276  }
  1277  
  1278  func (l *Payload) String() string {
  1279  	return stringLayer(l)
  1280  }
  1281  
  1282  // parsePayload parses the bytes assuming that they start with a payload and
  1283  // continue to the end. There can be no further encapsulations.
  1284  func parsePayload(b []byte) (Layer, layerParser) {
  1285  	payload := Payload{
  1286  		Bytes: b,
  1287  	}
  1288  	return &payload, nil
  1289  }
  1290  
  1291  // ToBytes implements Layer.ToBytes.
  1292  func (l *Payload) ToBytes() ([]byte, error) {
  1293  	return l.Bytes, nil
  1294  }
  1295  
  1296  // Length returns payload byte length.
  1297  func (l *Payload) Length() int {
  1298  	return l.length()
  1299  }
  1300  
  1301  func (l *Payload) match(other Layer) bool {
  1302  	return equalLayer(l, other)
  1303  }
  1304  
  1305  func (l *Payload) length() int {
  1306  	return len(l.Bytes)
  1307  }
  1308  
  1309  // merge implements Layer.merge.
  1310  func (l *Payload) merge(other Layer) error {
  1311  	return mergeLayer(l, other)
  1312  }
  1313  
  1314  // Layers is an array of Layer and supports similar functions to Layer.
  1315  type Layers []Layer
  1316  
  1317  // linkLayers sets the linked-list ponters in ls.
  1318  func (ls *Layers) linkLayers() {
  1319  	for i, l := range *ls {
  1320  		if i > 0 {
  1321  			l.setPrev((*ls)[i-1])
  1322  		} else {
  1323  			l.setPrev(nil)
  1324  		}
  1325  		if i+1 < len(*ls) {
  1326  			l.setNext((*ls)[i+1])
  1327  		} else {
  1328  			l.setNext(nil)
  1329  		}
  1330  	}
  1331  }
  1332  
  1333  // ToBytes converts the Layers into bytes. It creates a linked list of the Layer
  1334  // structs and then concatentates the output of ToBytes on each Layer.
  1335  func (ls *Layers) ToBytes() ([]byte, error) {
  1336  	ls.linkLayers()
  1337  	outBytes := []byte{}
  1338  	for _, l := range *ls {
  1339  		layerBytes, err := l.ToBytes()
  1340  		if err != nil {
  1341  			return nil, err
  1342  		}
  1343  		outBytes = append(outBytes, layerBytes...)
  1344  	}
  1345  	return outBytes, nil
  1346  }
  1347  
  1348  func (ls *Layers) match(other Layers) bool {
  1349  	if len(*ls) > len(other) {
  1350  		return false
  1351  	}
  1352  	for i, l := range *ls {
  1353  		if !equalLayer(l, other[i]) {
  1354  			return false
  1355  		}
  1356  	}
  1357  	return true
  1358  }
  1359  
  1360  // layerDiff stores the diffs for each field along with the label for the Layer.
  1361  // If rows is nil, that means that there was no diff.
  1362  type layerDiff struct {
  1363  	label string
  1364  	rows  []layerDiffRow
  1365  }
  1366  
  1367  // layerDiffRow stores the fields and corresponding values for two got and want
  1368  // layers. If the value was nil then the string stored is the empty string.
  1369  type layerDiffRow struct {
  1370  	field, got, want string
  1371  }
  1372  
  1373  // diffLayer extracts all differing fields between two layers.
  1374  func diffLayer(got, want Layer) []layerDiffRow {
  1375  	vGot := reflect.ValueOf(got).Elem()
  1376  	vWant := reflect.ValueOf(want).Elem()
  1377  	if vGot.Type() != vWant.Type() {
  1378  		return nil
  1379  	}
  1380  	t := vGot.Type()
  1381  	var result []layerDiffRow
  1382  	for i := 0; i < t.NumField(); i++ {
  1383  		t := t.Field(i)
  1384  		if t.Anonymous {
  1385  			// Ignore the LayerBase in the Layer struct.
  1386  			continue
  1387  		}
  1388  		vGot := vGot.Field(i)
  1389  		vWant := vWant.Field(i)
  1390  		gotString := ""
  1391  		if !vGot.IsNil() {
  1392  			gotString = fmt.Sprint(reflect.Indirect(vGot))
  1393  		}
  1394  		wantString := ""
  1395  		if !vWant.IsNil() {
  1396  			wantString = fmt.Sprint(reflect.Indirect(vWant))
  1397  		}
  1398  		result = append(result, layerDiffRow{t.Name, gotString, wantString})
  1399  	}
  1400  	return result
  1401  }
  1402  
  1403  // layerType returns a concise string describing the type of the Layer, like
  1404  // "TCP", or "IPv6".
  1405  func layerType(l Layer) string {
  1406  	return reflect.TypeOf(l).Elem().Name()
  1407  }
  1408  
  1409  // diff compares Layers and returns a representation of the difference. Each
  1410  // Layer in the Layers is pairwise compared. If an element in either is nil, it
  1411  // is considered a match with the other Layer. If two Layers have differing
  1412  // types, they don't match regardless of the contents. If two Layers have the
  1413  // same type then the fields in the Layer are pairwise compared. Fields that are
  1414  // nil always match. Two non-nil fields only match if they point to equal
  1415  // values. diff returns an empty string if and only if *ls and other match.
  1416  func (ls *Layers) diff(other Layers) string {
  1417  	var allDiffs []layerDiff
  1418  	// Check the cases where one list is longer than the other, where one or both
  1419  	// elements are nil, where the sides have different types, and where the sides
  1420  	// have the same type.
  1421  	for i := 0; i < len(*ls) || i < len(other); i++ {
  1422  		if i >= len(*ls) {
  1423  			// Matching ls against other where other is longer than ls. missing
  1424  			// matches everything so we just include a label without any rows. Having
  1425  			// no rows is a sign that there was no diff.
  1426  			allDiffs = append(allDiffs, layerDiff{
  1427  				label: "missing matches " + layerType(other[i]),
  1428  			})
  1429  			continue
  1430  		}
  1431  
  1432  		if i >= len(other) {
  1433  			// Matching ls against other where ls is longer than other. missing
  1434  			// matches everything so we just include a label without any rows. Having
  1435  			// no rows is a sign that there was no diff.
  1436  			allDiffs = append(allDiffs, layerDiff{
  1437  				label: layerType((*ls)[i]) + " matches missing",
  1438  			})
  1439  			continue
  1440  		}
  1441  
  1442  		if (*ls)[i] == nil && other[i] == nil {
  1443  			// Matching ls against other where both elements are nil. nil matches
  1444  			// everything so we just include a label without any rows. Having no rows
  1445  			// is a sign that there was no diff.
  1446  			allDiffs = append(allDiffs, layerDiff{
  1447  				label: "nil matches nil",
  1448  			})
  1449  			continue
  1450  		}
  1451  
  1452  		if (*ls)[i] == nil {
  1453  			// Matching ls against other where the element in ls is nil. nil matches
  1454  			// everything so we just include a label without any rows. Having no rows
  1455  			// is a sign that there was no diff.
  1456  			allDiffs = append(allDiffs, layerDiff{
  1457  				label: "nil matches " + layerType(other[i]),
  1458  			})
  1459  			continue
  1460  		}
  1461  
  1462  		if other[i] == nil {
  1463  			// Matching ls against other where the element in other is nil. nil
  1464  			// matches everything so we just include a label without any rows. Having
  1465  			// no rows is a sign that there was no diff.
  1466  			allDiffs = append(allDiffs, layerDiff{
  1467  				label: layerType((*ls)[i]) + " matches nil",
  1468  			})
  1469  			continue
  1470  		}
  1471  
  1472  		if reflect.TypeOf((*ls)[i]) == reflect.TypeOf(other[i]) {
  1473  			// Matching ls against other where both elements have the same type. Match
  1474  			// each field pairwise and only report a diff if there is a mismatch,
  1475  			// which is only when both sides are non-nil and have differring values.
  1476  			diff := diffLayer((*ls)[i], other[i])
  1477  			var layerDiffRows []layerDiffRow
  1478  			for _, d := range diff {
  1479  				if d.got == "" || d.want == "" || d.got == d.want {
  1480  					continue
  1481  				}
  1482  				layerDiffRows = append(layerDiffRows, layerDiffRow{
  1483  					d.field,
  1484  					d.got,
  1485  					d.want,
  1486  				})
  1487  			}
  1488  			if len(layerDiffRows) > 0 {
  1489  				allDiffs = append(allDiffs, layerDiff{
  1490  					label: layerType((*ls)[i]),
  1491  					rows:  layerDiffRows,
  1492  				})
  1493  			} else {
  1494  				allDiffs = append(allDiffs, layerDiff{
  1495  					label: layerType((*ls)[i]) + " matches " + layerType(other[i]),
  1496  					// Having no rows is a sign that there was no diff.
  1497  				})
  1498  			}
  1499  			continue
  1500  		}
  1501  		// Neither side is nil and the types are different, so we'll display one
  1502  		// side then the other.
  1503  		allDiffs = append(allDiffs, layerDiff{
  1504  			label: layerType((*ls)[i]) + " doesn't match " + layerType(other[i]),
  1505  		})
  1506  		diff := diffLayer((*ls)[i], (*ls)[i])
  1507  		layerDiffRows := []layerDiffRow{}
  1508  		for _, d := range diff {
  1509  			if len(d.got) == 0 {
  1510  				continue
  1511  			}
  1512  			layerDiffRows = append(layerDiffRows, layerDiffRow{
  1513  				d.field,
  1514  				d.got,
  1515  				"",
  1516  			})
  1517  		}
  1518  		allDiffs = append(allDiffs, layerDiff{
  1519  			label: layerType((*ls)[i]),
  1520  			rows:  layerDiffRows,
  1521  		})
  1522  
  1523  		layerDiffRows = []layerDiffRow{}
  1524  		diff = diffLayer(other[i], other[i])
  1525  		for _, d := range diff {
  1526  			if len(d.want) == 0 {
  1527  				continue
  1528  			}
  1529  			layerDiffRows = append(layerDiffRows, layerDiffRow{
  1530  				d.field,
  1531  				"",
  1532  				d.want,
  1533  			})
  1534  		}
  1535  		allDiffs = append(allDiffs, layerDiff{
  1536  			label: layerType(other[i]),
  1537  			rows:  layerDiffRows,
  1538  		})
  1539  	}
  1540  
  1541  	output := ""
  1542  	// These are for output formatting.
  1543  	maxLabelLen, maxFieldLen, maxGotLen, maxWantLen := 0, 0, 0, 0
  1544  	foundOne := false
  1545  	for _, l := range allDiffs {
  1546  		if len(l.label) > maxLabelLen && len(l.rows) > 0 {
  1547  			maxLabelLen = len(l.label)
  1548  		}
  1549  		if l.rows != nil {
  1550  			foundOne = true
  1551  		}
  1552  		for _, r := range l.rows {
  1553  			if len(r.field) > maxFieldLen {
  1554  				maxFieldLen = len(r.field)
  1555  			}
  1556  			if l := len(fmt.Sprint(r.got)); l > maxGotLen {
  1557  				maxGotLen = l
  1558  			}
  1559  			if l := len(fmt.Sprint(r.want)); l > maxWantLen {
  1560  				maxWantLen = l
  1561  			}
  1562  		}
  1563  	}
  1564  	if !foundOne {
  1565  		return ""
  1566  	}
  1567  	for _, l := range allDiffs {
  1568  		if len(l.rows) == 0 {
  1569  			output += "(" + l.label + ")\n"
  1570  			continue
  1571  		}
  1572  		for i, r := range l.rows {
  1573  			var label string
  1574  			if i == 0 {
  1575  				label = l.label + ":"
  1576  			}
  1577  			output += fmt.Sprintf(
  1578  				"%*s %*s %*v %*v\n",
  1579  				maxLabelLen+1, label,
  1580  				maxFieldLen+1, r.field+":",
  1581  				maxGotLen, r.got,
  1582  				maxWantLen, r.want,
  1583  			)
  1584  		}
  1585  	}
  1586  	return output
  1587  }
  1588  
  1589  // merge merges the other Layers into ls. If the other Layers is longer, those
  1590  // additional Layer structs are added to ls. The errors from merging are
  1591  // collected and returned.
  1592  func (ls *Layers) merge(other Layers) error {
  1593  	var errs error
  1594  	for i, o := range other {
  1595  		if i < len(*ls) {
  1596  			errs = multierr.Combine(errs, (*ls)[i].merge(o))
  1597  		} else {
  1598  			*ls = append(*ls, o)
  1599  		}
  1600  	}
  1601  	return errs
  1602  }