github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/container/lxd/connection.go (about)

     1  // Copyright 2018 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package lxd
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"net/http"
    10  	"net/url"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  
    15  	lxd "github.com/canonical/lxd/client"
    16  	"github.com/juju/errors"
    17  )
    18  
    19  type Protocol string
    20  
    21  const (
    22  	LXDProtocol           Protocol = "lxd"
    23  	SimpleStreamsProtocol Protocol = "simplestreams"
    24  )
    25  
    26  // ServerSpec describes the location and connection details for a
    27  // server utilized in LXD workflows.
    28  type ServerSpec struct {
    29  	Name           string
    30  	Host           string
    31  	Protocol       Protocol
    32  	connectionArgs *lxd.ConnectionArgs
    33  }
    34  
    35  // ProxyFunc defines a function that can act as a proxy for requests
    36  type ProxyFunc func(*http.Request) (*url.URL, error)
    37  
    38  // NewServerSpec creates a ServerSpec with default values where needed.
    39  // It also ensures the HTTPS for the host implicitly
    40  func NewServerSpec(host, serverCert string, clientCert *Certificate) ServerSpec {
    41  	return ServerSpec{
    42  		Host: EnsureHTTPS(host),
    43  		connectionArgs: &lxd.ConnectionArgs{
    44  			TLSServerCert: serverCert,
    45  			TLSClientCert: string(clientCert.CertPEM),
    46  			TLSClientKey:  string(clientCert.KeyPEM),
    47  		},
    48  	}
    49  }
    50  
    51  // WithProxy adds the optional proxy to the server spec.
    52  // Returns the ServerSpec to enable chaining of optional values
    53  func (s ServerSpec) WithProxy(proxy ProxyFunc) ServerSpec {
    54  	s.connectionArgs.Proxy = proxy
    55  	return s
    56  }
    57  
    58  // WithClientCertificate adds the optional client Certificate to the server
    59  // spec.
    60  // Returns the ServerSpec to enable chaining of optional values
    61  func (s ServerSpec) WithClientCertificate(clientCert *Certificate) ServerSpec {
    62  	s.connectionArgs.TLSClientCert = string(clientCert.CertPEM)
    63  	s.connectionArgs.TLSClientKey = string(clientCert.KeyPEM)
    64  	return s
    65  }
    66  
    67  // WithSkipGetServer adds the option skipping of the get server verification to
    68  // the server spec.
    69  func (s ServerSpec) WithSkipGetServer(b bool) ServerSpec {
    70  	s.connectionArgs.SkipGetServer = b
    71  	return s
    72  }
    73  
    74  // WithHTTPClient adds the option of passing the http client to the server spec.
    75  func (s ServerSpec) WithHTTPClient(client *http.Client) ServerSpec {
    76  	s.connectionArgs.HTTPClient = client
    77  	return s
    78  }
    79  
    80  // NewInsecureServerSpec creates a ServerSpec without certificate requirements,
    81  // which also bypasses the TLS verification.
    82  // It also ensures the HTTPS for the host implicitly
    83  func NewInsecureServerSpec(host string) ServerSpec {
    84  	return ServerSpec{
    85  		Host: EnsureHTTPS(host),
    86  		connectionArgs: &lxd.ConnectionArgs{
    87  			InsecureSkipVerify: true,
    88  		},
    89  	}
    90  }
    91  
    92  // MakeSimpleStreamsServerSpec creates a ServerSpec for the SimpleStreams
    93  // protocol, ensuring that the host is HTTPS
    94  func MakeSimpleStreamsServerSpec(name, host string) ServerSpec {
    95  	return ServerSpec{
    96  		Name:     name,
    97  		Host:     EnsureHTTPS(host),
    98  		Protocol: SimpleStreamsProtocol,
    99  	}
   100  }
   101  
   102  // Validate ensures that the ServerSpec is valid.
   103  func (s *ServerSpec) Validate() error {
   104  	return nil
   105  }
   106  
   107  // CloudImagesRemote hosts releases blessed by the Canonical team.
   108  var CloudImagesRemote = ServerSpec{
   109  	Name:     "cloud-images.ubuntu.com",
   110  	Host:     "https://cloud-images.ubuntu.com/releases",
   111  	Protocol: SimpleStreamsProtocol,
   112  }
   113  
   114  // CloudImagesDailyRemote hosts images from daily package builds.
   115  // These images have not been independently tested, but should be sound for
   116  // use, being build from packages in the released archive.
   117  var CloudImagesDailyRemote = ServerSpec{
   118  	Name:     "cloud-images.ubuntu.com",
   119  	Host:     "https://cloud-images.ubuntu.com/daily",
   120  	Protocol: SimpleStreamsProtocol,
   121  }
   122  
   123  // CloudImagesLinuxContainersRemote hosts images for other distributions.
   124  // These will be used for pulling CentOS images.
   125  var CloudImagesLinuxContainersRemote = ServerSpec{
   126  	Name:     "images.linuxcontainers.org",
   127  	Host:     "https://images.linuxcontainers.org",
   128  	Protocol: SimpleStreamsProtocol,
   129  }
   130  
   131  // ConnectImageRemote connects to a remote ImageServer using specified protocol.
   132  var ConnectImageRemote = connectImageRemote
   133  
   134  func connectImageRemote(ctx context.Context, remote ServerSpec) (lxd.ImageServer, error) {
   135  	switch remote.Protocol {
   136  	case LXDProtocol:
   137  		return lxd.ConnectPublicLXDWithContext(ctx, remote.Host, remote.connectionArgs)
   138  	case SimpleStreamsProtocol:
   139  		return lxd.ConnectSimpleStreams(remote.Host, remote.connectionArgs)
   140  	}
   141  	return nil, fmt.Errorf("bad protocol supplied for connection: %v", remote.Protocol)
   142  }
   143  
   144  func connectLocal() (lxd.InstanceServer, error) {
   145  	client, err := lxd.ConnectLXDUnix(SocketPath(IsUnixSocket), nil)
   146  	return client, errors.Trace(err)
   147  }
   148  
   149  // ConnectRemote connects to LXD on a remote socket.
   150  func ConnectRemote(spec ServerSpec) (lxd.InstanceServer, error) {
   151  	// Ensure the Port on the Host, if we get an error it is reasonable to
   152  	// assume that the address in the spec is invalid.
   153  	uri, err := EnsureHostPort(spec.Host)
   154  	if err != nil {
   155  		return nil, errors.Trace(err)
   156  	}
   157  	client, err := lxd.ConnectLXD(uri, spec.connectionArgs)
   158  	return client, errors.Trace(err)
   159  }
   160  
   161  // SocketPath returns the path to the local LXD socket.
   162  // The following are tried in order of preference:
   163  //   - LXD_DIR environment variable.
   164  //   - Snap socket.
   165  //   - Debian socket.
   166  //
   167  // An empty string is returned if no socket path can be determined.
   168  func SocketPath(isSocket func(path string) bool) string {
   169  	for _, maybePath := range []string{
   170  		os.Getenv("LXD_DIR"),
   171  		filepath.FromSlash("/var/snap/lxd/common/lxd"),
   172  		filepath.FromSlash("/var/lib/lxd"),
   173  	} {
   174  		if maybePath == "" {
   175  			continue
   176  		}
   177  
   178  		maybePath = filepath.Join(maybePath, "unix.socket")
   179  		if isSocket(maybePath) {
   180  			logger.Debugf("using LXD socket at path: %q", maybePath)
   181  			return maybePath
   182  		}
   183  	}
   184  
   185  	logger.Debugf("unable to detect LXD socket path")
   186  	return ""
   187  }
   188  
   189  // EnsureHTTPS takes a URI and ensures that it is a HTTPS URL.
   190  // LXD Requires HTTPS.
   191  func EnsureHTTPS(address string) string {
   192  	if strings.HasPrefix(address, "https://") {
   193  		return address
   194  	}
   195  	if strings.HasPrefix(address, "http://") {
   196  		addr := strings.Replace(address, "http://", "https://", 1)
   197  		logger.Debugf("LXD requires https://, using: %s", addr)
   198  		return addr
   199  	}
   200  	return "https://" + address
   201  }
   202  
   203  const defaultPort = 8443
   204  
   205  // EnsureHostPort takes a URI and ensures that it has a port set, if it doesn't
   206  // then it will ensure that port if added.
   207  // The address supplied for the Host will be validated when parsed and if the
   208  // address is not valid, then it will return an error.
   209  func EnsureHostPort(address string) (string, error) {
   210  	// make sure we ensure a schema, otherwise somewhere:8443 can return a
   211  	// the following //:8443/somewhere
   212  	uri, err := url.Parse(EnsureHTTPS(address))
   213  	if err != nil {
   214  		return "", errors.Trace(err)
   215  	}
   216  	if uri.Port() == "" {
   217  		uri.Host = fmt.Sprintf("%s:%d", uri.Host, defaultPort)
   218  	}
   219  	return strings.TrimRight(uri.String(), "/"), nil
   220  }