github.com/dynastymasra/migrate/v4@v4.11.0/database/neo4j/neo4j.go (about) 1 package neo4j 2 3 import ( 4 "C" // import C so that we can't compile with CGO_ENABLED=0 5 "bytes" 6 "fmt" 7 "io" 8 "io/ioutil" 9 neturl "net/url" 10 "strconv" 11 "sync/atomic" 12 13 "github.com/golang-migrate/migrate/v4/database" 14 "github.com/hashicorp/go-multierror" 15 "github.com/neo4j/neo4j-go-driver/neo4j" 16 ) 17 18 func init() { 19 db := Neo4j{} 20 database.Register("neo4j", &db) 21 } 22 23 const DefaultMigrationsLabel = "SchemaMigration" 24 25 var StatementSeparator = []byte(";") 26 27 var ( 28 ErrNilConfig = fmt.Errorf("no config") 29 ) 30 31 type Config struct { 32 AuthToken neo4j.AuthToken 33 URL string // if using WithInstance, don't provide auth in the URL, it will be ignored 34 MigrationsLabel string 35 MultiStatement bool 36 } 37 38 type Neo4j struct { 39 driver neo4j.Driver 40 lock uint32 41 42 // Open and WithInstance need to guarantee that config is never nil 43 config *Config 44 } 45 46 func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) { 47 if config == nil { 48 return nil, ErrNilConfig 49 } 50 51 nDriver := &Neo4j{ 52 driver: driver, 53 config: config, 54 } 55 56 if err := nDriver.ensureVersionConstraint(); err != nil { 57 return nil, err 58 } 59 60 return nDriver, nil 61 } 62 63 func (n *Neo4j) Open(url string) (database.Driver, error) { 64 uri, err := neturl.Parse(url) 65 if err != nil { 66 return nil, err 67 } 68 password, _ := uri.User.Password() 69 authToken := neo4j.BasicAuth(uri.User.Username(), password, "") 70 uri.User = nil 71 uri.Scheme = "bolt" 72 msQuery := uri.Query().Get("x-multi-statement") 73 multi := false 74 if msQuery != "" { 75 multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement")) 76 if err != nil { 77 return nil, err 78 } 79 } 80 uri.RawQuery = "" 81 82 driver, err := neo4j.NewDriver(uri.String(), authToken) 83 if err != nil { 84 return nil, err 85 } 86 87 return WithInstance(driver, &Config{ 88 URL: uri.String(), 89 AuthToken: authToken, 90 MigrationsLabel: DefaultMigrationsLabel, 91 MultiStatement: multi, 92 }) 93 } 94 95 func (n *Neo4j) Close() error { 96 return n.driver.Close() 97 } 98 99 // local locking in order to pass tests, Neo doesn't support database locking 100 func (n *Neo4j) Lock() error { 101 if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) { 102 return database.ErrLocked 103 } 104 105 return nil 106 } 107 108 func (n *Neo4j) Unlock() error { 109 if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) { 110 return database.ErrNotLocked 111 } 112 return nil 113 } 114 115 func (n *Neo4j) Run(migration io.Reader) (err error) { 116 body, err := ioutil.ReadAll(migration) 117 if err != nil { 118 return err 119 } 120 121 session, err := n.driver.Session(neo4j.AccessModeWrite) 122 if err != nil { 123 return err 124 } 125 defer func() { 126 if cerr := session.Close(); cerr != nil { 127 err = multierror.Append(err, cerr) 128 } 129 }() 130 131 if n.config.MultiStatement { 132 statements := bytes.Split(body, StatementSeparator) 133 _, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) { 134 for _, stmt := range statements { 135 trimStmt := bytes.TrimSpace(stmt) 136 if len(trimStmt) == 0 { 137 continue 138 } 139 result, err := transaction.Run(string(trimStmt[:]), nil) 140 if _, err := neo4j.Collect(result, err); err != nil { 141 return nil, err 142 } 143 } 144 return nil, nil 145 }) 146 if err != nil { 147 return err 148 } 149 } else { 150 if _, err := neo4j.Collect(session.Run(string(body[:]), nil)); err != nil { 151 return err 152 } 153 } 154 155 return nil 156 } 157 158 func (n *Neo4j) SetVersion(version int, dirty bool) (err error) { 159 session, err := n.driver.Session(neo4j.AccessModeWrite) 160 if err != nil { 161 return err 162 } 163 defer func() { 164 if cerr := session.Close(); cerr != nil { 165 err = multierror.Append(err, cerr) 166 } 167 }() 168 169 query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()", 170 n.config.MigrationsLabel) 171 _, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty})) 172 if err != nil { 173 return err 174 } 175 return nil 176 } 177 178 type MigrationRecord struct { 179 Version int 180 Dirty bool 181 } 182 183 func (n *Neo4j) Version() (version int, dirty bool, err error) { 184 session, err := n.driver.Session(neo4j.AccessModeRead) 185 if err != nil { 186 return database.NilVersion, false, err 187 } 188 defer func() { 189 if cerr := session.Close(); cerr != nil { 190 err = multierror.Append(err, cerr) 191 } 192 }() 193 194 query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty 195 ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`, 196 n.config.MigrationsLabel) 197 result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) { 198 result, err := transaction.Run(query, nil) 199 if err != nil { 200 return nil, err 201 } 202 if result.Next() { 203 record := result.Record() 204 mr := MigrationRecord{} 205 versionResult, ok := record.Get("version") 206 if !ok { 207 mr.Version = database.NilVersion 208 } else { 209 mr.Version = int(versionResult.(int64)) 210 } 211 212 dirtyResult, ok := record.Get("dirty") 213 if ok { 214 mr.Dirty = dirtyResult.(bool) 215 } 216 217 return mr, nil 218 } 219 return nil, result.Err() 220 }) 221 if err != nil { 222 return database.NilVersion, false, err 223 } 224 if result == nil { 225 return database.NilVersion, false, err 226 } 227 mr := result.(MigrationRecord) 228 return mr.Version, mr.Dirty, err 229 } 230 231 func (n *Neo4j) Drop() (err error) { 232 session, err := n.driver.Session(neo4j.AccessModeWrite) 233 if err != nil { 234 return err 235 } 236 defer func() { 237 if cerr := session.Close(); cerr != nil { 238 err = multierror.Append(err, cerr) 239 } 240 }() 241 242 if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil { 243 return err 244 } 245 return nil 246 } 247 248 func (n *Neo4j) ensureVersionConstraint() (err error) { 249 session, err := n.driver.Session(neo4j.AccessModeWrite) 250 if err != nil { 251 return err 252 } 253 defer func() { 254 if cerr := session.Close(); cerr != nil { 255 err = multierror.Append(err, cerr) 256 } 257 }() 258 259 query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel) 260 if _, err := neo4j.Collect(session.Run(query, nil)); err != nil { 261 return err 262 } 263 return nil 264 }