github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/fdchannel/fdchannel_unsafe.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
    16  
    17  // Package fdchannel implements passing file descriptors between processes over
    18  // Unix domain sockets.
    19  package fdchannel
    20  
    21  import (
    22  	"fmt"
    23  	"unsafe"
    24  
    25  	"golang.org/x/sys/unix"
    26  	"github.com/SagerNet/gvisor/pkg/gohacks"
    27  )
    28  
    29  // int32 is the real type of a file descriptor.
    30  const sizeofInt32 = int(unsafe.Sizeof(int32(0)))
    31  
    32  // NewConnectedSockets returns a pair of file descriptors, owned by the caller,
    33  // representing connected sockets that may be passed to separate calls to
    34  // NewEndpoint to create connected Endpoints.
    35  func NewConnectedSockets() ([2]int, error) {
    36  	return unix.Socketpair(unix.AF_UNIX, unix.SOCK_SEQPACKET|unix.SOCK_CLOEXEC, 0)
    37  }
    38  
    39  // Endpoint sends file descriptors to, and receives them from, another
    40  // connected Endpoint.
    41  //
    42  // Endpoint is not copyable or movable by value.
    43  type Endpoint struct {
    44  	sockfd int32
    45  	msghdr unix.Msghdr
    46  	cmsg   *unix.Cmsghdr // followed by sizeofInt32 bytes of data
    47  }
    48  
    49  // Init must be called on zero-value Endpoints before first use. sockfd must be
    50  // a blocking AF_UNIX SOCK_SEQPACKET socket.
    51  func (ep *Endpoint) Init(sockfd int) {
    52  	// "Datagram sockets in various domains (e.g., the UNIX and Internet
    53  	// domains) permit zero-length datagrams." - recv(2). Experimentally,
    54  	// sendmsg+recvmsg for a zero-length datagram is slightly faster than
    55  	// sendmsg+recvmsg for a single byte over a stream socket.
    56  	cmsgSlice := make([]byte, unix.CmsgSpace(sizeofInt32))
    57  	cmsgSliceHdr := (*gohacks.SliceHeader)(unsafe.Pointer(&cmsgSlice))
    58  	ep.sockfd = int32(sockfd)
    59  	ep.msghdr.Control = (*byte)(cmsgSliceHdr.Data)
    60  	ep.cmsg = (*unix.Cmsghdr)(cmsgSliceHdr.Data)
    61  	// ep.msghdr.Controllen and ep.cmsg.* are mutated by recvmsg(2), so they're
    62  	// set before calling sendmsg/recvmsg.
    63  }
    64  
    65  // NewEndpoint is a convenience function that returns an initialized Endpoint
    66  // allocated on the heap.
    67  func NewEndpoint(sockfd int) *Endpoint {
    68  	ep := &Endpoint{}
    69  	ep.Init(sockfd)
    70  	return ep
    71  }
    72  
    73  // Destroy releases resources owned by ep. No other Endpoint methods may be
    74  // called after Destroy.
    75  func (ep *Endpoint) Destroy() {
    76  	unix.Close(int(ep.sockfd))
    77  	ep.sockfd = -1
    78  }
    79  
    80  // Shutdown causes concurrent and future calls to ep.SendFD(), ep.RecvFD(), and
    81  // ep.RecvFDNonblock(), as well as the same calls in the connected Endpoint, to
    82  // unblock and return errors. It does not wait for concurrent calls to return.
    83  //
    84  // Shutdown is the only Endpoint method that may be called concurrently with
    85  // other methods.
    86  func (ep *Endpoint) Shutdown() {
    87  	unix.Shutdown(int(ep.sockfd), unix.SHUT_RDWR)
    88  }
    89  
    90  // SendFD sends the open file description represented by the given file
    91  // descriptor to the connected Endpoint.
    92  func (ep *Endpoint) SendFD(fd int) error {
    93  	cmsgLen := unix.CmsgLen(sizeofInt32)
    94  	ep.cmsg.Level = unix.SOL_SOCKET
    95  	ep.cmsg.Type = unix.SCM_RIGHTS
    96  	ep.cmsg.SetLen(cmsgLen)
    97  	*ep.cmsgData() = int32(fd)
    98  	ep.msghdr.SetControllen(cmsgLen)
    99  	_, _, e := unix.Syscall(unix.SYS_SENDMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), 0)
   100  	if e != 0 {
   101  		return e
   102  	}
   103  	return nil
   104  }
   105  
   106  // RecvFD receives an open file description from the connected Endpoint and
   107  // returns a file descriptor representing it, owned by the caller.
   108  func (ep *Endpoint) RecvFD() (int, error) {
   109  	return ep.recvFD(false)
   110  }
   111  
   112  // RecvFDNonblock receives an open file description from the connected Endpoint
   113  // and returns a file descriptor representing it, owned by the caller. If there
   114  // are no pending receivable open file descriptions, RecvFDNonblock returns
   115  // (<unspecified>, EAGAIN or EWOULDBLOCK).
   116  func (ep *Endpoint) RecvFDNonblock() (int, error) {
   117  	return ep.recvFD(true)
   118  }
   119  
   120  func (ep *Endpoint) recvFD(nonblock bool) (int, error) {
   121  	cmsgLen := unix.CmsgLen(sizeofInt32)
   122  	ep.msghdr.SetControllen(cmsgLen)
   123  	var e unix.Errno
   124  	if nonblock {
   125  		_, _, e = unix.RawSyscall(unix.SYS_RECVMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), unix.MSG_TRUNC|unix.MSG_DONTWAIT)
   126  	} else {
   127  		_, _, e = unix.Syscall(unix.SYS_RECVMSG, uintptr(ep.sockfd), uintptr(unsafe.Pointer(&ep.msghdr)), unix.MSG_TRUNC)
   128  	}
   129  	if e != 0 {
   130  		return -1, e
   131  	}
   132  	if int(ep.msghdr.Controllen) != cmsgLen {
   133  		return -1, fmt.Errorf("received control message has incorrect length: got %d, wanted %d", ep.msghdr.Controllen, cmsgLen)
   134  	}
   135  	if ep.cmsg.Level != unix.SOL_SOCKET || ep.cmsg.Type != unix.SCM_RIGHTS {
   136  		return -1, fmt.Errorf("received control message has incorrect (level, type): got (%v, %v), wanted (%v, %v)", ep.cmsg.Level, ep.cmsg.Type, unix.SOL_SOCKET, unix.SCM_RIGHTS)
   137  	}
   138  	return int(*ep.cmsgData()), nil
   139  }
   140  
   141  func (ep *Endpoint) cmsgData() *int32 {
   142  	// unix.CmsgLen(0) == unix.cmsgAlignOf(unix.SizeofCmsghdr)
   143  	return (*int32)(unsafe.Pointer(uintptr(unsafe.Pointer(ep.cmsg)) + uintptr(unix.CmsgLen(0))))
   144  }