github.com/polevpn/netstack@v1.10.9/tcpip/network/ipv4/ipv4.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package ipv4 contains the implementation of the ipv4 network protocol. To use
    16  // it in the networking stack, this package must be added to the project, and
    17  // activated on the stack by passing ipv4.NewProtocol() as one of the network
    18  // protocols when calling stack.New(). Then endpoints can be created by passing
    19  // ipv4.ProtocolNumber as the network protocol number when calling
    20  // Stack.NewEndpoint().
    21  package ipv4
    22  
    23  import (
    24  	"sync/atomic"
    25  
    26  	"github.com/polevpn/netstack/tcpip"
    27  	"github.com/polevpn/netstack/tcpip/buffer"
    28  	"github.com/polevpn/netstack/tcpip/header"
    29  	"github.com/polevpn/netstack/tcpip/network/fragmentation"
    30  	"github.com/polevpn/netstack/tcpip/network/hash"
    31  	"github.com/polevpn/netstack/tcpip/stack"
    32  )
    33  
    34  const (
    35  	// ProtocolNumber is the ipv4 protocol number.
    36  	ProtocolNumber = header.IPv4ProtocolNumber
    37  
    38  	// MaxTotalSize is maximum size that can be encoded in the 16-bit
    39  	// TotalLength field of the ipv4 header.
    40  	MaxTotalSize = 0xffff
    41  
    42  	// DefaultTTL is the default time-to-live value for this endpoint.
    43  	DefaultTTL = 64
    44  
    45  	// buckets is the number of identifier buckets.
    46  	buckets = 2048
    47  )
    48  
    49  type endpoint struct {
    50  	nicID         tcpip.NICID
    51  	id            stack.NetworkEndpointID
    52  	prefixLen     int
    53  	linkEP        stack.LinkEndpoint
    54  	dispatcher    stack.TransportDispatcher
    55  	fragmentation *fragmentation.Fragmentation
    56  	protocol      *protocol
    57  }
    58  
    59  // NewEndpoint creates a new ipv4 endpoint.
    60  func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
    61  	e := &endpoint{
    62  		nicID:         nicID,
    63  		id:            stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
    64  		prefixLen:     addrWithPrefix.PrefixLen,
    65  		linkEP:        linkEP,
    66  		dispatcher:    dispatcher,
    67  		fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
    68  		protocol:      p,
    69  	}
    70  
    71  	return e, nil
    72  }
    73  
    74  // DefaultTTL is the default time-to-live value for this endpoint.
    75  func (e *endpoint) DefaultTTL() uint8 {
    76  	return e.protocol.DefaultTTL()
    77  }
    78  
    79  // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
    80  // the network layer max header length.
    81  func (e *endpoint) MTU() uint32 {
    82  	return calculateMTU(e.linkEP.MTU())
    83  }
    84  
    85  // Capabilities implements stack.NetworkEndpoint.Capabilities.
    86  func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
    87  	return e.linkEP.Capabilities()
    88  }
    89  
    90  // NICID returns the ID of the NIC this endpoint belongs to.
    91  func (e *endpoint) NICID() tcpip.NICID {
    92  	return e.nicID
    93  }
    94  
    95  // ID returns the ipv4 endpoint ID.
    96  func (e *endpoint) ID() *stack.NetworkEndpointID {
    97  	return &e.id
    98  }
    99  
   100  // PrefixLen returns the ipv4 endpoint subnet prefix length in bits.
   101  func (e *endpoint) PrefixLen() int {
   102  	return e.prefixLen
   103  }
   104  
   105  // MaxHeaderLength returns the maximum length needed by ipv4 headers (and
   106  // underlying protocols).
   107  func (e *endpoint) MaxHeaderLength() uint16 {
   108  	return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
   109  }
   110  
   111  // GSOMaxSize returns the maximum GSO packet size.
   112  func (e *endpoint) GSOMaxSize() uint32 {
   113  	if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
   114  		return gso.GSOMaxSize()
   115  	}
   116  	return 0
   117  }
   118  
   119  // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
   120  // write. It assumes that the IP header is entirely in pkt.Header but does not
   121  // assume that only the IP header is in pkt.Header. It assumes that the input
   122  // packet's stated length matches the length of the header+payload. mtu
   123  // includes the IP header and options. This does not support the DontFragment
   124  // IP flag.
   125  func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error {
   126  	// This packet is too big, it needs to be fragmented.
   127  	ip := header.IPv4(pkt.Header.View())
   128  	flags := ip.Flags()
   129  
   130  	// Update mtu to take into account the header, which will exist in all
   131  	// fragments anyway.
   132  	innerMTU := mtu - int(ip.HeaderLength())
   133  
   134  	// Round the MTU down to align to 8 bytes. Then calculate the number of
   135  	// fragments. Calculate fragment sizes as in RFC791.
   136  	innerMTU &^= 7
   137  	n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
   138  
   139  	outerMTU := innerMTU + int(ip.HeaderLength())
   140  	offset := ip.FragmentOffset()
   141  	originalAvailableLength := pkt.Header.AvailableLength()
   142  	for i := 0; i < n; i++ {
   143  		// Where possible, the first fragment that is sent has the same
   144  		// pkt.Header.UsedLength() as the input packet. The link-layer
   145  		// endpoint may depend on this for looking at, eg, L4 headers.
   146  		h := ip
   147  		if i > 0 {
   148  			pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
   149  			h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength())))
   150  			copy(h, ip[:ip.HeaderLength()])
   151  		}
   152  		if i != n-1 {
   153  			h.SetTotalLength(uint16(outerMTU))
   154  			h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
   155  		} else {
   156  			h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size()))
   157  			h.SetFlagsFragmentOffset(flags, offset)
   158  		}
   159  		h.SetChecksum(0)
   160  		h.SetChecksum(^h.CalculateChecksum())
   161  		offset += uint16(innerMTU)
   162  		if i > 0 {
   163  			newPayload := pkt.Data.Clone(nil)
   164  			newPayload.CapLength(innerMTU)
   165  			if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
   166  				Header:        pkt.Header,
   167  				Data:          newPayload,
   168  				NetworkHeader: buffer.View(h),
   169  			}); err != nil {
   170  				return err
   171  			}
   172  			r.Stats().IP.PacketsSent.Increment()
   173  			pkt.Data.TrimFront(newPayload.Size())
   174  			continue
   175  		}
   176  		// Special handling for the first fragment because it comes
   177  		// from the header.
   178  		if outerMTU >= pkt.Header.UsedLength() {
   179  			// This fragment can fit all of pkt.Header and possibly
   180  			// some of pkt.Data, too.
   181  			newPayload := pkt.Data.Clone(nil)
   182  			newPayloadLength := outerMTU - pkt.Header.UsedLength()
   183  			newPayload.CapLength(newPayloadLength)
   184  			if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
   185  				Header:        pkt.Header,
   186  				Data:          newPayload,
   187  				NetworkHeader: buffer.View(h),
   188  			}); err != nil {
   189  				return err
   190  			}
   191  			r.Stats().IP.PacketsSent.Increment()
   192  			pkt.Data.TrimFront(newPayloadLength)
   193  		} else {
   194  			// The fragment is too small to fit all of pkt.Header.
   195  			startOfHdr := pkt.Header
   196  			startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
   197  			emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
   198  			if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{
   199  				Header:        startOfHdr,
   200  				Data:          emptyVV,
   201  				NetworkHeader: buffer.View(h),
   202  			}); err != nil {
   203  				return err
   204  			}
   205  			r.Stats().IP.PacketsSent.Increment()
   206  			// Add the unused bytes of pkt.Header into the pkt.Data
   207  			// that remains to be sent.
   208  			restOfHdr := pkt.Header.View()[outerMTU:]
   209  			tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
   210  			tmp.Append(pkt.Data)
   211  			pkt.Data = tmp
   212  		}
   213  	}
   214  	return nil
   215  }
   216  
   217  func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 {
   218  	ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
   219  	length := uint16(hdr.UsedLength() + payloadSize)
   220  	id := uint32(0)
   221  	if length > header.IPv4MaximumHeaderSize+8 {
   222  		// Packets of 68 bytes or less are required by RFC 791 to not be
   223  		// fragmented, so we only assign ids to larger packets.
   224  		id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
   225  	}
   226  	ip.Encode(&header.IPv4Fields{
   227  		IHL:         header.IPv4MinimumSize,
   228  		TotalLength: length,
   229  		ID:          uint16(id),
   230  		TTL:         params.TTL,
   231  		TOS:         params.TOS,
   232  		Protocol:    uint8(params.Protocol),
   233  		SrcAddr:     r.LocalAddress,
   234  		DstAddr:     r.RemoteAddress,
   235  	})
   236  	ip.SetChecksum(^ip.CalculateChecksum())
   237  	return ip
   238  }
   239  
   240  // WritePacket writes a packet to the given destination address and protocol.
   241  func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
   242  	ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
   243  
   244  	if loop&stack.PacketLoop != 0 {
   245  		views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
   246  		views[0] = pkt.Header.View()
   247  		views = append(views, pkt.Data.Views()...)
   248  		loopedR := r.MakeLoopedRoute()
   249  
   250  		e.HandlePacket(&loopedR, tcpip.PacketBuffer{
   251  			Data:          buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
   252  			NetworkHeader: buffer.View(ip),
   253  		})
   254  
   255  		loopedR.Release()
   256  	}
   257  	if loop&stack.PacketOut == 0 {
   258  		return nil
   259  	}
   260  	if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
   261  		return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
   262  	}
   263  	if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
   264  		return err
   265  	}
   266  	r.Stats().IP.PacketsSent.Increment()
   267  	return nil
   268  }
   269  
   270  // WritePackets implements stack.NetworkEndpoint.WritePackets.
   271  func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
   272  	if loop&stack.PacketLoop != 0 {
   273  		panic("multiple packets in local loop")
   274  	}
   275  	if loop&stack.PacketOut == 0 {
   276  		return len(hdrs), nil
   277  	}
   278  
   279  	for i := range hdrs {
   280  		e.addIPHeader(r, &hdrs[i].Hdr, hdrs[i].Size, params)
   281  	}
   282  	n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber)
   283  	r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
   284  	return n, err
   285  }
   286  
   287  // WriteHeaderIncludedPacket writes a packet already containing a network
   288  // header through the given route.
   289  func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error {
   290  	// The packet already has an IP header, but there are a few required
   291  	// checks.
   292  	ip := header.IPv4(pkt.Data.First())
   293  	if !ip.IsValid(pkt.Data.Size()) {
   294  		return tcpip.ErrInvalidOptionValue
   295  	}
   296  
   297  	// Always set the total length.
   298  	ip.SetTotalLength(uint16(pkt.Data.Size()))
   299  
   300  	// Set the source address when zero.
   301  	if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
   302  		ip.SetSourceAddress(r.LocalAddress)
   303  	}
   304  
   305  	// Set the destination. If the packet already included a destination,
   306  	// it will be part of the route.
   307  	ip.SetDestinationAddress(r.RemoteAddress)
   308  
   309  	// Set the packet ID when zero.
   310  	if ip.ID() == 0 {
   311  		id := uint32(0)
   312  		if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 {
   313  			// Packets of 68 bytes or less are required by RFC 791 to not be
   314  			// fragmented, so we only assign ids to larger packets.
   315  			id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
   316  		}
   317  		ip.SetID(uint16(id))
   318  	}
   319  
   320  	// Always set the checksum.
   321  	ip.SetChecksum(0)
   322  	ip.SetChecksum(^ip.CalculateChecksum())
   323  
   324  	if loop&stack.PacketLoop != 0 {
   325  		e.HandlePacket(r, pkt.Clone())
   326  	}
   327  	if loop&stack.PacketOut == 0 {
   328  		return nil
   329  	}
   330  
   331  	r.Stats().IP.PacketsSent.Increment()
   332  
   333  	ip = ip[:ip.HeaderLength()]
   334  	pkt.Header = buffer.NewPrependableFromView(buffer.View(ip))
   335  	pkt.Data.TrimFront(int(ip.HeaderLength()))
   336  	return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
   337  }
   338  
   339  // HandlePacket is called by the link layer when new ipv4 packets arrive for
   340  // this endpoint.
   341  func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
   342  	headerView := pkt.Data.First()
   343  	h := header.IPv4(headerView)
   344  	if !h.IsValid(pkt.Data.Size()) {
   345  		r.Stats().IP.MalformedPacketsReceived.Increment()
   346  		return
   347  	}
   348  	pkt.NetworkHeader = headerView[:h.HeaderLength()]
   349  
   350  	hlen := int(h.HeaderLength())
   351  	tlen := int(h.TotalLength())
   352  	pkt.Data.TrimFront(hlen)
   353  	pkt.Data.CapLength(tlen - hlen)
   354  
   355  	more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
   356  	if more || h.FragmentOffset() != 0 {
   357  		if pkt.Data.Size() == 0 {
   358  			// Drop the packet as it's marked as a fragment but has
   359  			// no payload.
   360  			r.Stats().IP.MalformedPacketsReceived.Increment()
   361  			r.Stats().IP.MalformedFragmentsReceived.Increment()
   362  			return
   363  		}
   364  		// The packet is a fragment, let's try to reassemble it.
   365  		last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1
   366  		// Drop the packet if the fragmentOffset is incorrect. i.e the
   367  		// combination of fragmentOffset and pkt.Data.size() causes a
   368  		// wrap around resulting in last being less than the offset.
   369  		if last < h.FragmentOffset() {
   370  			r.Stats().IP.MalformedPacketsReceived.Increment()
   371  			r.Stats().IP.MalformedFragmentsReceived.Increment()
   372  			return
   373  		}
   374  		var ready bool
   375  		var err error
   376  		pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data)
   377  		if err != nil {
   378  			r.Stats().IP.MalformedPacketsReceived.Increment()
   379  			r.Stats().IP.MalformedFragmentsReceived.Increment()
   380  			return
   381  		}
   382  		if !ready {
   383  			return
   384  		}
   385  	}
   386  	p := h.TransportProtocol()
   387  	if p == header.ICMPv4ProtocolNumber {
   388  		headerView.CapLength(hlen)
   389  		e.handleICMP(r, pkt)
   390  		return
   391  	}
   392  	r.Stats().IP.PacketsDelivered.Increment()
   393  	e.dispatcher.DeliverTransportPacket(r, p, pkt)
   394  }
   395  
   396  // Close cleans up resources associated with the endpoint.
   397  func (e *endpoint) Close() {}
   398  
   399  type protocol struct {
   400  	ids    []uint32
   401  	hashIV uint32
   402  
   403  	// defaultTTL is the current default TTL for the protocol. Only the
   404  	// uint8 portion of it is meaningful and it must be accessed
   405  	// atomically.
   406  	defaultTTL uint32
   407  }
   408  
   409  // Number returns the ipv4 protocol number.
   410  func (p *protocol) Number() tcpip.NetworkProtocolNumber {
   411  	return ProtocolNumber
   412  }
   413  
   414  // MinimumPacketSize returns the minimum valid ipv4 packet size.
   415  func (p *protocol) MinimumPacketSize() int {
   416  	return header.IPv4MinimumSize
   417  }
   418  
   419  // DefaultPrefixLen returns the IPv4 default prefix length.
   420  func (p *protocol) DefaultPrefixLen() int {
   421  	return header.IPv4AddressSize * 8
   422  }
   423  
   424  // ParseAddresses implements NetworkProtocol.ParseAddresses.
   425  func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
   426  	h := header.IPv4(v)
   427  	return h.SourceAddress(), h.DestinationAddress()
   428  }
   429  
   430  // SetOption implements NetworkProtocol.SetOption.
   431  func (p *protocol) SetOption(option interface{}) *tcpip.Error {
   432  	switch v := option.(type) {
   433  	case tcpip.DefaultTTLOption:
   434  		p.SetDefaultTTL(uint8(v))
   435  		return nil
   436  	default:
   437  		return tcpip.ErrUnknownProtocolOption
   438  	}
   439  }
   440  
   441  // Option implements NetworkProtocol.Option.
   442  func (p *protocol) Option(option interface{}) *tcpip.Error {
   443  	switch v := option.(type) {
   444  	case *tcpip.DefaultTTLOption:
   445  		*v = tcpip.DefaultTTLOption(p.DefaultTTL())
   446  		return nil
   447  	default:
   448  		return tcpip.ErrUnknownProtocolOption
   449  	}
   450  }
   451  
   452  // SetDefaultTTL sets the default TTL for endpoints created with this protocol.
   453  func (p *protocol) SetDefaultTTL(ttl uint8) {
   454  	atomic.StoreUint32(&p.defaultTTL, uint32(ttl))
   455  }
   456  
   457  // DefaultTTL returns the default TTL for endpoints created with this protocol.
   458  func (p *protocol) DefaultTTL() uint8 {
   459  	return uint8(atomic.LoadUint32(&p.defaultTTL))
   460  }
   461  
   462  // calculateMTU calculates the network-layer payload MTU based on the link-layer
   463  // payload mtu.
   464  func calculateMTU(mtu uint32) uint32 {
   465  	if mtu > MaxTotalSize {
   466  		mtu = MaxTotalSize
   467  	}
   468  	return mtu - header.IPv4MinimumSize
   469  }
   470  
   471  // hashRoute calculates a hash value for the given route. It uses the source &
   472  // destination address, the transport protocol number, and a random initial
   473  // value (generated once on initialization) to generate the hash.
   474  func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
   475  	t := r.LocalAddress
   476  	a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
   477  	t = r.RemoteAddress
   478  	b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
   479  	return hash.Hash3Words(a, b, uint32(protocol), hashIV)
   480  }
   481  
   482  // NewProtocol returns an IPv4 network protocol.
   483  func NewProtocol() stack.NetworkProtocol {
   484  	ids := make([]uint32, buckets)
   485  
   486  	// Randomly initialize hashIV and the ids.
   487  	r := hash.RandN32(1 + buckets)
   488  	for i := range ids {
   489  		ids[i] = r[i]
   490  	}
   491  	hashIV := r[buckets]
   492  
   493  	return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL}
   494  }