github.com/SigNoz/golang-migrate/v4@v4.0.0-20231005133642-7493dbaf5f5b/migrate.go (about)

     1  // Package migrate reads migrations from sources and runs them against databases.
     2  // Sources are defined by the `source.Driver` and databases by the `database.Driver`
     3  // interface. The driver interfaces are kept "dumb", all migration logic is kept
     4  // in this package.
     5  package migrate
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"os"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/hashicorp/go-multierror"
    16  
    17  	"github.com/golang-migrate/migrate/v4/database"
    18  	iurl "github.com/golang-migrate/migrate/v4/internal/url"
    19  	"github.com/golang-migrate/migrate/v4/source"
    20  )
    21  
    22  // DefaultPrefetchMigrations sets the number of migrations to pre-read
    23  // from the source. This is helpful if the source is remote, but has little
    24  // effect for a local source (i.e. file system).
    25  // Please note that this setting has a major impact on the memory usage,
    26  // since each pre-read migration is buffered in memory. See DefaultBufferSize.
    27  var DefaultPrefetchMigrations = uint(10)
    28  
    29  // DefaultLockTimeout sets the max time a database driver has to acquire a lock.
    30  var DefaultLockTimeout = 15 * time.Second
    31  
    32  var (
    33  	ErrNoChange       = errors.New("no change")
    34  	ErrNilVersion     = errors.New("no migration")
    35  	ErrInvalidVersion = errors.New("version must be >= -1")
    36  	ErrLocked         = errors.New("database locked")
    37  	ErrLockTimeout    = errors.New("timeout: can't acquire database lock")
    38  )
    39  
    40  // ErrShortLimit is an error returned when not enough migrations
    41  // can be returned by a source for a given limit.
    42  type ErrShortLimit struct {
    43  	Short uint
    44  }
    45  
    46  // Error implements the error interface.
    47  func (e ErrShortLimit) Error() string {
    48  	return fmt.Sprintf("limit %v short", e.Short)
    49  }
    50  
    51  type ErrDirty struct {
    52  	Version int
    53  }
    54  
    55  func (e ErrDirty) Error() string {
    56  	return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version)
    57  }
    58  
    59  type Migrate struct {
    60  	sourceName   string
    61  	sourceDrv    source.Driver
    62  	databaseName string
    63  	databaseDrv  database.Driver
    64  
    65  	// Log accepts a Logger interface
    66  	Log Logger
    67  
    68  	// GracefulStop accepts `true` and will stop executing migrations
    69  	// as soon as possible at a safe break point, so that the database
    70  	// is not corrupted.
    71  	GracefulStop chan bool
    72  	isLockedMu   *sync.Mutex
    73  
    74  	isGracefulStop bool
    75  	isLocked       bool
    76  
    77  	// PrefetchMigrations defaults to DefaultPrefetchMigrations,
    78  	// but can be set per Migrate instance.
    79  	PrefetchMigrations uint
    80  
    81  	// LockTimeout defaults to DefaultLockTimeout,
    82  	// but can be set per Migrate instance.
    83  	LockTimeout time.Duration
    84  
    85  	// EnableTemplating treats migration files as go text templates; making environment variables accessible in your
    86  	// migration files.
    87  	// i.e. If you set the LOCAL_WAREHOUSE environment variable to MY_DB and have a migration file with the
    88  	// following contents:
    89  	//   INSERT INTO {{.LOCAL_WAREHOUSE}}.INVENTORY.RECORDS ('foo') VALUES ('bar');
    90  	// it will be transformed into the following before being executed:
    91  	//   INSERT INTO MY_DB.INVENTORY.RECORDS ('foo') VALUES ('bar');
    92  	//
    93  	// See https://pkg.go.dev/text/template for more information on template formats
    94  	EnableTemplating bool
    95  }
    96  
    97  // New returns a new Migrate instance from a source URL and a database URL.
    98  // The URL scheme is defined by each driver.
    99  func New(sourceURL, databaseURL string) (*Migrate, error) {
   100  	m := newCommon()
   101  
   102  	sourceName, err := iurl.SchemeFromURL(sourceURL)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  	m.sourceName = sourceName
   107  
   108  	databaseName, err := iurl.SchemeFromURL(databaseURL)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	m.databaseName = databaseName
   113  
   114  	sourceDrv, err := source.Open(sourceURL)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	m.sourceDrv = sourceDrv
   119  
   120  	databaseDrv, err := database.Open(databaseURL)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  	m.databaseDrv = databaseDrv
   125  
   126  	return m, nil
   127  }
   128  
   129  // NewWithDatabaseInstance returns a new Migrate instance from a source URL
   130  // and an existing database instance. The source URL scheme is defined by each driver.
   131  // Use any string that can serve as an identifier during logging as databaseName.
   132  // You are responsible for closing the underlying database client if necessary.
   133  func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
   134  	m := newCommon()
   135  
   136  	sourceName, err := iurl.SchemeFromURL(sourceURL)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	m.sourceName = sourceName
   141  
   142  	m.databaseName = databaseName
   143  
   144  	sourceDrv, err := source.Open(sourceURL)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	m.sourceDrv = sourceDrv
   149  
   150  	m.databaseDrv = databaseInstance
   151  
   152  	return m, nil
   153  }
   154  
   155  // NewWithSourceInstance returns a new Migrate instance from an existing source instance
   156  // and a database URL. The database URL scheme is defined by each driver.
   157  // Use any string that can serve as an identifier during logging as sourceName.
   158  // You are responsible for closing the underlying source client if necessary.
   159  func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) {
   160  	m := newCommon()
   161  
   162  	databaseName, err := iurl.SchemeFromURL(databaseURL)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	m.databaseName = databaseName
   167  
   168  	m.sourceName = sourceName
   169  
   170  	databaseDrv, err := database.Open(databaseURL)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  	m.databaseDrv = databaseDrv
   175  
   176  	m.sourceDrv = sourceInstance
   177  
   178  	return m, nil
   179  }
   180  
   181  // NewWithInstance returns a new Migrate instance from an existing source and
   182  // database instance. Use any string that can serve as an identifier during logging
   183  // as sourceName and databaseName. You are responsible for closing down
   184  // the underlying source and database client if necessary.
   185  func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
   186  	m := newCommon()
   187  
   188  	m.sourceName = sourceName
   189  	m.databaseName = databaseName
   190  
   191  	m.sourceDrv = sourceInstance
   192  	m.databaseDrv = databaseInstance
   193  
   194  	return m, nil
   195  }
   196  
   197  func newCommon() *Migrate {
   198  	return &Migrate{
   199  		GracefulStop:       make(chan bool, 1),
   200  		PrefetchMigrations: DefaultPrefetchMigrations,
   201  		LockTimeout:        DefaultLockTimeout,
   202  		isLockedMu:         &sync.Mutex{},
   203  	}
   204  }
   205  
   206  // Close closes the source and the database.
   207  func (m *Migrate) Close() (source error, database error) {
   208  	databaseSrvClose := make(chan error)
   209  	sourceSrvClose := make(chan error)
   210  
   211  	m.logVerbosePrintf("Closing source and database\n")
   212  
   213  	go func() {
   214  		databaseSrvClose <- m.databaseDrv.Close()
   215  	}()
   216  
   217  	go func() {
   218  		sourceSrvClose <- m.sourceDrv.Close()
   219  	}()
   220  
   221  	return <-sourceSrvClose, <-databaseSrvClose
   222  }
   223  
   224  // Migrate looks at the currently active migration version,
   225  // then migrates either up or down to the specified version.
   226  func (m *Migrate) Migrate(version uint) error {
   227  	if err := m.lock(); err != nil {
   228  		return err
   229  	}
   230  
   231  	curVersion, dirty, err := m.databaseDrv.Version()
   232  	if err != nil {
   233  		return m.unlockErr(err)
   234  	}
   235  
   236  	if dirty {
   237  		return m.unlockErr(ErrDirty{curVersion})
   238  	}
   239  
   240  	ret := make(chan interface{}, m.PrefetchMigrations)
   241  	go m.read(curVersion, int(version), ret)
   242  
   243  	return m.unlockErr(m.runMigrations(ret))
   244  }
   245  
   246  // Steps looks at the currently active migration version.
   247  // It will migrate up if n > 0, and down if n < 0.
   248  func (m *Migrate) Steps(n int) error {
   249  	if n == 0 {
   250  		return ErrNoChange
   251  	}
   252  
   253  	if err := m.lock(); err != nil {
   254  		return err
   255  	}
   256  
   257  	curVersion, dirty, err := m.databaseDrv.Version()
   258  	if err != nil {
   259  		return m.unlockErr(err)
   260  	}
   261  
   262  	if dirty {
   263  		return m.unlockErr(ErrDirty{curVersion})
   264  	}
   265  
   266  	ret := make(chan interface{}, m.PrefetchMigrations)
   267  
   268  	if n > 0 {
   269  		go m.readUp(curVersion, n, ret)
   270  	} else {
   271  		go m.readDown(curVersion, -n, ret)
   272  	}
   273  
   274  	return m.unlockErr(m.runMigrations(ret))
   275  }
   276  
   277  // Up looks at the currently active migration version
   278  // and will migrate all the way up (applying all up migrations).
   279  func (m *Migrate) Up() error {
   280  	if err := m.lock(); err != nil {
   281  		return err
   282  	}
   283  
   284  	curVersion, dirty, err := m.databaseDrv.Version()
   285  	if err != nil {
   286  		return m.unlockErr(err)
   287  	}
   288  
   289  	if dirty {
   290  		return m.unlockErr(ErrDirty{curVersion})
   291  	}
   292  
   293  	ret := make(chan interface{}, m.PrefetchMigrations)
   294  
   295  	go m.readUp(curVersion, -1, ret)
   296  	return m.unlockErr(m.runMigrations(ret))
   297  }
   298  
   299  // Down looks at the currently active migration version
   300  // and will migrate all the way down (applying all down migrations).
   301  func (m *Migrate) Down() error {
   302  	if err := m.lock(); err != nil {
   303  		return err
   304  	}
   305  
   306  	curVersion, dirty, err := m.databaseDrv.Version()
   307  	if err != nil {
   308  		return m.unlockErr(err)
   309  	}
   310  
   311  	if dirty {
   312  		return m.unlockErr(ErrDirty{curVersion})
   313  	}
   314  
   315  	ret := make(chan interface{}, m.PrefetchMigrations)
   316  	go m.readDown(curVersion, -1, ret)
   317  	return m.unlockErr(m.runMigrations(ret))
   318  }
   319  
   320  // Drop deletes everything in the database.
   321  func (m *Migrate) Drop() error {
   322  	if err := m.lock(); err != nil {
   323  		return err
   324  	}
   325  	if err := m.databaseDrv.Drop(); err != nil {
   326  		return m.unlockErr(err)
   327  	}
   328  	return m.unlock()
   329  }
   330  
   331  // Run runs any migration provided by you against the database.
   332  // It does not check any currently active version in database.
   333  // Usually you don't need this function at all. Use Migrate,
   334  // Steps, Up or Down instead.
   335  func (m *Migrate) Run(migration ...*Migration) error {
   336  	if len(migration) == 0 {
   337  		return ErrNoChange
   338  	}
   339  
   340  	if err := m.lock(); err != nil {
   341  		return err
   342  	}
   343  
   344  	curVersion, dirty, err := m.databaseDrv.Version()
   345  	if err != nil {
   346  		return m.unlockErr(err)
   347  	}
   348  
   349  	if dirty {
   350  		return m.unlockErr(ErrDirty{curVersion})
   351  	}
   352  
   353  	ret := make(chan interface{}, m.PrefetchMigrations)
   354  
   355  	go func() {
   356  		defer close(ret)
   357  		for _, migr := range migration {
   358  			if m.PrefetchMigrations > 0 && migr.Body != nil {
   359  				m.logVerbosePrintf("Start buffering %v\n", migr.LogString())
   360  			} else {
   361  				m.logVerbosePrintf("Scheduled %v\n", migr.LogString())
   362  			}
   363  
   364  			ret <- migr
   365  			go func(migr *Migration) {
   366  				if err := migr.Buffer(); err != nil {
   367  					m.logErr(err)
   368  				}
   369  			}(migr)
   370  		}
   371  	}()
   372  
   373  	return m.unlockErr(m.runMigrations(ret))
   374  }
   375  
   376  // Force sets a migration version.
   377  // It does not check any currently active version in database.
   378  // It resets the dirty state to false.
   379  func (m *Migrate) Force(version int) error {
   380  	if version < -1 {
   381  		return ErrInvalidVersion
   382  	}
   383  
   384  	if err := m.lock(); err != nil {
   385  		return err
   386  	}
   387  
   388  	if err := m.databaseDrv.SetVersion(version, false); err != nil {
   389  		return m.unlockErr(err)
   390  	}
   391  
   392  	return m.unlock()
   393  }
   394  
   395  // Version returns the currently active migration version.
   396  // If no migration has been applied, yet, it will return ErrNilVersion.
   397  func (m *Migrate) Version() (version uint, dirty bool, err error) {
   398  	v, d, err := m.databaseDrv.Version()
   399  	if err != nil {
   400  		return 0, false, err
   401  	}
   402  
   403  	if v == database.NilVersion {
   404  		return 0, false, ErrNilVersion
   405  	}
   406  
   407  	return suint(v), d, nil
   408  }
   409  
   410  // read reads either up or down migrations from source `from` to `to`.
   411  // Each migration is then written to the ret channel.
   412  // If an error occurs during reading, that error is written to the ret channel, too.
   413  // Once read is done reading it will close the ret channel.
   414  func (m *Migrate) read(from int, to int, ret chan<- interface{}) {
   415  	defer close(ret)
   416  
   417  	// check if from version exists
   418  	if from >= 0 {
   419  		if err := m.versionExists(suint(from)); err != nil {
   420  			ret <- err
   421  			return
   422  		}
   423  	}
   424  
   425  	// check if to version exists
   426  	if to >= 0 {
   427  		if err := m.versionExists(suint(to)); err != nil {
   428  			ret <- err
   429  			return
   430  		}
   431  	}
   432  
   433  	// no change?
   434  	if from == to {
   435  		ret <- ErrNoChange
   436  		return
   437  	}
   438  
   439  	if from < to {
   440  		// it's going up
   441  		// apply first migration if from is nil version
   442  		if from == -1 {
   443  			firstVersion, err := m.sourceDrv.First()
   444  			if err != nil {
   445  				ret <- err
   446  				return
   447  			}
   448  
   449  			migr, err := m.newMigration(firstVersion, int(firstVersion))
   450  			if err != nil {
   451  				ret <- err
   452  				return
   453  			}
   454  
   455  			ret <- migr
   456  			go func() {
   457  				if err := migr.Buffer(); err != nil {
   458  					m.logErr(err)
   459  				}
   460  			}()
   461  
   462  			from = int(firstVersion)
   463  		}
   464  
   465  		// run until we reach target ...
   466  		for from < to {
   467  			if m.stop() {
   468  				return
   469  			}
   470  
   471  			next, err := m.sourceDrv.Next(suint(from))
   472  			if err != nil {
   473  				ret <- err
   474  				return
   475  			}
   476  
   477  			migr, err := m.newMigration(next, int(next))
   478  			if err != nil {
   479  				ret <- err
   480  				return
   481  			}
   482  
   483  			ret <- migr
   484  			go func() {
   485  				if err := migr.Buffer(); err != nil {
   486  					m.logErr(err)
   487  				}
   488  			}()
   489  
   490  			from = int(next)
   491  		}
   492  
   493  	} else {
   494  		// it's going down
   495  		// run until we reach target ...
   496  		for from > to && from >= 0 {
   497  			if m.stop() {
   498  				return
   499  			}
   500  
   501  			prev, err := m.sourceDrv.Prev(suint(from))
   502  			if errors.Is(err, os.ErrNotExist) && to == -1 {
   503  				// apply nil migration
   504  				migr, err := m.newMigration(suint(from), -1)
   505  				if err != nil {
   506  					ret <- err
   507  					return
   508  				}
   509  				ret <- migr
   510  				go func() {
   511  					if err := migr.Buffer(); err != nil {
   512  						m.logErr(err)
   513  					}
   514  				}()
   515  
   516  				return
   517  
   518  			} else if err != nil {
   519  				ret <- err
   520  				return
   521  			}
   522  
   523  			migr, err := m.newMigration(suint(from), int(prev))
   524  			if err != nil {
   525  				ret <- err
   526  				return
   527  			}
   528  
   529  			ret <- migr
   530  			go func() {
   531  				if err := migr.Buffer(); err != nil {
   532  					m.logErr(err)
   533  				}
   534  			}()
   535  
   536  			from = int(prev)
   537  		}
   538  	}
   539  }
   540  
   541  // readUp reads up migrations from `from` limitted by `limit`.
   542  // limit can be -1, implying no limit and reading until there are no more migrations.
   543  // Each migration is then written to the ret channel.
   544  // If an error occurs during reading, that error is written to the ret channel, too.
   545  // Once readUp is done reading it will close the ret channel.
   546  func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) {
   547  	defer close(ret)
   548  
   549  	// check if from version exists
   550  	if from >= 0 {
   551  		if err := m.versionExists(suint(from)); err != nil {
   552  			ret <- err
   553  			return
   554  		}
   555  	}
   556  
   557  	if limit == 0 {
   558  		ret <- ErrNoChange
   559  		return
   560  	}
   561  
   562  	count := 0
   563  	for count < limit || limit == -1 {
   564  		if m.stop() {
   565  			return
   566  		}
   567  
   568  		// apply first migration if from is nil version
   569  		if from == -1 {
   570  			firstVersion, err := m.sourceDrv.First()
   571  			if err != nil {
   572  				ret <- err
   573  				return
   574  			}
   575  
   576  			migr, err := m.newMigration(firstVersion, int(firstVersion))
   577  			if err != nil {
   578  				ret <- err
   579  				return
   580  			}
   581  
   582  			ret <- migr
   583  			go func() {
   584  				if err := migr.Buffer(); err != nil {
   585  					m.logErr(err)
   586  				}
   587  			}()
   588  			from = int(firstVersion)
   589  			count++
   590  			continue
   591  		}
   592  
   593  		// apply next migration
   594  		next, err := m.sourceDrv.Next(suint(from))
   595  		if errors.Is(err, os.ErrNotExist) {
   596  			// no limit, but no migrations applied?
   597  			if limit == -1 && count == 0 {
   598  				ret <- ErrNoChange
   599  				return
   600  			}
   601  
   602  			// no limit, reached end
   603  			if limit == -1 {
   604  				return
   605  			}
   606  
   607  			// reached end, and didn't apply any migrations
   608  			if limit > 0 && count == 0 {
   609  				ret <- os.ErrNotExist
   610  				return
   611  			}
   612  
   613  			// applied less migrations than limit?
   614  			if count < limit {
   615  				ret <- ErrShortLimit{suint(limit - count)}
   616  				return
   617  			}
   618  		}
   619  		if err != nil {
   620  			ret <- err
   621  			return
   622  		}
   623  
   624  		migr, err := m.newMigration(next, int(next))
   625  		if err != nil {
   626  			ret <- err
   627  			return
   628  		}
   629  
   630  		ret <- migr
   631  		go func() {
   632  			if err := migr.Buffer(); err != nil {
   633  				m.logErr(err)
   634  			}
   635  		}()
   636  		from = int(next)
   637  		count++
   638  	}
   639  }
   640  
   641  // readDown reads down migrations from `from` limitted by `limit`.
   642  // limit can be -1, implying no limit and reading until there are no more migrations.
   643  // Each migration is then written to the ret channel.
   644  // If an error occurs during reading, that error is written to the ret channel, too.
   645  // Once readDown is done reading it will close the ret channel.
   646  func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
   647  	defer close(ret)
   648  
   649  	// check if from version exists
   650  	if from >= 0 {
   651  		if err := m.versionExists(suint(from)); err != nil {
   652  			ret <- err
   653  			return
   654  		}
   655  	}
   656  
   657  	if limit == 0 {
   658  		ret <- ErrNoChange
   659  		return
   660  	}
   661  
   662  	// no change if already at nil version
   663  	if from == -1 && limit == -1 {
   664  		ret <- ErrNoChange
   665  		return
   666  	}
   667  
   668  	// can't go over limit if already at nil version
   669  	if from == -1 && limit > 0 {
   670  		ret <- os.ErrNotExist
   671  		return
   672  	}
   673  
   674  	count := 0
   675  	for count < limit || limit == -1 {
   676  		if m.stop() {
   677  			return
   678  		}
   679  
   680  		prev, err := m.sourceDrv.Prev(suint(from))
   681  		if errors.Is(err, os.ErrNotExist) {
   682  			// no limit or haven't reached limit, apply "first" migration
   683  			if limit == -1 || limit-count > 0 {
   684  				firstVersion, err := m.sourceDrv.First()
   685  				if err != nil {
   686  					ret <- err
   687  					return
   688  				}
   689  
   690  				migr, err := m.newMigration(firstVersion, -1)
   691  				if err != nil {
   692  					ret <- err
   693  					return
   694  				}
   695  				ret <- migr
   696  				go func() {
   697  					if err := migr.Buffer(); err != nil {
   698  						m.logErr(err)
   699  					}
   700  				}()
   701  				count++
   702  			}
   703  
   704  			if count < limit {
   705  				ret <- ErrShortLimit{suint(limit - count)}
   706  			}
   707  			return
   708  		}
   709  		if err != nil {
   710  			ret <- err
   711  			return
   712  		}
   713  
   714  		migr, err := m.newMigration(suint(from), int(prev))
   715  		if err != nil {
   716  			ret <- err
   717  			return
   718  		}
   719  
   720  		ret <- migr
   721  		go func() {
   722  			if err := migr.Buffer(); err != nil {
   723  				m.logErr(err)
   724  			}
   725  		}()
   726  		from = int(prev)
   727  		count++
   728  	}
   729  }
   730  
   731  // runMigrations reads *Migration and error from a channel. Any other type
   732  // sent on this channel will result in a panic. Each migration is then
   733  // proxied to the database driver and run against the database.
   734  // Before running a newly received migration it will check if it's supposed
   735  // to stop execution because it might have received a stop signal on the
   736  // GracefulStop channel.
   737  func (m *Migrate) runMigrations(ret <-chan interface{}) error {
   738  	for r := range ret {
   739  
   740  		if m.stop() {
   741  			return nil
   742  		}
   743  
   744  		switch r := r.(type) {
   745  		case error:
   746  			return r
   747  
   748  		case *Migration:
   749  			migr := r
   750  
   751  			// set version with dirty state
   752  			if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
   753  				return err
   754  			}
   755  
   756  			if migr.Body != nil {
   757  				m.logVerbosePrintf("Read and execute %v\n", migr.LogString())
   758  				if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
   759  					return err
   760  				}
   761  			}
   762  
   763  			// set clean state
   764  			if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil {
   765  				return err
   766  			}
   767  
   768  			endTime := time.Now()
   769  			readTime := migr.FinishedReading.Sub(migr.StartedBuffering)
   770  			runTime := endTime.Sub(migr.FinishedReading)
   771  
   772  			// log either verbose or normal
   773  			if m.Log != nil {
   774  				if m.Log.Verbose() {
   775  					m.logPrintf("Finished %v (read %v, ran %v)\n", migr.LogString(), readTime, runTime)
   776  				} else {
   777  					m.logPrintf("%v (%v)\n", migr.LogString(), readTime+runTime)
   778  				}
   779  			}
   780  
   781  		default:
   782  			return fmt.Errorf("unknown type: %T with value: %+v", r, r)
   783  		}
   784  	}
   785  	return nil
   786  }
   787  
   788  // versionExists checks the source if either the up or down migration for
   789  // the specified migration version exists.
   790  func (m *Migrate) versionExists(version uint) (result error) {
   791  	// try up migration first
   792  	up, _, err := m.sourceDrv.ReadUp(version)
   793  	if err == nil {
   794  		defer func() {
   795  			if errClose := up.Close(); errClose != nil {
   796  				result = multierror.Append(result, errClose)
   797  			}
   798  		}()
   799  	}
   800  	if errors.Is(err, os.ErrExist) {
   801  		return nil
   802  	} else if !errors.Is(err, os.ErrNotExist) {
   803  		return err
   804  	}
   805  
   806  	// then try down migration
   807  	down, _, err := m.sourceDrv.ReadDown(version)
   808  	if err == nil {
   809  		defer func() {
   810  			if errClose := down.Close(); errClose != nil {
   811  				result = multierror.Append(result, errClose)
   812  			}
   813  		}()
   814  	}
   815  	if errors.Is(err, os.ErrExist) {
   816  		return nil
   817  	} else if !errors.Is(err, os.ErrNotExist) {
   818  		return err
   819  	}
   820  
   821  	err = fmt.Errorf("no migration found for version %d: %w", version, err)
   822  	m.logErr(err)
   823  	return err
   824  }
   825  
   826  // stop returns true if no more migrations should be run against the database
   827  // because a stop signal was received on the GracefulStop channel.
   828  // Calls are cheap and this function is not blocking.
   829  func (m *Migrate) stop() bool {
   830  	if m.isGracefulStop {
   831  		return true
   832  	}
   833  
   834  	select {
   835  	case <-m.GracefulStop:
   836  		m.isGracefulStop = true
   837  		return true
   838  
   839  	default:
   840  		return false
   841  	}
   842  }
   843  
   844  // newMigration is a helper func that returns a *Migration for the
   845  // specified version and targetVersion.
   846  func (m *Migrate) newMigration(version uint, targetVersion int) (*Migration, error) {
   847  	var migr *Migration
   848  	var r io.ReadCloser
   849  	var identifier string
   850  	var err error
   851  
   852  	if targetVersion >= int(version) {
   853  		r, identifier, err = m.sourceDrv.ReadUp(version)
   854  	} else {
   855  		r, identifier, err = m.sourceDrv.ReadDown(version)
   856  	}
   857  
   858  	if errors.Is(err, os.ErrNotExist) {
   859  		// create "empty" migration
   860  		migr, err = NewMigration(nil, "", version, targetVersion)
   861  		if err != nil {
   862  			return nil, err
   863  		}
   864  
   865  	} else if err != nil {
   866  		return nil, err
   867  
   868  	} else {
   869  		if m.EnableTemplating {
   870  			if r, err = applyEnvironmentTemplate(r); err != nil {
   871  				return nil, err
   872  			}
   873  		}
   874  		// create migration from source
   875  		migr, err = NewMigration(r, identifier, version, targetVersion)
   876  		if err != nil {
   877  			return nil, err
   878  		}
   879  	}
   880  
   881  	if m.PrefetchMigrations > 0 && migr.Body != nil {
   882  		m.logVerbosePrintf("Start buffering %v\n", migr.LogString())
   883  	} else {
   884  		m.logVerbosePrintf("Scheduled %v\n", migr.LogString())
   885  	}
   886  	return migr, nil
   887  }
   888  
   889  // lock is a thread safe helper function to lock the database.
   890  // It should be called as late as possible when running migrations.
   891  func (m *Migrate) lock() error {
   892  	m.isLockedMu.Lock()
   893  	defer m.isLockedMu.Unlock()
   894  
   895  	if m.isLocked {
   896  		return ErrLocked
   897  	}
   898  
   899  	// create done channel, used in the timeout goroutine
   900  	done := make(chan bool, 1)
   901  	defer func() {
   902  		done <- true
   903  	}()
   904  
   905  	// use errchan to signal error back to this context
   906  	errchan := make(chan error, 2)
   907  
   908  	// start timeout goroutine
   909  	timeout := time.After(m.LockTimeout)
   910  	go func() {
   911  		for {
   912  			select {
   913  			case <-done:
   914  				return
   915  			case <-timeout:
   916  				errchan <- ErrLockTimeout
   917  				return
   918  			}
   919  		}
   920  	}()
   921  
   922  	// now try to acquire the lock
   923  	go func() {
   924  		if err := m.databaseDrv.Lock(); err != nil {
   925  			errchan <- err
   926  		} else {
   927  			errchan <- nil
   928  		}
   929  	}()
   930  
   931  	// wait until we either receive ErrLockTimeout or error from Lock operation
   932  	err := <-errchan
   933  	if err == nil {
   934  		m.isLocked = true
   935  	}
   936  	return err
   937  }
   938  
   939  // unlock is a thread safe helper function to unlock the database.
   940  // It should be called as early as possible when no more migrations are
   941  // expected to be executed.
   942  func (m *Migrate) unlock() error {
   943  	m.isLockedMu.Lock()
   944  	defer m.isLockedMu.Unlock()
   945  
   946  	if err := m.databaseDrv.Unlock(); err != nil {
   947  		// BUG: Can potentially create a deadlock. Add a timeout.
   948  		return err
   949  	}
   950  
   951  	m.isLocked = false
   952  	return nil
   953  }
   954  
   955  // unlockErr calls unlock and returns a combined error
   956  // if a prevErr is not nil.
   957  func (m *Migrate) unlockErr(prevErr error) error {
   958  	if err := m.unlock(); err != nil {
   959  		return multierror.Append(prevErr, err)
   960  	}
   961  	return prevErr
   962  }
   963  
   964  // logPrintf writes to m.Log if not nil
   965  func (m *Migrate) logPrintf(format string, v ...interface{}) {
   966  	if m.Log != nil {
   967  		m.Log.Printf(format, v...)
   968  	}
   969  }
   970  
   971  // logVerbosePrintf writes to m.Log if not nil. Use for verbose logging output.
   972  func (m *Migrate) logVerbosePrintf(format string, v ...interface{}) {
   973  	if m.Log != nil && m.Log.Verbose() {
   974  		m.Log.Printf(format, v...)
   975  	}
   976  }
   977  
   978  // logErr writes error to m.Log if not nil
   979  func (m *Migrate) logErr(err error) {
   980  	if m.Log != nil {
   981  		m.Log.Printf("error: %v", err)
   982  	}
   983  }