github.com/Redstoneguy129/cli@v0.0.0-20230211220159-15dca4e91917/internal/db/diff/migra.go (about)

     1  package diff
     2  
     3  import (
     4  	"context"
     5  	_ "embed"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net/url"
    10  	"os"
    11  	"path/filepath"
    12  	"strconv"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/docker/docker/api/types/container"
    17  	"github.com/docker/go-connections/nat"
    18  	"github.com/jackc/pgconn"
    19  	"github.com/jackc/pgx/v4"
    20  	"github.com/spf13/afero"
    21  	"github.com/Redstoneguy129/cli/internal/migration/list"
    22  	"github.com/Redstoneguy129/cli/internal/utils"
    23  	"github.com/Redstoneguy129/cli/internal/utils/parser"
    24  )
    25  
    26  const LIST_SCHEMAS = "SELECT schema_name FROM information_schema.schemata WHERE NOT schema_name = ANY($1) ORDER BY schema_name"
    27  
    28  var (
    29  	//go:embed templates/migra.sh
    30  	diffSchemaScript string
    31  )
    32  
    33  func RunMigra(ctx context.Context, schema []string, file, password string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    34  	// Sanity checks.
    35  	if err := utils.LoadConfigFS(fsys); err != nil {
    36  		return err
    37  	}
    38  	// 1. Determine local or remote target
    39  	target, err := buildTargetUrl(password, fsys)
    40  	if err != nil {
    41  		return err
    42  	}
    43  	// 2. Load all user defined schemas
    44  	if len(schema) == 0 {
    45  		var conn *pgx.Conn
    46  		if len(password) > 0 {
    47  			options = append(options, func(cc *pgx.ConnConfig) {
    48  				cc.PreferSimpleProtocol = true
    49  			})
    50  			conn, err = utils.ConnectByUrl(ctx, target, options...)
    51  		} else {
    52  			conn, err = utils.ConnectLocalPostgres(ctx, utils.Config.Hostname, utils.Config.Db.Port, "postgres", options...)
    53  		}
    54  		if err != nil {
    55  			return err
    56  		}
    57  		defer conn.Close(context.Background())
    58  		schema, err = LoadUserSchemas(ctx, conn)
    59  		if err != nil {
    60  			return err
    61  		}
    62  	}
    63  	// 3. Run migra to diff schema
    64  	out, err := DiffDatabase(ctx, schema, target, os.Stderr, fsys, options...)
    65  	if err != nil {
    66  		return err
    67  	}
    68  	branch, err := utils.GetCurrentBranchFS(fsys)
    69  	if err != nil {
    70  		branch = "main"
    71  	}
    72  	fmt.Fprintln(os.Stderr, "Finished "+utils.Aqua("supabase db diff")+" on branch "+utils.Aqua(branch)+".\n")
    73  	return SaveDiff(out, file, fsys)
    74  }
    75  
    76  // Builds a postgres connection string for local or remote database
    77  func buildTargetUrl(password string, fsys afero.Fs) (target string, err error) {
    78  	if len(password) > 0 {
    79  		ref, err := utils.LoadProjectRef(fsys)
    80  		if err != nil {
    81  			return target, err
    82  		}
    83  		target = fmt.Sprintf(
    84  			"postgresql://%s@%s:6543/postgres",
    85  			url.UserPassword("postgres", password),
    86  			utils.GetSupabaseDbHost(ref),
    87  		)
    88  		fmt.Fprintln(os.Stderr, "Connecting to linked project...")
    89  	} else {
    90  		if err := utils.AssertSupabaseDbIsRunning(); err != nil {
    91  			return target, err
    92  		}
    93  		target = "postgresql://postgres:postgres@" + utils.DbId + ":5432/postgres"
    94  		fmt.Fprintln(os.Stderr, "Connecting to local database...")
    95  	}
    96  	return target, err
    97  }
    98  
    99  func LoadUserSchemas(ctx context.Context, conn *pgx.Conn, exclude ...string) ([]string, error) {
   100  	// Include auth,storage,extensions by default for RLS policies
   101  	if len(exclude) == 0 {
   102  		exclude = append([]string{
   103  			"pgbouncer",
   104  			"realtime",
   105  			"_realtime",
   106  			// Exclude functions because Webhooks support is early alpha
   107  			"supabase_functions",
   108  			"supabase_migrations",
   109  		}, utils.SystemSchemas...)
   110  	}
   111  	rows, err := conn.Query(ctx, LIST_SCHEMAS, exclude)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  	schemas := []string{}
   116  	for rows.Next() {
   117  		var name string
   118  		if err := rows.Scan(&name); err != nil {
   119  			return nil, err
   120  		}
   121  		schemas = append(schemas, name)
   122  	}
   123  	return schemas, nil
   124  }
   125  
   126  func CreateShadowDatabase(ctx context.Context) (string, error) {
   127  	config := container.Config{
   128  		Image: utils.DbImage,
   129  		Env:   []string{"POSTGRES_PASSWORD=postgres"},
   130  	}
   131  	if utils.Config.Db.MajorVersion >= 14 {
   132  		config.Cmd = []string{"postgres",
   133  			"-c", "config_file=/etc/postgresql/postgresql.conf",
   134  			// Ref: https://postgrespro.com/list/thread-id/2448092
   135  			"-c", `search_path="$user",public,extensions`,
   136  		}
   137  	}
   138  	hostPort := strconv.FormatUint(uint64(utils.Config.Db.ShadowPort), 10)
   139  	hostConfig := container.HostConfig{
   140  		PortBindings: nat.PortMap{"5432/tcp": []nat.PortBinding{{HostPort: hostPort}}},
   141  		Binds:        []string{"/dev/null:/docker-entrypoint-initdb.d/migrate.sh:ro"},
   142  		AutoRemove:   true,
   143  	}
   144  	return utils.DockerStart(ctx, config, hostConfig, "")
   145  }
   146  
   147  func connectShadowDatabase(ctx context.Context, timeout time.Duration, options ...func(*pgx.ConnConfig)) (conn *pgx.Conn, err error) {
   148  	now := time.Now()
   149  	expiry := now.Add(timeout)
   150  	ticker := time.NewTicker(time.Second)
   151  	defer ticker.Stop()
   152  	// Retry until connected, cancelled, or timeout
   153  	for t := now; t.Before(expiry); t = <-ticker.C {
   154  		conn, err = utils.ConnectLocalPostgres(ctx, utils.Config.Hostname, utils.Config.Db.ShadowPort, "postgres", options...)
   155  		if err == nil || errors.Is(ctx.Err(), context.Canceled) {
   156  			break
   157  		}
   158  	}
   159  	return conn, err
   160  }
   161  
   162  func MigrateShadowDatabase(ctx context.Context, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
   163  	conn, err := connectShadowDatabase(ctx, 10*time.Second, options...)
   164  	if err != nil {
   165  		return err
   166  	}
   167  	defer conn.Close(context.Background())
   168  	if err := BatchExecDDL(ctx, conn, strings.NewReader(utils.GlobalsSql)); err != nil {
   169  		return err
   170  	}
   171  	if err := BatchExecDDL(ctx, conn, strings.NewReader(utils.InitialSchemaSql)); err != nil {
   172  		return err
   173  	}
   174  	return MigrateDatabase(ctx, conn, fsys)
   175  }
   176  
   177  // Applies local migration scripts to a database.
   178  func ApplyMigrations(ctx context.Context, url string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
   179  	// Parse connection url
   180  	config, err := pgx.ParseConfig(url)
   181  	if err != nil {
   182  		return err
   183  	}
   184  	// Apply config overrides
   185  	for _, op := range options {
   186  		op(config)
   187  	}
   188  	// Connect to database
   189  	conn, err := pgx.ConnectConfig(ctx, config)
   190  	if err != nil {
   191  		return err
   192  	}
   193  	defer conn.Close(context.Background())
   194  	return MigrateDatabase(ctx, conn, fsys)
   195  }
   196  
   197  func MigrateDatabase(ctx context.Context, conn *pgx.Conn, fsys afero.Fs) error {
   198  	migrations, err := list.LoadLocalMigrations(fsys)
   199  	if err != nil {
   200  		return err
   201  	}
   202  	// Apply migrations
   203  	for _, filename := range migrations {
   204  		if err := migrateUp(ctx, conn, filename, fsys); err != nil {
   205  			return err
   206  		}
   207  	}
   208  	return nil
   209  }
   210  
   211  func migrateUp(ctx context.Context, conn *pgx.Conn, filename string, fsys afero.Fs) error {
   212  	fmt.Fprintln(os.Stderr, "Applying migration "+utils.Bold(filename)+"...")
   213  	sql, err := fsys.Open(filepath.Join(utils.MigrationsDir, filename))
   214  	if err != nil {
   215  		return err
   216  	}
   217  	defer sql.Close()
   218  	return BatchExecDDL(ctx, conn, sql)
   219  }
   220  
   221  func BatchExecDDL(ctx context.Context, conn *pgx.Conn, sql io.Reader) error {
   222  	lines, err := parser.SplitAndTrim(sql)
   223  	if err != nil {
   224  		return err
   225  	}
   226  	// Batch migration commands, without using statement cache
   227  	batch := pgconn.Batch{}
   228  	for _, line := range lines {
   229  		batch.ExecParams(line, nil, nil, nil, nil)
   230  	}
   231  	if result, err := conn.PgConn().ExecBatch(ctx, &batch).ReadAll(); err != nil {
   232  		i := len(result)
   233  		var stat string
   234  		if i < len(lines) {
   235  			stat = lines[i]
   236  		}
   237  		return fmt.Errorf("%v\nAt statement %d: %s", err, i, utils.Aqua(stat))
   238  	}
   239  	return nil
   240  }
   241  
   242  // Diffs local database schema against shadow, dumps output to stdout.
   243  func DiffSchemaMigra(ctx context.Context, source, target string, schema []string) (string, error) {
   244  	env := []string{"SOURCE=" + source, "TARGET=" + target}
   245  	// Passing in script string means command line args must be set manually, ie. "$@"
   246  	args := "set -- " + strings.Join(schema, " ") + ";"
   247  	cmd := []string{"/bin/sh", "-c", args + diffSchemaScript}
   248  	out, err := utils.DockerRunOnce(ctx, utils.MigraImage, env, cmd)
   249  	if err != nil {
   250  		return "", errors.New("error diffing schema: " + err.Error())
   251  	}
   252  	return out, nil
   253  }
   254  
   255  func DiffDatabase(ctx context.Context, schema []string, target string, w io.Writer, fsys afero.Fs, options ...func(*pgx.ConnConfig)) (string, error) {
   256  	fmt.Fprintln(w, "Creating shadow database...")
   257  	shadow, err := CreateShadowDatabase(ctx)
   258  	if err != nil {
   259  		return "", err
   260  	}
   261  	defer utils.DockerRemove(shadow)
   262  	if err := MigrateShadowDatabase(ctx, fsys, options...); err != nil {
   263  		return "", err
   264  	}
   265  	fmt.Fprintln(w, "Diffing schemas:", strings.Join(schema, ","))
   266  	source := "postgresql://postgres:postgres@" + shadow[:12] + ":5432/postgres"
   267  	return DiffSchemaMigra(ctx, source, target, schema)
   268  }