
     1  package mongodb
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net/url"
     8  	"os"
     9  	"strconv"
    10  	"time"
    12  	""
    13  	""
    14  	""
    15  	""
    16  	""
    17  	""
    18  	""
    19  	""
    20  )
    22  func init() {
    23  	db := Mongo{}
    24  	database.Register("mongodb", &db)
    25  	database.Register("mongodb+srv", &db)
    26  }
    28  var DefaultMigrationsCollection = "schema_migrations"
    30  const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default.
    31  const lockKeyUniqueValue = 0                             // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked).
    32  const DefaultLockTimeout = 15                            // the default maximum time to wait for a lock to be released.
    33  const DefaultLockTimeoutInterval = 10                    // the default maximum intervals time for the locking timout.
    34  const DefaultAdvisoryLockingFlag = true                  // the default value for the advisory locking feature flag. Default is true.
    35  const LockIndexName = "lock_unique_key"                  // the name of the index which adds unique constraint to the locking_key field.
    36  const contextWaitTimeout = 5 * time.Second               // how long to wait for the request to mongo to block/wait for.
    38  var (
    39  	ErrNoDatabaseName            = fmt.Errorf("no database name")
    40  	ErrNilConfig                 = fmt.Errorf("no config")
    41  	ErrLockTimeoutConfigConflict = fmt.Errorf("both x-advisory-lock-timeout-interval and x-advisory-lock-timout-interval were specified")
    42  )
    44  type Mongo struct {
    45  	client   *mongo.Client
    46  	db       *mongo.Database
    47  	config   *Config
    48  	isLocked atomic.Bool
    49  }
    51  type Locking struct {
    52  	CollectionName string
    53  	Timeout        int
    54  	Enabled        bool
    55  	Interval       int
    56  }
    57  type Config struct {
    58  	DatabaseName         string
    59  	MigrationsCollection string
    60  	TransactionMode      bool
    61  	Locking              Locking
    62  }
    63  type versionInfo struct {
    64  	Version int  `bson:"version"`
    65  	Dirty   bool `bson:"dirty"`
    66  }
    68  type lockObj struct {
    69  	Key       int       `bson:"locking_key"`
    70  	Pid       int       `bson:"pid"`
    71  	Hostname  string    `bson:"hostname"`
    72  	CreatedAt time.Time `bson:"created_at"`
    73  }
    74  type findFilter struct {
    75  	Key int `bson:"locking_key"`
    76  }
    78  func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
    79  	if config == nil {
    80  		return nil, ErrNilConfig
    81  	}
    82  	if len(config.DatabaseName) == 0 {
    83  		return nil, ErrNoDatabaseName
    84  	}
    85  	if len(config.MigrationsCollection) == 0 {
    86  		config.MigrationsCollection = DefaultMigrationsCollection
    87  	}
    88  	if len(config.Locking.CollectionName) == 0 {
    89  		config.Locking.CollectionName = DefaultLockingCollection
    90  	}
    91  	if config.Locking.Timeout <= 0 {
    92  		config.Locking.Timeout = DefaultLockTimeout
    93  	}
    94  	if config.Locking.Interval <= 0 {
    95  		config.Locking.Interval = DefaultLockTimeoutInterval
    96  	}
    98  	mc := &Mongo{
    99  		client: instance,
   100  		db:     instance.Database(config.DatabaseName),
   101  		config: config,
   102  	}
   104  	if mc.config.Locking.Enabled {
   105  		if err := mc.ensureLockTable(); err != nil {
   106  			return nil, err
   107  		}
   108  	}
   109  	if err := mc.ensureVersionTable(); err != nil {
   110  		return nil, err
   111  	}
   113  	return mc, nil
   114  }
   116  func (m *Mongo) Open(dsn string) (database.Driver, error) {
   117  	// connstring is experimental package, but it used for parse connection string in mongo.Connect function
   118  	uri, err := connstring.Parse(dsn)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	if len(uri.Database) == 0 {
   123  		return nil, ErrNoDatabaseName
   124  	}
   125  	unknown := url.Values(uri.UnknownOptions)
   127  	migrationsCollection := unknown.Get("x-migrations-collection")
   128  	lockCollection := unknown.Get("x-advisory-lock-collection")
   129  	transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout)
   138  	if err != nil {
   139  		return nil, err
   140  	}
   142  	lockTimeoutIntervalValue := unknown.Get("x-advisory-lock-timeout-interval")
   143  	// The initial release had a typo for this argument but for backwards compatibility sake, we will keep supporting it
   144  	// and we will error out if both values are set.
   145  	lockTimeoutIntervalValueFromTypo := unknown.Get("x-advisory-lock-timout-interval")
   147  	lockTimeout := lockTimeoutIntervalValue
   149  	if lockTimeoutIntervalValue != "" && lockTimeoutIntervalValueFromTypo != "" {
   150  		return nil, ErrLockTimeoutConfigConflict
   151  	} else if lockTimeoutIntervalValueFromTypo != "" {
   152  		lockTimeout = lockTimeoutIntervalValueFromTypo
   153  	}
   155  	maxLockCheckInterval, err := parseInt(lockTimeout, DefaultLockTimeoutInterval)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  	client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn))
   161  	if err != nil {
   162  		return nil, err
   163  	}
   165  	if err = client.Ping(context.TODO(), nil); err != nil {
   166  		return nil, err
   167  	}
   168  	mc, err := WithInstance(client, &Config{
   169  		DatabaseName:         uri.Database,
   170  		MigrationsCollection: migrationsCollection,
   171  		TransactionMode:      transactionMode,
   172  		Locking: Locking{
   173  			CollectionName: lockCollection,
   174  			Timeout:        lockingTimout,
   175  			Enabled:        advisoryLockingFlag,
   176  			Interval:       maxLockCheckInterval,
   177  		},
   178  	})
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  	return mc, nil
   183  }
   185  // Parse the url param, convert it to boolean
   186  // returns error if param invalid. returns defaultValue if param not present
   187  func parseBoolean(urlParam string, defaultValue bool) (bool, error) {
   189  	// if parameter passed, parse it (otherwise return default value)
   190  	if urlParam != "" {
   191  		result, err := strconv.ParseBool(urlParam)
   192  		if err != nil {
   193  			return false, err
   194  		}
   195  		return result, nil
   196  	}
   198  	// if no url Param passed, return default value
   199  	return defaultValue, nil
   200  }
   202  // Parse the url param, convert it to int
   203  // returns error if param invalid. returns defaultValue if param not present
   204  func parseInt(urlParam string, defaultValue int) (int, error) {
   206  	// if parameter passed, parse it (otherwise return default value)
   207  	if urlParam != "" {
   208  		result, err := strconv.Atoi(urlParam)
   209  		if err != nil {
   210  			return -1, err
   211  		}
   212  		return result, nil
   213  	}
   215  	// if no url Param passed, return default value
   216  	return defaultValue, nil
   217  }
   218  func (m *Mongo) SetVersion(version int, dirty bool) error {
   219  	migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
   220  	if err := migrationsCollection.Drop(context.TODO()); err != nil {
   221  		return &database.Error{OrigErr: err, Err: "drop migrations collection failed"}
   222  	}
   223  	_, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty})
   224  	if err != nil {
   225  		return &database.Error{OrigErr: err, Err: "save version failed"}
   226  	}
   227  	return nil
   228  }
   230  func (m *Mongo) Version() (version int, dirty bool, err error) {
   231  	var versionInfo versionInfo
   232  	err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo)
   233  	switch {
   234  	case err == mongo.ErrNoDocuments:
   235  		return database.NilVersion, false, nil
   236  	case err != nil:
   237  		return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"}
   238  	default:
   239  		return versionInfo.Version, versionInfo.Dirty, nil
   240  	}
   241  }
   243  func (m *Mongo) Run(migration io.Reader) error {
   244  	migr, err := io.ReadAll(migration)
   245  	if err != nil {
   246  		return err
   247  	}
   248  	var cmds []bson.D
   249  	err = bson.UnmarshalExtJSON(migr, true, &cmds)
   250  	if err != nil {
   251  		return fmt.Errorf("unmarshaling json error: %s", err)
   252  	}
   253  	if m.config.TransactionMode {
   254  		if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil {
   255  			return err
   256  		}
   257  	} else {
   258  		if err := m.executeCommands(context.TODO(), cmds); err != nil {
   259  			return err
   260  		}
   261  	}
   262  	return nil
   263  }
   265  func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error {
   266  	err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error {
   267  		if err := sessionContext.StartTransaction(); err != nil {
   268  			return &database.Error{OrigErr: err, Err: "failed to start transaction"}
   269  		}
   270  		if err := m.executeCommands(sessionContext, cmds); err != nil {
   271  			// When command execution is failed, it's aborting transaction
   272  			// If you tried to call abortTransaction, it`s return error that transaction already aborted
   273  			return err
   274  		}
   275  		if err := sessionContext.CommitTransaction(sessionContext); err != nil {
   276  			return &database.Error{OrigErr: err, Err: "failed to commit transaction"}
   277  		}
   278  		return nil
   279  	})
   280  	if err != nil {
   281  		return err
   282  	}
   283  	return nil
   284  }
   286  func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error {
   287  	for _, cmd := range cmds {
   288  		err := m.db.RunCommand(ctx, cmd).Err()
   289  		if err != nil {
   290  			return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)}
   291  		}
   292  	}
   293  	return nil
   294  }
   296  func (m *Mongo) Close() error {
   297  	return m.client.Disconnect(context.TODO())
   298  }
   300  func (m *Mongo) Drop() error {
   301  	return m.db.Drop(context.TODO())
   302  }
   304  func (m *Mongo) ensureLockTable() error {
   305  	indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes()
   307  	indexOptions := options.Index().SetUnique(true).SetName(LockIndexName)
   308  	_, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{
   309  		Options: indexOptions,
   310  		Keys:    findFilter{Key: -1},
   311  	})
   312  	if err != nil {
   313  		return err
   314  	}
   315  	return nil
   316  }
   318  // ensureVersionTable checks if versions table exists and, if not, creates it.
   319  // Note that this function locks the database, which deviates from the usual
   320  // convention of "caller locks" in the MongoDb type.
   321  func (m *Mongo) ensureVersionTable() (err error) {
   322  	if err = m.Lock(); err != nil {
   323  		return err
   324  	}
   326  	defer func() {
   327  		if e := m.Unlock(); e != nil {
   328  			if err == nil {
   329  				err = e
   330  			} else {
   331  				err = multierror.Append(err, e)
   332  			}
   333  		}
   334  	}()
   336  	if err != nil {
   337  		return err
   338  	}
   339  	if _, _, err = m.Version(); err != nil {
   340  		return err
   341  	}
   342  	return nil
   343  }
   345  // Utilizes advisory locking on the config.LockingCollection collection
   346  // This uses a unique index on the `locking_key` field.
   347  func (m *Mongo) Lock() error {
   348  	return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
   349  		if !m.config.Locking.Enabled {
   350  			return nil
   351  		}
   353  		pid := os.Getpid()
   354  		hostname, err := os.Hostname()
   355  		if err != nil {
   356  			hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
   357  		}
   359  		newLockObj := lockObj{
   360  			Key:       lockKeyUniqueValue,
   361  			Pid:       pid,
   362  			Hostname:  hostname,
   363  			CreatedAt: time.Now(),
   364  		}
   365  		operation := func() error {
   366  			timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
   367  			_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
   368  			defer cancelFunc()
   369  			return err
   370  		}
   371  		exponentialBackOff := backoff.NewExponentialBackOff()
   372  		duration := time.Duration(m.config.Locking.Timeout) * time.Second
   373  		exponentialBackOff.MaxElapsedTime = duration
   374  		exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
   376  		err = backoff.Retry(operation, exponentialBackOff)
   377  		if err != nil {
   378  			return database.ErrLocked
   379  		}
   381  		return nil
   382  	})
   383  }
   385  func (m *Mongo) Unlock() error {
   386  	return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
   387  		if !m.config.Locking.Enabled {
   388  			return nil
   389  		}
   391  		filter := findFilter{
   392  			Key: lockKeyUniqueValue,
   393  		}
   395  		ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
   396  		_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
   397  		defer cancel()
   399  		if err != nil {
   400  			return err
   401  		}
   402  		return nil
   403  	})
   404  }