github.com/fr-nvriep/migrate/v4@v4.3.2/database/mongodb/mongodb.go (about)

     1  package mongodb
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/url"
     9  	"strconv"
    10  
    11  	"github.com/fr-nvriep/migrate/v4"
    12  	"github.com/fr-nvriep/migrate/v4/database"
    13  	"github.com/mongodb/mongo-go-driver/bson"
    14  	"github.com/mongodb/mongo-go-driver/mongo"
    15  	"github.com/mongodb/mongo-go-driver/x/network/connstring"
    16  )
    17  
    18  func init() {
    19  	database.Register("mongodb", &Mongo{})
    20  }
    21  
    22  var DefaultMigrationsCollection = "schema_migrations"
    23  
    24  var (
    25  	ErrNoDatabaseName = fmt.Errorf("no database name")
    26  	ErrNilConfig      = fmt.Errorf("no config")
    27  )
    28  
    29  type Mongo struct {
    30  	client *mongo.Client
    31  	db     *mongo.Database
    32  
    33  	config *Config
    34  }
    35  
    36  type Config struct {
    37  	DatabaseName         string
    38  	MigrationsCollection string
    39  	TransactionMode      bool
    40  }
    41  
    42  type versionInfo struct {
    43  	Version int  `bson:"version"`
    44  	Dirty   bool `bson:"dirty"`
    45  }
    46  
    47  func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
    48  	if config == nil {
    49  		return nil, ErrNilConfig
    50  	}
    51  	if len(config.DatabaseName) == 0 {
    52  		return nil, ErrNoDatabaseName
    53  	}
    54  	if len(config.MigrationsCollection) == 0 {
    55  		config.MigrationsCollection = DefaultMigrationsCollection
    56  	}
    57  	mc := &Mongo{
    58  		client: instance,
    59  		db:     instance.Database(config.DatabaseName),
    60  		config: config,
    61  	}
    62  
    63  	return mc, nil
    64  }
    65  
    66  func (m *Mongo) Open(dsn string) (database.Driver, error) {
    67  	//connsting is experimental package, but it used for parse connection string in mongo.Connect function
    68  	uri, err := connstring.Parse(dsn)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	if len(uri.Database) == 0 {
    73  		return nil, ErrNoDatabaseName
    74  	}
    75  
    76  	purl, err := url.Parse(dsn)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	migrationsCollection := purl.Query().Get("x-migrations-collection")
    81  
    82  	transactionMode, _ := strconv.ParseBool(purl.Query().Get("x-transaction-mode"))
    83  
    84  	q := migrate.FilterCustomQuery(purl)
    85  	q.Scheme = "mongodb"
    86  
    87  	client, err := mongo.Connect(context.TODO(), q.String())
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	if err = client.Ping(context.TODO(), nil); err != nil {
    92  		return nil, err
    93  	}
    94  	mc, err := WithInstance(client, &Config{
    95  		DatabaseName:         uri.Database,
    96  		MigrationsCollection: migrationsCollection,
    97  		TransactionMode:      transactionMode,
    98  	})
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  	return mc, nil
   103  }
   104  
   105  func (m *Mongo) SetVersion(version int, dirty bool) error {
   106  	migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
   107  	if err := migrationsCollection.Drop(context.TODO()); err != nil {
   108  		return &database.Error{OrigErr: err, Err: "drop migrations collection failed"}
   109  	}
   110  	_, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty})
   111  	if err != nil {
   112  		return &database.Error{OrigErr: err, Err: "save version failed"}
   113  	}
   114  	return nil
   115  }
   116  
   117  func (m *Mongo) Version() (version int, dirty bool, err error) {
   118  	var versionInfo versionInfo
   119  	err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo)
   120  	switch {
   121  	case err == mongo.ErrNoDocuments:
   122  		return database.NilVersion, false, nil
   123  	case err != nil:
   124  		return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"}
   125  	default:
   126  		return versionInfo.Version, versionInfo.Dirty, nil
   127  	}
   128  }
   129  
   130  func (m *Mongo) Run(migration io.Reader) error {
   131  	migr, err := ioutil.ReadAll(migration)
   132  	if err != nil {
   133  		return err
   134  	}
   135  	var cmds []bson.D
   136  	err = bson.UnmarshalExtJSON(migr, true, &cmds)
   137  	if err != nil {
   138  		return fmt.Errorf("unmarshaling json error: %s", err)
   139  	}
   140  	if m.config.TransactionMode {
   141  		if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil {
   142  			return err
   143  		}
   144  	} else {
   145  		if err := m.executeCommands(context.TODO(), cmds); err != nil {
   146  			return err
   147  		}
   148  	}
   149  	return nil
   150  }
   151  
   152  func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error {
   153  	err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error {
   154  		if err := sessionContext.StartTransaction(); err != nil {
   155  			return &database.Error{OrigErr: err, Err: "failed to start transaction"}
   156  		}
   157  		if err := m.executeCommands(sessionContext, cmds); err != nil {
   158  			//When command execution is failed, it's aborting transaction
   159  			//If you tried to call abortTransaction, it`s return error that transaction already aborted
   160  			return err
   161  		}
   162  		if err := sessionContext.CommitTransaction(sessionContext); err != nil {
   163  			return &database.Error{OrigErr: err, Err: "failed to commit transaction"}
   164  		}
   165  		return nil
   166  	})
   167  	if err != nil {
   168  		return err
   169  	}
   170  	return nil
   171  }
   172  
   173  func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error {
   174  	for _, cmd := range cmds {
   175  		err := m.db.RunCommand(ctx, cmd).Err()
   176  		if err != nil {
   177  			return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)}
   178  		}
   179  	}
   180  	return nil
   181  }
   182  
   183  func (m *Mongo) Close() error {
   184  	return m.client.Disconnect(context.TODO())
   185  }
   186  
   187  func (m *Mongo) Drop() error {
   188  	return m.db.Drop(context.TODO())
   189  }
   190  
   191  func (m *Mongo) Lock() error {
   192  	return nil
   193  }
   194  
   195  func (m *Mongo) Unlock() error {
   196  	return nil
   197  }