github.com/MerlinKodo/gvisor@v0.0.0-20231110090155-957f62ecf90e/pkg/sentry/socket/control/control.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package control provides internal representations of socket control
    16  // messages.
    17  package control
    18  
    19  import (
    20  	"math"
    21  	"time"
    22  
    23  	"github.com/MerlinKodo/gvisor/pkg/abi/linux"
    24  	"github.com/MerlinKodo/gvisor/pkg/bits"
    25  	"github.com/MerlinKodo/gvisor/pkg/context"
    26  	"github.com/MerlinKodo/gvisor/pkg/errors/linuxerr"
    27  	"github.com/MerlinKodo/gvisor/pkg/hostarch"
    28  	"github.com/MerlinKodo/gvisor/pkg/marshal"
    29  	"github.com/MerlinKodo/gvisor/pkg/marshal/primitive"
    30  	"github.com/MerlinKodo/gvisor/pkg/sentry/kernel"
    31  	"github.com/MerlinKodo/gvisor/pkg/sentry/kernel/auth"
    32  	"github.com/MerlinKodo/gvisor/pkg/sentry/socket"
    33  	"github.com/MerlinKodo/gvisor/pkg/sentry/socket/unix/transport"
    34  	"github.com/MerlinKodo/gvisor/pkg/sentry/vfs"
    35  )
    36  
    37  // SCMCredentials represents a SCM_CREDENTIALS socket control message.
    38  type SCMCredentials interface {
    39  	transport.CredentialsControlMessage
    40  
    41  	// Credentials returns properly namespaced values for the sender's pid, uid
    42  	// and gid.
    43  	Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID)
    44  }
    45  
    46  // scmCredentials represents an SCM_CREDENTIALS socket control message.
    47  //
    48  // +stateify savable
    49  type scmCredentials struct {
    50  	t    *kernel.Task
    51  	kuid auth.KUID
    52  	kgid auth.KGID
    53  }
    54  
    55  // NewSCMCredentials creates a new SCM_CREDENTIALS socket control message
    56  // representation.
    57  func NewSCMCredentials(t *kernel.Task, cred linux.ControlMessageCredentials) (SCMCredentials, error) {
    58  	tcred := t.Credentials()
    59  	kuid, err := tcred.UseUID(auth.UID(cred.UID))
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  	kgid, err := tcred.UseGID(auth.GID(cred.GID))
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	if kernel.ThreadID(cred.PID) != t.ThreadGroup().ID() && !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.PIDNamespace().UserNamespace()) {
    68  		return nil, linuxerr.EPERM
    69  	}
    70  	return &scmCredentials{t, kuid, kgid}, nil
    71  }
    72  
    73  // Equals implements transport.CredentialsControlMessage.Equals.
    74  func (c *scmCredentials) Equals(oc transport.CredentialsControlMessage) bool {
    75  	if oc, _ := oc.(*scmCredentials); oc != nil && *c == *oc {
    76  		return true
    77  	}
    78  	return false
    79  }
    80  
    81  func putUint64(buf []byte, n uint64) []byte {
    82  	hostarch.ByteOrder.PutUint64(buf[len(buf):len(buf)+8], n)
    83  	return buf[:len(buf)+8]
    84  }
    85  
    86  func putUint32(buf []byte, n uint32) []byte {
    87  	hostarch.ByteOrder.PutUint32(buf[len(buf):len(buf)+4], n)
    88  	return buf[:len(buf)+4]
    89  }
    90  
    91  // putCmsg writes a control message header and as much data as will fit into
    92  // the unused capacity of a buffer.
    93  func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
    94  	space := bits.AlignDown(cap(buf)-len(buf), 4)
    95  
    96  	// We can't write to space that doesn't exist, so if we are going to align
    97  	// the available space, we must align down.
    98  	//
    99  	// align must be >= 4 and each data int32 is 4 bytes. The length of the
   100  	// header is already aligned, so if we align to the width of the data there
   101  	// are two cases:
   102  	// 1. The aligned length is less than the length of the header. The
   103  	// unaligned length was also less than the length of the header, so we
   104  	// can't write anything.
   105  	// 2. The aligned length is greater than or equal to the length of the
   106  	// header. We can write the header plus zero or more bytes of data. We can't
   107  	// write a partial int32, so the length of the message will be
   108  	// min(aligned length, header + data).
   109  	if space < linux.SizeOfControlMessageHeader {
   110  		flags |= linux.MSG_CTRUNC
   111  		return buf, flags
   112  	}
   113  
   114  	length := 4*len(data) + linux.SizeOfControlMessageHeader
   115  	if length > space {
   116  		length = space
   117  	}
   118  	buf = putUint64(buf, uint64(length))
   119  	buf = putUint32(buf, linux.SOL_SOCKET)
   120  	buf = putUint32(buf, msgType)
   121  	for _, d := range data {
   122  		if len(buf)+4 > cap(buf) {
   123  			flags |= linux.MSG_CTRUNC
   124  			break
   125  		}
   126  		buf = putUint32(buf, uint32(d))
   127  	}
   128  	return alignSlice(buf, align), flags
   129  }
   130  
   131  func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data marshal.Marshallable) []byte {
   132  	if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader {
   133  		return buf
   134  	}
   135  	ob := buf
   136  
   137  	buf = putUint64(buf, uint64(linux.SizeOfControlMessageHeader))
   138  	buf = putUint32(buf, msgLevel)
   139  	buf = putUint32(buf, msgType)
   140  
   141  	hdrBuf := buf
   142  	buf = append(buf, marshal.Marshal(data)...)
   143  
   144  	// If the control message data brought us over capacity, omit it.
   145  	if cap(buf) != cap(ob) {
   146  		return hdrBuf
   147  	}
   148  
   149  	// Update control message length to include data.
   150  	putUint64(ob, uint64(len(buf)-len(ob)))
   151  
   152  	return alignSlice(buf, align)
   153  }
   154  
   155  // Credentials implements SCMCredentials.Credentials.
   156  func (c *scmCredentials) Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
   157  	// "When a process's user and group IDs are passed over a UNIX domain
   158  	// socket to a process in a different user namespace (see the description
   159  	// of SCM_CREDENTIALS in unix(7)), they are translated into the
   160  	// corresponding values as per the receiving process's user and group ID
   161  	// mappings." - user_namespaces(7)
   162  	pid := t.PIDNamespace().IDOfTask(c.t)
   163  	uid := c.kuid.In(t.UserNamespace()).OrOverflow()
   164  	gid := c.kgid.In(t.UserNamespace()).OrOverflow()
   165  
   166  	return pid, uid, gid
   167  }
   168  
   169  // PackCredentials packs the credentials in the control message (or default
   170  // credentials if none) into a buffer.
   171  func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int) ([]byte, int) {
   172  	align := t.Arch().Width()
   173  
   174  	// Default credentials if none are available.
   175  	pid := kernel.ThreadID(0)
   176  	uid := auth.UID(auth.NobodyKUID)
   177  	gid := auth.GID(auth.NobodyKGID)
   178  
   179  	if creds != nil {
   180  		pid, uid, gid = creds.Credentials(t)
   181  	}
   182  	c := []int32{int32(pid), int32(uid), int32(gid)}
   183  	return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c)
   184  }
   185  
   186  // alignSlice extends a slice's length (up to the capacity) to align it.
   187  func alignSlice(buf []byte, align uint) []byte {
   188  	aligned := bits.AlignUp(len(buf), align)
   189  	if aligned > cap(buf) {
   190  		// Linux allows unaligned data if there isn't room for alignment.
   191  		// Since there isn't room for alignment, there isn't room for any
   192  		// additional messages either.
   193  		return buf
   194  	}
   195  	return buf[:aligned]
   196  }
   197  
   198  // PackTimestamp packs a SO_TIMESTAMP socket control message.
   199  func PackTimestamp(t *kernel.Task, timestamp time.Time, buf []byte) []byte {
   200  	timestampP := linux.NsecToTimeval(timestamp.UnixNano())
   201  	return putCmsgStruct(
   202  		buf,
   203  		linux.SOL_SOCKET,
   204  		linux.SO_TIMESTAMP,
   205  		t.Arch().Width(),
   206  		&timestampP,
   207  	)
   208  }
   209  
   210  // PackInq packs a TCP_INQ socket control message.
   211  func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
   212  	return putCmsgStruct(
   213  		buf,
   214  		linux.SOL_TCP,
   215  		linux.TCP_INQ,
   216  		t.Arch().Width(),
   217  		primitive.AllocateInt32(inq),
   218  	)
   219  }
   220  
   221  // PackTOS packs an IP_TOS socket control message.
   222  func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
   223  	return putCmsgStruct(
   224  		buf,
   225  		linux.SOL_IP,
   226  		linux.IP_TOS,
   227  		t.Arch().Width(),
   228  		primitive.AllocateUint8(tos),
   229  	)
   230  }
   231  
   232  // PackTClass packs an IPV6_TCLASS socket control message.
   233  func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
   234  	return putCmsgStruct(
   235  		buf,
   236  		linux.SOL_IPV6,
   237  		linux.IPV6_TCLASS,
   238  		t.Arch().Width(),
   239  		primitive.AllocateUint32(tClass),
   240  	)
   241  }
   242  
   243  // PackTTL packs an IP_TTL socket control message.
   244  func PackTTL(t *kernel.Task, ttl uint32, buf []byte) []byte {
   245  	return putCmsgStruct(
   246  		buf,
   247  		linux.SOL_IP,
   248  		linux.IP_TTL,
   249  		t.Arch().Width(),
   250  		primitive.AllocateUint32(ttl),
   251  	)
   252  }
   253  
   254  // PackHopLimit packs an IPV6_HOPLIMIT socket control message.
   255  func PackHopLimit(t *kernel.Task, hoplimit uint32, buf []byte) []byte {
   256  	return putCmsgStruct(
   257  		buf,
   258  		linux.SOL_IPV6,
   259  		linux.IPV6_HOPLIMIT,
   260  		t.Arch().Width(),
   261  		primitive.AllocateUint32(hoplimit),
   262  	)
   263  }
   264  
   265  // PackIPPacketInfo packs an IP_PKTINFO socket control message.
   266  func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte {
   267  	return putCmsgStruct(
   268  		buf,
   269  		linux.SOL_IP,
   270  		linux.IP_PKTINFO,
   271  		t.Arch().Width(),
   272  		packetInfo,
   273  	)
   274  }
   275  
   276  // PackIPv6PacketInfo packs an IPV6_PKTINFO socket control message.
   277  func PackIPv6PacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPv6PacketInfo, buf []byte) []byte {
   278  	return putCmsgStruct(
   279  		buf,
   280  		linux.SOL_IPV6,
   281  		linux.IPV6_PKTINFO,
   282  		t.Arch().Width(),
   283  		packetInfo,
   284  	)
   285  }
   286  
   287  // PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
   288  func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
   289  	var level uint32
   290  	var optType uint32
   291  	switch originalDstAddress.(type) {
   292  	case *linux.SockAddrInet:
   293  		level = linux.SOL_IP
   294  		optType = linux.IP_RECVORIGDSTADDR
   295  	case *linux.SockAddrInet6:
   296  		level = linux.SOL_IPV6
   297  		optType = linux.IPV6_RECVORIGDSTADDR
   298  	default:
   299  		panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg")
   300  	}
   301  	return putCmsgStruct(
   302  		buf, level, optType, t.Arch().Width(), originalDstAddress)
   303  }
   304  
   305  // PackSockExtendedErr packs an IP*_RECVERR socket control message.
   306  func PackSockExtendedErr(t *kernel.Task, sockErr linux.SockErrCMsg, buf []byte) []byte {
   307  	return putCmsgStruct(
   308  		buf,
   309  		sockErr.CMsgLevel(),
   310  		sockErr.CMsgType(),
   311  		t.Arch().Width(),
   312  		sockErr,
   313  	)
   314  }
   315  
   316  // PackControlMessages packs control messages into the given buffer.
   317  //
   318  // We skip control messages specific to Unix domain sockets.
   319  //
   320  // Note that some control messages may be truncated if they do not fit under
   321  // the capacity of buf.
   322  func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
   323  	if cmsgs.IP.HasTimestamp {
   324  		buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
   325  	}
   326  
   327  	if cmsgs.IP.HasInq {
   328  		// In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
   329  		buf = PackInq(t, cmsgs.IP.Inq, buf)
   330  	}
   331  
   332  	if cmsgs.IP.HasTOS {
   333  		buf = PackTOS(t, cmsgs.IP.TOS, buf)
   334  	}
   335  
   336  	if cmsgs.IP.HasTTL {
   337  		buf = PackTTL(t, cmsgs.IP.TTL, buf)
   338  	}
   339  
   340  	if cmsgs.IP.HasTClass {
   341  		buf = PackTClass(t, cmsgs.IP.TClass, buf)
   342  	}
   343  
   344  	if cmsgs.IP.HasHopLimit {
   345  		buf = PackHopLimit(t, cmsgs.IP.HopLimit, buf)
   346  	}
   347  
   348  	if cmsgs.IP.HasIPPacketInfo {
   349  		buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
   350  	}
   351  
   352  	if cmsgs.IP.HasIPv6PacketInfo {
   353  		buf = PackIPv6PacketInfo(t, &cmsgs.IP.IPv6PacketInfo, buf)
   354  	}
   355  
   356  	if cmsgs.IP.OriginalDstAddress != nil {
   357  		buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
   358  	}
   359  
   360  	if cmsgs.IP.SockErr != nil {
   361  		buf = PackSockExtendedErr(t, cmsgs.IP.SockErr, buf)
   362  	}
   363  
   364  	return buf
   365  }
   366  
   367  // cmsgSpace is equivalent to CMSG_SPACE in Linux.
   368  func cmsgSpace(t *kernel.Task, dataLen int) int {
   369  	return linux.SizeOfControlMessageHeader + bits.AlignUp(dataLen, t.Arch().Width())
   370  }
   371  
   372  // CmsgsSpace returns the number of bytes needed to fit the control messages
   373  // represented in cmsgs.
   374  func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
   375  	space := 0
   376  
   377  	if cmsgs.IP.HasTimestamp {
   378  		space += cmsgSpace(t, linux.SizeOfTimeval)
   379  	}
   380  
   381  	if cmsgs.IP.HasInq {
   382  		space += cmsgSpace(t, linux.SizeOfControlMessageInq)
   383  	}
   384  
   385  	if cmsgs.IP.HasTOS {
   386  		space += cmsgSpace(t, linux.SizeOfControlMessageTOS)
   387  	}
   388  
   389  	if cmsgs.IP.HasTTL {
   390  		space += cmsgSpace(t, linux.SizeOfControlMessageTTL)
   391  	}
   392  
   393  	if cmsgs.IP.HasTClass {
   394  		space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
   395  	}
   396  
   397  	if cmsgs.IP.HasHopLimit {
   398  		space += cmsgSpace(t, linux.SizeOfControlMessageHopLimit)
   399  	}
   400  
   401  	if cmsgs.IP.HasIPPacketInfo {
   402  		space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
   403  	}
   404  
   405  	if cmsgs.IP.HasIPv6PacketInfo {
   406  		space += cmsgSpace(t, linux.SizeOfControlMessageIPv6PacketInfo)
   407  	}
   408  
   409  	if cmsgs.IP.OriginalDstAddress != nil {
   410  		space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
   411  	}
   412  
   413  	if cmsgs.IP.SockErr != nil {
   414  		space += cmsgSpace(t, cmsgs.IP.SockErr.SizeBytes())
   415  	}
   416  
   417  	return space
   418  }
   419  
   420  // Parse parses a raw socket control message into portable objects.
   421  // TODO(https://gvisor.dev/issue/7188): Parse is only called on raw cmsg that
   422  // are used when sending a messages. We should fail with EINVAL when we find a
   423  // non-sendable control messages (such as IP_RECVERR). And the function should
   424  // be renamed to reflect that.
   425  func Parse(t *kernel.Task, socketOrEndpoint any, buf []byte, width uint) (socket.ControlMessages, error) {
   426  	var (
   427  		cmsgs socket.ControlMessages
   428  		fds   []primitive.Int32
   429  	)
   430  
   431  	for len(buf) > 0 {
   432  		if linux.SizeOfControlMessageHeader > len(buf) {
   433  			return cmsgs, linuxerr.EINVAL
   434  		}
   435  
   436  		var h linux.ControlMessageHeader
   437  		buf = h.UnmarshalUnsafe(buf)
   438  
   439  		if h.Length < uint64(linux.SizeOfControlMessageHeader) {
   440  			return socket.ControlMessages{}, linuxerr.EINVAL
   441  		}
   442  
   443  		length := int(h.Length) - linux.SizeOfControlMessageHeader
   444  		if length < 0 || length > len(buf) {
   445  			return socket.ControlMessages{}, linuxerr.EINVAL
   446  		}
   447  
   448  		switch h.Level {
   449  		case linux.SOL_SOCKET:
   450  			switch h.Type {
   451  			case linux.SCM_RIGHTS:
   452  				rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight)
   453  				numRights := rightsSize / linux.SizeOfControlMessageRight
   454  
   455  				if len(fds)+numRights > linux.SCM_MAX_FD {
   456  					return socket.ControlMessages{}, linuxerr.EINVAL
   457  				}
   458  
   459  				curFDs := make([]primitive.Int32, numRights)
   460  				primitive.UnmarshalUnsafeInt32Slice(curFDs, buf[:rightsSize])
   461  				fds = append(fds, curFDs...)
   462  
   463  			case linux.SCM_CREDENTIALS:
   464  				if length < linux.SizeOfControlMessageCredentials {
   465  					return socket.ControlMessages{}, linuxerr.EINVAL
   466  				}
   467  
   468  				var creds linux.ControlMessageCredentials
   469  				creds.UnmarshalUnsafe(buf)
   470  				scmCreds, err := NewSCMCredentials(t, creds)
   471  				if err != nil {
   472  					return socket.ControlMessages{}, err
   473  				}
   474  				cmsgs.Unix.Credentials = scmCreds
   475  
   476  			case linux.SO_TIMESTAMP:
   477  				if length < linux.SizeOfTimeval {
   478  					return socket.ControlMessages{}, linuxerr.EINVAL
   479  				}
   480  				var ts linux.Timeval
   481  				ts.UnmarshalUnsafe(buf)
   482  				cmsgs.IP.Timestamp = ts.ToTime()
   483  				cmsgs.IP.HasTimestamp = true
   484  
   485  			default:
   486  				// Unknown message type.
   487  				return socket.ControlMessages{}, linuxerr.EINVAL
   488  			}
   489  		case linux.SOL_IP:
   490  			switch h.Type {
   491  			case linux.IP_TOS:
   492  				if length < linux.SizeOfControlMessageTOS {
   493  					return socket.ControlMessages{}, linuxerr.EINVAL
   494  				}
   495  				cmsgs.IP.HasTOS = true
   496  				var tos primitive.Uint8
   497  				tos.UnmarshalUnsafe(buf)
   498  				cmsgs.IP.TOS = uint8(tos)
   499  
   500  			case linux.IP_TTL:
   501  				if length < linux.SizeOfControlMessageTTL {
   502  					return socket.ControlMessages{}, linuxerr.EINVAL
   503  				}
   504  				var ttl primitive.Uint32
   505  				ttl.UnmarshalUnsafe(buf)
   506  				if ttl == 0 || ttl > math.MaxUint8 {
   507  					return socket.ControlMessages{}, linuxerr.EINVAL
   508  				}
   509  				cmsgs.IP.TTL = uint32(ttl)
   510  				cmsgs.IP.HasTTL = true
   511  
   512  			case linux.IP_PKTINFO:
   513  				if length < linux.SizeOfControlMessageIPPacketInfo {
   514  					return socket.ControlMessages{}, linuxerr.EINVAL
   515  				}
   516  
   517  				cmsgs.IP.HasIPPacketInfo = true
   518  				var packetInfo linux.ControlMessageIPPacketInfo
   519  				packetInfo.UnmarshalUnsafe(buf)
   520  				cmsgs.IP.PacketInfo = packetInfo
   521  
   522  			case linux.IP_RECVORIGDSTADDR:
   523  				var addr linux.SockAddrInet
   524  				if length < addr.SizeBytes() {
   525  					return socket.ControlMessages{}, linuxerr.EINVAL
   526  				}
   527  				addr.UnmarshalUnsafe(buf)
   528  				cmsgs.IP.OriginalDstAddress = &addr
   529  
   530  			case linux.IP_RECVERR:
   531  				var errCmsg linux.SockErrCMsgIPv4
   532  				if length < errCmsg.SizeBytes() {
   533  					return socket.ControlMessages{}, linuxerr.EINVAL
   534  				}
   535  
   536  				errCmsg.UnmarshalBytes(buf)
   537  				cmsgs.IP.SockErr = &errCmsg
   538  
   539  			default:
   540  				return socket.ControlMessages{}, linuxerr.EINVAL
   541  			}
   542  		case linux.SOL_IPV6:
   543  			switch h.Type {
   544  			case linux.IPV6_TCLASS:
   545  				if length < linux.SizeOfControlMessageTClass {
   546  					return socket.ControlMessages{}, linuxerr.EINVAL
   547  				}
   548  				cmsgs.IP.HasTClass = true
   549  				var tclass primitive.Uint32
   550  				tclass.UnmarshalUnsafe(buf)
   551  				cmsgs.IP.TClass = uint32(tclass)
   552  
   553  			case linux.IPV6_PKTINFO:
   554  				if length < linux.SizeOfControlMessageIPv6PacketInfo {
   555  					return socket.ControlMessages{}, linuxerr.EINVAL
   556  				}
   557  
   558  				cmsgs.IP.HasIPv6PacketInfo = true
   559  				var packetInfo linux.ControlMessageIPv6PacketInfo
   560  				packetInfo.UnmarshalUnsafe(buf)
   561  				cmsgs.IP.IPv6PacketInfo = packetInfo
   562  
   563  			case linux.IPV6_HOPLIMIT:
   564  				if length < linux.SizeOfControlMessageHopLimit {
   565  					return socket.ControlMessages{}, linuxerr.EINVAL
   566  				}
   567  				var hoplimit primitive.Uint32
   568  				hoplimit.UnmarshalUnsafe(buf)
   569  				if hoplimit > math.MaxUint8 {
   570  					return socket.ControlMessages{}, linuxerr.EINVAL
   571  				}
   572  				cmsgs.IP.HasHopLimit = true
   573  				cmsgs.IP.HopLimit = uint32(hoplimit)
   574  
   575  			case linux.IPV6_RECVORIGDSTADDR:
   576  				var addr linux.SockAddrInet6
   577  				if length < addr.SizeBytes() {
   578  					return socket.ControlMessages{}, linuxerr.EINVAL
   579  				}
   580  				addr.UnmarshalUnsafe(buf)
   581  				cmsgs.IP.OriginalDstAddress = &addr
   582  
   583  			case linux.IPV6_RECVERR:
   584  				var errCmsg linux.SockErrCMsgIPv6
   585  				if length < errCmsg.SizeBytes() {
   586  					return socket.ControlMessages{}, linuxerr.EINVAL
   587  				}
   588  
   589  				errCmsg.UnmarshalBytes(buf)
   590  				cmsgs.IP.SockErr = &errCmsg
   591  
   592  			default:
   593  				return socket.ControlMessages{}, linuxerr.EINVAL
   594  			}
   595  		default:
   596  			return socket.ControlMessages{}, linuxerr.EINVAL
   597  		}
   598  		if shift := bits.AlignUp(length, width); shift > len(buf) {
   599  			buf = buf[:0]
   600  		} else {
   601  			buf = buf[shift:]
   602  		}
   603  	}
   604  
   605  	if cmsgs.Unix.Credentials == nil {
   606  		cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint)
   607  	}
   608  
   609  	if len(fds) > 0 {
   610  		rights, err := NewSCMRights(t, fds)
   611  		if err != nil {
   612  			return socket.ControlMessages{}, err
   613  		}
   614  		cmsgs.Unix.Rights = rights
   615  	}
   616  
   617  	return cmsgs, nil
   618  }
   619  
   620  func makeCreds(t *kernel.Task, socketOrEndpoint any) SCMCredentials {
   621  	if t == nil || socketOrEndpoint == nil {
   622  		return nil
   623  	}
   624  	if cr, ok := socketOrEndpoint.(transport.Credentialer); ok && (cr.Passcred() || cr.ConnectedPasscred()) {
   625  		return MakeCreds(t)
   626  	}
   627  	return nil
   628  }
   629  
   630  // MakeCreds creates default SCMCredentials.
   631  func MakeCreds(t *kernel.Task) SCMCredentials {
   632  	if t == nil {
   633  		return nil
   634  	}
   635  	tcred := t.Credentials()
   636  	return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID}
   637  }
   638  
   639  // New creates default control messages if needed.
   640  func New(t *kernel.Task, socketOrEndpoint any) transport.ControlMessages {
   641  	return transport.ControlMessages{
   642  		Credentials: makeCreds(t, socketOrEndpoint),
   643  	}
   644  }
   645  
   646  // SCMRights represents a SCM_RIGHTS socket control message.
   647  //
   648  // +stateify savable
   649  type SCMRights interface {
   650  	transport.RightsControlMessage
   651  
   652  	// Files returns up to max RightsFiles.
   653  	//
   654  	// Returned files are consumed and ownership is transferred to the caller.
   655  	// Subsequent calls to Files will return the next files.
   656  	Files(ctx context.Context, max int) (rf RightsFiles, truncated bool)
   657  }
   658  
   659  // RightsFiles represents a SCM_RIGHTS socket control message. A reference
   660  // is maintained for each vfs.FileDescription and is release either when an FD
   661  // is created or when the Release method is called.
   662  //
   663  // +stateify savable
   664  type RightsFiles []*vfs.FileDescription
   665  
   666  // NewSCMRights creates a new SCM_RIGHTS socket control message
   667  // representation using local sentry FDs.
   668  func NewSCMRights(t *kernel.Task, fds []primitive.Int32) (SCMRights, error) {
   669  	files := make(RightsFiles, 0, len(fds))
   670  	for _, fd := range fds {
   671  		file := t.GetFile(int32(fd))
   672  		if file == nil {
   673  			files.Release(t)
   674  			return nil, linuxerr.EBADF
   675  		}
   676  		files = append(files, file)
   677  	}
   678  	return &files, nil
   679  }
   680  
   681  // Files implements SCMRights.Files.
   682  func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) {
   683  	n := max
   684  	var trunc bool
   685  	if l := len(*fs); n > l {
   686  		n = l
   687  	} else if n < l {
   688  		trunc = true
   689  	}
   690  	rf := (*fs)[:n]
   691  	*fs = (*fs)[n:]
   692  	return rf, trunc
   693  }
   694  
   695  // Clone implements transport.RightsControlMessage.Clone.
   696  func (fs *RightsFiles) Clone() transport.RightsControlMessage {
   697  	nfs := append(RightsFiles(nil), *fs...)
   698  	for _, nf := range nfs {
   699  		nf.IncRef()
   700  	}
   701  	return &nfs
   702  }
   703  
   704  // Release implements transport.RightsControlMessage.Release.
   705  func (fs *RightsFiles) Release(ctx context.Context) {
   706  	for _, f := range *fs {
   707  		f.DecRef(ctx)
   708  	}
   709  	*fs = nil
   710  }
   711  
   712  // rightsFDs gets up to the specified maximum number of FDs.
   713  func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32, bool) {
   714  	files, trunc := rights.Files(t, max)
   715  	fds := make([]int32, 0, len(files))
   716  	for i := 0; i < max && len(files) > 0; i++ {
   717  		fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{
   718  			CloseOnExec: cloexec,
   719  		})
   720  		files[0].DecRef(t)
   721  		files = files[1:]
   722  		if err != nil {
   723  			t.Warningf("Error inserting FD: %v", err)
   724  			// This is what Linux does.
   725  			break
   726  		}
   727  
   728  		fds = append(fds, int32(fd))
   729  	}
   730  	return fds, trunc
   731  }
   732  
   733  // PackRights packs as many FDs as will fit into the unused capacity of buf.
   734  func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte, flags int) ([]byte, int) {
   735  	maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4
   736  	// Linux does not return any FDs if none fit.
   737  	if maxFDs <= 0 {
   738  		flags |= linux.MSG_CTRUNC
   739  		return buf, flags
   740  	}
   741  	fds, trunc := rightsFDs(t, rights, cloexec, maxFDs)
   742  	if trunc {
   743  		flags |= linux.MSG_CTRUNC
   744  	}
   745  	align := t.Arch().Width()
   746  	return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds)
   747  }