github.com/scraniel/migrate@v0.0.0-20230320185700-339088f36cee/database/neo4j/neo4j.go (about) 1 package neo4j 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 neturl "net/url" 8 "strconv" 9 "sync/atomic" 10 11 "github.com/golang-migrate/migrate/v4/database" 12 "github.com/golang-migrate/migrate/v4/database/multistmt" 13 "github.com/hashicorp/go-multierror" 14 "github.com/neo4j/neo4j-go-driver/neo4j" 15 ) 16 17 func init() { 18 db := Neo4j{} 19 database.Register("neo4j", &db) 20 } 21 22 const DefaultMigrationsLabel = "SchemaMigration" 23 24 var ( 25 StatementSeparator = []byte(";") 26 DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB 27 ) 28 29 var ( 30 ErrNilConfig = fmt.Errorf("no config") 31 ) 32 33 type Config struct { 34 MigrationsLabel string 35 MultiStatement bool 36 MultiStatementMaxSize int 37 } 38 39 type Neo4j struct { 40 driver neo4j.Driver 41 lock uint32 42 43 // Open and WithInstance need to guarantee that config is never nil 44 config *Config 45 } 46 47 func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) { 48 if config == nil { 49 return nil, ErrNilConfig 50 } 51 52 nDriver := &Neo4j{ 53 driver: driver, 54 config: config, 55 } 56 57 if err := nDriver.ensureVersionConstraint(); err != nil { 58 return nil, err 59 } 60 61 return nDriver, nil 62 } 63 64 func (n *Neo4j) Open(url string) (database.Driver, error) { 65 uri, err := neturl.Parse(url) 66 if err != nil { 67 return nil, err 68 } 69 password, _ := uri.User.Password() 70 authToken := neo4j.BasicAuth(uri.User.Username(), password, "") 71 uri.User = nil 72 uri.Scheme = "bolt" 73 msQuery := uri.Query().Get("x-multi-statement") 74 75 // Whether to turn on/off TLS encryption. 76 tlsEncrypted := uri.Query().Get("x-tls-encrypted") 77 multi := false 78 encrypted := false 79 if msQuery != "" { 80 multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement")) 81 if err != nil { 82 return nil, err 83 } 84 } 85 86 if tlsEncrypted != "" { 87 encrypted, err = strconv.ParseBool(tlsEncrypted) 88 if err != nil { 89 return nil, err 90 } 91 } 92 93 multiStatementMaxSize := DefaultMultiStatementMaxSize 94 if s := uri.Query().Get("x-multi-statement-max-size"); s != "" { 95 multiStatementMaxSize, err = strconv.Atoi(s) 96 if err != nil { 97 return nil, err 98 } 99 } 100 101 uri.RawQuery = "" 102 103 driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) { 104 config.Encrypted = encrypted 105 }) 106 if err != nil { 107 return nil, err 108 } 109 110 return WithInstance(driver, &Config{ 111 MigrationsLabel: DefaultMigrationsLabel, 112 MultiStatement: multi, 113 MultiStatementMaxSize: multiStatementMaxSize, 114 }) 115 } 116 117 func (n *Neo4j) Close() error { 118 return n.driver.Close() 119 } 120 121 // local locking in order to pass tests, Neo doesn't support database locking 122 func (n *Neo4j) Lock() error { 123 if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) { 124 return database.ErrLocked 125 } 126 127 return nil 128 } 129 130 func (n *Neo4j) Unlock() error { 131 if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) { 132 return database.ErrNotLocked 133 } 134 return nil 135 } 136 137 func (n *Neo4j) Run(migration io.Reader) (err error) { 138 session, err := n.driver.Session(neo4j.AccessModeWrite) 139 if err != nil { 140 return err 141 } 142 defer func() { 143 if cerr := session.Close(); cerr != nil { 144 err = multierror.Append(err, cerr) 145 } 146 }() 147 148 if n.config.MultiStatement { 149 _, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) { 150 var stmtRunErr error 151 if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool { 152 trimStmt := bytes.TrimSpace(stmt) 153 if len(trimStmt) == 0 { 154 return true 155 } 156 trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator) 157 if len(trimStmt) == 0 { 158 return true 159 } 160 161 result, err := transaction.Run(string(trimStmt), nil) 162 if _, err := neo4j.Collect(result, err); err != nil { 163 stmtRunErr = err 164 return false 165 } 166 return true 167 }); err != nil { 168 return nil, err 169 } 170 return nil, stmtRunErr 171 }) 172 return err 173 } 174 175 body, err := io.ReadAll(migration) 176 if err != nil { 177 return err 178 } 179 180 _, err = neo4j.Collect(session.Run(string(body[:]), nil)) 181 return err 182 } 183 184 func (n *Neo4j) SetVersion(version int, dirty bool) (err error) { 185 session, err := n.driver.Session(neo4j.AccessModeWrite) 186 if err != nil { 187 return err 188 } 189 defer func() { 190 if cerr := session.Close(); cerr != nil { 191 err = multierror.Append(err, cerr) 192 } 193 }() 194 195 query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()", 196 n.config.MigrationsLabel) 197 _, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty})) 198 if err != nil { 199 return err 200 } 201 return nil 202 } 203 204 type MigrationRecord struct { 205 Version int 206 Dirty bool 207 } 208 209 func (n *Neo4j) Version() (version int, dirty bool, err error) { 210 session, err := n.driver.Session(neo4j.AccessModeRead) 211 if err != nil { 212 return database.NilVersion, false, err 213 } 214 defer func() { 215 if cerr := session.Close(); cerr != nil { 216 err = multierror.Append(err, cerr) 217 } 218 }() 219 220 query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty 221 ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`, 222 n.config.MigrationsLabel) 223 result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) { 224 result, err := transaction.Run(query, nil) 225 if err != nil { 226 return nil, err 227 } 228 if result.Next() { 229 record := result.Record() 230 mr := MigrationRecord{} 231 versionResult, ok := record.Get("version") 232 if !ok { 233 mr.Version = database.NilVersion 234 } else { 235 mr.Version = int(versionResult.(int64)) 236 } 237 238 dirtyResult, ok := record.Get("dirty") 239 if ok { 240 mr.Dirty = dirtyResult.(bool) 241 } 242 243 return mr, nil 244 } 245 return nil, result.Err() 246 }) 247 if err != nil { 248 return database.NilVersion, false, err 249 } 250 if result == nil { 251 return database.NilVersion, false, err 252 } 253 mr := result.(MigrationRecord) 254 return mr.Version, mr.Dirty, err 255 } 256 257 func (n *Neo4j) Drop() (err error) { 258 session, err := n.driver.Session(neo4j.AccessModeWrite) 259 if err != nil { 260 return err 261 } 262 defer func() { 263 if cerr := session.Close(); cerr != nil { 264 err = multierror.Append(err, cerr) 265 } 266 }() 267 268 if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil { 269 return err 270 } 271 return nil 272 } 273 274 func (n *Neo4j) ensureVersionConstraint() (err error) { 275 session, err := n.driver.Session(neo4j.AccessModeWrite) 276 if err != nil { 277 return err 278 } 279 defer func() { 280 if cerr := session.Close(); cerr != nil { 281 err = multierror.Append(err, cerr) 282 } 283 }() 284 285 /** 286 Get constraint and check to avoid error duplicate 287 using db.labels() to support Neo4j 3 and 4. 288 Neo4J 3 doesn't support db.constraints() YIELD name 289 */ 290 res, err := neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil)) 291 if err != nil { 292 return err 293 } 294 if len(res) == 1 { 295 return nil 296 } 297 298 query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel) 299 if _, err := neo4j.Collect(session.Run(query, nil)); err != nil { 300 return err 301 } 302 return nil 303 }