github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/migrate/migrate.go (about)

     1  package migrate
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  
     8  	log "github.com/authzed/spicedb/internal/logging"
     9  )
    10  
    11  const (
    12  	Head = "head"
    13  	None = ""
    14  )
    15  
    16  type RunType bool
    17  
    18  var (
    19  	DryRun  RunType = true
    20  	LiveRun RunType = false
    21  )
    22  
    23  // Driver represents the common interface for enabling the orchestration of migrations
    24  // for a specific type of datastore. The driver is parameterized with a type representing
    25  // a connection handler that will be forwarded by the Manager to the MigrationFunc to execute.
    26  type Driver[C any, T any] interface {
    27  	// Version returns the current version of the schema in the backing datastore.
    28  	// If the datastore is brand new, version should return the empty string without
    29  	// an error.
    30  	Version(ctx context.Context) (string, error)
    31  
    32  	// WriteVersion stores the migration version being run
    33  	WriteVersion(ctx context.Context, tx T, version string, replaced string) error
    34  
    35  	// Conn returns the drivers underlying connection handler to be used by one or more MigrationFunc
    36  	Conn() C
    37  
    38  	// RunTx returns a transaction for to be used by one or more TxMigrationFunc
    39  	RunTx(context.Context, TxMigrationFunc[T]) error
    40  
    41  	// Close frees up any resources in use by the driver.
    42  	Close(ctx context.Context) error
    43  }
    44  
    45  // MigrationFunc is a function that executes in the context of a specific database connection handler.
    46  type MigrationFunc[C any] func(ctx context.Context, conn C) error
    47  
    48  // TxMigrationFunc is a function that executes in the context of a specific database transaction.
    49  type TxMigrationFunc[T any] func(ctx context.Context, tx T) error
    50  
    51  type migration[C any, T any] struct {
    52  	version  string
    53  	replaces string
    54  	up       MigrationFunc[C]
    55  	upTx     TxMigrationFunc[T]
    56  }
    57  
    58  // Manager is used to manage a self-contained set of migrations. Standard usage
    59  // would be to instantiate one at the package level for a particular application
    60  // and then statically register migrations to the single instantiation in init
    61  // functions.
    62  // The manager is parameterized using the Driver interface along the concrete type of
    63  // a database connection handler. This makes it possible for MigrationFunc to run without
    64  // having to abstract each connection handler behind a common interface.
    65  type Manager[D Driver[C, T], C any, T any] struct {
    66  	migrations map[string]migration[C, T]
    67  }
    68  
    69  // NewManager creates a new empty instance of a migration manager.
    70  func NewManager[D Driver[C, T], C any, T any]() *Manager[D, C, T] {
    71  	return &Manager[D, C, T]{migrations: make(map[string]migration[C, T])}
    72  }
    73  
    74  // Register is used to associate a single migration with the migration engine.
    75  // The up parameter should be a function that performs the actual upgrade logic
    76  // and which takes a pointer to a concrete implementation of the Driver
    77  // interface as its only parameters, which will be passed directly from the Run
    78  // method into the upgrade function. If not extra fields or data are required
    79  // the function can alternatively take a Driver interface param.
    80  func (m *Manager[D, C, T]) Register(version, replaces string, up MigrationFunc[C], upTx TxMigrationFunc[T]) error {
    81  	if strings.ToLower(version) == Head {
    82  		return fmt.Errorf("unable to register version called head")
    83  	}
    84  
    85  	if _, ok := m.migrations[version]; ok {
    86  		return fmt.Errorf("revision already exists: %s", version)
    87  	}
    88  
    89  	m.migrations[version] = migration[C, T]{
    90  		version:  version,
    91  		replaces: replaces,
    92  		up:       up,
    93  		upTx:     upTx,
    94  	}
    95  
    96  	return nil
    97  }
    98  
    99  // Run will actually perform the necessary migrations to bring the backing datastore
   100  // from its current revision to the specified revision.
   101  func (m *Manager[D, C, T]) Run(ctx context.Context, driver D, throughRevision string, dryRun RunType) error {
   102  	requestedRevision := throughRevision
   103  	starting, err := driver.Version(ctx)
   104  	if err != nil {
   105  		return fmt.Errorf("unable to get current revision: %w", err)
   106  	}
   107  
   108  	if strings.ToLower(throughRevision) == Head {
   109  		throughRevision, err = m.HeadRevision()
   110  		if err != nil {
   111  			return fmt.Errorf("unable to compute head revision: %w", err)
   112  		}
   113  	}
   114  
   115  	toRun, err := collectMigrationsInRange(starting, throughRevision, m.migrations)
   116  	if err != nil {
   117  		return fmt.Errorf("unable to compute migration list: %w", err)
   118  	}
   119  	if len(toRun) == 0 {
   120  		log.Ctx(ctx).Info().Str("targetRevision", requestedRevision).Msg("server already at requested revision")
   121  	}
   122  
   123  	if !dryRun {
   124  		for _, migrationToRun := range toRun {
   125  			// Double check that the current version reported is the one we expect
   126  			currentVersion, err := driver.Version(ctx)
   127  			if err != nil {
   128  				return fmt.Errorf("unable to load version from driver: %w", err)
   129  			}
   130  
   131  			if migrationToRun.replaces != currentVersion {
   132  				return fmt.Errorf("migration attempting to run out of order: %s != %s", currentVersion, migrationToRun.replaces)
   133  			}
   134  
   135  			log.Ctx(ctx).Info().Str("from", migrationToRun.replaces).Str("to", migrationToRun.version).Msg("migrating")
   136  			if migrationToRun.up != nil {
   137  				if err = migrationToRun.up(ctx, driver.Conn()); err != nil {
   138  					return fmt.Errorf("error executing migration function: %w", err)
   139  				}
   140  			}
   141  
   142  			migrationToRun := migrationToRun
   143  			if err := driver.RunTx(ctx, func(ctx context.Context, tx T) error {
   144  				if migrationToRun.upTx != nil {
   145  					if err := migrationToRun.upTx(ctx, tx); err != nil {
   146  						return err
   147  					}
   148  				}
   149  				return driver.WriteVersion(ctx, tx, migrationToRun.version, migrationToRun.replaces)
   150  			}); err != nil {
   151  				return fmt.Errorf("error executing migration `%s`: %w", migrationToRun.version, err)
   152  			}
   153  
   154  			currentVersion, err = driver.Version(ctx)
   155  			if err != nil {
   156  				return fmt.Errorf("unable to load version from driver: %w", err)
   157  			}
   158  			if migrationToRun.version != currentVersion {
   159  				return fmt.Errorf("the migration function succeeded, but the driver did not report the expected version: %s", migrationToRun.version)
   160  			}
   161  		}
   162  	}
   163  
   164  	return nil
   165  }
   166  
   167  func (m *Manager[D, C, T]) HeadRevision() (string, error) {
   168  	candidates := make(map[string]struct{}, len(m.migrations))
   169  	for candidate := range m.migrations {
   170  		candidates[candidate] = struct{}{}
   171  	}
   172  
   173  	for _, eliminateReplaces := range m.migrations {
   174  		delete(candidates, eliminateReplaces.replaces)
   175  	}
   176  
   177  	allHeads := make([]string, 0, len(candidates))
   178  	for headRevision := range candidates {
   179  		allHeads = append(allHeads, headRevision)
   180  	}
   181  
   182  	if len(allHeads) != 1 {
   183  		return "", fmt.Errorf("multiple or zero head revisions found: %v", allHeads)
   184  	}
   185  
   186  	return allHeads[0], nil
   187  }
   188  
   189  func (m *Manager[D, C, T]) IsHeadCompatible(revision string) (bool, error) {
   190  	headRevision, err := m.HeadRevision()
   191  	if err != nil {
   192  		return false, err
   193  	}
   194  	headMigration := m.migrations[headRevision]
   195  	return revision == headMigration.version || revision == headMigration.replaces, nil
   196  }
   197  
   198  func collectMigrationsInRange[C any, T any](starting, through string, all map[string]migration[C, T]) ([]migration[C, T], error) {
   199  	var found []migration[C, T]
   200  
   201  	lookingForRevision := through
   202  	for lookingForRevision != starting {
   203  		foundMigration, ok := all[lookingForRevision]
   204  		if !ok {
   205  			return []migration[C, T]{}, fmt.Errorf("unable to find migration for revision: %s", lookingForRevision)
   206  		}
   207  
   208  		found = append([]migration[C, T]{foundMigration}, found...)
   209  		lookingForRevision = foundMigration.replaces
   210  	}
   211  
   212  	return found, nil
   213  }