github.com/getsynq/migrate/v4@v4.15.3-0.20220615182648-8e72daaa5ed9/database/mongodb/mongodb.go (about)

     1  package mongodb
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/url"
     9  	os "os"
    10  	"strconv"
    11  	"time"
    12  
    13  	"github.com/cenkalti/backoff/v4"
    14  	"github.com/getsynq/migrate/v4/database"
    15  	"github.com/hashicorp/go-multierror"
    16  	"go.mongodb.org/mongo-driver/bson"
    17  	"go.mongodb.org/mongo-driver/mongo"
    18  	"go.mongodb.org/mongo-driver/mongo/options"
    19  	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
    20  	"go.uber.org/atomic"
    21  )
    22  
    23  func init() {
    24  	db := Mongo{}
    25  	database.Register("mongodb", &db)
    26  	database.Register("mongodb+srv", &db)
    27  }
    28  
    29  var DefaultMigrationsCollection = "schema_migrations"
    30  
    31  const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default.
    32  const lockKeyUniqueValue = 0                             // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked).
    33  const DefaultLockTimeout = 15                            // the default maximum time to wait for a lock to be released.
    34  const DefaultLockTimeoutInterval = 10                    // the default maximum intervals time for the locking timout.
    35  const DefaultAdvisoryLockingFlag = true                  // the default value for the advisory locking feature flag. Default is true.
    36  const LockIndexName = "lock_unique_key"                  // the name of the index which adds unique constraint to the locking_key field.
    37  const contextWaitTimeout = 5 * time.Second               // how long to wait for the request to mongo to block/wait for.
    38  
    39  var (
    40  	ErrNoDatabaseName            = fmt.Errorf("no database name")
    41  	ErrNilConfig                 = fmt.Errorf("no config")
    42  	ErrLockTimeoutConfigConflict = fmt.Errorf("both x-advisory-lock-timeout-interval and x-advisory-lock-timout-interval were specified")
    43  )
    44  
    45  type Mongo struct {
    46  	client   *mongo.Client
    47  	db       *mongo.Database
    48  	config   *Config
    49  	isLocked atomic.Bool
    50  }
    51  
    52  type Locking struct {
    53  	CollectionName string
    54  	Timeout        int
    55  	Enabled        bool
    56  	Interval       int
    57  }
    58  type Config struct {
    59  	DatabaseName         string
    60  	MigrationsCollection string
    61  	TransactionMode      bool
    62  	Locking              Locking
    63  }
    64  type versionInfo struct {
    65  	Version int  `bson:"version"`
    66  	Dirty   bool `bson:"dirty"`
    67  }
    68  
    69  type lockObj struct {
    70  	Key       int       `bson:"locking_key"`
    71  	Pid       int       `bson:"pid"`
    72  	Hostname  string    `bson:"hostname"`
    73  	CreatedAt time.Time `bson:"created_at"`
    74  }
    75  type findFilter struct {
    76  	Key int `bson:"locking_key"`
    77  }
    78  
    79  func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
    80  	if config == nil {
    81  		return nil, ErrNilConfig
    82  	}
    83  	if len(config.DatabaseName) == 0 {
    84  		return nil, ErrNoDatabaseName
    85  	}
    86  	if len(config.MigrationsCollection) == 0 {
    87  		config.MigrationsCollection = DefaultMigrationsCollection
    88  	}
    89  	if len(config.Locking.CollectionName) == 0 {
    90  		config.Locking.CollectionName = DefaultLockingCollection
    91  	}
    92  	if config.Locking.Timeout <= 0 {
    93  		config.Locking.Timeout = DefaultLockTimeout
    94  	}
    95  	if config.Locking.Interval <= 0 {
    96  		config.Locking.Interval = DefaultLockTimeoutInterval
    97  	}
    98  
    99  	mc := &Mongo{
   100  		client: instance,
   101  		db:     instance.Database(config.DatabaseName),
   102  		config: config,
   103  	}
   104  
   105  	if mc.config.Locking.Enabled {
   106  		if err := mc.ensureLockTable(); err != nil {
   107  			return nil, err
   108  		}
   109  	}
   110  	if err := mc.ensureVersionTable(); err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	return mc, nil
   115  }
   116  
   117  func (m *Mongo) Open(dsn string) (database.Driver, error) {
   118  	//connstring is experimental package, but it used for parse connection string in mongo.Connect function
   119  	uri, err := connstring.Parse(dsn)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	if len(uri.Database) == 0 {
   124  		return nil, ErrNoDatabaseName
   125  	}
   126  	unknown := url.Values(uri.UnknownOptions)
   127  
   128  	migrationsCollection := unknown.Get("x-migrations-collection")
   129  	lockCollection := unknown.Get("x-advisory-lock-collection")
   130  	transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false)
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  	advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag)
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  	lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	lockTimeoutIntervalValue := unknown.Get("x-advisory-lock-timeout-interval")
   144  	// The initial release had a typo for this argument but for backwards compatibility sake, we will keep supporting it
   145  	// and we will error out if both values are set.
   146  	lockTimeoutIntervalValueFromTypo := unknown.Get("x-advisory-lock-timout-interval")
   147  
   148  	lockTimeout := lockTimeoutIntervalValue
   149  
   150  	if lockTimeoutIntervalValue != "" && lockTimeoutIntervalValueFromTypo != "" {
   151  		return nil, ErrLockTimeoutConfigConflict
   152  	} else if lockTimeoutIntervalValueFromTypo != "" {
   153  		lockTimeout = lockTimeoutIntervalValueFromTypo
   154  	}
   155  
   156  	maxLockCheckInterval, err := parseInt(lockTimeout, DefaultLockTimeoutInterval)
   157  
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  	client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn))
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  
   166  	if err = client.Ping(context.TODO(), nil); err != nil {
   167  		return nil, err
   168  	}
   169  	mc, err := WithInstance(client, &Config{
   170  		DatabaseName:         uri.Database,
   171  		MigrationsCollection: migrationsCollection,
   172  		TransactionMode:      transactionMode,
   173  		Locking: Locking{
   174  			CollectionName: lockCollection,
   175  			Timeout:        lockingTimout,
   176  			Enabled:        advisoryLockingFlag,
   177  			Interval:       maxLockCheckInterval,
   178  		},
   179  	})
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  	return mc, nil
   184  }
   185  
   186  //Parse the url param, convert it to boolean
   187  // returns error if param invalid. returns defaultValue if param not present
   188  func parseBoolean(urlParam string, defaultValue bool) (bool, error) {
   189  
   190  	// if parameter passed, parse it (otherwise return default value)
   191  	if urlParam != "" {
   192  		result, err := strconv.ParseBool(urlParam)
   193  		if err != nil {
   194  			return false, err
   195  		}
   196  		return result, nil
   197  	}
   198  
   199  	// if no url Param passed, return default value
   200  	return defaultValue, nil
   201  }
   202  
   203  //Parse the url param, convert it to int
   204  // returns error if param invalid. returns defaultValue if param not present
   205  func parseInt(urlParam string, defaultValue int) (int, error) {
   206  
   207  	// if parameter passed, parse it (otherwise return default value)
   208  	if urlParam != "" {
   209  		result, err := strconv.Atoi(urlParam)
   210  		if err != nil {
   211  			return -1, err
   212  		}
   213  		return result, nil
   214  	}
   215  
   216  	// if no url Param passed, return default value
   217  	return defaultValue, nil
   218  }
   219  func (m *Mongo) SetVersion(version int, dirty bool) error {
   220  	migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
   221  	if err := migrationsCollection.Drop(context.TODO()); err != nil {
   222  		return &database.Error{OrigErr: err, Err: "drop migrations collection failed"}
   223  	}
   224  	_, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty})
   225  	if err != nil {
   226  		return &database.Error{OrigErr: err, Err: "save version failed"}
   227  	}
   228  	return nil
   229  }
   230  
   231  func (m *Mongo) Version() (version int, dirty bool, err error) {
   232  	var versionInfo versionInfo
   233  	err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo)
   234  	switch {
   235  	case err == mongo.ErrNoDocuments:
   236  		return database.NilVersion, false, nil
   237  	case err != nil:
   238  		return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"}
   239  	default:
   240  		return versionInfo.Version, versionInfo.Dirty, nil
   241  	}
   242  }
   243  
   244  func (m *Mongo) Run(migration io.Reader) error {
   245  	migr, err := ioutil.ReadAll(migration)
   246  	if err != nil {
   247  		return err
   248  	}
   249  	var cmds []bson.D
   250  	err = bson.UnmarshalExtJSON(migr, true, &cmds)
   251  	if err != nil {
   252  		return fmt.Errorf("unmarshaling json error: %s", err)
   253  	}
   254  	if m.config.TransactionMode {
   255  		if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil {
   256  			return err
   257  		}
   258  	} else {
   259  		if err := m.executeCommands(context.TODO(), cmds); err != nil {
   260  			return err
   261  		}
   262  	}
   263  	return nil
   264  }
   265  
   266  func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error {
   267  	err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error {
   268  		if err := sessionContext.StartTransaction(); err != nil {
   269  			return &database.Error{OrigErr: err, Err: "failed to start transaction"}
   270  		}
   271  		if err := m.executeCommands(sessionContext, cmds); err != nil {
   272  			//When command execution is failed, it's aborting transaction
   273  			//If you tried to call abortTransaction, it`s return error that transaction already aborted
   274  			return err
   275  		}
   276  		if err := sessionContext.CommitTransaction(sessionContext); err != nil {
   277  			return &database.Error{OrigErr: err, Err: "failed to commit transaction"}
   278  		}
   279  		return nil
   280  	})
   281  	if err != nil {
   282  		return err
   283  	}
   284  	return nil
   285  }
   286  
   287  func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error {
   288  	for _, cmd := range cmds {
   289  		err := m.db.RunCommand(ctx, cmd).Err()
   290  		if err != nil {
   291  			return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)}
   292  		}
   293  	}
   294  	return nil
   295  }
   296  
   297  func (m *Mongo) Close() error {
   298  	return m.client.Disconnect(context.TODO())
   299  }
   300  
   301  func (m *Mongo) Drop() error {
   302  	return m.db.Drop(context.TODO())
   303  }
   304  
   305  func (m *Mongo) ensureLockTable() error {
   306  	indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes()
   307  
   308  	indexOptions := options.Index().SetUnique(true).SetName(LockIndexName)
   309  	_, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{
   310  		Options: indexOptions,
   311  		Keys:    findFilter{Key: -1},
   312  	})
   313  	if err != nil {
   314  		return err
   315  	}
   316  	return nil
   317  }
   318  
   319  // ensureVersionTable checks if versions table exists and, if not, creates it.
   320  // Note that this function locks the database, which deviates from the usual
   321  // convention of "caller locks" in the MongoDb type.
   322  func (m *Mongo) ensureVersionTable() (err error) {
   323  	if err = m.Lock(); err != nil {
   324  		return err
   325  	}
   326  
   327  	defer func() {
   328  		if e := m.Unlock(); e != nil {
   329  			if err == nil {
   330  				err = e
   331  			} else {
   332  				err = multierror.Append(err, e)
   333  			}
   334  		}
   335  	}()
   336  
   337  	if err != nil {
   338  		return err
   339  	}
   340  	if _, _, err = m.Version(); err != nil {
   341  		return err
   342  	}
   343  	return nil
   344  }
   345  
   346  // Utilizes advisory locking on the config.LockingCollection collection
   347  // This uses a unique index on the `locking_key` field.
   348  func (m *Mongo) Lock() error {
   349  	return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
   350  		if !m.config.Locking.Enabled {
   351  			return nil
   352  		}
   353  
   354  		pid := os.Getpid()
   355  		hostname, err := os.Hostname()
   356  		if err != nil {
   357  			hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
   358  		}
   359  
   360  		newLockObj := lockObj{
   361  			Key:       lockKeyUniqueValue,
   362  			Pid:       pid,
   363  			Hostname:  hostname,
   364  			CreatedAt: time.Now(),
   365  		}
   366  		operation := func() error {
   367  			timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
   368  			_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
   369  			defer cancelFunc()
   370  			return err
   371  		}
   372  		exponentialBackOff := backoff.NewExponentialBackOff()
   373  		duration := time.Duration(m.config.Locking.Timeout) * time.Second
   374  		exponentialBackOff.MaxElapsedTime = duration
   375  		exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
   376  
   377  		err = backoff.Retry(operation, exponentialBackOff)
   378  		if err != nil {
   379  			return database.ErrLocked
   380  		}
   381  
   382  		return nil
   383  	})
   384  }
   385  
   386  func (m *Mongo) Unlock() error {
   387  	return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
   388  		if !m.config.Locking.Enabled {
   389  			return nil
   390  		}
   391  
   392  		filter := findFilter{
   393  			Key: lockKeyUniqueValue,
   394  		}
   395  
   396  		ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
   397  		_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
   398  		defer cancel()
   399  
   400  		if err != nil {
   401  			return err
   402  		}
   403  		return nil
   404  	})
   405  }