github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/tcpip/network/internal/fragmentation/fragmentation.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package fragmentation contains the implementation of IP fragmentation.
    16  // It is based on RFC 791, RFC 815 and RFC 8200.
    17  package fragmentation
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"time"
    23  
    24  	"github.com/sagernet/gvisor/pkg/buffer"
    25  	"github.com/sagernet/gvisor/pkg/log"
    26  	"github.com/sagernet/gvisor/pkg/sync"
    27  	"github.com/sagernet/gvisor/pkg/tcpip"
    28  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    29  )
    30  
    31  const (
    32  	// HighFragThreshold is the threshold at which we start trimming old
    33  	// fragmented packets. Linux uses a default value of 4 MB. See
    34  	// net.ipv4.ipfrag_high_thresh for more information.
    35  	HighFragThreshold = 4 << 20 // 4MB
    36  
    37  	// LowFragThreshold is the threshold we reach to when we start dropping
    38  	// older fragmented packets. It's important that we keep enough room for newer
    39  	// packets to be re-assembled. Hence, this needs to be lower than
    40  	// HighFragThreshold enough. Linux uses a default value of 3 MB. See
    41  	// net.ipv4.ipfrag_low_thresh for more information.
    42  	LowFragThreshold = 3 << 20 // 3MB
    43  
    44  	// minBlockSize is the minimum block size for fragments.
    45  	minBlockSize = 1
    46  )
    47  
    48  var (
    49  	// ErrInvalidArgs indicates to the caller that an invalid argument was
    50  	// provided.
    51  	ErrInvalidArgs = errors.New("invalid args")
    52  
    53  	// ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps
    54  	// with another one.
    55  	ErrFragmentOverlap = errors.New("overlapping fragments")
    56  
    57  	// ErrFragmentConflict indicates that, during reassembly, some fragments are
    58  	// in conflict with one another.
    59  	ErrFragmentConflict = errors.New("conflicting fragments")
    60  )
    61  
    62  // FragmentID is the identifier for a fragment.
    63  type FragmentID struct {
    64  	// Source is the source address of the fragment.
    65  	Source tcpip.Address
    66  
    67  	// Destination is the destination address of the fragment.
    68  	Destination tcpip.Address
    69  
    70  	// ID is the identification value of the fragment.
    71  	//
    72  	// This is a uint32 because IPv6 uses a 32-bit identification value.
    73  	ID uint32
    74  
    75  	// The protocol for the packet.
    76  	Protocol uint8
    77  }
    78  
    79  // Fragmentation is the main structure that other modules
    80  // of the stack should use to implement IP Fragmentation.
    81  type Fragmentation struct {
    82  	mu             sync.Mutex
    83  	highLimit      int
    84  	lowLimit       int
    85  	reassemblers   map[FragmentID]*reassembler
    86  	rList          reassemblerList
    87  	memSize        int
    88  	timeout        time.Duration
    89  	blockSize      uint16
    90  	clock          tcpip.Clock
    91  	releaseJob     *tcpip.Job
    92  	timeoutHandler TimeoutHandler
    93  }
    94  
    95  // TimeoutHandler is consulted if a packet reassembly has timed out.
    96  type TimeoutHandler interface {
    97  	// OnReassemblyTimeout will be called with the first fragment (or nil, if the
    98  	// first fragment has not been received) of a packet whose reassembly has
    99  	// timed out.
   100  	OnReassemblyTimeout(pkt *stack.PacketBuffer)
   101  }
   102  
   103  // NewFragmentation creates a new Fragmentation.
   104  //
   105  // blockSize specifies the fragment block size, in bytes.
   106  //
   107  // highMemoryLimit specifies the limit on the memory consumed
   108  // by the fragments stored by Fragmentation (overhead of internal data-structures
   109  // is not accounted). Fragments are dropped when the limit is reached.
   110  //
   111  // lowMemoryLimit specifies the limit on which we will reach by dropping
   112  // fragments after reaching highMemoryLimit.
   113  //
   114  // reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
   115  // Fragments are lazily evicted only when a new a packet with an
   116  // already existing fragmentation-id arrives after the timeout.
   117  func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation {
   118  	if lowMemoryLimit >= highMemoryLimit {
   119  		lowMemoryLimit = highMemoryLimit
   120  	}
   121  
   122  	if lowMemoryLimit < 0 {
   123  		lowMemoryLimit = 0
   124  	}
   125  
   126  	if blockSize < minBlockSize {
   127  		blockSize = minBlockSize
   128  	}
   129  
   130  	f := &Fragmentation{
   131  		reassemblers:   make(map[FragmentID]*reassembler),
   132  		highLimit:      highMemoryLimit,
   133  		lowLimit:       lowMemoryLimit,
   134  		timeout:        reassemblingTimeout,
   135  		blockSize:      blockSize,
   136  		clock:          clock,
   137  		timeoutHandler: timeoutHandler,
   138  	}
   139  	f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked)
   140  
   141  	return f
   142  }
   143  
   144  // Process processes an incoming fragment belonging to an ID and returns a
   145  // complete packet and its protocol number when all the packets belonging to
   146  // that ID have been received.
   147  //
   148  // [first, last] is the range of the fragment bytes.
   149  //
   150  // first must be a multiple of the block size f is configured with. The size
   151  // of the fragment data must be a multiple of the block size, unless there are
   152  // no fragments following this fragment (more set to false).
   153  //
   154  // proto is the protocol number marked in the fragment being processed. It has
   155  // to be given here outside of the FragmentID struct because IPv6 should not use
   156  // the protocol to identify a fragment.
   157  func (f *Fragmentation) Process(
   158  	id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (
   159  	*stack.PacketBuffer, uint8, bool, error) {
   160  	if first > last {
   161  		return nil, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
   162  	}
   163  
   164  	if first%f.blockSize != 0 {
   165  		return nil, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
   166  	}
   167  
   168  	fragmentSize := last - first + 1
   169  	if more && fragmentSize%f.blockSize != 0 {
   170  		return nil, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
   171  	}
   172  
   173  	if l := pkt.Data().Size(); l != int(fragmentSize) {
   174  		return nil, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
   175  	}
   176  
   177  	f.mu.Lock()
   178  	if f.reassemblers == nil {
   179  		return nil, 0, false, fmt.Errorf("Release() called before fragmentation processing could finish")
   180  	}
   181  
   182  	r, ok := f.reassemblers[id]
   183  	if !ok {
   184  		r = newReassembler(id, f.clock)
   185  		f.reassemblers[id] = r
   186  		wasEmpty := f.rList.Empty()
   187  		f.rList.PushFront(r)
   188  		if wasEmpty {
   189  			// If we have just pushed a first reassembler into an empty list, we
   190  			// should kickstart the release job. The release job will keep
   191  			// rescheduling itself until the list becomes empty.
   192  			f.releaseReassemblersLocked()
   193  		}
   194  	}
   195  	f.mu.Unlock()
   196  
   197  	resPkt, firstFragmentProto, done, memConsumed, err := r.process(first, last, more, proto, pkt)
   198  	if err != nil {
   199  		// We probably got an invalid sequence of fragments. Just
   200  		// discard the reassembler and move on.
   201  		f.mu.Lock()
   202  		f.release(r, false /* timedOut */)
   203  		f.mu.Unlock()
   204  		return nil, 0, false, fmt.Errorf("fragmentation processing error: %w", err)
   205  	}
   206  	f.mu.Lock()
   207  	f.memSize += memConsumed
   208  	if done {
   209  		f.release(r, false /* timedOut */)
   210  	}
   211  	// Evict reassemblers if we are consuming more memory than highLimit until
   212  	// we reach lowLimit.
   213  	if f.memSize > f.highLimit {
   214  		for f.memSize > f.lowLimit {
   215  			tail := f.rList.Back()
   216  			if tail == nil {
   217  				break
   218  			}
   219  			f.release(tail, false /* timedOut */)
   220  		}
   221  	}
   222  	f.mu.Unlock()
   223  	return resPkt, firstFragmentProto, done, nil
   224  }
   225  
   226  // Release releases all underlying resources.
   227  func (f *Fragmentation) Release() {
   228  	f.mu.Lock()
   229  	defer f.mu.Unlock()
   230  	for _, r := range f.reassemblers {
   231  		f.release(r, false /* timedOut */)
   232  	}
   233  	f.reassemblers = nil
   234  }
   235  
   236  func (f *Fragmentation) release(r *reassembler, timedOut bool) {
   237  	// Before releasing a fragment we need to check if r is already marked as done.
   238  	// Otherwise, we would delete it twice.
   239  	if r.checkDoneOrMark() {
   240  		return
   241  	}
   242  
   243  	delete(f.reassemblers, r.id)
   244  	f.rList.Remove(r)
   245  	f.memSize -= r.memSize
   246  	if f.memSize < 0 {
   247  		log.Warningf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.memSize)
   248  		f.memSize = 0
   249  	}
   250  
   251  	if h := f.timeoutHandler; timedOut && h != nil {
   252  		h.OnReassemblyTimeout(r.pkt)
   253  	}
   254  	if r.pkt != nil {
   255  		r.pkt.DecRef()
   256  		r.pkt = nil
   257  	}
   258  	for _, h := range r.holes {
   259  		if h.pkt != nil {
   260  			h.pkt.DecRef()
   261  			h.pkt = nil
   262  		}
   263  	}
   264  	r.holes = nil
   265  }
   266  
   267  // releaseReassemblersLocked releases already-expired reassemblers, then
   268  // schedules the job to call back itself for the remaining reassemblers if
   269  // any. This function must be called with f.mu locked.
   270  func (f *Fragmentation) releaseReassemblersLocked() {
   271  	now := f.clock.NowMonotonic()
   272  	for {
   273  		// The reassembler at the end of the list is the oldest.
   274  		r := f.rList.Back()
   275  		if r == nil {
   276  			// The list is empty.
   277  			break
   278  		}
   279  		elapsed := now.Sub(r.createdAt)
   280  		if f.timeout > elapsed {
   281  			// If the oldest reassembler has not expired, schedule the release
   282  			// job so that this function is called back when it has expired.
   283  			f.releaseJob.Schedule(f.timeout - elapsed)
   284  			break
   285  		}
   286  		// If the oldest reassembler has already expired, release it.
   287  		f.release(r, true /* timedOut*/)
   288  	}
   289  }
   290  
   291  // PacketFragmenter is the book-keeping struct for packet fragmentation.
   292  type PacketFragmenter struct {
   293  	transportHeader    []byte
   294  	data               buffer.Buffer
   295  	reserve            int
   296  	fragmentPayloadLen int
   297  	fragmentCount      int
   298  	currentFragment    int
   299  	fragmentOffset     int
   300  }
   301  
   302  // MakePacketFragmenter prepares the struct needed for packet fragmentation.
   303  //
   304  // pkt is the packet to be fragmented.
   305  //
   306  // fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can
   307  // have.
   308  //
   309  // reserve is the number of bytes that should be reserved for the headers in
   310  // each generated fragment.
   311  func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter {
   312  	// As per RFC 8200 Section 4.5, some IPv6 extension headers should not be
   313  	// repeated in each fragment. However we do not currently support any header
   314  	// of that kind yet, so the following computation is valid for both IPv4 and
   315  	// IPv6.
   316  	// TODO(gvisor.dev/issue/3912): Once Authentication or ESP Headers are
   317  	// supported for outbound packets, the fragmentable data should not include
   318  	// these headers.
   319  	var fragmentableData buffer.Buffer
   320  	fragmentableData.Append(pkt.TransportHeader().View())
   321  	pktBuf := pkt.Data().ToBuffer()
   322  	fragmentableData.Merge(&pktBuf)
   323  	fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen
   324  
   325  	return PacketFragmenter{
   326  		data:               fragmentableData,
   327  		reserve:            reserve,
   328  		fragmentPayloadLen: int(fragmentPayloadLen),
   329  		fragmentCount:      int(fragmentCount),
   330  	}
   331  }
   332  
   333  // BuildNextFragment returns a packet with the payload of the next fragment,
   334  // along with the fragment's offset, the number of bytes copied and a boolean
   335  // indicating if there are more fragments left or not. If this function is
   336  // called again after it indicated that no more fragments were left, it will
   337  // panic.
   338  //
   339  // Note that the returned packet will not have its network and link headers
   340  // populated, but space for them will be reserved. The transport header will be
   341  // stored in the packet's data.
   342  func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, bool) {
   343  	if pf.currentFragment >= pf.fragmentCount {
   344  		panic("BuildNextFragment should not be called again after the last fragment was returned")
   345  	}
   346  
   347  	fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   348  		ReserveHeaderBytes: pf.reserve,
   349  	})
   350  
   351  	// Copy data for the fragment.
   352  	copied := fragPkt.Data().ReadFrom(&pf.data, pf.fragmentPayloadLen)
   353  
   354  	offset := pf.fragmentOffset
   355  	pf.fragmentOffset += copied
   356  	pf.currentFragment++
   357  	more := pf.currentFragment != pf.fragmentCount
   358  
   359  	return fragPkt, offset, copied, more
   360  }
   361  
   362  // RemainingFragmentCount returns the number of fragments left to be built.
   363  func (pf *PacketFragmenter) RemainingFragmentCount() int {
   364  	return pf.fragmentCount - pf.currentFragment
   365  }
   366  
   367  // Release frees resources owned by the packet fragmenter.
   368  func (pf *PacketFragmenter) Release() {
   369  	pf.data.Release()
   370  }