github.com/nagyist/migrate/v4@v4.14.6/database/neo4j/neo4j.go (about)

     1  package neo4j
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	neturl "net/url"
     9  	"strconv"
    10  	"sync/atomic"
    11  
    12  	"github.com/golang-migrate/migrate/v4/database"
    13  	"github.com/golang-migrate/migrate/v4/database/multistmt"
    14  	"github.com/hashicorp/go-multierror"
    15  	"github.com/neo4j/neo4j-go-driver/neo4j"
    16  )
    17  
    18  func init() {
    19  	db := Neo4j{}
    20  	database.Register("neo4j", &db)
    21  }
    22  
    23  const DefaultMigrationsLabel = "SchemaMigration"
    24  
    25  var (
    26  	StatementSeparator           = []byte(";")
    27  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    28  )
    29  
    30  var (
    31  	ErrNilConfig = fmt.Errorf("no config")
    32  )
    33  
    34  type Config struct {
    35  	MigrationsLabel       string
    36  	MultiStatement        bool
    37  	MultiStatementMaxSize int
    38  }
    39  
    40  type Neo4j struct {
    41  	driver neo4j.Driver
    42  	lock   uint32
    43  
    44  	// Open and WithInstance need to guarantee that config is never nil
    45  	config *Config
    46  }
    47  
    48  func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) {
    49  	if config == nil {
    50  		return nil, ErrNilConfig
    51  	}
    52  
    53  	nDriver := &Neo4j{
    54  		driver: driver,
    55  		config: config,
    56  	}
    57  
    58  	if err := nDriver.ensureVersionConstraint(); err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	return nDriver, nil
    63  }
    64  
    65  func (n *Neo4j) Open(url string) (database.Driver, error) {
    66  	uri, err := neturl.Parse(url)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  	password, _ := uri.User.Password()
    71  	authToken := neo4j.BasicAuth(uri.User.Username(), password, "")
    72  	uri.User = nil
    73  	uri.Scheme = "bolt"
    74  	msQuery := uri.Query().Get("x-multi-statement")
    75  
    76  	// Whether to turn on/off TLS encryption.
    77  	tlsEncrypted := uri.Query().Get("x-tls-encrypted")
    78  	multi := false
    79  	encrypted := false
    80  	if msQuery != "" {
    81  		multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement"))
    82  		if err != nil {
    83  			return nil, err
    84  		}
    85  	}
    86  
    87  	if tlsEncrypted != "" {
    88  		encrypted, err = strconv.ParseBool(tlsEncrypted)
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  	}
    93  
    94  	multiStatementMaxSize := DefaultMultiStatementMaxSize
    95  	if s := uri.Query().Get("x-multi-statement-max-size"); s != "" {
    96  		multiStatementMaxSize, err = strconv.Atoi(s)
    97  		if err != nil {
    98  			return nil, err
    99  		}
   100  	}
   101  
   102  	uri.RawQuery = ""
   103  
   104  	driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) {
   105  		config.Encrypted = encrypted
   106  	})
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	return WithInstance(driver, &Config{
   112  		MigrationsLabel:       DefaultMigrationsLabel,
   113  		MultiStatement:        multi,
   114  		MultiStatementMaxSize: multiStatementMaxSize,
   115  	})
   116  }
   117  
   118  func (n *Neo4j) Close() error {
   119  	return n.driver.Close()
   120  }
   121  
   122  // local locking in order to pass tests, Neo doesn't support database locking
   123  func (n *Neo4j) Lock() error {
   124  	if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) {
   125  		return database.ErrLocked
   126  	}
   127  
   128  	return nil
   129  }
   130  
   131  func (n *Neo4j) Unlock() error {
   132  	if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) {
   133  		return database.ErrNotLocked
   134  	}
   135  	return nil
   136  }
   137  
   138  func (n *Neo4j) Run(migration io.Reader) (err error) {
   139  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   140  	if err != nil {
   141  		return err
   142  	}
   143  	defer func() {
   144  		if cerr := session.Close(); cerr != nil {
   145  			err = multierror.Append(err, cerr)
   146  		}
   147  	}()
   148  
   149  	if n.config.MultiStatement {
   150  		_, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
   151  			var stmtRunErr error
   152  			if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool {
   153  				trimStmt := bytes.TrimSpace(stmt)
   154  				if len(trimStmt) == 0 {
   155  					return true
   156  				}
   157  				trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator)
   158  				if len(trimStmt) == 0 {
   159  					return true
   160  				}
   161  
   162  				result, err := transaction.Run(string(trimStmt), nil)
   163  				if _, err := neo4j.Collect(result, err); err != nil {
   164  					stmtRunErr = err
   165  					return false
   166  				}
   167  				return true
   168  			}); err != nil {
   169  				return nil, err
   170  			}
   171  			return nil, stmtRunErr
   172  		})
   173  		return err
   174  	}
   175  
   176  	body, err := ioutil.ReadAll(migration)
   177  	if err != nil {
   178  		return err
   179  	}
   180  
   181  	_, err = neo4j.Collect(session.Run(string(body[:]), nil))
   182  	return err
   183  }
   184  
   185  func (n *Neo4j) SetVersion(version int, dirty bool) (err error) {
   186  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   187  	if err != nil {
   188  		return err
   189  	}
   190  	defer func() {
   191  		if cerr := session.Close(); cerr != nil {
   192  			err = multierror.Append(err, cerr)
   193  		}
   194  	}()
   195  
   196  	query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()",
   197  		n.config.MigrationsLabel)
   198  	_, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty}))
   199  	if err != nil {
   200  		return err
   201  	}
   202  	return nil
   203  }
   204  
   205  type MigrationRecord struct {
   206  	Version int
   207  	Dirty   bool
   208  }
   209  
   210  func (n *Neo4j) Version() (version int, dirty bool, err error) {
   211  	session, err := n.driver.Session(neo4j.AccessModeRead)
   212  	if err != nil {
   213  		return database.NilVersion, false, err
   214  	}
   215  	defer func() {
   216  		if cerr := session.Close(); cerr != nil {
   217  			err = multierror.Append(err, cerr)
   218  		}
   219  	}()
   220  
   221  	query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty
   222  ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`,
   223  		n.config.MigrationsLabel)
   224  	result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
   225  		result, err := transaction.Run(query, nil)
   226  		if err != nil {
   227  			return nil, err
   228  		}
   229  		if result.Next() {
   230  			record := result.Record()
   231  			mr := MigrationRecord{}
   232  			versionResult, ok := record.Get("version")
   233  			if !ok {
   234  				mr.Version = database.NilVersion
   235  			} else {
   236  				mr.Version = int(versionResult.(int64))
   237  			}
   238  
   239  			dirtyResult, ok := record.Get("dirty")
   240  			if ok {
   241  				mr.Dirty = dirtyResult.(bool)
   242  			}
   243  
   244  			return mr, nil
   245  		}
   246  		return nil, result.Err()
   247  	})
   248  	if err != nil {
   249  		return database.NilVersion, false, err
   250  	}
   251  	if result == nil {
   252  		return database.NilVersion, false, err
   253  	}
   254  	mr := result.(MigrationRecord)
   255  	return mr.Version, mr.Dirty, err
   256  }
   257  
   258  func (n *Neo4j) Drop() (err error) {
   259  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   260  	if err != nil {
   261  		return err
   262  	}
   263  	defer func() {
   264  		if cerr := session.Close(); cerr != nil {
   265  			err = multierror.Append(err, cerr)
   266  		}
   267  	}()
   268  
   269  	if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil {
   270  		return err
   271  	}
   272  	return nil
   273  }
   274  
   275  func (n *Neo4j) ensureVersionConstraint() (err error) {
   276  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   277  	if err != nil {
   278  		return err
   279  	}
   280  	defer func() {
   281  		if cerr := session.Close(); cerr != nil {
   282  			err = multierror.Append(err, cerr)
   283  		}
   284  	}()
   285  
   286  	/**
   287  	Get constraint and check to avoid error duplicate
   288  	using db.labels() to support Neo4j 3 and 4.
   289  	Neo4J 3 doesn't support db.constraints() YIELD name
   290  	*/
   291  	res, err := neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil))
   292  	if err != nil {
   293  		return err
   294  	}
   295  	if len(res) == 1 {
   296  		return nil
   297  	}
   298  
   299  	query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel)
   300  	if _, err := neo4j.Collect(session.Run(query, nil)); err != nil {
   301  		return err
   302  	}
   303  	return nil
   304  }