github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/sentry/socket/control/control.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package control provides internal representations of socket control 16 // messages. 17 package control 18 19 import ( 20 "math" 21 "time" 22 23 "github.com/nicocha30/gvisor-ligolo/pkg/abi/linux" 24 "github.com/nicocha30/gvisor-ligolo/pkg/bits" 25 "github.com/nicocha30/gvisor-ligolo/pkg/context" 26 "github.com/nicocha30/gvisor-ligolo/pkg/errors/linuxerr" 27 "github.com/nicocha30/gvisor-ligolo/pkg/hostarch" 28 "github.com/nicocha30/gvisor-ligolo/pkg/marshal" 29 "github.com/nicocha30/gvisor-ligolo/pkg/marshal/primitive" 30 "github.com/nicocha30/gvisor-ligolo/pkg/sentry/kernel" 31 "github.com/nicocha30/gvisor-ligolo/pkg/sentry/kernel/auth" 32 "github.com/nicocha30/gvisor-ligolo/pkg/sentry/socket" 33 "github.com/nicocha30/gvisor-ligolo/pkg/sentry/socket/unix/transport" 34 "github.com/nicocha30/gvisor-ligolo/pkg/sentry/vfs" 35 ) 36 37 // SCMCredentials represents a SCM_CREDENTIALS socket control message. 38 type SCMCredentials interface { 39 transport.CredentialsControlMessage 40 41 // Credentials returns properly namespaced values for the sender's pid, uid 42 // and gid. 43 Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) 44 } 45 46 // scmCredentials represents an SCM_CREDENTIALS socket control message. 47 // 48 // +stateify savable 49 type scmCredentials struct { 50 t *kernel.Task 51 kuid auth.KUID 52 kgid auth.KGID 53 } 54 55 // NewSCMCredentials creates a new SCM_CREDENTIALS socket control message 56 // representation. 57 func NewSCMCredentials(t *kernel.Task, cred linux.ControlMessageCredentials) (SCMCredentials, error) { 58 tcred := t.Credentials() 59 kuid, err := tcred.UseUID(auth.UID(cred.UID)) 60 if err != nil { 61 return nil, err 62 } 63 kgid, err := tcred.UseGID(auth.GID(cred.GID)) 64 if err != nil { 65 return nil, err 66 } 67 if kernel.ThreadID(cred.PID) != t.ThreadGroup().ID() && !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.PIDNamespace().UserNamespace()) { 68 return nil, linuxerr.EPERM 69 } 70 return &scmCredentials{t, kuid, kgid}, nil 71 } 72 73 // Equals implements transport.CredentialsControlMessage.Equals. 74 func (c *scmCredentials) Equals(oc transport.CredentialsControlMessage) bool { 75 if oc, _ := oc.(*scmCredentials); oc != nil && *c == *oc { 76 return true 77 } 78 return false 79 } 80 81 func putUint64(buf []byte, n uint64) []byte { 82 hostarch.ByteOrder.PutUint64(buf[len(buf):len(buf)+8], n) 83 return buf[:len(buf)+8] 84 } 85 86 func putUint32(buf []byte, n uint32) []byte { 87 hostarch.ByteOrder.PutUint32(buf[len(buf):len(buf)+4], n) 88 return buf[:len(buf)+4] 89 } 90 91 // putCmsg writes a control message header and as much data as will fit into 92 // the unused capacity of a buffer. 93 func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) { 94 space := bits.AlignDown(cap(buf)-len(buf), 4) 95 96 // We can't write to space that doesn't exist, so if we are going to align 97 // the available space, we must align down. 98 // 99 // align must be >= 4 and each data int32 is 4 bytes. The length of the 100 // header is already aligned, so if we align to the width of the data there 101 // are two cases: 102 // 1. The aligned length is less than the length of the header. The 103 // unaligned length was also less than the length of the header, so we 104 // can't write anything. 105 // 2. The aligned length is greater than or equal to the length of the 106 // header. We can write the header plus zero or more bytes of data. We can't 107 // write a partial int32, so the length of the message will be 108 // min(aligned length, header + data). 109 if space < linux.SizeOfControlMessageHeader { 110 flags |= linux.MSG_CTRUNC 111 return buf, flags 112 } 113 114 length := 4*len(data) + linux.SizeOfControlMessageHeader 115 if length > space { 116 length = space 117 } 118 buf = putUint64(buf, uint64(length)) 119 buf = putUint32(buf, linux.SOL_SOCKET) 120 buf = putUint32(buf, msgType) 121 for _, d := range data { 122 if len(buf)+4 > cap(buf) { 123 flags |= linux.MSG_CTRUNC 124 break 125 } 126 buf = putUint32(buf, uint32(d)) 127 } 128 return alignSlice(buf, align), flags 129 } 130 131 func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data marshal.Marshallable) []byte { 132 if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader { 133 return buf 134 } 135 ob := buf 136 137 buf = putUint64(buf, uint64(linux.SizeOfControlMessageHeader)) 138 buf = putUint32(buf, msgLevel) 139 buf = putUint32(buf, msgType) 140 141 hdrBuf := buf 142 buf = append(buf, marshal.Marshal(data)...) 143 144 // If the control message data brought us over capacity, omit it. 145 if cap(buf) != cap(ob) { 146 return hdrBuf 147 } 148 149 // Update control message length to include data. 150 putUint64(ob, uint64(len(buf)-len(ob))) 151 152 return alignSlice(buf, align) 153 } 154 155 // Credentials implements SCMCredentials.Credentials. 156 func (c *scmCredentials) Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) { 157 // "When a process's user and group IDs are passed over a UNIX domain 158 // socket to a process in a different user namespace (see the description 159 // of SCM_CREDENTIALS in unix(7)), they are translated into the 160 // corresponding values as per the receiving process's user and group ID 161 // mappings." - user_namespaces(7) 162 pid := t.PIDNamespace().IDOfTask(c.t) 163 uid := c.kuid.In(t.UserNamespace()).OrOverflow() 164 gid := c.kgid.In(t.UserNamespace()).OrOverflow() 165 166 return pid, uid, gid 167 } 168 169 // PackCredentials packs the credentials in the control message (or default 170 // credentials if none) into a buffer. 171 func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int) ([]byte, int) { 172 align := t.Arch().Width() 173 174 // Default credentials if none are available. 175 pid := kernel.ThreadID(0) 176 uid := auth.UID(auth.NobodyKUID) 177 gid := auth.GID(auth.NobodyKGID) 178 179 if creds != nil { 180 pid, uid, gid = creds.Credentials(t) 181 } 182 c := []int32{int32(pid), int32(uid), int32(gid)} 183 return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c) 184 } 185 186 // alignSlice extends a slice's length (up to the capacity) to align it. 187 func alignSlice(buf []byte, align uint) []byte { 188 aligned := bits.AlignUp(len(buf), align) 189 if aligned > cap(buf) { 190 // Linux allows unaligned data if there isn't room for alignment. 191 // Since there isn't room for alignment, there isn't room for any 192 // additional messages either. 193 return buf 194 } 195 return buf[:aligned] 196 } 197 198 // PackTimestamp packs a SO_TIMESTAMP socket control message. 199 func PackTimestamp(t *kernel.Task, timestamp time.Time, buf []byte) []byte { 200 timestampP := linux.NsecToTimeval(timestamp.UnixNano()) 201 return putCmsgStruct( 202 buf, 203 linux.SOL_SOCKET, 204 linux.SO_TIMESTAMP, 205 t.Arch().Width(), 206 ×tampP, 207 ) 208 } 209 210 // PackInq packs a TCP_INQ socket control message. 211 func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { 212 return putCmsgStruct( 213 buf, 214 linux.SOL_TCP, 215 linux.TCP_INQ, 216 t.Arch().Width(), 217 primitive.AllocateInt32(inq), 218 ) 219 } 220 221 // PackTOS packs an IP_TOS socket control message. 222 func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { 223 return putCmsgStruct( 224 buf, 225 linux.SOL_IP, 226 linux.IP_TOS, 227 t.Arch().Width(), 228 primitive.AllocateUint8(tos), 229 ) 230 } 231 232 // PackTClass packs an IPV6_TCLASS socket control message. 233 func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { 234 return putCmsgStruct( 235 buf, 236 linux.SOL_IPV6, 237 linux.IPV6_TCLASS, 238 t.Arch().Width(), 239 primitive.AllocateUint32(tClass), 240 ) 241 } 242 243 // PackTTL packs an IP_TTL socket control message. 244 func PackTTL(t *kernel.Task, ttl uint32, buf []byte) []byte { 245 return putCmsgStruct( 246 buf, 247 linux.SOL_IP, 248 linux.IP_TTL, 249 t.Arch().Width(), 250 primitive.AllocateUint32(ttl), 251 ) 252 } 253 254 // PackHopLimit packs an IPV6_HOPLIMIT socket control message. 255 func PackHopLimit(t *kernel.Task, hoplimit uint32, buf []byte) []byte { 256 return putCmsgStruct( 257 buf, 258 linux.SOL_IPV6, 259 linux.IPV6_HOPLIMIT, 260 t.Arch().Width(), 261 primitive.AllocateUint32(hoplimit), 262 ) 263 } 264 265 // PackIPPacketInfo packs an IP_PKTINFO socket control message. 266 func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte { 267 return putCmsgStruct( 268 buf, 269 linux.SOL_IP, 270 linux.IP_PKTINFO, 271 t.Arch().Width(), 272 packetInfo, 273 ) 274 } 275 276 // PackIPv6PacketInfo packs an IPV6_PKTINFO socket control message. 277 func PackIPv6PacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPv6PacketInfo, buf []byte) []byte { 278 return putCmsgStruct( 279 buf, 280 linux.SOL_IPV6, 281 linux.IPV6_PKTINFO, 282 t.Arch().Width(), 283 packetInfo, 284 ) 285 } 286 287 // PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. 288 func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte { 289 var level uint32 290 var optType uint32 291 switch originalDstAddress.(type) { 292 case *linux.SockAddrInet: 293 level = linux.SOL_IP 294 optType = linux.IP_RECVORIGDSTADDR 295 case *linux.SockAddrInet6: 296 level = linux.SOL_IPV6 297 optType = linux.IPV6_RECVORIGDSTADDR 298 default: 299 panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg") 300 } 301 return putCmsgStruct( 302 buf, level, optType, t.Arch().Width(), originalDstAddress) 303 } 304 305 // PackSockExtendedErr packs an IP*_RECVERR socket control message. 306 func PackSockExtendedErr(t *kernel.Task, sockErr linux.SockErrCMsg, buf []byte) []byte { 307 return putCmsgStruct( 308 buf, 309 sockErr.CMsgLevel(), 310 sockErr.CMsgType(), 311 t.Arch().Width(), 312 sockErr, 313 ) 314 } 315 316 // PackControlMessages packs control messages into the given buffer. 317 // 318 // We skip control messages specific to Unix domain sockets. 319 // 320 // Note that some control messages may be truncated if they do not fit under 321 // the capacity of buf. 322 func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte { 323 if cmsgs.IP.HasTimestamp { 324 buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf) 325 } 326 327 if cmsgs.IP.HasInq { 328 // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. 329 buf = PackInq(t, cmsgs.IP.Inq, buf) 330 } 331 332 if cmsgs.IP.HasTOS { 333 buf = PackTOS(t, cmsgs.IP.TOS, buf) 334 } 335 336 if cmsgs.IP.HasTTL { 337 buf = PackTTL(t, cmsgs.IP.TTL, buf) 338 } 339 340 if cmsgs.IP.HasTClass { 341 buf = PackTClass(t, cmsgs.IP.TClass, buf) 342 } 343 344 if cmsgs.IP.HasHopLimit { 345 buf = PackHopLimit(t, cmsgs.IP.HopLimit, buf) 346 } 347 348 if cmsgs.IP.HasIPPacketInfo { 349 buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf) 350 } 351 352 if cmsgs.IP.HasIPv6PacketInfo { 353 buf = PackIPv6PacketInfo(t, &cmsgs.IP.IPv6PacketInfo, buf) 354 } 355 356 if cmsgs.IP.OriginalDstAddress != nil { 357 buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) 358 } 359 360 if cmsgs.IP.SockErr != nil { 361 buf = PackSockExtendedErr(t, cmsgs.IP.SockErr, buf) 362 } 363 364 return buf 365 } 366 367 // cmsgSpace is equivalent to CMSG_SPACE in Linux. 368 func cmsgSpace(t *kernel.Task, dataLen int) int { 369 return linux.SizeOfControlMessageHeader + bits.AlignUp(dataLen, t.Arch().Width()) 370 } 371 372 // CmsgsSpace returns the number of bytes needed to fit the control messages 373 // represented in cmsgs. 374 func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { 375 space := 0 376 377 if cmsgs.IP.HasTimestamp { 378 space += cmsgSpace(t, linux.SizeOfTimeval) 379 } 380 381 if cmsgs.IP.HasInq { 382 space += cmsgSpace(t, linux.SizeOfControlMessageInq) 383 } 384 385 if cmsgs.IP.HasTOS { 386 space += cmsgSpace(t, linux.SizeOfControlMessageTOS) 387 } 388 389 if cmsgs.IP.HasTTL { 390 space += cmsgSpace(t, linux.SizeOfControlMessageTTL) 391 } 392 393 if cmsgs.IP.HasTClass { 394 space += cmsgSpace(t, linux.SizeOfControlMessageTClass) 395 } 396 397 if cmsgs.IP.HasHopLimit { 398 space += cmsgSpace(t, linux.SizeOfControlMessageHopLimit) 399 } 400 401 if cmsgs.IP.HasIPPacketInfo { 402 space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo) 403 } 404 405 if cmsgs.IP.HasIPv6PacketInfo { 406 space += cmsgSpace(t, linux.SizeOfControlMessageIPv6PacketInfo) 407 } 408 409 if cmsgs.IP.OriginalDstAddress != nil { 410 space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) 411 } 412 413 if cmsgs.IP.SockErr != nil { 414 space += cmsgSpace(t, cmsgs.IP.SockErr.SizeBytes()) 415 } 416 417 return space 418 } 419 420 // Parse parses a raw socket control message into portable objects. 421 // TODO(https://gvisor.dev/issue/7188): Parse is only called on raw cmsg that 422 // are used when sending a messages. We should fail with EINVAL when we find a 423 // non-sendable control messages (such as IP_RECVERR). And the function should 424 // be renamed to reflect that. 425 func Parse(t *kernel.Task, socketOrEndpoint any, buf []byte, width uint) (socket.ControlMessages, error) { 426 var ( 427 cmsgs socket.ControlMessages 428 fds []primitive.Int32 429 ) 430 431 for len(buf) > 0 { 432 if linux.SizeOfControlMessageHeader > len(buf) { 433 return cmsgs, linuxerr.EINVAL 434 } 435 436 var h linux.ControlMessageHeader 437 buf = h.UnmarshalUnsafe(buf) 438 439 if h.Length < uint64(linux.SizeOfControlMessageHeader) { 440 return socket.ControlMessages{}, linuxerr.EINVAL 441 } 442 443 length := int(h.Length) - linux.SizeOfControlMessageHeader 444 if length < 0 || length > len(buf) { 445 return socket.ControlMessages{}, linuxerr.EINVAL 446 } 447 448 switch h.Level { 449 case linux.SOL_SOCKET: 450 switch h.Type { 451 case linux.SCM_RIGHTS: 452 rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) 453 numRights := rightsSize / linux.SizeOfControlMessageRight 454 455 if len(fds)+numRights > linux.SCM_MAX_FD { 456 return socket.ControlMessages{}, linuxerr.EINVAL 457 } 458 459 curFDs := make([]primitive.Int32, numRights) 460 primitive.UnmarshalUnsafeInt32Slice(curFDs, buf[:rightsSize]) 461 fds = append(fds, curFDs...) 462 463 case linux.SCM_CREDENTIALS: 464 if length < linux.SizeOfControlMessageCredentials { 465 return socket.ControlMessages{}, linuxerr.EINVAL 466 } 467 468 var creds linux.ControlMessageCredentials 469 creds.UnmarshalUnsafe(buf) 470 scmCreds, err := NewSCMCredentials(t, creds) 471 if err != nil { 472 return socket.ControlMessages{}, err 473 } 474 cmsgs.Unix.Credentials = scmCreds 475 476 case linux.SO_TIMESTAMP: 477 if length < linux.SizeOfTimeval { 478 return socket.ControlMessages{}, linuxerr.EINVAL 479 } 480 var ts linux.Timeval 481 ts.UnmarshalUnsafe(buf) 482 cmsgs.IP.Timestamp = ts.ToTime() 483 cmsgs.IP.HasTimestamp = true 484 485 default: 486 // Unknown message type. 487 return socket.ControlMessages{}, linuxerr.EINVAL 488 } 489 case linux.SOL_IP: 490 switch h.Type { 491 case linux.IP_TOS: 492 if length < linux.SizeOfControlMessageTOS { 493 return socket.ControlMessages{}, linuxerr.EINVAL 494 } 495 cmsgs.IP.HasTOS = true 496 var tos primitive.Uint8 497 tos.UnmarshalUnsafe(buf) 498 cmsgs.IP.TOS = uint8(tos) 499 500 case linux.IP_TTL: 501 if length < linux.SizeOfControlMessageTTL { 502 return socket.ControlMessages{}, linuxerr.EINVAL 503 } 504 var ttl primitive.Uint32 505 ttl.UnmarshalUnsafe(buf) 506 if ttl == 0 || ttl > math.MaxUint8 { 507 return socket.ControlMessages{}, linuxerr.EINVAL 508 } 509 cmsgs.IP.TTL = uint32(ttl) 510 cmsgs.IP.HasTTL = true 511 512 case linux.IP_PKTINFO: 513 if length < linux.SizeOfControlMessageIPPacketInfo { 514 return socket.ControlMessages{}, linuxerr.EINVAL 515 } 516 517 cmsgs.IP.HasIPPacketInfo = true 518 var packetInfo linux.ControlMessageIPPacketInfo 519 packetInfo.UnmarshalUnsafe(buf) 520 cmsgs.IP.PacketInfo = packetInfo 521 522 case linux.IP_RECVORIGDSTADDR: 523 var addr linux.SockAddrInet 524 if length < addr.SizeBytes() { 525 return socket.ControlMessages{}, linuxerr.EINVAL 526 } 527 addr.UnmarshalUnsafe(buf) 528 cmsgs.IP.OriginalDstAddress = &addr 529 530 case linux.IP_RECVERR: 531 var errCmsg linux.SockErrCMsgIPv4 532 if length < errCmsg.SizeBytes() { 533 return socket.ControlMessages{}, linuxerr.EINVAL 534 } 535 536 errCmsg.UnmarshalBytes(buf) 537 cmsgs.IP.SockErr = &errCmsg 538 539 default: 540 return socket.ControlMessages{}, linuxerr.EINVAL 541 } 542 case linux.SOL_IPV6: 543 switch h.Type { 544 case linux.IPV6_TCLASS: 545 if length < linux.SizeOfControlMessageTClass { 546 return socket.ControlMessages{}, linuxerr.EINVAL 547 } 548 cmsgs.IP.HasTClass = true 549 var tclass primitive.Uint32 550 tclass.UnmarshalUnsafe(buf) 551 cmsgs.IP.TClass = uint32(tclass) 552 553 case linux.IPV6_PKTINFO: 554 if length < linux.SizeOfControlMessageIPv6PacketInfo { 555 return socket.ControlMessages{}, linuxerr.EINVAL 556 } 557 558 cmsgs.IP.HasIPv6PacketInfo = true 559 var packetInfo linux.ControlMessageIPv6PacketInfo 560 packetInfo.UnmarshalUnsafe(buf) 561 cmsgs.IP.IPv6PacketInfo = packetInfo 562 563 case linux.IPV6_HOPLIMIT: 564 if length < linux.SizeOfControlMessageHopLimit { 565 return socket.ControlMessages{}, linuxerr.EINVAL 566 } 567 var hoplimit primitive.Uint32 568 hoplimit.UnmarshalUnsafe(buf) 569 if hoplimit > math.MaxUint8 { 570 return socket.ControlMessages{}, linuxerr.EINVAL 571 } 572 cmsgs.IP.HasHopLimit = true 573 cmsgs.IP.HopLimit = uint32(hoplimit) 574 575 case linux.IPV6_RECVORIGDSTADDR: 576 var addr linux.SockAddrInet6 577 if length < addr.SizeBytes() { 578 return socket.ControlMessages{}, linuxerr.EINVAL 579 } 580 addr.UnmarshalUnsafe(buf) 581 cmsgs.IP.OriginalDstAddress = &addr 582 583 case linux.IPV6_RECVERR: 584 var errCmsg linux.SockErrCMsgIPv6 585 if length < errCmsg.SizeBytes() { 586 return socket.ControlMessages{}, linuxerr.EINVAL 587 } 588 589 errCmsg.UnmarshalBytes(buf) 590 cmsgs.IP.SockErr = &errCmsg 591 592 default: 593 return socket.ControlMessages{}, linuxerr.EINVAL 594 } 595 default: 596 return socket.ControlMessages{}, linuxerr.EINVAL 597 } 598 if shift := bits.AlignUp(length, width); shift > len(buf) { 599 buf = buf[:0] 600 } else { 601 buf = buf[shift:] 602 } 603 } 604 605 if cmsgs.Unix.Credentials == nil { 606 cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint) 607 } 608 609 if len(fds) > 0 { 610 rights, err := NewSCMRights(t, fds) 611 if err != nil { 612 return socket.ControlMessages{}, err 613 } 614 cmsgs.Unix.Rights = rights 615 } 616 617 return cmsgs, nil 618 } 619 620 func makeCreds(t *kernel.Task, socketOrEndpoint any) SCMCredentials { 621 if t == nil || socketOrEndpoint == nil { 622 return nil 623 } 624 if cr, ok := socketOrEndpoint.(transport.Credentialer); ok && (cr.Passcred() || cr.ConnectedPasscred()) { 625 return MakeCreds(t) 626 } 627 return nil 628 } 629 630 // MakeCreds creates default SCMCredentials. 631 func MakeCreds(t *kernel.Task) SCMCredentials { 632 if t == nil { 633 return nil 634 } 635 tcred := t.Credentials() 636 return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID} 637 } 638 639 // New creates default control messages if needed. 640 func New(t *kernel.Task, socketOrEndpoint any) transport.ControlMessages { 641 return transport.ControlMessages{ 642 Credentials: makeCreds(t, socketOrEndpoint), 643 } 644 } 645 646 // SCMRights represents a SCM_RIGHTS socket control message. 647 // 648 // +stateify savable 649 type SCMRights interface { 650 transport.RightsControlMessage 651 652 // Files returns up to max RightsFiles. 653 // 654 // Returned files are consumed and ownership is transferred to the caller. 655 // Subsequent calls to Files will return the next files. 656 Files(ctx context.Context, max int) (rf RightsFiles, truncated bool) 657 } 658 659 // RightsFiles represents a SCM_RIGHTS socket control message. A reference 660 // is maintained for each vfs.FileDescription and is release either when an FD 661 // is created or when the Release method is called. 662 // 663 // +stateify savable 664 type RightsFiles []*vfs.FileDescription 665 666 // NewSCMRights creates a new SCM_RIGHTS socket control message 667 // representation using local sentry FDs. 668 func NewSCMRights(t *kernel.Task, fds []primitive.Int32) (SCMRights, error) { 669 files := make(RightsFiles, 0, len(fds)) 670 for _, fd := range fds { 671 file := t.GetFile(int32(fd)) 672 if file == nil { 673 files.Release(t) 674 return nil, linuxerr.EBADF 675 } 676 files = append(files, file) 677 } 678 return &files, nil 679 } 680 681 // Files implements SCMRights.Files. 682 func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) { 683 n := max 684 var trunc bool 685 if l := len(*fs); n > l { 686 n = l 687 } else if n < l { 688 trunc = true 689 } 690 rf := (*fs)[:n] 691 *fs = (*fs)[n:] 692 return rf, trunc 693 } 694 695 // Clone implements transport.RightsControlMessage.Clone. 696 func (fs *RightsFiles) Clone() transport.RightsControlMessage { 697 nfs := append(RightsFiles(nil), *fs...) 698 for _, nf := range nfs { 699 nf.IncRef() 700 } 701 return &nfs 702 } 703 704 // Release implements transport.RightsControlMessage.Release. 705 func (fs *RightsFiles) Release(ctx context.Context) { 706 for _, f := range *fs { 707 f.DecRef(ctx) 708 } 709 *fs = nil 710 } 711 712 // rightsFDs gets up to the specified maximum number of FDs. 713 func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32, bool) { 714 files, trunc := rights.Files(t, max) 715 fds := make([]int32, 0, len(files)) 716 for i := 0; i < max && len(files) > 0; i++ { 717 fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{ 718 CloseOnExec: cloexec, 719 }) 720 files[0].DecRef(t) 721 files = files[1:] 722 if err != nil { 723 t.Warningf("Error inserting FD: %v", err) 724 // This is what Linux does. 725 break 726 } 727 728 fds = append(fds, int32(fd)) 729 } 730 return fds, trunc 731 } 732 733 // PackRights packs as many FDs as will fit into the unused capacity of buf. 734 func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte, flags int) ([]byte, int) { 735 maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4 736 // Linux does not return any FDs if none fit. 737 if maxFDs <= 0 { 738 flags |= linux.MSG_CTRUNC 739 return buf, flags 740 } 741 fds, trunc := rightsFDs(t, rights, cloexec, maxFDs) 742 if trunc { 743 flags |= linux.MSG_CTRUNC 744 } 745 align := t.Arch().Width() 746 return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds) 747 }