github.com/containers/libpod@v1.9.4-0.20220419124438-4284fd425507/pkg/bindings/connection.go (about)

     1  package bindings
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net"
    10  	"net/http"
    11  	"net/url"
    12  	"os"
    13  	"path/filepath"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/containers/libpod/pkg/api/handlers"
    19  	jsoniter "github.com/json-iterator/go"
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"golang.org/x/crypto/ssh"
    23  	"k8s.io/client-go/util/homedir"
    24  )
    25  
    26  var (
    27  	basePath = &url.URL{
    28  		Scheme: "http",
    29  		Host:   "d",
    30  		Path:   "/v" + handlers.MinimalApiVersion + "/libpod",
    31  	}
    32  )
    33  
    34  type APIResponse struct {
    35  	*http.Response
    36  	Request *http.Request
    37  }
    38  
    39  type Connection struct {
    40  	_url   *url.URL
    41  	client *http.Client
    42  }
    43  
    44  type valueKey string
    45  
    46  const (
    47  	clientKey = valueKey("client")
    48  )
    49  
    50  // GetClient from context build by NewConnection()
    51  func GetClient(ctx context.Context) (*Connection, error) {
    52  	c, ok := ctx.Value(clientKey).(*Connection)
    53  	if !ok {
    54  		return nil, errors.Errorf("ClientKey not set in context")
    55  	}
    56  	return c, nil
    57  }
    58  
    59  // JoinURL elements with '/'
    60  func JoinURL(elements ...string) string {
    61  	return strings.Join(elements, "/")
    62  }
    63  
    64  // NewConnection takes a URI as a string and returns a context with the
    65  // Connection embedded as a value.  This context needs to be passed to each
    66  // endpoint to work correctly.
    67  //
    68  // A valid URI connection should be scheme://
    69  // For example tcp://localhost:<port>
    70  // or unix:///run/podman/podman.sock
    71  // or ssh://<user>@<host>[:port]/run/podman/podman.sock?secure=True
    72  func NewConnection(ctx context.Context, uri string, identity ...string) (context.Context, error) {
    73  	var (
    74  		err    error
    75  		secure bool
    76  	)
    77  	if v, found := os.LookupEnv("PODMAN_HOST"); found {
    78  		uri = v
    79  	}
    80  
    81  	if v, found := os.LookupEnv("PODMAN_SSHKEY"); found {
    82  		identity = []string{v}
    83  	}
    84  
    85  	_url, err := url.Parse(uri)
    86  	if err != nil {
    87  		return nil, errors.Wrapf(err, "Value of PODMAN_HOST is not a valid url: %s", uri)
    88  	}
    89  
    90  	// Now we setup the http client to use the connection above
    91  	var client *http.Client
    92  	switch _url.Scheme {
    93  	case "ssh":
    94  		secure, err = strconv.ParseBool(_url.Query().Get("secure"))
    95  		if err != nil {
    96  			secure = false
    97  		}
    98  		client, err = sshClient(_url, identity[0], secure)
    99  	case "unix":
   100  		if !strings.HasPrefix(uri, "unix:///") {
   101  			// autofix unix://path_element vs unix:///path_element
   102  			_url.Path = JoinURL(_url.Host, _url.Path)
   103  			_url.Host = ""
   104  		}
   105  		client, err = unixClient(_url)
   106  	case "tcp":
   107  		if !strings.HasPrefix(uri, "tcp://") {
   108  			return nil, errors.New("tcp URIs should begin with tcp://")
   109  		}
   110  		client, err = tcpClient(_url)
   111  	default:
   112  		return nil, errors.Errorf("'%s' is not a supported schema", _url.Scheme)
   113  	}
   114  	if err != nil {
   115  		return nil, errors.Wrapf(err, "Failed to create %sClient", _url.Scheme)
   116  	}
   117  
   118  	ctx = context.WithValue(ctx, clientKey, &Connection{_url, client})
   119  	if err := pingNewConnection(ctx); err != nil {
   120  		return nil, err
   121  	}
   122  	return ctx, nil
   123  }
   124  
   125  func tcpClient(_url *url.URL) (*http.Client, error) {
   126  	return &http.Client{
   127  		Transport: &http.Transport{
   128  			DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
   129  				return net.Dial("tcp", _url.Path)
   130  			},
   131  			DisableCompression: true,
   132  		},
   133  	}, nil
   134  }
   135  
   136  // pingNewConnection pings to make sure the RESTFUL service is up
   137  // and running. it should only be used where initializing a connection
   138  func pingNewConnection(ctx context.Context) error {
   139  	client, err := GetClient(ctx)
   140  	if err != nil {
   141  		return err
   142  	}
   143  	// the ping endpoint sits at / in this case
   144  	response, err := client.DoRequest(nil, http.MethodGet, "../../../_ping", nil)
   145  	if err != nil {
   146  		return err
   147  	}
   148  	if response.StatusCode == http.StatusOK {
   149  		return nil
   150  	}
   151  	return errors.Errorf("ping response was %q", response.StatusCode)
   152  }
   153  
   154  func sshClient(_url *url.URL, identity string, secure bool) (*http.Client, error) {
   155  	auth, err := publicKey(identity)
   156  	if err != nil {
   157  		return nil, errors.Wrapf(err, "Failed to parse identity %s: %v\n", _url.String(), identity)
   158  	}
   159  
   160  	callback := ssh.InsecureIgnoreHostKey()
   161  	if secure {
   162  		key := hostKey(_url.Hostname())
   163  		if key != nil {
   164  			callback = ssh.FixedHostKey(key)
   165  		}
   166  	}
   167  
   168  	port := _url.Port()
   169  	if port == "" {
   170  		port = "22"
   171  	}
   172  
   173  	bastion, err := ssh.Dial("tcp",
   174  		net.JoinHostPort(_url.Hostname(), port),
   175  		&ssh.ClientConfig{
   176  			User:            _url.User.Username(),
   177  			Auth:            []ssh.AuthMethod{auth},
   178  			HostKeyCallback: callback,
   179  			HostKeyAlgorithms: []string{
   180  				ssh.KeyAlgoRSA,
   181  				ssh.KeyAlgoDSA,
   182  				ssh.KeyAlgoECDSA256,
   183  				ssh.KeyAlgoECDSA384,
   184  				ssh.KeyAlgoECDSA521,
   185  				ssh.KeyAlgoED25519,
   186  			},
   187  			Timeout: 5 * time.Second,
   188  		},
   189  	)
   190  	if err != nil {
   191  		return nil, errors.Wrapf(err, "Connection to bastion host (%s) failed.", _url.String())
   192  	}
   193  	return &http.Client{
   194  		Transport: &http.Transport{
   195  			DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
   196  				return bastion.Dial("unix", _url.Path)
   197  			},
   198  		}}, nil
   199  }
   200  
   201  func unixClient(_url *url.URL) (*http.Client, error) {
   202  	return &http.Client{
   203  		Transport: &http.Transport{
   204  			DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
   205  				d := net.Dialer{}
   206  				return d.DialContext(ctx, "unix", _url.Path)
   207  			},
   208  			DisableCompression: true,
   209  		},
   210  	}, nil
   211  }
   212  
   213  // DoRequest assembles the http request and returns the response
   214  func (c *Connection) DoRequest(httpBody io.Reader, httpMethod, endpoint string, queryParams url.Values, pathValues ...string) (*APIResponse, error) {
   215  	var (
   216  		err      error
   217  		response *http.Response
   218  	)
   219  	safePathValues := make([]interface{}, len(pathValues))
   220  	// Make sure path values are http url safe
   221  	for i, pv := range pathValues {
   222  		safePathValues[i] = url.PathEscape(pv)
   223  	}
   224  	// Lets eventually use URL for this which might lead to safer
   225  	// usage
   226  	safeEndpoint := fmt.Sprintf(endpoint, safePathValues...)
   227  	e := basePath.String() + safeEndpoint
   228  	req, err := http.NewRequest(httpMethod, e, httpBody)
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  	if len(queryParams) > 0 {
   233  		req.URL.RawQuery = queryParams.Encode()
   234  	}
   235  	// Give the Do three chances in the case of a comm/service hiccup
   236  	for i := 0; i < 3; i++ {
   237  		response, err = c.client.Do(req) // nolint
   238  		if err == nil {
   239  			break
   240  		}
   241  		time.Sleep(time.Duration(i*100) * time.Millisecond)
   242  	}
   243  	return &APIResponse{response, req}, err
   244  }
   245  
   246  // FiltersToString converts our typical filter format of a
   247  // map[string][]string to a query/html safe string.
   248  func FiltersToString(filters map[string][]string) (string, error) {
   249  	lowerCaseKeys := make(map[string][]string)
   250  	for k, v := range filters {
   251  		lowerCaseKeys[strings.ToLower(k)] = v
   252  	}
   253  	return jsoniter.MarshalToString(lowerCaseKeys)
   254  }
   255  
   256  // IsInformation returns true if the response code is 1xx
   257  func (h *APIResponse) IsInformational() bool {
   258  	return h.Response.StatusCode/100 == 1
   259  }
   260  
   261  // IsSuccess returns true if the response code is 2xx
   262  func (h *APIResponse) IsSuccess() bool {
   263  	return h.Response.StatusCode/100 == 2
   264  }
   265  
   266  // IsRedirection returns true if the response code is 3xx
   267  func (h *APIResponse) IsRedirection() bool {
   268  	return h.Response.StatusCode/100 == 3
   269  }
   270  
   271  // IsClientError returns true if the response code is 4xx
   272  func (h *APIResponse) IsClientError() bool {
   273  	return h.Response.StatusCode/100 == 4
   274  }
   275  
   276  // IsServerError returns true if the response code is 5xx
   277  func (h *APIResponse) IsServerError() bool {
   278  	return h.Response.StatusCode/100 == 5
   279  }
   280  
   281  func publicKey(path string) (ssh.AuthMethod, error) {
   282  	key, err := ioutil.ReadFile(path)
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  
   287  	signer, err := ssh.ParsePrivateKey(key)
   288  	if err != nil {
   289  		return nil, err
   290  	}
   291  
   292  	return ssh.PublicKeys(signer), nil
   293  }
   294  
   295  func hostKey(host string) ssh.PublicKey {
   296  	// parse OpenSSH known_hosts file
   297  	// ssh or use ssh-keyscan to get initial key
   298  	known_hosts := filepath.Join(homedir.HomeDir(), ".ssh", "known_hosts")
   299  	fd, err := os.Open(known_hosts)
   300  	if err != nil {
   301  		logrus.Error(err)
   302  		return nil
   303  	}
   304  
   305  	scanner := bufio.NewScanner(fd)
   306  	for scanner.Scan() {
   307  		_, hosts, key, _, _, err := ssh.ParseKnownHosts(scanner.Bytes())
   308  		if err != nil {
   309  			logrus.Errorf("Failed to parse known_hosts: %s", scanner.Text())
   310  			continue
   311  		}
   312  
   313  		for _, h := range hosts {
   314  			if h == host {
   315  				return key
   316  			}
   317  		}
   318  	}
   319  
   320  	return nil
   321  }