github.com/FlowerWrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/network/ipv4/ipv4.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package ipv4 contains the implementation of the ipv4 network protocol. To use
    16  // it in the networking stack, this package must be added to the project, and
    17  // activated on the stack by passing ipv4.NewProtocol() as one of the network
    18  // protocols when calling stack.New(). Then endpoints can be created by passing
    19  // ipv4.ProtocolNumber as the network protocol number when calling
    20  // Stack.NewEndpoint().
    21  package ipv4
    22  
    23  import (
    24  	"sync/atomic"
    25  
    26  	"github.com/FlowerWrong/netstack/tcpip"
    27  	"github.com/FlowerWrong/netstack/tcpip/buffer"
    28  	"github.com/FlowerWrong/netstack/tcpip/header"
    29  	"github.com/FlowerWrong/netstack/tcpip/network/fragmentation"
    30  	"github.com/FlowerWrong/netstack/tcpip/network/hash"
    31  	"github.com/FlowerWrong/netstack/tcpip/stack"
    32  )
    33  
    34  const (
    35  	// ProtocolNumber is the ipv4 protocol number.
    36  	ProtocolNumber = header.IPv4ProtocolNumber
    37  
    38  	// MaxTotalSize is maximum size that can be encoded in the 16-bit
    39  	// TotalLength field of the ipv4 header.
    40  	MaxTotalSize = 0xffff
    41  
    42  	// DefaultTTL is the default time-to-live value for this endpoint.
    43  	DefaultTTL = 64
    44  
    45  	// buckets is the number of identifier buckets.
    46  	buckets = 2048
    47  )
    48  
    49  type endpoint struct {
    50  	nicid         tcpip.NICID
    51  	id            stack.NetworkEndpointID
    52  	prefixLen     int
    53  	linkEP        stack.LinkEndpoint
    54  	dispatcher    stack.TransportDispatcher
    55  	fragmentation *fragmentation.Fragmentation
    56  	protocol      *protocol
    57  }
    58  
    59  // NewEndpoint creates a new ipv4 endpoint.
    60  func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
    61  	e := &endpoint{
    62  		nicid:         nicid,
    63  		id:            stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
    64  		prefixLen:     addrWithPrefix.PrefixLen,
    65  		linkEP:        linkEP,
    66  		dispatcher:    dispatcher,
    67  		fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
    68  		protocol:      p,
    69  	}
    70  
    71  	return e, nil
    72  }
    73  
    74  // DefaultTTL is the default time-to-live value for this endpoint.
    75  func (e *endpoint) DefaultTTL() uint8 {
    76  	return e.protocol.DefaultTTL()
    77  }
    78  
    79  // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
    80  // the network layer max header length.
    81  func (e *endpoint) MTU() uint32 {
    82  	return calculateMTU(e.linkEP.MTU())
    83  }
    84  
    85  // Capabilities implements stack.NetworkEndpoint.Capabilities.
    86  func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
    87  	return e.linkEP.Capabilities()
    88  }
    89  
    90  // NICID returns the ID of the NIC this endpoint belongs to.
    91  func (e *endpoint) NICID() tcpip.NICID {
    92  	return e.nicid
    93  }
    94  
    95  // ID returns the ipv4 endpoint ID.
    96  func (e *endpoint) ID() *stack.NetworkEndpointID {
    97  	return &e.id
    98  }
    99  
   100  // PrefixLen returns the ipv4 endpoint subnet prefix length in bits.
   101  func (e *endpoint) PrefixLen() int {
   102  	return e.prefixLen
   103  }
   104  
   105  // MaxHeaderLength returns the maximum length needed by ipv4 headers (and
   106  // underlying protocols).
   107  func (e *endpoint) MaxHeaderLength() uint16 {
   108  	return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
   109  }
   110  
   111  // GSOMaxSize returns the maximum GSO packet size.
   112  func (e *endpoint) GSOMaxSize() uint32 {
   113  	if gso, ok := e.linkEP.(stack.GSOEndpoint); ok {
   114  		return gso.GSOMaxSize()
   115  	}
   116  	return 0
   117  }
   118  
   119  // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
   120  // write. It assumes that the IP header is entirely in hdr but does not assume
   121  // that only the IP header is in hdr. It assumes that the input packet's stated
   122  // length matches the length of the hdr+payload. mtu includes the IP header and
   123  // options. This does not support the DontFragment IP flag.
   124  func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error {
   125  	// This packet is too big, it needs to be fragmented.
   126  	ip := header.IPv4(hdr.View())
   127  	flags := ip.Flags()
   128  
   129  	// Update mtu to take into account the header, which will exist in all
   130  	// fragments anyway.
   131  	innerMTU := mtu - int(ip.HeaderLength())
   132  
   133  	// Round the MTU down to align to 8 bytes. Then calculate the number of
   134  	// fragments. Calculate fragment sizes as in RFC791.
   135  	innerMTU &^= 7
   136  	n := (int(ip.PayloadLength()) + innerMTU - 1) / innerMTU
   137  
   138  	outerMTU := innerMTU + int(ip.HeaderLength())
   139  	offset := ip.FragmentOffset()
   140  	originalAvailableLength := hdr.AvailableLength()
   141  	for i := 0; i < n; i++ {
   142  		// Where possible, the first fragment that is sent has the same
   143  		// hdr.UsedLength() as the input packet. The link-layer endpoint may depends
   144  		// on this for looking at, eg, L4 headers.
   145  		h := ip
   146  		if i > 0 {
   147  			hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
   148  			h = header.IPv4(hdr.Prepend(int(ip.HeaderLength())))
   149  			copy(h, ip[:ip.HeaderLength()])
   150  		}
   151  		if i != n-1 {
   152  			h.SetTotalLength(uint16(outerMTU))
   153  			h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
   154  		} else {
   155  			h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size()))
   156  			h.SetFlagsFragmentOffset(flags, offset)
   157  		}
   158  		h.SetChecksum(0)
   159  		h.SetChecksum(^h.CalculateChecksum())
   160  		offset += uint16(innerMTU)
   161  		if i > 0 {
   162  			newPayload := payload.Clone([]buffer.View{})
   163  			newPayload.CapLength(innerMTU)
   164  			if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
   165  				return err
   166  			}
   167  			r.Stats().IP.PacketsSent.Increment()
   168  			payload.TrimFront(newPayload.Size())
   169  			continue
   170  		}
   171  		// Special handling for the first fragment because it comes from the hdr.
   172  		if outerMTU >= hdr.UsedLength() {
   173  			// This fragment can fit all of hdr and possibly some of payload, too.
   174  			newPayload := payload.Clone([]buffer.View{})
   175  			newPayloadLength := outerMTU - hdr.UsedLength()
   176  			newPayload.CapLength(newPayloadLength)
   177  			if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
   178  				return err
   179  			}
   180  			r.Stats().IP.PacketsSent.Increment()
   181  			payload.TrimFront(newPayloadLength)
   182  		} else {
   183  			// The fragment is too small to fit all of hdr.
   184  			startOfHdr := hdr
   185  			startOfHdr.TrimBack(hdr.UsedLength() - outerMTU)
   186  			emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
   187  			if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil {
   188  				return err
   189  			}
   190  			r.Stats().IP.PacketsSent.Increment()
   191  			// Add the unused bytes of hdr into the payload that remains to be sent.
   192  			restOfHdr := hdr.View()[outerMTU:]
   193  			tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
   194  			tmp.Append(payload)
   195  			payload = tmp
   196  		}
   197  	}
   198  	return nil
   199  }
   200  
   201  // WritePacket writes a packet to the given destination address and protocol.
   202  func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error {
   203  	ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
   204  	length := uint16(hdr.UsedLength() + payload.Size())
   205  	id := uint32(0)
   206  	if length > header.IPv4MaximumHeaderSize+8 {
   207  		// Packets of 68 bytes or less are required by RFC 791 to not be
   208  		// fragmented, so we only assign ids to larger packets.
   209  		id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, protocol, e.protocol.hashIV)%buckets], 1)
   210  	}
   211  	ip.Encode(&header.IPv4Fields{
   212  		IHL:         header.IPv4MinimumSize,
   213  		TotalLength: length,
   214  		ID:          uint16(id),
   215  		TTL:         ttl,
   216  		Protocol:    uint8(protocol),
   217  		SrcAddr:     r.LocalAddress,
   218  		DstAddr:     r.RemoteAddress,
   219  	})
   220  	ip.SetChecksum(^ip.CalculateChecksum())
   221  
   222  	if loop&stack.PacketLoop != 0 {
   223  		views := make([]buffer.View, 1, 1+len(payload.Views()))
   224  		views[0] = hdr.View()
   225  		views = append(views, payload.Views()...)
   226  		vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
   227  		loopedR := r.MakeLoopedRoute()
   228  		e.HandlePacket(&loopedR, vv)
   229  		loopedR.Release()
   230  	}
   231  	if loop&stack.PacketOut == 0 {
   232  		return nil
   233  	}
   234  	if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
   235  		return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU()))
   236  	}
   237  	if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil {
   238  		return err
   239  	}
   240  	r.Stats().IP.PacketsSent.Increment()
   241  	return nil
   242  }
   243  
   244  // WriteHeaderIncludedPacket writes a packet already containing a network
   245  // header through the given route.
   246  func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
   247  	// The packet already has an IP header, but there are a few required
   248  	// checks.
   249  	ip := header.IPv4(payload.First())
   250  	if !ip.IsValid(payload.Size()) {
   251  		return tcpip.ErrInvalidOptionValue
   252  	}
   253  
   254  	// Always set the total length.
   255  	ip.SetTotalLength(uint16(payload.Size()))
   256  
   257  	// Set the source address when zero.
   258  	if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
   259  		ip.SetSourceAddress(r.LocalAddress)
   260  	}
   261  
   262  	// Set the destination. If the packet already included a destination,
   263  	// it will be part of the route.
   264  	ip.SetDestinationAddress(r.RemoteAddress)
   265  
   266  	// Set the packet ID when zero.
   267  	if ip.ID() == 0 {
   268  		id := uint32(0)
   269  		if payload.Size() > header.IPv4MaximumHeaderSize+8 {
   270  			// Packets of 68 bytes or less are required by RFC 791 to not be
   271  			// fragmented, so we only assign ids to larger packets.
   272  			id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
   273  		}
   274  		ip.SetID(uint16(id))
   275  	}
   276  
   277  	// Always set the checksum.
   278  	ip.SetChecksum(0)
   279  	ip.SetChecksum(^ip.CalculateChecksum())
   280  
   281  	if loop&stack.PacketLoop != 0 {
   282  		e.HandlePacket(r, payload)
   283  	}
   284  	if loop&stack.PacketOut == 0 {
   285  		return nil
   286  	}
   287  
   288  	hdr := buffer.NewPrependableFromView(payload.ToView())
   289  	r.Stats().IP.PacketsSent.Increment()
   290  	return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
   291  }
   292  
   293  // HandlePacket is called by the link layer when new ipv4 packets arrive for
   294  // this endpoint.
   295  func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
   296  	headerView := vv.First()
   297  	h := header.IPv4(headerView)
   298  	if !h.IsValid(vv.Size()) {
   299  		return
   300  	}
   301  
   302  	hlen := int(h.HeaderLength())
   303  	tlen := int(h.TotalLength())
   304  	vv.TrimFront(hlen)
   305  	vv.CapLength(tlen - hlen)
   306  
   307  	more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
   308  	if more || h.FragmentOffset() != 0 {
   309  		// The packet is a fragment, let's try to reassemble it.
   310  		last := h.FragmentOffset() + uint16(vv.Size()) - 1
   311  		var ready bool
   312  		vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
   313  		if !ready {
   314  			return
   315  		}
   316  	}
   317  	p := h.TransportProtocol()
   318  	if p == header.ICMPv4ProtocolNumber {
   319  		headerView.CapLength(hlen)
   320  		e.handleICMP(r, headerView, vv)
   321  		return
   322  	}
   323  	r.Stats().IP.PacketsDelivered.Increment()
   324  	e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
   325  }
   326  
   327  // Close cleans up resources associated with the endpoint.
   328  func (e *endpoint) Close() {}
   329  
   330  type protocol struct {
   331  	ids    []uint32
   332  	hashIV uint32
   333  
   334  	// defaultTTL is the current default TTL for the protocol. Only the
   335  	// uint8 portion of it is meaningful and it must be accessed
   336  	// atomically.
   337  	defaultTTL uint32
   338  }
   339  
   340  // Number returns the ipv4 protocol number.
   341  func (p *protocol) Number() tcpip.NetworkProtocolNumber {
   342  	return ProtocolNumber
   343  }
   344  
   345  // MinimumPacketSize returns the minimum valid ipv4 packet size.
   346  func (p *protocol) MinimumPacketSize() int {
   347  	return header.IPv4MinimumSize
   348  }
   349  
   350  // DefaultPrefixLen returns the IPv4 default prefix length.
   351  func (p *protocol) DefaultPrefixLen() int {
   352  	return header.IPv4AddressSize * 8
   353  }
   354  
   355  // ParseAddresses implements NetworkProtocol.ParseAddresses.
   356  func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
   357  	h := header.IPv4(v)
   358  	return h.SourceAddress(), h.DestinationAddress()
   359  }
   360  
   361  // SetOption implements NetworkProtocol.SetOption.
   362  func (p *protocol) SetOption(option interface{}) *tcpip.Error {
   363  	switch v := option.(type) {
   364  	case tcpip.DefaultTTLOption:
   365  		p.SetDefaultTTL(uint8(v))
   366  		return nil
   367  	default:
   368  		return tcpip.ErrUnknownProtocolOption
   369  	}
   370  }
   371  
   372  // Option implements NetworkProtocol.Option.
   373  func (p *protocol) Option(option interface{}) *tcpip.Error {
   374  	switch v := option.(type) {
   375  	case *tcpip.DefaultTTLOption:
   376  		*v = tcpip.DefaultTTLOption(p.DefaultTTL())
   377  		return nil
   378  	default:
   379  		return tcpip.ErrUnknownProtocolOption
   380  	}
   381  }
   382  
   383  // SetDefaultTTL sets the default TTL for endpoints created with this protocol.
   384  func (p *protocol) SetDefaultTTL(ttl uint8) {
   385  	atomic.StoreUint32(&p.defaultTTL, uint32(ttl))
   386  }
   387  
   388  // DefaultTTL returns the default TTL for endpoints created with this protocol.
   389  func (p *protocol) DefaultTTL() uint8 {
   390  	return uint8(atomic.LoadUint32(&p.defaultTTL))
   391  }
   392  
   393  // calculateMTU calculates the network-layer payload MTU based on the link-layer
   394  // payload mtu.
   395  func calculateMTU(mtu uint32) uint32 {
   396  	if mtu > MaxTotalSize {
   397  		mtu = MaxTotalSize
   398  	}
   399  	return mtu - header.IPv4MinimumSize
   400  }
   401  
   402  // hashRoute calculates a hash value for the given route. It uses the source &
   403  // destination address, the transport protocol number, and a random initial
   404  // value (generated once on initialization) to generate the hash.
   405  func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
   406  	t := r.LocalAddress
   407  	a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
   408  	t = r.RemoteAddress
   409  	b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
   410  	return hash.Hash3Words(a, b, uint32(protocol), hashIV)
   411  }
   412  
   413  // NewProtocol returns an IPv4 network protocol.
   414  func NewProtocol() stack.NetworkProtocol {
   415  	ids := make([]uint32, buckets)
   416  
   417  	// Randomly initialize hashIV and the ids.
   418  	r := hash.RandN32(1 + buckets)
   419  	for i := range ids {
   420  		ids[i] = r[i]
   421  	}
   422  	hashIV := r[buckets]
   423  
   424  	return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL}
   425  }