gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/internal/network/endpoint.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package network provides facilities to support tcpip.Endpoints that operate
    16  // at the network layer or above.
    17  package network
    18  
    19  import (
    20  	"fmt"
    21  
    22  	"gvisor.dev/gvisor/pkg/atomicbitops"
    23  	"gvisor.dev/gvisor/pkg/buffer"
    24  	"gvisor.dev/gvisor/pkg/sync"
    25  	"gvisor.dev/gvisor/pkg/tcpip"
    26  	"gvisor.dev/gvisor/pkg/tcpip/header"
    27  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    28  	"gvisor.dev/gvisor/pkg/tcpip/transport"
    29  	"gvisor.dev/gvisor/pkg/waiter"
    30  )
    31  
    32  // Endpoint is a datagram-based endpoint. It only supports sending datagrams to
    33  // a peer.
    34  //
    35  // +stateify savable
    36  type Endpoint struct {
    37  	// The following fields must only be set once then never changed.
    38  	stack       *stack.Stack `state:"manual"`
    39  	ops         *tcpip.SocketOptions
    40  	netProto    tcpip.NetworkProtocolNumber
    41  	transProto  tcpip.TransportProtocolNumber
    42  	waiterQueue *waiter.Queue
    43  
    44  	mu sync.RWMutex `state:"nosave"`
    45  	// +checklocks:mu
    46  	wasBound bool
    47  	// owner is the owner of transmitted packets.
    48  	//
    49  	// +checklocks:mu
    50  	owner tcpip.PacketOwner
    51  	// +checklocks:mu
    52  	writeShutdown bool
    53  	// +checklocks:mu
    54  	effectiveNetProto tcpip.NetworkProtocolNumber
    55  	// +checklocks:mu
    56  	connectedRoute *stack.Route `state:"manual"`
    57  	// +checklocks:mu
    58  	multicastMemberships map[multicastMembership]struct{}
    59  	// +checklocks:mu
    60  	ipv4TTL uint8
    61  	// +checklocks:mu
    62  	ipv6HopLimit int16
    63  	// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
    64  	// +checklocks:mu
    65  	multicastTTL uint8
    66  	// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
    67  	// +checklocks:mu
    68  	multicastAddr tcpip.Address
    69  	// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
    70  	// +checklocks:mu
    71  	multicastNICID tcpip.NICID
    72  	// +checklocks:mu
    73  	ipv4TOS uint8
    74  	// +checklocks:mu
    75  	ipv6TClass uint8
    76  
    77  	// Lock ordering: mu > infoMu.
    78  	infoMu sync.RWMutex `state:"nosave"`
    79  	// info has a dedicated mutex so that we can avoid lock ordering violations
    80  	// when reading the endpoint's info. If we used mu, we need to guarantee
    81  	// that any lock taken while mu is held is not held when calling Info()
    82  	// which is not true as of writing (we hold mu while registering transport
    83  	// endpoints (taking the transport demuxer lock but we also hold the demuxer
    84  	// lock when delivering packets/errors to endpoints).
    85  	//
    86  	// Writes must be performed through setInfo.
    87  	//
    88  	// +checklocks:infoMu
    89  	info stack.TransportEndpointInfo
    90  
    91  	// state holds a transport.DatagramBasedEndpointState.
    92  	//
    93  	// state must be accessed with atomics so that we can avoid lock ordering
    94  	// violations when reading the state. If we used mu, we need to guarantee
    95  	// that any lock taken while mu is held is not held when calling State()
    96  	// which is not true as of writing (we hold mu while registering transport
    97  	// endpoints (taking the transport demuxer lock but we also hold the demuxer
    98  	// lock when delivering packets/errors to endpoints).
    99  	//
   100  	// Writes must be performed through setEndpointState.
   101  	state atomicbitops.Uint32
   102  
   103  	// Callers should not attempt to obtain sendBufferSizeInUseMu while holding
   104  	// another lock on Endpoint.
   105  	sendBufferSizeInUseMu sync.RWMutex `state:"nosave"`
   106  	// sendBufferSizeInUse keeps track of the bytes in use by in-flight packets.
   107  	//
   108  	// +checklocks:sendBufferSizeInUseMu
   109  	sendBufferSizeInUse int64 `state:"nosave"`
   110  }
   111  
   112  // +stateify savable
   113  type multicastMembership struct {
   114  	nicID         tcpip.NICID
   115  	multicastAddr tcpip.Address
   116  }
   117  
   118  // Init initializes the endpoint.
   119  func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions, waiterQueue *waiter.Queue) {
   120  	e.mu.Lock()
   121  	defer e.mu.Unlock()
   122  	if e.multicastMemberships != nil {
   123  		panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships))
   124  	}
   125  
   126  	switch netProto {
   127  	case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
   128  	default:
   129  		panic(fmt.Sprintf("invalid protocol number = %d", netProto))
   130  	}
   131  
   132  	e.stack = s
   133  	e.ops = ops
   134  	e.netProto = netProto
   135  	e.transProto = transProto
   136  	e.waiterQueue = waiterQueue
   137  	e.infoMu.Lock()
   138  	e.info = stack.TransportEndpointInfo{
   139  		NetProto:   netProto,
   140  		TransProto: transProto,
   141  	}
   142  	e.infoMu.Unlock()
   143  	e.effectiveNetProto = netProto
   144  	e.ipv4TTL = tcpip.UseDefaultIPv4TTL
   145  	e.ipv6HopLimit = tcpip.UseDefaultIPv6HopLimit
   146  
   147  	// Linux defaults to TTL=1.
   148  	e.multicastTTL = 1
   149  	e.multicastMemberships = make(map[multicastMembership]struct{})
   150  	e.setEndpointState(transport.DatagramEndpointStateInitial)
   151  }
   152  
   153  // NetProto returns the network protocol the endpoint was initialized with.
   154  func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
   155  	return e.netProto
   156  }
   157  
   158  // setEndpointState sets the state of the endpoint.
   159  //
   160  // e.mu must be held to synchronize changes to state with the rest of the
   161  // endpoint.
   162  //
   163  // +checklocks:e.mu
   164  func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
   165  	e.state.Store(uint32(state))
   166  }
   167  
   168  // State returns the state of the endpoint.
   169  func (e *Endpoint) State() transport.DatagramEndpointState {
   170  	return transport.DatagramEndpointState(e.state.Load())
   171  }
   172  
   173  // Close cleans the endpoint's resources and leaves the endpoint in a closed
   174  // state.
   175  func (e *Endpoint) Close() {
   176  	e.mu.Lock()
   177  	defer e.mu.Unlock()
   178  
   179  	if e.State() == transport.DatagramEndpointStateClosed {
   180  		return
   181  	}
   182  
   183  	for mem := range e.multicastMemberships {
   184  		e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
   185  	}
   186  	e.multicastMemberships = nil
   187  
   188  	if e.connectedRoute != nil {
   189  		e.connectedRoute.Release()
   190  		e.connectedRoute = nil
   191  	}
   192  
   193  	e.setEndpointState(transport.DatagramEndpointStateClosed)
   194  }
   195  
   196  // SetOwner sets the owner of transmitted packets.
   197  func (e *Endpoint) SetOwner(owner tcpip.PacketOwner) {
   198  	e.mu.Lock()
   199  	defer e.mu.Unlock()
   200  	e.owner = owner
   201  }
   202  
   203  // +checklocksread:e.mu
   204  func (e *Endpoint) calculateTTL(route *stack.Route) uint8 {
   205  	remoteAddress := route.RemoteAddress()
   206  	if header.IsV4MulticastAddress(remoteAddress) || header.IsV6MulticastAddress(remoteAddress) {
   207  		return e.multicastTTL
   208  	}
   209  
   210  	switch netProto := route.NetProto(); netProto {
   211  	case header.IPv4ProtocolNumber:
   212  		if e.ipv4TTL == 0 {
   213  			return route.DefaultTTL()
   214  		}
   215  		return e.ipv4TTL
   216  	case header.IPv6ProtocolNumber:
   217  		if e.ipv6HopLimit == -1 {
   218  			return route.DefaultTTL()
   219  		}
   220  		return uint8(e.ipv6HopLimit)
   221  	default:
   222  		panic(fmt.Sprintf("invalid protocol number = %d", netProto))
   223  	}
   224  }
   225  
   226  // WriteContext holds the context for a write.
   227  type WriteContext struct {
   228  	e     *Endpoint
   229  	route *stack.Route
   230  	ttl   uint8
   231  	tos   uint8
   232  }
   233  
   234  func (c *WriteContext) MTU() uint32 {
   235  	return c.route.MTU()
   236  }
   237  
   238  // Release releases held resources.
   239  func (c *WriteContext) Release() {
   240  	c.route.Release()
   241  	*c = WriteContext{}
   242  }
   243  
   244  // WritePacketInfo is the properties of a packet that may be written.
   245  type WritePacketInfo struct {
   246  	NetProto                    tcpip.NetworkProtocolNumber
   247  	LocalAddress, RemoteAddress tcpip.Address
   248  	MaxHeaderLength             uint16
   249  	RequiresTXTransportChecksum bool
   250  }
   251  
   252  // PacketInfo returns the properties of a packet that will be written.
   253  func (c *WriteContext) PacketInfo() WritePacketInfo {
   254  	return WritePacketInfo{
   255  		NetProto:                    c.route.NetProto(),
   256  		LocalAddress:                c.route.LocalAddress(),
   257  		RemoteAddress:               c.route.RemoteAddress(),
   258  		MaxHeaderLength:             c.route.MaxHeaderLength(),
   259  		RequiresTXTransportChecksum: c.route.RequiresTXTransportChecksum(),
   260  	}
   261  }
   262  
   263  // TryNewPacketBuffer returns a new packet buffer iff the endpoint's send buffer
   264  // is not full.
   265  //
   266  // If this method returns nil, the caller should wait for the endpoint to become
   267  // writable.
   268  func (c *WriteContext) TryNewPacketBuffer(reserveHdrBytes int, data buffer.Buffer) *stack.PacketBuffer {
   269  	e := c.e
   270  
   271  	e.sendBufferSizeInUseMu.Lock()
   272  	defer e.sendBufferSizeInUseMu.Unlock()
   273  
   274  	if !e.hasSendSpaceRLocked() {
   275  		return nil
   276  	}
   277  	return c.newPacketBufferLocked(reserveHdrBytes, data)
   278  }
   279  
   280  // TryNewPacketBufferFromPayloader returns a new packet buffer iff the endpoint's send buffer
   281  // is not full. Otherwise, data from `payloader` isn't read.
   282  //
   283  // If this method returns nil, the caller should wait for the endpoint to become
   284  // writable.
   285  func (c *WriteContext) TryNewPacketBufferFromPayloader(reserveHdrBytes int, payloader tcpip.Payloader) *stack.PacketBuffer {
   286  	e := c.e
   287  
   288  	e.sendBufferSizeInUseMu.Lock()
   289  	defer e.sendBufferSizeInUseMu.Unlock()
   290  
   291  	if !e.hasSendSpaceRLocked() {
   292  		return nil
   293  	}
   294  	var data buffer.Buffer
   295  	if _, err := data.WriteFromReader(payloader, int64(payloader.Len())); err != nil {
   296  		data.Release()
   297  		return nil
   298  	}
   299  	return c.newPacketBufferLocked(reserveHdrBytes, data)
   300  }
   301  
   302  // +checklocks:c.e.sendBufferSizeInUseMu
   303  func (c *WriteContext) newPacketBufferLocked(reserveHdrBytes int, data buffer.Buffer) *stack.PacketBuffer {
   304  	e := c.e
   305  	// Note that we allow oversubscription - if there is any space at all in the
   306  	// send buffer, we accept the full packet which may be larger than the space
   307  	// available. This is because if the endpoint reports that it is writable,
   308  	// a write operation should succeed.
   309  	//
   310  	// This matches Linux behaviour:
   311  	// https://github.com/torvalds/linux/blob/38d741cb70b/include/net/sock.h#L2519
   312  	// https://github.com/torvalds/linux/blob/38d741cb70b/net/core/sock.c#L2588
   313  	pktSize := int64(reserveHdrBytes) + int64(data.Size())
   314  	e.sendBufferSizeInUse += pktSize
   315  
   316  	return stack.NewPacketBuffer(stack.PacketBufferOptions{
   317  		ReserveHeaderBytes: reserveHdrBytes,
   318  		Payload:            data,
   319  		OnRelease: func() {
   320  			e.sendBufferSizeInUseMu.Lock()
   321  			if got := e.sendBufferSizeInUse; got < pktSize {
   322  				e.sendBufferSizeInUseMu.Unlock()
   323  				panic(fmt.Sprintf("e.sendBufferSizeInUse=(%d) < pktSize(=%d)", got, pktSize))
   324  			}
   325  			e.sendBufferSizeInUse -= pktSize
   326  			signal := e.hasSendSpaceRLocked()
   327  			e.sendBufferSizeInUseMu.Unlock()
   328  
   329  			// Let waiters know if we now have space in the send buffer.
   330  			if signal {
   331  				e.waiterQueue.Notify(waiter.WritableEvents)
   332  			}
   333  		},
   334  	})
   335  }
   336  
   337  // WritePacket attempts to write the packet.
   338  func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
   339  	c.e.mu.RLock()
   340  	pkt.Owner = c.e.owner
   341  	c.e.mu.RUnlock()
   342  
   343  	if headerIncluded {
   344  		return c.route.WriteHeaderIncludedPacket(pkt)
   345  	}
   346  
   347  	err := c.route.WritePacket(stack.NetworkHeaderParams{
   348  		Protocol: c.e.transProto,
   349  		TTL:      c.ttl,
   350  		TOS:      c.tos,
   351  	}, pkt)
   352  
   353  	if _, ok := err.(*tcpip.ErrNoBufferSpace); ok {
   354  		var recvErr bool
   355  		switch netProto := c.route.NetProto(); netProto {
   356  		case header.IPv4ProtocolNumber:
   357  			recvErr = c.e.ops.GetIPv4RecvError()
   358  		case header.IPv6ProtocolNumber:
   359  			recvErr = c.e.ops.GetIPv6RecvError()
   360  		default:
   361  			panic(fmt.Sprintf("unhandled network protocol number = %d", netProto))
   362  		}
   363  
   364  		// Linux only returns ENOBUFS to the caller if IP{,V6}_RECVERR is set.
   365  		//
   366  		// https://github.com/torvalds/linux/blob/3e71713c9e75c/net/ipv4/udp.c#L969
   367  		// https://github.com/torvalds/linux/blob/3e71713c9e75c/net/ipv6/udp.c#L1260
   368  		if !recvErr {
   369  			err = nil
   370  		}
   371  	}
   372  
   373  	return err
   374  }
   375  
   376  // MaybeSignalWritable signals waiters with writable events if the send buffer
   377  // has space.
   378  func (e *Endpoint) MaybeSignalWritable() {
   379  	e.sendBufferSizeInUseMu.RLock()
   380  	signal := e.hasSendSpaceRLocked()
   381  	e.sendBufferSizeInUseMu.RUnlock()
   382  
   383  	if signal {
   384  		e.waiterQueue.Notify(waiter.WritableEvents)
   385  	}
   386  }
   387  
   388  // HasSendSpace returns whether or not the send buffer has space.
   389  func (e *Endpoint) HasSendSpace() bool {
   390  	e.sendBufferSizeInUseMu.RLock()
   391  	defer e.sendBufferSizeInUseMu.RUnlock()
   392  	return e.hasSendSpaceRLocked()
   393  }
   394  
   395  // +checklocksread:e.sendBufferSizeInUseMu
   396  func (e *Endpoint) hasSendSpaceRLocked() bool {
   397  	return e.ops.GetSendBufferSize() > e.sendBufferSizeInUse
   398  }
   399  
   400  // AcquireContextForWrite acquires a WriteContext.
   401  func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext, tcpip.Error) {
   402  	e.mu.RLock()
   403  	defer e.mu.RUnlock()
   404  
   405  	// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
   406  	if opts.More {
   407  		return WriteContext{}, &tcpip.ErrInvalidOptionValue{}
   408  	}
   409  
   410  	if e.State() == transport.DatagramEndpointStateClosed {
   411  		return WriteContext{}, &tcpip.ErrInvalidEndpointState{}
   412  	}
   413  
   414  	if e.writeShutdown {
   415  		return WriteContext{}, &tcpip.ErrClosedForSend{}
   416  	}
   417  
   418  	ipv6PktInfoValid := e.effectiveNetProto == header.IPv6ProtocolNumber && opts.ControlMessages.HasIPv6PacketInfo
   419  
   420  	route := e.connectedRoute
   421  	to := opts.To
   422  	info := e.Info()
   423  	switch {
   424  	case to == nil:
   425  		// If the user doesn't specify a destination, they should have
   426  		// connected to another address.
   427  		if e.State() != transport.DatagramEndpointStateConnected {
   428  			return WriteContext{}, &tcpip.ErrDestinationRequired{}
   429  		}
   430  
   431  		if !ipv6PktInfoValid {
   432  			route.Acquire()
   433  			break
   434  		}
   435  
   436  		// We are connected and the caller did not specify the destination but
   437  		// we have an IPv6 packet info structure which may change our local
   438  		// interface/address used to send the packet so we need to construct
   439  		// a new route instead of using the connected route.
   440  		//
   441  		// Construct a destination matching the remote the endpoint is connected
   442  		// to.
   443  		to = &tcpip.FullAddress{
   444  			// RegisterNICID is set when the endpoint is connected. It is usually
   445  			// only set for link-local addresses or multicast addresses if the
   446  			// multicast interface was specified (see e.multicastNICID,
   447  			// e.connectRouteRLocked and e.ConnectAndThen).
   448  			NIC:  info.RegisterNICID,
   449  			Addr: info.ID.RemoteAddress,
   450  		}
   451  		fallthrough
   452  	default:
   453  		// Reject destination address if it goes through a different
   454  		// NIC than the endpoint was bound to.
   455  		nicID := to.NIC
   456  		if nicID == 0 {
   457  			nicID = tcpip.NICID(e.ops.GetBindToDevice())
   458  		}
   459  
   460  		var localAddr tcpip.Address
   461  		if ipv6PktInfoValid {
   462  			// Uphold strong-host semantics since (as of writing) the stack follows
   463  			// the strong host model.
   464  
   465  			pktInfoNICID := opts.ControlMessages.IPv6PacketInfo.NIC
   466  			pktInfoAddr := opts.ControlMessages.IPv6PacketInfo.Addr
   467  
   468  			if pktInfoNICID != 0 {
   469  				// If we are bound to an interface or specified the destination
   470  				// interface (usually when using link-local addresses), make sure the
   471  				// interface matches the specified local interface.
   472  				if nicID != 0 && nicID != pktInfoNICID {
   473  					return WriteContext{}, &tcpip.ErrHostUnreachable{}
   474  				}
   475  
   476  				// If a local address is not specified, then we need to make sure the
   477  				// bound address belongs to the specified local interface.
   478  				if pktInfoAddr.BitLen() == 0 {
   479  					// If the bound interface is different from the specified local
   480  					// interface, the bound address obviously does not belong to the
   481  					// specified local interface.
   482  					//
   483  					// The bound interface is usually only set for link-local addresses.
   484  					if info.BindNICID != 0 && info.BindNICID != pktInfoNICID {
   485  						return WriteContext{}, &tcpip.ErrHostUnreachable{}
   486  					}
   487  					if info.ID.LocalAddress.BitLen() != 0 && e.stack.CheckLocalAddress(pktInfoNICID, header.IPv6ProtocolNumber, info.ID.LocalAddress) == 0 {
   488  						return WriteContext{}, &tcpip.ErrBadLocalAddress{}
   489  					}
   490  				}
   491  
   492  				nicID = pktInfoNICID
   493  			}
   494  
   495  			if pktInfoAddr.BitLen() != 0 {
   496  				// The local address must belong to the stack. If an outgoing interface
   497  				// is specified as a result of binding the endpoint to a device, or
   498  				// specifying the outgoing interface in the destination address/pkt info
   499  				// structure, the address must belong to that interface.
   500  				if e.stack.CheckLocalAddress(nicID, header.IPv6ProtocolNumber, pktInfoAddr) == 0 {
   501  					return WriteContext{}, &tcpip.ErrBadLocalAddress{}
   502  				}
   503  
   504  				localAddr = pktInfoAddr
   505  			}
   506  		} else {
   507  			if info.BindNICID != 0 {
   508  				if nicID != 0 && nicID != info.BindNICID {
   509  					return WriteContext{}, &tcpip.ErrHostUnreachable{}
   510  				}
   511  
   512  				nicID = info.BindNICID
   513  			}
   514  			if nicID == 0 {
   515  				nicID = info.RegisterNICID
   516  			}
   517  		}
   518  
   519  		dst, netProto, err := e.checkV4Mapped(*to)
   520  		if err != nil {
   521  			return WriteContext{}, err
   522  		}
   523  
   524  		route, _, err = e.connectRouteRLocked(nicID, localAddr, dst, netProto)
   525  		if err != nil {
   526  			return WriteContext{}, err
   527  		}
   528  	}
   529  
   530  	if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
   531  		route.Release()
   532  		return WriteContext{}, &tcpip.ErrBroadcastDisabled{}
   533  	}
   534  
   535  	var tos uint8
   536  	var ttl uint8
   537  	switch netProto := route.NetProto(); netProto {
   538  	case header.IPv4ProtocolNumber:
   539  		tos = e.ipv4TOS
   540  		if opts.ControlMessages.HasTTL {
   541  			ttl = opts.ControlMessages.TTL
   542  		} else {
   543  			ttl = e.calculateTTL(route)
   544  		}
   545  	case header.IPv6ProtocolNumber:
   546  		tos = e.ipv6TClass
   547  		if opts.ControlMessages.HasHopLimit {
   548  			ttl = opts.ControlMessages.HopLimit
   549  		} else {
   550  			ttl = e.calculateTTL(route)
   551  		}
   552  	default:
   553  		panic(fmt.Sprintf("invalid protocol number = %d", netProto))
   554  	}
   555  
   556  	return WriteContext{
   557  		e:     e,
   558  		route: route,
   559  		ttl:   ttl,
   560  		tos:   tos,
   561  	}, nil
   562  }
   563  
   564  // Disconnect disconnects the endpoint from its peer.
   565  func (e *Endpoint) Disconnect() {
   566  	e.mu.Lock()
   567  	defer e.mu.Unlock()
   568  
   569  	if e.State() != transport.DatagramEndpointStateConnected {
   570  		return
   571  	}
   572  
   573  	info := e.Info()
   574  	// Exclude ephemerally bound endpoints.
   575  	if e.wasBound {
   576  		info.ID = stack.TransportEndpointID{
   577  			LocalAddress: info.BindAddr,
   578  		}
   579  		e.setEndpointState(transport.DatagramEndpointStateBound)
   580  	} else {
   581  		info.ID = stack.TransportEndpointID{}
   582  		e.setEndpointState(transport.DatagramEndpointStateInitial)
   583  	}
   584  	e.setInfo(info)
   585  
   586  	e.connectedRoute.Release()
   587  	e.connectedRoute = nil
   588  }
   589  
   590  // connectRouteRLocked establishes a route to the specified interface or the
   591  // configured multicast interface if no interface is specified and the
   592  // specified address is a multicast address.
   593  //
   594  // +checklocksread:e.mu
   595  func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, localAddr tcpip.Address, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
   596  	if localAddr.BitLen() == 0 {
   597  		localAddr = e.Info().ID.LocalAddress
   598  		if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
   599  			// A packet can only originate from a unicast address (i.e., an interface).
   600  			localAddr = tcpip.Address{}
   601  		}
   602  
   603  		if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
   604  			if nicID == 0 {
   605  				nicID = e.multicastNICID
   606  			}
   607  			if localAddr == (tcpip.Address{}) && nicID == 0 {
   608  				localAddr = e.multicastAddr
   609  			}
   610  		}
   611  	}
   612  
   613  	// Find a route to the desired destination.
   614  	r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
   615  	if err != nil {
   616  		return nil, 0, err
   617  	}
   618  	return r, nicID, nil
   619  }
   620  
   621  // Connect connects the endpoint to the address.
   622  func (e *Endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
   623  	return e.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
   624  		return nil
   625  	})
   626  }
   627  
   628  // ConnectAndThen connects the endpoint to the address and then calls the
   629  // provided function.
   630  //
   631  // If the function returns an error, the endpoint's state does not change. The
   632  // function will be called with the network protocol used to connect to the peer
   633  // and the source and destination addresses that will be used to send traffic to
   634  // the peer.
   635  func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error) tcpip.Error {
   636  	addr.Port = 0
   637  
   638  	e.mu.Lock()
   639  	defer e.mu.Unlock()
   640  
   641  	info := e.Info()
   642  	nicID := addr.NIC
   643  	switch e.State() {
   644  	case transport.DatagramEndpointStateInitial:
   645  	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   646  		if info.BindNICID == 0 {
   647  			break
   648  		}
   649  
   650  		if nicID != 0 && nicID != info.BindNICID {
   651  			return &tcpip.ErrInvalidEndpointState{}
   652  		}
   653  
   654  		nicID = info.BindNICID
   655  	default:
   656  		return &tcpip.ErrInvalidEndpointState{}
   657  	}
   658  
   659  	addr, netProto, err := e.checkV4Mapped(addr)
   660  	if err != nil {
   661  		return err
   662  	}
   663  
   664  	r, nicID, err := e.connectRouteRLocked(nicID, tcpip.Address{}, addr, netProto)
   665  	if err != nil {
   666  		return err
   667  	}
   668  
   669  	id := stack.TransportEndpointID{
   670  		LocalAddress:  info.ID.LocalAddress,
   671  		RemoteAddress: r.RemoteAddress(),
   672  	}
   673  	if e.State() == transport.DatagramEndpointStateInitial {
   674  		id.LocalAddress = r.LocalAddress()
   675  	}
   676  
   677  	if err := f(r.NetProto(), info.ID, id); err != nil {
   678  		r.Release()
   679  		return err
   680  	}
   681  
   682  	if e.connectedRoute != nil {
   683  		// If the endpoint was previously connected then release any previous route.
   684  		e.connectedRoute.Release()
   685  	}
   686  	e.connectedRoute = r
   687  	info.ID = id
   688  	info.RegisterNICID = nicID
   689  	e.setInfo(info)
   690  	e.effectiveNetProto = netProto
   691  	e.setEndpointState(transport.DatagramEndpointStateConnected)
   692  	return nil
   693  }
   694  
   695  // Shutdown shutsdown the endpoint.
   696  func (e *Endpoint) Shutdown() tcpip.Error {
   697  	e.mu.Lock()
   698  	defer e.mu.Unlock()
   699  
   700  	switch state := e.State(); state {
   701  	case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
   702  		return &tcpip.ErrNotConnected{}
   703  	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   704  		e.writeShutdown = true
   705  		return nil
   706  	default:
   707  		panic(fmt.Sprintf("unhandled state = %s", state))
   708  	}
   709  }
   710  
   711  // checkV4MappedRLocked determines the effective network protocol and converts
   712  // addr to its canonical form.
   713  func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
   714  	info := e.Info()
   715  	unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
   716  	if err != nil {
   717  		return tcpip.FullAddress{}, 0, err
   718  	}
   719  	return unwrapped, netProto, nil
   720  }
   721  
   722  func (e *Endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
   723  	return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
   724  }
   725  
   726  // Bind binds the endpoint to the address.
   727  func (e *Endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
   728  	return e.BindAndThen(addr, func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error {
   729  		return nil
   730  	})
   731  }
   732  
   733  // BindAndThen binds the endpoint to the address and then calls the provided
   734  // function.
   735  //
   736  // If the function returns an error, the endpoint's state does not change. The
   737  // function will be called with the bound network protocol and address.
   738  func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error) tcpip.Error {
   739  	addr.Port = 0
   740  
   741  	e.mu.Lock()
   742  	defer e.mu.Unlock()
   743  
   744  	// Don't allow binding once endpoint is not in the initial state
   745  	// anymore.
   746  	if e.State() != transport.DatagramEndpointStateInitial {
   747  		return &tcpip.ErrInvalidEndpointState{}
   748  	}
   749  
   750  	addr, netProto, err := e.checkV4Mapped(addr)
   751  	if err != nil {
   752  		return err
   753  	}
   754  
   755  	nicID := addr.NIC
   756  	if addr.Addr.BitLen() != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
   757  		nicID = e.stack.CheckLocalAddress(nicID, netProto, addr.Addr)
   758  		if nicID == 0 {
   759  			return &tcpip.ErrBadLocalAddress{}
   760  		}
   761  	}
   762  
   763  	if err := f(netProto, addr.Addr); err != nil {
   764  		return err
   765  	}
   766  
   767  	e.wasBound = true
   768  
   769  	info := e.Info()
   770  	info.ID = stack.TransportEndpointID{
   771  		LocalAddress: addr.Addr,
   772  	}
   773  	info.BindNICID = addr.NIC
   774  	info.RegisterNICID = nicID
   775  	info.BindAddr = addr.Addr
   776  	e.setInfo(info)
   777  	e.effectiveNetProto = netProto
   778  	e.setEndpointState(transport.DatagramEndpointStateBound)
   779  	return nil
   780  }
   781  
   782  // WasBound returns true iff the endpoint was ever bound.
   783  func (e *Endpoint) WasBound() bool {
   784  	e.mu.RLock()
   785  	defer e.mu.RUnlock()
   786  	return e.wasBound
   787  }
   788  
   789  // GetLocalAddress returns the address that the endpoint is bound to.
   790  func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
   791  	e.mu.RLock()
   792  	defer e.mu.RUnlock()
   793  
   794  	info := e.Info()
   795  	addr := info.BindAddr
   796  	if e.State() == transport.DatagramEndpointStateConnected {
   797  		addr = e.connectedRoute.LocalAddress()
   798  	}
   799  
   800  	return tcpip.FullAddress{
   801  		NIC:  info.RegisterNICID,
   802  		Addr: addr,
   803  	}
   804  }
   805  
   806  // GetRemoteAddress returns the address that the endpoint is connected to.
   807  func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
   808  	e.mu.RLock()
   809  	defer e.mu.RUnlock()
   810  
   811  	if e.State() != transport.DatagramEndpointStateConnected {
   812  		return tcpip.FullAddress{}, false
   813  	}
   814  
   815  	return tcpip.FullAddress{
   816  		Addr: e.connectedRoute.RemoteAddress(),
   817  		NIC:  e.Info().RegisterNICID,
   818  	}, true
   819  }
   820  
   821  // SetSockOptInt sets the socket option.
   822  func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
   823  	switch opt {
   824  	case tcpip.MTUDiscoverOption:
   825  		// Return not supported if the value is not disabling path
   826  		// MTU discovery.
   827  		if tcpip.PMTUDStrategy(v) != tcpip.PMTUDiscoveryDont {
   828  			return &tcpip.ErrNotSupported{}
   829  		}
   830  
   831  	case tcpip.MulticastTTLOption:
   832  		e.mu.Lock()
   833  		e.multicastTTL = uint8(v)
   834  		e.mu.Unlock()
   835  
   836  	case tcpip.IPv4TTLOption:
   837  		e.mu.Lock()
   838  		e.ipv4TTL = uint8(v)
   839  		e.mu.Unlock()
   840  
   841  	case tcpip.IPv6HopLimitOption:
   842  		e.mu.Lock()
   843  		e.ipv6HopLimit = int16(v)
   844  		e.mu.Unlock()
   845  
   846  	case tcpip.IPv4TOSOption:
   847  		e.mu.Lock()
   848  		e.ipv4TOS = uint8(v)
   849  		e.mu.Unlock()
   850  
   851  	case tcpip.IPv6TrafficClassOption:
   852  		e.mu.Lock()
   853  		e.ipv6TClass = uint8(v)
   854  		e.mu.Unlock()
   855  	}
   856  
   857  	return nil
   858  }
   859  
   860  // GetSockOptInt returns the socket option.
   861  func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
   862  	switch opt {
   863  	case tcpip.MTUDiscoverOption:
   864  		// The only supported setting is path MTU discovery disabled.
   865  		return int(tcpip.PMTUDiscoveryDont), nil
   866  
   867  	case tcpip.MulticastTTLOption:
   868  		e.mu.Lock()
   869  		v := int(e.multicastTTL)
   870  		e.mu.Unlock()
   871  		return v, nil
   872  
   873  	case tcpip.IPv4TTLOption:
   874  		e.mu.Lock()
   875  		v := int(e.ipv4TTL)
   876  		e.mu.Unlock()
   877  		return v, nil
   878  
   879  	case tcpip.IPv6HopLimitOption:
   880  		e.mu.Lock()
   881  		v := int(e.ipv6HopLimit)
   882  		e.mu.Unlock()
   883  		return v, nil
   884  
   885  	case tcpip.IPv4TOSOption:
   886  		e.mu.RLock()
   887  		v := int(e.ipv4TOS)
   888  		e.mu.RUnlock()
   889  		return v, nil
   890  
   891  	case tcpip.IPv6TrafficClassOption:
   892  		e.mu.RLock()
   893  		v := int(e.ipv6TClass)
   894  		e.mu.RUnlock()
   895  		return v, nil
   896  
   897  	default:
   898  		return -1, &tcpip.ErrUnknownProtocolOption{}
   899  	}
   900  }
   901  
   902  // SetSockOpt sets the socket option.
   903  func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
   904  	switch v := opt.(type) {
   905  	case *tcpip.MulticastInterfaceOption:
   906  		e.mu.Lock()
   907  		defer e.mu.Unlock()
   908  
   909  		fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
   910  		fa, netProto, err := e.checkV4Mapped(fa)
   911  		if err != nil {
   912  			return err
   913  		}
   914  		nic := v.NIC
   915  		addr := fa.Addr
   916  
   917  		if nic == 0 && addr == (tcpip.Address{}) {
   918  			e.multicastAddr = tcpip.Address{}
   919  			e.multicastNICID = 0
   920  			break
   921  		}
   922  
   923  		if nic != 0 {
   924  			if !e.stack.CheckNIC(nic) {
   925  				return &tcpip.ErrBadLocalAddress{}
   926  			}
   927  		} else {
   928  			nic = e.stack.CheckLocalAddress(0, netProto, addr)
   929  			if nic == 0 {
   930  				return &tcpip.ErrBadLocalAddress{}
   931  			}
   932  		}
   933  
   934  		if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic {
   935  			return &tcpip.ErrInvalidEndpointState{}
   936  		}
   937  
   938  		e.multicastNICID = nic
   939  		e.multicastAddr = addr
   940  
   941  	case *tcpip.AddMembershipOption:
   942  		if !(header.IsV4MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv4ProtocolNumber) && !(header.IsV6MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv6ProtocolNumber) {
   943  			return &tcpip.ErrInvalidOptionValue{}
   944  		}
   945  
   946  		nicID := v.NIC
   947  
   948  		if v.InterfaceAddr.Unspecified() {
   949  			if nicID == 0 {
   950  				if r, err := e.stack.FindRoute(0, tcpip.Address{}, v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
   951  					nicID = r.NICID()
   952  					r.Release()
   953  				}
   954  			}
   955  		} else {
   956  			nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
   957  		}
   958  		if nicID == 0 {
   959  			return &tcpip.ErrUnknownDevice{}
   960  		}
   961  
   962  		memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
   963  
   964  		e.mu.Lock()
   965  		defer e.mu.Unlock()
   966  
   967  		if _, ok := e.multicastMemberships[memToInsert]; ok {
   968  			return &tcpip.ErrPortInUse{}
   969  		}
   970  
   971  		if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
   972  			return err
   973  		}
   974  
   975  		e.multicastMemberships[memToInsert] = struct{}{}
   976  
   977  	case *tcpip.RemoveMembershipOption:
   978  		if !(header.IsV4MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv4ProtocolNumber) && !(header.IsV6MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv6ProtocolNumber) {
   979  			return &tcpip.ErrInvalidOptionValue{}
   980  		}
   981  
   982  		nicID := v.NIC
   983  		if v.InterfaceAddr.Unspecified() {
   984  			if nicID == 0 {
   985  				if r, err := e.stack.FindRoute(0, tcpip.Address{}, v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
   986  					nicID = r.NICID()
   987  					r.Release()
   988  				}
   989  			}
   990  		} else {
   991  			nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
   992  		}
   993  		if nicID == 0 {
   994  			return &tcpip.ErrUnknownDevice{}
   995  		}
   996  
   997  		memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
   998  
   999  		e.mu.Lock()
  1000  		defer e.mu.Unlock()
  1001  
  1002  		if _, ok := e.multicastMemberships[memToRemove]; !ok {
  1003  			return &tcpip.ErrBadLocalAddress{}
  1004  		}
  1005  
  1006  		if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
  1007  			return err
  1008  		}
  1009  
  1010  		delete(e.multicastMemberships, memToRemove)
  1011  
  1012  	case *tcpip.SocketDetachFilterOption:
  1013  		return nil
  1014  	}
  1015  	return nil
  1016  }
  1017  
  1018  // GetSockOpt returns the socket option.
  1019  func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
  1020  	switch o := opt.(type) {
  1021  	case *tcpip.MulticastInterfaceOption:
  1022  		e.mu.Lock()
  1023  		*o = tcpip.MulticastInterfaceOption{
  1024  			NIC:           e.multicastNICID,
  1025  			InterfaceAddr: e.multicastAddr,
  1026  		}
  1027  		e.mu.Unlock()
  1028  
  1029  	default:
  1030  		return &tcpip.ErrUnknownProtocolOption{}
  1031  	}
  1032  	return nil
  1033  }
  1034  
  1035  // Info returns a copy of the endpoint info.
  1036  func (e *Endpoint) Info() stack.TransportEndpointInfo {
  1037  	e.infoMu.RLock()
  1038  	defer e.infoMu.RUnlock()
  1039  	return e.info
  1040  }
  1041  
  1042  // setInfo sets the endpoint's info.
  1043  //
  1044  // e.mu must be held to synchronize changes to info with the rest of the
  1045  // endpoint.
  1046  //
  1047  // +checklocks:e.mu
  1048  func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) {
  1049  	e.infoMu.Lock()
  1050  	defer e.infoMu.Unlock()
  1051  	e.info = info
  1052  }