github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/fdsrv/fdsrv.go (about)

     1  // Copyright 2022 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Serves a file descriptor over an AF_UNIX socket when presented with a nonce.
     6  //
     7  // You must pass the socket path and nonce to the client via some out-of-band
     8  // mechanism, such as gRPC or a bash script.
     9  //
    10  // Notes:
    11  // - Uses the unix domain socket abstract namespace
    12  // - Picks its own path in the abstract namespace for the socket.
    13  // - Shared FDs are essentially duped, and they point to the same struct file:
    14  // they share offsets and whatnot.
    15  //
    16  // Options:
    17  // - WithServeOnce: serve once and shuts down (default is forever)
    18  // - WithTimeout: cancel itself after a timeout (default none)
    19  //
    20  // Usage Server:
    21  //
    22  //	fds, err := NewServer(fd_to_share, "some_nonce", WithServeOnce())
    23  //	var s path = fds.UDSPath()
    24  //
    25  //	// Pass path and some_nonce to the client via an out of band mechanism
    26  //
    27  //	fds.Serve(); // Blocks until the server is done
    28  //	fds.Close()
    29  //
    30  // Usage Client:
    31  //
    32  //	sfd, err := GetSharedFD("uds_path", "some_nonce")
    33  package fdsrv
    34  
    35  import (
    36  	"errors"
    37  	"io"
    38  	"net"
    39  	"os"
    40  	"syscall"
    41  	"time"
    42  )
    43  
    44  var (
    45  	ErrTruncatedWrite   = errors.New("truncated write")
    46  	ErrEmptyNonce       = errors.New("nonce must not be empty")
    47  	ErrMissingSCM       = errors.New("missing socket control message")
    48  	ErrNotOneUnixRights = errors.New("expected exactly one unix rights")
    49  )
    50  
    51  type Server struct {
    52  	dupedFD   int
    53  	nonce     string
    54  	listener  *net.UnixListener
    55  	timeout   time.Duration
    56  	serveOnce bool
    57  }
    58  
    59  // Serves the fd, returns true if successful, err for a server error.
    60  // "false, nil" means the client was wrong, not the server.
    61  func (fds *Server) handleConnection(uc *net.UnixConn) (bool, error) {
    62  	defer uc.Close()
    63  
    64  	buf := make([]byte, 4096)
    65  	n, err := uc.Read(buf)
    66  	if err != nil {
    67  		return false, err
    68  	}
    69  	query := string(buf[:n])
    70  	if query != fds.nonce {
    71  		io.WriteString(uc, "BAD NONCE")
    72  		return false, nil
    73  	}
    74  	oob := syscall.UnixRights(fds.dupedFD)
    75  	good := []byte("GOOD NONCE")
    76  	goodn, oobn, err := uc.WriteMsgUnix(good, oob, nil)
    77  	if err != nil {
    78  		return false, err
    79  	}
    80  	if goodn != len(good) || oobn != len(oob) {
    81  		return false, ErrTruncatedWrite
    82  	}
    83  	return true, nil
    84  }
    85  
    86  // NewServer creates a server.  Close() it when you're done.
    87  func NewServer(fd int, nonce string, options ...func(*Server) error) (*Server, error) {
    88  	var err error
    89  	fds := &Server{}
    90  
    91  	if len(nonce) == 0 {
    92  		return nil, ErrEmptyNonce
    93  	}
    94  	fds.nonce = nonce
    95  
    96  	for _, op := range options {
    97  		if err := op(fds); err != nil {
    98  			return nil, err
    99  		}
   100  	}
   101  
   102  	// An empty addr tells Linux to "autobind" to an available path in the
   103  	// abstract unix domain socket namespace
   104  	ua, err := net.ResolveUnixAddr("unix", "")
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	fds.listener, err = net.ListenUnix("unix", ua)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	// Caller could close the file while we are running.  Keep our own copy.
   114  	fds.dupedFD, err = syscall.Dup(int(fd))
   115  	if err != nil {
   116  		fds.listener.Close()
   117  		return nil, err
   118  	}
   119  
   120  	return fds, nil
   121  }
   122  
   123  // WithTimeOut adds a timeout option to NewServer
   124  func WithTimeout(timeout time.Duration) func(*Server) error {
   125  	return func(fds *Server) error {
   126  		fds.timeout = timeout
   127  		return nil
   128  	}
   129  }
   130  
   131  // WithServeOnce sets the "serve once and exit" option to NewServer
   132  func WithServeOnce() func(*Server) error {
   133  	return func(fds *Server) error {
   134  		fds.serveOnce = true
   135  		return nil
   136  	}
   137  }
   138  
   139  // UDSPath returns the Unix Domain Socket path the server is listening on
   140  func (fds *Server) UDSPath() string {
   141  	return fds.listener.Addr().String()
   142  }
   143  
   144  // Close closes the server
   145  func (fds *Server) Close() {
   146  	fds.listener.Close()
   147  	syscall.Close(fds.dupedFD)
   148  }
   149  
   150  // Serve serves the FD
   151  func (fds *Server) Serve() error {
   152  	var deadline time.Time
   153  	if fds.timeout != 0 {
   154  		deadline = time.Now().Add(fds.timeout)
   155  	}
   156  	fds.listener.SetDeadline(deadline)
   157  	for {
   158  		conn, err := fds.listener.AcceptUnix()
   159  		// Clean up after ourselves, since we are initiating our own
   160  		// closure through the timeout.
   161  		if os.IsTimeout(err) {
   162  			fds.Close()
   163  			return err
   164  		} else if errors.Is(err, net.ErrClosed) {
   165  			return nil
   166  		} else if err != nil {
   167  			return err
   168  		}
   169  		conn.SetDeadline(deadline)
   170  		succeeded, err := fds.handleConnection(conn)
   171  		if err != nil {
   172  			return err
   173  		}
   174  		if succeeded && fds.serveOnce {
   175  			break
   176  		}
   177  	}
   178  	return nil
   179  }
   180  
   181  // GetSharedFD gets an FD served at udsPath with nonce
   182  func GetSharedFD(udsPath, nonce string) (int, error) {
   183  	// If you don't send at least a byte, the server won't recvmsg.  This
   184  	// is a Linux UDS SOCK_STREAM thing.
   185  	if len(nonce) == 0 {
   186  		return 0, ErrEmptyNonce
   187  	}
   188  
   189  	ua, err := net.ResolveUnixAddr("unix", udsPath)
   190  	if err != nil {
   191  		return 0, err
   192  	}
   193  	uc, err := net.DialUnix("unix", nil, ua)
   194  	if err != nil {
   195  		return 0, err
   196  	}
   197  
   198  	n, err := uc.Write([]byte(nonce))
   199  	if err != nil {
   200  		return 0, err
   201  	}
   202  	if n != len(nonce) {
   203  		return 0, ErrTruncatedWrite
   204  	}
   205  
   206  	oob := make([]byte, 1024)
   207  	_, oobn, _, _, err := uc.ReadMsgUnix(nil, oob)
   208  	if err != nil {
   209  		return 0, err
   210  	}
   211  	scm, err := syscall.ParseSocketControlMessage(oob[:oobn])
   212  	if err != nil {
   213  		return 0, err
   214  	}
   215  	if len(scm) != 1 {
   216  		return 0, ErrMissingSCM
   217  	}
   218  	urs, err := syscall.ParseUnixRights(&scm[0])
   219  	if err != nil {
   220  		return 0, err
   221  	}
   222  	if len(urs) != 1 {
   223  		return 0, ErrNotOneUnixRights
   224  	}
   225  	return urs[0], nil
   226  }