github.com/panekj/cli@v0.0.0-20230304125325-467dd2f3797e/cli/connhelper/commandconn/commandconn.go (about)

     1  // Package commandconn provides a net.Conn implementation that can be used for
     2  // proxying (or emulating) stream via a custom command.
     3  //
     4  // For example, to provide an http.Client that can connect to a Docker daemon
     5  // running in a Docker container ("DIND"):
     6  //
     7  //	httpClient := &http.Client{
     8  //		Transport: &http.Transport{
     9  //			DialContext: func(ctx context.Context, _network, _addr string) (net.Conn, error) {
    10  //				return commandconn.New(ctx, "docker", "exec", "-it", containerID, "docker", "system", "dial-stdio")
    11  //			},
    12  //		},
    13  //	}
    14  package commandconn
    15  
    16  import (
    17  	"bytes"
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"net"
    22  	"os"
    23  	"runtime"
    24  	"strings"
    25  	"sync"
    26  	"syscall"
    27  	"time"
    28  
    29  	"github.com/pkg/errors"
    30  	"github.com/sirupsen/logrus"
    31  	exec "golang.org/x/sys/execabs"
    32  )
    33  
    34  // New returns net.Conn
    35  func New(ctx context.Context, cmd string, args ...string) (net.Conn, error) {
    36  	var (
    37  		c   commandConn
    38  		err error
    39  	)
    40  	c.cmd = exec.Command(cmd, args...)
    41  	// we assume that args never contains sensitive information
    42  	logrus.Debugf("commandconn: starting %s with %v", cmd, args)
    43  	c.cmd.Env = os.Environ()
    44  	c.cmd.SysProcAttr = &syscall.SysProcAttr{}
    45  	setPdeathsig(c.cmd)
    46  	createSession(c.cmd)
    47  	c.stdin, err = c.cmd.StdinPipe()
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  	c.stdout, err = c.cmd.StdoutPipe()
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	c.cmd.Stderr = &stderrWriter{
    56  		stderrMu:    &c.stderrMu,
    57  		stderr:      &c.stderr,
    58  		debugPrefix: fmt.Sprintf("commandconn (%s):", cmd),
    59  	}
    60  	c.localAddr = dummyAddr{network: "dummy", s: "dummy-0"}
    61  	c.remoteAddr = dummyAddr{network: "dummy", s: "dummy-1"}
    62  	return &c, c.cmd.Start()
    63  }
    64  
    65  // commandConn implements net.Conn
    66  type commandConn struct {
    67  	cmd           *exec.Cmd
    68  	cmdExited     bool
    69  	cmdWaitErr    error
    70  	cmdMutex      sync.Mutex
    71  	stdin         io.WriteCloser
    72  	stdout        io.ReadCloser
    73  	stderrMu      sync.Mutex
    74  	stderr        bytes.Buffer
    75  	stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed
    76  	stdinClosed   bool
    77  	stdoutClosed  bool
    78  	localAddr     net.Addr
    79  	remoteAddr    net.Addr
    80  }
    81  
    82  // killIfStdioClosed kills the cmd if both stdin and stdout are closed.
    83  func (c *commandConn) killIfStdioClosed() error {
    84  	c.stdioClosedMu.Lock()
    85  	stdioClosed := c.stdoutClosed && c.stdinClosed
    86  	c.stdioClosedMu.Unlock()
    87  	if !stdioClosed {
    88  		return nil
    89  	}
    90  	return c.kill()
    91  }
    92  
    93  // killAndWait tries sending SIGTERM to the process before sending SIGKILL.
    94  func killAndWait(cmd *exec.Cmd) error {
    95  	var werr error
    96  	if runtime.GOOS != "windows" {
    97  		werrCh := make(chan error)
    98  		go func() { werrCh <- cmd.Wait() }()
    99  		cmd.Process.Signal(syscall.SIGTERM)
   100  		select {
   101  		case werr = <-werrCh:
   102  		case <-time.After(3 * time.Second):
   103  			cmd.Process.Kill()
   104  			werr = <-werrCh
   105  		}
   106  	} else {
   107  		cmd.Process.Kill()
   108  		werr = cmd.Wait()
   109  	}
   110  	return werr
   111  }
   112  
   113  // kill returns nil if the command terminated, regardless to the exit status.
   114  func (c *commandConn) kill() error {
   115  	var werr error
   116  	c.cmdMutex.Lock()
   117  	if c.cmdExited {
   118  		werr = c.cmdWaitErr
   119  	} else {
   120  		werr = killAndWait(c.cmd)
   121  		c.cmdWaitErr = werr
   122  		c.cmdExited = true
   123  	}
   124  	c.cmdMutex.Unlock()
   125  	if werr == nil {
   126  		return nil
   127  	}
   128  	wExitErr, ok := werr.(*exec.ExitError)
   129  	if ok {
   130  		if wExitErr.ProcessState.Exited() {
   131  			return nil
   132  		}
   133  	}
   134  	return errors.Wrapf(werr, "commandconn: failed to wait")
   135  }
   136  
   137  func (c *commandConn) onEOF(eof error) error {
   138  	// when we got EOF, the command is going to be terminated
   139  	var werr error
   140  	c.cmdMutex.Lock()
   141  	if c.cmdExited {
   142  		werr = c.cmdWaitErr
   143  	} else {
   144  		werrCh := make(chan error)
   145  		go func() { werrCh <- c.cmd.Wait() }()
   146  		select {
   147  		case werr = <-werrCh:
   148  			c.cmdWaitErr = werr
   149  			c.cmdExited = true
   150  		case <-time.After(10 * time.Second):
   151  			c.cmdMutex.Unlock()
   152  			c.stderrMu.Lock()
   153  			stderr := c.stderr.String()
   154  			c.stderrMu.Unlock()
   155  			return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr)
   156  		}
   157  	}
   158  	c.cmdMutex.Unlock()
   159  	if werr == nil {
   160  		return eof
   161  	}
   162  	c.stderrMu.Lock()
   163  	stderr := c.stderr.String()
   164  	c.stderrMu.Unlock()
   165  	return errors.Errorf("command %v has exited with %v, please make sure the URL is valid, and Docker 18.09 or later is installed on the remote host: stderr=%s", c.cmd.Args, werr, stderr)
   166  }
   167  
   168  func ignorableCloseError(err error) bool {
   169  	errS := err.Error()
   170  	ss := []string{
   171  		os.ErrClosed.Error(),
   172  	}
   173  	for _, s := range ss {
   174  		if strings.Contains(errS, s) {
   175  			return true
   176  		}
   177  	}
   178  	return false
   179  }
   180  
   181  func (c *commandConn) CloseRead() error {
   182  	// NOTE: maybe already closed here
   183  	if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
   184  		logrus.Warnf("commandConn.CloseRead: %v", err)
   185  	}
   186  	c.stdioClosedMu.Lock()
   187  	c.stdoutClosed = true
   188  	c.stdioClosedMu.Unlock()
   189  	if err := c.killIfStdioClosed(); err != nil {
   190  		logrus.Warnf("commandConn.CloseRead: %v", err)
   191  	}
   192  	return nil
   193  }
   194  
   195  func (c *commandConn) Read(p []byte) (int, error) {
   196  	n, err := c.stdout.Read(p)
   197  	if err == io.EOF {
   198  		err = c.onEOF(err)
   199  	}
   200  	return n, err
   201  }
   202  
   203  func (c *commandConn) CloseWrite() error {
   204  	// NOTE: maybe already closed here
   205  	if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
   206  		logrus.Warnf("commandConn.CloseWrite: %v", err)
   207  	}
   208  	c.stdioClosedMu.Lock()
   209  	c.stdinClosed = true
   210  	c.stdioClosedMu.Unlock()
   211  	if err := c.killIfStdioClosed(); err != nil {
   212  		logrus.Warnf("commandConn.CloseWrite: %v", err)
   213  	}
   214  	return nil
   215  }
   216  
   217  func (c *commandConn) Write(p []byte) (int, error) {
   218  	n, err := c.stdin.Write(p)
   219  	if err == io.EOF {
   220  		err = c.onEOF(err)
   221  	}
   222  	return n, err
   223  }
   224  
   225  func (c *commandConn) Close() error {
   226  	var err error
   227  	if err = c.CloseRead(); err != nil {
   228  		logrus.Warnf("commandConn.Close: CloseRead: %v", err)
   229  	}
   230  	if err = c.CloseWrite(); err != nil {
   231  		logrus.Warnf("commandConn.Close: CloseWrite: %v", err)
   232  	}
   233  	return err
   234  }
   235  
   236  func (c *commandConn) LocalAddr() net.Addr {
   237  	return c.localAddr
   238  }
   239  
   240  func (c *commandConn) RemoteAddr() net.Addr {
   241  	return c.remoteAddr
   242  }
   243  
   244  func (c *commandConn) SetDeadline(t time.Time) error {
   245  	logrus.Debugf("unimplemented call: SetDeadline(%v)", t)
   246  	return nil
   247  }
   248  
   249  func (c *commandConn) SetReadDeadline(t time.Time) error {
   250  	logrus.Debugf("unimplemented call: SetReadDeadline(%v)", t)
   251  	return nil
   252  }
   253  
   254  func (c *commandConn) SetWriteDeadline(t time.Time) error {
   255  	logrus.Debugf("unimplemented call: SetWriteDeadline(%v)", t)
   256  	return nil
   257  }
   258  
   259  type dummyAddr struct {
   260  	network string
   261  	s       string
   262  }
   263  
   264  func (d dummyAddr) Network() string {
   265  	return d.network
   266  }
   267  
   268  func (d dummyAddr) String() string {
   269  	return d.s
   270  }
   271  
   272  type stderrWriter struct {
   273  	stderrMu    *sync.Mutex
   274  	stderr      *bytes.Buffer
   275  	debugPrefix string
   276  }
   277  
   278  func (w *stderrWriter) Write(p []byte) (int, error) {
   279  	logrus.Debugf("%s%s", w.debugPrefix, string(p))
   280  	w.stderrMu.Lock()
   281  	if w.stderr.Len() > 4096 {
   282  		w.stderr.Reset()
   283  	}
   284  	n, err := w.stderr.Write(p)
   285  	w.stderrMu.Unlock()
   286  	return n, err
   287  }