github.com/nagyist/migrate/v4@v4.14.6/database/mongodb/mongodb.go (about)

     1  package mongodb
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"github.com/cenkalti/backoff/v4"
     7  	"github.com/golang-migrate/migrate/v4/database"
     8  	"github.com/hashicorp/go-multierror"
     9  	"go.mongodb.org/mongo-driver/bson"
    10  	"go.mongodb.org/mongo-driver/mongo"
    11  	"go.mongodb.org/mongo-driver/mongo/options"
    12  	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
    13  	"io"
    14  	"io/ioutil"
    15  	"net/url"
    16  	os "os"
    17  	"strconv"
    18  	"time"
    19  )
    20  
    21  func init() {
    22  	db := Mongo{}
    23  	database.Register("mongodb", &db)
    24  	database.Register("mongodb+srv", &db)
    25  }
    26  
    27  var DefaultMigrationsCollection = "schema_migrations"
    28  
    29  const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default.
    30  const lockKeyUniqueValue = 0                             // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked).
    31  const DefaultLockTimeout = 15                            // the default maximum time to wait for a lock to be released.
    32  const DefaultLockTimeoutInterval = 10                    // the default maximum intervals time for the locking timout.
    33  const DefaultAdvisoryLockingFlag = true                  // the default value for the advisory locking feature flag. Default is true.
    34  const LockIndexName = "lock_unique_key"                  // the name of the index which adds unique constraint to the locking_key field.
    35  const contextWaitTimeout = 5 * time.Second               // how long to wait for the request to mongo to block/wait for.
    36  
    37  var (
    38  	ErrNoDatabaseName = fmt.Errorf("no database name")
    39  	ErrNilConfig      = fmt.Errorf("no config")
    40  )
    41  
    42  type Mongo struct {
    43  	client *mongo.Client
    44  	db     *mongo.Database
    45  	config *Config
    46  }
    47  
    48  type Locking struct {
    49  	CollectionName string
    50  	Timeout        int
    51  	Enabled        bool
    52  	Interval       int
    53  }
    54  type Config struct {
    55  	DatabaseName         string
    56  	MigrationsCollection string
    57  	TransactionMode      bool
    58  	Locking              Locking
    59  }
    60  type versionInfo struct {
    61  	Version int  `bson:"version"`
    62  	Dirty   bool `bson:"dirty"`
    63  }
    64  
    65  type lockObj struct {
    66  	Key       int       `bson:"locking_key"`
    67  	Pid       int       `bson:"pid"`
    68  	Hostname  string    `bson:"hostname"`
    69  	CreatedAt time.Time `bson:"created_at"`
    70  }
    71  type findFilter struct {
    72  	Key int `bson:"locking_key"`
    73  }
    74  
    75  func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
    76  	if config == nil {
    77  		return nil, ErrNilConfig
    78  	}
    79  	if len(config.DatabaseName) == 0 {
    80  		return nil, ErrNoDatabaseName
    81  	}
    82  	if len(config.MigrationsCollection) == 0 {
    83  		config.MigrationsCollection = DefaultMigrationsCollection
    84  	}
    85  	if len(config.Locking.CollectionName) == 0 {
    86  		config.Locking.CollectionName = DefaultLockingCollection
    87  	}
    88  	if config.Locking.Timeout <= 0 {
    89  		config.Locking.Timeout = DefaultLockTimeout
    90  	}
    91  	if config.Locking.Interval <= 0 {
    92  		config.Locking.Interval = DefaultLockTimeoutInterval
    93  	}
    94  
    95  	mc := &Mongo{
    96  		client: instance,
    97  		db:     instance.Database(config.DatabaseName),
    98  		config: config,
    99  	}
   100  
   101  	if mc.config.Locking.Enabled {
   102  		if err := mc.ensureLockTable(); err != nil {
   103  			return nil, err
   104  		}
   105  	}
   106  	if err := mc.ensureVersionTable(); err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	return mc, nil
   111  }
   112  
   113  func (m *Mongo) Open(dsn string) (database.Driver, error) {
   114  	//connstring is experimental package, but it used for parse connection string in mongo.Connect function
   115  	uri, err := connstring.Parse(dsn)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  	if len(uri.Database) == 0 {
   120  		return nil, ErrNoDatabaseName
   121  	}
   122  	unknown := url.Values(uri.UnknownOptions)
   123  
   124  	migrationsCollection := unknown.Get("x-migrations-collection")
   125  	lockCollection := unknown.Get("x-advisory-lock-collection")
   126  	transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  	advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag)
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  	lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout)
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  	maxLockingIntervals, err := parseInt(unknown.Get("x-advisory-lock-timout-interval"), DefaultLockTimeoutInterval)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  	client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn))
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	if err = client.Ping(context.TODO(), nil); err != nil {
   148  		return nil, err
   149  	}
   150  	mc, err := WithInstance(client, &Config{
   151  		DatabaseName:         uri.Database,
   152  		MigrationsCollection: migrationsCollection,
   153  		TransactionMode:      transactionMode,
   154  		Locking: Locking{
   155  			CollectionName: lockCollection,
   156  			Timeout:        lockingTimout,
   157  			Enabled:        advisoryLockingFlag,
   158  			Interval:       maxLockingIntervals,
   159  		},
   160  	})
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  	return mc, nil
   165  }
   166  
   167  //Parse the url param, convert it to boolean
   168  // returns error if param invalid. returns defaultValue if param not present
   169  func parseBoolean(urlParam string, defaultValue bool) (bool, error) {
   170  
   171  	// if parameter passed, parse it (otherwise return default value)
   172  	if urlParam != "" {
   173  		result, err := strconv.ParseBool(urlParam)
   174  		if err != nil {
   175  			return false, err
   176  		}
   177  		return result, nil
   178  	}
   179  
   180  	// if no url Param passed, return default value
   181  	return defaultValue, nil
   182  }
   183  
   184  //Parse the url param, convert it to int
   185  // returns error if param invalid. returns defaultValue if param not present
   186  func parseInt(urlParam string, defaultValue int) (int, error) {
   187  
   188  	// if parameter passed, parse it (otherwise return default value)
   189  	if urlParam != "" {
   190  		result, err := strconv.Atoi(urlParam)
   191  		if err != nil {
   192  			return -1, err
   193  		}
   194  		return result, nil
   195  	}
   196  
   197  	// if no url Param passed, return default value
   198  	return defaultValue, nil
   199  }
   200  func (m *Mongo) SetVersion(version int, dirty bool) error {
   201  	migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
   202  	if err := migrationsCollection.Drop(context.TODO()); err != nil {
   203  		return &database.Error{OrigErr: err, Err: "drop migrations collection failed"}
   204  	}
   205  	_, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty})
   206  	if err != nil {
   207  		return &database.Error{OrigErr: err, Err: "save version failed"}
   208  	}
   209  	return nil
   210  }
   211  
   212  func (m *Mongo) Version() (version int, dirty bool, err error) {
   213  	var versionInfo versionInfo
   214  	err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo)
   215  	switch {
   216  	case err == mongo.ErrNoDocuments:
   217  		return database.NilVersion, false, nil
   218  	case err != nil:
   219  		return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"}
   220  	default:
   221  		return versionInfo.Version, versionInfo.Dirty, nil
   222  	}
   223  }
   224  
   225  func (m *Mongo) Run(migration io.Reader) error {
   226  	migr, err := ioutil.ReadAll(migration)
   227  	if err != nil {
   228  		return err
   229  	}
   230  	var cmds []bson.D
   231  	err = bson.UnmarshalExtJSON(migr, true, &cmds)
   232  	if err != nil {
   233  		return fmt.Errorf("unmarshaling json error: %s", err)
   234  	}
   235  	if m.config.TransactionMode {
   236  		if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil {
   237  			return err
   238  		}
   239  	} else {
   240  		if err := m.executeCommands(context.TODO(), cmds); err != nil {
   241  			return err
   242  		}
   243  	}
   244  	return nil
   245  }
   246  
   247  func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error {
   248  	err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error {
   249  		if err := sessionContext.StartTransaction(); err != nil {
   250  			return &database.Error{OrigErr: err, Err: "failed to start transaction"}
   251  		}
   252  		if err := m.executeCommands(sessionContext, cmds); err != nil {
   253  			//When command execution is failed, it's aborting transaction
   254  			//If you tried to call abortTransaction, it`s return error that transaction already aborted
   255  			return err
   256  		}
   257  		if err := sessionContext.CommitTransaction(sessionContext); err != nil {
   258  			return &database.Error{OrigErr: err, Err: "failed to commit transaction"}
   259  		}
   260  		return nil
   261  	})
   262  	if err != nil {
   263  		return err
   264  	}
   265  	return nil
   266  }
   267  
   268  func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error {
   269  	for _, cmd := range cmds {
   270  		err := m.db.RunCommand(ctx, cmd).Err()
   271  		if err != nil {
   272  			return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)}
   273  		}
   274  	}
   275  	return nil
   276  }
   277  
   278  func (m *Mongo) Close() error {
   279  	return m.client.Disconnect(context.TODO())
   280  }
   281  
   282  func (m *Mongo) Drop() error {
   283  	return m.db.Drop(context.TODO())
   284  }
   285  
   286  func (m *Mongo) ensureLockTable() error {
   287  	indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes()
   288  
   289  	indexOptions := options.Index().SetUnique(true).SetName(LockIndexName)
   290  	_, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{
   291  		Options: indexOptions,
   292  		Keys:    findFilter{Key: -1},
   293  	})
   294  	if err != nil {
   295  		return err
   296  	}
   297  	return nil
   298  }
   299  
   300  // ensureVersionTable checks if versions table exists and, if not, creates it.
   301  // Note that this function locks the database, which deviates from the usual
   302  // convention of "caller locks" in the MongoDb type.
   303  func (m *Mongo) ensureVersionTable() (err error) {
   304  	if err = m.Lock(); err != nil {
   305  		return err
   306  	}
   307  
   308  	defer func() {
   309  		if e := m.Unlock(); e != nil {
   310  			if err == nil {
   311  				err = e
   312  			} else {
   313  				err = multierror.Append(err, e)
   314  			}
   315  		}
   316  	}()
   317  
   318  	if err != nil {
   319  		return err
   320  	}
   321  	if _, _, err = m.Version(); err != nil {
   322  		return err
   323  	}
   324  	return nil
   325  }
   326  
   327  // Utilizes advisory locking on the config.LockingCollection collection
   328  // This uses a unique index on the `locking_key` field.
   329  func (m *Mongo) Lock() error {
   330  	if !m.config.Locking.Enabled {
   331  		return nil
   332  	}
   333  	pid := os.Getpid()
   334  	hostname, err := os.Hostname()
   335  	if err != nil {
   336  		hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
   337  	}
   338  
   339  	newLockObj := lockObj{
   340  		Key:       lockKeyUniqueValue,
   341  		Pid:       pid,
   342  		Hostname:  hostname,
   343  		CreatedAt: time.Now(),
   344  	}
   345  	operation := func() error {
   346  		timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
   347  		_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
   348  		defer cancelFunc()
   349  		return err
   350  	}
   351  	exponentialBackOff := backoff.NewExponentialBackOff()
   352  	duration := time.Duration(m.config.Locking.Timeout) * time.Second
   353  	exponentialBackOff.MaxElapsedTime = duration
   354  	exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
   355  
   356  	err = backoff.Retry(operation, exponentialBackOff)
   357  	if err != nil {
   358  		return database.ErrLocked
   359  	}
   360  
   361  	return nil
   362  
   363  }
   364  func (m *Mongo) Unlock() error {
   365  	if !m.config.Locking.Enabled {
   366  		return nil
   367  	}
   368  
   369  	filter := findFilter{
   370  		Key: lockKeyUniqueValue,
   371  	}
   372  
   373  	ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
   374  	_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
   375  	defer cancel()
   376  
   377  	if err != nil {
   378  		return err
   379  	}
   380  	return nil
   381  }