github.com/omnigres/cli@v0.1.4/orb/docker.go (about)

     1  package orb
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"os"
    11  	"os/signal"
    12  	"os/user"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/charmbracelet/log"
    17  	"github.com/docker/docker/api/types"
    18  	"github.com/docker/docker/api/types/container"
    19  	"github.com/docker/docker/api/types/image"
    20  	"github.com/docker/docker/api/types/mount"
    21  	"github.com/docker/docker/api/types/network"
    22  	"github.com/docker/docker/client"
    23  	"github.com/docker/docker/errdefs"
    24  	_ "github.com/lib/pq"
    25  	"github.com/omnigres/cli/internal/fileutils"
    26  	"github.com/omnigres/cli/tui"
    27  	"github.com/spf13/viper"
    28  	"golang.org/x/term"
    29  )
    30  
    31  const default_directory_mount = "/mnt/host"
    32  
    33  type DockerOrbCluster struct {
    34  	client             *client.Client
    35  	currentContainerId string
    36  	OrbOptions
    37  }
    38  
    39  func (d *DockerOrbCluster) Config() *Config {
    40  	return d.OrbOptions.Config
    41  }
    42  
    43  func NewDockerOrbCluster() (orb OrbCluster, err error) {
    44  	log.Debugf(
    45  		"Creating docker client from env."+
    46  			"\n %s: %s"+
    47  			"\n %s: %s"+
    48  			"\n %s: %s"+
    49  			"\n %s: %s",
    50  		client.EnvOverrideHost,
    51  		os.Getenv(client.EnvOverrideHost),
    52  		client.EnvOverrideAPIVersion,
    53  		os.Getenv(client.EnvOverrideAPIVersion),
    54  		client.EnvOverrideCertPath,
    55  		os.Getenv(client.EnvOverrideCertPath),
    56  		client.EnvTLSVerify,
    57  		os.Getenv(client.EnvTLSVerify),
    58  	)
    59  	cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
    60  	if err != nil {
    61  		return
    62  	}
    63  	orb = &DockerOrbCluster{client: cli, OrbOptions: OrbOptions{}}
    64  	return
    65  }
    66  
    67  func (d *DockerOrbCluster) Configure(options OrbOptions) error {
    68  	d.OrbOptions = options
    69  	return nil
    70  }
    71  
    72  func (d *DockerOrbCluster) prepareImage(ctx context.Context) (digest string, err error) {
    73  	cli := d.client
    74  	imageName := d.Config().Image.Name
    75  
    76  	var img types.ImageInspect
    77  
    78  	// Try getting the image locally
    79  	img, _, err = cli.ImageInspectWithRaw(ctx, imageName)
    80  
    81  	notFound := errdefs.IsNotFound(err)
    82  
    83  	// If there's an error and if it is not "not found" error, propagate it
    84  	if err != nil && !notFound {
    85  		return
    86  	}
    87  
    88  	digest = imageName
    89  
    90  	if !notFound {
    91  		// Get the digest (if found)
    92  		if len(img.RepoDigests) > 0 {
    93  			digest = img.RepoDigests[0]
    94  
    95  			// If digest does not match, it is as good as if was not found
    96  			if d.Config().Image.Digest != "" && d.Config().Image.Digest != digest {
    97  				imageName = digest
    98  				notFound = true
    99  			}
   100  		}
   101  	}
   102  
   103  	if notFound {
   104  		// Pull the image
   105  		var reader io.ReadCloser
   106  		reader, err = cli.ImagePull(ctx, imageName, image.PullOptions{})
   107  		if err != nil {
   108  			return
   109  		}
   110  		defer reader.Close()
   111  
   112  		progress := tui.NewDownloadProgress("Downloading docker image "+imageName, reader)
   113  		progressModel, progressError := progress.Run()
   114  		if progressError != nil {
   115  			fmt.Println("Download failed to run for image", imageName, progressError.Error())
   116  			os.Exit(1)
   117  		}
   118  
   119  		m := progressModel.(tui.Model)
   120  		if m.Err != nil {
   121  			fmt.Println("Oh no! Could not download image", imageName)
   122  			os.Exit(1)
   123  		}
   124  
   125  		// Getting the image locally again to get the digest
   126  		img, _, err = cli.ImageInspectWithRaw(ctx, imageName)
   127  		if err != nil {
   128  			return
   129  		}
   130  
   131  		// Fetch the digest
   132  		if len(img.RepoDigests) > 0 {
   133  			digest = img.RepoDigests[0]
   134  		}
   135  	}
   136  
   137  	// Ensure the config has been updated
   138  	if d.Config().Image.Name != digest {
   139  		d.Config().Image.Digest = digest
   140  	}
   141  
   142  	return
   143  
   144  }
   145  
   146  func (d *DockerOrbCluster) runfile() (v *viper.Viper) {
   147  	v = viper.New()
   148  	v.SetConfigFile(d.Path + "/omnigres.run.yaml")
   149  	return
   150  }
   151  
   152  func (d *DockerOrbCluster) waitUntilClusterIsReady(ctx context.Context, listeners []OrbStartEventListener, cancel context.CancelFunc) {
   153  
   154  	log.Debug("Waiting for is_omnigres_ready...")
   155  	deadline := time.Now().Add(1 * time.Minute)
   156  
   157  	ready := false
   158  
   159  checkPg:
   160  	for time.Now().Before(deadline) {
   161  		c, err := d.Connect(ctx)
   162  		if err == nil {
   163  			if err = c.Ping(); err != nil {
   164  				continue checkPg
   165  			}
   166  		checkOmnigres:
   167  			for time.Now().Before(deadline) {
   168  				if err = c.QueryRowContext(ctx, "select is_omnigres_ready()").Scan(&ready); err != nil {
   169  					time.Sleep(1 * time.Second)
   170  					log.Debugf("Error trying is_omnigres_ready: %s", err)
   171  					continue checkOmnigres
   172  				}
   173  				_ = c.Close()
   174  				log.Debugf("is_omnigres_ready: %t", ready)
   175  				if ready {
   176  					for _, listener := range listeners {
   177  						if listener.Ready != nil {
   178  							go listener.Ready(d)
   179  						}
   180  					}
   181  					return
   182  				}
   183  				time.Sleep(1 * time.Second)
   184  			}
   185  		}
   186  		time.Sleep(1 * time.Second)
   187  	}
   188  
   189  	fmt.Println("Can't get a healthy cluster, terminating...")
   190  	cancel()
   191  }
   192  
   193  func (d *DockerOrbCluster) StartWithCurrentUser(ctx context.Context, options OrbClusterStartOptions) (err error) {
   194  	ctx, cancel := context.WithCancel(ctx)
   195  	defer cancel()
   196  
   197  	// Get the current user
   198  	var currentUser *user.User
   199  	currentUser, err = user.Current()
   200  	if err != nil {
   201  		log.Fatalf("Could not get current user: %s", err)
   202  	}
   203  
   204  	err = d.Start(
   205  		ctx,
   206  		options,
   207  		&currentUser.Uid,
   208  		nil,
   209  	)
   210  	if err != nil {
   211  		log.Fatal("Fail starting Orb", "err", err)
   212  	}
   213  	return
   214  }
   215  
   216  func (d *DockerOrbCluster) Start(ctx context.Context, options OrbClusterStartOptions, runAs *string, entryPoint []string) (err error) {
   217  	cli := d.client
   218  	ctx, cancel := context.WithCancel(ctx)
   219  	defer cancel()
   220  
   221  	var imageDigest string
   222  
   223  	var run *viper.Viper
   224  	var containerId string
   225  
   226  	if options.Runfile {
   227  		run = d.runfile()
   228  		err = fileutils.CreateIfNotExists(run.ConfigFileUsed(), false)
   229  		if err != nil {
   230  			return
   231  		}
   232  
   233  		err = run.ReadInConfig()
   234  		if err != nil {
   235  			return
   236  		}
   237  
   238  		containerId, err = d.containerId()
   239  	}
   240  
   241  	// Prepare image
   242  	imageDigest, err = d.prepareImage(ctx)
   243  	if err != nil {
   244  		return
   245  	}
   246  
   247  checkContainer:
   248  	if containerId != "" {
   249  		log.Debugf("Found a container id %s", containerId)
   250  		var cnt types.ContainerJSON
   251  		cnt, err = cli.ContainerInspect(ctx, containerId)
   252  		if errdefs.IsNotFound(err) {
   253  			log.Warn("Container not found, starting new one", "container", containerId)
   254  			containerId = ""
   255  			goto checkContainer
   256  		}
   257  		if err != nil {
   258  			return
   259  		}
   260  		// Check the container
   261  		if cnt.State.Running {
   262  			err = errors.New("Container already running")
   263  			return
   264  		}
   265  
   266  		// Check the image
   267  		var image types.ImageInspect
   268  		image, _, err = cli.ImageInspectWithRaw(ctx, cnt.Image)
   269  		if err != nil {
   270  			return
   271  		}
   272  		if len(image.RepoDigests) > 0 && image.RepoDigests[0] != imageDigest {
   273  			err = fmt.Errorf("Container's image %s does not match expected %s", image.RepoDigests[0], imageDigest)
   274  			return
   275  		}
   276  
   277  	} else {
   278  
   279  		networkName := "omnigres"
   280  
   281  		_, err = cli.NetworkCreate(ctx, networkName, network.CreateOptions{
   282  			Driver: "bridge",
   283  		})
   284  
   285  		if err != nil {
   286  			// If it is a conflict, this is normal flow – network already exists
   287  			if !errdefs.IsConflict(err) {
   288  				// otherwise, it's an error
   289  				return
   290  			}
   291  		}
   292  
   293  		// Bindings
   294  		hostconfig := container.HostConfig{
   295  			AutoRemove: options.AutoRemove,
   296  			Mounts: []mount.Mount{
   297  				{
   298  					Type:   mount.TypeBind,
   299  					Source: d.Path,
   300  					Target: default_directory_mount,
   301  				},
   302  			},
   303  			NetworkMode: container.NetworkMode(networkName),
   304  		}
   305  
   306  		// Prepare environment for every orb
   307  		env := make([]string, 0)
   308  		for _, orb := range d.Config().Orbs {
   309  			for _, e := range os.Environ() {
   310  				if strings.HasPrefix(e, strings.ToUpper(orb.Name+"_")) {
   311  					env = append(env, e)
   312  				}
   313  			}
   314  		}
   315  		env = append(env, "POSTGRES_HOST_AUTH_METHOD=password")
   316  		// Allows to prevent problems with initialization scripts failing due to
   317  		// be unable to chmod /var/lib/postgresql/data (since it already exists
   318  		// and not owned by user passed in `runAs`)
   319  		env = append(env, "PGDATA=/var/lib/postgresql/omnigres")
   320  
   321  		// Create container
   322  		log.Debugf("Creating container ...")
   323  		var containerResponse container.CreateResponse
   324  		var config *container.Config
   325  		config = &container.Config{Image: imageDigest, Env: env}
   326  		if runAs != nil {
   327  			log.Debugf("🪪 Starting cluster with current user id: %s", *runAs)
   328  			// Ensure we have the right user and group
   329  			config.User = fmt.Sprintf("%s:postgres", *runAs)
   330  		}
   331  		if entryPoint != nil {
   332  			log.Debugf("🛂 Starting cluster with custom entry point: %s", entryPoint)
   333  			config.Entrypoint = entryPoint
   334  		}
   335  		containerResponse, err = cli.ContainerCreate(
   336  			ctx,
   337  			config,
   338  			&hostconfig,
   339  			nil,
   340  			nil,
   341  			"",
   342  		)
   343  		if err != nil {
   344  			return
   345  		}
   346  		containerId = containerResponse.ID
   347  		d.currentContainerId = containerId
   348  	}
   349  
   350  	if options.Attachment.ShouldAttach {
   351  		var resp types.HijackedResponse
   352  		resp, err = cli.ContainerAttach(ctx, containerId, container.AttachOptions{
   353  			Stream: true,
   354  			Stdin:  true,
   355  			Stdout: true,
   356  			Stderr: true,
   357  		})
   358  		if err != nil {
   359  			fmt.Printf("Error attaching to attach instance: %v\n", err)
   360  			return
   361  		}
   362  		defer resp.Close()
   363  
   364  		d.currentContainerId = containerId
   365  
   366  		// Connect stdout/stderr to the consumer
   367  		for _, listener := range options.Attachment.Listeners {
   368  			if listener.OutputHandler != nil {
   369  				listener.OutputHandler(d, resp.Reader)
   370  			}
   371  		}
   372  	}
   373  
   374  	// Start container
   375  	err = cli.ContainerStart(ctx, containerId, container.StartOptions{})
   376  	if err != nil {
   377  		return err
   378  	}
   379  
   380  	for _, listener := range options.Listeners {
   381  		if listener.Started != nil {
   382  			go listener.Started(d)
   383  		}
   384  	}
   385  
   386  	// If we fail below, stop the container
   387  	defer func() {
   388  		if err != nil || options.Attachment.ShouldAttach {
   389  			timeout := 0 // forcibly terminate
   390  			newErr := cli.ContainerStop(ctx, containerId, container.StopOptions{Timeout: &timeout})
   391  
   392  			if newErr != nil {
   393  				err = errors.Join(err, newErr)
   394  			}
   395  			if options.Attachment.ShouldAttach {
   396  				for _, listener := range options.Attachment.Listeners {
   397  					if listener.Stopped != nil {
   398  						go listener.Stopped(d)
   399  					}
   400  				}
   401  			}
   402  
   403  		}
   404  	}()
   405  
   406  	if options.Runfile {
   407  		run.Set("containerid", containerId)
   408  
   409  		err = run.WriteConfig()
   410  		if err != nil {
   411  			return
   412  		}
   413  	}
   414  
   415  	// TODO: do this in the background?
   416  	// wait only when we have Listeners
   417  	if options.Listeners != nil {
   418  		d.waitUntilClusterIsReady(ctx, options.Listeners, cancel)
   419  	}
   420  
   421  	if options.Attachment.ShouldAttach {
   422  		statusCh, errCh := cli.ContainerWait(ctx, containerId, container.WaitConditionNotRunning)
   423  		sigCtx, stop := signal.NotifyContext(ctx, os.Interrupt)
   424  		defer stop()
   425  
   426  		select {
   427  		case <-sigCtx.Done():
   428  			fmt.Println("Terminating cluster")
   429  		case err = <-errCh:
   430  			if err != nil {
   431  				return
   432  			}
   433  		case status := <-statusCh:
   434  			if status.StatusCode == 0 {
   435  				fmt.Printf("Omnigres exited with status: %d\n", status.StatusCode)
   436  			}
   437  		}
   438  	}
   439  
   440  	return nil
   441  }
   442  
   443  func (d *DockerOrbCluster) containerId() (containerId string, err error) {
   444  	if d.currentContainerId != "" {
   445  		containerId = d.currentContainerId
   446  	} else {
   447  		v := d.runfile()
   448  		err = v.ReadInConfig()
   449  		if err != nil {
   450  			return
   451  		}
   452  
   453  		containerId = v.GetString("containerid")
   454  	}
   455  	return
   456  }
   457  
   458  func (d *DockerOrbCluster) Stop(ctx context.Context) (err error) {
   459  	cli := d.client
   460  
   461  	var id string
   462  	id, err = d.containerId()
   463  	if err != nil {
   464  		return
   465  	}
   466  
   467  	var cnt types.ContainerJSON
   468  	cnt, err = cli.ContainerInspect(ctx, id)
   469  	if err != nil {
   470  		return
   471  	}
   472  
   473  	if !cnt.State.Running {
   474  		err = errors.New("Container is not running")
   475  		return
   476  	}
   477  
   478  	err = cli.ContainerStop(ctx, id, container.StopOptions{})
   479  	if err != nil {
   480  		return
   481  	}
   482  	return
   483  }
   484  
   485  func (d *DockerOrbCluster) Close() (err error) {
   486  	err = d.client.Close()
   487  	return
   488  }
   489  
   490  func (d *DockerOrbCluster) ConnectPsql(ctx context.Context, database ...string) (err error) {
   491  	var id string
   492  	id, err = d.containerId()
   493  	if err != nil {
   494  		return
   495  	}
   496  
   497  	var db string
   498  	if len(database) == 0 {
   499  		db = "omnigres"
   500  	} else {
   501  		db = database[0]
   502  	}
   503  	if len(database) > 1 {
   504  		err = errors.New("orb: database name is ambiguous")
   505  		return
   506  	}
   507  	cli := d.client
   508  
   509  	var execResponse types.IDResponse
   510  	execResponse, err = cli.ContainerExecCreate(ctx, id, container.ExecOptions{
   511  		Cmd:          []string{"psql", "-Uomnigres", "--set", "HISTFILE=.psql_history", db},
   512  		WorkingDir:   default_directory_mount,
   513  		AttachStdin:  true,
   514  		AttachStdout: true,
   515  		AttachStderr: true,
   516  		Tty:          true,
   517  	})
   518  
   519  	if err != nil {
   520  		return
   521  	}
   522  
   523  	// Attach to the exec instance
   524  	resp, err := cli.ContainerExecAttach(ctx, execResponse.ID, container.ExecAttachOptions{
   525  		Tty: true,
   526  	})
   527  	if err != nil {
   528  		fmt.Printf("Error attaching to exec instance: %v\n", err)
   529  		return
   530  	}
   531  	defer resp.Close()
   532  
   533  	// Save the original terminal state
   534  	oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
   535  	if err != nil {
   536  		fmt.Printf("Error setting terminal to raw mode: %v\n", err)
   537  		return
   538  	}
   539  	defer term.Restore(int(os.Stdin.Fd()), oldState)
   540  
   541  	// Connect stdin to the terminal
   542  	go func() {
   543  		_, _ = io.Copy(resp.Conn, os.Stdin)
   544  	}()
   545  
   546  	// Connect stdout/stderr to the terminal
   547  	_, _ = io.Copy(os.Stdout, resp.Reader)
   548  
   549  	return
   550  }
   551  
   552  func (d *DockerOrbCluster) NetworkID(ctx context.Context) (network string, err error) {
   553  	cli := d.client
   554  
   555  	var id string
   556  	id, err = d.containerId()
   557  	if err != nil {
   558  		return
   559  	}
   560  
   561  	var cnt types.ContainerJSON
   562  	cnt, err = cli.ContainerInspect(ctx, id)
   563  	if err != nil {
   564  		return
   565  	}
   566  
   567  	if !cnt.State.Running {
   568  		err = errors.New("Container is not running")
   569  		return
   570  	}
   571  
   572  	network = cnt.HostConfig.NetworkMode.NetworkName()
   573  	return
   574  }
   575  
   576  func (d *DockerOrbCluster) NetworkIP(ctx context.Context) (ip string, err error) {
   577  	cli := d.client
   578  
   579  	var id string
   580  	id, err = d.containerId()
   581  	if err != nil {
   582  		return
   583  	}
   584  
   585  	var cnt types.ContainerJSON
   586  	cnt, err = cli.ContainerInspect(ctx, id)
   587  	if err != nil {
   588  		return
   589  	}
   590  
   591  	if !cnt.State.Running {
   592  		err = errors.New("Container is not running")
   593  		return
   594  	}
   595  
   596  	ip = cnt.NetworkSettings.Networks[cnt.HostConfig.NetworkMode.NetworkName()].IPAddress
   597  	return
   598  }
   599  
   600  func (d *DockerOrbCluster) Connect(ctx context.Context, database ...string) (conn *sql.DB, err error) {
   601  	var db string
   602  	if len(database) == 0 {
   603  		db = "omnigres"
   604  	} else {
   605  		db = database[0]
   606  	}
   607  	var ip string
   608  	ip, err = d.NetworkIP(ctx)
   609  	if err != nil {
   610  		return
   611  	}
   612  	port := 5432
   613  	conn, err = sql.Open("postgres", fmt.Sprintf("user=omnigres password=omnigres dbname=%s host=%s port=%d sslmode=disable", db, ip, port))
   614  	return
   615  }
   616  
   617  func (d *DockerOrbCluster) Endpoints(ctx context.Context) (endpoints []Endpoint, err error) {
   618  	var addr string
   619  	addr, err = d.NetworkIP(ctx)
   620  	if err != nil {
   621  		return
   622  	}
   623  	ipaddr := net.ParseIP(addr)
   624  	endpoints = make([]Endpoint, 0)
   625  	var conn *sql.DB
   626  	conn, err = d.Connect(ctx)
   627  	if err != nil {
   628  		return
   629  	}
   630  	defer conn.Close()
   631  
   632  	var rows *sql.Rows
   633  	// Search for all databases
   634  	rows, err = conn.QueryContext(ctx, `select datname from pg_database where not datistemplate and datname != 'postgres'`)
   635  	if err != nil {
   636  		return
   637  	}
   638  	defer rows.Close()
   639  nextDatabase:
   640  	for rows.Next() {
   641  		var datname string
   642  		if err = rows.Scan(&datname); err != nil {
   643  			return
   644  		}
   645  		// For every database
   646  		var dbconn *sql.DB
   647  		dbconn, err = d.Connect(ctx, datname)
   648  		if err != nil {
   649  			return
   650  		}
   651  		defer dbconn.Close()
   652  		// Add the Postgres service
   653  		endpoints = append(endpoints, Endpoint{Database: datname, IP: ipaddr, Port: 5432, Protocol: "Postgres"})
   654  		// Get the list of HTTP listeners.
   655  		// TODO: in the future, we expect this to be generialized through omni_service
   656  		var portRows *sql.Rows
   657  		portRows, err = dbconn.QueryContext(ctx, "select effective_port from omni_httpd.listeners")
   658  		if err != nil {
   659  			err = nil
   660  			continue nextDatabase
   661  		}
   662  		defer portRows.Close()
   663  		for portRows.Next() {
   664  			var port int
   665  			err = portRows.Scan(&port)
   666  			if err != nil {
   667  				return
   668  			}
   669  			endpoints = append(endpoints, Endpoint{Database: datname, IP: ipaddr, Port: port, Protocol: "HTTP"})
   670  		}
   671  
   672  	}
   673  	return
   674  }