github.com/dynastymasra/migrate/v4@v4.11.0/database/mongodb/mongodb.go (about)

     1  package mongodb
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/url"
     9  	"strconv"
    10  
    11  	"github.com/golang-migrate/migrate/v4/database"
    12  	"go.mongodb.org/mongo-driver/bson"
    13  	"go.mongodb.org/mongo-driver/mongo"
    14  	"go.mongodb.org/mongo-driver/mongo/options"
    15  	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
    16  )
    17  
    18  func init() {
    19  	database.Register("mongodb", &Mongo{})
    20  }
    21  
    22  var DefaultMigrationsCollection = "schema_migrations"
    23  
    24  var (
    25  	ErrNoDatabaseName = fmt.Errorf("no database name")
    26  	ErrNilConfig      = fmt.Errorf("no config")
    27  )
    28  
    29  type Mongo struct {
    30  	client *mongo.Client
    31  	db     *mongo.Database
    32  
    33  	config *Config
    34  }
    35  
    36  type Config struct {
    37  	DatabaseName         string
    38  	MigrationsCollection string
    39  	TransactionMode      bool
    40  }
    41  
    42  type versionInfo struct {
    43  	Version int  `bson:"version"`
    44  	Dirty   bool `bson:"dirty"`
    45  }
    46  
    47  func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
    48  	if config == nil {
    49  		return nil, ErrNilConfig
    50  	}
    51  	if len(config.DatabaseName) == 0 {
    52  		return nil, ErrNoDatabaseName
    53  	}
    54  	if len(config.MigrationsCollection) == 0 {
    55  		config.MigrationsCollection = DefaultMigrationsCollection
    56  	}
    57  	mc := &Mongo{
    58  		client: instance,
    59  		db:     instance.Database(config.DatabaseName),
    60  		config: config,
    61  	}
    62  
    63  	return mc, nil
    64  }
    65  
    66  func (m *Mongo) Open(dsn string) (database.Driver, error) {
    67  	//connsting is experimental package, but it used for parse connection string in mongo.Connect function
    68  	uri, err := connstring.Parse(dsn)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	if len(uri.Database) == 0 {
    73  		return nil, ErrNoDatabaseName
    74  	}
    75  
    76  	unknown := url.Values(uri.UnknownOptions)
    77  
    78  	migrationsCollection := unknown.Get("x-migrations-collection")
    79  	transactionMode, _ := strconv.ParseBool(unknown.Get("x-transaction-mode"))
    80  
    81  	client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn))
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	if err = client.Ping(context.TODO(), nil); err != nil {
    86  		return nil, err
    87  	}
    88  	mc, err := WithInstance(client, &Config{
    89  		DatabaseName:         uri.Database,
    90  		MigrationsCollection: migrationsCollection,
    91  		TransactionMode:      transactionMode,
    92  	})
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	return mc, nil
    97  }
    98  
    99  func (m *Mongo) SetVersion(version int, dirty bool) error {
   100  	migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
   101  	if err := migrationsCollection.Drop(context.TODO()); err != nil {
   102  		return &database.Error{OrigErr: err, Err: "drop migrations collection failed"}
   103  	}
   104  	_, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty})
   105  	if err != nil {
   106  		return &database.Error{OrigErr: err, Err: "save version failed"}
   107  	}
   108  	return nil
   109  }
   110  
   111  func (m *Mongo) Version() (version int, dirty bool, err error) {
   112  	var versionInfo versionInfo
   113  	err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo)
   114  	switch {
   115  	case err == mongo.ErrNoDocuments:
   116  		return database.NilVersion, false, nil
   117  	case err != nil:
   118  		return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"}
   119  	default:
   120  		return versionInfo.Version, versionInfo.Dirty, nil
   121  	}
   122  }
   123  
   124  func (m *Mongo) Run(migration io.Reader) error {
   125  	migr, err := ioutil.ReadAll(migration)
   126  	if err != nil {
   127  		return err
   128  	}
   129  	var cmds []bson.D
   130  	err = bson.UnmarshalExtJSON(migr, true, &cmds)
   131  	if err != nil {
   132  		return fmt.Errorf("unmarshaling json error: %s", err)
   133  	}
   134  	if m.config.TransactionMode {
   135  		if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil {
   136  			return err
   137  		}
   138  	} else {
   139  		if err := m.executeCommands(context.TODO(), cmds); err != nil {
   140  			return err
   141  		}
   142  	}
   143  	return nil
   144  }
   145  
   146  func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error {
   147  	err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error {
   148  		if err := sessionContext.StartTransaction(); err != nil {
   149  			return &database.Error{OrigErr: err, Err: "failed to start transaction"}
   150  		}
   151  		if err := m.executeCommands(sessionContext, cmds); err != nil {
   152  			//When command execution is failed, it's aborting transaction
   153  			//If you tried to call abortTransaction, it`s return error that transaction already aborted
   154  			return err
   155  		}
   156  		if err := sessionContext.CommitTransaction(sessionContext); err != nil {
   157  			return &database.Error{OrigErr: err, Err: "failed to commit transaction"}
   158  		}
   159  		return nil
   160  	})
   161  	if err != nil {
   162  		return err
   163  	}
   164  	return nil
   165  }
   166  
   167  func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error {
   168  	for _, cmd := range cmds {
   169  		err := m.db.RunCommand(ctx, cmd).Err()
   170  		if err != nil {
   171  			return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)}
   172  		}
   173  	}
   174  	return nil
   175  }
   176  
   177  func (m *Mongo) Close() error {
   178  	return m.client.Disconnect(context.TODO())
   179  }
   180  
   181  func (m *Mongo) Drop() error {
   182  	return m.db.Drop(context.TODO())
   183  }
   184  
   185  func (m *Mongo) Lock() error {
   186  	return nil
   187  }
   188  
   189  func (m *Mongo) Unlock() error {
   190  	return nil
   191  }