github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/fdsrv/fdsrv.go (about) 1 // Copyright 2022 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Serves a file descriptor over an AF_UNIX socket when presented with a nonce. 6 // 7 // You must pass the socket path and nonce to the client via some out-of-band 8 // mechanism, such as gRPC or a bash script. 9 // 10 // Notes: 11 // - Uses the unix domain socket abstract namespace 12 // - Picks its own path in the abstract namespace for the socket. 13 // - Shared FDs are essentially duped, and they point to the same struct file: 14 // they share offsets and whatnot. 15 // 16 // Options: 17 // - WithServeOnce: serve once and shuts down (default is forever) 18 // - WithTimeout: cancel itself after a timeout (default none) 19 // 20 // Usage Server: 21 // 22 // fds, err := NewServer(fd_to_share, "some_nonce", WithServeOnce()) 23 // var s path = fds.UDSPath() 24 // 25 // // Pass path and some_nonce to the client via an out of band mechanism 26 // 27 // fds.Serve(); // Blocks until the server is done 28 // fds.Close() 29 // 30 // Usage Client: 31 // 32 // sfd, err := GetSharedFD("uds_path", "some_nonce") 33 package fdsrv 34 35 import ( 36 "errors" 37 "io" 38 "net" 39 "os" 40 "syscall" 41 "time" 42 ) 43 44 var ( 45 ErrTruncatedWrite = errors.New("truncated write") 46 ErrEmptyNonce = errors.New("nonce must not be empty") 47 ErrMissingSCM = errors.New("missing socket control message") 48 ErrNotOneUnixRights = errors.New("expected exactly one unix rights") 49 ) 50 51 type Server struct { 52 dupedFD int 53 nonce string 54 listener *net.UnixListener 55 timeout time.Duration 56 serveOnce bool 57 } 58 59 // Serves the fd, returns true if successful, err for a server error. 60 // "false, nil" means the client was wrong, not the server. 61 func (fds *Server) handleConnection(uc *net.UnixConn) (bool, error) { 62 defer uc.Close() 63 64 buf := make([]byte, 4096) 65 n, err := uc.Read(buf) 66 if err != nil { 67 return false, err 68 } 69 query := string(buf[:n]) 70 if query != fds.nonce { 71 io.WriteString(uc, "BAD NONCE") 72 return false, nil 73 } 74 oob := syscall.UnixRights(fds.dupedFD) 75 good := []byte("GOOD NONCE") 76 goodn, oobn, err := uc.WriteMsgUnix(good, oob, nil) 77 if err != nil { 78 return false, err 79 } 80 if goodn != len(good) || oobn != len(oob) { 81 return false, ErrTruncatedWrite 82 } 83 return true, nil 84 } 85 86 // NewServer creates a server. Close() it when you're done. 87 func NewServer(fd int, nonce string, options ...func(*Server) error) (*Server, error) { 88 var err error 89 fds := &Server{} 90 91 if len(nonce) == 0 { 92 return nil, ErrEmptyNonce 93 } 94 fds.nonce = nonce 95 96 for _, op := range options { 97 if err := op(fds); err != nil { 98 return nil, err 99 } 100 } 101 102 // An empty addr tells Linux to "autobind" to an available path in the 103 // abstract unix domain socket namespace 104 ua, err := net.ResolveUnixAddr("unix", "") 105 if err != nil { 106 return nil, err 107 } 108 fds.listener, err = net.ListenUnix("unix", ua) 109 if err != nil { 110 return nil, err 111 } 112 113 // Caller could close the file while we are running. Keep our own copy. 114 fds.dupedFD, err = syscall.Dup(int(fd)) 115 if err != nil { 116 fds.listener.Close() 117 return nil, err 118 } 119 120 return fds, nil 121 } 122 123 // WithTimeOut adds a timeout option to NewServer 124 func WithTimeout(timeout time.Duration) func(*Server) error { 125 return func(fds *Server) error { 126 fds.timeout = timeout 127 return nil 128 } 129 } 130 131 // WithServeOnce sets the "serve once and exit" option to NewServer 132 func WithServeOnce() func(*Server) error { 133 return func(fds *Server) error { 134 fds.serveOnce = true 135 return nil 136 } 137 } 138 139 // UDSPath returns the Unix Domain Socket path the server is listening on 140 func (fds *Server) UDSPath() string { 141 return fds.listener.Addr().String() 142 } 143 144 // Close closes the server 145 func (fds *Server) Close() { 146 fds.listener.Close() 147 syscall.Close(fds.dupedFD) 148 } 149 150 // Serve serves the FD 151 func (fds *Server) Serve() error { 152 var deadline time.Time 153 if fds.timeout != 0 { 154 deadline = time.Now().Add(fds.timeout) 155 } 156 fds.listener.SetDeadline(deadline) 157 for { 158 conn, err := fds.listener.AcceptUnix() 159 // Clean up after ourselves, since we are initiating our own 160 // closure through the timeout. 161 if os.IsTimeout(err) { 162 fds.Close() 163 return err 164 } else if errors.Is(err, net.ErrClosed) { 165 return nil 166 } else if err != nil { 167 return err 168 } 169 conn.SetDeadline(deadline) 170 succeeded, err := fds.handleConnection(conn) 171 if err != nil { 172 return err 173 } 174 if succeeded && fds.serveOnce { 175 break 176 } 177 } 178 return nil 179 } 180 181 // GetSharedFD gets an FD served at udsPath with nonce 182 func GetSharedFD(udsPath, nonce string) (int, error) { 183 // If you don't send at least a byte, the server won't recvmsg. This 184 // is a Linux UDS SOCK_STREAM thing. 185 if len(nonce) == 0 { 186 return 0, ErrEmptyNonce 187 } 188 189 ua, err := net.ResolveUnixAddr("unix", udsPath) 190 if err != nil { 191 return 0, err 192 } 193 uc, err := net.DialUnix("unix", nil, ua) 194 if err != nil { 195 return 0, err 196 } 197 198 n, err := uc.Write([]byte(nonce)) 199 if err != nil { 200 return 0, err 201 } 202 if n != len(nonce) { 203 return 0, ErrTruncatedWrite 204 } 205 206 oob := make([]byte, 1024) 207 _, oobn, _, _, err := uc.ReadMsgUnix(nil, oob) 208 if err != nil { 209 return 0, err 210 } 211 scm, err := syscall.ParseSocketControlMessage(oob[:oobn]) 212 if err != nil { 213 return 0, err 214 } 215 if len(scm) != 1 { 216 return 0, ErrMissingSCM 217 } 218 urs, err := syscall.ParseUnixRights(&scm[0]) 219 if err != nil { 220 return 0, err 221 } 222 if len(urs) != 1 { 223 return 0, ErrNotOneUnixRights 224 } 225 return urs[0], nil 226 }