github.com/supabase/cli@v1.168.1/internal/db/reset/reset.go (about)

     1  package reset
     2  
     3  import (
     4  	"context"
     5  	_ "embed"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/docker/docker/api/types/container"
    14  	"github.com/docker/docker/api/types/network"
    15  	"github.com/docker/docker/errdefs"
    16  	"github.com/docker/docker/pkg/stdcopy"
    17  	"github.com/go-errors/errors"
    18  	"github.com/jackc/pgconn"
    19  	"github.com/jackc/pgerrcode"
    20  	"github.com/jackc/pgx/v4"
    21  	"github.com/spf13/afero"
    22  	"github.com/supabase/cli/internal/db/start"
    23  	"github.com/supabase/cli/internal/gen/keys"
    24  	"github.com/supabase/cli/internal/migration/apply"
    25  	"github.com/supabase/cli/internal/migration/repair"
    26  	"github.com/supabase/cli/internal/status"
    27  	"github.com/supabase/cli/internal/utils"
    28  	"github.com/supabase/cli/internal/utils/pgxv5"
    29  )
    30  
    31  var (
    32  	ErrUnhealthy   = errors.New("service not healthy")
    33  	serviceTimeout = 30 * time.Second
    34  	//go:embed templates/drop.sql
    35  	dropObjects string
    36  	//go:embed templates/list.sql
    37  	ListSchemas string
    38  )
    39  
    40  func Run(ctx context.Context, version string, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    41  	if len(version) > 0 {
    42  		if _, err := strconv.Atoi(version); err != nil {
    43  			return errors.New(repair.ErrInvalidVersion)
    44  		}
    45  		if _, err := repair.GetMigrationFile(version, fsys); err != nil {
    46  			return err
    47  		}
    48  	}
    49  	if !utils.IsLocalDatabase(config) {
    50  		msg := "Do you want to reset the remote database?"
    51  		if shouldReset := utils.NewConsole().PromptYesNo(msg, false); !shouldReset {
    52  			return errors.New(context.Canceled)
    53  		}
    54  		return resetRemote(ctx, version, config, fsys, options...)
    55  	}
    56  
    57  	// Config file is loaded before parsing --linked or --local flags
    58  	if err := utils.AssertSupabaseDbIsRunning(); err != nil {
    59  		return err
    60  	}
    61  
    62  	// Reset postgres database because extensions (pg_cron, pg_net) require postgres
    63  	if err := resetDatabase(ctx, version, fsys, options...); err != nil {
    64  		return err
    65  	}
    66  
    67  	branch := keys.GetGitBranch(fsys)
    68  	fmt.Fprintln(os.Stderr, "Finished "+utils.Aqua("supabase db reset")+" on branch "+utils.Aqua(branch)+".")
    69  	return nil
    70  }
    71  
    72  func resetDatabase(ctx context.Context, version string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    73  	fmt.Fprintln(os.Stderr, "Resetting local database"+toLogMessage(version))
    74  	if utils.Config.Db.MajorVersion <= 14 {
    75  		return resetDatabase14(ctx, version, fsys, options...)
    76  	}
    77  	return resetDatabase15(ctx, version, fsys, options...)
    78  }
    79  
    80  func toLogMessage(version string) string {
    81  	if len(version) > 0 {
    82  		return " to version: " + version
    83  	}
    84  	return "..."
    85  }
    86  
    87  func resetDatabase14(ctx context.Context, version string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    88  	if err := recreateDatabase(ctx, options...); err != nil {
    89  		return err
    90  	}
    91  	if err := initDatabase(ctx, options...); err != nil {
    92  		return err
    93  	}
    94  	if err := RestartDatabase(ctx, os.Stderr); err != nil {
    95  		return err
    96  	}
    97  	conn, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{}, options...)
    98  	if err != nil {
    99  		return err
   100  	}
   101  	defer conn.Close(context.Background())
   102  	if utils.Config.Db.MajorVersion > 14 {
   103  		if err := start.SetupDatabase(ctx, conn, utils.DbId, os.Stderr, fsys); err != nil {
   104  			return err
   105  		}
   106  	}
   107  	return apply.MigrateAndSeed(ctx, version, conn, fsys)
   108  }
   109  
   110  func resetDatabase15(ctx context.Context, version string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
   111  	if err := utils.Docker.ContainerRemove(ctx, utils.DbId, container.RemoveOptions{Force: true}); err != nil {
   112  		return errors.Errorf("failed to remove container: %w", err)
   113  	}
   114  	if err := utils.Docker.VolumeRemove(ctx, utils.DbId, true); err != nil {
   115  		return errors.Errorf("failed to remove volume: %w", err)
   116  	}
   117  	// Skip syslog if vector container is not started
   118  	if _, err := utils.Docker.ContainerInspect(ctx, utils.VectorId); err != nil {
   119  		utils.Config.Analytics.Enabled = false
   120  	}
   121  	config := start.NewContainerConfig()
   122  	hostConfig := start.NewHostConfig()
   123  	networkingConfig := network.NetworkingConfig{
   124  		EndpointsConfig: map[string]*network.EndpointSettings{
   125  			utils.NetId: {
   126  				Aliases: utils.DbAliases,
   127  			},
   128  		},
   129  	}
   130  	fmt.Fprintln(os.Stderr, "Recreating database...")
   131  	if _, err := utils.DockerStart(ctx, config, hostConfig, networkingConfig, utils.DbId); err != nil {
   132  		return err
   133  	}
   134  	if !start.WaitForHealthyService(ctx, utils.DbId, start.HealthTimeout) {
   135  		return errors.New(start.ErrDatabase)
   136  	}
   137  	conn, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{}, options...)
   138  	if err != nil {
   139  		return err
   140  	}
   141  	defer conn.Close(context.Background())
   142  	if err := start.SetupDatabase(ctx, conn, utils.DbId, os.Stderr, fsys); err != nil {
   143  		return err
   144  	}
   145  	if err := apply.MigrateAndSeed(ctx, version, conn, fsys); err != nil {
   146  		return err
   147  	}
   148  	fmt.Fprintln(os.Stderr, "Restarting containers...")
   149  	return restartServices(ctx)
   150  }
   151  
   152  func initDatabase(ctx context.Context, options ...func(*pgx.ConnConfig)) error {
   153  	conn, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{User: "supabase_admin"}, options...)
   154  	if err != nil {
   155  		return err
   156  	}
   157  	defer conn.Close(context.Background())
   158  	return apply.BatchExecDDL(ctx, conn, strings.NewReader(utils.InitialSchemaSql))
   159  }
   160  
   161  // Recreate postgres database by connecting to template1
   162  func recreateDatabase(ctx context.Context, options ...func(*pgx.ConnConfig)) error {
   163  	conn, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{User: "supabase_admin", Database: "template1"}, options...)
   164  	if err != nil {
   165  		return err
   166  	}
   167  	defer conn.Close(context.Background())
   168  	if err := DisconnectClients(ctx, conn); err != nil {
   169  		return err
   170  	}
   171  	// We are not dropping roles here because they are cluster level entities. Use stop && start instead.
   172  	sql := repair.MigrationFile{
   173  		Lines: []string{
   174  			"DROP DATABASE IF EXISTS postgres WITH (FORCE)",
   175  			"CREATE DATABASE postgres WITH OWNER postgres",
   176  		},
   177  	}
   178  	return sql.ExecBatch(ctx, conn)
   179  }
   180  
   181  func DisconnectClients(ctx context.Context, conn *pgx.Conn) error {
   182  	// Must be executed separately because running in transaction is unsupported
   183  	disconn := "ALTER DATABASE postgres ALLOW_CONNECTIONS false;"
   184  	if _, err := conn.Exec(ctx, disconn); err != nil {
   185  		var pgErr *pgconn.PgError
   186  		if errors.As(err, &pgErr) && pgErr.Code != pgerrcode.InvalidCatalogName {
   187  			return errors.Errorf("failed to disconnect clients: %w", err)
   188  		}
   189  	}
   190  	term := fmt.Sprintf(utils.TerminateDbSqlFmt, "postgres")
   191  	if _, err := conn.Exec(ctx, term); err != nil {
   192  		return errors.Errorf("failed to terminate backend: %w", err)
   193  	}
   194  	return nil
   195  }
   196  
   197  func RestartDatabase(ctx context.Context, w io.Writer) error {
   198  	fmt.Fprintln(w, "Restarting containers...")
   199  	// Some extensions must be manually restarted after pg_terminate_backend
   200  	// Ref: https://github.com/citusdata/pg_cron/issues/99
   201  	if err := utils.Docker.ContainerRestart(ctx, utils.DbId, container.StopOptions{}); err != nil {
   202  		return errors.Errorf("failed to restart container: %w", err)
   203  	}
   204  	if !start.WaitForHealthyService(ctx, utils.DbId, start.HealthTimeout) {
   205  		return errors.New(start.ErrDatabase)
   206  	}
   207  	return restartServices(ctx)
   208  }
   209  
   210  func restartServices(ctx context.Context) error {
   211  	// No need to restart PostgREST because it automatically reconnects and listens for schema changes
   212  	services := []string{utils.StorageId, utils.GotrueId, utils.RealtimeId}
   213  	result := utils.WaitAll(services, func(id string) error {
   214  		if err := utils.Docker.ContainerRestart(ctx, id, container.StopOptions{}); err != nil && !errdefs.IsNotFound(err) {
   215  			return errors.Errorf("Failed to restart %s: %w", id, err)
   216  		}
   217  		return nil
   218  	})
   219  	// Do not wait for service healthy as those services may be excluded from starting
   220  	return errors.Join(result...)
   221  }
   222  
   223  func WaitForServiceReady(ctx context.Context, started []string) error {
   224  	probe := func() bool {
   225  		var unhealthy []string
   226  		for _, container := range started {
   227  			if !status.IsServiceReady(ctx, container) {
   228  				unhealthy = append(unhealthy, container)
   229  			}
   230  		}
   231  		started = unhealthy
   232  		return len(started) == 0
   233  	}
   234  	if !start.RetryEverySecond(ctx, probe, serviceTimeout) {
   235  		// Print container logs for easier debugging
   236  		for _, containerId := range started {
   237  			logs, err := utils.Docker.ContainerLogs(ctx, containerId, container.LogsOptions{
   238  				ShowStdout: true,
   239  				ShowStderr: true,
   240  			})
   241  			if err != nil {
   242  				fmt.Fprintln(os.Stderr, err)
   243  				continue
   244  			}
   245  			fmt.Fprintln(os.Stderr, containerId, "container logs:")
   246  			if _, err := stdcopy.StdCopy(os.Stderr, os.Stderr, logs); err != nil {
   247  				fmt.Fprintln(os.Stderr, err)
   248  			}
   249  			logs.Close()
   250  		}
   251  		return errors.Errorf("%w: %v", ErrUnhealthy, started)
   252  	}
   253  	return nil
   254  }
   255  
   256  func resetRemote(ctx context.Context, version string, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
   257  	fmt.Fprintln(os.Stderr, "Resetting remote database"+toLogMessage(version))
   258  	conn, err := utils.ConnectByConfigStream(ctx, config, io.Discard, options...)
   259  	if err != nil {
   260  		return err
   261  	}
   262  	defer conn.Close(context.Background())
   263  	// Only drop objects in extensions and public schema
   264  	excludes := append([]string{
   265  		"extensions",
   266  		"public",
   267  	}, utils.ManagedSchemas...)
   268  	userSchemas, err := LoadUserSchemas(ctx, conn, excludes...)
   269  	if err != nil {
   270  		return err
   271  	}
   272  	// Drop all user defined schemas
   273  	migration := repair.MigrationFile{}
   274  	for _, schema := range userSchemas {
   275  		sql := fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schema)
   276  		migration.Lines = append(migration.Lines, sql)
   277  	}
   278  	// If an extension uses a schema it doesn't create, dropping the schema will cascade to also
   279  	// drop the extension. But if an extension creates its own schema, dropping the schema will
   280  	// throw an error. Hence, we drop the extension instead so it cascades to its own schema.
   281  	migration.Lines = append(migration.Lines, dropObjects)
   282  	if err := migration.ExecBatch(ctx, conn); err != nil {
   283  		return err
   284  	}
   285  	return apply.MigrateAndSeed(ctx, version, conn, fsys)
   286  }
   287  
   288  func LoadUserSchemas(ctx context.Context, conn *pgx.Conn, exclude ...string) ([]string, error) {
   289  	if len(exclude) == 0 {
   290  		exclude = utils.ManagedSchemas
   291  	}
   292  	exclude = LikeEscapeSchema(exclude)
   293  	rows, err := conn.Query(ctx, ListSchemas, exclude)
   294  	if err != nil {
   295  		return nil, errors.Errorf("failed to list schemas: %w", err)
   296  	}
   297  	// TODO: show detail and hint from pgconn.PgError
   298  	return pgxv5.CollectStrings(rows)
   299  }
   300  
   301  func LikeEscapeSchema(schemas []string) (result []string) {
   302  	// Treat _ as literal, * as any character
   303  	replacer := strings.NewReplacer("_", `\_`, "*", "%")
   304  	for _, sch := range schemas {
   305  		result = append(result, replacer.Replace(sch))
   306  	}
   307  	return result
   308  }