github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/transport/tcp/segment.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tcp
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  
    21  	"github.com/nicocha30/gvisor-ligolo/pkg/buffer"
    22  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    23  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/seqnum"
    26  	"github.com/nicocha30/gvisor-ligolo/pkg/tcpip/stack"
    27  )
    28  
    29  // queueFlags are used to indicate which queue of an endpoint a particular segment
    30  // belongs to. This is used to track memory accounting correctly.
    31  type queueFlags uint8
    32  
    33  const (
    34  	// SegOverheadSize is the size of an empty seg in memory including packet
    35  	// buffer overhead. It is advised to use SegOverheadSize instead of segSize
    36  	// in all cases where accounting for segment memory overhead is important.
    37  	SegOverheadSize = segSize + stack.PacketBufferStructSize + header.IPv4MaximumHeaderSize
    38  
    39  	recvQ queueFlags = 1 << iota
    40  	sendQ
    41  )
    42  
    43  var segmentPool = sync.Pool{
    44  	New: func() any {
    45  		return &segment{}
    46  	},
    47  }
    48  
    49  // segment represents a TCP segment. It holds the payload and parsed TCP segment
    50  // information, and can be added to intrusive lists.
    51  // segment is mostly immutable, the only field allowed to change is data.
    52  //
    53  // +stateify savable
    54  type segment struct {
    55  	segmentEntry
    56  	segmentRefs
    57  
    58  	ep     *endpoint
    59  	qFlags queueFlags
    60  	id     stack.TransportEndpointID `state:"manual"`
    61  
    62  	pkt stack.PacketBufferPtr
    63  
    64  	sequenceNumber seqnum.Value
    65  	ackNumber      seqnum.Value
    66  	flags          header.TCPFlags
    67  	window         seqnum.Size
    68  	// csum is only populated for received segments.
    69  	csum uint16
    70  	// csumValid is true if the csum in the received segment is valid.
    71  	csumValid bool
    72  
    73  	// parsedOptions stores the parsed values from the options in the segment.
    74  	parsedOptions  header.TCPOptions
    75  	options        []byte `state:".([]byte)"`
    76  	hasNewSACKInfo bool
    77  	rcvdTime       tcpip.MonotonicTime
    78  	// xmitTime is the last transmit time of this segment.
    79  	xmitTime  tcpip.MonotonicTime
    80  	xmitCount uint32
    81  
    82  	// acked indicates if the segment has already been SACKed.
    83  	acked bool
    84  
    85  	// dataMemSize is the memory used by pkt initially. The value is used for
    86  	// memory accounting in the receive buffer instead of pkt.MemSize() because
    87  	// packet contents can be modified, so relying on the computed memory size
    88  	// to "free" reserved bytes could leak memory in the receiver.
    89  	dataMemSize int
    90  
    91  	// lost indicates if the segment is marked as lost by RACK.
    92  	lost bool
    93  }
    94  
    95  func newIncomingSegment(id stack.TransportEndpointID, clock tcpip.Clock, pkt stack.PacketBufferPtr) (*segment, error) {
    96  	hdr := header.TCP(pkt.TransportHeader().Slice())
    97  	netHdr := pkt.Network()
    98  	csum, csumValid, ok := header.TCPValid(
    99  		hdr,
   100  		func() uint16 { return pkt.Data().Checksum() },
   101  		uint16(pkt.Data().Size()),
   102  		netHdr.SourceAddress(),
   103  		netHdr.DestinationAddress(),
   104  		pkt.RXChecksumValidated)
   105  	if !ok {
   106  		return nil, fmt.Errorf("header data offset does not respect size constraints: %d < offset < %d, got offset=%d", header.TCPMinimumSize, len(hdr), hdr.DataOffset())
   107  	}
   108  
   109  	s := newSegment()
   110  	s.id = id
   111  	s.options = hdr[header.TCPMinimumSize:]
   112  	s.parsedOptions = header.ParseTCPOptions(hdr[header.TCPMinimumSize:])
   113  	s.sequenceNumber = seqnum.Value(hdr.SequenceNumber())
   114  	s.ackNumber = seqnum.Value(hdr.AckNumber())
   115  	s.flags = hdr.Flags()
   116  	s.window = seqnum.Size(hdr.WindowSize())
   117  	s.rcvdTime = clock.NowMonotonic()
   118  	s.dataMemSize = pkt.MemSize()
   119  	s.pkt = pkt.IncRef()
   120  	s.csumValid = csumValid
   121  
   122  	if !s.pkt.RXChecksumValidated {
   123  		s.csum = csum
   124  	}
   125  	return s, nil
   126  }
   127  
   128  func newOutgoingSegment(id stack.TransportEndpointID, clock tcpip.Clock, buf buffer.Buffer) *segment {
   129  	s := newSegment()
   130  	s.id = id
   131  	s.rcvdTime = clock.NowMonotonic()
   132  	s.pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf})
   133  	s.dataMemSize = s.pkt.MemSize()
   134  	return s
   135  }
   136  
   137  func (s *segment) clone() *segment {
   138  	t := newSegment()
   139  	t.id = s.id
   140  	t.sequenceNumber = s.sequenceNumber
   141  	t.ackNumber = s.ackNumber
   142  	t.flags = s.flags
   143  	t.window = s.window
   144  	t.rcvdTime = s.rcvdTime
   145  	t.xmitTime = s.xmitTime
   146  	t.xmitCount = s.xmitCount
   147  	t.ep = s.ep
   148  	t.qFlags = s.qFlags
   149  	t.dataMemSize = s.dataMemSize
   150  	t.pkt = s.pkt.Clone()
   151  	return t
   152  }
   153  
   154  func newSegment() *segment {
   155  	s := segmentPool.Get().(*segment)
   156  	*s = segment{}
   157  	s.InitRefs()
   158  	return s
   159  }
   160  
   161  // merge merges data in oth and clears oth.
   162  func (s *segment) merge(oth *segment) {
   163  	s.pkt.Data().Merge(oth.pkt.Data())
   164  	s.dataMemSize = s.pkt.MemSize()
   165  	oth.dataMemSize = oth.pkt.MemSize()
   166  }
   167  
   168  // setOwner sets the owning endpoint for this segment. Its required
   169  // to be called to ensure memory accounting for receive/send buffer
   170  // queues is done properly.
   171  func (s *segment) setOwner(ep *endpoint, qFlags queueFlags) {
   172  	switch qFlags {
   173  	case recvQ:
   174  		ep.updateReceiveMemUsed(s.segMemSize())
   175  	case sendQ:
   176  		// no memory account for sendQ yet.
   177  	default:
   178  		panic(fmt.Sprintf("unexpected queue flag %b", qFlags))
   179  	}
   180  	s.ep = ep
   181  	s.qFlags = qFlags
   182  }
   183  
   184  func (s *segment) DecRef() {
   185  	s.segmentRefs.DecRef(func() {
   186  		if s.ep != nil {
   187  			switch s.qFlags {
   188  			case recvQ:
   189  				s.ep.updateReceiveMemUsed(-s.segMemSize())
   190  			case sendQ:
   191  				// no memory accounting for sendQ yet.
   192  			default:
   193  				panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags))
   194  			}
   195  		}
   196  		s.pkt.DecRef()
   197  		s.pkt = nil
   198  		segmentPool.Put(s)
   199  	})
   200  }
   201  
   202  // logicalLen is the segment length in the sequence number space. It's defined
   203  // as the data length plus one for each of the SYN and FIN bits set.
   204  func (s *segment) logicalLen() seqnum.Size {
   205  	l := seqnum.Size(s.payloadSize())
   206  	if s.flags.Contains(header.TCPFlagSyn) {
   207  		l++
   208  	}
   209  	if s.flags.Contains(header.TCPFlagFin) {
   210  		l++
   211  	}
   212  	return l
   213  }
   214  
   215  // payloadSize is the size of s.data.
   216  func (s *segment) payloadSize() int {
   217  	return s.pkt.Data().Size()
   218  }
   219  
   220  // segMemSize is the amount of memory used to hold the segment data and
   221  // the associated metadata.
   222  func (s *segment) segMemSize() int {
   223  	return segSize + s.dataMemSize
   224  }
   225  
   226  // sackBlock returns a header.SACKBlock that represents this segment.
   227  func (s *segment) sackBlock() header.SACKBlock {
   228  	return header.SACKBlock{Start: s.sequenceNumber, End: s.sequenceNumber.Add(s.logicalLen())}
   229  }
   230  
   231  func (s *segment) TrimFront(ackLeft seqnum.Size) {
   232  	s.pkt.Data().TrimFront(int(ackLeft))
   233  }
   234  
   235  func (s *segment) ReadTo(dst io.Writer, peek bool) (int, error) {
   236  	return s.pkt.Data().ReadTo(dst, peek)
   237  }