github.com/nagyist/migrate/v4@v4.14.6/database/neo4j/neo4j.go (about) 1 package neo4j 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "io/ioutil" 8 neturl "net/url" 9 "strconv" 10 "sync/atomic" 11 12 "github.com/golang-migrate/migrate/v4/database" 13 "github.com/golang-migrate/migrate/v4/database/multistmt" 14 "github.com/hashicorp/go-multierror" 15 "github.com/neo4j/neo4j-go-driver/neo4j" 16 ) 17 18 func init() { 19 db := Neo4j{} 20 database.Register("neo4j", &db) 21 } 22 23 const DefaultMigrationsLabel = "SchemaMigration" 24 25 var ( 26 StatementSeparator = []byte(";") 27 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 28 ) 29 30 var ( 31 ErrNilConfig = fmt.Errorf("no config") 32 ) 33 34 type Config struct { 35 MigrationsLabel string 36 MultiStatement bool 37 MultiStatementMaxSize int 38 } 39 40 type Neo4j struct { 41 driver neo4j.Driver 42 lock uint32 43 44 // Open and WithInstance need to guarantee that config is never nil 45 config *Config 46 } 47 48 func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) { 49 if config == nil { 50 return nil, ErrNilConfig 51 } 52 53 nDriver := &Neo4j{ 54 driver: driver, 55 config: config, 56 } 57 58 if err := nDriver.ensureVersionConstraint(); err != nil { 59 return nil, err 60 } 61 62 return nDriver, nil 63 } 64 65 func (n *Neo4j) Open(url string) (database.Driver, error) { 66 uri, err := neturl.Parse(url) 67 if err != nil { 68 return nil, err 69 } 70 password, _ := uri.User.Password() 71 authToken := neo4j.BasicAuth(uri.User.Username(), password, "") 72 uri.User = nil 73 uri.Scheme = "bolt" 74 msQuery := uri.Query().Get("x-multi-statement") 75 76 // Whether to turn on/off TLS encryption. 77 tlsEncrypted := uri.Query().Get("x-tls-encrypted") 78 multi := false 79 encrypted := false 80 if msQuery != "" { 81 multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement")) 82 if err != nil { 83 return nil, err 84 } 85 } 86 87 if tlsEncrypted != "" { 88 encrypted, err = strconv.ParseBool(tlsEncrypted) 89 if err != nil { 90 return nil, err 91 } 92 } 93 94 multiStatementMaxSize := DefaultMultiStatementMaxSize 95 if s := uri.Query().Get("x-multi-statement-max-size"); s != "" { 96 multiStatementMaxSize, err = strconv.Atoi(s) 97 if err != nil { 98 return nil, err 99 } 100 } 101 102 uri.RawQuery = "" 103 104 driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) { 105 config.Encrypted = encrypted 106 }) 107 if err != nil { 108 return nil, err 109 } 110 111 return WithInstance(driver, &Config{ 112 MigrationsLabel: DefaultMigrationsLabel, 113 MultiStatement: multi, 114 MultiStatementMaxSize: multiStatementMaxSize, 115 }) 116 } 117 118 func (n *Neo4j) Close() error { 119 return n.driver.Close() 120 } 121 122 // local locking in order to pass tests, Neo doesn't support database locking 123 func (n *Neo4j) Lock() error { 124 if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) { 125 return database.ErrLocked 126 } 127 128 return nil 129 } 130 131 func (n *Neo4j) Unlock() error { 132 if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) { 133 return database.ErrNotLocked 134 } 135 return nil 136 } 137 138 func (n *Neo4j) Run(migration io.Reader) (err error) { 139 session, err := n.driver.Session(neo4j.AccessModeWrite) 140 if err != nil { 141 return err 142 } 143 defer func() { 144 if cerr := session.Close(); cerr != nil { 145 err = multierror.Append(err, cerr) 146 } 147 }() 148 149 if n.config.MultiStatement { 150 _, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) { 151 var stmtRunErr error 152 if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool { 153 trimStmt := bytes.TrimSpace(stmt) 154 if len(trimStmt) == 0 { 155 return true 156 } 157 trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator) 158 if len(trimStmt) == 0 { 159 return true 160 } 161 162 result, err := transaction.Run(string(trimStmt), nil) 163 if _, err := neo4j.Collect(result, err); err != nil { 164 stmtRunErr = err 165 return false 166 } 167 return true 168 }); err != nil { 169 return nil, err 170 } 171 return nil, stmtRunErr 172 }) 173 return err 174 } 175 176 body, err := ioutil.ReadAll(migration) 177 if err != nil { 178 return err 179 } 180 181 _, err = neo4j.Collect(session.Run(string(body[:]), nil)) 182 return err 183 } 184 185 func (n *Neo4j) SetVersion(version int, dirty bool) (err error) { 186 session, err := n.driver.Session(neo4j.AccessModeWrite) 187 if err != nil { 188 return err 189 } 190 defer func() { 191 if cerr := session.Close(); cerr != nil { 192 err = multierror.Append(err, cerr) 193 } 194 }() 195 196 query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()", 197 n.config.MigrationsLabel) 198 _, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty})) 199 if err != nil { 200 return err 201 } 202 return nil 203 } 204 205 type MigrationRecord struct { 206 Version int 207 Dirty bool 208 } 209 210 func (n *Neo4j) Version() (version int, dirty bool, err error) { 211 session, err := n.driver.Session(neo4j.AccessModeRead) 212 if err != nil { 213 return database.NilVersion, false, err 214 } 215 defer func() { 216 if cerr := session.Close(); cerr != nil { 217 err = multierror.Append(err, cerr) 218 } 219 }() 220 221 query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty 222 ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`, 223 n.config.MigrationsLabel) 224 result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) { 225 result, err := transaction.Run(query, nil) 226 if err != nil { 227 return nil, err 228 } 229 if result.Next() { 230 record := result.Record() 231 mr := MigrationRecord{} 232 versionResult, ok := record.Get("version") 233 if !ok { 234 mr.Version = database.NilVersion 235 } else { 236 mr.Version = int(versionResult.(int64)) 237 } 238 239 dirtyResult, ok := record.Get("dirty") 240 if ok { 241 mr.Dirty = dirtyResult.(bool) 242 } 243 244 return mr, nil 245 } 246 return nil, result.Err() 247 }) 248 if err != nil { 249 return database.NilVersion, false, err 250 } 251 if result == nil { 252 return database.NilVersion, false, err 253 } 254 mr := result.(MigrationRecord) 255 return mr.Version, mr.Dirty, err 256 } 257 258 func (n *Neo4j) Drop() (err error) { 259 session, err := n.driver.Session(neo4j.AccessModeWrite) 260 if err != nil { 261 return err 262 } 263 defer func() { 264 if cerr := session.Close(); cerr != nil { 265 err = multierror.Append(err, cerr) 266 } 267 }() 268 269 if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil { 270 return err 271 } 272 return nil 273 } 274 275 func (n *Neo4j) ensureVersionConstraint() (err error) { 276 session, err := n.driver.Session(neo4j.AccessModeWrite) 277 if err != nil { 278 return err 279 } 280 defer func() { 281 if cerr := session.Close(); cerr != nil { 282 err = multierror.Append(err, cerr) 283 } 284 }() 285 286 /** 287 Get constraint and check to avoid error duplicate 288 using db.labels() to support Neo4j 3 and 4. 289 Neo4J 3 doesn't support db.constraints() YIELD name 290 */ 291 res, err := neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil)) 292 if err != nil { 293 return err 294 } 295 if len(res) == 1 { 296 return nil 297 } 298 299 query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel) 300 if _, err := neo4j.Collect(session.Run(query, nil)); err != nil { 301 return err 302 } 303 return nil 304 }