github.com/supabase/cli@v1.168.1/internal/migration/repair/repair.go (about)

     1  package repair
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"path/filepath"
     9  	"strconv"
    10  
    11  	"github.com/go-errors/errors"
    12  	"github.com/jackc/pgconn"
    13  	"github.com/jackc/pgtype"
    14  	"github.com/jackc/pgx/v4"
    15  	"github.com/spf13/afero"
    16  	"github.com/spf13/viper"
    17  	"github.com/supabase/cli/internal/migration/history"
    18  	"github.com/supabase/cli/internal/migration/list"
    19  	"github.com/supabase/cli/internal/utils"
    20  	"github.com/supabase/cli/internal/utils/parser"
    21  )
    22  
    23  const (
    24  	Applied  = "applied"
    25  	Reverted = "reverted"
    26  )
    27  
    28  var ErrInvalidVersion = errors.New("invalid version number")
    29  
    30  func Run(ctx context.Context, config pgconn.Config, version []string, status string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    31  	for _, v := range version {
    32  		if _, err := strconv.Atoi(v); err != nil {
    33  			return errors.Errorf("failed to parse %s: %w", v, ErrInvalidVersion)
    34  		}
    35  	}
    36  	repairAll := len(version) == 0
    37  	if repairAll {
    38  		msg := "Do you want to repair the entire migration history table to match local migration files?"
    39  		if shouldRepair := utils.NewConsole().PromptYesNo(msg, false); !shouldRepair {
    40  			return errors.New(context.Canceled)
    41  		}
    42  		local, err := list.LoadLocalVersions(fsys)
    43  		if err != nil {
    44  			return err
    45  		}
    46  		version = append(version, local...)
    47  	}
    48  	conn, err := utils.ConnectByConfig(ctx, config, options...)
    49  	if err != nil {
    50  		return err
    51  	}
    52  	defer conn.Close(context.Background())
    53  	// Update migration history
    54  	if err = UpdateMigrationTable(ctx, conn, version, status, repairAll, fsys); err == nil {
    55  		utils.CmdSuggestion = fmt.Sprintf("Run %s to show the updated migration history.", utils.Aqua("supabase migration list"))
    56  	}
    57  	return err
    58  }
    59  
    60  func UpdateMigrationTable(ctx context.Context, conn *pgx.Conn, version []string, status string, repairAll bool, fsys afero.Fs) error {
    61  	if err := history.CreateMigrationTable(ctx, conn); err != nil {
    62  		return err
    63  	}
    64  	// Data statements don't mutate schemas, safe to use statement cache
    65  	batch := &pgx.Batch{}
    66  	if repairAll {
    67  		batch.Queue(history.TRUNCATE_VERSION_TABLE)
    68  	}
    69  	switch status {
    70  	case Applied:
    71  		for _, v := range version {
    72  			f, err := NewMigrationFromVersion(v, fsys)
    73  			if err != nil {
    74  				return err
    75  			}
    76  			batch.Queue(history.INSERT_MIGRATION_VERSION, f.Version, f.Name, f.Lines)
    77  		}
    78  	case Reverted:
    79  		if !repairAll {
    80  			batch.Queue(history.DELETE_MIGRATION_VERSION, version)
    81  		}
    82  	}
    83  	if err := conn.SendBatch(ctx, batch).Close(); err != nil {
    84  		return errors.Errorf("failed to update migration table: %w", err)
    85  	}
    86  	if !repairAll {
    87  		fmt.Fprintf(os.Stderr, "Repaired migration history: %v => %s\n", version, status)
    88  	}
    89  	return nil
    90  }
    91  
    92  func GetMigrationFile(version string, fsys afero.Fs) (string, error) {
    93  	path := filepath.Join(utils.MigrationsDir, version+"_*.sql")
    94  	matches, err := afero.Glob(fsys, path)
    95  	if err != nil {
    96  		return "", errors.Errorf("failed to glob migration files: %w", err)
    97  	}
    98  	if len(matches) == 0 {
    99  		return "", errors.Errorf("glob %s: %w", path, os.ErrNotExist)
   100  	}
   101  	return matches[0], nil
   102  }
   103  
   104  type MigrationFile struct {
   105  	Lines   []string
   106  	Version string
   107  	Name    string
   108  }
   109  
   110  func NewMigrationFromVersion(version string, fsys afero.Fs) (*MigrationFile, error) {
   111  	name, err := GetMigrationFile(version, fsys)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  	return NewMigrationFromFile(name, fsys)
   116  }
   117  
   118  func NewMigrationFromFile(path string, fsys afero.Fs) (*MigrationFile, error) {
   119  	sql, err := fsys.Open(path)
   120  	if err != nil {
   121  		return nil, errors.Errorf("failed to open migration file: %w", err)
   122  	}
   123  	defer sql.Close()
   124  	// Unless explicitly specified, Use file length as max buffer size
   125  	if !viper.IsSet("SCANNER_BUFFER_SIZE") {
   126  		if fi, err := sql.Stat(); err == nil {
   127  			if size := int(fi.Size()); size > parser.MaxScannerCapacity {
   128  				parser.MaxScannerCapacity = size
   129  			}
   130  		}
   131  	}
   132  	file, err := NewMigrationFromReader(sql)
   133  	if err == nil {
   134  		// Parse version from file name
   135  		filename := filepath.Base(path)
   136  		matches := utils.MigrateFilePattern.FindStringSubmatch(filename)
   137  		if len(matches) > 2 {
   138  			file.Version = matches[1]
   139  			file.Name = matches[2]
   140  		}
   141  	}
   142  	return file, err
   143  }
   144  
   145  func NewMigrationFromReader(sql io.Reader) (*MigrationFile, error) {
   146  	lines, err := parser.SplitAndTrim(sql)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  	return &MigrationFile{Lines: lines}, nil
   151  }
   152  
   153  func (m *MigrationFile) ExecBatch(ctx context.Context, conn *pgx.Conn) error {
   154  	// Batch migration commands, without using statement cache
   155  	batch := &pgconn.Batch{}
   156  	for _, line := range m.Lines {
   157  		batch.ExecParams(line, nil, nil, nil, nil)
   158  	}
   159  	// Insert into migration history
   160  	if len(m.Version) > 0 {
   161  		if err := m.insertVersionSQL(conn, batch); err != nil {
   162  			return err
   163  		}
   164  	}
   165  	// ExecBatch is implicitly transactional
   166  	if result, err := conn.PgConn().ExecBatch(ctx, batch).ReadAll(); err != nil {
   167  		// Defaults to printing the last statement on error
   168  		stat := history.INSERT_MIGRATION_VERSION
   169  		i := len(result)
   170  		if i < len(m.Lines) {
   171  			stat = m.Lines[i]
   172  		}
   173  		return errors.Errorf("%w\nAt statement %d: %s", err, i, stat)
   174  	}
   175  	return nil
   176  }
   177  
   178  func (m *MigrationFile) insertVersionSQL(conn *pgx.Conn, batch *pgconn.Batch) error {
   179  	value := pgtype.TextArray{}
   180  	if err := value.Set(m.Lines); err != nil {
   181  		return errors.Errorf("failed to set text array: %w", err)
   182  	}
   183  	ci := conn.ConnInfo()
   184  	var err error
   185  	var encoded []byte
   186  	var valueFormat int16
   187  	if conn.Config().PreferSimpleProtocol {
   188  		encoded, err = value.EncodeText(ci, encoded)
   189  		valueFormat = pgtype.TextFormatCode
   190  	} else {
   191  		encoded, err = value.EncodeBinary(ci, encoded)
   192  		valueFormat = pgtype.BinaryFormatCode
   193  	}
   194  	if err != nil {
   195  		return errors.Errorf("failed to encode binary: %w", err)
   196  	}
   197  	batch.ExecParams(
   198  		history.INSERT_MIGRATION_VERSION,
   199  		[][]byte{[]byte(m.Version), []byte(m.Name), encoded},
   200  		[]uint32{pgtype.TextOID, pgtype.TextOID, pgtype.TextArrayOID},
   201  		[]int16{pgtype.TextFormatCode, pgtype.TextFormatCode, valueFormat},
   202  		nil,
   203  	)
   204  	return nil
   205  }
   206  
   207  func (m *MigrationFile) ExecBatchWithCache(ctx context.Context, conn *pgx.Conn) error {
   208  	// Data statements don't mutate schemas, safe to use statement cache
   209  	batch := pgx.Batch{}
   210  	for _, line := range m.Lines {
   211  		batch.Queue(line)
   212  	}
   213  	// No need to track version here because there are no schema changes
   214  	if err := conn.SendBatch(ctx, &batch).Close(); err != nil {
   215  		return errors.Errorf("failed to send batch: %w", err)
   216  	}
   217  	return nil
   218  }