github.com/nokia/migrate/v4@v4.16.0/database/neo4j/neo4j.go (about)

     1  package neo4j
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	neturl "net/url"
     9  	"strconv"
    10  	"sync/atomic"
    11  
    12  	"github.com/hashicorp/go-multierror"
    13  	"github.com/neo4j/neo4j-go-driver/neo4j"
    14  	"github.com/nokia/migrate/v4/database"
    15  	"github.com/nokia/migrate/v4/database/multistmt"
    16  	"github.com/nokia/migrate/v4/source"
    17  )
    18  
    19  func init() {
    20  	db := Neo4j{}
    21  	database.Register("neo4j", &db)
    22  }
    23  
    24  const DefaultMigrationsLabel = "SchemaMigration"
    25  
    26  var (
    27  	StatementSeparator           = []byte(";")
    28  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    29  )
    30  
    31  var ErrNilConfig = fmt.Errorf("no config")
    32  
    33  type Config struct {
    34  	MigrationsLabel       string
    35  	MultiStatement        bool
    36  	MultiStatementMaxSize int
    37  }
    38  
    39  type Neo4j struct {
    40  	driver neo4j.Driver
    41  	lock   uint32
    42  
    43  	// Open and WithInstance need to guarantee that config is never nil
    44  	config *Config
    45  }
    46  
    47  func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) {
    48  	if config == nil {
    49  		return nil, ErrNilConfig
    50  	}
    51  
    52  	nDriver := &Neo4j{
    53  		driver: driver,
    54  		config: config,
    55  	}
    56  
    57  	if err := nDriver.ensureVersionConstraint(); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	return nDriver, nil
    62  }
    63  
    64  func (n *Neo4j) Open(url string) (database.Driver, error) {
    65  	uri, err := neturl.Parse(url)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	password, _ := uri.User.Password()
    70  	authToken := neo4j.BasicAuth(uri.User.Username(), password, "")
    71  	uri.User = nil
    72  	uri.Scheme = "bolt"
    73  	msQuery := uri.Query().Get("x-multi-statement")
    74  
    75  	// Whether to turn on/off TLS encryption.
    76  	tlsEncrypted := uri.Query().Get("x-tls-encrypted")
    77  	multi := false
    78  	encrypted := false
    79  	if msQuery != "" {
    80  		multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement"))
    81  		if err != nil {
    82  			return nil, err
    83  		}
    84  	}
    85  
    86  	if tlsEncrypted != "" {
    87  		encrypted, err = strconv.ParseBool(tlsEncrypted)
    88  		if err != nil {
    89  			return nil, err
    90  		}
    91  	}
    92  
    93  	multiStatementMaxSize := DefaultMultiStatementMaxSize
    94  	if s := uri.Query().Get("x-multi-statement-max-size"); s != "" {
    95  		multiStatementMaxSize, err = strconv.Atoi(s)
    96  		if err != nil {
    97  			return nil, err
    98  		}
    99  	}
   100  
   101  	uri.RawQuery = ""
   102  
   103  	driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) {
   104  		config.Encrypted = encrypted
   105  	})
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	return WithInstance(driver, &Config{
   111  		MigrationsLabel:       DefaultMigrationsLabel,
   112  		MultiStatement:        multi,
   113  		MultiStatementMaxSize: multiStatementMaxSize,
   114  	})
   115  }
   116  
   117  func (n *Neo4j) Close() error {
   118  	return n.driver.Close()
   119  }
   120  
   121  // local locking in order to pass tests, Neo doesn't support database locking
   122  func (n *Neo4j) Lock() error {
   123  	if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) {
   124  		return database.ErrLocked
   125  	}
   126  
   127  	return nil
   128  }
   129  
   130  func (n *Neo4j) Unlock() error {
   131  	if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) {
   132  		return database.ErrNotLocked
   133  	}
   134  	return nil
   135  }
   136  
   137  func (n *Neo4j) Run(migration io.Reader) (err error) {
   138  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   139  	if err != nil {
   140  		return err
   141  	}
   142  	defer func() {
   143  		if cerr := session.Close(); cerr != nil {
   144  			err = multierror.Append(err, cerr)
   145  		}
   146  	}()
   147  
   148  	if n.config.MultiStatement {
   149  		_, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
   150  			var stmtRunErr error
   151  			if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool {
   152  				trimStmt := bytes.TrimSpace(stmt)
   153  				if len(trimStmt) == 0 {
   154  					return true
   155  				}
   156  				trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator)
   157  				if len(trimStmt) == 0 {
   158  					return true
   159  				}
   160  
   161  				result, err := transaction.Run(string(trimStmt), nil)
   162  				if _, err := neo4j.Collect(result, err); err != nil {
   163  					stmtRunErr = err
   164  					return false
   165  				}
   166  				return true
   167  			}); err != nil {
   168  				return nil, err
   169  			}
   170  			return nil, stmtRunErr
   171  		})
   172  		return err
   173  	}
   174  
   175  	body, err := ioutil.ReadAll(migration)
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	_, err = neo4j.Collect(session.Run(string(body[:]), nil))
   181  	return err
   182  }
   183  
   184  func (n *Neo4j) RunFunctionMigration(fn source.MigrationFunc) error {
   185  	return database.ErrNotImpl
   186  }
   187  
   188  func (n *Neo4j) SetVersion(version int, dirty bool) (err error) {
   189  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   190  	if err != nil {
   191  		return err
   192  	}
   193  	defer func() {
   194  		if cerr := session.Close(); cerr != nil {
   195  			err = multierror.Append(err, cerr)
   196  		}
   197  	}()
   198  
   199  	query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()",
   200  		n.config.MigrationsLabel)
   201  	_, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty}))
   202  	if err != nil {
   203  		return err
   204  	}
   205  	return nil
   206  }
   207  
   208  type MigrationRecord struct {
   209  	Version int
   210  	Dirty   bool
   211  }
   212  
   213  func (n *Neo4j) Version() (version int, dirty bool, err error) {
   214  	session, err := n.driver.Session(neo4j.AccessModeRead)
   215  	if err != nil {
   216  		return database.NilVersion, false, err
   217  	}
   218  	defer func() {
   219  		if cerr := session.Close(); cerr != nil {
   220  			err = multierror.Append(err, cerr)
   221  		}
   222  	}()
   223  
   224  	query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty
   225  ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`,
   226  		n.config.MigrationsLabel)
   227  	result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
   228  		result, err := transaction.Run(query, nil)
   229  		if err != nil {
   230  			return nil, err
   231  		}
   232  		if result.Next() {
   233  			record := result.Record()
   234  			mr := MigrationRecord{}
   235  			versionResult, ok := record.Get("version")
   236  			if !ok {
   237  				mr.Version = database.NilVersion
   238  			} else {
   239  				mr.Version = int(versionResult.(int64))
   240  			}
   241  
   242  			dirtyResult, ok := record.Get("dirty")
   243  			if ok {
   244  				mr.Dirty = dirtyResult.(bool)
   245  			}
   246  
   247  			return mr, nil
   248  		}
   249  		return nil, result.Err()
   250  	})
   251  	if err != nil {
   252  		return database.NilVersion, false, err
   253  	}
   254  	if result == nil {
   255  		return database.NilVersion, false, err
   256  	}
   257  	mr := result.(MigrationRecord)
   258  	return mr.Version, mr.Dirty, err
   259  }
   260  
   261  func (n *Neo4j) Drop() (err error) {
   262  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   263  	if err != nil {
   264  		return err
   265  	}
   266  	defer func() {
   267  		if cerr := session.Close(); cerr != nil {
   268  			err = multierror.Append(err, cerr)
   269  		}
   270  	}()
   271  
   272  	if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil {
   273  		return err
   274  	}
   275  	return nil
   276  }
   277  
   278  func (n *Neo4j) ensureVersionConstraint() (err error) {
   279  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   280  	if err != nil {
   281  		return err
   282  	}
   283  	defer func() {
   284  		if cerr := session.Close(); cerr != nil {
   285  			err = multierror.Append(err, cerr)
   286  		}
   287  	}()
   288  
   289  	/**
   290  	Get constraint and check to avoid error duplicate
   291  	using db.labels() to support Neo4j 3 and 4.
   292  	Neo4J 3 doesn't support db.constraints() YIELD name
   293  	*/
   294  	res, err := neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil))
   295  	if err != nil {
   296  		return err
   297  	}
   298  	if len(res) == 1 {
   299  		return nil
   300  	}
   301  
   302  	query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel)
   303  	if _, err := neo4j.Collect(session.Run(query, nil)); err != nil {
   304  		return err
   305  	}
   306  	return nil
   307  }