github.com/khulnasoft/cli@v0.0.0-20240402070845-01bcad7beefa/cli/connhelper/commandconn/commandconn.go (about)

     1  // Package commandconn provides a net.Conn implementation that can be used for
     2  // proxying (or emulating) stream via a custom command.
     3  //
     4  // For example, to provide an http.Client that can connect to a Docker daemon
     5  // running in a Docker container ("DIND"):
     6  //
     7  //	httpClient := &http.Client{
     8  //		Transport: &http.Transport{
     9  //			DialContext: func(ctx context.Context, _network, _addr string) (net.Conn, error) {
    10  //				return commandconn.New(ctx, "docker", "exec", "-it", containerID, "docker", "system", "dial-stdio")
    11  //			},
    12  //		},
    13  //	}
    14  package commandconn
    15  
    16  import (
    17  	"bytes"
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"net"
    22  	"os"
    23  	"os/exec"
    24  	"runtime"
    25  	"strings"
    26  	"sync"
    27  	"sync/atomic"
    28  	"syscall"
    29  	"time"
    30  
    31  	"github.com/pkg/errors"
    32  	"github.com/sirupsen/logrus"
    33  )
    34  
    35  // New returns net.Conn
    36  func New(_ context.Context, cmd string, args ...string) (net.Conn, error) {
    37  	var (
    38  		c   commandConn
    39  		err error
    40  	)
    41  	c.cmd = exec.Command(cmd, args...)
    42  	// we assume that args never contains sensitive information
    43  	logrus.Debugf("commandconn: starting %s with %v", cmd, args)
    44  	c.cmd.Env = os.Environ()
    45  	c.cmd.SysProcAttr = &syscall.SysProcAttr{}
    46  	setPdeathsig(c.cmd)
    47  	createSession(c.cmd)
    48  	c.stdin, err = c.cmd.StdinPipe()
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  	c.stdout, err = c.cmd.StdoutPipe()
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  	c.cmd.Stderr = &stderrWriter{
    57  		stderrMu:    &c.stderrMu,
    58  		stderr:      &c.stderr,
    59  		debugPrefix: fmt.Sprintf("commandconn (%s):", cmd),
    60  	}
    61  	c.localAddr = dummyAddr{network: "dummy", s: "dummy-0"}
    62  	c.remoteAddr = dummyAddr{network: "dummy", s: "dummy-1"}
    63  	return &c, c.cmd.Start()
    64  }
    65  
    66  // commandConn implements net.Conn
    67  type commandConn struct {
    68  	cmdMutex     sync.Mutex // for cmd, cmdWaitErr
    69  	cmd          *exec.Cmd
    70  	cmdWaitErr   error
    71  	cmdExited    atomic.Bool
    72  	stdin        io.WriteCloser
    73  	stdout       io.ReadCloser
    74  	stderrMu     sync.Mutex // for stderr
    75  	stderr       bytes.Buffer
    76  	stdinClosed  atomic.Bool
    77  	stdoutClosed atomic.Bool
    78  	closing      atomic.Bool
    79  	localAddr    net.Addr
    80  	remoteAddr   net.Addr
    81  }
    82  
    83  // kill terminates the process. On Windows it kills the process directly,
    84  // whereas on other platforms, a SIGTERM is sent, before forcefully terminating
    85  // the process after 3 seconds.
    86  func (c *commandConn) kill() {
    87  	if c.cmdExited.Load() {
    88  		return
    89  	}
    90  	c.cmdMutex.Lock()
    91  	var werr error
    92  	if runtime.GOOS != "windows" {
    93  		werrCh := make(chan error)
    94  		go func() { werrCh <- c.cmd.Wait() }()
    95  		_ = c.cmd.Process.Signal(syscall.SIGTERM)
    96  		select {
    97  		case werr = <-werrCh:
    98  		case <-time.After(3 * time.Second):
    99  			_ = c.cmd.Process.Kill()
   100  			werr = <-werrCh
   101  		}
   102  	} else {
   103  		_ = c.cmd.Process.Kill()
   104  		werr = c.cmd.Wait()
   105  	}
   106  	c.cmdWaitErr = werr
   107  	c.cmdMutex.Unlock()
   108  	c.cmdExited.Store(true)
   109  }
   110  
   111  // handleEOF handles io.EOF errors while reading or writing from the underlying
   112  // command pipes.
   113  //
   114  // When we've received an EOF we expect that the command will
   115  // be terminated soon. As such, we call Wait() on the command
   116  // and return EOF or the error depending on whether the command
   117  // exited with an error.
   118  //
   119  // If Wait() does not return within 10s, an error is returned
   120  func (c *commandConn) handleEOF(err error) error {
   121  	if err != io.EOF {
   122  		return err
   123  	}
   124  
   125  	c.cmdMutex.Lock()
   126  	defer c.cmdMutex.Unlock()
   127  
   128  	var werr error
   129  	if c.cmdExited.Load() {
   130  		werr = c.cmdWaitErr
   131  	} else {
   132  		werrCh := make(chan error)
   133  		go func() { werrCh <- c.cmd.Wait() }()
   134  		select {
   135  		case werr = <-werrCh:
   136  			c.cmdWaitErr = werr
   137  			c.cmdExited.Store(true)
   138  		case <-time.After(10 * time.Second):
   139  			c.stderrMu.Lock()
   140  			stderr := c.stderr.String()
   141  			c.stderrMu.Unlock()
   142  			return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, err, stderr)
   143  		}
   144  	}
   145  
   146  	if werr == nil {
   147  		return err
   148  	}
   149  	c.stderrMu.Lock()
   150  	stderr := c.stderr.String()
   151  	c.stderrMu.Unlock()
   152  	return errors.Errorf("command %v has exited with %v, please make sure the URL is valid, and Docker 18.09 or later is installed on the remote host: stderr=%s", c.cmd.Args, werr, stderr)
   153  }
   154  
   155  func ignorableCloseError(err error) bool {
   156  	return strings.Contains(err.Error(), os.ErrClosed.Error())
   157  }
   158  
   159  func (c *commandConn) Read(p []byte) (int, error) {
   160  	n, err := c.stdout.Read(p)
   161  	// check after the call to Read, since
   162  	// it is blocking, and while waiting on it
   163  	// Close might get called
   164  	if c.closing.Load() {
   165  		// If we're currently closing the connection
   166  		// we don't want to call onEOF
   167  		return n, err
   168  	}
   169  
   170  	return n, c.handleEOF(err)
   171  }
   172  
   173  func (c *commandConn) Write(p []byte) (int, error) {
   174  	n, err := c.stdin.Write(p)
   175  	// check after the call to Write, since
   176  	// it is blocking, and while waiting on it
   177  	// Close might get called
   178  	if c.closing.Load() {
   179  		// If we're currently closing the connection
   180  		// we don't want to call onEOF
   181  		return n, err
   182  	}
   183  
   184  	return n, c.handleEOF(err)
   185  }
   186  
   187  // CloseRead allows commandConn to implement halfCloser
   188  func (c *commandConn) CloseRead() error {
   189  	// NOTE: maybe already closed here
   190  	if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
   191  		return err
   192  	}
   193  	c.stdoutClosed.Store(true)
   194  
   195  	if c.stdinClosed.Load() {
   196  		c.kill()
   197  	}
   198  
   199  	return nil
   200  }
   201  
   202  // CloseWrite allows commandConn to implement halfCloser
   203  func (c *commandConn) CloseWrite() error {
   204  	// NOTE: maybe already closed here
   205  	if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
   206  		return err
   207  	}
   208  	c.stdinClosed.Store(true)
   209  
   210  	if c.stdoutClosed.Load() {
   211  		c.kill()
   212  	}
   213  	return nil
   214  }
   215  
   216  // Close is the net.Conn func that gets called
   217  // by the transport when a dial is cancelled
   218  // due to it's context timing out. Any blocked
   219  // Read or Write calls will be unblocked and
   220  // return errors. It will block until the underlying
   221  // command has terminated.
   222  func (c *commandConn) Close() error {
   223  	c.closing.Store(true)
   224  	defer c.closing.Store(false)
   225  
   226  	if err := c.CloseRead(); err != nil {
   227  		logrus.Warnf("commandConn.Close: CloseRead: %v", err)
   228  		return err
   229  	}
   230  	if err := c.CloseWrite(); err != nil {
   231  		logrus.Warnf("commandConn.Close: CloseWrite: %v", err)
   232  		return err
   233  	}
   234  
   235  	return nil
   236  }
   237  
   238  func (c *commandConn) LocalAddr() net.Addr {
   239  	return c.localAddr
   240  }
   241  
   242  func (c *commandConn) RemoteAddr() net.Addr {
   243  	return c.remoteAddr
   244  }
   245  
   246  func (c *commandConn) SetDeadline(t time.Time) error {
   247  	logrus.Debugf("unimplemented call: SetDeadline(%v)", t)
   248  	return nil
   249  }
   250  
   251  func (c *commandConn) SetReadDeadline(t time.Time) error {
   252  	logrus.Debugf("unimplemented call: SetReadDeadline(%v)", t)
   253  	return nil
   254  }
   255  
   256  func (c *commandConn) SetWriteDeadline(t time.Time) error {
   257  	logrus.Debugf("unimplemented call: SetWriteDeadline(%v)", t)
   258  	return nil
   259  }
   260  
   261  type dummyAddr struct {
   262  	network string
   263  	s       string
   264  }
   265  
   266  func (d dummyAddr) Network() string {
   267  	return d.network
   268  }
   269  
   270  func (d dummyAddr) String() string {
   271  	return d.s
   272  }
   273  
   274  type stderrWriter struct {
   275  	stderrMu    *sync.Mutex
   276  	stderr      *bytes.Buffer
   277  	debugPrefix string
   278  }
   279  
   280  func (w *stderrWriter) Write(p []byte) (int, error) {
   281  	logrus.Debugf("%s%s", w.debugPrefix, string(p))
   282  	w.stderrMu.Lock()
   283  	if w.stderr.Len() > 4096 {
   284  		w.stderr.Reset()
   285  	}
   286  	n, err := w.stderr.Write(p)
   287  	w.stderrMu.Unlock()
   288  	return n, err
   289  }