github.com/Elate-DevOps/migrate/v4@v4.0.12/database/neo4j/neo4j.go (about) 1 package neo4j 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 neturl "net/url" 8 "strconv" 9 "sync/atomic" 10 11 "github.com/Elate-DevOps/migrate/v4/database" 12 "github.com/Elate-DevOps/migrate/v4/database/multistmt" 13 "github.com/hashicorp/go-multierror" 14 "github.com/neo4j/neo4j-go-driver/neo4j" 15 ) 16 17 func init() { 18 db := Neo4j{} 19 database.Register("neo4j", &db) 20 } 21 22 const DefaultMigrationsLabel = "SchemaMigration" 23 24 var ( 25 StatementSeparator = []byte(";") 26 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 27 ) 28 29 var ErrNilConfig = fmt.Errorf("no config") 30 31 type Config struct { 32 MigrationsLabel string 33 MultiStatement bool 34 MultiStatementMaxSize int 35 } 36 37 type Neo4j struct { 38 driver neo4j.Driver 39 lock uint32 40 41 // Open and WithInstance need to guarantee that config is never nil 42 config *Config 43 } 44 45 func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) { 46 if config == nil { 47 return nil, ErrNilConfig 48 } 49 50 nDriver := &Neo4j{ 51 driver: driver, 52 config: config, 53 } 54 55 if err := nDriver.ensureVersionConstraint(); err != nil { 56 return nil, err 57 } 58 59 return nDriver, nil 60 } 61 62 func (n *Neo4j) Open(url string) (database.Driver, error) { 63 uri, err := neturl.Parse(url) 64 if err != nil { 65 return nil, err 66 } 67 password, _ := uri.User.Password() 68 authToken := neo4j.BasicAuth(uri.User.Username(), password, "") 69 uri.User = nil 70 uri.Scheme = "bolt" 71 msQuery := uri.Query().Get("x-multi-statement") 72 73 // Whether to turn on/off TLS encryption. 74 tlsEncrypted := uri.Query().Get("x-tls-encrypted") 75 multi := false 76 encrypted := false 77 if msQuery != "" { 78 multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement")) 79 if err != nil { 80 return nil, err 81 } 82 } 83 84 if tlsEncrypted != "" { 85 encrypted, err = strconv.ParseBool(tlsEncrypted) 86 if err != nil { 87 return nil, err 88 } 89 } 90 91 multiStatementMaxSize := DefaultMultiStatementMaxSize 92 if s := uri.Query().Get("x-multi-statement-max-size"); s != "" { 93 multiStatementMaxSize, err = strconv.Atoi(s) 94 if err != nil { 95 return nil, err 96 } 97 } 98 99 uri.RawQuery = "" 100 101 driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) { 102 config.Encrypted = encrypted 103 }) 104 if err != nil { 105 return nil, err 106 } 107 108 return WithInstance(driver, &Config{ 109 MigrationsLabel: DefaultMigrationsLabel, 110 MultiStatement: multi, 111 MultiStatementMaxSize: multiStatementMaxSize, 112 }) 113 } 114 115 func (n *Neo4j) Close() error { 116 return n.driver.Close() 117 } 118 119 // local locking in order to pass tests, Neo doesn't support database locking 120 func (n *Neo4j) Lock() error { 121 if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) { 122 return database.ErrLocked 123 } 124 125 return nil 126 } 127 128 func (n *Neo4j) Unlock() error { 129 if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) { 130 return database.ErrNotLocked 131 } 132 return nil 133 } 134 135 func (n *Neo4j) Run(migration io.Reader) (err error) { 136 session, err := n.driver.Session(neo4j.AccessModeWrite) 137 if err != nil { 138 return err 139 } 140 defer func() { 141 if cerr := session.Close(); cerr != nil { 142 err = multierror.Append(err, cerr) 143 } 144 }() 145 146 if n.config.MultiStatement { 147 _, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) { 148 var stmtRunErr error 149 if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool { 150 trimStmt := bytes.TrimSpace(stmt) 151 if len(trimStmt) == 0 { 152 return true 153 } 154 trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator) 155 if len(trimStmt) == 0 { 156 return true 157 } 158 159 result, err := transaction.Run(string(trimStmt), nil) 160 if _, err := neo4j.Collect(result, err); err != nil { 161 stmtRunErr = err 162 return false 163 } 164 return true 165 }); err != nil { 166 return nil, err 167 } 168 return nil, stmtRunErr 169 }) 170 return err 171 } 172 173 body, err := io.ReadAll(migration) 174 if err != nil { 175 return err 176 } 177 178 _, err = neo4j.Collect(session.Run(string(body[:]), nil)) 179 return err 180 } 181 182 func (n *Neo4j) SetVersion(version int, dirty bool) (err error) { 183 session, err := n.driver.Session(neo4j.AccessModeWrite) 184 if err != nil { 185 return err 186 } 187 defer func() { 188 if cerr := session.Close(); cerr != nil { 189 err = multierror.Append(err, cerr) 190 } 191 }() 192 193 query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()", 194 n.config.MigrationsLabel) 195 _, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty})) 196 if err != nil { 197 return err 198 } 199 return nil 200 } 201 202 type MigrationRecord struct { 203 Version int 204 Dirty bool 205 } 206 207 func (n *Neo4j) Version() (version int, dirty bool, err error) { 208 session, err := n.driver.Session(neo4j.AccessModeRead) 209 if err != nil { 210 return database.NilVersion, false, err 211 } 212 defer func() { 213 if cerr := session.Close(); cerr != nil { 214 err = multierror.Append(err, cerr) 215 } 216 }() 217 218 query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty 219 ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`, 220 n.config.MigrationsLabel) 221 result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) { 222 result, err := transaction.Run(query, nil) 223 if err != nil { 224 return nil, err 225 } 226 if result.Next() { 227 record := result.Record() 228 mr := MigrationRecord{} 229 versionResult, ok := record.Get("version") 230 if !ok { 231 mr.Version = database.NilVersion 232 } else { 233 mr.Version = int(versionResult.(int64)) 234 } 235 236 dirtyResult, ok := record.Get("dirty") 237 if ok { 238 mr.Dirty = dirtyResult.(bool) 239 } 240 241 return mr, nil 242 } 243 return nil, result.Err() 244 }) 245 if err != nil { 246 return database.NilVersion, false, err 247 } 248 if result == nil { 249 return database.NilVersion, false, err 250 } 251 mr := result.(MigrationRecord) 252 return mr.Version, mr.Dirty, err 253 } 254 255 func (n *Neo4j) Drop() (err error) { 256 session, err := n.driver.Session(neo4j.AccessModeWrite) 257 if err != nil { 258 return err 259 } 260 defer func() { 261 if cerr := session.Close(); cerr != nil { 262 err = multierror.Append(err, cerr) 263 } 264 }() 265 266 if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil { 267 return err 268 } 269 return nil 270 } 271 272 func (n *Neo4j) ensureVersionConstraint() (err error) { 273 session, err := n.driver.Session(neo4j.AccessModeWrite) 274 if err != nil { 275 return err 276 } 277 defer func() { 278 if cerr := session.Close(); cerr != nil { 279 err = multierror.Append(err, cerr) 280 } 281 }() 282 283 /** 284 Get constraint and check to avoid error duplicate 285 using db.labels() to support Neo4j 3 and 4. 286 Neo4J 3 doesn't support db.constraints() YIELD name 287 */ 288 res, err := neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil)) 289 if err != nil { 290 return err 291 } 292 if len(res) == 1 { 293 return nil 294 } 295 296 query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel) 297 if _, err := neo4j.Collect(session.Run(query, nil)); err != nil { 298 return err 299 } 300 return nil 301 }