github.com/getsynq/migrate/v4@v4.15.3-0.20220615182648-8e72daaa5ed9/database/mongodb/mongodb.go (about) 1 package mongodb 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "net/url" 9 os "os" 10 "strconv" 11 "time" 12 13 "github.com/cenkalti/backoff/v4" 14 "github.com/getsynq/migrate/v4/database" 15 "github.com/hashicorp/go-multierror" 16 "go.mongodb.org/mongo-driver/bson" 17 "go.mongodb.org/mongo-driver/mongo" 18 "go.mongodb.org/mongo-driver/mongo/options" 19 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" 20 "go.uber.org/atomic" 21 ) 22 23 func init() { 24 db := Mongo{} 25 database.Register("mongodb", &db) 26 database.Register("mongodb+srv", &db) 27 } 28 29 var DefaultMigrationsCollection = "schema_migrations" 30 31 const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default. 32 const lockKeyUniqueValue = 0 // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked). 33 const DefaultLockTimeout = 15 // the default maximum time to wait for a lock to be released. 34 const DefaultLockTimeoutInterval = 10 // the default maximum intervals time for the locking timout. 35 const DefaultAdvisoryLockingFlag = true // the default value for the advisory locking feature flag. Default is true. 36 const LockIndexName = "lock_unique_key" // the name of the index which adds unique constraint to the locking_key field. 37 const contextWaitTimeout = 5 * time.Second // how long to wait for the request to mongo to block/wait for. 38 39 var ( 40 ErrNoDatabaseName = fmt.Errorf("no database name") 41 ErrNilConfig = fmt.Errorf("no config") 42 ErrLockTimeoutConfigConflict = fmt.Errorf("both x-advisory-lock-timeout-interval and x-advisory-lock-timout-interval were specified") 43 ) 44 45 type Mongo struct { 46 client *mongo.Client 47 db *mongo.Database 48 config *Config 49 isLocked atomic.Bool 50 } 51 52 type Locking struct { 53 CollectionName string 54 Timeout int 55 Enabled bool 56 Interval int 57 } 58 type Config struct { 59 DatabaseName string 60 MigrationsCollection string 61 TransactionMode bool 62 Locking Locking 63 } 64 type versionInfo struct { 65 Version int `bson:"version"` 66 Dirty bool `bson:"dirty"` 67 } 68 69 type lockObj struct { 70 Key int `bson:"locking_key"` 71 Pid int `bson:"pid"` 72 Hostname string `bson:"hostname"` 73 CreatedAt time.Time `bson:"created_at"` 74 } 75 type findFilter struct { 76 Key int `bson:"locking_key"` 77 } 78 79 func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) { 80 if config == nil { 81 return nil, ErrNilConfig 82 } 83 if len(config.DatabaseName) == 0 { 84 return nil, ErrNoDatabaseName 85 } 86 if len(config.MigrationsCollection) == 0 { 87 config.MigrationsCollection = DefaultMigrationsCollection 88 } 89 if len(config.Locking.CollectionName) == 0 { 90 config.Locking.CollectionName = DefaultLockingCollection 91 } 92 if config.Locking.Timeout <= 0 { 93 config.Locking.Timeout = DefaultLockTimeout 94 } 95 if config.Locking.Interval <= 0 { 96 config.Locking.Interval = DefaultLockTimeoutInterval 97 } 98 99 mc := &Mongo{ 100 client: instance, 101 db: instance.Database(config.DatabaseName), 102 config: config, 103 } 104 105 if mc.config.Locking.Enabled { 106 if err := mc.ensureLockTable(); err != nil { 107 return nil, err 108 } 109 } 110 if err := mc.ensureVersionTable(); err != nil { 111 return nil, err 112 } 113 114 return mc, nil 115 } 116 117 func (m *Mongo) Open(dsn string) (database.Driver, error) { 118 //connstring is experimental package, but it used for parse connection string in mongo.Connect function 119 uri, err := connstring.Parse(dsn) 120 if err != nil { 121 return nil, err 122 } 123 if len(uri.Database) == 0 { 124 return nil, ErrNoDatabaseName 125 } 126 unknown := url.Values(uri.UnknownOptions) 127 128 migrationsCollection := unknown.Get("x-migrations-collection") 129 lockCollection := unknown.Get("x-advisory-lock-collection") 130 transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false) 131 if err != nil { 132 return nil, err 133 } 134 advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag) 135 if err != nil { 136 return nil, err 137 } 138 lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout) 139 if err != nil { 140 return nil, err 141 } 142 143 lockTimeoutIntervalValue := unknown.Get("x-advisory-lock-timeout-interval") 144 // The initial release had a typo for this argument but for backwards compatibility sake, we will keep supporting it 145 // and we will error out if both values are set. 146 lockTimeoutIntervalValueFromTypo := unknown.Get("x-advisory-lock-timout-interval") 147 148 lockTimeout := lockTimeoutIntervalValue 149 150 if lockTimeoutIntervalValue != "" && lockTimeoutIntervalValueFromTypo != "" { 151 return nil, ErrLockTimeoutConfigConflict 152 } else if lockTimeoutIntervalValueFromTypo != "" { 153 lockTimeout = lockTimeoutIntervalValueFromTypo 154 } 155 156 maxLockCheckInterval, err := parseInt(lockTimeout, DefaultLockTimeoutInterval) 157 158 if err != nil { 159 return nil, err 160 } 161 client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn)) 162 if err != nil { 163 return nil, err 164 } 165 166 if err = client.Ping(context.TODO(), nil); err != nil { 167 return nil, err 168 } 169 mc, err := WithInstance(client, &Config{ 170 DatabaseName: uri.Database, 171 MigrationsCollection: migrationsCollection, 172 TransactionMode: transactionMode, 173 Locking: Locking{ 174 CollectionName: lockCollection, 175 Timeout: lockingTimout, 176 Enabled: advisoryLockingFlag, 177 Interval: maxLockCheckInterval, 178 }, 179 }) 180 if err != nil { 181 return nil, err 182 } 183 return mc, nil 184 } 185 186 //Parse the url param, convert it to boolean 187 // returns error if param invalid. returns defaultValue if param not present 188 func parseBoolean(urlParam string, defaultValue bool) (bool, error) { 189 190 // if parameter passed, parse it (otherwise return default value) 191 if urlParam != "" { 192 result, err := strconv.ParseBool(urlParam) 193 if err != nil { 194 return false, err 195 } 196 return result, nil 197 } 198 199 // if no url Param passed, return default value 200 return defaultValue, nil 201 } 202 203 //Parse the url param, convert it to int 204 // returns error if param invalid. returns defaultValue if param not present 205 func parseInt(urlParam string, defaultValue int) (int, error) { 206 207 // if parameter passed, parse it (otherwise return default value) 208 if urlParam != "" { 209 result, err := strconv.Atoi(urlParam) 210 if err != nil { 211 return -1, err 212 } 213 return result, nil 214 } 215 216 // if no url Param passed, return default value 217 return defaultValue, nil 218 } 219 func (m *Mongo) SetVersion(version int, dirty bool) error { 220 migrationsCollection := m.db.Collection(m.config.MigrationsCollection) 221 if err := migrationsCollection.Drop(context.TODO()); err != nil { 222 return &database.Error{OrigErr: err, Err: "drop migrations collection failed"} 223 } 224 _, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty}) 225 if err != nil { 226 return &database.Error{OrigErr: err, Err: "save version failed"} 227 } 228 return nil 229 } 230 231 func (m *Mongo) Version() (version int, dirty bool, err error) { 232 var versionInfo versionInfo 233 err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo) 234 switch { 235 case err == mongo.ErrNoDocuments: 236 return database.NilVersion, false, nil 237 case err != nil: 238 return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"} 239 default: 240 return versionInfo.Version, versionInfo.Dirty, nil 241 } 242 } 243 244 func (m *Mongo) Run(migration io.Reader) error { 245 migr, err := ioutil.ReadAll(migration) 246 if err != nil { 247 return err 248 } 249 var cmds []bson.D 250 err = bson.UnmarshalExtJSON(migr, true, &cmds) 251 if err != nil { 252 return fmt.Errorf("unmarshaling json error: %s", err) 253 } 254 if m.config.TransactionMode { 255 if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil { 256 return err 257 } 258 } else { 259 if err := m.executeCommands(context.TODO(), cmds); err != nil { 260 return err 261 } 262 } 263 return nil 264 } 265 266 func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error { 267 err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error { 268 if err := sessionContext.StartTransaction(); err != nil { 269 return &database.Error{OrigErr: err, Err: "failed to start transaction"} 270 } 271 if err := m.executeCommands(sessionContext, cmds); err != nil { 272 //When command execution is failed, it's aborting transaction 273 //If you tried to call abortTransaction, it`s return error that transaction already aborted 274 return err 275 } 276 if err := sessionContext.CommitTransaction(sessionContext); err != nil { 277 return &database.Error{OrigErr: err, Err: "failed to commit transaction"} 278 } 279 return nil 280 }) 281 if err != nil { 282 return err 283 } 284 return nil 285 } 286 287 func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error { 288 for _, cmd := range cmds { 289 err := m.db.RunCommand(ctx, cmd).Err() 290 if err != nil { 291 return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)} 292 } 293 } 294 return nil 295 } 296 297 func (m *Mongo) Close() error { 298 return m.client.Disconnect(context.TODO()) 299 } 300 301 func (m *Mongo) Drop() error { 302 return m.db.Drop(context.TODO()) 303 } 304 305 func (m *Mongo) ensureLockTable() error { 306 indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes() 307 308 indexOptions := options.Index().SetUnique(true).SetName(LockIndexName) 309 _, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{ 310 Options: indexOptions, 311 Keys: findFilter{Key: -1}, 312 }) 313 if err != nil { 314 return err 315 } 316 return nil 317 } 318 319 // ensureVersionTable checks if versions table exists and, if not, creates it. 320 // Note that this function locks the database, which deviates from the usual 321 // convention of "caller locks" in the MongoDb type. 322 func (m *Mongo) ensureVersionTable() (err error) { 323 if err = m.Lock(); err != nil { 324 return err 325 } 326 327 defer func() { 328 if e := m.Unlock(); e != nil { 329 if err == nil { 330 err = e 331 } else { 332 err = multierror.Append(err, e) 333 } 334 } 335 }() 336 337 if err != nil { 338 return err 339 } 340 if _, _, err = m.Version(); err != nil { 341 return err 342 } 343 return nil 344 } 345 346 // Utilizes advisory locking on the config.LockingCollection collection 347 // This uses a unique index on the `locking_key` field. 348 func (m *Mongo) Lock() error { 349 return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { 350 if !m.config.Locking.Enabled { 351 return nil 352 } 353 354 pid := os.Getpid() 355 hostname, err := os.Hostname() 356 if err != nil { 357 hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error()) 358 } 359 360 newLockObj := lockObj{ 361 Key: lockKeyUniqueValue, 362 Pid: pid, 363 Hostname: hostname, 364 CreatedAt: time.Now(), 365 } 366 operation := func() error { 367 timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout) 368 _, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj) 369 defer cancelFunc() 370 return err 371 } 372 exponentialBackOff := backoff.NewExponentialBackOff() 373 duration := time.Duration(m.config.Locking.Timeout) * time.Second 374 exponentialBackOff.MaxElapsedTime = duration 375 exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second 376 377 err = backoff.Retry(operation, exponentialBackOff) 378 if err != nil { 379 return database.ErrLocked 380 } 381 382 return nil 383 }) 384 } 385 386 func (m *Mongo) Unlock() error { 387 return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error { 388 if !m.config.Locking.Enabled { 389 return nil 390 } 391 392 filter := findFilter{ 393 Key: lockKeyUniqueValue, 394 } 395 396 ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout) 397 _, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter) 398 defer cancel() 399 400 if err != nil { 401 return err 402 } 403 return nil 404 }) 405 }