github.com/Redstoneguy129/cli@v0.0.0-20230211220159-15dca4e91917/internal/db/reset/reset.go (about)

     1  package reset
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"os"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/jackc/pgconn"
    12  	"github.com/jackc/pgerrcode"
    13  	"github.com/jackc/pgx/v4"
    14  	"github.com/spf13/afero"
    15  	"github.com/Redstoneguy129/cli/internal/db/diff"
    16  	"github.com/Redstoneguy129/cli/internal/status"
    17  	"github.com/Redstoneguy129/cli/internal/utils"
    18  	"github.com/Redstoneguy129/cli/internal/utils/parser"
    19  )
    20  
    21  var (
    22  	healthTimeout = 5 * time.Second
    23  )
    24  
    25  func Run(ctx context.Context, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    26  	// Sanity checks.
    27  	{
    28  		if err := utils.LoadConfigFS(fsys); err != nil {
    29  			return err
    30  		}
    31  		if err := utils.AssertSupabaseDbIsRunning(); err != nil {
    32  			return err
    33  		}
    34  	}
    35  
    36  	// Reset postgres database because extensions (pg_cron, pg_net) require postgres
    37  	{
    38  		fmt.Fprintln(os.Stderr, "Resetting database...")
    39  		if err := RecreateDatabase(ctx, options...); err != nil {
    40  			return err
    41  		}
    42  		defer RestartDatabase(context.Background())
    43  		if err := resetDatabase(ctx, fsys, options...); err != nil {
    44  			return err
    45  		}
    46  	}
    47  
    48  	branch, err := utils.GetCurrentBranchFS(fsys)
    49  	if err != nil {
    50  		// Assume we are on main branch
    51  		branch = "main"
    52  	}
    53  	fmt.Fprintln(os.Stderr, "Finished "+utils.Aqua("supabase db reset")+" on branch "+utils.Aqua(branch)+".")
    54  
    55  	return nil
    56  }
    57  
    58  func resetDatabase(ctx context.Context, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    59  	conn, err := utils.ConnectLocalPostgres(ctx, utils.Config.Hostname, utils.Config.Db.Port, "postgres", options...)
    60  	if err != nil {
    61  		return err
    62  	}
    63  	defer conn.Close(context.Background())
    64  	fmt.Fprintln(os.Stderr, "Initialising schema...")
    65  	return InitialiseDatabase(ctx, conn, fsys)
    66  }
    67  
    68  func InitialiseDatabase(ctx context.Context, conn *pgx.Conn, fsys afero.Fs) error {
    69  	if err := diff.BatchExecDDL(ctx, conn, strings.NewReader(utils.InitialSchemaSql)); err != nil {
    70  		return err
    71  	}
    72  	if err := diff.MigrateDatabase(ctx, conn, fsys); err != nil {
    73  		return err
    74  	}
    75  	return SeedDatabase(ctx, conn, fsys)
    76  }
    77  
    78  // Recreate postgres database by connecting to template1
    79  func RecreateDatabase(ctx context.Context, options ...func(*pgx.ConnConfig)) error {
    80  	conn, err := utils.ConnectLocalPostgres(ctx, utils.Config.Hostname, utils.Config.Db.Port, "template1", options...)
    81  	if err != nil {
    82  		return err
    83  	}
    84  	defer conn.Close(context.Background())
    85  	if err := DisconnectClients(ctx, conn); err != nil {
    86  		return err
    87  	}
    88  	drop := "DROP DATABASE IF EXISTS postgres WITH (FORCE);"
    89  	if _, err := conn.Exec(ctx, drop); err != nil {
    90  		return err
    91  	}
    92  	_, err = conn.Exec(ctx, "CREATE DATABASE postgres;")
    93  	return err
    94  }
    95  
    96  func SeedDatabase(ctx context.Context, conn *pgx.Conn, fsys afero.Fs) error {
    97  	sql, err := fsys.Open(utils.SeedDataPath)
    98  	if errors.Is(err, os.ErrNotExist) {
    99  		return nil
   100  	} else if err != nil {
   101  		return err
   102  	}
   103  	defer sql.Close()
   104  	fmt.Fprintln(os.Stderr, "Seeding data "+utils.Bold(utils.SeedDataPath)+"...")
   105  	lines, err := parser.SplitAndTrim(sql)
   106  	if err != nil {
   107  		return err
   108  	}
   109  	// Batch seed commands, safe to use statement cache
   110  	batch := pgx.Batch{}
   111  	for _, line := range lines {
   112  		batch.Queue(line)
   113  	}
   114  	return conn.SendBatch(ctx, &batch).Close()
   115  }
   116  
   117  func DisconnectClients(ctx context.Context, conn *pgx.Conn) error {
   118  	// Must be executed separately because running in transaction is unsupported
   119  	disconn := "ALTER DATABASE postgres ALLOW_CONNECTIONS false;"
   120  	if _, err := conn.Exec(ctx, disconn); err != nil {
   121  		var pgErr *pgconn.PgError
   122  		if errors.As(err, &pgErr) && pgErr.Code != pgerrcode.InvalidCatalogName {
   123  			return err
   124  		}
   125  	}
   126  	term := fmt.Sprintf(utils.TerminateDbSqlFmt, "postgres")
   127  	if _, err := conn.Exec(ctx, term); err != nil {
   128  		return err
   129  	}
   130  	return nil
   131  }
   132  
   133  func RestartDatabase(ctx context.Context) {
   134  	// Some extensions must be manually restarted after pg_terminate_backend
   135  	// Ref: https://github.com/citusdata/pg_cron/issues/99
   136  	if err := utils.Docker.ContainerRestart(ctx, utils.DbId, nil); err != nil {
   137  		fmt.Fprintln(os.Stderr, "Failed to restart database:", err)
   138  		return
   139  	}
   140  	if !WaitForHealthyService(ctx, utils.DbId, healthTimeout) {
   141  		fmt.Fprintln(os.Stderr, "Database is not healthy.")
   142  		return
   143  	}
   144  	// TODO: update storage-api to handle postgres restarts
   145  	if err := utils.Docker.ContainerRestart(ctx, utils.StorageId, nil); err != nil {
   146  		fmt.Fprintln(os.Stderr, "Failed to restart storage-api:", err)
   147  	}
   148  	// Reload PostgREST schema cache.
   149  	if err := utils.Docker.ContainerKill(ctx, utils.RestId, "SIGUSR1"); err != nil {
   150  		fmt.Fprintln(os.Stderr, "Error reloading PostgREST schema cache:", err)
   151  	}
   152  }
   153  
   154  func RetryEverySecond(ctx context.Context, callback func() bool, timeout time.Duration) bool {
   155  	now := time.Now()
   156  	expiry := now.Add(timeout)
   157  	ticker := time.NewTicker(time.Second)
   158  	defer ticker.Stop()
   159  	for t := now; t.Before(expiry) && ctx.Err() == nil; t = <-ticker.C {
   160  		if callback() {
   161  			return true
   162  		}
   163  	}
   164  	return false
   165  }
   166  
   167  func WaitForHealthyService(ctx context.Context, container string, timeout time.Duration) bool {
   168  	probe := func() bool {
   169  		return status.AssertContainerHealthy(ctx, container) == nil
   170  	}
   171  	return RetryEverySecond(ctx, probe, timeout)
   172  }