github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/fdchannel/fdchannel_unsafe.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
    16  // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
    17  
    18  // Package fdchannel implements passing file descriptors between processes over
    19  // Unix domain sockets.
    20  package fdchannel
    21  
    22  import (
    23  	"fmt"
    24  	"unsafe"
    25  
    26  	"golang.org/x/sys/unix"
    27  )
    28  
    29  // int32 is the real type of a file descriptor.
    30  const sizeofInt32 = int(unsafe.Sizeof(int32(0)))
    31  
    32  // NewConnectedSockets returns a pair of file descriptors, owned by the caller,
    33  // representing connected sockets that may be passed to separate calls to
    34  // NewEndpoint to create connected Endpoints.
    35  func NewConnectedSockets() ([2]int, error) {
    36  	return unix.Socketpair(unix.AF_UNIX, unix.SOCK_SEQPACKET|unix.SOCK_CLOEXEC, 0)
    37  }
    38  
    39  // Endpoint sends file descriptors to, and receives them from, another
    40  // connected Endpoint.
    41  //
    42  // Endpoint is not copyable or movable by value.
    43  type Endpoint struct {
    44  	sockfd int32
    45  	msghdr unix.Msghdr
    46  	cmsg   *unix.Cmsghdr // followed by sizeofInt32 bytes of data
    47  }
    48  
    49  // Init must be called on zero-value Endpoints before first use. sockfd must be
    50  // a blocking AF_UNIX SOCK_SEQPACKET socket.
    51  func (ep *Endpoint) Init(sockfd int) {
    52  	// "Datagram sockets in various domains (e.g., the UNIX and Internet
    53  	// domains) permit zero-length datagrams." - recv(2). Experimentally,
    54  	// sendmsg+recvmsg for a zero-length datagram is slightly faster than
    55  	// sendmsg+recvmsg for a single byte over a stream socket.
    56  	cmsgSlice := make([]byte, unix.CmsgSpace(sizeofInt32))
    57  	ep.sockfd = int32(sockfd)
    58  	ep.msghdr.Control = (*byte)(unsafe.Pointer(&cmsgSlice[0]))
    59  	ep.cmsg = (*unix.Cmsghdr)(unsafe.Pointer(&cmsgSlice[0]))
    60  	// ep.msghdr.Controllen and ep.cmsg.* are mutated by recvmsg(2), so they're
    61  	// set before calling sendmsg/recvmsg.
    62  }
    63  
    64  // NewEndpoint is a convenience function that returns an initialized Endpoint
    65  // allocated on the heap.
    66  func NewEndpoint(sockfd int) *Endpoint {
    67  	ep := &Endpoint{}
    68  	ep.Init(sockfd)
    69  	return ep
    70  }
    71  
    72  // Destroy releases resources owned by ep. No other Endpoint methods may be
    73  // called after Destroy.
    74  func (ep *Endpoint) Destroy() {
    75  	unix.Close(int(ep.sockfd))
    76  	ep.sockfd = -1
    77  }
    78  
    79  // Shutdown causes concurrent and future calls to ep.SendFD(), ep.RecvFD(), and
    80  // ep.RecvFDNonblock(), as well as the same calls in the connected Endpoint, to
    81  // unblock and return errors. It does not wait for concurrent calls to return.
    82  //
    83  // Shutdown is the only Endpoint method that may be called concurrently with
    84  // other methods.
    85  func (ep *Endpoint) Shutdown() {
    86  	unix.Shutdown(int(ep.sockfd), unix.SHUT_RDWR)
    87  }
    88  
    89  // SendFD sends the open file description represented by the given file
    90  // descriptor to the connected Endpoint.
    91  func (ep *Endpoint) SendFD(fd int) error {
    92  	cmsgLen := unix.CmsgLen(sizeofInt32)
    93  	ep.cmsg.Level = unix.SOL_SOCKET
    94  	ep.cmsg.Type = unix.SCM_RIGHTS
    95  	ep.cmsg.SetLen(cmsgLen)
    96  	*ep.cmsgData() = int32(fd)
    97  	ep.msghdr.SetControllen(cmsgLen)
    98  	_, _, e := unix.Syscall(unix.SYS_SENDMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), 0)
    99  	if e != 0 {
   100  		return e
   101  	}
   102  	return nil
   103  }
   104  
   105  // RecvFD receives an open file description from the connected Endpoint and
   106  // returns a file descriptor representing it, owned by the caller.
   107  func (ep *Endpoint) RecvFD() (int, error) {
   108  	return ep.recvFD(false)
   109  }
   110  
   111  // RecvFDNonblock receives an open file description from the connected Endpoint
   112  // and returns a file descriptor representing it, owned by the caller. If there
   113  // are no pending receivable open file descriptions, RecvFDNonblock returns
   114  // (<unspecified>, EAGAIN or EWOULDBLOCK).
   115  func (ep *Endpoint) RecvFDNonblock() (int, error) {
   116  	return ep.recvFD(true)
   117  }
   118  
   119  func (ep *Endpoint) recvFD(nonblock bool) (int, error) {
   120  	cmsgLen := unix.CmsgLen(sizeofInt32)
   121  	ep.msghdr.SetControllen(cmsgLen)
   122  	var e unix.Errno
   123  	if nonblock {
   124  		_, _, e = unix.RawSyscall(unix.SYS_RECVMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), unix.MSG_TRUNC|unix.MSG_DONTWAIT)
   125  	} else {
   126  		_, _, e = unix.Syscall(unix.SYS_RECVMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), unix.MSG_TRUNC)
   127  	}
   128  	if e != 0 {
   129  		return -1, e
   130  	}
   131  	if int(ep.msghdr.Controllen) != cmsgLen {
   132  		return -1, fmt.Errorf("received control message has incorrect length: got %d, wanted %d", ep.msghdr.Controllen, cmsgLen)
   133  	}
   134  	if ep.cmsg.Level != unix.SOL_SOCKET || ep.cmsg.Type != unix.SCM_RIGHTS {
   135  		return -1, fmt.Errorf("received control message has incorrect (level, type): got (%v, %v), wanted (%v, %v)", ep.cmsg.Level, ep.cmsg.Type, unix.SOL_SOCKET, unix.SCM_RIGHTS)
   136  	}
   137  	return int(*ep.cmsgData()), nil
   138  }
   139  
   140  func (ep *Endpoint) cmsgData() *int32 {
   141  	// unix.CmsgLen(0) == unix.cmsgAlignOf(unix.SizeofCmsghdr)
   142  	return (*int32)(unsafe.Pointer(uintptr(unsafe.Pointer(ep.cmsg)) + uintptr(unix.CmsgLen(0))))
   143  }