github.com/goern/docker@v1.9.0-rc1/api/client/hijack.go (about)

     1  package client
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"net/http/httputil"
    11  	"os"
    12  	"runtime"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/Sirupsen/logrus"
    17  	"github.com/docker/docker/api"
    18  	"github.com/docker/docker/autogen/dockerversion"
    19  	"github.com/docker/docker/pkg/stdcopy"
    20  	"github.com/docker/docker/pkg/term"
    21  )
    22  
    23  type tlsClientCon struct {
    24  	*tls.Conn
    25  	rawConn net.Conn
    26  }
    27  
    28  func (c *tlsClientCon) CloseWrite() error {
    29  	// Go standard tls.Conn doesn't provide the CloseWrite() method so we do it
    30  	// on its underlying connection.
    31  	if cwc, ok := c.rawConn.(interface {
    32  		CloseWrite() error
    33  	}); ok {
    34  		return cwc.CloseWrite()
    35  	}
    36  	return nil
    37  }
    38  
    39  func tlsDial(network, addr string, config *tls.Config) (net.Conn, error) {
    40  	return tlsDialWithDialer(new(net.Dialer), network, addr, config)
    41  }
    42  
    43  // We need to copy Go's implementation of tls.Dial (pkg/cryptor/tls/tls.go) in
    44  // order to return our custom tlsClientCon struct which holds both the tls.Conn
    45  // object _and_ its underlying raw connection. The rationale for this is that
    46  // we need to be able to close the write end of the connection when attaching,
    47  // which tls.Conn does not provide.
    48  func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) {
    49  	// We want the Timeout and Deadline values from dialer to cover the
    50  	// whole process: TCP connection and TLS handshake. This means that we
    51  	// also need to start our own timers now.
    52  	timeout := dialer.Timeout
    53  
    54  	if !dialer.Deadline.IsZero() {
    55  		deadlineTimeout := dialer.Deadline.Sub(time.Now())
    56  		if timeout == 0 || deadlineTimeout < timeout {
    57  			timeout = deadlineTimeout
    58  		}
    59  	}
    60  
    61  	var errChannel chan error
    62  
    63  	if timeout != 0 {
    64  		errChannel = make(chan error, 2)
    65  		time.AfterFunc(timeout, func() {
    66  			errChannel <- errors.New("")
    67  		})
    68  	}
    69  
    70  	rawConn, err := dialer.Dial(network, addr)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  	// When we set up a TCP connection for hijack, there could be long periods
    75  	// of inactivity (a long running command with no output) that in certain
    76  	// network setups may cause ECONNTIMEOUT, leaving the client in an unknown
    77  	// state. Setting TCP KeepAlive on the socket connection will prohibit
    78  	// ECONNTIMEOUT unless the socket connection truly is broken
    79  	if tcpConn, ok := rawConn.(*net.TCPConn); ok {
    80  		tcpConn.SetKeepAlive(true)
    81  		tcpConn.SetKeepAlivePeriod(30 * time.Second)
    82  	}
    83  
    84  	colonPos := strings.LastIndex(addr, ":")
    85  	if colonPos == -1 {
    86  		colonPos = len(addr)
    87  	}
    88  	hostname := addr[:colonPos]
    89  
    90  	// If no ServerName is set, infer the ServerName
    91  	// from the hostname we're connecting to.
    92  	if config.ServerName == "" {
    93  		// Make a copy to avoid polluting argument or default.
    94  		c := *config
    95  		c.ServerName = hostname
    96  		config = &c
    97  	}
    98  
    99  	conn := tls.Client(rawConn, config)
   100  
   101  	if timeout == 0 {
   102  		err = conn.Handshake()
   103  	} else {
   104  		go func() {
   105  			errChannel <- conn.Handshake()
   106  		}()
   107  
   108  		err = <-errChannel
   109  	}
   110  
   111  	if err != nil {
   112  		rawConn.Close()
   113  		return nil, err
   114  	}
   115  
   116  	// This is Docker difference with standard's crypto/tls package: returned a
   117  	// wrapper which holds both the TLS and raw connections.
   118  	return &tlsClientCon{conn, rawConn}, nil
   119  }
   120  
   121  func (cli *DockerCli) dial() (net.Conn, error) {
   122  	if cli.tlsConfig != nil && cli.proto != "unix" {
   123  		// Notice this isn't Go standard's tls.Dial function
   124  		return tlsDial(cli.proto, cli.addr, cli.tlsConfig)
   125  	}
   126  	return net.Dial(cli.proto, cli.addr)
   127  }
   128  
   129  func (cli *DockerCli) hijack(method, path string, setRawTerminal bool, in io.ReadCloser, stdout, stderr io.Writer, started chan io.Closer, data interface{}) error {
   130  	defer func() {
   131  		if started != nil {
   132  			close(started)
   133  		}
   134  	}()
   135  
   136  	params, err := cli.encodeData(data)
   137  	if err != nil {
   138  		return err
   139  	}
   140  	req, err := http.NewRequest(method, fmt.Sprintf("%s/v%s%s", cli.basePath, api.Version, path), params)
   141  	if err != nil {
   142  		return err
   143  	}
   144  
   145  	// Add CLI Config's HTTP Headers BEFORE we set the Docker headers
   146  	// then the user can't change OUR headers
   147  	for k, v := range cli.configFile.HTTPHeaders {
   148  		req.Header.Set(k, v)
   149  	}
   150  
   151  	req.Header.Set("User-Agent", "Docker-Client/"+dockerversion.VERSION+" ("+runtime.GOOS+")")
   152  	req.Header.Set("Content-Type", "text/plain")
   153  	req.Header.Set("Connection", "Upgrade")
   154  	req.Header.Set("Upgrade", "tcp")
   155  	req.Host = cli.addr
   156  
   157  	dial, err := cli.dial()
   158  	if err != nil {
   159  		if strings.Contains(err.Error(), "connection refused") {
   160  			return fmt.Errorf("Cannot connect to the Docker daemon. Is 'docker daemon' running on this host?")
   161  		}
   162  		return err
   163  	}
   164  
   165  	// When we set up a TCP connection for hijack, there could be long periods
   166  	// of inactivity (a long running command with no output) that in certain
   167  	// network setups may cause ECONNTIMEOUT, leaving the client in an unknown
   168  	// state. Setting TCP KeepAlive on the socket connection will prohibit
   169  	// ECONNTIMEOUT unless the socket connection truly is broken
   170  	if tcpConn, ok := dial.(*net.TCPConn); ok {
   171  		tcpConn.SetKeepAlive(true)
   172  		tcpConn.SetKeepAlivePeriod(30 * time.Second)
   173  	}
   174  
   175  	clientconn := httputil.NewClientConn(dial, nil)
   176  	defer clientconn.Close()
   177  
   178  	// Server hijacks the connection, error 'connection closed' expected
   179  	clientconn.Do(req)
   180  
   181  	rwc, br := clientconn.Hijack()
   182  	defer rwc.Close()
   183  
   184  	if started != nil {
   185  		started <- rwc
   186  	}
   187  
   188  	var oldState *term.State
   189  
   190  	if in != nil && setRawTerminal && cli.isTerminalIn && os.Getenv("NORAW") == "" {
   191  		oldState, err = term.SetRawTerminal(cli.inFd)
   192  		if err != nil {
   193  			return err
   194  		}
   195  		defer term.RestoreTerminal(cli.inFd, oldState)
   196  	}
   197  
   198  	receiveStdout := make(chan error, 1)
   199  	if stdout != nil || stderr != nil {
   200  		go func() {
   201  			defer func() {
   202  				if in != nil {
   203  					if setRawTerminal && cli.isTerminalIn {
   204  						term.RestoreTerminal(cli.inFd, oldState)
   205  					}
   206  					in.Close()
   207  				}
   208  			}()
   209  
   210  			// When TTY is ON, use regular copy
   211  			if setRawTerminal && stdout != nil {
   212  				_, err = io.Copy(stdout, br)
   213  			} else {
   214  				_, err = stdcopy.StdCopy(stdout, stderr, br)
   215  			}
   216  			logrus.Debugf("[hijack] End of stdout")
   217  			receiveStdout <- err
   218  		}()
   219  	}
   220  
   221  	stdinDone := make(chan struct{})
   222  	go func() {
   223  		if in != nil {
   224  			io.Copy(rwc, in)
   225  			logrus.Debugf("[hijack] End of stdin")
   226  		}
   227  
   228  		if conn, ok := rwc.(interface {
   229  			CloseWrite() error
   230  		}); ok {
   231  			if err := conn.CloseWrite(); err != nil {
   232  				logrus.Debugf("Couldn't send EOF: %s", err)
   233  			}
   234  		}
   235  		close(stdinDone)
   236  	}()
   237  
   238  	select {
   239  	case err := <-receiveStdout:
   240  		if err != nil {
   241  			logrus.Debugf("Error receiveStdout: %s", err)
   242  			return err
   243  		}
   244  	case <-stdinDone:
   245  		if stdout != nil || stderr != nil {
   246  			if err := <-receiveStdout; err != nil {
   247  				logrus.Debugf("Error receiveStdout: %s", err)
   248  				return err
   249  			}
   250  		}
   251  	}
   252  
   253  	return nil
   254  }