github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/tcpip/transport/tcp/segment.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  
    21  	"github.com/sagernet/gvisor/pkg/buffer"
    22  	"github.com/sagernet/gvisor/pkg/sync"
    23  	"github.com/sagernet/gvisor/pkg/tcpip"
    24  	"github.com/sagernet/gvisor/pkg/tcpip/header"
    25  	"github.com/sagernet/gvisor/pkg/tcpip/seqnum"
    26  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    27  )
    28  
    29  // queueFlags are used to indicate which queue of an endpoint a particular segment
    30  // belongs to. This is used to track memory accounting correctly.
    31  type queueFlags uint8
    32  
    33  const (
    34  	// SegOverheadSize is the size of an empty seg in memory including packet
    35  	// buffer overhead. It is advised to use SegOverheadSize instead of segSize
    36  	// in all cases where accounting for segment memory overhead is important.
    37  	SegOverheadSize = segSize + stack.PacketBufferStructSize + header.IPv4MaximumHeaderSize
    38  
    39  	recvQ queueFlags = 1 << iota
    40  	sendQ
    41  )
    42  
    43  var segmentPool = sync.Pool{
    44  	New: func() any {
    45  		return &segment{}
    46  	},
    47  }
    48  
    49  // segment represents a TCP segment. It holds the payload and parsed TCP segment
    50  // information, and can be added to intrusive lists.
    51  // segment is mostly immutable, the only field allowed to change is data.
    52  //
    53  // +stateify savable
    54  type segment struct {
    55  	segmentEntry
    56  	segmentRefs
    57  
    58  	ep     *Endpoint
    59  	qFlags queueFlags
    60  	id     stack.TransportEndpointID `state:"manual"`
    61  
    62  	pkt *stack.PacketBuffer
    63  
    64  	sequenceNumber seqnum.Value
    65  	ackNumber      seqnum.Value
    66  	flags          header.TCPFlags
    67  	window         seqnum.Size
    68  	// csum is only populated for received segments.
    69  	csum uint16
    70  	// csumValid is true if the csum in the received segment is valid.
    71  	csumValid bool
    72  
    73  	// parsedOptions stores the parsed values from the options in the segment.
    74  	parsedOptions  header.TCPOptions
    75  	options        []byte `state:".([]byte)"`
    76  	hasNewSACKInfo bool
    77  	rcvdTime       tcpip.MonotonicTime
    78  	// xmitTime is the last transmit time of this segment.
    79  	xmitTime  tcpip.MonotonicTime
    80  	xmitCount uint32
    81  
    82  	// acked indicates if the segment has already been SACKed.
    83  	acked bool
    84  
    85  	// dataMemSize is the memory used by pkt initially. The value is used for
    86  	// memory accounting in the receive buffer instead of pkt.MemSize() because
    87  	// packet contents can be modified, so relying on the computed memory size
    88  	// to "free" reserved bytes could leak memory in the receiver.
    89  	dataMemSize int
    90  
    91  	// lost indicates if the segment is marked as lost by RACK.
    92  	lost bool
    93  }
    94  
    95  func newIncomingSegment(id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) (*segment, error) {
    96  	hdr := header.TCP(pkt.TransportHeader().Slice())
    97  	var srcAddr tcpip.Address
    98  	var dstAddr tcpip.Address
    99  	switch netProto := pkt.NetworkProtocolNumber; netProto {
   100  	case header.IPv4ProtocolNumber:
   101  		hdr := header.IPv4(pkt.NetworkHeader().Slice())
   102  		srcAddr = hdr.SourceAddress()
   103  		dstAddr = hdr.DestinationAddress()
   104  	case header.IPv6ProtocolNumber:
   105  		hdr := header.IPv6(pkt.NetworkHeader().Slice())
   106  		srcAddr = hdr.SourceAddress()
   107  		dstAddr = hdr.DestinationAddress()
   108  	default:
   109  		panic(fmt.Sprintf("unknown network protocol number %d", netProto))
   110  	}
   111  
   112  	csum, csumValid, ok := header.TCPValid(
   113  		hdr,
   114  		func() uint16 { return pkt.Data().Checksum() },
   115  		uint16(pkt.Data().Size()),
   116  		srcAddr,
   117  		dstAddr,
   118  		pkt.RXChecksumValidated)
   119  	if !ok {
   120  		return nil, fmt.Errorf("header data offset does not respect size constraints: %d < offset < %d, got offset=%d", header.TCPMinimumSize, len(hdr), hdr.DataOffset())
   121  	}
   122  
   123  	s := newSegment()
   124  	s.id = id
   125  	s.options = hdr[header.TCPMinimumSize:]
   126  	s.parsedOptions = header.ParseTCPOptions(hdr[header.TCPMinimumSize:])
   127  	s.sequenceNumber = seqnum.Value(hdr.SequenceNumber())
   128  	s.ackNumber = seqnum.Value(hdr.AckNumber())
   129  	s.flags = hdr.Flags()
   130  	s.window = seqnum.Size(hdr.WindowSize())
   131  	s.rcvdTime = clock.NowMonotonic()
   132  	s.dataMemSize = pkt.MemSize()
   133  	s.pkt = pkt.IncRef()
   134  	s.csumValid = csumValid
   135  
   136  	if !s.pkt.RXChecksumValidated {
   137  		s.csum = csum
   138  	}
   139  	return s, nil
   140  }
   141  
   142  func newOutgoingSegment(id stack.TransportEndpointID, clock tcpip.Clock, buf buffer.Buffer) *segment {
   143  	s := newSegment()
   144  	s.id = id
   145  	s.rcvdTime = clock.NowMonotonic()
   146  	s.pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf})
   147  	s.dataMemSize = s.pkt.MemSize()
   148  	return s
   149  }
   150  
   151  func (s *segment) clone() *segment {
   152  	t := newSegment()
   153  	t.id = s.id
   154  	t.sequenceNumber = s.sequenceNumber
   155  	t.ackNumber = s.ackNumber
   156  	t.flags = s.flags
   157  	t.window = s.window
   158  	t.rcvdTime = s.rcvdTime
   159  	t.xmitTime = s.xmitTime
   160  	t.xmitCount = s.xmitCount
   161  	t.ep = s.ep
   162  	t.qFlags = s.qFlags
   163  	t.dataMemSize = s.dataMemSize
   164  	t.pkt = s.pkt.Clone()
   165  	return t
   166  }
   167  
   168  func newSegment() *segment {
   169  	s := segmentPool.Get().(*segment)
   170  	*s = segment{}
   171  	s.InitRefs()
   172  	return s
   173  }
   174  
   175  // merge merges data in oth and clears oth.
   176  func (s *segment) merge(oth *segment) {
   177  	s.pkt.Data().Merge(oth.pkt.Data())
   178  	s.dataMemSize = s.pkt.MemSize()
   179  	oth.dataMemSize = oth.pkt.MemSize()
   180  }
   181  
   182  // setOwner sets the owning endpoint for this segment. Its required
   183  // to be called to ensure memory accounting for receive/send buffer
   184  // queues is done properly.
   185  func (s *segment) setOwner(ep *Endpoint, qFlags queueFlags) {
   186  	switch qFlags {
   187  	case recvQ:
   188  		ep.updateReceiveMemUsed(s.segMemSize())
   189  	case sendQ:
   190  		// no memory account for sendQ yet.
   191  	default:
   192  		panic(fmt.Sprintf("unexpected queue flag %b", qFlags))
   193  	}
   194  	s.ep = ep
   195  	s.qFlags = qFlags
   196  }
   197  
   198  func (s *segment) DecRef() {
   199  	s.segmentRefs.DecRef(func() {
   200  		if s.ep != nil {
   201  			switch s.qFlags {
   202  			case recvQ:
   203  				s.ep.updateReceiveMemUsed(-s.segMemSize())
   204  			case sendQ:
   205  				// no memory accounting for sendQ yet.
   206  			default:
   207  				panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags))
   208  			}
   209  		}
   210  		s.pkt.DecRef()
   211  		s.pkt = nil
   212  		segmentPool.Put(s)
   213  	})
   214  }
   215  
   216  // logicalLen is the segment length in the sequence number space. It's defined
   217  // as the data length plus one for each of the SYN and FIN bits set.
   218  func (s *segment) logicalLen() seqnum.Size {
   219  	l := seqnum.Size(s.payloadSize())
   220  	if s.flags.Contains(header.TCPFlagSyn) {
   221  		l++
   222  	}
   223  	if s.flags.Contains(header.TCPFlagFin) {
   224  		l++
   225  	}
   226  	return l
   227  }
   228  
   229  // payloadSize is the size of s.data.
   230  func (s *segment) payloadSize() int {
   231  	return s.pkt.Data().Size()
   232  }
   233  
   234  // segMemSize is the amount of memory used to hold the segment data and
   235  // the associated metadata.
   236  func (s *segment) segMemSize() int {
   237  	return segSize + s.dataMemSize
   238  }
   239  
   240  // sackBlock returns a header.SACKBlock that represents this segment.
   241  func (s *segment) sackBlock() header.SACKBlock {
   242  	return header.SACKBlock{Start: s.sequenceNumber, End: s.sequenceNumber.Add(s.logicalLen())}
   243  }
   244  
   245  func (s *segment) TrimFront(ackLeft seqnum.Size) {
   246  	s.pkt.Data().TrimFront(int(ackLeft))
   247  }
   248  
   249  func (s *segment) ReadTo(dst io.Writer, peek bool) (int, error) {
   250  	return s.pkt.Data().ReadTo(dst, peek)
   251  }