github.com/mrqzzz/migrate@v5.1.7+incompatible/database/mongodb/mongodb.go (about)

     1  package mongodb
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/url"
     9  	"strconv"
    10  
    11  	"github.com/golang-migrate/migrate/v4"
    12  	"github.com/golang-migrate/migrate/v4/database"
    13  	"github.com/mongodb/mongo-go-driver/bson"
    14  	"github.com/mongodb/mongo-go-driver/mongo"
    15  	"github.com/mongodb/mongo-go-driver/mongo/options"
    16  	"github.com/mongodb/mongo-go-driver/x/network/connstring"
    17  )
    18  
    19  func init() {
    20  	database.Register("mongodb", &Mongo{})
    21  }
    22  
    23  var DefaultMigrationsCollection = "schema_migrations"
    24  
    25  var (
    26  	ErrNoDatabaseName = fmt.Errorf("no database name")
    27  	ErrNilConfig      = fmt.Errorf("no config")
    28  )
    29  
    30  type Mongo struct {
    31  	client *mongo.Client
    32  	db     *mongo.Database
    33  
    34  	config *Config
    35  }
    36  
    37  type Config struct {
    38  	DatabaseName         string
    39  	MigrationsCollection string
    40  	TransactionMode      bool
    41  }
    42  
    43  type versionInfo struct {
    44  	Version int  `bson:"version"`
    45  	Dirty   bool `bson:"dirty"`
    46  }
    47  
    48  func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
    49  	if config == nil {
    50  		return nil, ErrNilConfig
    51  	}
    52  	if len(config.DatabaseName) == 0 {
    53  		return nil, ErrNoDatabaseName
    54  	}
    55  	if len(config.MigrationsCollection) == 0 {
    56  		config.MigrationsCollection = DefaultMigrationsCollection
    57  	}
    58  	mc := &Mongo{
    59  		client: instance,
    60  		db:     instance.Database(config.DatabaseName),
    61  		config: config,
    62  	}
    63  	return mc, nil
    64  }
    65  
    66  func (m *Mongo) Open(dsn string) (database.Driver, error) {
    67  	//connsting is experimental package, but it used for parse connection string in mongo.Connect function
    68  	uri, err := connstring.Parse(dsn)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	if len(uri.Database) == 0 {
    73  		return nil, ErrNoDatabaseName
    74  	}
    75  
    76  	purl, err := url.Parse(dsn)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	migrationsCollection := purl.Query().Get("x-migrations-collection")
    81  	if len(migrationsCollection) == 0 {
    82  		migrationsCollection = DefaultMigrationsCollection
    83  	}
    84  
    85  	transactionMode, _ := strconv.ParseBool(purl.Query().Get("x-transaction-mode"))
    86  
    87  	q := migrate.FilterCustomQuery(purl)
    88  	q.Scheme = "mongodb"
    89  
    90  	client, err := mongo.Connect(context.TODO(), q.String())
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	if err = client.Ping(context.TODO(), nil); err != nil {
    95  		return nil, err
    96  	}
    97  	mc, err := WithInstance(client, &Config{
    98  		DatabaseName:         uri.Database,
    99  		MigrationsCollection: migrationsCollection,
   100  		TransactionMode:      transactionMode,
   101  	})
   102  	if err != nil {
   103  		return nil, err
   104  	}
   105  	return mc, nil
   106  }
   107  
   108  func (m *Mongo) SetVersion(version int, dirty bool) error {
   109  	migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
   110  	var tr = true
   111  	filt := bson.D{{"version", bson.D{{"$exists", false}}}}
   112  	if res := migrationsCollection.FindOneAndUpdate(context.TODO(), filt, bson.M{"version": version, "dirty": dirty}, &options.FindOneAndUpdateOptions{Upsert: &tr}); res.Err() != nil {
   113  		return &database.Error{OrigErr: res.Err(), Err: "drop migrations collection failed"}
   114  	}
   115  	//if err := migrationsCollection.Drop(context.TODO()); err != nil {
   116  	//	return &database.Error{OrigErr: err, Err: "drop migrations collection failed"}
   117  	//}
   118  	//_, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty})
   119  	//if err != nil {
   120  	//	return &database.Error{OrigErr: err, Err: "save version failed"}
   121  	//}
   122  	return nil
   123  }
   124  
   125  func (m *Mongo) Version() (version int, dirty bool, err error) {
   126  	var versionInfo versionInfo
   127  	err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo)
   128  	switch {
   129  	case err == mongo.ErrNoDocuments:
   130  		return database.NilVersion, false, nil
   131  	case err != nil:
   132  		return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"}
   133  	default:
   134  		return versionInfo.Version, versionInfo.Dirty, nil
   135  	}
   136  }
   137  
   138  func (m *Mongo) Run(migration io.Reader) error {
   139  	migr, err := ioutil.ReadAll(migration)
   140  	if err != nil {
   141  		return err
   142  	}
   143  	var cmds []bson.D
   144  	err = bson.UnmarshalExtJSON(migr, true, &cmds)
   145  	if err != nil {
   146  		return fmt.Errorf("unmarshaling json error: %s", err)
   147  	}
   148  	if m.config.TransactionMode {
   149  		if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil {
   150  			return err
   151  		}
   152  	} else {
   153  		if err := m.executeCommands(context.TODO(), cmds); err != nil {
   154  			return err
   155  		}
   156  	}
   157  	return nil
   158  }
   159  
   160  func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error {
   161  	err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error {
   162  		if err := sessionContext.StartTransaction(); err != nil {
   163  			return &database.Error{OrigErr: err, Err: "failed to start transaction"}
   164  		}
   165  		if err := m.executeCommands(sessionContext, cmds); err != nil {
   166  			//When command execution is failed, it's aborting transaction
   167  			//If you tried to call abortTransaction, it`s return error that transaction already aborted
   168  			return err
   169  		}
   170  		if err := sessionContext.CommitTransaction(sessionContext); err != nil {
   171  			return &database.Error{OrigErr: err, Err: "failed to commit transaction"}
   172  		}
   173  		return nil
   174  	})
   175  	if err != nil {
   176  		return err
   177  	}
   178  	return nil
   179  }
   180  
   181  func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error {
   182  	for _, cmd := range cmds {
   183  		err := m.db.RunCommand(ctx, cmd).Err()
   184  		if err != nil {
   185  			return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)}
   186  		}
   187  	}
   188  	return nil
   189  }
   190  
   191  func (m *Mongo) Close() error {
   192  	return m.client.Disconnect(context.TODO())
   193  }
   194  
   195  func (m *Mongo) Drop() error {
   196  	return m.db.Drop(context.TODO())
   197  }
   198  
   199  func (m *Mongo) Lock() error {
   200  	return nil
   201  }
   202  
   203  func (m *Mongo) Unlock() error {
   204  	return nil
   205  }