github.com/omnigres/cli@v0.1.4/orb/docker.go (about) 1 package orb 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "io" 9 "net" 10 "os" 11 "os/signal" 12 "os/user" 13 "strings" 14 "time" 15 16 "github.com/charmbracelet/log" 17 "github.com/docker/docker/api/types" 18 "github.com/docker/docker/api/types/container" 19 "github.com/docker/docker/api/types/image" 20 "github.com/docker/docker/api/types/mount" 21 "github.com/docker/docker/api/types/network" 22 "github.com/docker/docker/client" 23 "github.com/docker/docker/errdefs" 24 _ "github.com/lib/pq" 25 "github.com/omnigres/cli/internal/fileutils" 26 "github.com/omnigres/cli/tui" 27 "github.com/spf13/viper" 28 "golang.org/x/term" 29 ) 30 31 const default_directory_mount = "/mnt/host" 32 33 type DockerOrbCluster struct { 34 client *client.Client 35 currentContainerId string 36 OrbOptions 37 } 38 39 func (d *DockerOrbCluster) Config() *Config { 40 return d.OrbOptions.Config 41 } 42 43 func NewDockerOrbCluster() (orb OrbCluster, err error) { 44 log.Debugf( 45 "Creating docker client from env."+ 46 "\n %s: %s"+ 47 "\n %s: %s"+ 48 "\n %s: %s"+ 49 "\n %s: %s", 50 client.EnvOverrideHost, 51 os.Getenv(client.EnvOverrideHost), 52 client.EnvOverrideAPIVersion, 53 os.Getenv(client.EnvOverrideAPIVersion), 54 client.EnvOverrideCertPath, 55 os.Getenv(client.EnvOverrideCertPath), 56 client.EnvTLSVerify, 57 os.Getenv(client.EnvTLSVerify), 58 ) 59 cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) 60 if err != nil { 61 return 62 } 63 orb = &DockerOrbCluster{client: cli, OrbOptions: OrbOptions{}} 64 return 65 } 66 67 func (d *DockerOrbCluster) Configure(options OrbOptions) error { 68 d.OrbOptions = options 69 return nil 70 } 71 72 func (d *DockerOrbCluster) prepareImage(ctx context.Context) (digest string, err error) { 73 cli := d.client 74 imageName := d.Config().Image.Name 75 76 var img types.ImageInspect 77 78 // Try getting the image locally 79 img, _, err = cli.ImageInspectWithRaw(ctx, imageName) 80 81 notFound := errdefs.IsNotFound(err) 82 83 // If there's an error and if it is not "not found" error, propagate it 84 if err != nil && !notFound { 85 return 86 } 87 88 digest = imageName 89 90 if !notFound { 91 // Get the digest (if found) 92 if len(img.RepoDigests) > 0 { 93 digest = img.RepoDigests[0] 94 95 // If digest does not match, it is as good as if was not found 96 if d.Config().Image.Digest != "" && d.Config().Image.Digest != digest { 97 imageName = digest 98 notFound = true 99 } 100 } 101 } 102 103 if notFound { 104 // Pull the image 105 var reader io.ReadCloser 106 reader, err = cli.ImagePull(ctx, imageName, image.PullOptions{}) 107 if err != nil { 108 return 109 } 110 defer reader.Close() 111 112 progress := tui.NewDownloadProgress("Downloading docker image "+imageName, reader) 113 progressModel, progressError := progress.Run() 114 if progressError != nil { 115 fmt.Println("Download failed to run for image", imageName, progressError.Error()) 116 os.Exit(1) 117 } 118 119 m := progressModel.(tui.Model) 120 if m.Err != nil { 121 fmt.Println("Oh no! Could not download image", imageName) 122 os.Exit(1) 123 } 124 125 // Getting the image locally again to get the digest 126 img, _, err = cli.ImageInspectWithRaw(ctx, imageName) 127 if err != nil { 128 return 129 } 130 131 // Fetch the digest 132 if len(img.RepoDigests) > 0 { 133 digest = img.RepoDigests[0] 134 } 135 } 136 137 // Ensure the config has been updated 138 if d.Config().Image.Name != digest { 139 d.Config().Image.Digest = digest 140 } 141 142 return 143 144 } 145 146 func (d *DockerOrbCluster) runfile() (v *viper.Viper) { 147 v = viper.New() 148 v.SetConfigFile(d.Path + "/omnigres.run.yaml") 149 return 150 } 151 152 func (d *DockerOrbCluster) waitUntilClusterIsReady(ctx context.Context, listeners []OrbStartEventListener, cancel context.CancelFunc) { 153 154 log.Debug("Waiting for is_omnigres_ready...") 155 deadline := time.Now().Add(1 * time.Minute) 156 157 ready := false 158 159 checkPg: 160 for time.Now().Before(deadline) { 161 c, err := d.Connect(ctx) 162 if err == nil { 163 if err = c.Ping(); err != nil { 164 continue checkPg 165 } 166 checkOmnigres: 167 for time.Now().Before(deadline) { 168 if err = c.QueryRowContext(ctx, "select is_omnigres_ready()").Scan(&ready); err != nil { 169 time.Sleep(1 * time.Second) 170 log.Debugf("Error trying is_omnigres_ready: %s", err) 171 continue checkOmnigres 172 } 173 _ = c.Close() 174 log.Debugf("is_omnigres_ready: %t", ready) 175 if ready { 176 for _, listener := range listeners { 177 if listener.Ready != nil { 178 go listener.Ready(d) 179 } 180 } 181 return 182 } 183 time.Sleep(1 * time.Second) 184 } 185 } 186 time.Sleep(1 * time.Second) 187 } 188 189 fmt.Println("Can't get a healthy cluster, terminating...") 190 cancel() 191 } 192 193 func (d *DockerOrbCluster) StartWithCurrentUser(ctx context.Context, options OrbClusterStartOptions) (err error) { 194 ctx, cancel := context.WithCancel(ctx) 195 defer cancel() 196 197 // Get the current user 198 var currentUser *user.User 199 currentUser, err = user.Current() 200 if err != nil { 201 log.Fatalf("Could not get current user: %s", err) 202 } 203 204 err = d.Start( 205 ctx, 206 options, 207 ¤tUser.Uid, 208 nil, 209 ) 210 if err != nil { 211 log.Fatal("Fail starting Orb", "err", err) 212 } 213 return 214 } 215 216 func (d *DockerOrbCluster) Start(ctx context.Context, options OrbClusterStartOptions, runAs *string, entryPoint []string) (err error) { 217 cli := d.client 218 ctx, cancel := context.WithCancel(ctx) 219 defer cancel() 220 221 var imageDigest string 222 223 var run *viper.Viper 224 var containerId string 225 226 if options.Runfile { 227 run = d.runfile() 228 err = fileutils.CreateIfNotExists(run.ConfigFileUsed(), false) 229 if err != nil { 230 return 231 } 232 233 err = run.ReadInConfig() 234 if err != nil { 235 return 236 } 237 238 containerId, err = d.containerId() 239 } 240 241 // Prepare image 242 imageDigest, err = d.prepareImage(ctx) 243 if err != nil { 244 return 245 } 246 247 checkContainer: 248 if containerId != "" { 249 log.Debugf("Found a container id %s", containerId) 250 var cnt types.ContainerJSON 251 cnt, err = cli.ContainerInspect(ctx, containerId) 252 if errdefs.IsNotFound(err) { 253 log.Warn("Container not found, starting new one", "container", containerId) 254 containerId = "" 255 goto checkContainer 256 } 257 if err != nil { 258 return 259 } 260 // Check the container 261 if cnt.State.Running { 262 err = errors.New("Container already running") 263 return 264 } 265 266 // Check the image 267 var image types.ImageInspect 268 image, _, err = cli.ImageInspectWithRaw(ctx, cnt.Image) 269 if err != nil { 270 return 271 } 272 if len(image.RepoDigests) > 0 && image.RepoDigests[0] != imageDigest { 273 err = fmt.Errorf("Container's image %s does not match expected %s", image.RepoDigests[0], imageDigest) 274 return 275 } 276 277 } else { 278 279 networkName := "omnigres" 280 281 _, err = cli.NetworkCreate(ctx, networkName, network.CreateOptions{ 282 Driver: "bridge", 283 }) 284 285 if err != nil { 286 // If it is a conflict, this is normal flow – network already exists 287 if !errdefs.IsConflict(err) { 288 // otherwise, it's an error 289 return 290 } 291 } 292 293 // Bindings 294 hostconfig := container.HostConfig{ 295 AutoRemove: options.AutoRemove, 296 Mounts: []mount.Mount{ 297 { 298 Type: mount.TypeBind, 299 Source: d.Path, 300 Target: default_directory_mount, 301 }, 302 }, 303 NetworkMode: container.NetworkMode(networkName), 304 } 305 306 // Prepare environment for every orb 307 env := make([]string, 0) 308 for _, orb := range d.Config().Orbs { 309 for _, e := range os.Environ() { 310 if strings.HasPrefix(e, strings.ToUpper(orb.Name+"_")) { 311 env = append(env, e) 312 } 313 } 314 } 315 env = append(env, "POSTGRES_HOST_AUTH_METHOD=password") 316 // Allows to prevent problems with initialization scripts failing due to 317 // be unable to chmod /var/lib/postgresql/data (since it already exists 318 // and not owned by user passed in `runAs`) 319 env = append(env, "PGDATA=/var/lib/postgresql/omnigres") 320 321 // Create container 322 log.Debugf("Creating container ...") 323 var containerResponse container.CreateResponse 324 var config *container.Config 325 config = &container.Config{Image: imageDigest, Env: env} 326 if runAs != nil { 327 log.Debugf("🪪 Starting cluster with current user id: %s", *runAs) 328 // Ensure we have the right user and group 329 config.User = fmt.Sprintf("%s:postgres", *runAs) 330 } 331 if entryPoint != nil { 332 log.Debugf("🛂 Starting cluster with custom entry point: %s", entryPoint) 333 config.Entrypoint = entryPoint 334 } 335 containerResponse, err = cli.ContainerCreate( 336 ctx, 337 config, 338 &hostconfig, 339 nil, 340 nil, 341 "", 342 ) 343 if err != nil { 344 return 345 } 346 containerId = containerResponse.ID 347 d.currentContainerId = containerId 348 } 349 350 if options.Attachment.ShouldAttach { 351 var resp types.HijackedResponse 352 resp, err = cli.ContainerAttach(ctx, containerId, container.AttachOptions{ 353 Stream: true, 354 Stdin: true, 355 Stdout: true, 356 Stderr: true, 357 }) 358 if err != nil { 359 fmt.Printf("Error attaching to attach instance: %v\n", err) 360 return 361 } 362 defer resp.Close() 363 364 d.currentContainerId = containerId 365 366 // Connect stdout/stderr to the consumer 367 for _, listener := range options.Attachment.Listeners { 368 if listener.OutputHandler != nil { 369 listener.OutputHandler(d, resp.Reader) 370 } 371 } 372 } 373 374 // Start container 375 err = cli.ContainerStart(ctx, containerId, container.StartOptions{}) 376 if err != nil { 377 return err 378 } 379 380 for _, listener := range options.Listeners { 381 if listener.Started != nil { 382 go listener.Started(d) 383 } 384 } 385 386 // If we fail below, stop the container 387 defer func() { 388 if err != nil || options.Attachment.ShouldAttach { 389 timeout := 0 // forcibly terminate 390 newErr := cli.ContainerStop(ctx, containerId, container.StopOptions{Timeout: &timeout}) 391 392 if newErr != nil { 393 err = errors.Join(err, newErr) 394 } 395 if options.Attachment.ShouldAttach { 396 for _, listener := range options.Attachment.Listeners { 397 if listener.Stopped != nil { 398 go listener.Stopped(d) 399 } 400 } 401 } 402 403 } 404 }() 405 406 if options.Runfile { 407 run.Set("containerid", containerId) 408 409 err = run.WriteConfig() 410 if err != nil { 411 return 412 } 413 } 414 415 // TODO: do this in the background? 416 // wait only when we have Listeners 417 if options.Listeners != nil { 418 d.waitUntilClusterIsReady(ctx, options.Listeners, cancel) 419 } 420 421 if options.Attachment.ShouldAttach { 422 statusCh, errCh := cli.ContainerWait(ctx, containerId, container.WaitConditionNotRunning) 423 sigCtx, stop := signal.NotifyContext(ctx, os.Interrupt) 424 defer stop() 425 426 select { 427 case <-sigCtx.Done(): 428 fmt.Println("Terminating cluster") 429 case err = <-errCh: 430 if err != nil { 431 return 432 } 433 case status := <-statusCh: 434 if status.StatusCode == 0 { 435 fmt.Printf("Omnigres exited with status: %d\n", status.StatusCode) 436 } 437 } 438 } 439 440 return nil 441 } 442 443 func (d *DockerOrbCluster) containerId() (containerId string, err error) { 444 if d.currentContainerId != "" { 445 containerId = d.currentContainerId 446 } else { 447 v := d.runfile() 448 err = v.ReadInConfig() 449 if err != nil { 450 return 451 } 452 453 containerId = v.GetString("containerid") 454 } 455 return 456 } 457 458 func (d *DockerOrbCluster) Stop(ctx context.Context) (err error) { 459 cli := d.client 460 461 var id string 462 id, err = d.containerId() 463 if err != nil { 464 return 465 } 466 467 var cnt types.ContainerJSON 468 cnt, err = cli.ContainerInspect(ctx, id) 469 if err != nil { 470 return 471 } 472 473 if !cnt.State.Running { 474 err = errors.New("Container is not running") 475 return 476 } 477 478 err = cli.ContainerStop(ctx, id, container.StopOptions{}) 479 if err != nil { 480 return 481 } 482 return 483 } 484 485 func (d *DockerOrbCluster) Close() (err error) { 486 err = d.client.Close() 487 return 488 } 489 490 func (d *DockerOrbCluster) ConnectPsql(ctx context.Context, database ...string) (err error) { 491 var id string 492 id, err = d.containerId() 493 if err != nil { 494 return 495 } 496 497 var db string 498 if len(database) == 0 { 499 db = "omnigres" 500 } else { 501 db = database[0] 502 } 503 if len(database) > 1 { 504 err = errors.New("orb: database name is ambiguous") 505 return 506 } 507 cli := d.client 508 509 var execResponse types.IDResponse 510 execResponse, err = cli.ContainerExecCreate(ctx, id, container.ExecOptions{ 511 Cmd: []string{"psql", "-Uomnigres", "--set", "HISTFILE=.psql_history", db}, 512 WorkingDir: default_directory_mount, 513 AttachStdin: true, 514 AttachStdout: true, 515 AttachStderr: true, 516 Tty: true, 517 }) 518 519 if err != nil { 520 return 521 } 522 523 // Attach to the exec instance 524 resp, err := cli.ContainerExecAttach(ctx, execResponse.ID, container.ExecAttachOptions{ 525 Tty: true, 526 }) 527 if err != nil { 528 fmt.Printf("Error attaching to exec instance: %v\n", err) 529 return 530 } 531 defer resp.Close() 532 533 // Save the original terminal state 534 oldState, err := term.MakeRaw(int(os.Stdin.Fd())) 535 if err != nil { 536 fmt.Printf("Error setting terminal to raw mode: %v\n", err) 537 return 538 } 539 defer term.Restore(int(os.Stdin.Fd()), oldState) 540 541 // Connect stdin to the terminal 542 go func() { 543 _, _ = io.Copy(resp.Conn, os.Stdin) 544 }() 545 546 // Connect stdout/stderr to the terminal 547 _, _ = io.Copy(os.Stdout, resp.Reader) 548 549 return 550 } 551 552 func (d *DockerOrbCluster) NetworkID(ctx context.Context) (network string, err error) { 553 cli := d.client 554 555 var id string 556 id, err = d.containerId() 557 if err != nil { 558 return 559 } 560 561 var cnt types.ContainerJSON 562 cnt, err = cli.ContainerInspect(ctx, id) 563 if err != nil { 564 return 565 } 566 567 if !cnt.State.Running { 568 err = errors.New("Container is not running") 569 return 570 } 571 572 network = cnt.HostConfig.NetworkMode.NetworkName() 573 return 574 } 575 576 func (d *DockerOrbCluster) NetworkIP(ctx context.Context) (ip string, err error) { 577 cli := d.client 578 579 var id string 580 id, err = d.containerId() 581 if err != nil { 582 return 583 } 584 585 var cnt types.ContainerJSON 586 cnt, err = cli.ContainerInspect(ctx, id) 587 if err != nil { 588 return 589 } 590 591 if !cnt.State.Running { 592 err = errors.New("Container is not running") 593 return 594 } 595 596 ip = cnt.NetworkSettings.Networks[cnt.HostConfig.NetworkMode.NetworkName()].IPAddress 597 return 598 } 599 600 func (d *DockerOrbCluster) Connect(ctx context.Context, database ...string) (conn *sql.DB, err error) { 601 var db string 602 if len(database) == 0 { 603 db = "omnigres" 604 } else { 605 db = database[0] 606 } 607 var ip string 608 ip, err = d.NetworkIP(ctx) 609 if err != nil { 610 return 611 } 612 port := 5432 613 conn, err = sql.Open("postgres", fmt.Sprintf("user=omnigres password=omnigres dbname=%s host=%s port=%d sslmode=disable", db, ip, port)) 614 return 615 } 616 617 func (d *DockerOrbCluster) Endpoints(ctx context.Context) (endpoints []Endpoint, err error) { 618 var addr string 619 addr, err = d.NetworkIP(ctx) 620 if err != nil { 621 return 622 } 623 ipaddr := net.ParseIP(addr) 624 endpoints = make([]Endpoint, 0) 625 var conn *sql.DB 626 conn, err = d.Connect(ctx) 627 if err != nil { 628 return 629 } 630 defer conn.Close() 631 632 var rows *sql.Rows 633 // Search for all databases 634 rows, err = conn.QueryContext(ctx, `select datname from pg_database where not datistemplate and datname != 'postgres'`) 635 if err != nil { 636 return 637 } 638 defer rows.Close() 639 nextDatabase: 640 for rows.Next() { 641 var datname string 642 if err = rows.Scan(&datname); err != nil { 643 return 644 } 645 // For every database 646 var dbconn *sql.DB 647 dbconn, err = d.Connect(ctx, datname) 648 if err != nil { 649 return 650 } 651 defer dbconn.Close() 652 // Add the Postgres service 653 endpoints = append(endpoints, Endpoint{Database: datname, IP: ipaddr, Port: 5432, Protocol: "Postgres"}) 654 // Get the list of HTTP listeners. 655 // TODO: in the future, we expect this to be generialized through omni_service 656 var portRows *sql.Rows 657 portRows, err = dbconn.QueryContext(ctx, "select effective_port from omni_httpd.listeners") 658 if err != nil { 659 err = nil 660 continue nextDatabase 661 } 662 defer portRows.Close() 663 for portRows.Next() { 664 var port int 665 err = portRows.Scan(&port) 666 if err != nil { 667 return 668 } 669 endpoints = append(endpoints, Endpoint{Database: datname, IP: ipaddr, Port: port, Protocol: "HTTP"}) 670 } 671 672 } 673 return 674 }