github.com/Elate-DevOps/migrate/v4@v4.0.12/database/mongodb/mongodb.go (about) 1 package mongodb 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net/url" 8 "os" 9 "strconv" 10 "time" 11 12 "github.com/Elate-DevOps/migrate/v4/database" 13 "github.com/cenkalti/backoff/v4" 14 "github.com/hashicorp/go-multierror" 15 "go.mongodb.org/mongo-driver/bson" 16 "go.mongodb.org/mongo-driver/mongo" 17 "go.mongodb.org/mongo-driver/mongo/options" 18 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" 19 "go.uber.org/atomic" 20 ) 21 22 func init() { 23 db := Mongo{} 24 database.Register("mongodb", &db) 25 database.Register("mongodb+srv", &db) 26 } 27 28 var DefaultMigrationsCollection = "schema_migrations" 29 30 const ( 31 DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default. 32 lockKeyUniqueValue = 0 // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked). 33 DefaultLockTimeout = 15 // the default maximum time to wait for a lock to be released. 34 DefaultLockTimeoutInterval = 10 // the default maximum intervals time for the locking timout. 35 DefaultAdvisoryLockingFlag = true // the default value for the advisory locking feature flag. Default is true. 36 LockIndexName = "lock_unique_key" // the name of the index which adds unique constraint to the locking_key field. 37 contextWaitTimeout = 5 * time.Second // how long to wait for the request to mongo to block/wait for. 38 ) 39 40 var ( 41 ErrNoDatabaseName = fmt.Errorf("no database name") 42 ErrNilConfig = fmt.Errorf("no config") 43 ErrLockTimeoutConfigConflict = fmt.Errorf("both x-advisory-lock-timeout-interval and x-advisory-lock-timout-interval were specified") 44 ) 45 46 type Mongo struct { 47 client *mongo.Client 48 db *mongo.Database 49 config *Config 50 isLocked atomic.Bool 51 } 52 53 type Locking struct { 54 CollectionName string 55 Timeout int 56 Enabled bool 57 Interval int 58 } 59 type Config struct { 60 DatabaseName string 61 MigrationsCollection string 62 TransactionMode bool 63 Locking Locking 64 } 65 type versionInfo struct { 66 Version int `bson:"version"` 67 Dirty bool `bson:"dirty"` 68 } 69 70 type lockObj struct { 71 Key int `bson:"locking_key"` 72 Pid int `bson:"pid"` 73 Hostname string `bson:"hostname"` 74 CreatedAt time.Time `bson:"created_at"` 75 } 76 type findFilter struct { 77 Key int `bson:"locking_key"` 78 } 79 80 func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) { 81 if config == nil { 82 return nil, ErrNilConfig 83 } 84 if len(config.DatabaseName) == 0 { 85 return nil, ErrNoDatabaseName 86 } 87 if len(config.MigrationsCollection) == 0 { 88 config.MigrationsCollection = DefaultMigrationsCollection 89 } 90 if len(config.Locking.CollectionName) == 0 { 91 config.Locking.CollectionName = DefaultLockingCollection 92 } 93 if config.Locking.Timeout <= 0 { 94 config.Locking.Timeout = DefaultLockTimeout 95 } 96 if config.Locking.Interval <= 0 { 97 config.Locking.Interval = DefaultLockTimeoutInterval 98 } 99 100 mc := &Mongo{ 101 client: instance, 102 db: instance.Database(config.DatabaseName), 103 config: config, 104 } 105 106 if mc.config.Locking.Enabled { 107 if err := mc.ensureLockTable(); err != nil { 108 return nil, err 109 } 110 } 111 if err := mc.ensureVersionTable(); err != nil { 112 return nil, err 113 } 114 115 return mc, nil 116 } 117 118 func (m *Mongo) Open(dsn string) (database.Driver, error) { 119 // connstring is experimental package, but it used for parse connection string in mongo.Connect function 120 uri, err := connstring.Parse(dsn) 121 if err != nil { 122 return nil, err 123 } 124 if len(uri.Database) == 0 { 125 return nil, ErrNoDatabaseName 126 } 127 unknown := url.Values(uri.UnknownOptions) 128 129 migrationsCollection := unknown.Get("x-migrations-collection") 130 lockCollection := unknown.Get("x-advisory-lock-collection") 131 transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false) 132 if err != nil { 133 return nil, err 134 } 135 advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag) 136 if err != nil { 137 return nil, err 138 } 139 lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout) 140 if err != nil { 141 return nil, err 142 } 143 144 lockTimeoutIntervalValue := unknown.Get("x-advisory-lock-timeout-interval") 145 // The initial release had a typo for this argument but for backwards compatibility sake, we will keep supporting it 146 // and we will error out if both values are set. 147 lockTimeoutIntervalValueFromTypo := unknown.Get("x-advisory-lock-timout-interval") 148 149 lockTimeout := lockTimeoutIntervalValue 150 151 if lockTimeoutIntervalValue != "" && lockTimeoutIntervalValueFromTypo != "" { 152 return nil, ErrLockTimeoutConfigConflict 153 } else if lockTimeoutIntervalValueFromTypo != "" { 154 lockTimeout = lockTimeoutIntervalValueFromTypo 155 } 156 157 maxLockCheckInterval, err := parseInt(lockTimeout, DefaultLockTimeoutInterval) 158 if err != nil { 159 return nil, err 160 } 161 client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn)) 162 if err != nil { 163 return nil, err 164 } 165 166 if err = client.Ping(context.TODO(), nil); err != nil { 167 return nil, err 168 } 169 mc, err := WithInstance(client, &Config{ 170 DatabaseName: uri.Database, 171 MigrationsCollection: migrationsCollection, 172 TransactionMode: transactionMode, 173 Locking: Locking{ 174 CollectionName: lockCollection, 175 Timeout: lockingTimout, 176 Enabled: advisoryLockingFlag, 177 Interval: maxLockCheckInterval, 178 }, 179 }) 180 if err != nil { 181 return nil, err 182 } 183 return mc, nil 184 } 185 186 // Parse the url param, convert it to boolean 187 // returns error if param invalid. returns defaultValue if param not present 188 func parseBoolean(urlParam string, defaultValue bool) (bool, error) { 189 // if parameter passed, parse it (otherwise return default value) 190 if urlParam != "" { 191 result, err := strconv.ParseBool(urlParam) 192 if err != nil { 193 return false, err 194 } 195 return result, nil 196 } 197 198 // if no url Param passed, return default value 199 return defaultValue, nil 200 } 201 202 // Parse the url param, convert it to int 203 // returns error if param invalid. returns defaultValue if param not present 204 func parseInt(urlParam string, defaultValue int) (int, error) { 205 // if parameter passed, parse it (otherwise return default value) 206 if urlParam != "" { 207 result, err := strconv.Atoi(urlParam) 208 if err != nil { 209 return -1, err 210 } 211 return result, nil 212 } 213 214 // if no url Param passed, return default value 215 return defaultValue, nil 216 } 217 218 func (m *Mongo) SetVersion(version int, dirty bool) error { 219 migrationsCollection := m.db.Collection(m.config.MigrationsCollection) 220 if err := migrationsCollection.Drop(context.TODO()); err != nil { 221 return &database.Error{OrigErr: err, Err: "drop migrations collection failed"} 222 } 223 _, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty}) 224 if err != nil { 225 return &database.Error{OrigErr: err, Err: "save version failed"} 226 } 227 return nil 228 } 229 230 func (m *Mongo) Version() (version int, dirty bool, err error) { 231 var versionInfo versionInfo 232 err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo) 233 switch { 234 case err == mongo.ErrNoDocuments: 235 return database.NilVersion, false, nil 236 case err != nil: 237 return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"} 238 default: 239 return versionInfo.Version, versionInfo.Dirty, nil 240 } 241 } 242 243 func (m *Mongo) Run(migration io.Reader) error { 244 migr, err := io.ReadAll(migration) 245 if err != nil { 246 return err 247 } 248 var cmds []bson.D 249 err = bson.UnmarshalExtJSON(migr, true, &cmds) 250 if err != nil { 251 return fmt.Errorf("unmarshaling json error: %s", err) 252 } 253 if m.config.TransactionMode { 254 if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil { 255 return err 256 } 257 } else { 258 if err := m.executeCommands(context.TODO(), cmds); err != nil { 259 return err 260 } 261 } 262 return nil 263 } 264 265 func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error { 266 err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error { 267 if err := sessionContext.StartTransaction(); err != nil { 268 return &database.Error{OrigErr: err, Err: "failed to start transaction"} 269 } 270 if err := m.executeCommands(sessionContext, cmds); err != nil { 271 // When command execution is failed, it's aborting transaction 272 // If you tried to call abortTransaction, it`s return error that transaction already aborted 273 return err 274 } 275 if err := sessionContext.CommitTransaction(sessionContext); err != nil { 276 return &database.Error{OrigErr: err, Err: "failed to commit transaction"} 277 } 278 return nil 279 }) 280 if err != nil { 281 return err 282 } 283 return nil 284 } 285 286 func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error { 287 for _, cmd := range cmds { 288 err := m.db.RunCommand(ctx, cmd).Err() 289 if err != nil { 290 return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)} 291 } 292 } 293 return nil 294 } 295 296 func (m *Mongo) Close() error { 297 return m.client.Disconnect(context.TODO()) 298 } 299 300 func (m *Mongo) Drop() error { 301 return m.db.Drop(context.TODO()) 302 } 303 304 func (m *Mongo) ensureLockTable() error { 305 indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes() 306 307 indexOptions := options.Index().SetUnique(true).SetName(LockIndexName) 308 _, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{ 309 Options: indexOptions, 310 Keys: findFilter{Key: -1}, 311 }) 312 if err != nil { 313 return err 314 } 315 return nil 316 } 317 318 // ensureVersionTable checks if versions table exists and, if not, creates it. 319 // Note that this function locks the database, which deviates from the usual 320 // convention of "caller locks" in the MongoDb type. 321 func (m *Mongo) ensureVersionTable() (err error) { 322 if err = m.Lock(); err != nil { 323 return err 324 } 325 326 defer func() { 327 if e := m.Unlock(); e != nil { 328 if err == nil { 329 err = e 330 } else { 331 err = multierror.Append(err, e) 332 } 333 } 334 }() 335 336 if err != nil { 337 return err 338 } 339 if _, _, err = m.Version(); err != nil { 340 return err 341 } 342 return nil 343 } 344 345 // Utilizes advisory locking on the config.LockingCollection collection 346 // This uses a unique index on the `locking_key` field. 347 func (m *Mongo) Lock() error { 348 return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { 349 if !m.config.Locking.Enabled { 350 return nil 351 } 352 353 pid := os.Getpid() 354 hostname, err := os.Hostname() 355 if err != nil { 356 hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error()) 357 } 358 359 newLockObj := lockObj{ 360 Key: lockKeyUniqueValue, 361 Pid: pid, 362 Hostname: hostname, 363 CreatedAt: time.Now(), 364 } 365 operation := func() error { 366 timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout) 367 _, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj) 368 defer cancelFunc() 369 return err 370 } 371 exponentialBackOff := backoff.NewExponentialBackOff() 372 duration := time.Duration(m.config.Locking.Timeout) * time.Second 373 exponentialBackOff.MaxElapsedTime = duration 374 exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second 375 376 err = backoff.Retry(operation, exponentialBackOff) 377 if err != nil { 378 return database.ErrLocked 379 } 380 381 return nil 382 }) 383 } 384 385 func (m *Mongo) Unlock() error { 386 return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error { 387 if !m.config.Locking.Enabled { 388 return nil 389 } 390 391 filter := findFilter{ 392 Key: lockKeyUniqueValue, 393 } 394 395 ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout) 396 _, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter) 397 defer cancel() 398 399 if err != nil { 400 return err 401 } 402 return nil 403 }) 404 }