github.com/MerlinKodo/gvisor@v0.0.0-20231110090155-957f62ecf90e/pkg/tcpip/transport/icmp/endpoint.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package icmp
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"time"
    21  
    22  	"github.com/MerlinKodo/gvisor/pkg/buffer"
    23  	"github.com/MerlinKodo/gvisor/pkg/sync"
    24  	"github.com/MerlinKodo/gvisor/pkg/tcpip"
    25  	"github.com/MerlinKodo/gvisor/pkg/tcpip/checksum"
    26  	"github.com/MerlinKodo/gvisor/pkg/tcpip/header"
    27  	"github.com/MerlinKodo/gvisor/pkg/tcpip/ports"
    28  	"github.com/MerlinKodo/gvisor/pkg/tcpip/stack"
    29  	"github.com/MerlinKodo/gvisor/pkg/tcpip/transport"
    30  	"github.com/MerlinKodo/gvisor/pkg/tcpip/transport/internal/network"
    31  	"github.com/MerlinKodo/gvisor/pkg/waiter"
    32  )
    33  
    34  // +stateify savable
    35  type icmpPacket struct {
    36  	icmpPacketEntry
    37  	senderAddress tcpip.FullAddress
    38  	packetInfo    tcpip.IPPacketInfo
    39  	data          stack.PacketBufferPtr
    40  	receivedAt    time.Time `state:".(int64)"`
    41  
    42  	// tosOrTClass stores either the Type of Service for IPv4 or the Traffic Class
    43  	// for IPv6.
    44  	tosOrTClass uint8
    45  	// ttlOrHopLimit stores either the TTL for IPv4 or the HopLimit for IPv6
    46  	ttlOrHopLimit uint8
    47  }
    48  
    49  // endpoint represents an ICMP endpoint. This struct serves as the interface
    50  // between users of the endpoint and the protocol implementation; it is legal to
    51  // have concurrent goroutines make calls into the endpoint, they are properly
    52  // synchronized.
    53  //
    54  // +stateify savable
    55  type endpoint struct {
    56  	tcpip.DefaultSocketOptionsHandler
    57  
    58  	// The following fields are initialized at creation time and are
    59  	// immutable.
    60  	stack       *stack.Stack `state:"manual"`
    61  	transProto  tcpip.TransportProtocolNumber
    62  	waiterQueue *waiter.Queue
    63  	uniqueID    uint64
    64  	net         network.Endpoint
    65  	stats       tcpip.TransportEndpointStats
    66  	ops         tcpip.SocketOptions
    67  
    68  	// The following fields are used to manage the receive queue, and are
    69  	// protected by rcvMu.
    70  	rcvMu      sync.Mutex `state:"nosave"`
    71  	rcvReady   bool
    72  	rcvList    icmpPacketList
    73  	rcvBufSize int
    74  	rcvClosed  bool
    75  
    76  	// The following fields are protected by the mu mutex.
    77  	mu sync.RWMutex `state:"nosave"`
    78  	// frozen indicates if the packets should be delivered to the endpoint
    79  	// during restore.
    80  	frozen bool
    81  	ident  uint16
    82  }
    83  
    84  func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) {
    85  	ep := &endpoint{
    86  		stack:       s,
    87  		transProto:  transProto,
    88  		waiterQueue: waiterQueue,
    89  		uniqueID:    s.UniqueID(),
    90  	}
    91  	ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
    92  	ep.ops.SetSendBufferSize(32*1024, false /* notify */)
    93  	ep.ops.SetReceiveBufferSize(32*1024, false /* notify */)
    94  	ep.net.Init(s, netProto, transProto, &ep.ops, waiterQueue)
    95  
    96  	// Override with stack defaults.
    97  	var ss tcpip.SendBufferSizeOption
    98  	if err := s.Option(&ss); err == nil {
    99  		ep.ops.SetSendBufferSize(int64(ss.Default), false /* notify */)
   100  	}
   101  	var rs tcpip.ReceiveBufferSizeOption
   102  	if err := s.Option(&rs); err == nil {
   103  		ep.ops.SetReceiveBufferSize(int64(rs.Default), false /* notify */)
   104  	}
   105  	return ep, nil
   106  }
   107  
   108  // WakeupWriters implements tcpip.SocketOptionsHandler.
   109  func (e *endpoint) WakeupWriters() {
   110  	e.net.MaybeSignalWritable()
   111  }
   112  
   113  // UniqueID implements stack.TransportEndpoint.UniqueID.
   114  func (e *endpoint) UniqueID() uint64 {
   115  	return e.uniqueID
   116  }
   117  
   118  // Abort implements stack.TransportEndpoint.Abort.
   119  func (e *endpoint) Abort() {
   120  	e.Close()
   121  }
   122  
   123  // Close puts the endpoint in a closed state and frees all resources
   124  // associated with it.
   125  func (e *endpoint) Close() {
   126  	notify := func() bool {
   127  		e.mu.Lock()
   128  		defer e.mu.Unlock()
   129  
   130  		switch state := e.net.State(); state {
   131  		case transport.DatagramEndpointStateInitial:
   132  		case transport.DatagramEndpointStateClosed:
   133  			return false
   134  		case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   135  			info := e.net.Info()
   136  			info.ID.LocalPort = e.ident
   137  			e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice()))
   138  		default:
   139  			panic(fmt.Sprintf("unhandled state = %s", state))
   140  		}
   141  
   142  		e.net.Shutdown()
   143  		e.net.Close()
   144  
   145  		e.rcvMu.Lock()
   146  		defer e.rcvMu.Unlock()
   147  		e.rcvClosed = true
   148  		e.rcvBufSize = 0
   149  		for !e.rcvList.Empty() {
   150  			p := e.rcvList.Front()
   151  			e.rcvList.Remove(p)
   152  			p.data.DecRef()
   153  		}
   154  
   155  		return true
   156  	}()
   157  
   158  	if notify {
   159  		e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
   160  	}
   161  }
   162  
   163  // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
   164  func (*endpoint) ModerateRecvBuf(int) {}
   165  
   166  // SetOwner implements tcpip.Endpoint.SetOwner.
   167  func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
   168  	e.net.SetOwner(owner)
   169  }
   170  
   171  // Read implements tcpip.Endpoint.Read.
   172  func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
   173  	e.rcvMu.Lock()
   174  
   175  	if e.rcvList.Empty() {
   176  		var err tcpip.Error = &tcpip.ErrWouldBlock{}
   177  		if e.rcvClosed {
   178  			e.stats.ReadErrors.ReadClosed.Increment()
   179  			err = &tcpip.ErrClosedForReceive{}
   180  		}
   181  		e.rcvMu.Unlock()
   182  		return tcpip.ReadResult{}, err
   183  	}
   184  
   185  	p := e.rcvList.Front()
   186  	if !opts.Peek {
   187  		e.rcvList.Remove(p)
   188  		defer p.data.DecRef()
   189  		e.rcvBufSize -= p.data.Data().Size()
   190  	}
   191  
   192  	e.rcvMu.Unlock()
   193  
   194  	// Control Messages
   195  	// TODO(https://gvisor.dev/issue/7012): Share control message code with other
   196  	// network endpoints.
   197  	cm := tcpip.ReceivableControlMessages{
   198  		HasTimestamp: true,
   199  		Timestamp:    p.receivedAt,
   200  	}
   201  	switch netProto := e.net.NetProto(); netProto {
   202  	case header.IPv4ProtocolNumber:
   203  		if e.ops.GetReceiveTOS() {
   204  			cm.HasTOS = true
   205  			cm.TOS = p.tosOrTClass
   206  		}
   207  		if e.ops.GetReceivePacketInfo() {
   208  			cm.HasIPPacketInfo = true
   209  			cm.PacketInfo = p.packetInfo
   210  		}
   211  		if e.ops.GetReceiveTTL() {
   212  			cm.HasTTL = true
   213  			cm.TTL = p.ttlOrHopLimit
   214  		}
   215  	case header.IPv6ProtocolNumber:
   216  		if e.ops.GetReceiveTClass() {
   217  			cm.HasTClass = true
   218  			// Although TClass is an 8-bit value it's read in the CMsg as a uint32.
   219  			cm.TClass = uint32(p.tosOrTClass)
   220  		}
   221  		if e.ops.GetIPv6ReceivePacketInfo() {
   222  			cm.HasIPv6PacketInfo = true
   223  			cm.IPv6PacketInfo = tcpip.IPv6PacketInfo{
   224  				NIC:  p.packetInfo.NIC,
   225  				Addr: p.packetInfo.DestinationAddr,
   226  			}
   227  		}
   228  		if e.ops.GetReceiveHopLimit() {
   229  			cm.HasHopLimit = true
   230  			cm.HopLimit = p.ttlOrHopLimit
   231  		}
   232  	default:
   233  		panic(fmt.Sprintf("unrecognized network protocol = %d", netProto))
   234  	}
   235  
   236  	res := tcpip.ReadResult{
   237  		Total:           p.data.Data().Size(),
   238  		ControlMessages: cm,
   239  	}
   240  	if opts.NeedRemoteAddr {
   241  		res.RemoteAddr = p.senderAddress
   242  	}
   243  
   244  	n, err := p.data.Data().ReadTo(dst, opts.Peek)
   245  	if n == 0 && err != nil {
   246  		return res, &tcpip.ErrBadBuffer{}
   247  	}
   248  	res.Count = n
   249  	return res, nil
   250  }
   251  
   252  // prepareForWrite prepares the endpoint for sending data. In particular, it
   253  // binds it if it's still in the initial state. To do so, it must first
   254  // reacquire the mutex in exclusive mode.
   255  //
   256  // Returns true for retry if preparation should be retried.
   257  // +checklocksread:e.mu
   258  func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) {
   259  	switch e.net.State() {
   260  	case transport.DatagramEndpointStateInitial:
   261  	case transport.DatagramEndpointStateConnected:
   262  		return false, nil
   263  	case transport.DatagramEndpointStateBound:
   264  		if to == nil {
   265  			return false, &tcpip.ErrDestinationRequired{}
   266  		}
   267  		return false, nil
   268  	default:
   269  		return false, &tcpip.ErrInvalidEndpointState{}
   270  	}
   271  
   272  	e.mu.RUnlock()
   273  	e.mu.Lock()
   274  	defer e.mu.DowngradeLock()
   275  
   276  	// The state changed when we released the shared locked and re-acquired
   277  	// it in exclusive mode. Try again.
   278  	if e.net.State() != transport.DatagramEndpointStateInitial {
   279  		return true, nil
   280  	}
   281  
   282  	// The state is still 'initial', so try to bind the endpoint.
   283  	if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
   284  		return false, err
   285  	}
   286  
   287  	return true, nil
   288  }
   289  
   290  // Write writes data to the endpoint's peer. This method does not block
   291  // if the data cannot be written.
   292  func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   293  	n, err := e.write(p, opts)
   294  	switch err.(type) {
   295  	case nil:
   296  		e.stats.PacketsSent.Increment()
   297  	case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue:
   298  		e.stats.WriteErrors.InvalidArgs.Increment()
   299  	case *tcpip.ErrClosedForSend:
   300  		e.stats.WriteErrors.WriteClosed.Increment()
   301  	case *tcpip.ErrInvalidEndpointState:
   302  		e.stats.WriteErrors.InvalidEndpointState.Increment()
   303  	case *tcpip.ErrHostUnreachable, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable:
   304  		// Errors indicating any problem with IP routing of the packet.
   305  		e.stats.SendErrors.NoRoute.Increment()
   306  	default:
   307  		// For all other errors when writing to the network layer.
   308  		e.stats.SendErrors.SendToNetworkFailed.Increment()
   309  	}
   310  	return n, err
   311  }
   312  
   313  func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) {
   314  	e.mu.RLock()
   315  	defer e.mu.RUnlock()
   316  
   317  	// Prepare for write.
   318  	for {
   319  		retry, err := e.prepareForWriteInner(opts.To)
   320  		if err != nil {
   321  			return network.WriteContext{}, 0, err
   322  		}
   323  
   324  		if !retry {
   325  			break
   326  		}
   327  	}
   328  
   329  	ctx, err := e.net.AcquireContextForWrite(opts)
   330  	return ctx, e.ident, err
   331  }
   332  
   333  func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) {
   334  	ctx, ident, err := e.prepareForWrite(opts)
   335  	if err != nil {
   336  		return 0, err
   337  	}
   338  	defer ctx.Release()
   339  
   340  	// Prevents giant buffer allocations.
   341  	if p.Len() > header.DatagramMaximumSize {
   342  		return 0, &tcpip.ErrMessageTooLong{}
   343  	}
   344  
   345  	v := buffer.NewView(p.Len())
   346  	defer v.Release()
   347  	if _, err := io.CopyN(v, p, int64(p.Len())); err != nil {
   348  		return 0, &tcpip.ErrBadBuffer{}
   349  	}
   350  	n := v.Size()
   351  
   352  	switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto {
   353  	case header.IPv4ProtocolNumber:
   354  		if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil {
   355  			return 0, err
   356  		}
   357  
   358  	case header.IPv6ProtocolNumber:
   359  		if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil {
   360  			return 0, err
   361  		}
   362  	default:
   363  		panic(fmt.Sprintf("unhandled network protocol = %d", netProto))
   364  	}
   365  
   366  	return int64(n), nil
   367  }
   368  
   369  var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
   370  
   371  // HasNIC implements tcpip.SocketOptionsHandler.
   372  func (e *endpoint) HasNIC(id int32) bool {
   373  	return e.stack.HasNIC(tcpip.NICID(id))
   374  }
   375  
   376  // SetSockOpt implements tcpip.Endpoint.
   377  func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
   378  	return e.net.SetSockOpt(opt)
   379  }
   380  
   381  // SetSockOptInt implements tcpip.Endpoint.
   382  func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
   383  	return e.net.SetSockOptInt(opt, v)
   384  }
   385  
   386  // GetSockOptInt implements tcpip.Endpoint.
   387  func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
   388  	switch opt {
   389  	case tcpip.ReceiveQueueSizeOption:
   390  		v := 0
   391  		e.rcvMu.Lock()
   392  		if !e.rcvList.Empty() {
   393  			p := e.rcvList.Front()
   394  			v = p.data.Data().Size()
   395  		}
   396  		e.rcvMu.Unlock()
   397  		return v, nil
   398  
   399  	default:
   400  		return e.net.GetSockOptInt(opt)
   401  	}
   402  }
   403  
   404  // GetSockOpt implements tcpip.Endpoint.
   405  func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
   406  	return e.net.GetSockOpt(opt)
   407  }
   408  
   409  func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data *buffer.View, maxHeaderLength uint16) tcpip.Error {
   410  	if data.Size() < header.ICMPv4MinimumSize {
   411  		return &tcpip.ErrInvalidEndpointState{}
   412  	}
   413  
   414  	pkt := ctx.TryNewPacketBuffer(header.ICMPv4MinimumSize+int(maxHeaderLength), buffer.Buffer{})
   415  	if pkt.IsNil() {
   416  		return &tcpip.ErrWouldBlock{}
   417  	}
   418  	defer pkt.DecRef()
   419  
   420  	icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
   421  	pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber
   422  	copy(icmpv4, data.AsSlice())
   423  	// Set the ident to the user-specified port. Sequence number should
   424  	// already be set by the user.
   425  	icmpv4.SetIdent(ident)
   426  	data.TrimFront(header.ICMPv4MinimumSize)
   427  
   428  	// Linux performs these basic checks.
   429  	if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 {
   430  		return &tcpip.ErrInvalidEndpointState{}
   431  	}
   432  
   433  	icmpv4.SetChecksum(0)
   434  	icmpv4.SetChecksum(^checksum.Checksum(icmpv4, checksum.Checksum(data.AsSlice(), 0)))
   435  	pkt.Data().AppendView(data.Clone())
   436  
   437  	// Because this icmp endpoint is implemented in the transport layer, we can
   438  	// only increment the 'stack-wide' stats but we can't increment the
   439  	// 'per-NetworkEndpoint' stats.
   440  	stats := s.Stats().ICMP.V4.PacketsSent
   441  
   442  	if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
   443  		stats.Dropped.Increment()
   444  		return err
   445  	}
   446  
   447  	stats.EchoRequest.Increment()
   448  	return nil
   449  }
   450  
   451  func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data *buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error {
   452  	if data.Size() < header.ICMPv6EchoMinimumSize {
   453  		return &tcpip.ErrInvalidEndpointState{}
   454  	}
   455  
   456  	pkt := ctx.TryNewPacketBuffer(header.ICMPv6MinimumSize+int(maxHeaderLength), buffer.Buffer{})
   457  	if pkt.IsNil() {
   458  		return &tcpip.ErrWouldBlock{}
   459  	}
   460  	defer pkt.DecRef()
   461  
   462  	icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
   463  	pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber
   464  	copy(icmpv6, data.AsSlice())
   465  	// Set the ident. Sequence number is provided by the user.
   466  	icmpv6.SetIdent(ident)
   467  	data.TrimFront(header.ICMPv6MinimumSize)
   468  
   469  	if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
   470  		return &tcpip.ErrInvalidEndpointState{}
   471  	}
   472  
   473  	pkt.Data().AppendView(data.Clone())
   474  	pktData := pkt.Data()
   475  	icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
   476  		Header:      icmpv6,
   477  		Src:         src,
   478  		Dst:         dst,
   479  		PayloadCsum: pktData.Checksum(),
   480  		PayloadLen:  pktData.Size(),
   481  	}))
   482  
   483  	// Because this icmp endpoint is implemented in the transport layer, we can
   484  	// only increment the 'stack-wide' stats but we can't increment the
   485  	// 'per-NetworkEndpoint' stats.
   486  	stats := s.Stats().ICMP.V6.PacketsSent
   487  
   488  	if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil {
   489  		stats.Dropped.Increment()
   490  		return err
   491  	}
   492  
   493  	stats.EchoRequest.Increment()
   494  	return nil
   495  }
   496  
   497  // Disconnect implements tcpip.Endpoint.Disconnect.
   498  func (*endpoint) Disconnect() tcpip.Error {
   499  	return &tcpip.ErrNotSupported{}
   500  }
   501  
   502  // Connect connects the endpoint to its peer. Specifying a NIC is optional.
   503  func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
   504  	e.mu.Lock()
   505  	defer e.mu.Unlock()
   506  
   507  	err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error {
   508  		nextID.LocalPort = e.ident
   509  
   510  		nextID, err := e.registerWithStack(netProto, nextID)
   511  		if err != nil {
   512  			return err
   513  		}
   514  
   515  		e.ident = nextID.LocalPort
   516  		return nil
   517  	})
   518  	if err != nil {
   519  		return err
   520  	}
   521  
   522  	e.rcvMu.Lock()
   523  	e.rcvReady = true
   524  	e.rcvMu.Unlock()
   525  
   526  	return nil
   527  }
   528  
   529  // ConnectEndpoint is not supported.
   530  func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error {
   531  	return &tcpip.ErrInvalidEndpointState{}
   532  }
   533  
   534  // Shutdown closes the read and/or write end of the endpoint connection
   535  // to its peer.
   536  func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error {
   537  	e.mu.Lock()
   538  	defer e.mu.Unlock()
   539  
   540  	switch state := e.net.State(); state {
   541  	case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
   542  		return &tcpip.ErrNotConnected{}
   543  	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   544  	default:
   545  		panic(fmt.Sprintf("unhandled state = %s", state))
   546  	}
   547  
   548  	if flags&tcpip.ShutdownWrite != 0 {
   549  		if err := e.net.Shutdown(); err != nil {
   550  			return err
   551  		}
   552  	}
   553  
   554  	if flags&tcpip.ShutdownRead != 0 {
   555  		e.rcvMu.Lock()
   556  		wasClosed := e.rcvClosed
   557  		e.rcvClosed = true
   558  		e.rcvMu.Unlock()
   559  
   560  		if !wasClosed {
   561  			e.waiterQueue.Notify(waiter.ReadableEvents)
   562  		}
   563  	}
   564  
   565  	return nil
   566  }
   567  
   568  // Listen is not supported by UDP, it just fails.
   569  func (*endpoint) Listen(int) tcpip.Error {
   570  	return &tcpip.ErrNotSupported{}
   571  }
   572  
   573  // Accept is not supported by UDP, it just fails.
   574  func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpip.Error) {
   575  	return nil, nil, &tcpip.ErrNotSupported{}
   576  }
   577  
   578  func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
   579  	bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
   580  	if id.LocalPort != 0 {
   581  		// The endpoint already has a local port, just attempt to
   582  		// register it.
   583  		return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
   584  	}
   585  
   586  	// We need to find a port for the endpoint.
   587  	_, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
   588  		id.LocalPort = p
   589  		err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice)
   590  		switch err.(type) {
   591  		case nil:
   592  			return true, nil
   593  		case *tcpip.ErrPortInUse:
   594  			return false, nil
   595  		default:
   596  			return false, err
   597  		}
   598  	})
   599  
   600  	return id, err
   601  }
   602  
   603  func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error {
   604  	// Don't allow binding once endpoint is not in the initial state
   605  	// anymore.
   606  	if e.net.State() != transport.DatagramEndpointStateInitial {
   607  		return &tcpip.ErrInvalidEndpointState{}
   608  	}
   609  
   610  	err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error {
   611  		id := stack.TransportEndpointID{
   612  			LocalPort:    addr.Port,
   613  			LocalAddress: addr.Addr,
   614  		}
   615  		id, err := e.registerWithStack(boundNetProto, id)
   616  		if err != nil {
   617  			return err
   618  		}
   619  
   620  		e.ident = id.LocalPort
   621  		return nil
   622  	})
   623  	if err != nil {
   624  		return err
   625  	}
   626  
   627  	e.rcvMu.Lock()
   628  	e.rcvReady = true
   629  	e.rcvMu.Unlock()
   630  
   631  	return nil
   632  }
   633  
   634  func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) bool {
   635  	return addr == header.IPv4Broadcast ||
   636  		header.IsV4MulticastAddress(addr) ||
   637  		header.IsV6MulticastAddress(addr) ||
   638  		e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr)
   639  }
   640  
   641  // Bind binds the endpoint to a specific local address and port.
   642  // Specifying a NIC is optional.
   643  func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
   644  	if addr.Addr.BitLen() != 0 && e.isBroadcastOrMulticast(addr.NIC, addr.Addr) {
   645  		return &tcpip.ErrBadLocalAddress{}
   646  	}
   647  
   648  	e.mu.Lock()
   649  	defer e.mu.Unlock()
   650  
   651  	return e.bindLocked(addr)
   652  }
   653  
   654  // GetLocalAddress returns the address to which the endpoint is bound.
   655  func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
   656  	e.mu.RLock()
   657  	defer e.mu.RUnlock()
   658  
   659  	addr := e.net.GetLocalAddress()
   660  	addr.Port = e.ident
   661  	return addr, nil
   662  }
   663  
   664  // GetRemoteAddress returns the address to which the endpoint is connected.
   665  func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) {
   666  	e.mu.RLock()
   667  	defer e.mu.RUnlock()
   668  
   669  	if addr, connected := e.net.GetRemoteAddress(); connected {
   670  		return addr, nil
   671  	}
   672  
   673  	return tcpip.FullAddress{}, &tcpip.ErrNotConnected{}
   674  }
   675  
   676  // Readiness returns the current readiness of the endpoint. For example, if
   677  // waiter.EventIn is set, the endpoint is immediately readable.
   678  func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
   679  	var result waiter.EventMask
   680  
   681  	if e.net.HasSendSpace() {
   682  		result |= waiter.WritableEvents & mask
   683  	}
   684  
   685  	// Determine if the endpoint is readable if requested.
   686  	if (mask & waiter.ReadableEvents) != 0 {
   687  		e.rcvMu.Lock()
   688  		if !e.rcvList.Empty() || e.rcvClosed {
   689  			result |= waiter.ReadableEvents
   690  		}
   691  		e.rcvMu.Unlock()
   692  	}
   693  
   694  	return result
   695  }
   696  
   697  // HandlePacket is called by the stack when new packets arrive to this transport
   698  // endpoint.
   699  func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) {
   700  	// Only accept echo replies.
   701  	switch e.net.NetProto() {
   702  	case header.IPv4ProtocolNumber:
   703  		h := header.ICMPv4(pkt.TransportHeader().Slice())
   704  		if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
   705  			e.stack.Stats().DroppedPackets.Increment()
   706  			e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
   707  			return
   708  		}
   709  	case header.IPv6ProtocolNumber:
   710  		h := header.ICMPv6(pkt.TransportHeader().Slice())
   711  		if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
   712  			e.stack.Stats().DroppedPackets.Increment()
   713  			e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
   714  			return
   715  		}
   716  	}
   717  
   718  	e.rcvMu.Lock()
   719  
   720  	// Drop the packet if our buffer is currently full.
   721  	if !e.rcvReady || e.rcvClosed {
   722  		e.rcvMu.Unlock()
   723  		e.stack.Stats().DroppedPackets.Increment()
   724  		e.stats.ReceiveErrors.ClosedReceiver.Increment()
   725  		return
   726  	}
   727  
   728  	rcvBufSize := e.ops.GetReceiveBufferSize()
   729  	if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
   730  		e.rcvMu.Unlock()
   731  		e.stack.Stats().DroppedPackets.Increment()
   732  		e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
   733  		return
   734  	}
   735  
   736  	wasEmpty := e.rcvBufSize == 0
   737  
   738  	net := pkt.Network()
   739  	dstAddr := net.DestinationAddress()
   740  	// Push new packet into receive list and increment the buffer size.
   741  	packet := &icmpPacket{
   742  		senderAddress: tcpip.FullAddress{
   743  			NIC:  pkt.NICID,
   744  			Addr: id.RemoteAddress,
   745  		},
   746  		packetInfo: tcpip.IPPacketInfo{
   747  			// Linux does not 'prepare' [1] in_pktinfo on socket buffers destined to
   748  			// ping sockets (unlike UDP/RAW sockets). However the interface index [2]
   749  			// and the Header Destination Address [3] are always filled.
   750  			// [1] https://github.com/torvalds/linux/blob/dcb85f85fa6/net/ipv4/ip_sockglue.c#L1392
   751  			// [2] https://github.com/torvalds/linux/blob/dcb85f85fa6/net/ipv4/ip_input.c#L510
   752  			// [3] https://github.com/torvalds/linux/blob/dcb85f85fa6/net/ipv4/ip_sockglue.c#L60
   753  			NIC:             pkt.NICID,
   754  			DestinationAddr: dstAddr,
   755  		},
   756  	}
   757  
   758  	// Save any useful information from the network header to the packet.
   759  	packet.tosOrTClass, _ = net.TOS()
   760  	switch pkt.NetworkProtocolNumber {
   761  	case header.IPv4ProtocolNumber:
   762  		packet.ttlOrHopLimit = header.IPv4(pkt.NetworkHeader().Slice()).TTL()
   763  	case header.IPv6ProtocolNumber:
   764  		packet.ttlOrHopLimit = header.IPv6(pkt.NetworkHeader().Slice()).HopLimit()
   765  	}
   766  
   767  	// ICMP socket's data includes ICMP header but no others. Trim all other
   768  	// headers from the front of the packet.
   769  	pktBuf := pkt.ToBuffer()
   770  	pktBuf.TrimFront(int64(pkt.HeaderSize() - len(pkt.TransportHeader().Slice())))
   771  	packet.data = stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: pktBuf})
   772  
   773  	e.rcvList.PushBack(packet)
   774  	e.rcvBufSize += packet.data.Data().Size()
   775  
   776  	packet.receivedAt = e.stack.Clock().Now()
   777  
   778  	e.rcvMu.Unlock()
   779  	e.stats.PacketsReceived.Increment()
   780  	// Notify any waiters that there's data to be read now.
   781  	if wasEmpty {
   782  		e.waiterQueue.Notify(waiter.ReadableEvents)
   783  	}
   784  }
   785  
   786  // HandleError implements stack.TransportEndpoint.
   787  func (*endpoint) HandleError(stack.TransportError, stack.PacketBufferPtr) {}
   788  
   789  // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
   790  // expose internal socket state.
   791  func (e *endpoint) State() uint32 {
   792  	return uint32(e.net.State())
   793  }
   794  
   795  // Info returns a copy of the endpoint info.
   796  func (e *endpoint) Info() tcpip.EndpointInfo {
   797  	e.mu.RLock()
   798  	defer e.mu.RUnlock()
   799  	ret := e.net.Info()
   800  	ret.ID.LocalPort = e.ident
   801  	return &ret
   802  }
   803  
   804  // Stats returns a pointer to the endpoint stats.
   805  func (e *endpoint) Stats() tcpip.EndpointStats {
   806  	return &e.stats
   807  }
   808  
   809  // Wait implements stack.TransportEndpoint.Wait.
   810  func (*endpoint) Wait() {}
   811  
   812  // LastError implements tcpip.Endpoint.LastError.
   813  func (*endpoint) LastError() tcpip.Error {
   814  	return nil
   815  }
   816  
   817  // SocketOptions implements tcpip.Endpoint.SocketOptions.
   818  func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
   819  	return &e.ops
   820  }
   821  
   822  // freeze prevents any more packets from being delivered to the endpoint.
   823  func (e *endpoint) freeze() {
   824  	e.mu.Lock()
   825  	e.frozen = true
   826  	e.mu.Unlock()
   827  }
   828  
   829  // thaw unfreezes a previously frozen endpoint using endpoint.freeze() allows
   830  // new packets to be delivered again.
   831  func (e *endpoint) thaw() {
   832  	e.mu.Lock()
   833  	e.frozen = false
   834  	e.mu.Unlock()
   835  }