github.com/Elate-DevOps/migrate/v4@v4.0.12/database/mongodb/mongodb.go (about)

     1  package mongodb
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net/url"
     8  	"os"
     9  	"strconv"
    10  	"time"
    11  
    12  	"github.com/Elate-DevOps/migrate/v4/database"
    13  	"github.com/cenkalti/backoff/v4"
    14  	"github.com/hashicorp/go-multierror"
    15  	"go.mongodb.org/mongo-driver/bson"
    16  	"go.mongodb.org/mongo-driver/mongo"
    17  	"go.mongodb.org/mongo-driver/mongo/options"
    18  	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
    19  	"go.uber.org/atomic"
    20  )
    21  
    22  func init() {
    23  	db := Mongo{}
    24  	database.Register("mongodb", &db)
    25  	database.Register("mongodb+srv", &db)
    26  }
    27  
    28  var DefaultMigrationsCollection = "schema_migrations"
    29  
    30  const (
    31  	DefaultLockingCollection   = "migrate_advisory_lock" // the collection to use for advisory locking by default.
    32  	lockKeyUniqueValue         = 0                       // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked).
    33  	DefaultLockTimeout         = 15                      // the default maximum time to wait for a lock to be released.
    34  	DefaultLockTimeoutInterval = 10                      // the default maximum intervals time for the locking timout.
    35  	DefaultAdvisoryLockingFlag = true                    // the default value for the advisory locking feature flag. Default is true.
    36  	LockIndexName              = "lock_unique_key"       // the name of the index which adds unique constraint to the locking_key field.
    37  	contextWaitTimeout         = 5 * time.Second         // how long to wait for the request to mongo to block/wait for.
    38  )
    39  
    40  var (
    41  	ErrNoDatabaseName            = fmt.Errorf("no database name")
    42  	ErrNilConfig                 = fmt.Errorf("no config")
    43  	ErrLockTimeoutConfigConflict = fmt.Errorf("both x-advisory-lock-timeout-interval and x-advisory-lock-timout-interval were specified")
    44  )
    45  
    46  type Mongo struct {
    47  	client   *mongo.Client
    48  	db       *mongo.Database
    49  	config   *Config
    50  	isLocked atomic.Bool
    51  }
    52  
    53  type Locking struct {
    54  	CollectionName string
    55  	Timeout        int
    56  	Enabled        bool
    57  	Interval       int
    58  }
    59  type Config struct {
    60  	DatabaseName         string
    61  	MigrationsCollection string
    62  	TransactionMode      bool
    63  	Locking              Locking
    64  }
    65  type versionInfo struct {
    66  	Version int  `bson:"version"`
    67  	Dirty   bool `bson:"dirty"`
    68  }
    69  
    70  type lockObj struct {
    71  	Key       int       `bson:"locking_key"`
    72  	Pid       int       `bson:"pid"`
    73  	Hostname  string    `bson:"hostname"`
    74  	CreatedAt time.Time `bson:"created_at"`
    75  }
    76  type findFilter struct {
    77  	Key int `bson:"locking_key"`
    78  }
    79  
    80  func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
    81  	if config == nil {
    82  		return nil, ErrNilConfig
    83  	}
    84  	if len(config.DatabaseName) == 0 {
    85  		return nil, ErrNoDatabaseName
    86  	}
    87  	if len(config.MigrationsCollection) == 0 {
    88  		config.MigrationsCollection = DefaultMigrationsCollection
    89  	}
    90  	if len(config.Locking.CollectionName) == 0 {
    91  		config.Locking.CollectionName = DefaultLockingCollection
    92  	}
    93  	if config.Locking.Timeout <= 0 {
    94  		config.Locking.Timeout = DefaultLockTimeout
    95  	}
    96  	if config.Locking.Interval <= 0 {
    97  		config.Locking.Interval = DefaultLockTimeoutInterval
    98  	}
    99  
   100  	mc := &Mongo{
   101  		client: instance,
   102  		db:     instance.Database(config.DatabaseName),
   103  		config: config,
   104  	}
   105  
   106  	if mc.config.Locking.Enabled {
   107  		if err := mc.ensureLockTable(); err != nil {
   108  			return nil, err
   109  		}
   110  	}
   111  	if err := mc.ensureVersionTable(); err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	return mc, nil
   116  }
   117  
   118  func (m *Mongo) Open(dsn string) (database.Driver, error) {
   119  	// connstring is experimental package, but it used for parse connection string in mongo.Connect function
   120  	uri, err := connstring.Parse(dsn)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  	if len(uri.Database) == 0 {
   125  		return nil, ErrNoDatabaseName
   126  	}
   127  	unknown := url.Values(uri.UnknownOptions)
   128  
   129  	migrationsCollection := unknown.Get("x-migrations-collection")
   130  	lockCollection := unknown.Get("x-advisory-lock-collection")
   131  	transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout)
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  
   144  	lockTimeoutIntervalValue := unknown.Get("x-advisory-lock-timeout-interval")
   145  	// The initial release had a typo for this argument but for backwards compatibility sake, we will keep supporting it
   146  	// and we will error out if both values are set.
   147  	lockTimeoutIntervalValueFromTypo := unknown.Get("x-advisory-lock-timout-interval")
   148  
   149  	lockTimeout := lockTimeoutIntervalValue
   150  
   151  	if lockTimeoutIntervalValue != "" && lockTimeoutIntervalValueFromTypo != "" {
   152  		return nil, ErrLockTimeoutConfigConflict
   153  	} else if lockTimeoutIntervalValueFromTypo != "" {
   154  		lockTimeout = lockTimeoutIntervalValueFromTypo
   155  	}
   156  
   157  	maxLockCheckInterval, err := parseInt(lockTimeout, DefaultLockTimeoutInterval)
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  	client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn))
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  
   166  	if err = client.Ping(context.TODO(), nil); err != nil {
   167  		return nil, err
   168  	}
   169  	mc, err := WithInstance(client, &Config{
   170  		DatabaseName:         uri.Database,
   171  		MigrationsCollection: migrationsCollection,
   172  		TransactionMode:      transactionMode,
   173  		Locking: Locking{
   174  			CollectionName: lockCollection,
   175  			Timeout:        lockingTimout,
   176  			Enabled:        advisoryLockingFlag,
   177  			Interval:       maxLockCheckInterval,
   178  		},
   179  	})
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  	return mc, nil
   184  }
   185  
   186  // Parse the url param, convert it to boolean
   187  // returns error if param invalid. returns defaultValue if param not present
   188  func parseBoolean(urlParam string, defaultValue bool) (bool, error) {
   189  	// if parameter passed, parse it (otherwise return default value)
   190  	if urlParam != "" {
   191  		result, err := strconv.ParseBool(urlParam)
   192  		if err != nil {
   193  			return false, err
   194  		}
   195  		return result, nil
   196  	}
   197  
   198  	// if no url Param passed, return default value
   199  	return defaultValue, nil
   200  }
   201  
   202  // Parse the url param, convert it to int
   203  // returns error if param invalid. returns defaultValue if param not present
   204  func parseInt(urlParam string, defaultValue int) (int, error) {
   205  	// if parameter passed, parse it (otherwise return default value)
   206  	if urlParam != "" {
   207  		result, err := strconv.Atoi(urlParam)
   208  		if err != nil {
   209  			return -1, err
   210  		}
   211  		return result, nil
   212  	}
   213  
   214  	// if no url Param passed, return default value
   215  	return defaultValue, nil
   216  }
   217  
   218  func (m *Mongo) SetVersion(version int, dirty bool) error {
   219  	migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
   220  	if err := migrationsCollection.Drop(context.TODO()); err != nil {
   221  		return &database.Error{OrigErr: err, Err: "drop migrations collection failed"}
   222  	}
   223  	_, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty})
   224  	if err != nil {
   225  		return &database.Error{OrigErr: err, Err: "save version failed"}
   226  	}
   227  	return nil
   228  }
   229  
   230  func (m *Mongo) Version() (version int, dirty bool, err error) {
   231  	var versionInfo versionInfo
   232  	err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo)
   233  	switch {
   234  	case err == mongo.ErrNoDocuments:
   235  		return database.NilVersion, false, nil
   236  	case err != nil:
   237  		return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"}
   238  	default:
   239  		return versionInfo.Version, versionInfo.Dirty, nil
   240  	}
   241  }
   242  
   243  func (m *Mongo) Run(migration io.Reader) error {
   244  	migr, err := io.ReadAll(migration)
   245  	if err != nil {
   246  		return err
   247  	}
   248  	var cmds []bson.D
   249  	err = bson.UnmarshalExtJSON(migr, true, &cmds)
   250  	if err != nil {
   251  		return fmt.Errorf("unmarshaling json error: %s", err)
   252  	}
   253  	if m.config.TransactionMode {
   254  		if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil {
   255  			return err
   256  		}
   257  	} else {
   258  		if err := m.executeCommands(context.TODO(), cmds); err != nil {
   259  			return err
   260  		}
   261  	}
   262  	return nil
   263  }
   264  
   265  func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error {
   266  	err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error {
   267  		if err := sessionContext.StartTransaction(); err != nil {
   268  			return &database.Error{OrigErr: err, Err: "failed to start transaction"}
   269  		}
   270  		if err := m.executeCommands(sessionContext, cmds); err != nil {
   271  			// When command execution is failed, it's aborting transaction
   272  			// If you tried to call abortTransaction, it`s return error that transaction already aborted
   273  			return err
   274  		}
   275  		if err := sessionContext.CommitTransaction(sessionContext); err != nil {
   276  			return &database.Error{OrigErr: err, Err: "failed to commit transaction"}
   277  		}
   278  		return nil
   279  	})
   280  	if err != nil {
   281  		return err
   282  	}
   283  	return nil
   284  }
   285  
   286  func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error {
   287  	for _, cmd := range cmds {
   288  		err := m.db.RunCommand(ctx, cmd).Err()
   289  		if err != nil {
   290  			return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)}
   291  		}
   292  	}
   293  	return nil
   294  }
   295  
   296  func (m *Mongo) Close() error {
   297  	return m.client.Disconnect(context.TODO())
   298  }
   299  
   300  func (m *Mongo) Drop() error {
   301  	return m.db.Drop(context.TODO())
   302  }
   303  
   304  func (m *Mongo) ensureLockTable() error {
   305  	indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes()
   306  
   307  	indexOptions := options.Index().SetUnique(true).SetName(LockIndexName)
   308  	_, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{
   309  		Options: indexOptions,
   310  		Keys:    findFilter{Key: -1},
   311  	})
   312  	if err != nil {
   313  		return err
   314  	}
   315  	return nil
   316  }
   317  
   318  // ensureVersionTable checks if versions table exists and, if not, creates it.
   319  // Note that this function locks the database, which deviates from the usual
   320  // convention of "caller locks" in the MongoDb type.
   321  func (m *Mongo) ensureVersionTable() (err error) {
   322  	if err = m.Lock(); err != nil {
   323  		return err
   324  	}
   325  
   326  	defer func() {
   327  		if e := m.Unlock(); e != nil {
   328  			if err == nil {
   329  				err = e
   330  			} else {
   331  				err = multierror.Append(err, e)
   332  			}
   333  		}
   334  	}()
   335  
   336  	if err != nil {
   337  		return err
   338  	}
   339  	if _, _, err = m.Version(); err != nil {
   340  		return err
   341  	}
   342  	return nil
   343  }
   344  
   345  // Utilizes advisory locking on the config.LockingCollection collection
   346  // This uses a unique index on the `locking_key` field.
   347  func (m *Mongo) Lock() error {
   348  	return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
   349  		if !m.config.Locking.Enabled {
   350  			return nil
   351  		}
   352  
   353  		pid := os.Getpid()
   354  		hostname, err := os.Hostname()
   355  		if err != nil {
   356  			hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
   357  		}
   358  
   359  		newLockObj := lockObj{
   360  			Key:       lockKeyUniqueValue,
   361  			Pid:       pid,
   362  			Hostname:  hostname,
   363  			CreatedAt: time.Now(),
   364  		}
   365  		operation := func() error {
   366  			timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
   367  			_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
   368  			defer cancelFunc()
   369  			return err
   370  		}
   371  		exponentialBackOff := backoff.NewExponentialBackOff()
   372  		duration := time.Duration(m.config.Locking.Timeout) * time.Second
   373  		exponentialBackOff.MaxElapsedTime = duration
   374  		exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
   375  
   376  		err = backoff.Retry(operation, exponentialBackOff)
   377  		if err != nil {
   378  			return database.ErrLocked
   379  		}
   380  
   381  		return nil
   382  	})
   383  }
   384  
   385  func (m *Mongo) Unlock() error {
   386  	return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
   387  		if !m.config.Locking.Enabled {
   388  			return nil
   389  		}
   390  
   391  		filter := findFilter{
   392  			Key: lockKeyUniqueValue,
   393  		}
   394  
   395  		ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
   396  		_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
   397  		defer cancel()
   398  
   399  		if err != nil {
   400  			return err
   401  		}
   402  		return nil
   403  	})
   404  }