github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/mongodb/mongodb.go (about) 1 package mongodb 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net/url" 8 "os" 9 "strconv" 10 "time" 11 12 "github.com/cenkalti/backoff/v4" 13 "github.com/golang-migrate/migrate/v4/database" 14 "github.com/hashicorp/go-multierror" 15 "go.mongodb.org/mongo-driver/bson" 16 "go.mongodb.org/mongo-driver/mongo" 17 "go.mongodb.org/mongo-driver/mongo/options" 18 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" 19 "go.uber.org/atomic" 20 ) 21 22 func init() { 23 db := Mongo{} 24 database.Register("mongodb", &db) 25 database.Register("mongodb+srv", &db) 26 } 27 28 var DefaultMigrationsCollection = "schema_migrations" 29 30 const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default. 31 const lockKeyUniqueValue = 0 // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked). 32 const DefaultLockTimeout = 15 // the default maximum time to wait for a lock to be released. 33 const DefaultLockTimeoutInterval = 10 // the default maximum intervals time for the locking timout. 34 const DefaultAdvisoryLockingFlag = true // the default value for the advisory locking feature flag. Default is true. 35 const LockIndexName = "lock_unique_key" // the name of the index which adds unique constraint to the locking_key field. 36 const contextWaitTimeout = 5 * time.Second // how long to wait for the request to mongo to block/wait for. 37 38 var ( 39 ErrNoDatabaseName = fmt.Errorf("no database name") 40 ErrNilConfig = fmt.Errorf("no config") 41 ErrLockTimeoutConfigConflict = fmt.Errorf("both x-advisory-lock-timeout-interval and x-advisory-lock-timout-interval were specified") 42 ) 43 44 type Mongo struct { 45 client *mongo.Client 46 db *mongo.Database 47 config *Config 48 isLocked atomic.Bool 49 } 50 51 type Locking struct { 52 CollectionName string 53 Timeout int 54 Enabled bool 55 Interval int 56 } 57 type Config struct { 58 DatabaseName string 59 MigrationsCollection string 60 TransactionMode bool 61 Locking Locking 62 } 63 type versionInfo struct { 64 Version int `bson:"version"` 65 Dirty bool `bson:"dirty"` 66 } 67 68 type lockObj struct { 69 Key int `bson:"locking_key"` 70 Pid int `bson:"pid"` 71 Hostname string `bson:"hostname"` 72 CreatedAt time.Time `bson:"created_at"` 73 } 74 type findFilter struct { 75 Key int `bson:"locking_key"` 76 } 77 78 func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) { 79 if config == nil { 80 return nil, ErrNilConfig 81 } 82 if len(config.DatabaseName) == 0 { 83 return nil, ErrNoDatabaseName 84 } 85 if len(config.MigrationsCollection) == 0 { 86 config.MigrationsCollection = DefaultMigrationsCollection 87 } 88 if len(config.Locking.CollectionName) == 0 { 89 config.Locking.CollectionName = DefaultLockingCollection 90 } 91 if config.Locking.Timeout <= 0 { 92 config.Locking.Timeout = DefaultLockTimeout 93 } 94 if config.Locking.Interval <= 0 { 95 config.Locking.Interval = DefaultLockTimeoutInterval 96 } 97 98 mc := &Mongo{ 99 client: instance, 100 db: instance.Database(config.DatabaseName), 101 config: config, 102 } 103 104 if mc.config.Locking.Enabled { 105 if err := mc.ensureLockTable(); err != nil { 106 return nil, err 107 } 108 } 109 if err := mc.ensureVersionTable(); err != nil { 110 return nil, err 111 } 112 113 return mc, nil 114 } 115 116 func (m *Mongo) Open(dsn string) (database.Driver, error) { 117 // connstring is experimental package, but it used for parse connection string in mongo.Connect function 118 uri, err := connstring.Parse(dsn) 119 if err != nil { 120 return nil, err 121 } 122 if len(uri.Database) == 0 { 123 return nil, ErrNoDatabaseName 124 } 125 unknown := url.Values(uri.UnknownOptions) 126 127 migrationsCollection := unknown.Get("x-migrations-collection") 128 lockCollection := unknown.Get("x-advisory-lock-collection") 129 transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false) 130 if err != nil { 131 return nil, err 132 } 133 advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag) 134 if err != nil { 135 return nil, err 136 } 137 lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout) 138 if err != nil { 139 return nil, err 140 } 141 142 lockTimeoutIntervalValue := unknown.Get("x-advisory-lock-timeout-interval") 143 // The initial release had a typo for this argument but for backwards compatibility sake, we will keep supporting it 144 // and we will error out if both values are set. 145 lockTimeoutIntervalValueFromTypo := unknown.Get("x-advisory-lock-timout-interval") 146 147 lockTimeout := lockTimeoutIntervalValue 148 149 if lockTimeoutIntervalValue != "" && lockTimeoutIntervalValueFromTypo != "" { 150 return nil, ErrLockTimeoutConfigConflict 151 } else if lockTimeoutIntervalValueFromTypo != "" { 152 lockTimeout = lockTimeoutIntervalValueFromTypo 153 } 154 155 maxLockCheckInterval, err := parseInt(lockTimeout, DefaultLockTimeoutInterval) 156 157 if err != nil { 158 return nil, err 159 } 160 client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn)) 161 if err != nil { 162 return nil, err 163 } 164 165 if err = client.Ping(context.TODO(), nil); err != nil { 166 return nil, err 167 } 168 mc, err := WithInstance(client, &Config{ 169 DatabaseName: uri.Database, 170 MigrationsCollection: migrationsCollection, 171 TransactionMode: transactionMode, 172 Locking: Locking{ 173 CollectionName: lockCollection, 174 Timeout: lockingTimout, 175 Enabled: advisoryLockingFlag, 176 Interval: maxLockCheckInterval, 177 }, 178 }) 179 if err != nil { 180 return nil, err 181 } 182 return mc, nil 183 } 184 185 // Parse the url param, convert it to boolean 186 // returns error if param invalid. returns defaultValue if param not present 187 func parseBoolean(urlParam string, defaultValue bool) (bool, error) { 188 189 // if parameter passed, parse it (otherwise return default value) 190 if urlParam != "" { 191 result, err := strconv.ParseBool(urlParam) 192 if err != nil { 193 return false, err 194 } 195 return result, nil 196 } 197 198 // if no url Param passed, return default value 199 return defaultValue, nil 200 } 201 202 // Parse the url param, convert it to int 203 // returns error if param invalid. returns defaultValue if param not present 204 func parseInt(urlParam string, defaultValue int) (int, error) { 205 206 // if parameter passed, parse it (otherwise return default value) 207 if urlParam != "" { 208 result, err := strconv.Atoi(urlParam) 209 if err != nil { 210 return -1, err 211 } 212 return result, nil 213 } 214 215 // if no url Param passed, return default value 216 return defaultValue, nil 217 } 218 func (m *Mongo) SetVersion(version int, dirty bool) error { 219 migrationsCollection := m.db.Collection(m.config.MigrationsCollection) 220 if err := migrationsCollection.Drop(context.TODO()); err != nil { 221 return &database.Error{OrigErr: err, Err: "drop migrations collection failed"} 222 } 223 _, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty}) 224 if err != nil { 225 return &database.Error{OrigErr: err, Err: "save version failed"} 226 } 227 return nil 228 } 229 230 func (m *Mongo) Version() (version int, dirty bool, err error) { 231 var versionInfo versionInfo 232 err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo) 233 switch { 234 case err == mongo.ErrNoDocuments: 235 return database.NilVersion, false, nil 236 case err != nil: 237 return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"} 238 default: 239 return versionInfo.Version, versionInfo.Dirty, nil 240 } 241 } 242 243 func (m *Mongo) Run(migration io.Reader) error { 244 migr, err := io.ReadAll(migration) 245 if err != nil { 246 return err 247 } 248 var cmds []bson.D 249 err = bson.UnmarshalExtJSON(migr, true, &cmds) 250 if err != nil { 251 return fmt.Errorf("unmarshaling json error: %s", err) 252 } 253 if m.config.TransactionMode { 254 if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil { 255 return err 256 } 257 } else { 258 if err := m.executeCommands(context.TODO(), cmds); err != nil { 259 return err 260 } 261 } 262 return nil 263 } 264 265 func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error { 266 err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error { 267 if err := sessionContext.StartTransaction(); err != nil { 268 return &database.Error{OrigErr: err, Err: "failed to start transaction"} 269 } 270 if err := m.executeCommands(sessionContext, cmds); err != nil { 271 // When command execution is failed, it's aborting transaction 272 // If you tried to call abortTransaction, it`s return error that transaction already aborted 273 return err 274 } 275 if err := sessionContext.CommitTransaction(sessionContext); err != nil { 276 return &database.Error{OrigErr: err, Err: "failed to commit transaction"} 277 } 278 return nil 279 }) 280 if err != nil { 281 return err 282 } 283 return nil 284 } 285 286 func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error { 287 for _, cmd := range cmds { 288 err := m.db.RunCommand(ctx, cmd).Err() 289 if err != nil { 290 return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)} 291 } 292 } 293 return nil 294 } 295 296 func (m *Mongo) Close() error { 297 return m.client.Disconnect(context.TODO()) 298 } 299 300 func (m *Mongo) Drop() error { 301 return m.db.Drop(context.TODO()) 302 } 303 304 func (m *Mongo) ensureLockTable() error { 305 indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes() 306 307 indexOptions := options.Index().SetUnique(true).SetName(LockIndexName) 308 _, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{ 309 Options: indexOptions, 310 Keys: findFilter{Key: -1}, 311 }) 312 if err != nil { 313 return err 314 } 315 return nil 316 } 317 318 // ensureVersionTable checks if versions table exists and, if not, creates it. 319 // Note that this function locks the database, which deviates from the usual 320 // convention of "caller locks" in the MongoDb type. 321 func (m *Mongo) ensureVersionTable() (err error) { 322 if err = m.Lock(); err != nil { 323 return err 324 } 325 326 defer func() { 327 if e := m.Unlock(); e != nil { 328 if err == nil { 329 err = e 330 } else { 331 err = multierror.Append(err, e) 332 } 333 } 334 }() 335 336 if err != nil { 337 return err 338 } 339 if _, _, err = m.Version(); err != nil { 340 return err 341 } 342 return nil 343 } 344 345 // Utilizes advisory locking on the config.LockingCollection collection 346 // This uses a unique index on the `locking_key` field. 347 func (m *Mongo) Lock() error { 348 return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { 349 if !m.config.Locking.Enabled { 350 return nil 351 } 352 353 pid := os.Getpid() 354 hostname, err := os.Hostname() 355 if err != nil { 356 hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error()) 357 } 358 359 newLockObj := lockObj{ 360 Key: lockKeyUniqueValue, 361 Pid: pid, 362 Hostname: hostname, 363 CreatedAt: time.Now(), 364 } 365 operation := func() error { 366 timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout) 367 _, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj) 368 defer cancelFunc() 369 return err 370 } 371 exponentialBackOff := backoff.NewExponentialBackOff() 372 duration := time.Duration(m.config.Locking.Timeout) * time.Second 373 exponentialBackOff.MaxElapsedTime = duration 374 exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second 375 376 err = backoff.Retry(operation, exponentialBackOff) 377 if err != nil { 378 return database.ErrLocked 379 } 380 381 return nil 382 }) 383 } 384 385 func (m *Mongo) Unlock() error { 386 return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error { 387 if !m.config.Locking.Enabled { 388 return nil 389 } 390 391 filter := findFilter{ 392 Key: lockKeyUniqueValue, 393 } 394 395 ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout) 396 _, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter) 397 defer cancel() 398 399 if err != nil { 400 return err 401 } 402 return nil 403 }) 404 }