github.com/khulnasoft/cli@v0.0.0-20240402070845-01bcad7beefa/cli/connhelper/commandconn/commandconn.go (about) 1 // Package commandconn provides a net.Conn implementation that can be used for 2 // proxying (or emulating) stream via a custom command. 3 // 4 // For example, to provide an http.Client that can connect to a Docker daemon 5 // running in a Docker container ("DIND"): 6 // 7 // httpClient := &http.Client{ 8 // Transport: &http.Transport{ 9 // DialContext: func(ctx context.Context, _network, _addr string) (net.Conn, error) { 10 // return commandconn.New(ctx, "docker", "exec", "-it", containerID, "docker", "system", "dial-stdio") 11 // }, 12 // }, 13 // } 14 package commandconn 15 16 import ( 17 "bytes" 18 "context" 19 "fmt" 20 "io" 21 "net" 22 "os" 23 "os/exec" 24 "runtime" 25 "strings" 26 "sync" 27 "sync/atomic" 28 "syscall" 29 "time" 30 31 "github.com/pkg/errors" 32 "github.com/sirupsen/logrus" 33 ) 34 35 // New returns net.Conn 36 func New(_ context.Context, cmd string, args ...string) (net.Conn, error) { 37 var ( 38 c commandConn 39 err error 40 ) 41 c.cmd = exec.Command(cmd, args...) 42 // we assume that args never contains sensitive information 43 logrus.Debugf("commandconn: starting %s with %v", cmd, args) 44 c.cmd.Env = os.Environ() 45 c.cmd.SysProcAttr = &syscall.SysProcAttr{} 46 setPdeathsig(c.cmd) 47 createSession(c.cmd) 48 c.stdin, err = c.cmd.StdinPipe() 49 if err != nil { 50 return nil, err 51 } 52 c.stdout, err = c.cmd.StdoutPipe() 53 if err != nil { 54 return nil, err 55 } 56 c.cmd.Stderr = &stderrWriter{ 57 stderrMu: &c.stderrMu, 58 stderr: &c.stderr, 59 debugPrefix: fmt.Sprintf("commandconn (%s):", cmd), 60 } 61 c.localAddr = dummyAddr{network: "dummy", s: "dummy-0"} 62 c.remoteAddr = dummyAddr{network: "dummy", s: "dummy-1"} 63 return &c, c.cmd.Start() 64 } 65 66 // commandConn implements net.Conn 67 type commandConn struct { 68 cmdMutex sync.Mutex // for cmd, cmdWaitErr 69 cmd *exec.Cmd 70 cmdWaitErr error 71 cmdExited atomic.Bool 72 stdin io.WriteCloser 73 stdout io.ReadCloser 74 stderrMu sync.Mutex // for stderr 75 stderr bytes.Buffer 76 stdinClosed atomic.Bool 77 stdoutClosed atomic.Bool 78 closing atomic.Bool 79 localAddr net.Addr 80 remoteAddr net.Addr 81 } 82 83 // kill terminates the process. On Windows it kills the process directly, 84 // whereas on other platforms, a SIGTERM is sent, before forcefully terminating 85 // the process after 3 seconds. 86 func (c *commandConn) kill() { 87 if c.cmdExited.Load() { 88 return 89 } 90 c.cmdMutex.Lock() 91 var werr error 92 if runtime.GOOS != "windows" { 93 werrCh := make(chan error) 94 go func() { werrCh <- c.cmd.Wait() }() 95 _ = c.cmd.Process.Signal(syscall.SIGTERM) 96 select { 97 case werr = <-werrCh: 98 case <-time.After(3 * time.Second): 99 _ = c.cmd.Process.Kill() 100 werr = <-werrCh 101 } 102 } else { 103 _ = c.cmd.Process.Kill() 104 werr = c.cmd.Wait() 105 } 106 c.cmdWaitErr = werr 107 c.cmdMutex.Unlock() 108 c.cmdExited.Store(true) 109 } 110 111 // handleEOF handles io.EOF errors while reading or writing from the underlying 112 // command pipes. 113 // 114 // When we've received an EOF we expect that the command will 115 // be terminated soon. As such, we call Wait() on the command 116 // and return EOF or the error depending on whether the command 117 // exited with an error. 118 // 119 // If Wait() does not return within 10s, an error is returned 120 func (c *commandConn) handleEOF(err error) error { 121 if err != io.EOF { 122 return err 123 } 124 125 c.cmdMutex.Lock() 126 defer c.cmdMutex.Unlock() 127 128 var werr error 129 if c.cmdExited.Load() { 130 werr = c.cmdWaitErr 131 } else { 132 werrCh := make(chan error) 133 go func() { werrCh <- c.cmd.Wait() }() 134 select { 135 case werr = <-werrCh: 136 c.cmdWaitErr = werr 137 c.cmdExited.Store(true) 138 case <-time.After(10 * time.Second): 139 c.stderrMu.Lock() 140 stderr := c.stderr.String() 141 c.stderrMu.Unlock() 142 return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, err, stderr) 143 } 144 } 145 146 if werr == nil { 147 return err 148 } 149 c.stderrMu.Lock() 150 stderr := c.stderr.String() 151 c.stderrMu.Unlock() 152 return errors.Errorf("command %v has exited with %v, please make sure the URL is valid, and Docker 18.09 or later is installed on the remote host: stderr=%s", c.cmd.Args, werr, stderr) 153 } 154 155 func ignorableCloseError(err error) bool { 156 return strings.Contains(err.Error(), os.ErrClosed.Error()) 157 } 158 159 func (c *commandConn) Read(p []byte) (int, error) { 160 n, err := c.stdout.Read(p) 161 // check after the call to Read, since 162 // it is blocking, and while waiting on it 163 // Close might get called 164 if c.closing.Load() { 165 // If we're currently closing the connection 166 // we don't want to call onEOF 167 return n, err 168 } 169 170 return n, c.handleEOF(err) 171 } 172 173 func (c *commandConn) Write(p []byte) (int, error) { 174 n, err := c.stdin.Write(p) 175 // check after the call to Write, since 176 // it is blocking, and while waiting on it 177 // Close might get called 178 if c.closing.Load() { 179 // If we're currently closing the connection 180 // we don't want to call onEOF 181 return n, err 182 } 183 184 return n, c.handleEOF(err) 185 } 186 187 // CloseRead allows commandConn to implement halfCloser 188 func (c *commandConn) CloseRead() error { 189 // NOTE: maybe already closed here 190 if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) { 191 return err 192 } 193 c.stdoutClosed.Store(true) 194 195 if c.stdinClosed.Load() { 196 c.kill() 197 } 198 199 return nil 200 } 201 202 // CloseWrite allows commandConn to implement halfCloser 203 func (c *commandConn) CloseWrite() error { 204 // NOTE: maybe already closed here 205 if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) { 206 return err 207 } 208 c.stdinClosed.Store(true) 209 210 if c.stdoutClosed.Load() { 211 c.kill() 212 } 213 return nil 214 } 215 216 // Close is the net.Conn func that gets called 217 // by the transport when a dial is cancelled 218 // due to it's context timing out. Any blocked 219 // Read or Write calls will be unblocked and 220 // return errors. It will block until the underlying 221 // command has terminated. 222 func (c *commandConn) Close() error { 223 c.closing.Store(true) 224 defer c.closing.Store(false) 225 226 if err := c.CloseRead(); err != nil { 227 logrus.Warnf("commandConn.Close: CloseRead: %v", err) 228 return err 229 } 230 if err := c.CloseWrite(); err != nil { 231 logrus.Warnf("commandConn.Close: CloseWrite: %v", err) 232 return err 233 } 234 235 return nil 236 } 237 238 func (c *commandConn) LocalAddr() net.Addr { 239 return c.localAddr 240 } 241 242 func (c *commandConn) RemoteAddr() net.Addr { 243 return c.remoteAddr 244 } 245 246 func (c *commandConn) SetDeadline(t time.Time) error { 247 logrus.Debugf("unimplemented call: SetDeadline(%v)", t) 248 return nil 249 } 250 251 func (c *commandConn) SetReadDeadline(t time.Time) error { 252 logrus.Debugf("unimplemented call: SetReadDeadline(%v)", t) 253 return nil 254 } 255 256 func (c *commandConn) SetWriteDeadline(t time.Time) error { 257 logrus.Debugf("unimplemented call: SetWriteDeadline(%v)", t) 258 return nil 259 } 260 261 type dummyAddr struct { 262 network string 263 s string 264 } 265 266 func (d dummyAddr) Network() string { 267 return d.network 268 } 269 270 func (d dummyAddr) String() string { 271 return d.s 272 } 273 274 type stderrWriter struct { 275 stderrMu *sync.Mutex 276 stderr *bytes.Buffer 277 debugPrefix string 278 } 279 280 func (w *stderrWriter) Write(p []byte) (int, error) { 281 logrus.Debugf("%s%s", w.debugPrefix, string(p)) 282 w.stderrMu.Lock() 283 if w.stderr.Len() > 4096 { 284 w.stderr.Reset() 285 } 286 n, err := w.stderr.Write(p) 287 w.stderrMu.Unlock() 288 return n, err 289 }