github.com/nokia/migrate/v4@v4.16.0/database/neo4j/neo4j.go (about) 1 package neo4j 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "io/ioutil" 8 neturl "net/url" 9 "strconv" 10 "sync/atomic" 11 12 "github.com/hashicorp/go-multierror" 13 "github.com/neo4j/neo4j-go-driver/neo4j" 14 "github.com/nokia/migrate/v4/database" 15 "github.com/nokia/migrate/v4/database/multistmt" 16 "github.com/nokia/migrate/v4/source" 17 ) 18 19 func init() { 20 db := Neo4j{} 21 database.Register("neo4j", &db) 22 } 23 24 const DefaultMigrationsLabel = "SchemaMigration" 25 26 var ( 27 StatementSeparator = []byte(";") 28 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 29 ) 30 31 var ErrNilConfig = fmt.Errorf("no config") 32 33 type Config struct { 34 MigrationsLabel string 35 MultiStatement bool 36 MultiStatementMaxSize int 37 } 38 39 type Neo4j struct { 40 driver neo4j.Driver 41 lock uint32 42 43 // Open and WithInstance need to guarantee that config is never nil 44 config *Config 45 } 46 47 func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) { 48 if config == nil { 49 return nil, ErrNilConfig 50 } 51 52 nDriver := &Neo4j{ 53 driver: driver, 54 config: config, 55 } 56 57 if err := nDriver.ensureVersionConstraint(); err != nil { 58 return nil, err 59 } 60 61 return nDriver, nil 62 } 63 64 func (n *Neo4j) Open(url string) (database.Driver, error) { 65 uri, err := neturl.Parse(url) 66 if err != nil { 67 return nil, err 68 } 69 password, _ := uri.User.Password() 70 authToken := neo4j.BasicAuth(uri.User.Username(), password, "") 71 uri.User = nil 72 uri.Scheme = "bolt" 73 msQuery := uri.Query().Get("x-multi-statement") 74 75 // Whether to turn on/off TLS encryption. 76 tlsEncrypted := uri.Query().Get("x-tls-encrypted") 77 multi := false 78 encrypted := false 79 if msQuery != "" { 80 multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement")) 81 if err != nil { 82 return nil, err 83 } 84 } 85 86 if tlsEncrypted != "" { 87 encrypted, err = strconv.ParseBool(tlsEncrypted) 88 if err != nil { 89 return nil, err 90 } 91 } 92 93 multiStatementMaxSize := DefaultMultiStatementMaxSize 94 if s := uri.Query().Get("x-multi-statement-max-size"); s != "" { 95 multiStatementMaxSize, err = strconv.Atoi(s) 96 if err != nil { 97 return nil, err 98 } 99 } 100 101 uri.RawQuery = "" 102 103 driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) { 104 config.Encrypted = encrypted 105 }) 106 if err != nil { 107 return nil, err 108 } 109 110 return WithInstance(driver, &Config{ 111 MigrationsLabel: DefaultMigrationsLabel, 112 MultiStatement: multi, 113 MultiStatementMaxSize: multiStatementMaxSize, 114 }) 115 } 116 117 func (n *Neo4j) Close() error { 118 return n.driver.Close() 119 } 120 121 // local locking in order to pass tests, Neo doesn't support database locking 122 func (n *Neo4j) Lock() error { 123 if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) { 124 return database.ErrLocked 125 } 126 127 return nil 128 } 129 130 func (n *Neo4j) Unlock() error { 131 if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) { 132 return database.ErrNotLocked 133 } 134 return nil 135 } 136 137 func (n *Neo4j) Run(migration io.Reader) (err error) { 138 session, err := n.driver.Session(neo4j.AccessModeWrite) 139 if err != nil { 140 return err 141 } 142 defer func() { 143 if cerr := session.Close(); cerr != nil { 144 err = multierror.Append(err, cerr) 145 } 146 }() 147 148 if n.config.MultiStatement { 149 _, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) { 150 var stmtRunErr error 151 if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool { 152 trimStmt := bytes.TrimSpace(stmt) 153 if len(trimStmt) == 0 { 154 return true 155 } 156 trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator) 157 if len(trimStmt) == 0 { 158 return true 159 } 160 161 result, err := transaction.Run(string(trimStmt), nil) 162 if _, err := neo4j.Collect(result, err); err != nil { 163 stmtRunErr = err 164 return false 165 } 166 return true 167 }); err != nil { 168 return nil, err 169 } 170 return nil, stmtRunErr 171 }) 172 return err 173 } 174 175 body, err := ioutil.ReadAll(migration) 176 if err != nil { 177 return err 178 } 179 180 _, err = neo4j.Collect(session.Run(string(body[:]), nil)) 181 return err 182 } 183 184 func (n *Neo4j) RunFunctionMigration(fn source.MigrationFunc) error { 185 return database.ErrNotImpl 186 } 187 188 func (n *Neo4j) SetVersion(version int, dirty bool) (err error) { 189 session, err := n.driver.Session(neo4j.AccessModeWrite) 190 if err != nil { 191 return err 192 } 193 defer func() { 194 if cerr := session.Close(); cerr != nil { 195 err = multierror.Append(err, cerr) 196 } 197 }() 198 199 query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()", 200 n.config.MigrationsLabel) 201 _, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty})) 202 if err != nil { 203 return err 204 } 205 return nil 206 } 207 208 type MigrationRecord struct { 209 Version int 210 Dirty bool 211 } 212 213 func (n *Neo4j) Version() (version int, dirty bool, err error) { 214 session, err := n.driver.Session(neo4j.AccessModeRead) 215 if err != nil { 216 return database.NilVersion, false, err 217 } 218 defer func() { 219 if cerr := session.Close(); cerr != nil { 220 err = multierror.Append(err, cerr) 221 } 222 }() 223 224 query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty 225 ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`, 226 n.config.MigrationsLabel) 227 result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) { 228 result, err := transaction.Run(query, nil) 229 if err != nil { 230 return nil, err 231 } 232 if result.Next() { 233 record := result.Record() 234 mr := MigrationRecord{} 235 versionResult, ok := record.Get("version") 236 if !ok { 237 mr.Version = database.NilVersion 238 } else { 239 mr.Version = int(versionResult.(int64)) 240 } 241 242 dirtyResult, ok := record.Get("dirty") 243 if ok { 244 mr.Dirty = dirtyResult.(bool) 245 } 246 247 return mr, nil 248 } 249 return nil, result.Err() 250 }) 251 if err != nil { 252 return database.NilVersion, false, err 253 } 254 if result == nil { 255 return database.NilVersion, false, err 256 } 257 mr := result.(MigrationRecord) 258 return mr.Version, mr.Dirty, err 259 } 260 261 func (n *Neo4j) Drop() (err error) { 262 session, err := n.driver.Session(neo4j.AccessModeWrite) 263 if err != nil { 264 return err 265 } 266 defer func() { 267 if cerr := session.Close(); cerr != nil { 268 err = multierror.Append(err, cerr) 269 } 270 }() 271 272 if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil { 273 return err 274 } 275 return nil 276 } 277 278 func (n *Neo4j) ensureVersionConstraint() (err error) { 279 session, err := n.driver.Session(neo4j.AccessModeWrite) 280 if err != nil { 281 return err 282 } 283 defer func() { 284 if cerr := session.Close(); cerr != nil { 285 err = multierror.Append(err, cerr) 286 } 287 }() 288 289 /** 290 Get constraint and check to avoid error duplicate 291 using db.labels() to support Neo4j 3 and 4. 292 Neo4J 3 doesn't support db.constraints() YIELD name 293 */ 294 res, err := neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil)) 295 if err != nil { 296 return err 297 } 298 if len(res) == 1 { 299 return nil 300 } 301 302 query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel) 303 if _, err := neo4j.Collect(session.Run(query, nil)); err != nil { 304 return err 305 } 306 return nil 307 }