github.com/noisysockets/netstack@v0.6.0/pkg/tcpip/transport/internal/network/endpoint.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package network provides facilities to support tcpip.Endpoints that operate
    16  // at the network layer or above.
    17  package network
    18  
    19  import (
    20  	"fmt"
    21  
    22  	"github.com/noisysockets/netstack/pkg/atomicbitops"
    23  	"github.com/noisysockets/netstack/pkg/buffer"
    24  	"github.com/noisysockets/netstack/pkg/sync"
    25  	"github.com/noisysockets/netstack/pkg/tcpip"
    26  	"github.com/noisysockets/netstack/pkg/tcpip/header"
    27  	"github.com/noisysockets/netstack/pkg/tcpip/stack"
    28  	"github.com/noisysockets/netstack/pkg/tcpip/transport"
    29  	"github.com/noisysockets/netstack/pkg/waiter"
    30  )
    31  
    32  // Endpoint is a datagram-based endpoint. It only supports sending datagrams to
    33  // a peer.
    34  //
    35  // +stateify savable
    36  type Endpoint struct {
    37  	// The following fields must only be set once then never changed.
    38  	stack       *stack.Stack `state:"manual"`
    39  	ops         *tcpip.SocketOptions
    40  	netProto    tcpip.NetworkProtocolNumber
    41  	transProto  tcpip.TransportProtocolNumber
    42  	waiterQueue *waiter.Queue
    43  
    44  	mu sync.RWMutex `state:"nosave"`
    45  	// +checklocks:mu
    46  	wasBound bool
    47  	// owner is the owner of transmitted packets.
    48  	//
    49  	// +checklocks:mu
    50  	owner tcpip.PacketOwner
    51  	// +checklocks:mu
    52  	writeShutdown bool
    53  	// +checklocks:mu
    54  	effectiveNetProto tcpip.NetworkProtocolNumber
    55  	// +checklocks:mu
    56  	connectedRoute *stack.Route `state:"manual"`
    57  	// +checklocks:mu
    58  	multicastMemberships map[multicastMembership]struct{}
    59  	// +checklocks:mu
    60  	ipv4TTL uint8
    61  	// +checklocks:mu
    62  	ipv6HopLimit int16
    63  	// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
    64  	// +checklocks:mu
    65  	multicastTTL uint8
    66  	// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
    67  	// +checklocks:mu
    68  	multicastAddr tcpip.Address
    69  	// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
    70  	// +checklocks:mu
    71  	multicastNICID tcpip.NICID
    72  	// +checklocks:mu
    73  	ipv4TOS uint8
    74  	// +checklocks:mu
    75  	ipv6TClass uint8
    76  
    77  	// Lock ordering: mu > infoMu.
    78  	infoMu sync.RWMutex `state:"nosave"`
    79  	// info has a dedicated mutex so that we can avoid lock ordering violations
    80  	// when reading the endpoint's info. If we used mu, we need to guarantee
    81  	// that any lock taken while mu is held is not held when calling Info()
    82  	// which is not true as of writing (we hold mu while registering transport
    83  	// endpoints (taking the transport demuxer lock but we also hold the demuxer
    84  	// lock when delivering packets/errors to endpoints).
    85  	//
    86  	// Writes must be performed through setInfo.
    87  	//
    88  	// +checklocks:infoMu
    89  	info stack.TransportEndpointInfo
    90  
    91  	// state holds a transport.DatagramBasedEndpointState.
    92  	//
    93  	// state must be accessed with atomics so that we can avoid lock ordering
    94  	// violations when reading the state. If we used mu, we need to guarantee
    95  	// that any lock taken while mu is held is not held when calling State()
    96  	// which is not true as of writing (we hold mu while registering transport
    97  	// endpoints (taking the transport demuxer lock but we also hold the demuxer
    98  	// lock when delivering packets/errors to endpoints).
    99  	//
   100  	// Writes must be performed through setEndpointState.
   101  	state atomicbitops.Uint32
   102  
   103  	// Callers should not attempt to obtain sendBufferSizeInUseMu while holding
   104  	// another lock on Endpoint.
   105  	sendBufferSizeInUseMu sync.RWMutex `state:"nosave"`
   106  	// sendBufferSizeInUse keeps track of the bytes in use by in-flight packets.
   107  	//
   108  	// +checklocks:sendBufferSizeInUseMu
   109  	sendBufferSizeInUse int64 `state:"nosave"`
   110  }
   111  
   112  // +stateify savable
   113  type multicastMembership struct {
   114  	nicID         tcpip.NICID
   115  	multicastAddr tcpip.Address
   116  }
   117  
   118  // Init initializes the endpoint.
   119  func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ops *tcpip.SocketOptions, waiterQueue *waiter.Queue) {
   120  	e.mu.Lock()
   121  	defer e.mu.Unlock()
   122  	if e.multicastMemberships != nil {
   123  		panic(fmt.Sprintf("endpoint is already initialized; got e.multicastMemberships = %#v, want = nil", e.multicastMemberships))
   124  	}
   125  
   126  	switch netProto {
   127  	case header.IPv4ProtocolNumber, header.IPv6ProtocolNumber:
   128  	default:
   129  		panic(fmt.Sprintf("invalid protocol number = %d", netProto))
   130  	}
   131  
   132  	e.stack = s
   133  	e.ops = ops
   134  	e.netProto = netProto
   135  	e.transProto = transProto
   136  	e.waiterQueue = waiterQueue
   137  	e.infoMu.Lock()
   138  	e.info = stack.TransportEndpointInfo{
   139  		NetProto:   netProto,
   140  		TransProto: transProto,
   141  	}
   142  	e.infoMu.Unlock()
   143  	e.effectiveNetProto = netProto
   144  	e.ipv4TTL = tcpip.UseDefaultIPv4TTL
   145  	e.ipv6HopLimit = tcpip.UseDefaultIPv6HopLimit
   146  
   147  	// Linux defaults to TTL=1.
   148  	e.multicastTTL = 1
   149  	e.multicastMemberships = make(map[multicastMembership]struct{})
   150  	e.setEndpointState(transport.DatagramEndpointStateInitial)
   151  }
   152  
   153  // NetProto returns the network protocol the endpoint was initialized with.
   154  func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber {
   155  	return e.netProto
   156  }
   157  
   158  // setEndpointState sets the state of the endpoint.
   159  //
   160  // e.mu must be held to synchronize changes to state with the rest of the
   161  // endpoint.
   162  //
   163  // +checklocks:e.mu
   164  func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) {
   165  	e.state.Store(uint32(state))
   166  }
   167  
   168  // State returns the state of the endpoint.
   169  func (e *Endpoint) State() transport.DatagramEndpointState {
   170  	return transport.DatagramEndpointState(e.state.Load())
   171  }
   172  
   173  // Close cleans the endpoint's resources and leaves the endpoint in a closed
   174  // state.
   175  func (e *Endpoint) Close() {
   176  	e.mu.Lock()
   177  	defer e.mu.Unlock()
   178  
   179  	if e.State() == transport.DatagramEndpointStateClosed {
   180  		return
   181  	}
   182  
   183  	for mem := range e.multicastMemberships {
   184  		e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
   185  	}
   186  	e.multicastMemberships = nil
   187  
   188  	if e.connectedRoute != nil {
   189  		e.connectedRoute.Release()
   190  		e.connectedRoute = nil
   191  	}
   192  
   193  	e.setEndpointState(transport.DatagramEndpointStateClosed)
   194  }
   195  
   196  // SetOwner sets the owner of transmitted packets.
   197  func (e *Endpoint) SetOwner(owner tcpip.PacketOwner) {
   198  	e.mu.Lock()
   199  	defer e.mu.Unlock()
   200  	e.owner = owner
   201  }
   202  
   203  // +checklocksread:e.mu
   204  func (e *Endpoint) calculateTTL(route *stack.Route) uint8 {
   205  	remoteAddress := route.RemoteAddress()
   206  	if header.IsV4MulticastAddress(remoteAddress) || header.IsV6MulticastAddress(remoteAddress) {
   207  		return e.multicastTTL
   208  	}
   209  
   210  	switch netProto := route.NetProto(); netProto {
   211  	case header.IPv4ProtocolNumber:
   212  		if e.ipv4TTL == 0 {
   213  			return route.DefaultTTL()
   214  		}
   215  		return e.ipv4TTL
   216  	case header.IPv6ProtocolNumber:
   217  		if e.ipv6HopLimit == -1 {
   218  			return route.DefaultTTL()
   219  		}
   220  		return uint8(e.ipv6HopLimit)
   221  	default:
   222  		panic(fmt.Sprintf("invalid protocol number = %d", netProto))
   223  	}
   224  }
   225  
   226  // WriteContext holds the context for a write.
   227  type WriteContext struct {
   228  	e     *Endpoint
   229  	route *stack.Route
   230  	ttl   uint8
   231  	tos   uint8
   232  }
   233  
   234  func (c *WriteContext) MTU() uint32 {
   235  	return c.route.MTU()
   236  }
   237  
   238  // Release releases held resources.
   239  func (c *WriteContext) Release() {
   240  	c.route.Release()
   241  	*c = WriteContext{}
   242  }
   243  
   244  // WritePacketInfo is the properties of a packet that may be written.
   245  type WritePacketInfo struct {
   246  	NetProto                    tcpip.NetworkProtocolNumber
   247  	LocalAddress, RemoteAddress tcpip.Address
   248  	MaxHeaderLength             uint16
   249  	RequiresTXTransportChecksum bool
   250  }
   251  
   252  // PacketInfo returns the properties of a packet that will be written.
   253  func (c *WriteContext) PacketInfo() WritePacketInfo {
   254  	return WritePacketInfo{
   255  		NetProto:                    c.route.NetProto(),
   256  		LocalAddress:                c.route.LocalAddress(),
   257  		RemoteAddress:               c.route.RemoteAddress(),
   258  		MaxHeaderLength:             c.route.MaxHeaderLength(),
   259  		RequiresTXTransportChecksum: c.route.RequiresTXTransportChecksum(),
   260  	}
   261  }
   262  
   263  // TryNewPacketBuffer returns a new packet buffer iff the endpoint's send buffer
   264  // is not full.
   265  //
   266  // If this method returns nil, the caller should wait for the endpoint to become
   267  // writable.
   268  func (c *WriteContext) TryNewPacketBuffer(reserveHdrBytes int, data buffer.Buffer) *stack.PacketBuffer {
   269  	e := c.e
   270  
   271  	e.sendBufferSizeInUseMu.Lock()
   272  	defer e.sendBufferSizeInUseMu.Unlock()
   273  
   274  	if !e.hasSendSpaceRLocked() {
   275  		return nil
   276  	}
   277  
   278  	// Note that we allow oversubscription - if there is any space at all in the
   279  	// send buffer, we accept the full packet which may be larger than the space
   280  	// available. This is because if the endpoint reports that it is writable,
   281  	// a write operation should succeed.
   282  	//
   283  	// This matches Linux behaviour:
   284  	// https://github.com/torvalds/linux/blob/38d741cb70b/include/net/sock.h#L2519
   285  	// https://github.com/torvalds/linux/blob/38d741cb70b/net/core/sock.c#L2588
   286  	pktSize := int64(reserveHdrBytes) + int64(data.Size())
   287  	e.sendBufferSizeInUse += pktSize
   288  
   289  	return stack.NewPacketBuffer(stack.PacketBufferOptions{
   290  		ReserveHeaderBytes: reserveHdrBytes,
   291  		Payload:            data,
   292  		OnRelease: func() {
   293  			e.sendBufferSizeInUseMu.Lock()
   294  			if got := e.sendBufferSizeInUse; got < pktSize {
   295  				e.sendBufferSizeInUseMu.Unlock()
   296  				panic(fmt.Sprintf("e.sendBufferSizeInUse=(%d) < pktSize(=%d)", got, pktSize))
   297  			}
   298  			e.sendBufferSizeInUse -= pktSize
   299  			signal := e.hasSendSpaceRLocked()
   300  			e.sendBufferSizeInUseMu.Unlock()
   301  
   302  			// Let waiters know if we now have space in the send buffer.
   303  			if signal {
   304  				e.waiterQueue.Notify(waiter.WritableEvents)
   305  			}
   306  		},
   307  	})
   308  }
   309  
   310  // WritePacket attempts to write the packet.
   311  func (c *WriteContext) WritePacket(pkt *stack.PacketBuffer, headerIncluded bool) tcpip.Error {
   312  	c.e.mu.RLock()
   313  	pkt.Owner = c.e.owner
   314  	c.e.mu.RUnlock()
   315  
   316  	if headerIncluded {
   317  		return c.route.WriteHeaderIncludedPacket(pkt)
   318  	}
   319  
   320  	err := c.route.WritePacket(stack.NetworkHeaderParams{
   321  		Protocol: c.e.transProto,
   322  		TTL:      c.ttl,
   323  		TOS:      c.tos,
   324  	}, pkt)
   325  
   326  	if _, ok := err.(*tcpip.ErrNoBufferSpace); ok {
   327  		var recvErr bool
   328  		switch netProto := c.route.NetProto(); netProto {
   329  		case header.IPv4ProtocolNumber:
   330  			recvErr = c.e.ops.GetIPv4RecvError()
   331  		case header.IPv6ProtocolNumber:
   332  			recvErr = c.e.ops.GetIPv6RecvError()
   333  		default:
   334  			panic(fmt.Sprintf("unhandled network protocol number = %d", netProto))
   335  		}
   336  
   337  		// Linux only returns ENOBUFS to the caller if IP{,V6}_RECVERR is set.
   338  		//
   339  		// https://github.com/torvalds/linux/blob/3e71713c9e75c/net/ipv4/udp.c#L969
   340  		// https://github.com/torvalds/linux/blob/3e71713c9e75c/net/ipv6/udp.c#L1260
   341  		if !recvErr {
   342  			err = nil
   343  		}
   344  	}
   345  
   346  	return err
   347  }
   348  
   349  // MaybeSignalWritable signals waiters with writable events if the send buffer
   350  // has space.
   351  func (e *Endpoint) MaybeSignalWritable() {
   352  	e.sendBufferSizeInUseMu.RLock()
   353  	signal := e.hasSendSpaceRLocked()
   354  	e.sendBufferSizeInUseMu.RUnlock()
   355  
   356  	if signal {
   357  		e.waiterQueue.Notify(waiter.WritableEvents)
   358  	}
   359  }
   360  
   361  // HasSendSpace returns whether or not the send buffer has space.
   362  func (e *Endpoint) HasSendSpace() bool {
   363  	e.sendBufferSizeInUseMu.RLock()
   364  	defer e.sendBufferSizeInUseMu.RUnlock()
   365  	return e.hasSendSpaceRLocked()
   366  }
   367  
   368  // +checklocksread:e.sendBufferSizeInUseMu
   369  func (e *Endpoint) hasSendSpaceRLocked() bool {
   370  	return e.ops.GetSendBufferSize() > e.sendBufferSizeInUse
   371  }
   372  
   373  // AcquireContextForWrite acquires a WriteContext.
   374  func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext, tcpip.Error) {
   375  	e.mu.RLock()
   376  	defer e.mu.RUnlock()
   377  
   378  	// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
   379  	if opts.More {
   380  		return WriteContext{}, &tcpip.ErrInvalidOptionValue{}
   381  	}
   382  
   383  	if e.State() == transport.DatagramEndpointStateClosed {
   384  		return WriteContext{}, &tcpip.ErrInvalidEndpointState{}
   385  	}
   386  
   387  	if e.writeShutdown {
   388  		return WriteContext{}, &tcpip.ErrClosedForSend{}
   389  	}
   390  
   391  	ipv6PktInfoValid := e.effectiveNetProto == header.IPv6ProtocolNumber && opts.ControlMessages.HasIPv6PacketInfo
   392  
   393  	route := e.connectedRoute
   394  	to := opts.To
   395  	info := e.Info()
   396  	switch {
   397  	case to == nil:
   398  		// If the user doesn't specify a destination, they should have
   399  		// connected to another address.
   400  		if e.State() != transport.DatagramEndpointStateConnected {
   401  			return WriteContext{}, &tcpip.ErrDestinationRequired{}
   402  		}
   403  
   404  		if !ipv6PktInfoValid {
   405  			route.Acquire()
   406  			break
   407  		}
   408  
   409  		// We are connected and the caller did not specify the destination but
   410  		// we have an IPv6 packet info structure which may change our local
   411  		// interface/address used to send the packet so we need to construct
   412  		// a new route instead of using the connected route.
   413  		//
   414  		// Construct a destination matching the remote the endpoint is connected
   415  		// to.
   416  		to = &tcpip.FullAddress{
   417  			// RegisterNICID is set when the endpoint is connected. It is usually
   418  			// only set for link-local addresses or multicast addresses if the
   419  			// multicast interface was specified (see e.multicastNICID,
   420  			// e.connectRouteRLocked and e.ConnectAndThen).
   421  			NIC:  info.RegisterNICID,
   422  			Addr: info.ID.RemoteAddress,
   423  		}
   424  		fallthrough
   425  	default:
   426  		// Reject destination address if it goes through a different
   427  		// NIC than the endpoint was bound to.
   428  		nicID := to.NIC
   429  		if nicID == 0 {
   430  			nicID = tcpip.NICID(e.ops.GetBindToDevice())
   431  		}
   432  
   433  		var localAddr tcpip.Address
   434  		if ipv6PktInfoValid {
   435  			// Uphold strong-host semantics since (as of writing) the stack follows
   436  			// the strong host model.
   437  
   438  			pktInfoNICID := opts.ControlMessages.IPv6PacketInfo.NIC
   439  			pktInfoAddr := opts.ControlMessages.IPv6PacketInfo.Addr
   440  
   441  			if pktInfoNICID != 0 {
   442  				// If we are bound to an interface or specified the destination
   443  				// interface (usually when using link-local addresses), make sure the
   444  				// interface matches the specified local interface.
   445  				if nicID != 0 && nicID != pktInfoNICID {
   446  					return WriteContext{}, &tcpip.ErrHostUnreachable{}
   447  				}
   448  
   449  				// If a local address is not specified, then we need to make sure the
   450  				// bound address belongs to the specified local interface.
   451  				if pktInfoAddr.BitLen() == 0 {
   452  					// If the bound interface is different from the specified local
   453  					// interface, the bound address obviously does not belong to the
   454  					// specified local interface.
   455  					//
   456  					// The bound interface is usually only set for link-local addresses.
   457  					if info.BindNICID != 0 && info.BindNICID != pktInfoNICID {
   458  						return WriteContext{}, &tcpip.ErrHostUnreachable{}
   459  					}
   460  					if info.ID.LocalAddress.BitLen() != 0 && e.stack.CheckLocalAddress(pktInfoNICID, header.IPv6ProtocolNumber, info.ID.LocalAddress) == 0 {
   461  						return WriteContext{}, &tcpip.ErrBadLocalAddress{}
   462  					}
   463  				}
   464  
   465  				nicID = pktInfoNICID
   466  			}
   467  
   468  			if pktInfoAddr.BitLen() != 0 {
   469  				// The local address must belong to the stack. If an outgoing interface
   470  				// is specified as a result of binding the endpoint to a device, or
   471  				// specifying the outgoing interface in the destination address/pkt info
   472  				// structure, the address must belong to that interface.
   473  				if e.stack.CheckLocalAddress(nicID, header.IPv6ProtocolNumber, pktInfoAddr) == 0 {
   474  					return WriteContext{}, &tcpip.ErrBadLocalAddress{}
   475  				}
   476  
   477  				localAddr = pktInfoAddr
   478  			}
   479  		} else {
   480  			if info.BindNICID != 0 {
   481  				if nicID != 0 && nicID != info.BindNICID {
   482  					return WriteContext{}, &tcpip.ErrHostUnreachable{}
   483  				}
   484  
   485  				nicID = info.BindNICID
   486  			}
   487  			if nicID == 0 {
   488  				nicID = info.RegisterNICID
   489  			}
   490  		}
   491  
   492  		dst, netProto, err := e.checkV4Mapped(*to)
   493  		if err != nil {
   494  			return WriteContext{}, err
   495  		}
   496  
   497  		route, _, err = e.connectRouteRLocked(nicID, localAddr, dst, netProto)
   498  		if err != nil {
   499  			return WriteContext{}, err
   500  		}
   501  	}
   502  
   503  	if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() {
   504  		route.Release()
   505  		return WriteContext{}, &tcpip.ErrBroadcastDisabled{}
   506  	}
   507  
   508  	var tos uint8
   509  	var ttl uint8
   510  	switch netProto := route.NetProto(); netProto {
   511  	case header.IPv4ProtocolNumber:
   512  		tos = e.ipv4TOS
   513  		if opts.ControlMessages.HasTTL {
   514  			ttl = opts.ControlMessages.TTL
   515  		} else {
   516  			ttl = e.calculateTTL(route)
   517  		}
   518  	case header.IPv6ProtocolNumber:
   519  		tos = e.ipv6TClass
   520  		if opts.ControlMessages.HasHopLimit {
   521  			ttl = opts.ControlMessages.HopLimit
   522  		} else {
   523  			ttl = e.calculateTTL(route)
   524  		}
   525  	default:
   526  		panic(fmt.Sprintf("invalid protocol number = %d", netProto))
   527  	}
   528  
   529  	return WriteContext{
   530  		e:     e,
   531  		route: route,
   532  		ttl:   ttl,
   533  		tos:   tos,
   534  	}, nil
   535  }
   536  
   537  // Disconnect disconnects the endpoint from its peer.
   538  func (e *Endpoint) Disconnect() {
   539  	e.mu.Lock()
   540  	defer e.mu.Unlock()
   541  
   542  	if e.State() != transport.DatagramEndpointStateConnected {
   543  		return
   544  	}
   545  
   546  	info := e.Info()
   547  	// Exclude ephemerally bound endpoints.
   548  	if e.wasBound {
   549  		info.ID = stack.TransportEndpointID{
   550  			LocalAddress: info.BindAddr,
   551  		}
   552  		e.setEndpointState(transport.DatagramEndpointStateBound)
   553  	} else {
   554  		info.ID = stack.TransportEndpointID{}
   555  		e.setEndpointState(transport.DatagramEndpointStateInitial)
   556  	}
   557  	e.setInfo(info)
   558  
   559  	e.connectedRoute.Release()
   560  	e.connectedRoute = nil
   561  }
   562  
   563  // connectRouteRLocked establishes a route to the specified interface or the
   564  // configured multicast interface if no interface is specified and the
   565  // specified address is a multicast address.
   566  //
   567  // +checklocksread:e.mu
   568  func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, localAddr tcpip.Address, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) {
   569  	if localAddr.BitLen() == 0 {
   570  		localAddr = e.Info().ID.LocalAddress
   571  		if e.isBroadcastOrMulticast(nicID, netProto, localAddr) {
   572  			// A packet can only originate from a unicast address (i.e., an interface).
   573  			localAddr = tcpip.Address{}
   574  		}
   575  
   576  		if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
   577  			if nicID == 0 {
   578  				nicID = e.multicastNICID
   579  			}
   580  			if localAddr == (tcpip.Address{}) && nicID == 0 {
   581  				localAddr = e.multicastAddr
   582  			}
   583  		}
   584  	}
   585  
   586  	// Find a route to the desired destination.
   587  	r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop())
   588  	if err != nil {
   589  		return nil, 0, err
   590  	}
   591  	return r, nicID, nil
   592  }
   593  
   594  // Connect connects the endpoint to the address.
   595  func (e *Endpoint) Connect(addr tcpip.FullAddress) tcpip.Error {
   596  	return e.ConnectAndThen(addr, func(_ tcpip.NetworkProtocolNumber, _, _ stack.TransportEndpointID) tcpip.Error {
   597  		return nil
   598  	})
   599  }
   600  
   601  // ConnectAndThen connects the endpoint to the address and then calls the
   602  // provided function.
   603  //
   604  // If the function returns an error, the endpoint's state does not change. The
   605  // function will be called with the network protocol used to connect to the peer
   606  // and the source and destination addresses that will be used to send traffic to
   607  // the peer.
   608  func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error) tcpip.Error {
   609  	addr.Port = 0
   610  
   611  	e.mu.Lock()
   612  	defer e.mu.Unlock()
   613  
   614  	info := e.Info()
   615  	nicID := addr.NIC
   616  	switch e.State() {
   617  	case transport.DatagramEndpointStateInitial:
   618  	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   619  		if info.BindNICID == 0 {
   620  			break
   621  		}
   622  
   623  		if nicID != 0 && nicID != info.BindNICID {
   624  			return &tcpip.ErrInvalidEndpointState{}
   625  		}
   626  
   627  		nicID = info.BindNICID
   628  	default:
   629  		return &tcpip.ErrInvalidEndpointState{}
   630  	}
   631  
   632  	addr, netProto, err := e.checkV4Mapped(addr)
   633  	if err != nil {
   634  		return err
   635  	}
   636  
   637  	r, nicID, err := e.connectRouteRLocked(nicID, tcpip.Address{}, addr, netProto)
   638  	if err != nil {
   639  		return err
   640  	}
   641  
   642  	id := stack.TransportEndpointID{
   643  		LocalAddress:  info.ID.LocalAddress,
   644  		RemoteAddress: r.RemoteAddress(),
   645  	}
   646  	if e.State() == transport.DatagramEndpointStateInitial {
   647  		id.LocalAddress = r.LocalAddress()
   648  	}
   649  
   650  	if err := f(r.NetProto(), info.ID, id); err != nil {
   651  		r.Release()
   652  		return err
   653  	}
   654  
   655  	if e.connectedRoute != nil {
   656  		// If the endpoint was previously connected then release any previous route.
   657  		e.connectedRoute.Release()
   658  	}
   659  	e.connectedRoute = r
   660  	info.ID = id
   661  	info.RegisterNICID = nicID
   662  	e.setInfo(info)
   663  	e.effectiveNetProto = netProto
   664  	e.setEndpointState(transport.DatagramEndpointStateConnected)
   665  	return nil
   666  }
   667  
   668  // Shutdown shutsdown the endpoint.
   669  func (e *Endpoint) Shutdown() tcpip.Error {
   670  	e.mu.Lock()
   671  	defer e.mu.Unlock()
   672  
   673  	switch state := e.State(); state {
   674  	case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
   675  		return &tcpip.ErrNotConnected{}
   676  	case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
   677  		e.writeShutdown = true
   678  		return nil
   679  	default:
   680  		panic(fmt.Sprintf("unhandled state = %s", state))
   681  	}
   682  }
   683  
   684  // checkV4MappedRLocked determines the effective network protocol and converts
   685  // addr to its canonical form.
   686  func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) {
   687  	info := e.Info()
   688  	unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only())
   689  	if err != nil {
   690  		return tcpip.FullAddress{}, 0, err
   691  	}
   692  	return unwrapped, netProto, nil
   693  }
   694  
   695  func (e *Endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
   696  	return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
   697  }
   698  
   699  // Bind binds the endpoint to the address.
   700  func (e *Endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
   701  	return e.BindAndThen(addr, func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error {
   702  		return nil
   703  	})
   704  }
   705  
   706  // BindAndThen binds the endpoint to the address and then calls the provided
   707  // function.
   708  //
   709  // If the function returns an error, the endpoint's state does not change. The
   710  // function will be called with the bound network protocol and address.
   711  func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProtocolNumber, tcpip.Address) tcpip.Error) tcpip.Error {
   712  	addr.Port = 0
   713  
   714  	e.mu.Lock()
   715  	defer e.mu.Unlock()
   716  
   717  	// Don't allow binding once endpoint is not in the initial state
   718  	// anymore.
   719  	if e.State() != transport.DatagramEndpointStateInitial {
   720  		return &tcpip.ErrInvalidEndpointState{}
   721  	}
   722  
   723  	addr, netProto, err := e.checkV4Mapped(addr)
   724  	if err != nil {
   725  		return err
   726  	}
   727  
   728  	nicID := addr.NIC
   729  	if addr.Addr.BitLen() != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) {
   730  		nicID = e.stack.CheckLocalAddress(nicID, netProto, addr.Addr)
   731  		if nicID == 0 {
   732  			return &tcpip.ErrBadLocalAddress{}
   733  		}
   734  	}
   735  
   736  	if err := f(netProto, addr.Addr); err != nil {
   737  		return err
   738  	}
   739  
   740  	e.wasBound = true
   741  
   742  	info := e.Info()
   743  	info.ID = stack.TransportEndpointID{
   744  		LocalAddress: addr.Addr,
   745  	}
   746  	info.BindNICID = addr.NIC
   747  	info.RegisterNICID = nicID
   748  	info.BindAddr = addr.Addr
   749  	e.setInfo(info)
   750  	e.effectiveNetProto = netProto
   751  	e.setEndpointState(transport.DatagramEndpointStateBound)
   752  	return nil
   753  }
   754  
   755  // WasBound returns true iff the endpoint was ever bound.
   756  func (e *Endpoint) WasBound() bool {
   757  	e.mu.RLock()
   758  	defer e.mu.RUnlock()
   759  	return e.wasBound
   760  }
   761  
   762  // GetLocalAddress returns the address that the endpoint is bound to.
   763  func (e *Endpoint) GetLocalAddress() tcpip.FullAddress {
   764  	e.mu.RLock()
   765  	defer e.mu.RUnlock()
   766  
   767  	info := e.Info()
   768  	addr := info.BindAddr
   769  	if e.State() == transport.DatagramEndpointStateConnected {
   770  		addr = e.connectedRoute.LocalAddress()
   771  	}
   772  
   773  	return tcpip.FullAddress{
   774  		NIC:  info.RegisterNICID,
   775  		Addr: addr,
   776  	}
   777  }
   778  
   779  // GetRemoteAddress returns the address that the endpoint is connected to.
   780  func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) {
   781  	e.mu.RLock()
   782  	defer e.mu.RUnlock()
   783  
   784  	if e.State() != transport.DatagramEndpointStateConnected {
   785  		return tcpip.FullAddress{}, false
   786  	}
   787  
   788  	return tcpip.FullAddress{
   789  		Addr: e.connectedRoute.RemoteAddress(),
   790  		NIC:  e.Info().RegisterNICID,
   791  	}, true
   792  }
   793  
   794  // SetSockOptInt sets the socket option.
   795  func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
   796  	switch opt {
   797  	case tcpip.MTUDiscoverOption:
   798  		// Return not supported if the value is not disabling path
   799  		// MTU discovery.
   800  		if v != tcpip.PMTUDiscoveryDont {
   801  			return &tcpip.ErrNotSupported{}
   802  		}
   803  
   804  	case tcpip.MulticastTTLOption:
   805  		e.mu.Lock()
   806  		e.multicastTTL = uint8(v)
   807  		e.mu.Unlock()
   808  
   809  	case tcpip.IPv4TTLOption:
   810  		e.mu.Lock()
   811  		e.ipv4TTL = uint8(v)
   812  		e.mu.Unlock()
   813  
   814  	case tcpip.IPv6HopLimitOption:
   815  		e.mu.Lock()
   816  		e.ipv6HopLimit = int16(v)
   817  		e.mu.Unlock()
   818  
   819  	case tcpip.IPv4TOSOption:
   820  		e.mu.Lock()
   821  		e.ipv4TOS = uint8(v)
   822  		e.mu.Unlock()
   823  
   824  	case tcpip.IPv6TrafficClassOption:
   825  		e.mu.Lock()
   826  		e.ipv6TClass = uint8(v)
   827  		e.mu.Unlock()
   828  	}
   829  
   830  	return nil
   831  }
   832  
   833  // GetSockOptInt returns the socket option.
   834  func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
   835  	switch opt {
   836  	case tcpip.MTUDiscoverOption:
   837  		// The only supported setting is path MTU discovery disabled.
   838  		return tcpip.PMTUDiscoveryDont, nil
   839  
   840  	case tcpip.MulticastTTLOption:
   841  		e.mu.Lock()
   842  		v := int(e.multicastTTL)
   843  		e.mu.Unlock()
   844  		return v, nil
   845  
   846  	case tcpip.IPv4TTLOption:
   847  		e.mu.Lock()
   848  		v := int(e.ipv4TTL)
   849  		e.mu.Unlock()
   850  		return v, nil
   851  
   852  	case tcpip.IPv6HopLimitOption:
   853  		e.mu.Lock()
   854  		v := int(e.ipv6HopLimit)
   855  		e.mu.Unlock()
   856  		return v, nil
   857  
   858  	case tcpip.IPv4TOSOption:
   859  		e.mu.RLock()
   860  		v := int(e.ipv4TOS)
   861  		e.mu.RUnlock()
   862  		return v, nil
   863  
   864  	case tcpip.IPv6TrafficClassOption:
   865  		e.mu.RLock()
   866  		v := int(e.ipv6TClass)
   867  		e.mu.RUnlock()
   868  		return v, nil
   869  
   870  	default:
   871  		return -1, &tcpip.ErrUnknownProtocolOption{}
   872  	}
   873  }
   874  
   875  // SetSockOpt sets the socket option.
   876  func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
   877  	switch v := opt.(type) {
   878  	case *tcpip.MulticastInterfaceOption:
   879  		e.mu.Lock()
   880  		defer e.mu.Unlock()
   881  
   882  		fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
   883  		fa, netProto, err := e.checkV4Mapped(fa)
   884  		if err != nil {
   885  			return err
   886  		}
   887  		nic := v.NIC
   888  		addr := fa.Addr
   889  
   890  		if nic == 0 && addr == (tcpip.Address{}) {
   891  			e.multicastAddr = tcpip.Address{}
   892  			e.multicastNICID = 0
   893  			break
   894  		}
   895  
   896  		if nic != 0 {
   897  			if !e.stack.CheckNIC(nic) {
   898  				return &tcpip.ErrBadLocalAddress{}
   899  			}
   900  		} else {
   901  			nic = e.stack.CheckLocalAddress(0, netProto, addr)
   902  			if nic == 0 {
   903  				return &tcpip.ErrBadLocalAddress{}
   904  			}
   905  		}
   906  
   907  		if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic {
   908  			return &tcpip.ErrInvalidEndpointState{}
   909  		}
   910  
   911  		e.multicastNICID = nic
   912  		e.multicastAddr = addr
   913  
   914  	case *tcpip.AddMembershipOption:
   915  		if !(header.IsV4MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv4ProtocolNumber) && !(header.IsV6MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv6ProtocolNumber) {
   916  			return &tcpip.ErrInvalidOptionValue{}
   917  		}
   918  
   919  		nicID := v.NIC
   920  
   921  		if v.InterfaceAddr.Unspecified() {
   922  			if nicID == 0 {
   923  				if r, err := e.stack.FindRoute(0, tcpip.Address{}, v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
   924  					nicID = r.NICID()
   925  					r.Release()
   926  				}
   927  			}
   928  		} else {
   929  			nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
   930  		}
   931  		if nicID == 0 {
   932  			return &tcpip.ErrUnknownDevice{}
   933  		}
   934  
   935  		memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
   936  
   937  		e.mu.Lock()
   938  		defer e.mu.Unlock()
   939  
   940  		if _, ok := e.multicastMemberships[memToInsert]; ok {
   941  			return &tcpip.ErrPortInUse{}
   942  		}
   943  
   944  		if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
   945  			return err
   946  		}
   947  
   948  		e.multicastMemberships[memToInsert] = struct{}{}
   949  
   950  	case *tcpip.RemoveMembershipOption:
   951  		if !(header.IsV4MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv4ProtocolNumber) && !(header.IsV6MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv6ProtocolNumber) {
   952  			return &tcpip.ErrInvalidOptionValue{}
   953  		}
   954  
   955  		nicID := v.NIC
   956  		if v.InterfaceAddr.Unspecified() {
   957  			if nicID == 0 {
   958  				if r, err := e.stack.FindRoute(0, tcpip.Address{}, v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
   959  					nicID = r.NICID()
   960  					r.Release()
   961  				}
   962  			}
   963  		} else {
   964  			nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
   965  		}
   966  		if nicID == 0 {
   967  			return &tcpip.ErrUnknownDevice{}
   968  		}
   969  
   970  		memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
   971  
   972  		e.mu.Lock()
   973  		defer e.mu.Unlock()
   974  
   975  		if _, ok := e.multicastMemberships[memToRemove]; !ok {
   976  			return &tcpip.ErrBadLocalAddress{}
   977  		}
   978  
   979  		if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
   980  			return err
   981  		}
   982  
   983  		delete(e.multicastMemberships, memToRemove)
   984  
   985  	case *tcpip.SocketDetachFilterOption:
   986  		return nil
   987  	}
   988  	return nil
   989  }
   990  
   991  // GetSockOpt returns the socket option.
   992  func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
   993  	switch o := opt.(type) {
   994  	case *tcpip.MulticastInterfaceOption:
   995  		e.mu.Lock()
   996  		*o = tcpip.MulticastInterfaceOption{
   997  			NIC:           e.multicastNICID,
   998  			InterfaceAddr: e.multicastAddr,
   999  		}
  1000  		e.mu.Unlock()
  1001  
  1002  	default:
  1003  		return &tcpip.ErrUnknownProtocolOption{}
  1004  	}
  1005  	return nil
  1006  }
  1007  
  1008  // Info returns a copy of the endpoint info.
  1009  func (e *Endpoint) Info() stack.TransportEndpointInfo {
  1010  	e.infoMu.RLock()
  1011  	defer e.infoMu.RUnlock()
  1012  	return e.info
  1013  }
  1014  
  1015  // setInfo sets the endpoint's info.
  1016  //
  1017  // e.mu must be held to synchronize changes to info with the rest of the
  1018  // endpoint.
  1019  //
  1020  // +checklocks:e.mu
  1021  func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) {
  1022  	e.infoMu.Lock()
  1023  	defer e.infoMu.Unlock()
  1024  	e.info = info
  1025  }