github.com/dynastymasra/migrate/v4@v4.11.0/database/neo4j/neo4j.go (about)

     1  package neo4j
     2  
     3  import (
     4  	"C" // import C so that we can't compile with CGO_ENABLED=0
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	neturl "net/url"
    10  	"strconv"
    11  	"sync/atomic"
    12  
    13  	"github.com/golang-migrate/migrate/v4/database"
    14  	"github.com/hashicorp/go-multierror"
    15  	"github.com/neo4j/neo4j-go-driver/neo4j"
    16  )
    17  
    18  func init() {
    19  	db := Neo4j{}
    20  	database.Register("neo4j", &db)
    21  }
    22  
    23  const DefaultMigrationsLabel = "SchemaMigration"
    24  
    25  var StatementSeparator = []byte(";")
    26  
    27  var (
    28  	ErrNilConfig = fmt.Errorf("no config")
    29  )
    30  
    31  type Config struct {
    32  	AuthToken       neo4j.AuthToken
    33  	URL             string // if using WithInstance, don't provide auth in the URL, it will be ignored
    34  	MigrationsLabel string
    35  	MultiStatement  bool
    36  }
    37  
    38  type Neo4j struct {
    39  	driver neo4j.Driver
    40  	lock   uint32
    41  
    42  	// Open and WithInstance need to guarantee that config is never nil
    43  	config *Config
    44  }
    45  
    46  func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) {
    47  	if config == nil {
    48  		return nil, ErrNilConfig
    49  	}
    50  
    51  	nDriver := &Neo4j{
    52  		driver: driver,
    53  		config: config,
    54  	}
    55  
    56  	if err := nDriver.ensureVersionConstraint(); err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	return nDriver, nil
    61  }
    62  
    63  func (n *Neo4j) Open(url string) (database.Driver, error) {
    64  	uri, err := neturl.Parse(url)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	password, _ := uri.User.Password()
    69  	authToken := neo4j.BasicAuth(uri.User.Username(), password, "")
    70  	uri.User = nil
    71  	uri.Scheme = "bolt"
    72  	msQuery := uri.Query().Get("x-multi-statement")
    73  	multi := false
    74  	if msQuery != "" {
    75  		multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement"))
    76  		if err != nil {
    77  			return nil, err
    78  		}
    79  	}
    80  	uri.RawQuery = ""
    81  
    82  	driver, err := neo4j.NewDriver(uri.String(), authToken)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  
    87  	return WithInstance(driver, &Config{
    88  		URL:             uri.String(),
    89  		AuthToken:       authToken,
    90  		MigrationsLabel: DefaultMigrationsLabel,
    91  		MultiStatement:  multi,
    92  	})
    93  }
    94  
    95  func (n *Neo4j) Close() error {
    96  	return n.driver.Close()
    97  }
    98  
    99  // local locking in order to pass tests, Neo doesn't support database locking
   100  func (n *Neo4j) Lock() error {
   101  	if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) {
   102  		return database.ErrLocked
   103  	}
   104  
   105  	return nil
   106  }
   107  
   108  func (n *Neo4j) Unlock() error {
   109  	if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) {
   110  		return database.ErrNotLocked
   111  	}
   112  	return nil
   113  }
   114  
   115  func (n *Neo4j) Run(migration io.Reader) (err error) {
   116  	body, err := ioutil.ReadAll(migration)
   117  	if err != nil {
   118  		return err
   119  	}
   120  
   121  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   122  	if err != nil {
   123  		return err
   124  	}
   125  	defer func() {
   126  		if cerr := session.Close(); cerr != nil {
   127  			err = multierror.Append(err, cerr)
   128  		}
   129  	}()
   130  
   131  	if n.config.MultiStatement {
   132  		statements := bytes.Split(body, StatementSeparator)
   133  		_, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
   134  			for _, stmt := range statements {
   135  				trimStmt := bytes.TrimSpace(stmt)
   136  				if len(trimStmt) == 0 {
   137  					continue
   138  				}
   139  				result, err := transaction.Run(string(trimStmt[:]), nil)
   140  				if _, err := neo4j.Collect(result, err); err != nil {
   141  					return nil, err
   142  				}
   143  			}
   144  			return nil, nil
   145  		})
   146  		if err != nil {
   147  			return err
   148  		}
   149  	} else {
   150  		if _, err := neo4j.Collect(session.Run(string(body[:]), nil)); err != nil {
   151  			return err
   152  		}
   153  	}
   154  
   155  	return nil
   156  }
   157  
   158  func (n *Neo4j) SetVersion(version int, dirty bool) (err error) {
   159  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   160  	if err != nil {
   161  		return err
   162  	}
   163  	defer func() {
   164  		if cerr := session.Close(); cerr != nil {
   165  			err = multierror.Append(err, cerr)
   166  		}
   167  	}()
   168  
   169  	query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()",
   170  		n.config.MigrationsLabel)
   171  	_, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty}))
   172  	if err != nil {
   173  		return err
   174  	}
   175  	return nil
   176  }
   177  
   178  type MigrationRecord struct {
   179  	Version int
   180  	Dirty   bool
   181  }
   182  
   183  func (n *Neo4j) Version() (version int, dirty bool, err error) {
   184  	session, err := n.driver.Session(neo4j.AccessModeRead)
   185  	if err != nil {
   186  		return database.NilVersion, false, err
   187  	}
   188  	defer func() {
   189  		if cerr := session.Close(); cerr != nil {
   190  			err = multierror.Append(err, cerr)
   191  		}
   192  	}()
   193  
   194  	query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty
   195  ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`,
   196  		n.config.MigrationsLabel)
   197  	result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
   198  		result, err := transaction.Run(query, nil)
   199  		if err != nil {
   200  			return nil, err
   201  		}
   202  		if result.Next() {
   203  			record := result.Record()
   204  			mr := MigrationRecord{}
   205  			versionResult, ok := record.Get("version")
   206  			if !ok {
   207  				mr.Version = database.NilVersion
   208  			} else {
   209  				mr.Version = int(versionResult.(int64))
   210  			}
   211  
   212  			dirtyResult, ok := record.Get("dirty")
   213  			if ok {
   214  				mr.Dirty = dirtyResult.(bool)
   215  			}
   216  
   217  			return mr, nil
   218  		}
   219  		return nil, result.Err()
   220  	})
   221  	if err != nil {
   222  		return database.NilVersion, false, err
   223  	}
   224  	if result == nil {
   225  		return database.NilVersion, false, err
   226  	}
   227  	mr := result.(MigrationRecord)
   228  	return mr.Version, mr.Dirty, err
   229  }
   230  
   231  func (n *Neo4j) Drop() (err error) {
   232  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   233  	if err != nil {
   234  		return err
   235  	}
   236  	defer func() {
   237  		if cerr := session.Close(); cerr != nil {
   238  			err = multierror.Append(err, cerr)
   239  		}
   240  	}()
   241  
   242  	if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil {
   243  		return err
   244  	}
   245  	return nil
   246  }
   247  
   248  func (n *Neo4j) ensureVersionConstraint() (err error) {
   249  	session, err := n.driver.Session(neo4j.AccessModeWrite)
   250  	if err != nil {
   251  		return err
   252  	}
   253  	defer func() {
   254  		if cerr := session.Close(); cerr != nil {
   255  			err = multierror.Append(err, cerr)
   256  		}
   257  	}()
   258  
   259  	query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel)
   260  	if _, err := neo4j.Collect(session.Run(query, nil)); err != nil {
   261  		return err
   262  	}
   263  	return nil
   264  }