github.com/Elate-DevOps/migrate/v4@v4.0.12/database/neo4j/neo4j.go (about)

     1  package neo4j
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	neturl "net/url"
     8  	"strconv"
     9  	"sync/atomic"
    10  
    11  	"github.com/Elate-DevOps/migrate/v4/database"
    12  	"github.com/Elate-DevOps/migrate/v4/database/multistmt"
    13  	"github.com/hashicorp/go-multierror"
    14  	"github.com/neo4j/neo4j-go-driver/neo4j"
    15  )
    16  
    17  func init() {
    18  	db := Neo4j{}
    19  	database.Register("neo4j", &db)
    20  }
    21  
    22  const DefaultMigrationsLabel = "SchemaMigration"
    23  
    24  var (
    25  	StatementSeparator           = []byte(";")
    26  	DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
    27  )
    28  
    29  var ErrNilConfig = fmt.Errorf("no config")
    30  
    31  type Config struct {
    32  	MigrationsLabel       string
    33  	MultiStatement        bool
    34  	MultiStatementMaxSize int
    35  }
    36  
    37  type Neo4j struct {
    38  	driver neo4j.Driver
    39  	lock   uint32
    40  
    41  	// Open and WithInstance need to guarantee that config is never nil
    42  	config *Config
    43  }
    44  
    45  func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) {
    46  	if config == nil {
    47  		return nil, ErrNilConfig
    48  	}
    49  
    50  	nDriver := &Neo4j{
    51  		driver: driver,
    52  		config: config,
    53  	}
    54  
    55  	if err := nDriver.ensureVersionConstraint(); err != nil {
    56  		return nil, err
    57  	}
    58  
    59  	return nDriver, nil
    60  }
    61  
    62  func (n *Neo4j) Open(url string) (database.Driver, error) {
    63  	uri, err := neturl.Parse(url)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	password, _ := uri.User.Password()
    68  	authToken := neo4j.BasicAuth(uri.User.Username(), password, "")
    69  	uri.User = nil
    70  	uri.Scheme = "bolt"
    71  	msQuery := uri.Query().Get("x-multi-statement")
    72  
    73  	// Whether to turn on/off TLS encryption.
    74  	tlsEncrypted := uri.Query().Get("x-tls-encrypted")
    75  	multi := false
    76  	encrypted := false
    77  	if msQuery != "" {
    78  		multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement"))
    79  		if err != nil {
    80  			return nil, err
    81  		}
    82  	}
    83  
    84  	if tlsEncrypted != "" {
    85  		encrypted, err = strconv.ParseBool(tlsEncrypted)
    86  		if err != nil {
    87  			return nil, err
    88  		}
    89  	}
    90  
    91  	multiStatementMaxSize := DefaultMultiStatementMaxSize
    92  	if s := uri.Query().Get("x-multi-statement-max-size"); s != "" {
    93  		multiStatementMaxSize, err = strconv.Atoi(s)
    94  		if err != nil {
    95  			return nil, err
    96  		}
    97  	}
    98  
    99  	uri.RawQuery = ""
   100  
   101  	driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) {
   102  		config.Encrypted = encrypted
   103  	})
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	return WithInstance(driver, &Config{
   109  		MigrationsLabel:       DefaultMigrationsLabel,
   110  		MultiStatement:        multi,
   111  		MultiStatementMaxSize: multiStatementMaxSize,
   112  	})
   113  }
   114  
   115  func (n *Neo4j) Close() error {
   116  	return n.driver.Close()
   117  }
   118  
   119  // local locking in order to pass tests, Neo doesn't support database locking
   120  func (n *Neo4j) Lock() error {
   121  	if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) {
   122  		return database.ErrLocked
   123  	}
   124  
   125  	return nil
   126  }
   127  
   128  func (n *Neo4j) Unlock() error {
   129  	if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) {
   130  		return database.ErrNotLocked
   131  	}
   132  	return nil
   133  }
   134  
   135  func (n *Neo4j) Run(migration io.Reader) (err error) {
   136  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   137  	if err != nil {
   138  		return err
   139  	}
   140  	defer func() {
   141  		if cerr := session.Close(); cerr != nil {
   142  			err = multierror.Append(err, cerr)
   143  		}
   144  	}()
   145  
   146  	if n.config.MultiStatement {
   147  		_, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
   148  			var stmtRunErr error
   149  			if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool {
   150  				trimStmt := bytes.TrimSpace(stmt)
   151  				if len(trimStmt) == 0 {
   152  					return true
   153  				}
   154  				trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator)
   155  				if len(trimStmt) == 0 {
   156  					return true
   157  				}
   158  
   159  				result, err := transaction.Run(string(trimStmt), nil)
   160  				if _, err := neo4j.Collect(result, err); err != nil {
   161  					stmtRunErr = err
   162  					return false
   163  				}
   164  				return true
   165  			}); err != nil {
   166  				return nil, err
   167  			}
   168  			return nil, stmtRunErr
   169  		})
   170  		return err
   171  	}
   172  
   173  	body, err := io.ReadAll(migration)
   174  	if err != nil {
   175  		return err
   176  	}
   177  
   178  	_, err = neo4j.Collect(session.Run(string(body[:]), nil))
   179  	return err
   180  }
   181  
   182  func (n *Neo4j) SetVersion(version int, dirty bool) (err error) {
   183  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   184  	if err != nil {
   185  		return err
   186  	}
   187  	defer func() {
   188  		if cerr := session.Close(); cerr != nil {
   189  			err = multierror.Append(err, cerr)
   190  		}
   191  	}()
   192  
   193  	query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()",
   194  		n.config.MigrationsLabel)
   195  	_, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty}))
   196  	if err != nil {
   197  		return err
   198  	}
   199  	return nil
   200  }
   201  
   202  type MigrationRecord struct {
   203  	Version int
   204  	Dirty   bool
   205  }
   206  
   207  func (n *Neo4j) Version() (version int, dirty bool, err error) {
   208  	session, err := n.driver.Session(neo4j.AccessModeRead)
   209  	if err != nil {
   210  		return database.NilVersion, false, err
   211  	}
   212  	defer func() {
   213  		if cerr := session.Close(); cerr != nil {
   214  			err = multierror.Append(err, cerr)
   215  		}
   216  	}()
   217  
   218  	query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty
   219  ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`,
   220  		n.config.MigrationsLabel)
   221  	result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
   222  		result, err := transaction.Run(query, nil)
   223  		if err != nil {
   224  			return nil, err
   225  		}
   226  		if result.Next() {
   227  			record := result.Record()
   228  			mr := MigrationRecord{}
   229  			versionResult, ok := record.Get("version")
   230  			if !ok {
   231  				mr.Version = database.NilVersion
   232  			} else {
   233  				mr.Version = int(versionResult.(int64))
   234  			}
   235  
   236  			dirtyResult, ok := record.Get("dirty")
   237  			if ok {
   238  				mr.Dirty = dirtyResult.(bool)
   239  			}
   240  
   241  			return mr, nil
   242  		}
   243  		return nil, result.Err()
   244  	})
   245  	if err != nil {
   246  		return database.NilVersion, false, err
   247  	}
   248  	if result == nil {
   249  		return database.NilVersion, false, err
   250  	}
   251  	mr := result.(MigrationRecord)
   252  	return mr.Version, mr.Dirty, err
   253  }
   254  
   255  func (n *Neo4j) Drop() (err error) {
   256  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   257  	if err != nil {
   258  		return err
   259  	}
   260  	defer func() {
   261  		if cerr := session.Close(); cerr != nil {
   262  			err = multierror.Append(err, cerr)
   263  		}
   264  	}()
   265  
   266  	if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil {
   267  		return err
   268  	}
   269  	return nil
   270  }
   271  
   272  func (n *Neo4j) ensureVersionConstraint() (err error) {
   273  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   274  	if err != nil {
   275  		return err
   276  	}
   277  	defer func() {
   278  		if cerr := session.Close(); cerr != nil {
   279  			err = multierror.Append(err, cerr)
   280  		}
   281  	}()
   282  
   283  	/**
   284  	Get constraint and check to avoid error duplicate
   285  	using db.labels() to support Neo4j 3 and 4.
   286  	Neo4J 3 doesn't support db.constraints() YIELD name
   287  	*/
   288  	res, err := neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil))
   289  	if err != nil {
   290  		return err
   291  	}
   292  	if len(res) == 1 {
   293  		return nil
   294  	}
   295  
   296  	query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel)
   297  	if _, err := neo4j.Collect(session.Run(query, nil)); err != nil {
   298  		return err
   299  	}
   300  	return nil
   301  }