github.com/hanks177/podman/v4@v4.1.3-0.20220613032544-16d90015bc83/pkg/bindings/connection.go (about)

     1  package bindings
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"net/url"
    10  	"os"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/blang/semver"
    16  	"github.com/hanks177/podman/v4/pkg/terminal"
    17  	"github.com/hanks177/podman/v4/version"
    18  	"github.com/pkg/errors"
    19  	"github.com/sirupsen/logrus"
    20  	"golang.org/x/crypto/ssh"
    21  	"golang.org/x/crypto/ssh/agent"
    22  )
    23  
    24  type APIResponse struct {
    25  	*http.Response
    26  	Request *http.Request
    27  }
    28  
    29  type Connection struct {
    30  	URI    *url.URL
    31  	Client *http.Client
    32  }
    33  
    34  type valueKey string
    35  
    36  const (
    37  	clientKey  = valueKey("Client")
    38  	versionKey = valueKey("ServiceVersion")
    39  )
    40  
    41  // GetClient from context build by NewConnection()
    42  func GetClient(ctx context.Context) (*Connection, error) {
    43  	if c, ok := ctx.Value(clientKey).(*Connection); ok {
    44  		return c, nil
    45  	}
    46  	return nil, errors.Errorf("%s not set in context", clientKey)
    47  }
    48  
    49  // ServiceVersion from context build by NewConnection()
    50  func ServiceVersion(ctx context.Context) *semver.Version {
    51  	if v, ok := ctx.Value(versionKey).(*semver.Version); ok {
    52  		return v
    53  	}
    54  	return new(semver.Version)
    55  }
    56  
    57  // JoinURL elements with '/'
    58  func JoinURL(elements ...string) string {
    59  	return "/" + strings.Join(elements, "/")
    60  }
    61  
    62  // NewConnection creates a new service connection without an identity
    63  func NewConnection(ctx context.Context, uri string) (context.Context, error) {
    64  	return NewConnectionWithIdentity(ctx, uri, "")
    65  }
    66  
    67  // NewConnectionWithIdentity takes a URI as a string and returns a context with the
    68  // Connection embedded as a value.  This context needs to be passed to each
    69  // endpoint to work correctly.
    70  //
    71  // A valid URI connection should be scheme://
    72  // For example tcp://localhost:<port>
    73  // or unix:///run/podman/podman.sock
    74  // or ssh://<user>@<host>[:port]/run/podman/podman.sock?secure=True
    75  func NewConnectionWithIdentity(ctx context.Context, uri string, identity string) (context.Context, error) {
    76  	var (
    77  		err    error
    78  		secure bool
    79  	)
    80  	if v, found := os.LookupEnv("CONTAINER_HOST"); found && uri == "" {
    81  		uri = v
    82  	}
    83  
    84  	if v, found := os.LookupEnv("CONTAINER_SSHKEY"); found && len(identity) == 0 {
    85  		identity = v
    86  	}
    87  
    88  	passPhrase := ""
    89  	if v, found := os.LookupEnv("CONTAINER_PASSPHRASE"); found {
    90  		passPhrase = v
    91  	}
    92  
    93  	_url, err := url.Parse(uri)
    94  	if err != nil {
    95  		return nil, errors.Wrapf(err, "Value of CONTAINER_HOST is not a valid url: %s", uri)
    96  	}
    97  
    98  	// Now we setup the http Client to use the connection above
    99  	var connection Connection
   100  	switch _url.Scheme {
   101  	case "ssh":
   102  		secure, err = strconv.ParseBool(_url.Query().Get("secure"))
   103  		if err != nil {
   104  			secure = false
   105  		}
   106  		connection, err = sshClient(_url, secure, passPhrase, identity)
   107  	case "unix":
   108  		if !strings.HasPrefix(uri, "unix:///") {
   109  			// autofix unix://path_element vs unix:///path_element
   110  			_url.Path = JoinURL(_url.Host, _url.Path)
   111  			_url.Host = ""
   112  		}
   113  		connection = unixClient(_url)
   114  	case "tcp":
   115  		if !strings.HasPrefix(uri, "tcp://") {
   116  			return nil, errors.New("tcp URIs should begin with tcp://")
   117  		}
   118  		connection = tcpClient(_url)
   119  	default:
   120  		return nil, errors.Errorf("unable to create connection. %q is not a supported schema", _url.Scheme)
   121  	}
   122  	if err != nil {
   123  		return nil, errors.Wrapf(err, "unable to connect to Podman. failed to create %sClient", _url.Scheme)
   124  	}
   125  
   126  	ctx = context.WithValue(ctx, clientKey, &connection)
   127  	serviceVersion, err := pingNewConnection(ctx)
   128  	if err != nil {
   129  		return nil, errors.Wrap(err, "unable to connect to Podman socket")
   130  	}
   131  	ctx = context.WithValue(ctx, versionKey, serviceVersion)
   132  	return ctx, nil
   133  }
   134  
   135  func tcpClient(_url *url.URL) Connection {
   136  	connection := Connection{
   137  		URI: _url,
   138  	}
   139  	connection.Client = &http.Client{
   140  		Transport: &http.Transport{
   141  			DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
   142  				return net.Dial("tcp", _url.Host)
   143  			},
   144  			DisableCompression: true,
   145  		},
   146  	}
   147  	return connection
   148  }
   149  
   150  // pingNewConnection pings to make sure the RESTFUL service is up
   151  // and running. it should only be used when initializing a connection
   152  func pingNewConnection(ctx context.Context) (*semver.Version, error) {
   153  	client, err := GetClient(ctx)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  	// the ping endpoint sits at / in this case
   158  	response, err := client.DoRequest(ctx, nil, http.MethodGet, "/_ping", nil, nil)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	defer response.Body.Close()
   163  
   164  	if response.StatusCode == http.StatusOK {
   165  		versionHdr := response.Header.Get("Libpod-API-Version")
   166  		if versionHdr == "" {
   167  			logrus.Info("Service did not provide Libpod-API-Version Header")
   168  			return new(semver.Version), nil
   169  		}
   170  		versionSrv, err := semver.ParseTolerant(versionHdr)
   171  		if err != nil {
   172  			return nil, err
   173  		}
   174  
   175  		switch version.APIVersion[version.Libpod][version.MinimalAPI].Compare(versionSrv) {
   176  		case -1, 0:
   177  			// Server's job when Client version is equal or older
   178  			return &versionSrv, nil
   179  		case 1:
   180  			return nil, errors.Errorf("server API version is too old. Client %q server %q",
   181  				version.APIVersion[version.Libpod][version.MinimalAPI].String(), versionSrv.String())
   182  		}
   183  	}
   184  	return nil, errors.Errorf("ping response was %d", response.StatusCode)
   185  }
   186  
   187  func sshClient(_url *url.URL, secure bool, passPhrase string, identity string) (Connection, error) {
   188  	// if you modify the authmethods or their conditionals, you will also need to make similar
   189  	// changes in the client (currently cmd/podman/system/connection/add getUDS).
   190  
   191  	var signers []ssh.Signer // order Signers are appended to this list determines which key is presented to server
   192  
   193  	if len(identity) > 0 {
   194  		s, err := terminal.PublicKey(identity, []byte(passPhrase))
   195  		if err != nil {
   196  			return Connection{}, errors.Wrapf(err, "failed to parse identity %q", identity)
   197  		}
   198  
   199  		signers = append(signers, s)
   200  		logrus.Debugf("SSH Ident Key %q %s %s", identity, ssh.FingerprintSHA256(s.PublicKey()), s.PublicKey().Type())
   201  	}
   202  
   203  	if sock, found := os.LookupEnv("SSH_AUTH_SOCK"); found {
   204  		logrus.Debugf("Found SSH_AUTH_SOCK %q, ssh-agent signer(s) enabled", sock)
   205  
   206  		c, err := net.Dial("unix", sock)
   207  		if err != nil {
   208  			return Connection{}, err
   209  		}
   210  
   211  		agentSigners, err := agent.NewClient(c).Signers()
   212  		if err != nil {
   213  			return Connection{}, err
   214  		}
   215  		signers = append(signers, agentSigners...)
   216  
   217  		if logrus.IsLevelEnabled(logrus.DebugLevel) {
   218  			for _, s := range agentSigners {
   219  				logrus.Debugf("SSH Agent Key %s %s", ssh.FingerprintSHA256(s.PublicKey()), s.PublicKey().Type())
   220  			}
   221  		}
   222  	}
   223  
   224  	var authMethods []ssh.AuthMethod
   225  	if len(signers) > 0 {
   226  		var dedup = make(map[string]ssh.Signer)
   227  		// Dedup signers based on fingerprint, ssh-agent keys override CONTAINER_SSHKEY
   228  		for _, s := range signers {
   229  			fp := ssh.FingerprintSHA256(s.PublicKey())
   230  			if _, found := dedup[fp]; found {
   231  				logrus.Debugf("Dedup SSH Key %s %s", ssh.FingerprintSHA256(s.PublicKey()), s.PublicKey().Type())
   232  			}
   233  			dedup[fp] = s
   234  		}
   235  
   236  		var uniq []ssh.Signer
   237  		for _, s := range dedup {
   238  			uniq = append(uniq, s)
   239  		}
   240  		authMethods = append(authMethods, ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
   241  			return uniq, nil
   242  		}))
   243  	}
   244  
   245  	if pw, found := _url.User.Password(); found {
   246  		authMethods = append(authMethods, ssh.Password(pw))
   247  	}
   248  
   249  	if len(authMethods) == 0 {
   250  		callback := func() (string, error) {
   251  			pass, err := terminal.ReadPassword("Login password:")
   252  			return string(pass), err
   253  		}
   254  		authMethods = append(authMethods, ssh.PasswordCallback(callback))
   255  	}
   256  
   257  	port := _url.Port()
   258  	if port == "" {
   259  		port = "22"
   260  	}
   261  
   262  	callback := ssh.InsecureIgnoreHostKey()
   263  	if secure {
   264  		host := _url.Hostname()
   265  		if port != "22" {
   266  			host = fmt.Sprintf("[%s]:%s", host, port)
   267  		}
   268  		key := terminal.HostKey(host)
   269  		if key != nil {
   270  			callback = ssh.FixedHostKey(key)
   271  		}
   272  	}
   273  
   274  	bastion, err := ssh.Dial("tcp",
   275  		net.JoinHostPort(_url.Hostname(), port),
   276  		&ssh.ClientConfig{
   277  			User:            _url.User.Username(),
   278  			Auth:            authMethods,
   279  			HostKeyCallback: callback,
   280  			HostKeyAlgorithms: []string{
   281  				ssh.KeyAlgoRSA,
   282  				ssh.KeyAlgoDSA,
   283  				ssh.KeyAlgoECDSA256,
   284  				ssh.KeyAlgoECDSA384,
   285  				ssh.KeyAlgoECDSA521,
   286  				ssh.KeyAlgoED25519,
   287  			},
   288  			Timeout: 5 * time.Second,
   289  		},
   290  	)
   291  	if err != nil {
   292  		return Connection{}, errors.Wrapf(err, "connection to bastion host (%s) failed", _url.String())
   293  	}
   294  
   295  	connection := Connection{URI: _url}
   296  	connection.Client = &http.Client{
   297  		Transport: &http.Transport{
   298  			DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
   299  				return bastion.Dial("unix", _url.Path)
   300  			},
   301  		}}
   302  	return connection, nil
   303  }
   304  
   305  func unixClient(_url *url.URL) Connection {
   306  	connection := Connection{URI: _url}
   307  	connection.Client = &http.Client{
   308  		Transport: &http.Transport{
   309  			DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
   310  				return (&net.Dialer{}).DialContext(ctx, "unix", _url.Path)
   311  			},
   312  			DisableCompression: true,
   313  		},
   314  	}
   315  	return connection
   316  }
   317  
   318  // DoRequest assembles the http request and returns the response
   319  func (c *Connection) DoRequest(ctx context.Context, httpBody io.Reader, httpMethod, endpoint string, queryParams url.Values, headers http.Header, pathValues ...string) (*APIResponse, error) {
   320  	var (
   321  		err      error
   322  		response *http.Response
   323  	)
   324  
   325  	params := make([]interface{}, len(pathValues)+1)
   326  
   327  	if v := headers.Values("API-Version"); len(v) > 0 {
   328  		params[0] = v[0]
   329  	} else {
   330  		// Including the semver suffices breaks older services... so do not include them
   331  		v := version.APIVersion[version.Libpod][version.CurrentAPI]
   332  		params[0] = fmt.Sprintf("%d.%d.%d", v.Major, v.Minor, v.Patch)
   333  	}
   334  
   335  	for i, pv := range pathValues {
   336  		// url.URL lacks the semantics for escaping embedded path parameters... so we manually
   337  		//   escape each one and assume the caller included the correct formatting in "endpoint"
   338  		params[i+1] = url.PathEscape(pv)
   339  	}
   340  
   341  	uri := fmt.Sprintf("http://d/v%s/libpod"+endpoint, params...)
   342  	logrus.Debugf("DoRequest Method: %s URI: %v", httpMethod, uri)
   343  
   344  	req, err := http.NewRequestWithContext(ctx, httpMethod, uri, httpBody)
   345  	if err != nil {
   346  		return nil, err
   347  	}
   348  	if len(queryParams) > 0 {
   349  		req.URL.RawQuery = queryParams.Encode()
   350  	}
   351  
   352  	for key, val := range headers {
   353  		if key == "API-Version" {
   354  			continue
   355  		}
   356  
   357  		for _, v := range val {
   358  			req.Header.Add(key, v)
   359  		}
   360  	}
   361  
   362  	// Give the Do three chances in the case of a comm/service hiccup
   363  	for i := 1; i <= 3; i++ {
   364  		response, err = c.Client.Do(req) // nolint
   365  		if err == nil {
   366  			break
   367  		}
   368  		time.Sleep(time.Duration(i*100) * time.Millisecond)
   369  	}
   370  	return &APIResponse{response, req}, err
   371  }
   372  
   373  // GetDialer returns raw Transport.DialContext from client
   374  func (c *Connection) GetDialer(ctx context.Context) (net.Conn, error) {
   375  	client := c.Client
   376  	transport := client.Transport.(*http.Transport)
   377  	if transport.DialContext != nil && transport.TLSClientConfig == nil {
   378  		return transport.DialContext(ctx, c.URI.Scheme, c.URI.String())
   379  	}
   380  
   381  	return nil, errors.New("Unable to get dial context")
   382  }
   383  
   384  // IsInformational returns true if the response code is 1xx
   385  func (h *APIResponse) IsInformational() bool {
   386  	return h.Response.StatusCode/100 == 1
   387  }
   388  
   389  // IsSuccess returns true if the response code is 2xx
   390  func (h *APIResponse) IsSuccess() bool {
   391  	return h.Response.StatusCode/100 == 2
   392  }
   393  
   394  // IsRedirection returns true if the response code is 3xx
   395  func (h *APIResponse) IsRedirection() bool {
   396  	return h.Response.StatusCode/100 == 3
   397  }
   398  
   399  // IsClientError returns true if the response code is 4xx
   400  func (h *APIResponse) IsClientError() bool {
   401  	return h.Response.StatusCode/100 == 4
   402  }
   403  
   404  // IsConflictError returns true if the response code is 409
   405  func (h *APIResponse) IsConflictError() bool {
   406  	return h.Response.StatusCode == 409
   407  }
   408  
   409  // IsServerError returns true if the response code is 5xx
   410  func (h *APIResponse) IsServerError() bool {
   411  	return h.Response.StatusCode/100 == 5
   412  }