github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/tcpip/transport/packet/endpoint.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package packet provides the implementation of packet sockets (see
    16  // packet(7)). Packet sockets allow applications to:
    17  //
    18  //   - manually write and inspect link, network, and transport headers
    19  //   - receive all traffic of a given network protocol, or all protocols
    20  //
    21  // Packet sockets are similar to raw sockets, but provide even more power to
    22  // users, letting them effectively talk directly to the network device.
    23  //
    24  // Packet sockets skip the input and output iptables chains.
    25  package packet
    26  
    27  import (
    28  	"io"
    29  	"time"
    30  
    31  	"github.com/sagernet/gvisor/pkg/buffer"
    32  	"github.com/sagernet/gvisor/pkg/sync"
    33  	"github.com/sagernet/gvisor/pkg/tcpip"
    34  	"github.com/sagernet/gvisor/pkg/tcpip/header"
    35  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    36  	"github.com/sagernet/gvisor/pkg/waiter"
    37  )
    38  
    39  // +stateify savable
    40  type packet struct {
    41  	packetEntry
    42  	// data holds the actual packet data, including any headers and payload.
    43  	data       *stack.PacketBuffer
    44  	receivedAt time.Time `state:".(int64)"`
    45  	// senderAddr is the network address of the sender.
    46  	senderAddr tcpip.FullAddress
    47  	// packetInfo holds additional information like the protocol
    48  	// of the packet etc.
    49  	packetInfo tcpip.LinkPacketInfo
    50  }
    51  
    52  // endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
    53  // to have goroutines make concurrent calls into the endpoint.
    54  //
    55  // Lock order:
    56  //
    57  //	endpoint.mu
    58  //	  endpoint.rcvMu
    59  //
    60  // +stateify savable
    61  type endpoint struct {
    62  	tcpip.DefaultSocketOptionsHandler
    63  
    64  	// The following fields are initialized at creation time and are
    65  	// immutable.
    66  	stack       *stack.Stack `state:"manual"`
    67  	waiterQueue *waiter.Queue
    68  	cooked      bool
    69  	ops         tcpip.SocketOptions
    70  	stats       tcpip.TransportEndpointStats
    71  
    72  	// The following fields are used to manage the receive queue.
    73  	rcvMu sync.Mutex `state:"nosave"`
    74  	// +checklocks:rcvMu
    75  	rcvList packetList
    76  	// +checklocks:rcvMu
    77  	rcvBufSize int
    78  	// +checklocks:rcvMu
    79  	rcvClosed bool
    80  	// +checklocks:rcvMu
    81  	rcvDisabled bool
    82  
    83  	mu sync.RWMutex `state:"nosave"`
    84  	// +checklocks:mu
    85  	closed bool
    86  	// +checklocks:mu
    87  	boundNetProto tcpip.NetworkProtocolNumber
    88  	// +checklocks:mu
    89  	boundNIC tcpip.NICID
    90  
    91  	lastErrorMu sync.Mutex `state:"nosave"`
    92  	// +checklocks:lastErrorMu
    93  	lastError tcpip.Error
    94  }
    95  
    96  // NewEndpoint returns a new packet endpoint.
    97  func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) tcpip.Endpoint {
    98  	ep := &endpoint{
    99  		stack:         s,
   100  		cooked:        cooked,
   101  		boundNetProto: netProto,
   102  		waiterQueue:   waiterQueue,
   103  	}
   104  	ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
   105  	ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
   106  
   107  	// Override with stack defaults.
   108  	var ss tcpip.SendBufferSizeOption
   109  	if err := s.Option(&ss); err == nil {
   110  		ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
   111  	}
   112  
   113  	var rs tcpip.ReceiveBufferSizeOption
   114  	if err := s.Option(&rs); err == nil {
   115  		ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
   116  	}
   117  
   118  	s.RegisterPacketEndpoint(0, netProto, ep)
   119  
   120  	return ep
   121  }
   122  
   123  // Abort implements stack.TransportEndpoint.Abort.
   124  func (ep *endpoint) Abort() {
   125  	ep.Close()
   126  }
   127  
   128  // Close implements tcpip.Endpoint.Close.
   129  func (ep *endpoint) Close() {
   130  	ep.mu.Lock()
   131  	defer ep.mu.Unlock()
   132  
   133  	if ep.closed {
   134  		return
   135  	}
   136  
   137  	ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
   138  
   139  	ep.rcvMu.Lock()
   140  	defer ep.rcvMu.Unlock()
   141  
   142  	// Clear the receive list.
   143  	ep.rcvClosed = true
   144  	ep.rcvBufSize = 0
   145  	for !ep.rcvList.Empty() {
   146  		p := ep.rcvList.Front()
   147  		ep.rcvList.Remove(p)
   148  		p.data.DecRef()
   149  	}
   150  
   151  	ep.closed = true
   152  	ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
   153  }
   154  
   155  // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
   156  func (*endpoint) ModerateRecvBuf(int) {}
   157  
   158  // Read implements tcpip.Endpoint.Read.
   159  func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
   160  	ep.rcvMu.Lock()
   161  
   162  	// If there's no data to read, return that read would block or that the
   163  	// endpoint is closed.
   164  	if ep.rcvList.Empty() {
   165  		var err tcpip.Error = &tcpip.ErrWouldBlock{}
   166  		if ep.rcvClosed {
   167  			ep.stats.ReadErrors.ReadClosed.Increment()
   168  			err = &tcpip.ErrClosedForReceive{}
   169  		}
   170  		ep.rcvMu.Unlock()
   171  		return tcpip.ReadResult{}, err
   172  	}
   173  
   174  	packet := ep.rcvList.Front()
   175  	if !opts.Peek {
   176  		ep.rcvList.Remove(packet)
   177  		defer packet.data.DecRef()
   178  		ep.rcvBufSize -= packet.data.Size()
   179  	}
   180  
   181  	ep.rcvMu.Unlock()
   182  
   183  	res := tcpip.ReadResult{
   184  		Total: packet.data.Size(),
   185  		ControlMessages: tcpip.ReceivableControlMessages{
   186  			HasTimestamp: true,
   187  			Timestamp:    packet.receivedAt,
   188  		},
   189  	}
   190  	if opts.NeedRemoteAddr {
   191  		res.RemoteAddr = packet.senderAddr
   192  	}
   193  	if opts.NeedLinkPacketInfo {
   194  		res.LinkPacketInfo = packet.packetInfo
   195  	}
   196  
   197  	n, err := packet.data.Data().ReadTo(dst, opts.Peek)
   198  	if n == 0 && err != nil {
   199  		return res, &tcpip.ErrBadBuffer{}
   200  	}
   201  	res.Count = n
   202  	return res, nil
   203  }
   204  
   205  func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   206  	if !ep.stack.PacketEndpointWriteSupported() {
   207  		return 0, &tcpip.ErrNotSupported{}
   208  	}
   209  
   210  	ep.mu.Lock()
   211  	closed := ep.closed
   212  	nicID := ep.boundNIC
   213  	proto := ep.boundNetProto
   214  	ep.mu.Unlock()
   215  	if closed {
   216  		return 0, &tcpip.ErrClosedForSend{}
   217  	}
   218  
   219  	var remote tcpip.LinkAddress
   220  	if to := opts.To; to != nil {
   221  		remote = to.LinkAddr
   222  
   223  		if n := to.NIC; n != 0 {
   224  			nicID = n
   225  		}
   226  
   227  		if p := to.Port; p != 0 {
   228  			proto = tcpip.NetworkProtocolNumber(p)
   229  		}
   230  	}
   231  
   232  	if nicID == 0 {
   233  		return 0, &tcpip.ErrInvalidOptionValue{}
   234  	}
   235  
   236  	// Prevents giant buffer allocations.
   237  	if p.Len() > header.DatagramMaximumSize {
   238  		return 0, &tcpip.ErrMessageTooLong{}
   239  	}
   240  
   241  	var payload buffer.Buffer
   242  	if _, err := payload.WriteFromReader(p, int64(p.Len())); err != nil {
   243  		return 0, &tcpip.ErrBadBuffer{}
   244  	}
   245  	payloadSz := payload.Size()
   246  
   247  	if err := func() tcpip.Error {
   248  		if ep.cooked {
   249  			return ep.stack.WritePacketToRemote(nicID, remote, proto, payload)
   250  		}
   251  		return ep.stack.WriteRawPacket(nicID, proto, payload)
   252  	}(); err != nil {
   253  		return 0, err
   254  	}
   255  	return payloadSz, nil
   256  }
   257  
   258  // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
   259  // disconnected, and this function always returns tpcip.ErrNotSupported.
   260  func (*endpoint) Disconnect() tcpip.Error {
   261  	return &tcpip.ErrNotSupported{}
   262  }
   263  
   264  // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be
   265  // connected, and this function always returns *tcpip.ErrNotSupported.
   266  func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error {
   267  	return &tcpip.ErrNotSupported{}
   268  }
   269  
   270  // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used
   271  // with Shutdown, and this function always returns *tcpip.ErrNotSupported.
   272  func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error {
   273  	return &tcpip.ErrNotSupported{}
   274  }
   275  
   276  // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with
   277  // Listen, and this function always returns *tcpip.ErrNotSupported.
   278  func (*endpoint) Listen(int) tcpip.Error {
   279  	return &tcpip.ErrNotSupported{}
   280  }
   281  
   282  // Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with
   283  // Accept, and this function always returns *tcpip.ErrNotSupported.
   284  func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
   285  	return nil, nil, &tcpip.ErrNotSupported{}
   286  }
   287  
   288  // Bind implements tcpip.Endpoint.Bind.
   289  func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
   290  	// "By default, all packets of the specified protocol type are passed
   291  	// to a packet socket.  To get packets only from a specific interface
   292  	// use bind(2) specifying an address in a struct sockaddr_ll to bind
   293  	// the packet socket  to  an interface.  Fields used for binding are
   294  	// sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex."
   295  	// - packet(7).
   296  
   297  	ep.mu.Lock()
   298  	defer ep.mu.Unlock()
   299  
   300  	netProto := tcpip.NetworkProtocolNumber(addr.Port)
   301  	if netProto == 0 {
   302  		// Do not allow unbinding the network protocol.
   303  		netProto = ep.boundNetProto
   304  	}
   305  
   306  	if ep.boundNIC == addr.NIC && ep.boundNetProto == netProto {
   307  		// Already bound to the requested NIC and network protocol.
   308  		return nil
   309  	}
   310  
   311  	// TODO(https://gvisor.dev/issue/6618): Unregister after registering the new
   312  	// binding.
   313  	ep.stack.UnregisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep)
   314  	ep.boundNIC = 0
   315  	ep.boundNetProto = 0
   316  
   317  	// Bind endpoint to receive packets from specific interface.
   318  	if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil {
   319  		return err
   320  	}
   321  
   322  	ep.boundNIC = addr.NIC
   323  	ep.boundNetProto = netProto
   324  	return nil
   325  }
   326  
   327  // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
   328  func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
   329  	ep.mu.RLock()
   330  	defer ep.mu.RUnlock()
   331  
   332  	return tcpip.FullAddress{
   333  		NIC:  ep.boundNIC,
   334  		Port: uint16(ep.boundNetProto),
   335  	}, nil
   336  }
   337  
   338  // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
   339  func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
   340  	// Even a connected socket doesn't return a remote address.
   341  	return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
   342  }
   343  
   344  // Readiness implements tcpip.Endpoint.Readiness.
   345  func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
   346  	// The endpoint is always writable.
   347  	result := waiter.WritableEvents & mask
   348  
   349  	// Determine whether the endpoint is readable.
   350  	if (mask & waiter.ReadableEvents) != 0 {
   351  		ep.rcvMu.Lock()
   352  		if !ep.rcvList.Empty() || ep.rcvClosed {
   353  			result |= waiter.ReadableEvents
   354  		}
   355  		ep.rcvMu.Unlock()
   356  	}
   357  
   358  	return result
   359  }
   360  
   361  // SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be
   362  // used with SetSockOpt, and this function always returns
   363  // *tcpip.ErrNotSupported.
   364  func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
   365  	switch opt.(type) {
   366  	case *tcpip.SocketDetachFilterOption:
   367  		return nil
   368  
   369  	default:
   370  		return &tcpip.ErrUnknownProtocolOption{}
   371  	}
   372  }
   373  
   374  // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
   375  func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error {
   376  	return &tcpip.ErrUnknownProtocolOption{}
   377  }
   378  
   379  func (ep *endpoint) LastError() tcpip.Error {
   380  	ep.lastErrorMu.Lock()
   381  	defer ep.lastErrorMu.Unlock()
   382  
   383  	err := ep.lastError
   384  	ep.lastError = nil
   385  	return err
   386  }
   387  
   388  // UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
   389  func (ep *endpoint) UpdateLastError(err tcpip.Error) {
   390  	ep.lastErrorMu.Lock()
   391  	ep.lastError = err
   392  	ep.lastErrorMu.Unlock()
   393  }
   394  
   395  // GetSockOpt implements tcpip.Endpoint.GetSockOpt.
   396  func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error {
   397  	return &tcpip.ErrNotSupported{}
   398  }
   399  
   400  // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
   401  func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
   402  	switch opt {
   403  	case tcpip.ReceiveQueueSizeOption:
   404  		v := 0
   405  		ep.rcvMu.Lock()
   406  		if !ep.rcvList.Empty() {
   407  			p := ep.rcvList.Front()
   408  			v = p.data.Size()
   409  		}
   410  		ep.rcvMu.Unlock()
   411  		return v, nil
   412  
   413  	default:
   414  		return -1, &tcpip.ErrUnknownProtocolOption{}
   415  	}
   416  }
   417  
   418  // HandlePacket implements stack.PacketEndpoint.HandlePacket.
   419  func (ep *endpoint) HandlePacket(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
   420  	ep.rcvMu.Lock()
   421  
   422  	// Drop the packet if our buffer is currently full.
   423  	if ep.rcvClosed {
   424  		ep.rcvMu.Unlock()
   425  		ep.stack.Stats().DroppedPackets.Increment()
   426  		ep.stats.ReceiveErrors.ClosedReceiver.Increment()
   427  		return
   428  	}
   429  
   430  	rcvBufSize := ep.ops.GetReceiveBufferSize()
   431  	if ep.rcvDisabled || ep.rcvBufSize >= int(rcvBufSize) {
   432  		ep.rcvMu.Unlock()
   433  		ep.stack.Stats().DroppedPackets.Increment()
   434  		ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
   435  		return
   436  	}
   437  
   438  	wasEmpty := ep.rcvBufSize == 0
   439  
   440  	rcvdPkt := packet{
   441  		packetInfo: tcpip.LinkPacketInfo{
   442  			Protocol: netProto,
   443  			PktType:  pkt.PktType,
   444  		},
   445  		senderAddr: tcpip.FullAddress{
   446  			NIC: nicID,
   447  		},
   448  		receivedAt: ep.stack.Clock().Now(),
   449  	}
   450  
   451  	if len(pkt.LinkHeader().Slice()) != 0 {
   452  		hdr := header.Ethernet(pkt.LinkHeader().Slice())
   453  		rcvdPkt.senderAddr.LinkAddr = hdr.SourceAddress()
   454  	}
   455  
   456  	// Raw packet endpoints include link-headers in received packets.
   457  	pktBuf := pkt.ToBuffer()
   458  	if ep.cooked {
   459  		// Cooked packet endpoints don't include the link-headers in received
   460  		// packets.
   461  		pktBuf.TrimFront(int64(len(pkt.LinkHeader().Slice()) + len(pkt.VirtioNetHeader().Slice())))
   462  	}
   463  	rcvdPkt.data = stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: pktBuf})
   464  
   465  	ep.rcvList.PushBack(&rcvdPkt)
   466  	ep.rcvBufSize += rcvdPkt.data.Size()
   467  
   468  	ep.rcvMu.Unlock()
   469  	ep.stats.PacketsReceived.Increment()
   470  	// Notify waiters that there's data to be read.
   471  	if wasEmpty {
   472  		ep.waiterQueue.Notify(waiter.ReadableEvents)
   473  	}
   474  }
   475  
   476  // State implements socket.Socket.State.
   477  func (*endpoint) State() uint32 {
   478  	return 0
   479  }
   480  
   481  // Info returns a copy of the endpoint info.
   482  func (ep *endpoint) Info() tcpip.EndpointInfo {
   483  	ep.mu.RLock()
   484  	defer ep.mu.RUnlock()
   485  	return &stack.TransportEndpointInfo{NetProto: ep.boundNetProto}
   486  }
   487  
   488  // Stats returns a pointer to the endpoint stats.
   489  func (ep *endpoint) Stats() tcpip.EndpointStats {
   490  	return &ep.stats
   491  }
   492  
   493  // SetOwner implements tcpip.Endpoint.SetOwner.
   494  func (*endpoint) SetOwner(tcpip.PacketOwner) {}
   495  
   496  // SocketOptions implements tcpip.Endpoint.SocketOptions.
   497  func (ep *endpoint) SocketOptions() *tcpip.SocketOptions {
   498  	return &ep.ops
   499  }