gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/raw/endpoint.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package raw provides the implementation of raw sockets (see raw(7)). Raw
    16  // sockets allow applications to:
    17  //
    18  //   - manually write and inspect transport layer headers and payloads
    19  //   - receive all traffic of a given transport protocol (e.g. ICMP or UDP)
    20  //   - optionally write and inspect network layer headers of packets
    21  //
    22  // Raw sockets don't have any notion of ports, and incoming packets are
    23  // demultiplexed solely by protocol number. Thus, a raw UDP endpoint will
    24  // receive every UDP packet received by netstack. bind(2) and connect(2) can be
    25  // used to filter incoming packets by source and destination.
    26  package raw
    27  
    28  import (
    29  	"fmt"
    30  	"io"
    31  	"time"
    32  
    33  	"gvisor.dev/gvisor/pkg/buffer"
    34  	"gvisor.dev/gvisor/pkg/sync"
    35  	"gvisor.dev/gvisor/pkg/tcpip"
    36  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    37  	"gvisor.dev/gvisor/pkg/tcpip/header"
    38  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    39  	"gvisor.dev/gvisor/pkg/tcpip/transport"
    40  	"gvisor.dev/gvisor/pkg/tcpip/transport/internal/network"
    41  	"gvisor.dev/gvisor/pkg/waiter"
    42  )
    43  
    44  // +stateify savable
    45  type rawPacket struct {
    46  	rawPacketEntry
    47  	// data holds the actual packet data, including any headers and
    48  	// payload.
    49  	data       *stack.PacketBuffer
    50  	receivedAt time.Time `state:".(int64)"`
    51  	// senderAddr is the network address of the sender.
    52  	senderAddr tcpip.FullAddress
    53  	packetInfo tcpip.IPPacketInfo
    54  
    55  	// tosOrTClass stores either the Type of Service for IPv4 or the Traffic Class
    56  	// for IPv6.
    57  	tosOrTClass uint8
    58  	// ttlOrHopLimit stores either the TTL for IPv4 or the HopLimit for IPv6
    59  	ttlOrHopLimit uint8
    60  }
    61  
    62  // endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to
    63  // have goroutines make concurrent calls into the endpoint.
    64  //
    65  // Lock order:
    66  //
    67  //	endpoint.mu
    68  //	  endpoint.rcvMu
    69  //
    70  // +stateify savable
    71  type endpoint struct {
    72  	tcpip.DefaultSocketOptionsHandler
    73  
    74  	// The following fields are initialized at creation time and are
    75  	// immutable.
    76  	stack       *stack.Stack `state:"manual"`
    77  	transProto  tcpip.TransportProtocolNumber
    78  	waiterQueue *waiter.Queue
    79  	associated  bool
    80  
    81  	net   network.Endpoint
    82  	stats tcpip.TransportEndpointStats
    83  	ops   tcpip.SocketOptions
    84  
    85  	rcvMu sync.Mutex `state:"nosave"`
    86  	// +checklocks:rcvMu
    87  	rcvList rawPacketList
    88  	// +checklocks:rcvMu
    89  	rcvBufSize int
    90  	// +checklocks:rcvMu
    91  	rcvClosed bool
    92  	// +checklocks:rcvMu
    93  	rcvDisabled bool
    94  
    95  	mu sync.RWMutex `state:"nosave"`
    96  
    97  	// ipv6ChecksumOffset indicates the offset to populate the IPv6 checksum at.
    98  	//
    99  	// A negative value indicates no checksum should be calculated.
   100  	//
   101  	// +checklocks:mu
   102  	ipv6ChecksumOffset int
   103  	// icmp6Filter holds the filter for ICMPv6 packets.
   104  	//
   105  	// +checklocks:mu
   106  	icmpv6Filter tcpip.ICMPv6Filter
   107  }
   108  
   109  // NewEndpoint returns a raw  endpoint for the given protocols.
   110  func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
   111  	return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
   112  }
   113  
   114  func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, tcpip.Error) {
   115  	// Calculating the upper-layer checksum is disabled by default for raw IPv6
   116  	// endpoints, unless the upper-layer protocol is ICMPv6.
   117  	//
   118  	// As per RFC 3542 section 3.1,
   119  	//
   120  	//   The kernel will calculate and insert the ICMPv6 checksum for ICMPv6
   121  	//   raw sockets, since this checksum is mandatory.
   122  	ipv6ChecksumOffset := -1
   123  	if netProto == header.IPv6ProtocolNumber && transProto == header.ICMPv6ProtocolNumber {
   124  		ipv6ChecksumOffset = header.ICMPv6ChecksumOffset
   125  	}
   126  
   127  	e := &endpoint{
   128  		stack:              s,
   129  		transProto:         transProto,
   130  		waiterQueue:        waiterQueue,
   131  		associated:         associated,
   132  		ipv6ChecksumOffset: ipv6ChecksumOffset,
   133  	}
   134  	e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
   135  	e.ops.SetMulticastLoop(true)
   136  	e.ops.SetHeaderIncluded(!associated)
   137  	e.ops.SetSendBufferSize(32*1024, false /* notify */)
   138  	e.ops.SetReceiveBufferSize(32*1024, false /* notify */)
   139  	e.net.Init(s, netProto, transProto, &e.ops, waiterQueue)
   140  
   141  	// Override with stack defaults.
   142  	var ss tcpip.SendBufferSizeOption
   143  	if err := s.Option(&ss); err == nil {
   144  		e.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
   145  	}
   146  
   147  	var rs tcpip.ReceiveBufferSizeOption
   148  	if err := s.Option(&rs); err == nil {
   149  		e.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
   150  	}
   151  
   152  	// Unassociated endpoints are write-only and users call Write() with IP
   153  	// headers included. Because they're write-only, We don't need to
   154  	// register with the stack.
   155  	if !associated {
   156  		e.ops.SetReceiveBufferSize(0, false /* notify */)
   157  		e.waiterQueue = nil
   158  		return e, nil
   159  	}
   160  
   161  	if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
   162  		return nil, err
   163  	}
   164  
   165  	return e, nil
   166  }
   167  
   168  // WakeupWriters implements tcpip.SocketOptionsHandler.
   169  func (e *endpoint) WakeupWriters() {
   170  	e.net.MaybeSignalWritable()
   171  }
   172  
   173  // HasNIC implements tcpip.SocketOptionsHandler.
   174  func (e *endpoint) HasNIC(id int32) bool {
   175  	return e.stack.HasNIC(tcpip.NICID(id))
   176  }
   177  
   178  // Abort implements stack.TransportEndpoint.Abort.
   179  func (e *endpoint) Abort() {
   180  	e.Close()
   181  }
   182  
   183  // Close implements tcpip.Endpoint.Close.
   184  func (e *endpoint) Close() {
   185  	e.mu.Lock()
   186  	defer e.mu.Unlock()
   187  
   188  	if e.net.State() == transport.DatagramEndpointStateClosed {
   189  		return
   190  	}
   191  
   192  	e.net.Close()
   193  
   194  	if !e.associated {
   195  		return
   196  	}
   197  
   198  	e.stack.UnregisterRawTransportEndpoint(e.net.NetProto(), e.transProto, e)
   199  
   200  	e.rcvMu.Lock()
   201  	defer e.rcvMu.Unlock()
   202  
   203  	// Clear the receive list.
   204  	e.rcvClosed = true
   205  	e.rcvBufSize = 0
   206  	for !e.rcvList.Empty() {
   207  		p := e.rcvList.Front()
   208  		e.rcvList.Remove(p)
   209  		p.data.DecRef()
   210  	}
   211  
   212  	e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
   213  }
   214  
   215  // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
   216  func (*endpoint) ModerateRecvBuf(int) {}
   217  
   218  func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
   219  	e.net.SetOwner(owner)
   220  }
   221  
   222  // Read implements tcpip.Endpoint.Read.
   223  func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
   224  	e.rcvMu.Lock()
   225  
   226  	// If there's no data to read, return that read would block or that the
   227  	// endpoint is closed.
   228  	if e.rcvList.Empty() {
   229  		var err tcpip.Error = &tcpip.ErrWouldBlock{}
   230  		if e.rcvClosed {
   231  			e.stats.ReadErrors.ReadClosed.Increment()
   232  			err = &tcpip.ErrClosedForReceive{}
   233  		}
   234  		e.rcvMu.Unlock()
   235  		return tcpip.ReadResult{}, err
   236  	}
   237  
   238  	pkt := e.rcvList.Front()
   239  	if !opts.Peek {
   240  		e.rcvList.Remove(pkt)
   241  		defer pkt.data.DecRef()
   242  		e.rcvBufSize -= pkt.data.Data().Size()
   243  	}
   244  
   245  	e.rcvMu.Unlock()
   246  
   247  	// Control Messages
   248  	// TODO(https://gvisor.dev/issue/7012): Share control message code with other
   249  	// network endpoints.
   250  	cm := tcpip.ReceivableControlMessages{
   251  		HasTimestamp: true,
   252  		Timestamp:    pkt.receivedAt,
   253  	}
   254  	switch netProto := e.net.NetProto(); netProto {
   255  	case header.IPv4ProtocolNumber:
   256  		if e.ops.GetReceiveTOS() {
   257  			cm.HasTOS = true
   258  			cm.TOS = pkt.tosOrTClass
   259  		}
   260  		if e.ops.GetReceiveTTL() {
   261  			cm.HasTTL = true
   262  			cm.TTL = pkt.ttlOrHopLimit
   263  		}
   264  		if e.ops.GetReceivePacketInfo() {
   265  			cm.HasIPPacketInfo = true
   266  			cm.PacketInfo = pkt.packetInfo
   267  		}
   268  	case header.IPv6ProtocolNumber:
   269  		if e.ops.GetReceiveTClass() {
   270  			cm.HasTClass = true
   271  			// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
   272  			cm.TClass = uint32(pkt.tosOrTClass)
   273  		}
   274  		if e.ops.GetReceiveHopLimit() {
   275  			cm.HasHopLimit = true
   276  			cm.HopLimit = pkt.ttlOrHopLimit
   277  		}
   278  		if e.ops.GetIPv6ReceivePacketInfo() {
   279  			cm.HasIPv6PacketInfo = true
   280  			cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{
   281  				NIC:  pkt.packetInfo.NIC,
   282  				Addr: pkt.packetInfo.DestinationAddr,
   283  			}
   284  		}
   285  	default:
   286  		panic(fmt.Sprintf("unrecognized network protocol = %d", netProto))
   287  	}
   288  
   289  	res := tcpip.ReadResult{
   290  		Total:           pkt.data.Data().Size(),
   291  		ControlMessages: cm,
   292  	}
   293  	if opts.NeedRemoteAddr {
   294  		res.RemoteAddr = pkt.senderAddr
   295  	}
   296  
   297  	n, err := pkt.data.Data().ReadTo(dst, opts.Peek)
   298  	if n == 0 && err != nil {
   299  		return res, &tcpip.ErrBadBuffer{}
   300  	}
   301  	res.Count = n
   302  	return res, nil
   303  }
   304  
   305  // Write implements tcpip.Endpoint.Write.
   306  func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   307  	netProto := e.net.NetProto()
   308  	// We can create, but not write to, unassociated IPv6 endpoints.
   309  	if !e.associated && netProto == header.IPv6ProtocolNumber {
   310  		return 0, &tcpip.ErrInvalidOptionValue{}
   311  	}
   312  
   313  	if opts.To != nil {
   314  		// Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint.
   315  		if netProto == header.IPv6ProtocolNumber && opts.To.Addr.BitLen() != header.IPv6AddressSizeBits {
   316  			return 0, &tcpip.ErrInvalidOptionValue{}
   317  		}
   318  	}
   319  
   320  	n, err := e.write(p, opts)
   321  	switch err.(type) {
   322  	case nil:
   323  		e.stats.PacketsSent.Increment()
   324  	case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
   325  		e.stats.WriteErrors.InvalidArgs.Increment()
   326  	case *tcpip.ErrClosedForSend:
   327  		e.stats.WriteErrors.WriteClosed.Increment()
   328  	case *tcpip.ErrInvalidEndpointState:
   329  		e.stats.WriteErrors.InvalidEndpointState.Increment()
   330  	case *tcpip.ErrHostUnreachable, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
   331  		// Errors indicating any problem with IP routing of the packet.
   332  		e.stats.SendErrors.NoRoute.Increment()
   333  	default:
   334  		// For all other errors when writing to the network layer.
   335  		e.stats.SendErrors.SendToNetworkFailed.Increment()
   336  	}
   337  	return n, err
   338  }
   339  
   340  func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   341  	e.mu.Lock()
   342  	ctx, err := e.net.AcquireContextForWrite(opts)
   343  	ipv6ChecksumOffset := e.ipv6ChecksumOffset
   344  	e.mu.Unlock()
   345  	if err != nil {
   346  		return 0, err
   347  	}
   348  	defer ctx.Release()
   349  
   350  	if p.Len() > int(ctx.MTU()) {
   351  		return 0, &tcpip.ErrMessageTooLong{}
   352  	}
   353  
   354  	// Prevents giant buffer allocations.
   355  	if p.Len() > header.DatagramMaximumSize {
   356  		return 0, &tcpip.ErrMessageTooLong{}
   357  	}
   358  
   359  	var payload buffer.Buffer
   360  	defer payload.Release()
   361  	if _, err := payload.WriteFromReader(p, int64(p.Len())); err != nil {
   362  		return 0, &tcpip.ErrBadBuffer{}
   363  	}
   364  	payloadSz := payload.Size()
   365  
   366  	if packetInfo := ctx.PacketInfo(); packetInfo.NetProto == header.IPv6ProtocolNumber && ipv6ChecksumOffset >= 0 {
   367  		// Make sure we can fit the checksum.
   368  		if payload.Size() < int64(ipv6ChecksumOffset+checksum.Size) {
   369  			return 0, &tcpip.ErrInvalidOptionValue{}
   370  		}
   371  
   372  		payloadView, _ := payload.PullUp(ipv6ChecksumOffset, int(payload.Size())-ipv6ChecksumOffset)
   373  		xsum := header.PseudoHeaderChecksum(e.transProto, packetInfo.LocalAddress, packetInfo.RemoteAddress, uint16(payload.Size()))
   374  		checksum.Put(payloadView.AsSlice(), 0)
   375  		xsum = checksum.Combine(payload.Checksum(0), xsum)
   376  		checksum.Put(payloadView.AsSlice(), ^xsum)
   377  	}
   378  
   379  	pkt := ctx.TryNewPacketBuffer(int(ctx.PacketInfo().MaxHeaderLength), payload.Clone())
   380  	if pkt == nil {
   381  		return 0, &tcpip.ErrWouldBlock{}
   382  	}
   383  	defer pkt.DecRef()
   384  
   385  	if err := ctx.WritePacket(pkt, e.ops.GetHeaderIncluded()); err != nil {
   386  		return 0, err
   387  	}
   388  
   389  	return payloadSz, nil
   390  }
   391  
   392  // Disconnect implements tcpip.Endpoint.Disconnect.
   393  func (*endpoint) Disconnect() tcpip.Error {
   394  	return &tcpip.ErrNotSupported{}
   395  }
   396  
   397  // Connect implements tcpip.Endpoint.Connect.
   398  func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
   399  	netProto := e.net.NetProto()
   400  
   401  	// Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint.
   402  	if netProto == header.IPv6ProtocolNumber && addr.Addr.BitLen() != header.IPv6AddressSizeBits {
   403  		return &tcpip.ErrAddressFamilyNotSupported{}
   404  	}
   405  
   406  	return e.net.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
   407  		if e.associated {
   408  			// Re-register the endpoint with the appropriate NIC.
   409  			if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
   410  				return err
   411  			}
   412  			e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
   413  		}
   414  
   415  		return nil
   416  	})
   417  }
   418  
   419  // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
   420  func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
   421  	if e.net.State() != transport.DatagramEndpointStateConnected {
   422  		return &tcpip.ErrNotConnected{}
   423  	}
   424  	return nil
   425  }
   426  
   427  // Listen implements tcpip.Endpoint.Listen.
   428  func (*endpoint) Listen(int) tcpip.Error {
   429  	return &tcpip.ErrNotSupported{}
   430  }
   431  
   432  // Accept implements tcpip.Endpoint.Accept.
   433  func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
   434  	return nil, nil, &tcpip.ErrNotSupported{}
   435  }
   436  
   437  // Bind implements tcpip.Endpoint.Bind.
   438  func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
   439  	return e.net.BindAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, _ tcpip.Address) tcpip.Error {
   440  		if !e.associated {
   441  			return nil
   442  		}
   443  
   444  		// Re-register the endpoint with the appropriate NIC.
   445  		if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
   446  			return err
   447  		}
   448  		e.stack.UnregisterRawTransportEndpoint(netProto, e.transProto, e)
   449  		return nil
   450  	})
   451  }
   452  
   453  // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
   454  func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
   455  	a := e.net.GetLocalAddress()
   456  	// Linux returns the protocol in the port field.
   457  	a.Port = uint16(e.transProto)
   458  	return a, nil
   459  }
   460  
   461  // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
   462  func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
   463  	// Even a connected socket doesn't return a remote address.
   464  	return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
   465  }
   466  
   467  // Readiness implements tcpip.Endpoint.Readiness.
   468  func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
   469  	var result waiter.EventMask
   470  
   471  	if e.net.HasSendSpace() {
   472  		result |= waiter.WritableEvents & mask
   473  	}
   474  
   475  	// Determine whether the endpoint is readable.
   476  	if (mask & waiter.ReadableEvents) != 0 {
   477  		e.rcvMu.Lock()
   478  		if !e.rcvList.Empty() || e.rcvClosed {
   479  			result |= waiter.ReadableEvents
   480  		}
   481  		e.rcvMu.Unlock()
   482  	}
   483  
   484  	return result
   485  }
   486  
   487  // SetSockOpt implements tcpip.Endpoint.SetSockOpt.
   488  func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
   489  	switch opt := opt.(type) {
   490  	case *tcpip.SocketDetachFilterOption:
   491  		return nil
   492  
   493  	case *tcpip.ICMPv6Filter:
   494  		if e.net.NetProto() != header.IPv6ProtocolNumber {
   495  			return &tcpip.ErrUnknownProtocolOption{}
   496  		}
   497  
   498  		if e.transProto != header.ICMPv6ProtocolNumber {
   499  			return &tcpip.ErrInvalidOptionValue{}
   500  		}
   501  
   502  		e.mu.Lock()
   503  		defer e.mu.Unlock()
   504  		e.icmpv6Filter = *opt
   505  		return nil
   506  	default:
   507  		return e.net.SetSockOpt(opt)
   508  	}
   509  }
   510  
   511  func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
   512  	switch opt {
   513  	case tcpip.IPv6Checksum:
   514  		if e.net.NetProto() != header.IPv6ProtocolNumber {
   515  			return &tcpip.ErrUnknownProtocolOption{}
   516  		}
   517  
   518  		if e.transProto == header.ICMPv6ProtocolNumber {
   519  			// As per RFC 3542 section 3.1,
   520  			//
   521  			//  An attempt to set IPV6_CHECKSUM for an ICMPv6 socket will fail.
   522  			return &tcpip.ErrInvalidOptionValue{}
   523  		}
   524  
   525  		// Make sure the offset is aligned properly if checksum is requested.
   526  		if v > 0 && v%checksum.Size != 0 {
   527  			return &tcpip.ErrInvalidOptionValue{}
   528  		}
   529  
   530  		e.mu.Lock()
   531  		defer e.mu.Unlock()
   532  		e.ipv6ChecksumOffset = v
   533  		return nil
   534  	default:
   535  		return e.net.SetSockOptInt(opt, v)
   536  	}
   537  }
   538  
   539  // GetSockOpt implements tcpip.Endpoint.GetSockOpt.
   540  func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
   541  	switch opt := opt.(type) {
   542  	case *tcpip.ICMPv6Filter:
   543  		if e.net.NetProto() != header.IPv6ProtocolNumber {
   544  			return &tcpip.ErrUnknownProtocolOption{}
   545  		}
   546  
   547  		if e.transProto != header.ICMPv6ProtocolNumber {
   548  			return &tcpip.ErrInvalidOptionValue{}
   549  		}
   550  
   551  		e.mu.RLock()
   552  		defer e.mu.RUnlock()
   553  		*opt = e.icmpv6Filter
   554  		return nil
   555  
   556  	default:
   557  		return e.net.GetSockOpt(opt)
   558  	}
   559  }
   560  
   561  // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
   562  func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
   563  	switch opt {
   564  	case tcpip.ReceiveQueueSizeOption:
   565  		v := 0
   566  		e.rcvMu.Lock()
   567  		if !e.rcvList.Empty() {
   568  			p := e.rcvList.Front()
   569  			v = p.data.Data().Size()
   570  		}
   571  		e.rcvMu.Unlock()
   572  		return v, nil
   573  
   574  	case tcpip.IPv6Checksum:
   575  		if e.net.NetProto() != header.IPv6ProtocolNumber {
   576  			return 0, &tcpip.ErrUnknownProtocolOption{}
   577  		}
   578  
   579  		e.mu.Lock()
   580  		defer e.mu.Unlock()
   581  		return e.ipv6ChecksumOffset, nil
   582  
   583  	default:
   584  		return e.net.GetSockOptInt(opt)
   585  	}
   586  }
   587  
   588  // HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
   589  func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
   590  	notifyReadableEvents := func() bool {
   591  		e.mu.RLock()
   592  		defer e.mu.RUnlock()
   593  		e.rcvMu.Lock()
   594  		defer e.rcvMu.Unlock()
   595  
   596  		// Drop the packet if our buffer is currently full or if this is an unassociated
   597  		// endpoint (i.e endpoint created  w/ IPPROTO_RAW). Such endpoints are send only
   598  		// See: https://man7.org/linux/man-pages/man7/raw.7.html
   599  		//
   600  		//    An IPPROTO_RAW socket is send only.  If you really want to receive
   601  		//    all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
   602  		//    Note that packet sockets don't reassemble IP fragments, unlike raw
   603  		//    sockets.
   604  		if e.rcvClosed || !e.associated {
   605  			e.stack.Stats().DroppedPackets.Increment()
   606  			e.stats.ReceiveErrors.ClosedReceiver.Increment()
   607  			return false
   608  		}
   609  
   610  		rcvBufSize := e.ops.GetReceiveBufferSize()
   611  		if e.rcvDisabled || e.rcvBufSize >= int(rcvBufSize) {
   612  			e.stack.Stats().DroppedPackets.Increment()
   613  			e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
   614  			return false
   615  		}
   616  
   617  		net := pkt.Network()
   618  		dstAddr := net.DestinationAddress()
   619  		srcAddr := net.SourceAddress()
   620  		info := e.net.Info()
   621  
   622  		switch state := e.net.State(); state {
   623  		case transport.DatagramEndpointStateInitial:
   624  		case transport.DatagramEndpointStateConnected:
   625  			// If connected, only accept packets from the remote address we
   626  			// connected to.
   627  			if info.ID.RemoteAddress != srcAddr {
   628  				return false
   629  			}
   630  
   631  			// Connected sockets may also have been bound to a specific
   632  			// address/NIC.
   633  			fallthrough
   634  		case transport.DatagramEndpointStateBound:
   635  			// If bound to a NIC, only accept data for that NIC.
   636  			if info.BindNICID != 0 && info.BindNICID != pkt.NICID {
   637  				return false
   638  			}
   639  
   640  			// If bound to an address, only accept data for that address.
   641  			if info.BindAddr != (tcpip.Address{}) && info.BindAddr != dstAddr {
   642  				return false
   643  			}
   644  		default:
   645  			panic(fmt.Sprintf("unhandled state = %s", state))
   646  		}
   647  
   648  		wasEmpty := e.rcvBufSize == 0
   649  
   650  		// Push new packet into receive list and increment the buffer size.
   651  		packet := &rawPacket{
   652  			senderAddr: tcpip.FullAddress{
   653  				NIC:  pkt.NICID,
   654  				Addr: srcAddr,
   655  			},
   656  			packetInfo: tcpip.IPPacketInfo{
   657  				// TODO(gvisor.dev/issue/3556): dstAddr may be a multicast or broadcast
   658  				// address. LocalAddr should hold a unicast address that can be
   659  				// used to respond to the incoming packet.
   660  				LocalAddr:       dstAddr,
   661  				DestinationAddr: dstAddr,
   662  				NIC:             pkt.NICID,
   663  			},
   664  		}
   665  
   666  		// Save any useful information from the network header to the packet.
   667  		packet.tosOrTClass, _ = pkt.Network().TOS()
   668  		switch pkt.NetworkProtocolNumber {
   669  		case header.IPv4ProtocolNumber:
   670  			packet.ttlOrHopLimit = header.IPv4(pkt.NetworkHeader().Slice()).TTL()
   671  		case header.IPv6ProtocolNumber:
   672  			packet.ttlOrHopLimit = header.IPv6(pkt.NetworkHeader().Slice()).HopLimit()
   673  		}
   674  
   675  		// Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
   676  		// We copy headers' underlying bytes because pkt.*Header may point to
   677  		// the middle of a slice, and another struct may point to the "outer"
   678  		// slice. Save/restore doesn't support overlapping slices and will fail.
   679  		//
   680  		// TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports
   681  		// overlapping slices.
   682  		transportHeader := pkt.TransportHeader().Slice()
   683  		var combinedBuf buffer.Buffer
   684  		defer combinedBuf.Release()
   685  		switch info.NetProto {
   686  		case header.IPv4ProtocolNumber:
   687  			networkHeader := pkt.NetworkHeader().Slice()
   688  			headers := buffer.NewView(len(networkHeader) + len(transportHeader))
   689  			headers.Write(networkHeader)
   690  			headers.Write(transportHeader)
   691  			combinedBuf = buffer.MakeWithView(headers)
   692  			pktBuf := pkt.Data().ToBuffer()
   693  			combinedBuf.Merge(&pktBuf)
   694  		case header.IPv6ProtocolNumber:
   695  			if e.transProto == header.ICMPv6ProtocolNumber {
   696  				if len(transportHeader) < header.ICMPv6MinimumSize {
   697  					return false
   698  				}
   699  
   700  				if e.icmpv6Filter.ShouldDeny(uint8(header.ICMPv6(transportHeader).Type())) {
   701  					return false
   702  				}
   703  			}
   704  
   705  			combinedBuf = buffer.MakeWithView(pkt.TransportHeader().View())
   706  			pktBuf := pkt.Data().ToBuffer()
   707  			combinedBuf.Merge(&pktBuf)
   708  
   709  			if checksumOffset := e.ipv6ChecksumOffset; checksumOffset >= 0 {
   710  				bufSize := int(combinedBuf.Size())
   711  				if bufSize < checksumOffset+checksum.Size {
   712  					// Message too small to fit checksum.
   713  					return false
   714  				}
   715  
   716  				xsum := header.PseudoHeaderChecksum(e.transProto, srcAddr, dstAddr, uint16(bufSize))
   717  				xsum = checksum.Combine(combinedBuf.Checksum(0), xsum)
   718  				if xsum != 0xFFFF {
   719  					// Invalid checksum.
   720  					return false
   721  				}
   722  			}
   723  		default:
   724  			panic(fmt.Sprintf("unrecognized protocol number = %d", info.NetProto))
   725  		}
   726  
   727  		packet.data = stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: combinedBuf.Clone()})
   728  		packet.receivedAt = e.stack.Clock().Now()
   729  
   730  		e.rcvList.PushBack(packet)
   731  		e.rcvBufSize += packet.data.Data().Size()
   732  		e.stats.PacketsReceived.Increment()
   733  
   734  		// Notify waiters that there is data to be read now.
   735  		return wasEmpty
   736  	}()
   737  
   738  	if notifyReadableEvents {
   739  		e.waiterQueue.Notify(waiter.ReadableEvents)
   740  	}
   741  }
   742  
   743  // State implements socket.Socket.State.
   744  func (e *endpoint) State() uint32 {
   745  	return uint32(e.net.State())
   746  }
   747  
   748  // Info returns a copy of the endpoint info.
   749  func (e *endpoint) Info() tcpip.EndpointInfo {
   750  	ret := e.net.Info()
   751  	return &ret
   752  }
   753  
   754  // Stats returns a pointer to the endpoint stats.
   755  func (e *endpoint) Stats() tcpip.EndpointStats {
   756  	return &e.stats
   757  }
   758  
   759  // Wait implements stack.TransportEndpoint.Wait.
   760  func (*endpoint) Wait() {}
   761  
   762  // LastError implements tcpip.Endpoint.LastError.
   763  func (*endpoint) LastError() tcpip.Error {
   764  	return nil
   765  }
   766  
   767  // SocketOptions implements tcpip.Endpoint.SocketOptions.
   768  func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
   769  	return &e.ops
   770  }
   771  
   772  func (e *endpoint) setReceiveDisabled(v bool) {
   773  	e.rcvMu.Lock()
   774  	defer e.rcvMu.Unlock()
   775  	e.rcvDisabled = v
   776  }