github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/neo4j/neo4j.go (about)

     1  package neo4j
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	neturl "net/url"
     8  	"strconv"
     9  	"sync/atomic"
    10  
    11  	"github.com/golang-migrate/migrate/v4/database"
    12  	"github.com/golang-migrate/migrate/v4/database/multistmt"
    13  	"github.com/hashicorp/go-multierror"
    14  	"github.com/neo4j/neo4j-go-driver/neo4j"
    15  )
    16  
    17  func init() {
    18  	db := Neo4j{}
    19  	database.Register("neo4j", &db)
    20  }
    21  
    22  const DefaultMigrationsLabel = "SchemaMigration"
    23  
    24  var (
    25  	StatementSeparator           = []byte(";")
    26  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    27  )
    28  
    29  var (
    30  	ErrNilConfig = fmt.Errorf("no config")
    31  )
    32  
    33  type Config struct {
    34  	MigrationsLabel       string
    35  	MultiStatement        bool
    36  	MultiStatementMaxSize int
    37  }
    38  
    39  type Neo4j struct {
    40  	driver neo4j.Driver
    41  	lock   uint32
    42  
    43  	// Open and WithInstance need to guarantee that config is never nil
    44  	config *Config
    45  }
    46  
    47  func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) {
    48  	if config == nil {
    49  		return nil, ErrNilConfig
    50  	}
    51  
    52  	nDriver := &Neo4j{
    53  		driver: driver,
    54  		config: config,
    55  	}
    56  
    57  	if err := nDriver.ensureVersionConstraint(); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	return nDriver, nil
    62  }
    63  
    64  func (n *Neo4j) Open(url string) (database.Driver, error) {
    65  	uri, err := neturl.Parse(url)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	password, _ := uri.User.Password()
    70  	authToken := neo4j.BasicAuth(uri.User.Username(), password, "")
    71  	uri.User = nil
    72  	uri.Scheme = "bolt"
    73  	msQuery := uri.Query().Get("x-multi-statement")
    74  
    75  	// Whether to turn on/off TLS encryption.
    76  	tlsEncrypted := uri.Query().Get("x-tls-encrypted")
    77  	multi := false
    78  	encrypted := false
    79  	if msQuery != "" {
    80  		multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement"))
    81  		if err != nil {
    82  			return nil, err
    83  		}
    84  	}
    85  
    86  	if tlsEncrypted != "" {
    87  		encrypted, err = strconv.ParseBool(tlsEncrypted)
    88  		if err != nil {
    89  			return nil, err
    90  		}
    91  	}
    92  
    93  	multiStatementMaxSize := DefaultMultiStatementMaxSize
    94  	if s := uri.Query().Get("x-multi-statement-max-size"); s != "" {
    95  		multiStatementMaxSize, err = strconv.Atoi(s)
    96  		if err != nil {
    97  			return nil, err
    98  		}
    99  	}
   100  
   101  	uri.RawQuery = ""
   102  
   103  	driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) {
   104  		config.Encrypted = encrypted
   105  	})
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	return WithInstance(driver, &Config{
   111  		MigrationsLabel:       DefaultMigrationsLabel,
   112  		MultiStatement:        multi,
   113  		MultiStatementMaxSize: multiStatementMaxSize,
   114  	})
   115  }
   116  
   117  func (n *Neo4j) Close() error {
   118  	return n.driver.Close()
   119  }
   120  
   121  // local locking in order to pass tests, Neo doesn't support database locking
   122  func (n *Neo4j) Lock() error {
   123  	if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) {
   124  		return database.ErrLocked
   125  	}
   126  
   127  	return nil
   128  }
   129  
   130  func (n *Neo4j) Unlock() error {
   131  	if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) {
   132  		return database.ErrNotLocked
   133  	}
   134  	return nil
   135  }
   136  
   137  func (n *Neo4j) Run(migration io.Reader) (err error) {
   138  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   139  	if err != nil {
   140  		return err
   141  	}
   142  	defer func() {
   143  		if cerr := session.Close(); cerr != nil {
   144  			err = multierror.Append(err, cerr)
   145  		}
   146  	}()
   147  
   148  	if n.config.MultiStatement {
   149  		_, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
   150  			var stmtRunErr error
   151  			if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool {
   152  				trimStmt := bytes.TrimSpace(stmt)
   153  				if len(trimStmt) == 0 {
   154  					return true
   155  				}
   156  				trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator)
   157  				if len(trimStmt) == 0 {
   158  					return true
   159  				}
   160  
   161  				result, err := transaction.Run(string(trimStmt), nil)
   162  				if _, err := neo4j.Collect(result, err); err != nil {
   163  					stmtRunErr = err
   164  					return false
   165  				}
   166  				return true
   167  			}); err != nil {
   168  				return nil, err
   169  			}
   170  			return nil, stmtRunErr
   171  		})
   172  		return err
   173  	}
   174  
   175  	body, err := io.ReadAll(migration)
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	_, err = neo4j.Collect(session.Run(string(body[:]), nil))
   181  	return err
   182  }
   183  
   184  func (n *Neo4j) SetVersion(version int, dirty bool) (err error) {
   185  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   186  	if err != nil {
   187  		return err
   188  	}
   189  	defer func() {
   190  		if cerr := session.Close(); cerr != nil {
   191  			err = multierror.Append(err, cerr)
   192  		}
   193  	}()
   194  
   195  	query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()",
   196  		n.config.MigrationsLabel)
   197  	_, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty}))
   198  	if err != nil {
   199  		return err
   200  	}
   201  	return nil
   202  }
   203  
   204  type MigrationRecord struct {
   205  	Version int
   206  	Dirty   bool
   207  }
   208  
   209  func (n *Neo4j) Version() (version int, dirty bool, err error) {
   210  	session, err := n.driver.Session(neo4j.AccessModeRead)
   211  	if err != nil {
   212  		return database.NilVersion, false, err
   213  	}
   214  	defer func() {
   215  		if cerr := session.Close(); cerr != nil {
   216  			err = multierror.Append(err, cerr)
   217  		}
   218  	}()
   219  
   220  	query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty
   221  ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`,
   222  		n.config.MigrationsLabel)
   223  	result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
   224  		result, err := transaction.Run(query, nil)
   225  		if err != nil {
   226  			return nil, err
   227  		}
   228  		if result.Next() {
   229  			record := result.Record()
   230  			mr := MigrationRecord{}
   231  			versionResult, ok := record.Get("version")
   232  			if !ok {
   233  				mr.Version = database.NilVersion
   234  			} else {
   235  				mr.Version = int(versionResult.(int64))
   236  			}
   237  
   238  			dirtyResult, ok := record.Get("dirty")
   239  			if ok {
   240  				mr.Dirty = dirtyResult.(bool)
   241  			}
   242  
   243  			return mr, nil
   244  		}
   245  		return nil, result.Err()
   246  	})
   247  	if err != nil {
   248  		return database.NilVersion, false, err
   249  	}
   250  	if result == nil {
   251  		return database.NilVersion, false, err
   252  	}
   253  	mr := result.(MigrationRecord)
   254  	return mr.Version, mr.Dirty, err
   255  }
   256  
   257  func (n *Neo4j) Drop() (err error) {
   258  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   259  	if err != nil {
   260  		return err
   261  	}
   262  	defer func() {
   263  		if cerr := session.Close(); cerr != nil {
   264  			err = multierror.Append(err, cerr)
   265  		}
   266  	}()
   267  
   268  	if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil {
   269  		return err
   270  	}
   271  	return nil
   272  }
   273  
   274  func (n *Neo4j) ensureVersionConstraint() (err error) {
   275  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   276  	if err != nil {
   277  		return err
   278  	}
   279  	defer func() {
   280  		if cerr := session.Close(); cerr != nil {
   281  			err = multierror.Append(err, cerr)
   282  		}
   283  	}()
   284  
   285  	/**
   286  	Get constraint and check to avoid error duplicate
   287  	using db.labels() to support Neo4j 3 and 4.
   288  	Neo4J 3 doesn't support db.constraints() YIELD name
   289  	*/
   290  	res, err := neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil))
   291  	if err != nil {
   292  		return err
   293  	}
   294  	if len(res) == 1 {
   295  		return nil
   296  	}
   297  
   298  	query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel)
   299  	if _, err := neo4j.Collect(session.Run(query, nil)); err != nil {
   300  		return err
   301  	}
   302  	return nil
   303  }