github.com/nagyist/migrate/v4@v4.14.6/database/mongodb/mongodb.go (about) 1 package mongodb 2 3 import ( 4 "context" 5 "fmt" 6 "github.com/cenkalti/backoff/v4" 7 "github.com/golang-migrate/migrate/v4/database" 8 "github.com/hashicorp/go-multierror" 9 "go.mongodb.org/mongo-driver/bson" 10 "go.mongodb.org/mongo-driver/mongo" 11 "go.mongodb.org/mongo-driver/mongo/options" 12 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" 13 "io" 14 "io/ioutil" 15 "net/url" 16 os "os" 17 "strconv" 18 "time" 19 ) 20 21 func init() { 22 db := Mongo{} 23 database.Register("mongodb", &db) 24 database.Register("mongodb+srv", &db) 25 } 26 27 var DefaultMigrationsCollection = "schema_migrations" 28 29 const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default. 30 const lockKeyUniqueValue = 0 // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked). 31 const DefaultLockTimeout = 15 // the default maximum time to wait for a lock to be released. 32 const DefaultLockTimeoutInterval = 10 // the default maximum intervals time for the locking timout. 33 const DefaultAdvisoryLockingFlag = true // the default value for the advisory locking feature flag. Default is true. 34 const LockIndexName = "lock_unique_key" // the name of the index which adds unique constraint to the locking_key field. 35 const contextWaitTimeout = 5 * time.Second // how long to wait for the request to mongo to block/wait for. 36 37 var ( 38 ErrNoDatabaseName = fmt.Errorf("no database name") 39 ErrNilConfig = fmt.Errorf("no config") 40 ) 41 42 type Mongo struct { 43 client *mongo.Client 44 db *mongo.Database 45 config *Config 46 } 47 48 type Locking struct { 49 CollectionName string 50 Timeout int 51 Enabled bool 52 Interval int 53 } 54 type Config struct { 55 DatabaseName string 56 MigrationsCollection string 57 TransactionMode bool 58 Locking Locking 59 } 60 type versionInfo struct { 61 Version int `bson:"version"` 62 Dirty bool `bson:"dirty"` 63 } 64 65 type lockObj struct { 66 Key int `bson:"locking_key"` 67 Pid int `bson:"pid"` 68 Hostname string `bson:"hostname"` 69 CreatedAt time.Time `bson:"created_at"` 70 } 71 type findFilter struct { 72 Key int `bson:"locking_key"` 73 } 74 75 func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) { 76 if config == nil { 77 return nil, ErrNilConfig 78 } 79 if len(config.DatabaseName) == 0 { 80 return nil, ErrNoDatabaseName 81 } 82 if len(config.MigrationsCollection) == 0 { 83 config.MigrationsCollection = DefaultMigrationsCollection 84 } 85 if len(config.Locking.CollectionName) == 0 { 86 config.Locking.CollectionName = DefaultLockingCollection 87 } 88 if config.Locking.Timeout <= 0 { 89 config.Locking.Timeout = DefaultLockTimeout 90 } 91 if config.Locking.Interval <= 0 { 92 config.Locking.Interval = DefaultLockTimeoutInterval 93 } 94 95 mc := &Mongo{ 96 client: instance, 97 db: instance.Database(config.DatabaseName), 98 config: config, 99 } 100 101 if mc.config.Locking.Enabled { 102 if err := mc.ensureLockTable(); err != nil { 103 return nil, err 104 } 105 } 106 if err := mc.ensureVersionTable(); err != nil { 107 return nil, err 108 } 109 110 return mc, nil 111 } 112 113 func (m *Mongo) Open(dsn string) (database.Driver, error) { 114 //connstring is experimental package, but it used for parse connection string in mongo.Connect function 115 uri, err := connstring.Parse(dsn) 116 if err != nil { 117 return nil, err 118 } 119 if len(uri.Database) == 0 { 120 return nil, ErrNoDatabaseName 121 } 122 unknown := url.Values(uri.UnknownOptions) 123 124 migrationsCollection := unknown.Get("x-migrations-collection") 125 lockCollection := unknown.Get("x-advisory-lock-collection") 126 transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false) 127 if err != nil { 128 return nil, err 129 } 130 advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag) 131 if err != nil { 132 return nil, err 133 } 134 lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout) 135 if err != nil { 136 return nil, err 137 } 138 maxLockingIntervals, err := parseInt(unknown.Get("x-advisory-lock-timout-interval"), DefaultLockTimeoutInterval) 139 if err != nil { 140 return nil, err 141 } 142 client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn)) 143 if err != nil { 144 return nil, err 145 } 146 147 if err = client.Ping(context.TODO(), nil); err != nil { 148 return nil, err 149 } 150 mc, err := WithInstance(client, &Config{ 151 DatabaseName: uri.Database, 152 MigrationsCollection: migrationsCollection, 153 TransactionMode: transactionMode, 154 Locking: Locking{ 155 CollectionName: lockCollection, 156 Timeout: lockingTimout, 157 Enabled: advisoryLockingFlag, 158 Interval: maxLockingIntervals, 159 }, 160 }) 161 if err != nil { 162 return nil, err 163 } 164 return mc, nil 165 } 166 167 //Parse the url param, convert it to boolean 168 // returns error if param invalid. returns defaultValue if param not present 169 func parseBoolean(urlParam string, defaultValue bool) (bool, error) { 170 171 // if parameter passed, parse it (otherwise return default value) 172 if urlParam != "" { 173 result, err := strconv.ParseBool(urlParam) 174 if err != nil { 175 return false, err 176 } 177 return result, nil 178 } 179 180 // if no url Param passed, return default value 181 return defaultValue, nil 182 } 183 184 //Parse the url param, convert it to int 185 // returns error if param invalid. returns defaultValue if param not present 186 func parseInt(urlParam string, defaultValue int) (int, error) { 187 188 // if parameter passed, parse it (otherwise return default value) 189 if urlParam != "" { 190 result, err := strconv.Atoi(urlParam) 191 if err != nil { 192 return -1, err 193 } 194 return result, nil 195 } 196 197 // if no url Param passed, return default value 198 return defaultValue, nil 199 } 200 func (m *Mongo) SetVersion(version int, dirty bool) error { 201 migrationsCollection := m.db.Collection(m.config.MigrationsCollection) 202 if err := migrationsCollection.Drop(context.TODO()); err != nil { 203 return &database.Error{OrigErr: err, Err: "drop migrations collection failed"} 204 } 205 _, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty}) 206 if err != nil { 207 return &database.Error{OrigErr: err, Err: "save version failed"} 208 } 209 return nil 210 } 211 212 func (m *Mongo) Version() (version int, dirty bool, err error) { 213 var versionInfo versionInfo 214 err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo) 215 switch { 216 case err == mongo.ErrNoDocuments: 217 return database.NilVersion, false, nil 218 case err != nil: 219 return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"} 220 default: 221 return versionInfo.Version, versionInfo.Dirty, nil 222 } 223 } 224 225 func (m *Mongo) Run(migration io.Reader) error { 226 migr, err := ioutil.ReadAll(migration) 227 if err != nil { 228 return err 229 } 230 var cmds []bson.D 231 err = bson.UnmarshalExtJSON(migr, true, &cmds) 232 if err != nil { 233 return fmt.Errorf("unmarshaling json error: %s", err) 234 } 235 if m.config.TransactionMode { 236 if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil { 237 return err 238 } 239 } else { 240 if err := m.executeCommands(context.TODO(), cmds); err != nil { 241 return err 242 } 243 } 244 return nil 245 } 246 247 func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error { 248 err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error { 249 if err := sessionContext.StartTransaction(); err != nil { 250 return &database.Error{OrigErr: err, Err: "failed to start transaction"} 251 } 252 if err := m.executeCommands(sessionContext, cmds); err != nil { 253 //When command execution is failed, it's aborting transaction 254 //If you tried to call abortTransaction, it`s return error that transaction already aborted 255 return err 256 } 257 if err := sessionContext.CommitTransaction(sessionContext); err != nil { 258 return &database.Error{OrigErr: err, Err: "failed to commit transaction"} 259 } 260 return nil 261 }) 262 if err != nil { 263 return err 264 } 265 return nil 266 } 267 268 func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error { 269 for _, cmd := range cmds { 270 err := m.db.RunCommand(ctx, cmd).Err() 271 if err != nil { 272 return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)} 273 } 274 } 275 return nil 276 } 277 278 func (m *Mongo) Close() error { 279 return m.client.Disconnect(context.TODO()) 280 } 281 282 func (m *Mongo) Drop() error { 283 return m.db.Drop(context.TODO()) 284 } 285 286 func (m *Mongo) ensureLockTable() error { 287 indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes() 288 289 indexOptions := options.Index().SetUnique(true).SetName(LockIndexName) 290 _, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{ 291 Options: indexOptions, 292 Keys: findFilter{Key: -1}, 293 }) 294 if err != nil { 295 return err 296 } 297 return nil 298 } 299 300 // ensureVersionTable checks if versions table exists and, if not, creates it. 301 // Note that this function locks the database, which deviates from the usual 302 // convention of "caller locks" in the MongoDb type. 303 func (m *Mongo) ensureVersionTable() (err error) { 304 if err = m.Lock(); err != nil { 305 return err 306 } 307 308 defer func() { 309 if e := m.Unlock(); e != nil { 310 if err == nil { 311 err = e 312 } else { 313 err = multierror.Append(err, e) 314 } 315 } 316 }() 317 318 if err != nil { 319 return err 320 } 321 if _, _, err = m.Version(); err != nil { 322 return err 323 } 324 return nil 325 } 326 327 // Utilizes advisory locking on the config.LockingCollection collection 328 // This uses a unique index on the `locking_key` field. 329 func (m *Mongo) Lock() error { 330 if !m.config.Locking.Enabled { 331 return nil 332 } 333 pid := os.Getpid() 334 hostname, err := os.Hostname() 335 if err != nil { 336 hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error()) 337 } 338 339 newLockObj := lockObj{ 340 Key: lockKeyUniqueValue, 341 Pid: pid, 342 Hostname: hostname, 343 CreatedAt: time.Now(), 344 } 345 operation := func() error { 346 timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout) 347 _, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj) 348 defer cancelFunc() 349 return err 350 } 351 exponentialBackOff := backoff.NewExponentialBackOff() 352 duration := time.Duration(m.config.Locking.Timeout) * time.Second 353 exponentialBackOff.MaxElapsedTime = duration 354 exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second 355 356 err = backoff.Retry(operation, exponentialBackOff) 357 if err != nil { 358 return database.ErrLocked 359 } 360 361 return nil 362 363 } 364 func (m *Mongo) Unlock() error { 365 if !m.config.Locking.Enabled { 366 return nil 367 } 368 369 filter := findFilter{ 370 Key: lockKeyUniqueValue, 371 } 372 373 ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout) 374 _, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter) 375 defer cancel() 376 377 if err != nil { 378 return err 379 } 380 return nil 381 }