github.com/bdollma-te/migrate/v4@v4.17.0-clickv2/migrate.go (about)

     1  // Package migrate reads migrations from sources and runs them against databases.
     2  // Sources are defined by the `source.Driver` and databases by the `database.Driver`
     3  // interface. The driver interfaces are kept "dumb", all migration logic is kept
     4  // in this package.
     5  package migrate
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"os"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/hashicorp/go-multierror"
    15  
    16  	"github.com/bdollma-te/migrate/v4/database"
    17  	iurl "github.com/bdollma-te/migrate/v4/internal/url"
    18  	"github.com/bdollma-te/migrate/v4/source"
    19  )
    20  
    21  // DefaultPrefetchMigrations sets the number of migrations to pre-read
    22  // from the source. This is helpful if the source is remote, but has little
    23  // effect for a local source (i.e. file system).
    24  // Please note that this setting has a major impact on the memory usage,
    25  // since each pre-read migration is buffered in memory. See DefaultBufferSize.
    26  var DefaultPrefetchMigrations = uint(10)
    27  
    28  // DefaultLockTimeout sets the max time a database driver has to acquire a lock.
    29  var DefaultLockTimeout = 15 * time.Second
    30  
    31  var (
    32  	ErrNoChange       = errors.New("no change")
    33  	ErrNilVersion     = errors.New("no migration")
    34  	ErrInvalidVersion = errors.New("version must be >= -1")
    35  	ErrLocked         = errors.New("database locked")
    36  	ErrLockTimeout    = errors.New("timeout: can't acquire database lock")
    37  )
    38  
    39  // ErrShortLimit is an error returned when not enough migrations
    40  // can be returned by a source for a given limit.
    41  type ErrShortLimit struct {
    42  	Short uint
    43  }
    44  
    45  // Error implements the error interface.
    46  func (e ErrShortLimit) Error() string {
    47  	return fmt.Sprintf("limit %v short", e.Short)
    48  }
    49  
    50  type ErrDirty struct {
    51  	Version int
    52  }
    53  
    54  func (e ErrDirty) Error() string {
    55  	return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version)
    56  }
    57  
    58  type Migrate struct {
    59  	sourceName   string
    60  	sourceDrv    source.Driver
    61  	databaseName string
    62  	databaseDrv  database.Driver
    63  
    64  	// Log accepts a Logger interface
    65  	Log Logger
    66  
    67  	// GracefulStop accepts `true` and will stop executing migrations
    68  	// as soon as possible at a safe break point, so that the database
    69  	// is not corrupted.
    70  	GracefulStop chan bool
    71  	isLockedMu   *sync.Mutex
    72  
    73  	isGracefulStop bool
    74  	isLocked       bool
    75  
    76  	// PrefetchMigrations defaults to DefaultPrefetchMigrations,
    77  	// but can be set per Migrate instance.
    78  	PrefetchMigrations uint
    79  
    80  	// LockTimeout defaults to DefaultLockTimeout,
    81  	// but can be set per Migrate instance.
    82  	LockTimeout time.Duration
    83  }
    84  
    85  // New returns a new Migrate instance from a source URL and a database URL.
    86  // The URL scheme is defined by each driver.
    87  func New(sourceURL, databaseURL string) (*Migrate, error) {
    88  	m := newCommon()
    89  
    90  	sourceName, err := iurl.SchemeFromURL(sourceURL)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	m.sourceName = sourceName
    95  
    96  	databaseName, err := iurl.SchemeFromURL(databaseURL)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	m.databaseName = databaseName
   101  
   102  	sourceDrv, err := source.Open(sourceURL)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  	m.sourceDrv = sourceDrv
   107  
   108  	databaseDrv, err := database.Open(databaseURL)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	m.databaseDrv = databaseDrv
   113  
   114  	return m, nil
   115  }
   116  
   117  // NewWithDatabaseInstance returns a new Migrate instance from a source URL
   118  // and an existing database instance. The source URL scheme is defined by each driver.
   119  // Use any string that can serve as an identifier during logging as databaseName.
   120  // You are responsible for closing the underlying database client if necessary.
   121  func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
   122  	m := newCommon()
   123  
   124  	sourceName, err := iurl.SchemeFromURL(sourceURL)
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  	m.sourceName = sourceName
   129  
   130  	m.databaseName = databaseName
   131  
   132  	sourceDrv, err := source.Open(sourceURL)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  	m.sourceDrv = sourceDrv
   137  
   138  	m.databaseDrv = databaseInstance
   139  
   140  	return m, nil
   141  }
   142  
   143  // NewWithSourceInstance returns a new Migrate instance from an existing source instance
   144  // and a database URL. The database URL scheme is defined by each driver.
   145  // Use any string that can serve as an identifier during logging as sourceName.
   146  // You are responsible for closing the underlying source client if necessary.
   147  func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) {
   148  	m := newCommon()
   149  
   150  	databaseName, err := iurl.SchemeFromURL(databaseURL)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  	m.databaseName = databaseName
   155  
   156  	m.sourceName = sourceName
   157  
   158  	databaseDrv, err := database.Open(databaseURL)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	m.databaseDrv = databaseDrv
   163  
   164  	m.sourceDrv = sourceInstance
   165  
   166  	return m, nil
   167  }
   168  
   169  // NewWithInstance returns a new Migrate instance from an existing source and
   170  // database instance. Use any string that can serve as an identifier during logging
   171  // as sourceName and databaseName. You are responsible for closing down
   172  // the underlying source and database client if necessary.
   173  func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
   174  	m := newCommon()
   175  
   176  	m.sourceName = sourceName
   177  	m.databaseName = databaseName
   178  
   179  	m.sourceDrv = sourceInstance
   180  	m.databaseDrv = databaseInstance
   181  
   182  	return m, nil
   183  }
   184  
   185  func newCommon() *Migrate {
   186  	return &Migrate{
   187  		GracefulStop:       make(chan bool, 1),
   188  		PrefetchMigrations: DefaultPrefetchMigrations,
   189  		LockTimeout:        DefaultLockTimeout,
   190  		isLockedMu:         &sync.Mutex{},
   191  	}
   192  }
   193  
   194  // Close closes the source and the database.
   195  func (m *Migrate) Close() (source error, database error) {
   196  	databaseSrvClose := make(chan error)
   197  	sourceSrvClose := make(chan error)
   198  
   199  	m.logVerbosePrintf("Closing source and database\n")
   200  
   201  	go func() {
   202  		databaseSrvClose <- m.databaseDrv.Close()
   203  	}()
   204  
   205  	go func() {
   206  		sourceSrvClose <- m.sourceDrv.Close()
   207  	}()
   208  
   209  	return <-sourceSrvClose, <-databaseSrvClose
   210  }
   211  
   212  // Migrate looks at the currently active migration version,
   213  // then migrates either up or down to the specified version.
   214  func (m *Migrate) Migrate(version uint) error {
   215  	if err := m.lock(); err != nil {
   216  		return err
   217  	}
   218  
   219  	curVersion, dirty, err := m.databaseDrv.Version()
   220  	if err != nil {
   221  		return m.unlockErr(err)
   222  	}
   223  
   224  	if dirty {
   225  		return m.unlockErr(ErrDirty{curVersion})
   226  	}
   227  
   228  	ret := make(chan interface{}, m.PrefetchMigrations)
   229  	go m.read(curVersion, int(version), ret)
   230  
   231  	return m.unlockErr(m.runMigrations(ret))
   232  }
   233  
   234  // Steps looks at the currently active migration version.
   235  // It will migrate up if n > 0, and down if n < 0.
   236  func (m *Migrate) Steps(n int) error {
   237  	if n == 0 {
   238  		return ErrNoChange
   239  	}
   240  
   241  	if err := m.lock(); err != nil {
   242  		return err
   243  	}
   244  
   245  	curVersion, dirty, err := m.databaseDrv.Version()
   246  	if err != nil {
   247  		return m.unlockErr(err)
   248  	}
   249  
   250  	if dirty {
   251  		return m.unlockErr(ErrDirty{curVersion})
   252  	}
   253  
   254  	ret := make(chan interface{}, m.PrefetchMigrations)
   255  
   256  	if n > 0 {
   257  		go m.readUp(curVersion, n, ret)
   258  	} else {
   259  		go m.readDown(curVersion, -n, ret)
   260  	}
   261  
   262  	return m.unlockErr(m.runMigrations(ret))
   263  }
   264  
   265  // Up looks at the currently active migration version
   266  // and will migrate all the way up (applying all up migrations).
   267  func (m *Migrate) Up() error {
   268  	if err := m.lock(); err != nil {
   269  		return err
   270  	}
   271  
   272  	curVersion, dirty, err := m.databaseDrv.Version()
   273  	if err != nil {
   274  		return m.unlockErr(err)
   275  	}
   276  
   277  	if dirty {
   278  		return m.unlockErr(ErrDirty{curVersion})
   279  	}
   280  
   281  	ret := make(chan interface{}, m.PrefetchMigrations)
   282  
   283  	go m.readUp(curVersion, -1, ret)
   284  	return m.unlockErr(m.runMigrations(ret))
   285  }
   286  
   287  // Down looks at the currently active migration version
   288  // and will migrate all the way down (applying all down migrations).
   289  func (m *Migrate) Down() error {
   290  	if err := m.lock(); err != nil {
   291  		return err
   292  	}
   293  
   294  	curVersion, dirty, err := m.databaseDrv.Version()
   295  	if err != nil {
   296  		return m.unlockErr(err)
   297  	}
   298  
   299  	if dirty {
   300  		return m.unlockErr(ErrDirty{curVersion})
   301  	}
   302  
   303  	ret := make(chan interface{}, m.PrefetchMigrations)
   304  	go m.readDown(curVersion, -1, ret)
   305  	return m.unlockErr(m.runMigrations(ret))
   306  }
   307  
   308  // Drop deletes everything in the database.
   309  func (m *Migrate) Drop() error {
   310  	if err := m.lock(); err != nil {
   311  		return err
   312  	}
   313  	if err := m.databaseDrv.Drop(); err != nil {
   314  		return m.unlockErr(err)
   315  	}
   316  	return m.unlock()
   317  }
   318  
   319  // Run runs any migration provided by you against the database.
   320  // It does not check any currently active version in database.
   321  // Usually you don't need this function at all. Use Migrate,
   322  // Steps, Up or Down instead.
   323  func (m *Migrate) Run(migration ...*Migration) error {
   324  	if len(migration) == 0 {
   325  		return ErrNoChange
   326  	}
   327  
   328  	if err := m.lock(); err != nil {
   329  		return err
   330  	}
   331  
   332  	curVersion, dirty, err := m.databaseDrv.Version()
   333  	if err != nil {
   334  		return m.unlockErr(err)
   335  	}
   336  
   337  	if dirty {
   338  		return m.unlockErr(ErrDirty{curVersion})
   339  	}
   340  
   341  	ret := make(chan interface{}, m.PrefetchMigrations)
   342  
   343  	go func() {
   344  		defer close(ret)
   345  		for _, migr := range migration {
   346  			if m.PrefetchMigrations > 0 && migr.Body != nil {
   347  				m.logVerbosePrintf("Start buffering %v\n", migr.LogString())
   348  			} else {
   349  				m.logVerbosePrintf("Scheduled %v\n", migr.LogString())
   350  			}
   351  
   352  			ret <- migr
   353  			go func(migr *Migration) {
   354  				if err := migr.Buffer(); err != nil {
   355  					m.logErr(err)
   356  				}
   357  			}(migr)
   358  		}
   359  	}()
   360  
   361  	return m.unlockErr(m.runMigrations(ret))
   362  }
   363  
   364  // Force sets a migration version.
   365  // It does not check any currently active version in database.
   366  // It resets the dirty state to false.
   367  func (m *Migrate) Force(version int) error {
   368  	if version < -1 {
   369  		return ErrInvalidVersion
   370  	}
   371  
   372  	if err := m.lock(); err != nil {
   373  		return err
   374  	}
   375  
   376  	if err := m.databaseDrv.SetVersion(version, false); err != nil {
   377  		return m.unlockErr(err)
   378  	}
   379  
   380  	return m.unlock()
   381  }
   382  
   383  // Version returns the currently active migration version.
   384  // If no migration has been applied, yet, it will return ErrNilVersion.
   385  func (m *Migrate) Version() (version uint, dirty bool, err error) {
   386  	v, d, err := m.databaseDrv.Version()
   387  	if err != nil {
   388  		return 0, false, err
   389  	}
   390  
   391  	if v == database.NilVersion {
   392  		return 0, false, ErrNilVersion
   393  	}
   394  
   395  	return suint(v), d, nil
   396  }
   397  
   398  // read reads either up or down migrations from source `from` to `to`.
   399  // Each migration is then written to the ret channel.
   400  // If an error occurs during reading, that error is written to the ret channel, too.
   401  // Once read is done reading it will close the ret channel.
   402  func (m *Migrate) read(from int, to int, ret chan<- interface{}) {
   403  	defer close(ret)
   404  
   405  	// check if from version exists
   406  	if from >= 0 {
   407  		if err := m.versionExists(suint(from)); err != nil {
   408  			ret <- err
   409  			return
   410  		}
   411  	}
   412  
   413  	// check if to version exists
   414  	if to >= 0 {
   415  		if err := m.versionExists(suint(to)); err != nil {
   416  			ret <- err
   417  			return
   418  		}
   419  	}
   420  
   421  	// no change?
   422  	if from == to {
   423  		ret <- ErrNoChange
   424  		return
   425  	}
   426  
   427  	if from < to {
   428  		// it's going up
   429  		// apply first migration if from is nil version
   430  		if from == -1 {
   431  			firstVersion, err := m.sourceDrv.First()
   432  			if err != nil {
   433  				ret <- err
   434  				return
   435  			}
   436  
   437  			migr, err := m.newMigration(firstVersion, int(firstVersion))
   438  			if err != nil {
   439  				ret <- err
   440  				return
   441  			}
   442  
   443  			ret <- migr
   444  			go func() {
   445  				if err := migr.Buffer(); err != nil {
   446  					m.logErr(err)
   447  				}
   448  			}()
   449  
   450  			from = int(firstVersion)
   451  		}
   452  
   453  		// run until we reach target ...
   454  		for from < to {
   455  			if m.stop() {
   456  				return
   457  			}
   458  
   459  			next, err := m.sourceDrv.Next(suint(from))
   460  			if err != nil {
   461  				ret <- err
   462  				return
   463  			}
   464  
   465  			migr, err := m.newMigration(next, int(next))
   466  			if err != nil {
   467  				ret <- err
   468  				return
   469  			}
   470  
   471  			ret <- migr
   472  			go func() {
   473  				if err := migr.Buffer(); err != nil {
   474  					m.logErr(err)
   475  				}
   476  			}()
   477  
   478  			from = int(next)
   479  		}
   480  
   481  	} else {
   482  		// it's going down
   483  		// run until we reach target ...
   484  		for from > to && from >= 0 {
   485  			if m.stop() {
   486  				return
   487  			}
   488  
   489  			prev, err := m.sourceDrv.Prev(suint(from))
   490  			if errors.Is(err, os.ErrNotExist) && to == -1 {
   491  				// apply nil migration
   492  				migr, err := m.newMigration(suint(from), -1)
   493  				if err != nil {
   494  					ret <- err
   495  					return
   496  				}
   497  				ret <- migr
   498  				go func() {
   499  					if err := migr.Buffer(); err != nil {
   500  						m.logErr(err)
   501  					}
   502  				}()
   503  
   504  				return
   505  
   506  			} else if err != nil {
   507  				ret <- err
   508  				return
   509  			}
   510  
   511  			migr, err := m.newMigration(suint(from), int(prev))
   512  			if err != nil {
   513  				ret <- err
   514  				return
   515  			}
   516  
   517  			ret <- migr
   518  			go func() {
   519  				if err := migr.Buffer(); err != nil {
   520  					m.logErr(err)
   521  				}
   522  			}()
   523  
   524  			from = int(prev)
   525  		}
   526  	}
   527  }
   528  
   529  // readUp reads up migrations from `from` limitted by `limit`.
   530  // limit can be -1, implying no limit and reading until there are no more migrations.
   531  // Each migration is then written to the ret channel.
   532  // If an error occurs during reading, that error is written to the ret channel, too.
   533  // Once readUp is done reading it will close the ret channel.
   534  func (m *Migrate) readUp(from int, limit int, ret chan<- interface{}) {
   535  	defer close(ret)
   536  
   537  	// check if from version exists
   538  	if from >= 0 {
   539  		if err := m.versionExists(suint(from)); err != nil {
   540  			ret <- err
   541  			return
   542  		}
   543  	}
   544  
   545  	if limit == 0 {
   546  		ret <- ErrNoChange
   547  		return
   548  	}
   549  
   550  	count := 0
   551  	for count < limit || limit == -1 {
   552  		if m.stop() {
   553  			return
   554  		}
   555  
   556  		// apply first migration if from is nil version
   557  		if from == -1 {
   558  			firstVersion, err := m.sourceDrv.First()
   559  			if err != nil {
   560  				ret <- err
   561  				return
   562  			}
   563  
   564  			migr, err := m.newMigration(firstVersion, int(firstVersion))
   565  			if err != nil {
   566  				ret <- err
   567  				return
   568  			}
   569  
   570  			ret <- migr
   571  			go func() {
   572  				if err := migr.Buffer(); err != nil {
   573  					m.logErr(err)
   574  				}
   575  			}()
   576  			from = int(firstVersion)
   577  			count++
   578  			continue
   579  		}
   580  
   581  		// apply next migration
   582  		next, err := m.sourceDrv.Next(suint(from))
   583  		if errors.Is(err, os.ErrNotExist) {
   584  			// no limit, but no migrations applied?
   585  			if limit == -1 && count == 0 {
   586  				ret <- ErrNoChange
   587  				return
   588  			}
   589  
   590  			// no limit, reached end
   591  			if limit == -1 {
   592  				return
   593  			}
   594  
   595  			// reached end, and didn't apply any migrations
   596  			if limit > 0 && count == 0 {
   597  				ret <- os.ErrNotExist
   598  				return
   599  			}
   600  
   601  			// applied less migrations than limit?
   602  			if count < limit {
   603  				ret <- ErrShortLimit{suint(limit - count)}
   604  				return
   605  			}
   606  		}
   607  		if err != nil {
   608  			ret <- err
   609  			return
   610  		}
   611  
   612  		migr, err := m.newMigration(next, int(next))
   613  		if err != nil {
   614  			ret <- err
   615  			return
   616  		}
   617  
   618  		ret <- migr
   619  		go func() {
   620  			if err := migr.Buffer(); err != nil {
   621  				m.logErr(err)
   622  			}
   623  		}()
   624  		from = int(next)
   625  		count++
   626  	}
   627  }
   628  
   629  // readDown reads down migrations from `from` limitted by `limit`.
   630  // limit can be -1, implying no limit and reading until there are no more migrations.
   631  // Each migration is then written to the ret channel.
   632  // If an error occurs during reading, that error is written to the ret channel, too.
   633  // Once readDown is done reading it will close the ret channel.
   634  func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
   635  	defer close(ret)
   636  
   637  	// check if from version exists
   638  	if from >= 0 {
   639  		if err := m.versionExists(suint(from)); err != nil {
   640  			ret <- err
   641  			return
   642  		}
   643  	}
   644  
   645  	if limit == 0 {
   646  		ret <- ErrNoChange
   647  		return
   648  	}
   649  
   650  	// no change if already at nil version
   651  	if from == -1 && limit == -1 {
   652  		ret <- ErrNoChange
   653  		return
   654  	}
   655  
   656  	// can't go over limit if already at nil version
   657  	if from == -1 && limit > 0 {
   658  		ret <- os.ErrNotExist
   659  		return
   660  	}
   661  
   662  	count := 0
   663  	for count < limit || limit == -1 {
   664  		if m.stop() {
   665  			return
   666  		}
   667  
   668  		prev, err := m.sourceDrv.Prev(suint(from))
   669  		if errors.Is(err, os.ErrNotExist) {
   670  			// no limit or haven't reached limit, apply "first" migration
   671  			if limit == -1 || limit-count > 0 {
   672  				firstVersion, err := m.sourceDrv.First()
   673  				if err != nil {
   674  					ret <- err
   675  					return
   676  				}
   677  
   678  				migr, err := m.newMigration(firstVersion, -1)
   679  				if err != nil {
   680  					ret <- err
   681  					return
   682  				}
   683  				ret <- migr
   684  				go func() {
   685  					if err := migr.Buffer(); err != nil {
   686  						m.logErr(err)
   687  					}
   688  				}()
   689  				count++
   690  			}
   691  
   692  			if count < limit {
   693  				ret <- ErrShortLimit{suint(limit - count)}
   694  			}
   695  			return
   696  		}
   697  		if err != nil {
   698  			ret <- err
   699  			return
   700  		}
   701  
   702  		migr, err := m.newMigration(suint(from), int(prev))
   703  		if err != nil {
   704  			ret <- err
   705  			return
   706  		}
   707  
   708  		ret <- migr
   709  		go func() {
   710  			if err := migr.Buffer(); err != nil {
   711  				m.logErr(err)
   712  			}
   713  		}()
   714  		from = int(prev)
   715  		count++
   716  	}
   717  }
   718  
   719  // runMigrations reads *Migration and error from a channel. Any other type
   720  // sent on this channel will result in a panic. Each migration is then
   721  // proxied to the database driver and run against the database.
   722  // Before running a newly received migration it will check if it's supposed
   723  // to stop execution because it might have received a stop signal on the
   724  // GracefulStop channel.
   725  func (m *Migrate) runMigrations(ret <-chan interface{}) error {
   726  	for r := range ret {
   727  
   728  		if m.stop() {
   729  			return nil
   730  		}
   731  
   732  		switch r := r.(type) {
   733  		case error:
   734  			return r
   735  
   736  		case *Migration:
   737  			migr := r
   738  
   739  			// set version with dirty state
   740  			if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
   741  				return err
   742  			}
   743  
   744  			if migr.Body != nil {
   745  				m.logVerbosePrintf("Read and execute %v\n", migr.LogString())
   746  				if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
   747  					return err
   748  				}
   749  			}
   750  
   751  			// set clean state
   752  			if err := m.databaseDrv.SetVersion(migr.TargetVersion, false); err != nil {
   753  				return err
   754  			}
   755  
   756  			endTime := time.Now()
   757  			readTime := migr.FinishedReading.Sub(migr.StartedBuffering)
   758  			runTime := endTime.Sub(migr.FinishedReading)
   759  
   760  			// log either verbose or normal
   761  			if m.Log != nil {
   762  				if m.Log.Verbose() {
   763  					m.logPrintf("Finished %v (read %v, ran %v)\n", migr.LogString(), readTime, runTime)
   764  				} else {
   765  					m.logPrintf("%v (%v)\n", migr.LogString(), readTime+runTime)
   766  				}
   767  			}
   768  
   769  		default:
   770  			return fmt.Errorf("unknown type: %T with value: %+v", r, r)
   771  		}
   772  	}
   773  	return nil
   774  }
   775  
   776  // versionExists checks the source if either the up or down migration for
   777  // the specified migration version exists.
   778  func (m *Migrate) versionExists(version uint) (result error) {
   779  	// try up migration first
   780  	up, _, err := m.sourceDrv.ReadUp(version)
   781  	if err == nil {
   782  		defer func() {
   783  			if errClose := up.Close(); errClose != nil {
   784  				result = multierror.Append(result, errClose)
   785  			}
   786  		}()
   787  	}
   788  	if errors.Is(err, os.ErrExist) {
   789  		return nil
   790  	} else if !errors.Is(err, os.ErrNotExist) {
   791  		return err
   792  	}
   793  
   794  	// then try down migration
   795  	down, _, err := m.sourceDrv.ReadDown(version)
   796  	if err == nil {
   797  		defer func() {
   798  			if errClose := down.Close(); errClose != nil {
   799  				result = multierror.Append(result, errClose)
   800  			}
   801  		}()
   802  	}
   803  	if errors.Is(err, os.ErrExist) {
   804  		return nil
   805  	} else if !errors.Is(err, os.ErrNotExist) {
   806  		return err
   807  	}
   808  
   809  	err = fmt.Errorf("no migration found for version %d: %w", version, err)
   810  	m.logErr(err)
   811  	return err
   812  }
   813  
   814  // stop returns true if no more migrations should be run against the database
   815  // because a stop signal was received on the GracefulStop channel.
   816  // Calls are cheap and this function is not blocking.
   817  func (m *Migrate) stop() bool {
   818  	if m.isGracefulStop {
   819  		return true
   820  	}
   821  
   822  	select {
   823  	case <-m.GracefulStop:
   824  		m.isGracefulStop = true
   825  		return true
   826  
   827  	default:
   828  		return false
   829  	}
   830  }
   831  
   832  // newMigration is a helper func that returns a *Migration for the
   833  // specified version and targetVersion.
   834  func (m *Migrate) newMigration(version uint, targetVersion int) (*Migration, error) {
   835  	var migr *Migration
   836  
   837  	if targetVersion >= int(version) {
   838  		r, identifier, err := m.sourceDrv.ReadUp(version)
   839  		if errors.Is(err, os.ErrNotExist) {
   840  			// create "empty" migration
   841  			migr, err = NewMigration(nil, "", version, targetVersion)
   842  			if err != nil {
   843  				return nil, err
   844  			}
   845  
   846  		} else if err != nil {
   847  			return nil, err
   848  
   849  		} else {
   850  			// create migration from up source
   851  			migr, err = NewMigration(r, identifier, version, targetVersion)
   852  			if err != nil {
   853  				return nil, err
   854  			}
   855  		}
   856  
   857  	} else {
   858  		r, identifier, err := m.sourceDrv.ReadDown(version)
   859  		if errors.Is(err, os.ErrNotExist) {
   860  			// create "empty" migration
   861  			migr, err = NewMigration(nil, "", version, targetVersion)
   862  			if err != nil {
   863  				return nil, err
   864  			}
   865  
   866  		} else if err != nil {
   867  			return nil, err
   868  
   869  		} else {
   870  			// create migration from down source
   871  			migr, err = NewMigration(r, identifier, version, targetVersion)
   872  			if err != nil {
   873  				return nil, err
   874  			}
   875  		}
   876  	}
   877  
   878  	if m.PrefetchMigrations > 0 && migr.Body != nil {
   879  		m.logVerbosePrintf("Start buffering %v\n", migr.LogString())
   880  	} else {
   881  		m.logVerbosePrintf("Scheduled %v\n", migr.LogString())
   882  	}
   883  
   884  	return migr, nil
   885  }
   886  
   887  // lock is a thread safe helper function to lock the database.
   888  // It should be called as late as possible when running migrations.
   889  func (m *Migrate) lock() error {
   890  	m.isLockedMu.Lock()
   891  	defer m.isLockedMu.Unlock()
   892  
   893  	if m.isLocked {
   894  		return ErrLocked
   895  	}
   896  
   897  	// create done channel, used in the timeout goroutine
   898  	done := make(chan bool, 1)
   899  	defer func() {
   900  		done <- true
   901  	}()
   902  
   903  	// use errchan to signal error back to this context
   904  	errchan := make(chan error, 2)
   905  
   906  	// start timeout goroutine
   907  	timeout := time.After(m.LockTimeout)
   908  	go func() {
   909  		for {
   910  			select {
   911  			case <-done:
   912  				return
   913  			case <-timeout:
   914  				errchan <- ErrLockTimeout
   915  				return
   916  			}
   917  		}
   918  	}()
   919  
   920  	// now try to acquire the lock
   921  	go func() {
   922  		if err := m.databaseDrv.Lock(); err != nil {
   923  			errchan <- err
   924  		} else {
   925  			errchan <- nil
   926  		}
   927  	}()
   928  
   929  	// wait until we either receive ErrLockTimeout or error from Lock operation
   930  	err := <-errchan
   931  	if err == nil {
   932  		m.isLocked = true
   933  	}
   934  	return err
   935  }
   936  
   937  // unlock is a thread safe helper function to unlock the database.
   938  // It should be called as early as possible when no more migrations are
   939  // expected to be executed.
   940  func (m *Migrate) unlock() error {
   941  	m.isLockedMu.Lock()
   942  	defer m.isLockedMu.Unlock()
   943  
   944  	if err := m.databaseDrv.Unlock(); err != nil {
   945  		// BUG: Can potentially create a deadlock. Add a timeout.
   946  		return err
   947  	}
   948  
   949  	m.isLocked = false
   950  	return nil
   951  }
   952  
   953  // unlockErr calls unlock and returns a combined error
   954  // if a prevErr is not nil.
   955  func (m *Migrate) unlockErr(prevErr error) error {
   956  	if err := m.unlock(); err != nil {
   957  		return multierror.Append(prevErr, err)
   958  	}
   959  	return prevErr
   960  }
   961  
   962  // logPrintf writes to m.Log if not nil
   963  func (m *Migrate) logPrintf(format string, v ...interface{}) {
   964  	if m.Log != nil {
   965  		m.Log.Printf(format, v...)
   966  	}
   967  }
   968  
   969  // logVerbosePrintf writes to m.Log if not nil. Use for verbose logging output.
   970  func (m *Migrate) logVerbosePrintf(format string, v ...interface{}) {
   971  	if m.Log != nil && m.Log.Verbose() {
   972  		m.Log.Printf(format, v...)
   973  	}
   974  }
   975  
   976  // logErr writes error to m.Log if not nil
   977  func (m *Migrate) logErr(err error) {
   978  	if m.Log != nil {
   979  		m.Log.Printf("error: %v", err)
   980  	}
   981  }