github.com/hashicorp/vault/sdk@v0.11.0/helper/docker/testhelpers.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package docker
     5  
     6  import (
     7  	"archive/tar"
     8  	"bufio"
     9  	"bytes"
    10  	"context"
    11  	"encoding/base64"
    12  	"encoding/json"
    13  	"fmt"
    14  	"io"
    15  	"io/ioutil"
    16  	"net/url"
    17  	"os"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  	"time"
    22  
    23  	"github.com/cenkalti/backoff/v3"
    24  	"github.com/docker/docker/api/types"
    25  	"github.com/docker/docker/api/types/container"
    26  	"github.com/docker/docker/api/types/filters"
    27  	"github.com/docker/docker/api/types/mount"
    28  	"github.com/docker/docker/api/types/network"
    29  	"github.com/docker/docker/api/types/strslice"
    30  	"github.com/docker/docker/client"
    31  	"github.com/docker/docker/pkg/archive"
    32  	"github.com/docker/docker/pkg/stdcopy"
    33  	"github.com/docker/go-connections/nat"
    34  	"github.com/hashicorp/go-uuid"
    35  )
    36  
    37  const DockerAPIVersion = "1.40"
    38  
    39  type Runner struct {
    40  	DockerAPI  *client.Client
    41  	RunOptions RunOptions
    42  }
    43  
    44  type RunOptions struct {
    45  	ImageRepo              string
    46  	ImageTag               string
    47  	ContainerName          string
    48  	Cmd                    []string
    49  	Entrypoint             []string
    50  	Env                    []string
    51  	NetworkName            string
    52  	NetworkID              string
    53  	CopyFromTo             map[string]string
    54  	Ports                  []string
    55  	DoNotAutoRemove        bool
    56  	AuthUsername           string
    57  	AuthPassword           string
    58  	OmitLogTimestamps      bool
    59  	LogConsumer            func(string)
    60  	Capabilities           []string
    61  	PreDelete              bool
    62  	PostStart              func(string, string) error
    63  	LogStderr              io.Writer
    64  	LogStdout              io.Writer
    65  	VolumeNameToMountPoint map[string]string
    66  }
    67  
    68  func NewDockerAPI() (*client.Client, error) {
    69  	return client.NewClientWithOpts(client.FromEnv, client.WithVersion(DockerAPIVersion))
    70  }
    71  
    72  func NewServiceRunner(opts RunOptions) (*Runner, error) {
    73  	dapi, err := NewDockerAPI()
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	if opts.NetworkName == "" {
    79  		opts.NetworkName = os.Getenv("TEST_DOCKER_NETWORK_NAME")
    80  	}
    81  	if opts.NetworkName != "" {
    82  		nets, err := dapi.NetworkList(context.TODO(), types.NetworkListOptions{
    83  			Filters: filters.NewArgs(filters.Arg("name", opts.NetworkName)),
    84  		})
    85  		if err != nil {
    86  			return nil, err
    87  		}
    88  		if len(nets) != 1 {
    89  			return nil, fmt.Errorf("expected exactly one docker network named %q, got %d", opts.NetworkName, len(nets))
    90  		}
    91  		opts.NetworkID = nets[0].ID
    92  	}
    93  	if opts.NetworkID == "" {
    94  		opts.NetworkID = os.Getenv("TEST_DOCKER_NETWORK_ID")
    95  	}
    96  	if opts.ContainerName == "" {
    97  		if strings.Contains(opts.ImageRepo, "/") {
    98  			return nil, fmt.Errorf("ContainerName is required for non-library images")
    99  		}
   100  		// If there's no slash in the repo it's almost certainly going to be
   101  		// a good container name.
   102  		opts.ContainerName = opts.ImageRepo
   103  	}
   104  	return &Runner{
   105  		DockerAPI:  dapi,
   106  		RunOptions: opts,
   107  	}, nil
   108  }
   109  
   110  type ServiceConfig interface {
   111  	Address() string
   112  	URL() *url.URL
   113  }
   114  
   115  func NewServiceHostPort(host string, port int) *ServiceHostPort {
   116  	return &ServiceHostPort{address: fmt.Sprintf("%s:%d", host, port)}
   117  }
   118  
   119  func NewServiceHostPortParse(s string) (*ServiceHostPort, error) {
   120  	pieces := strings.Split(s, ":")
   121  	if len(pieces) != 2 {
   122  		return nil, fmt.Errorf("address must be of the form host:port, got: %v", s)
   123  	}
   124  
   125  	port, err := strconv.Atoi(pieces[1])
   126  	if err != nil || port < 1 {
   127  		return nil, fmt.Errorf("address must be of the form host:port, got: %v", s)
   128  	}
   129  
   130  	return &ServiceHostPort{s}, nil
   131  }
   132  
   133  type ServiceHostPort struct {
   134  	address string
   135  }
   136  
   137  func (s ServiceHostPort) Address() string {
   138  	return s.address
   139  }
   140  
   141  func (s ServiceHostPort) URL() *url.URL {
   142  	return &url.URL{Host: s.address}
   143  }
   144  
   145  func NewServiceURLParse(s string) (*ServiceURL, error) {
   146  	u, err := url.Parse(s)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  	return &ServiceURL{u: *u}, nil
   151  }
   152  
   153  func NewServiceURL(u url.URL) *ServiceURL {
   154  	return &ServiceURL{u: u}
   155  }
   156  
   157  type ServiceURL struct {
   158  	u url.URL
   159  }
   160  
   161  func (s ServiceURL) Address() string {
   162  	return s.u.Host
   163  }
   164  
   165  func (s ServiceURL) URL() *url.URL {
   166  	return &s.u
   167  }
   168  
   169  // ServiceAdapter verifies connectivity to the service, then returns either the
   170  // connection string (typically a URL) and nil, or empty string and an error.
   171  type ServiceAdapter func(ctx context.Context, host string, port int) (ServiceConfig, error)
   172  
   173  // StartService will start the runner's configured docker container with a
   174  // random UUID suffix appended to the name to make it unique and will return
   175  // either a hostname or local address depending on if a Docker network was given.
   176  //
   177  // Most tests can default to using this.
   178  func (d *Runner) StartService(ctx context.Context, connect ServiceAdapter) (*Service, error) {
   179  	serv, _, err := d.StartNewService(ctx, true, false, connect)
   180  
   181  	return serv, err
   182  }
   183  
   184  type LogConsumerWriter struct {
   185  	consumer func(string)
   186  }
   187  
   188  func (l LogConsumerWriter) Write(p []byte) (n int, err error) {
   189  	// TODO this assumes that we're never passed partial log lines, which
   190  	// seems a safe assumption for now based on how docker looks to implement
   191  	// logging, but might change in the future.
   192  	scanner := bufio.NewScanner(bytes.NewReader(p))
   193  	scanner.Buffer(make([]byte, 64*1024), bufio.MaxScanTokenSize)
   194  	for scanner.Scan() {
   195  		l.consumer(scanner.Text())
   196  	}
   197  	return len(p), nil
   198  }
   199  
   200  var _ io.Writer = &LogConsumerWriter{}
   201  
   202  // StartNewService will start the runner's configured docker container but with the
   203  // ability to control adding a name suffix or forcing a local address to be returned.
   204  // 'addSuffix' will add a random UUID to the end of the container name.
   205  // 'forceLocalAddr' will force the container address returned to be in the
   206  // form of '127.0.0.1:1234' where 1234 is the mapped container port.
   207  func (d *Runner) StartNewService(ctx context.Context, addSuffix, forceLocalAddr bool, connect ServiceAdapter) (*Service, string, error) {
   208  	if d.RunOptions.PreDelete {
   209  		name := d.RunOptions.ContainerName
   210  		matches, err := d.DockerAPI.ContainerList(ctx, types.ContainerListOptions{
   211  			All: true,
   212  			// TODO use labels to ensure we don't delete anything we shouldn't
   213  			Filters: filters.NewArgs(
   214  				filters.Arg("name", name),
   215  			),
   216  		})
   217  		if err != nil {
   218  			return nil, "", fmt.Errorf("failed to list containers named %q", name)
   219  		}
   220  		for _, cont := range matches {
   221  			err = d.DockerAPI.ContainerRemove(ctx, cont.ID, types.ContainerRemoveOptions{Force: true})
   222  			if err != nil {
   223  				return nil, "", fmt.Errorf("failed to pre-delete container named %q", name)
   224  			}
   225  		}
   226  	}
   227  	result, err := d.Start(context.Background(), addSuffix, forceLocalAddr)
   228  	if err != nil {
   229  		return nil, "", err
   230  	}
   231  
   232  	// The waitgroup wg is used here to support some stuff in NewDockerCluster.
   233  	// We can't generate the PKI cert for the https listener until we know the
   234  	// container's address, meaning we must first start the container, then
   235  	// generate the cert, then copy it into the container, then signal Vault
   236  	// to reload its config/certs.  However, if we SIGHUP Vault before Vault
   237  	// has installed its signal handler, that will kill Vault, since the default
   238  	// behaviour for HUP is termination.  So the PostStart that NewDockerCluster
   239  	// passes in (which does all that PKI cert stuff) waits to see output from
   240  	// Vault on stdout/stderr before it sends the signal, and we don't want to
   241  	// run the PostStart until we've hooked into the docker logs.
   242  	var wg sync.WaitGroup
   243  	logConsumer := d.createLogConsumer(result.Container.ID, &wg)
   244  
   245  	if logConsumer != nil {
   246  		wg.Add(1)
   247  		go logConsumer()
   248  	}
   249  	wg.Wait()
   250  
   251  	if d.RunOptions.PostStart != nil {
   252  		if err := d.RunOptions.PostStart(result.Container.ID, result.RealIP); err != nil {
   253  			return nil, "", fmt.Errorf("poststart failed: %w", err)
   254  		}
   255  	}
   256  
   257  	cleanup := func() {
   258  		for i := 0; i < 10; i++ {
   259  			err := d.DockerAPI.ContainerRemove(ctx, result.Container.ID, types.ContainerRemoveOptions{Force: true})
   260  			if err == nil || client.IsErrNotFound(err) {
   261  				return
   262  			}
   263  			time.Sleep(1 * time.Second)
   264  		}
   265  	}
   266  
   267  	bo := backoff.NewExponentialBackOff()
   268  	bo.MaxInterval = time.Second * 5
   269  	bo.MaxElapsedTime = 2 * time.Minute
   270  
   271  	pieces := strings.Split(result.Addrs[0], ":")
   272  	portInt, err := strconv.Atoi(pieces[1])
   273  	if err != nil {
   274  		return nil, "", err
   275  	}
   276  
   277  	var config ServiceConfig
   278  	err = backoff.Retry(func() error {
   279  		container, err := d.DockerAPI.ContainerInspect(ctx, result.Container.ID)
   280  		if err != nil || !container.State.Running {
   281  			return backoff.Permanent(fmt.Errorf("failed inspect or container %q not running: %w", result.Container.ID, err))
   282  		}
   283  
   284  		c, err := connect(ctx, pieces[0], portInt)
   285  		if err != nil {
   286  			return err
   287  		}
   288  		if c == nil {
   289  			return fmt.Errorf("service adapter returned nil error and config")
   290  		}
   291  		config = c
   292  		return nil
   293  	}, bo)
   294  	if err != nil {
   295  		if !d.RunOptions.DoNotAutoRemove {
   296  			cleanup()
   297  		}
   298  		return nil, "", err
   299  	}
   300  
   301  	return &Service{
   302  		Config:      config,
   303  		Cleanup:     cleanup,
   304  		Container:   result.Container,
   305  		StartResult: result,
   306  	}, result.Container.ID, nil
   307  }
   308  
   309  // createLogConsumer returns a function to consume the logs of the container with the given ID.
   310  // If a wait group is given, `WaitGroup.Done()` will be called as soon as the call to the
   311  // ContainerLogs Docker API call is done.
   312  // The returned function will block, so it should be run on a goroutine.
   313  func (d *Runner) createLogConsumer(containerId string, wg *sync.WaitGroup) func() {
   314  	if d.RunOptions.LogStdout != nil && d.RunOptions.LogStderr != nil {
   315  		return func() {
   316  			d.consumeLogs(containerId, wg, d.RunOptions.LogStdout, d.RunOptions.LogStderr)
   317  		}
   318  	}
   319  	if d.RunOptions.LogConsumer != nil {
   320  		return func() {
   321  			d.consumeLogs(containerId, wg, &LogConsumerWriter{d.RunOptions.LogConsumer}, &LogConsumerWriter{d.RunOptions.LogConsumer})
   322  		}
   323  	}
   324  	return nil
   325  }
   326  
   327  // consumeLogs is the function called by the function returned by createLogConsumer.
   328  func (d *Runner) consumeLogs(containerId string, wg *sync.WaitGroup, logStdout, logStderr io.Writer) {
   329  	// We must run inside a goroutine because we're using Follow:true,
   330  	// and StdCopy will block until the log stream is closed.
   331  	stream, err := d.DockerAPI.ContainerLogs(context.Background(), containerId, types.ContainerLogsOptions{
   332  		ShowStdout: true,
   333  		ShowStderr: true,
   334  		Timestamps: !d.RunOptions.OmitLogTimestamps,
   335  		Details:    true,
   336  		Follow:     true,
   337  	})
   338  	wg.Done()
   339  	if err != nil {
   340  		d.RunOptions.LogConsumer(fmt.Sprintf("error reading container logs: %v", err))
   341  	} else {
   342  		_, err := stdcopy.StdCopy(logStdout, logStderr, stream)
   343  		if err != nil {
   344  			d.RunOptions.LogConsumer(fmt.Sprintf("error demultiplexing docker logs: %v", err))
   345  		}
   346  	}
   347  }
   348  
   349  type Service struct {
   350  	Config      ServiceConfig
   351  	Cleanup     func()
   352  	Container   *types.ContainerJSON
   353  	StartResult *StartResult
   354  }
   355  
   356  type StartResult struct {
   357  	Container *types.ContainerJSON
   358  	Addrs     []string
   359  	RealIP    string
   360  }
   361  
   362  func (d *Runner) Start(ctx context.Context, addSuffix, forceLocalAddr bool) (*StartResult, error) {
   363  	name := d.RunOptions.ContainerName
   364  	if addSuffix {
   365  		suffix, err := uuid.GenerateUUID()
   366  		if err != nil {
   367  			return nil, err
   368  		}
   369  		name += "-" + suffix
   370  	}
   371  
   372  	cfg := &container.Config{
   373  		Hostname: name,
   374  		Image:    fmt.Sprintf("%s:%s", d.RunOptions.ImageRepo, d.RunOptions.ImageTag),
   375  		Env:      d.RunOptions.Env,
   376  		Cmd:      d.RunOptions.Cmd,
   377  	}
   378  	if len(d.RunOptions.Ports) > 0 {
   379  		cfg.ExposedPorts = make(map[nat.Port]struct{})
   380  		for _, p := range d.RunOptions.Ports {
   381  			cfg.ExposedPorts[nat.Port(p)] = struct{}{}
   382  		}
   383  	}
   384  	if len(d.RunOptions.Entrypoint) > 0 {
   385  		cfg.Entrypoint = strslice.StrSlice(d.RunOptions.Entrypoint)
   386  	}
   387  
   388  	hostConfig := &container.HostConfig{
   389  		AutoRemove:      !d.RunOptions.DoNotAutoRemove,
   390  		PublishAllPorts: true,
   391  	}
   392  	if len(d.RunOptions.Capabilities) > 0 {
   393  		hostConfig.CapAdd = d.RunOptions.Capabilities
   394  	}
   395  
   396  	netConfig := &network.NetworkingConfig{}
   397  	if d.RunOptions.NetworkID != "" {
   398  		netConfig.EndpointsConfig = map[string]*network.EndpointSettings{
   399  			d.RunOptions.NetworkID: {},
   400  		}
   401  	}
   402  
   403  	// best-effort pull
   404  	var opts types.ImageCreateOptions
   405  	if d.RunOptions.AuthUsername != "" && d.RunOptions.AuthPassword != "" {
   406  		var buf bytes.Buffer
   407  		auth := map[string]string{
   408  			"username": d.RunOptions.AuthUsername,
   409  			"password": d.RunOptions.AuthPassword,
   410  		}
   411  		if err := json.NewEncoder(&buf).Encode(auth); err != nil {
   412  			return nil, err
   413  		}
   414  		opts.RegistryAuth = base64.URLEncoding.EncodeToString(buf.Bytes())
   415  	}
   416  	resp, _ := d.DockerAPI.ImageCreate(ctx, cfg.Image, opts)
   417  	if resp != nil {
   418  		_, _ = ioutil.ReadAll(resp)
   419  	}
   420  
   421  	for vol, mtpt := range d.RunOptions.VolumeNameToMountPoint {
   422  		hostConfig.Mounts = append(hostConfig.Mounts, mount.Mount{
   423  			Type:     "volume",
   424  			Source:   vol,
   425  			Target:   mtpt,
   426  			ReadOnly: false,
   427  		})
   428  	}
   429  
   430  	c, err := d.DockerAPI.ContainerCreate(ctx, cfg, hostConfig, netConfig, nil, cfg.Hostname)
   431  	if err != nil {
   432  		return nil, fmt.Errorf("container create failed: %v", err)
   433  	}
   434  
   435  	for from, to := range d.RunOptions.CopyFromTo {
   436  		if err := copyToContainer(ctx, d.DockerAPI, c.ID, from, to); err != nil {
   437  			_ = d.DockerAPI.ContainerRemove(ctx, c.ID, types.ContainerRemoveOptions{})
   438  			return nil, err
   439  		}
   440  	}
   441  
   442  	err = d.DockerAPI.ContainerStart(ctx, c.ID, types.ContainerStartOptions{})
   443  	if err != nil {
   444  		_ = d.DockerAPI.ContainerRemove(ctx, c.ID, types.ContainerRemoveOptions{})
   445  		return nil, fmt.Errorf("container start failed: %v", err)
   446  	}
   447  
   448  	inspect, err := d.DockerAPI.ContainerInspect(ctx, c.ID)
   449  	if err != nil {
   450  		_ = d.DockerAPI.ContainerRemove(ctx, c.ID, types.ContainerRemoveOptions{})
   451  		return nil, err
   452  	}
   453  
   454  	var addrs []string
   455  	for _, port := range d.RunOptions.Ports {
   456  		pieces := strings.Split(port, "/")
   457  		if len(pieces) < 2 {
   458  			return nil, fmt.Errorf("expected port of the form 1234/tcp, got: %s", port)
   459  		}
   460  		if d.RunOptions.NetworkID != "" && !forceLocalAddr {
   461  			addrs = append(addrs, fmt.Sprintf("%s:%s", cfg.Hostname, pieces[0]))
   462  		} else {
   463  			mapped, ok := inspect.NetworkSettings.Ports[nat.Port(port)]
   464  			if !ok || len(mapped) == 0 {
   465  				return nil, fmt.Errorf("no port mapping found for %s", port)
   466  			}
   467  			addrs = append(addrs, fmt.Sprintf("127.0.0.1:%s", mapped[0].HostPort))
   468  		}
   469  	}
   470  
   471  	var realIP string
   472  	if d.RunOptions.NetworkID == "" {
   473  		if len(inspect.NetworkSettings.Networks) > 1 {
   474  			return nil, fmt.Errorf("Set d.RunOptions.NetworkName instead for container with multiple networks: %v", inspect.NetworkSettings.Networks)
   475  		}
   476  		for _, network := range inspect.NetworkSettings.Networks {
   477  			realIP = network.IPAddress
   478  			break
   479  		}
   480  	} else {
   481  		realIP = inspect.NetworkSettings.Networks[d.RunOptions.NetworkName].IPAddress
   482  	}
   483  
   484  	return &StartResult{
   485  		Container: &inspect,
   486  		Addrs:     addrs,
   487  		RealIP:    realIP,
   488  	}, nil
   489  }
   490  
   491  func (d *Runner) RefreshFiles(ctx context.Context, containerID string) error {
   492  	for from, to := range d.RunOptions.CopyFromTo {
   493  		if err := copyToContainer(ctx, d.DockerAPI, containerID, from, to); err != nil {
   494  			// TODO too drastic?
   495  			_ = d.DockerAPI.ContainerRemove(ctx, containerID, types.ContainerRemoveOptions{})
   496  			return err
   497  		}
   498  	}
   499  	return d.DockerAPI.ContainerKill(ctx, containerID, "SIGHUP")
   500  }
   501  
   502  func (d *Runner) Stop(ctx context.Context, containerID string) error {
   503  	if d.RunOptions.NetworkID != "" {
   504  		if err := d.DockerAPI.NetworkDisconnect(ctx, d.RunOptions.NetworkID, containerID, true); err != nil {
   505  			return fmt.Errorf("error disconnecting network (%v): %v", d.RunOptions.NetworkID, err)
   506  		}
   507  	}
   508  
   509  	// timeout in seconds
   510  	timeout := 5
   511  	options := container.StopOptions{
   512  		Timeout: &timeout,
   513  	}
   514  	if err := d.DockerAPI.ContainerStop(ctx, containerID, options); err != nil {
   515  		return fmt.Errorf("error stopping container: %v", err)
   516  	}
   517  
   518  	return nil
   519  }
   520  
   521  func (d *Runner) RestartContainerWithTimeout(ctx context.Context, containerID string, timeout int) error {
   522  	err := d.DockerAPI.ContainerRestart(ctx, containerID, container.StopOptions{Timeout: &timeout})
   523  	if err != nil {
   524  		return fmt.Errorf("failed to restart container: %s", err)
   525  	}
   526  	var wg sync.WaitGroup
   527  	logConsumer := d.createLogConsumer(containerID, &wg)
   528  	if logConsumer != nil {
   529  		wg.Add(1)
   530  		go logConsumer()
   531  	}
   532  	// we don't really care about waiting for logs to start showing up, do we?
   533  	return nil
   534  }
   535  
   536  func (d *Runner) Restart(ctx context.Context, containerID string) error {
   537  	if err := d.DockerAPI.ContainerStart(ctx, containerID, types.ContainerStartOptions{}); err != nil {
   538  		return err
   539  	}
   540  
   541  	ends := &network.EndpointSettings{
   542  		NetworkID: d.RunOptions.NetworkID,
   543  	}
   544  
   545  	return d.DockerAPI.NetworkConnect(ctx, d.RunOptions.NetworkID, containerID, ends)
   546  }
   547  
   548  func copyToContainer(ctx context.Context, dapi *client.Client, containerID, from, to string) error {
   549  	srcInfo, err := archive.CopyInfoSourcePath(from, false)
   550  	if err != nil {
   551  		return fmt.Errorf("error copying from source %q: %v", from, err)
   552  	}
   553  
   554  	srcArchive, err := archive.TarResource(srcInfo)
   555  	if err != nil {
   556  		return fmt.Errorf("error creating tar from source %q: %v", from, err)
   557  	}
   558  	defer srcArchive.Close()
   559  
   560  	dstInfo := archive.CopyInfo{Path: to}
   561  
   562  	dstDir, content, err := archive.PrepareArchiveCopy(srcArchive, srcInfo, dstInfo)
   563  	if err != nil {
   564  		return fmt.Errorf("error preparing copy from %q -> %q: %v", from, to, err)
   565  	}
   566  	defer content.Close()
   567  	err = dapi.CopyToContainer(ctx, containerID, dstDir, content, types.CopyToContainerOptions{})
   568  	if err != nil {
   569  		return fmt.Errorf("error copying from %q -> %q: %v", from, to, err)
   570  	}
   571  
   572  	return nil
   573  }
   574  
   575  type RunCmdOpt interface {
   576  	Apply(cfg *types.ExecConfig) error
   577  }
   578  
   579  type RunCmdUser string
   580  
   581  var _ RunCmdOpt = (*RunCmdUser)(nil)
   582  
   583  func (u RunCmdUser) Apply(cfg *types.ExecConfig) error {
   584  	cfg.User = string(u)
   585  	return nil
   586  }
   587  
   588  func (d *Runner) RunCmdWithOutput(ctx context.Context, container string, cmd []string, opts ...RunCmdOpt) ([]byte, []byte, int, error) {
   589  	return RunCmdWithOutput(d.DockerAPI, ctx, container, cmd, opts...)
   590  }
   591  
   592  func RunCmdWithOutput(api *client.Client, ctx context.Context, container string, cmd []string, opts ...RunCmdOpt) ([]byte, []byte, int, error) {
   593  	runCfg := types.ExecConfig{
   594  		AttachStdout: true,
   595  		AttachStderr: true,
   596  		Cmd:          cmd,
   597  	}
   598  
   599  	for index, opt := range opts {
   600  		if err := opt.Apply(&runCfg); err != nil {
   601  			return nil, nil, -1, fmt.Errorf("error applying option (%d / %v): %w", index, opt, err)
   602  		}
   603  	}
   604  
   605  	ret, err := api.ContainerExecCreate(ctx, container, runCfg)
   606  	if err != nil {
   607  		return nil, nil, -1, fmt.Errorf("error creating execution environment: %v\ncfg: %v\n", err, runCfg)
   608  	}
   609  
   610  	resp, err := api.ContainerExecAttach(ctx, ret.ID, types.ExecStartCheck{})
   611  	if err != nil {
   612  		return nil, nil, -1, fmt.Errorf("error attaching to command execution: %v\ncfg: %v\nret: %v\n", err, runCfg, ret)
   613  	}
   614  	defer resp.Close()
   615  
   616  	var stdoutB bytes.Buffer
   617  	var stderrB bytes.Buffer
   618  	if _, err := stdcopy.StdCopy(&stdoutB, &stderrB, resp.Reader); err != nil {
   619  		return nil, nil, -1, fmt.Errorf("error reading command output: %v", err)
   620  	}
   621  
   622  	stdout := stdoutB.Bytes()
   623  	stderr := stderrB.Bytes()
   624  
   625  	// Fetch return code.
   626  	info, err := api.ContainerExecInspect(ctx, ret.ID)
   627  	if err != nil {
   628  		return stdout, stderr, -1, fmt.Errorf("error reading command exit code: %v", err)
   629  	}
   630  
   631  	return stdout, stderr, info.ExitCode, nil
   632  }
   633  
   634  func (d *Runner) RunCmdInBackground(ctx context.Context, container string, cmd []string, opts ...RunCmdOpt) (string, error) {
   635  	return RunCmdInBackground(d.DockerAPI, ctx, container, cmd, opts...)
   636  }
   637  
   638  func RunCmdInBackground(api *client.Client, ctx context.Context, container string, cmd []string, opts ...RunCmdOpt) (string, error) {
   639  	runCfg := types.ExecConfig{
   640  		AttachStdout: true,
   641  		AttachStderr: true,
   642  		Cmd:          cmd,
   643  	}
   644  
   645  	for index, opt := range opts {
   646  		if err := opt.Apply(&runCfg); err != nil {
   647  			return "", fmt.Errorf("error applying option (%d / %v): %w", index, opt, err)
   648  		}
   649  	}
   650  
   651  	ret, err := api.ContainerExecCreate(ctx, container, runCfg)
   652  	if err != nil {
   653  		return "", fmt.Errorf("error creating execution environment: %w\ncfg: %v\n", err, runCfg)
   654  	}
   655  
   656  	err = api.ContainerExecStart(ctx, ret.ID, types.ExecStartCheck{})
   657  	if err != nil {
   658  		return "", fmt.Errorf("error starting command execution: %w\ncfg: %v\nret: %v\n", err, runCfg, ret)
   659  	}
   660  
   661  	return ret.ID, nil
   662  }
   663  
   664  // Mapping of path->contents
   665  type PathContents interface {
   666  	UpdateHeader(header *tar.Header) error
   667  	Get() ([]byte, error)
   668  	SetMode(mode int64)
   669  	SetOwners(uid int, gid int)
   670  }
   671  
   672  type FileContents struct {
   673  	Data []byte
   674  	Mode int64
   675  	UID  int
   676  	GID  int
   677  }
   678  
   679  func (b FileContents) UpdateHeader(header *tar.Header) error {
   680  	header.Mode = b.Mode
   681  	header.Uid = b.UID
   682  	header.Gid = b.GID
   683  	return nil
   684  }
   685  
   686  func (b FileContents) Get() ([]byte, error) {
   687  	return b.Data, nil
   688  }
   689  
   690  func (b *FileContents) SetMode(mode int64) {
   691  	b.Mode = mode
   692  }
   693  
   694  func (b *FileContents) SetOwners(uid int, gid int) {
   695  	b.UID = uid
   696  	b.GID = gid
   697  }
   698  
   699  func PathContentsFromBytes(data []byte) PathContents {
   700  	return &FileContents{
   701  		Data: data,
   702  		Mode: 0o644,
   703  	}
   704  }
   705  
   706  func PathContentsFromString(data string) PathContents {
   707  	return PathContentsFromBytes([]byte(data))
   708  }
   709  
   710  type BuildContext map[string]PathContents
   711  
   712  func NewBuildContext() BuildContext {
   713  	return BuildContext{}
   714  }
   715  
   716  func BuildContextFromTarball(reader io.Reader) (BuildContext, error) {
   717  	archive := tar.NewReader(reader)
   718  	bCtx := NewBuildContext()
   719  
   720  	for true {
   721  		header, err := archive.Next()
   722  		if err != nil {
   723  			if err == io.EOF {
   724  				break
   725  			}
   726  
   727  			return nil, fmt.Errorf("failed to parse provided tarball: %v", err)
   728  		}
   729  
   730  		data := make([]byte, int(header.Size))
   731  		read, err := archive.Read(data)
   732  		if err != nil {
   733  			return nil, fmt.Errorf("failed to parse read from provided tarball: %v", err)
   734  		}
   735  
   736  		if read != int(header.Size) {
   737  			return nil, fmt.Errorf("unexpectedly short read on tarball: %v of %v", read, header.Size)
   738  		}
   739  
   740  		bCtx[header.Name] = &FileContents{
   741  			Data: data,
   742  			Mode: header.Mode,
   743  			UID:  header.Uid,
   744  			GID:  header.Gid,
   745  		}
   746  	}
   747  
   748  	return bCtx, nil
   749  }
   750  
   751  func (bCtx *BuildContext) ToTarball() (io.Reader, error) {
   752  	var err error
   753  	buffer := new(bytes.Buffer)
   754  	tarBuilder := tar.NewWriter(buffer)
   755  	defer tarBuilder.Close()
   756  
   757  	now := time.Now()
   758  	for filepath, contents := range *bCtx {
   759  		fileHeader := &tar.Header{
   760  			Name:       filepath,
   761  			ModTime:    now,
   762  			AccessTime: now,
   763  			ChangeTime: now,
   764  		}
   765  		if contents == nil && !strings.HasSuffix(filepath, "/") {
   766  			return nil, fmt.Errorf("expected file path (%v) to have trailing / due to nil contents, indicating directory", filepath)
   767  		}
   768  
   769  		if err := contents.UpdateHeader(fileHeader); err != nil {
   770  			return nil, fmt.Errorf("failed to update tar header entry for %v: %w", filepath, err)
   771  		}
   772  
   773  		var rawContents []byte
   774  		if contents != nil {
   775  			rawContents, err = contents.Get()
   776  			if err != nil {
   777  				return nil, fmt.Errorf("failed to get file contents for %v: %w", filepath, err)
   778  			}
   779  
   780  			fileHeader.Size = int64(len(rawContents))
   781  		}
   782  
   783  		if err := tarBuilder.WriteHeader(fileHeader); err != nil {
   784  			return nil, fmt.Errorf("failed to write tar header entry for %v: %w", filepath, err)
   785  		}
   786  
   787  		if contents != nil {
   788  			if _, err := tarBuilder.Write(rawContents); err != nil {
   789  				return nil, fmt.Errorf("failed to write tar file entry for %v: %w", filepath, err)
   790  			}
   791  		}
   792  	}
   793  
   794  	return bytes.NewReader(buffer.Bytes()), nil
   795  }
   796  
   797  type BuildOpt interface {
   798  	Apply(cfg *types.ImageBuildOptions) error
   799  }
   800  
   801  type BuildRemove bool
   802  
   803  var _ BuildOpt = (*BuildRemove)(nil)
   804  
   805  func (u BuildRemove) Apply(cfg *types.ImageBuildOptions) error {
   806  	cfg.Remove = bool(u)
   807  	return nil
   808  }
   809  
   810  type BuildForceRemove bool
   811  
   812  var _ BuildOpt = (*BuildForceRemove)(nil)
   813  
   814  func (u BuildForceRemove) Apply(cfg *types.ImageBuildOptions) error {
   815  	cfg.ForceRemove = bool(u)
   816  	return nil
   817  }
   818  
   819  type BuildPullParent bool
   820  
   821  var _ BuildOpt = (*BuildPullParent)(nil)
   822  
   823  func (u BuildPullParent) Apply(cfg *types.ImageBuildOptions) error {
   824  	cfg.PullParent = bool(u)
   825  	return nil
   826  }
   827  
   828  type BuildArgs map[string]*string
   829  
   830  var _ BuildOpt = (*BuildArgs)(nil)
   831  
   832  func (u BuildArgs) Apply(cfg *types.ImageBuildOptions) error {
   833  	cfg.BuildArgs = u
   834  	return nil
   835  }
   836  
   837  type BuildTags []string
   838  
   839  var _ BuildOpt = (*BuildTags)(nil)
   840  
   841  func (u BuildTags) Apply(cfg *types.ImageBuildOptions) error {
   842  	cfg.Tags = u
   843  	return nil
   844  }
   845  
   846  const containerfilePath = "_containerfile"
   847  
   848  func (d *Runner) BuildImage(ctx context.Context, containerfile string, containerContext BuildContext, opts ...BuildOpt) ([]byte, error) {
   849  	return BuildImage(ctx, d.DockerAPI, containerfile, containerContext, opts...)
   850  }
   851  
   852  func BuildImage(ctx context.Context, api *client.Client, containerfile string, containerContext BuildContext, opts ...BuildOpt) ([]byte, error) {
   853  	var cfg types.ImageBuildOptions
   854  
   855  	// Build container context tarball, provisioning containerfile in.
   856  	containerContext[containerfilePath] = PathContentsFromBytes([]byte(containerfile))
   857  	tar, err := containerContext.ToTarball()
   858  	if err != nil {
   859  		return nil, fmt.Errorf("failed to create build image context tarball: %w", err)
   860  	}
   861  	cfg.Dockerfile = "/" + containerfilePath
   862  
   863  	// Apply all given options
   864  	for index, opt := range opts {
   865  		if err := opt.Apply(&cfg); err != nil {
   866  			return nil, fmt.Errorf("failed to apply option (%d / %v): %w", index, opt, err)
   867  		}
   868  	}
   869  
   870  	resp, err := api.ImageBuild(ctx, tar, cfg)
   871  	if err != nil {
   872  		return nil, fmt.Errorf("failed to build image: %v", err)
   873  	}
   874  
   875  	output, err := io.ReadAll(resp.Body)
   876  	if err != nil {
   877  		return nil, fmt.Errorf("failed to read image build output: %w", err)
   878  	}
   879  
   880  	return output, nil
   881  }
   882  
   883  func (d *Runner) CopyTo(container string, destination string, contents BuildContext) error {
   884  	// XXX: currently we use the default options but we might want to allow
   885  	// modifying cfg.CopyUIDGID in the future.
   886  	var cfg types.CopyToContainerOptions
   887  
   888  	// Convert our provided contents to a tarball to ship up.
   889  	tar, err := contents.ToTarball()
   890  	if err != nil {
   891  		return fmt.Errorf("failed to build contents into tarball: %v", err)
   892  	}
   893  
   894  	return d.DockerAPI.CopyToContainer(context.Background(), container, destination, tar, cfg)
   895  }
   896  
   897  func (d *Runner) CopyFrom(container string, source string) (BuildContext, *types.ContainerPathStat, error) {
   898  	reader, stat, err := d.DockerAPI.CopyFromContainer(context.Background(), container, source)
   899  	if err != nil {
   900  		return nil, nil, fmt.Errorf("failed to read %v from container: %v", source, err)
   901  	}
   902  
   903  	result, err := BuildContextFromTarball(reader)
   904  	if err != nil {
   905  		return nil, nil, fmt.Errorf("failed to build archive from result: %v", err)
   906  	}
   907  
   908  	return result, &stat, nil
   909  }
   910  
   911  func (d *Runner) GetNetworkAndAddresses(container string) (map[string]string, error) {
   912  	response, err := d.DockerAPI.ContainerInspect(context.Background(), container)
   913  	if err != nil {
   914  		return nil, fmt.Errorf("failed to fetch container inspection data: %v", err)
   915  	}
   916  
   917  	if response.NetworkSettings == nil || len(response.NetworkSettings.Networks) == 0 {
   918  		return nil, fmt.Errorf("container (%v) had no associated network settings: %v", container, response)
   919  	}
   920  
   921  	ret := make(map[string]string)
   922  	ns := response.NetworkSettings.Networks
   923  	for network, data := range ns {
   924  		if data == nil {
   925  			continue
   926  		}
   927  
   928  		ret[network] = data.IPAddress
   929  	}
   930  
   931  	if len(ret) == 0 {
   932  		return nil, fmt.Errorf("no valid network data for container (%v): %v", container, response)
   933  	}
   934  
   935  	return ret, nil
   936  }