github.com/Redstoneguy129/cli@v0.0.0-20230211220159-15dca4e91917/internal/db/push/push.go (about)

     1  package push
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"os"
     8  	"path/filepath"
     9  
    10  	"github.com/jackc/pgconn"
    11  	"github.com/jackc/pgx/v4"
    12  	"github.com/spf13/afero"
    13  	"github.com/Redstoneguy129/cli/internal/migration/list"
    14  	"github.com/Redstoneguy129/cli/internal/migration/repair"
    15  	"github.com/Redstoneguy129/cli/internal/utils"
    16  	"github.com/Redstoneguy129/cli/internal/utils/parser"
    17  )
    18  
    19  var (
    20  	errConflict = errors.New("supabase_migrations.schema_migrations table conflicts with the contents of " + utils.Bold(utils.MigrationsDir) + ".")
    21  )
    22  
    23  func Run(ctx context.Context, dryRun bool, username, password, database, host string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    24  	if dryRun {
    25  		fmt.Fprintln(os.Stderr, "DRY RUN: migrations will *not* be pushed to the database.")
    26  	}
    27  	conn, err := utils.ConnectRemotePostgres(ctx, username, password, database, host, options...)
    28  	if err != nil {
    29  		return err
    30  	}
    31  	defer conn.Close(context.Background())
    32  	pending, err := getPendingMigrations(ctx, conn, fsys)
    33  	if err != nil {
    34  		return err
    35  	}
    36  	if len(pending) == 0 {
    37  		fmt.Println("Linked project is up to date.")
    38  		return nil
    39  	}
    40  	// Push pending migrations
    41  	for _, filename := range pending {
    42  		if dryRun {
    43  			fmt.Fprintln(os.Stderr, "Would push migration "+utils.Bold(filename)+"...")
    44  			continue
    45  		}
    46  		if err := pushMigration(ctx, conn, filename, fsys); err != nil {
    47  			return err
    48  		}
    49  	}
    50  	fmt.Println("Finished " + utils.Aqua("supabase db push") + ".")
    51  	return nil
    52  }
    53  
    54  func getPendingMigrations(ctx context.Context, conn *pgx.Conn, fsys afero.Fs) ([]string, error) {
    55  	remoteMigrations, err := list.LoadRemoteMigrations(ctx, conn)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	localMigrations, err := list.LoadLocalMigrations(fsys)
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  	// Check remote is in-sync or behind local
    64  	if len(remoteMigrations) > len(localMigrations) {
    65  		return nil, fmt.Errorf("%w; Found %d versions and %d migrations.", errConflict, len(remoteMigrations), len(localMigrations))
    66  	}
    67  	for i, remote := range remoteMigrations {
    68  		filename := localMigrations[i]
    69  		// LoadLocalMigrations guarantees we always have a match
    70  		local := utils.MigrateFilePattern.FindStringSubmatch(filename)[1]
    71  		if remote != local {
    72  			return nil, fmt.Errorf("%w; Expected version %s but found migration %s at index %d.", errConflict, remote, filename, i)
    73  		}
    74  	}
    75  	return localMigrations[len(remoteMigrations):], nil
    76  }
    77  
    78  func pushMigration(ctx context.Context, conn *pgx.Conn, filename string, fsys afero.Fs) error {
    79  	fmt.Fprintln(os.Stderr, "Pushing migration "+utils.Bold(filename)+"...")
    80  	sql, err := fsys.Open(filepath.Join(utils.MigrationsDir, filename))
    81  	if err != nil {
    82  		return err
    83  	}
    84  	lines, err := parser.SplitAndTrim(sql)
    85  	if err != nil {
    86  		return err
    87  	}
    88  	batch := pgconn.Batch{}
    89  	for _, line := range lines {
    90  		batch.ExecParams(line, nil, nil, nil, nil)
    91  	}
    92  	// Insert into migration history
    93  	lines = append(lines, repair.INSERT_MIGRATION_VERSION)
    94  	version := utils.MigrateFilePattern.FindStringSubmatch(filename)[1]
    95  	repair.InsertVersionSQL(&batch, version)
    96  	// ExecBatch is implicitly transactional
    97  	if result, err := conn.PgConn().ExecBatch(ctx, &batch).ReadAll(); err != nil {
    98  		i := len(result)
    99  		var stat string
   100  		if i < len(lines) {
   101  			stat = lines[i]
   102  		}
   103  		return fmt.Errorf("%v\nAt statement %d: %s", err, i, utils.Aqua(stat))
   104  	}
   105  	return nil
   106  }