github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/localdb/shared/genericSQL.go (about) 1 package shared 2 3 import ( 4 "context" 5 "embed" 6 "encoding/json" 7 "fmt" 8 "strings" 9 "time" 10 11 "database/sql" 12 13 sync "github.com/bacalhau-project/golang-mutex-tracer" 14 "github.com/filecoin-project/bacalhau/pkg/bacerrors" 15 "github.com/filecoin-project/bacalhau/pkg/localdb" 16 model "github.com/filecoin-project/bacalhau/pkg/model/v1beta1" 17 "github.com/golang-migrate/migrate/v4" 18 "github.com/golang-migrate/migrate/v4/source/iofs" 19 ) 20 21 // SQLClient is so we can pass *sql.DB and *sql.Tx to the same functions 22 type SQLClient interface { 23 ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 24 QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row 25 QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) 26 } 27 28 type GenericSQLDatastore struct { 29 mtx sync.RWMutex 30 connectionString string 31 db *sql.DB 32 } 33 34 func NewGenericSQLDatastore( 35 db *sql.DB, 36 name string, 37 connectionString string, 38 ) (*GenericSQLDatastore, error) { 39 datastore := &GenericSQLDatastore{ 40 connectionString: connectionString, 41 db: db, 42 } 43 datastore.mtx.EnableTracerWithOpts(sync.Opts{ 44 Threshold: 10 * time.Millisecond, 45 Id: fmt.Sprintf("GenericSQLDatastore[%s].mtx", name), 46 }) 47 return datastore, nil 48 } 49 50 func (d *GenericSQLDatastore) GetDB() *sql.DB { 51 return d.db 52 } 53 54 func getJob(db SQLClient, ctx context.Context, id string) (*model.Job, error) { 55 var apiversion string 56 var jobdata string 57 var statedata string 58 row := db.QueryRowContext(ctx, `select apiversion, jobdata, statedata from job where id like $1 || '%'`, strings.ToLower(id)) 59 err := row.Scan(&apiversion, &jobdata, &statedata) 60 if err != nil { 61 if err == sql.ErrNoRows { 62 return nil, bacerrors.NewJobNotFound(id) 63 } else { 64 return nil, err 65 } 66 } 67 job, err := model.APIVersionParseJob(apiversion, jobdata) 68 if err != nil { 69 return nil, err 70 } 71 state, err := model.APIVersionParseJobState(apiversion, statedata) 72 if err != nil { 73 return nil, err 74 } 75 job.Status.State = state 76 return &job, nil 77 } 78 79 func (d *GenericSQLDatastore) GetJob(ctx context.Context, id string) (*model.Job, error) { 80 d.mtx.RLock() 81 defer d.mtx.RUnlock() 82 return getJob(d.db, ctx, id) 83 } 84 85 func getJobsSQL( 86 query localdb.JobQuery, 87 countMode bool, 88 ) (string, []interface{}, error) { 89 var args []interface{} 90 clauses := []string{} 91 92 queryCounter := 0 93 getQueryCounter := func() string { 94 queryCounter++ 95 return fmt.Sprintf("$%d", queryCounter) 96 } 97 98 handleTag := func(annotation string, include bool) { 99 appendQuery := " < 1" 100 if include { 101 appendQuery = " > 0" 102 } 103 clauses = append(clauses, fmt.Sprintf(` 104 ( 105 select count(*) from job_annotation 106 where job_annotation.annotation = %s 107 and job_annotation.job_id = job.id 108 ) %s 109 `, getQueryCounter(), appendQuery)) 110 args = append(args, annotation) 111 } 112 113 for _, annotation := range query.IncludeTags { 114 handleTag(string(annotation), true) 115 } 116 117 for _, annotation := range query.ExcludeTags { 118 handleTag(string(annotation), false) 119 } 120 121 if query.ClientID != "" { 122 clauses = append(clauses, fmt.Sprintf("job.clientid = %s", getQueryCounter())) 123 args = append(args, query.ClientID) 124 } 125 126 after := "" 127 128 applyOrdering := func(field string) { 129 order := "asc" 130 if query.SortReverse { 131 order = "desc" 132 } 133 after = after + " order by " + field + " " + order 134 } 135 136 if query.SortBy == "created_at" { 137 applyOrdering("created") 138 } else if query.SortBy == "id" { 139 applyOrdering("id") 140 } else if query.SortBy != "" { 141 return "", nil, fmt.Errorf("invalid sort_by: %s", query.SortBy) 142 } 143 144 if query.Limit > 0 { 145 after = after + fmt.Sprintf(" limit %d", query.Limit) 146 } 147 148 if query.Offset > 0 { 149 after = after + fmt.Sprintf(" offset %d", query.Offset) 150 } 151 152 where := strings.Join(clauses, " and ") 153 154 if where != "" { 155 where = "where " + where 156 } 157 158 sql := fmt.Sprintf(` 159 select 160 apiversion, 161 jobdata, 162 statedata 163 from 164 job 165 %s 166 %s 167 `, where, after) 168 169 if countMode { 170 sql = fmt.Sprintf(` 171 select 172 count(job.id) as count 173 from 174 job 175 %s 176 %s 177 `, where, after) 178 } 179 180 return sql, args, nil 181 } 182 183 func getJobs(db SQLClient, ctx context.Context, query localdb.JobQuery) ([]*model.Job, error) { 184 if query.ID != "" { 185 job, err := getJob(db, ctx, query.ID) 186 if err != nil { 187 return nil, err 188 } 189 return []*model.Job{job}, nil 190 } 191 192 sql, args, err := getJobsSQL(query, false) 193 if err != nil { 194 return nil, err 195 } 196 197 rows, err := db.QueryContext(ctx, sql, args...) 198 if err != nil { 199 return nil, err 200 } 201 defer rows.Close() 202 jobs := []*model.Job{} 203 for rows.Next() { 204 var innerErr error 205 var apiversion string 206 var jobdata string 207 var statedata string 208 var job model.Job 209 if innerErr = rows.Scan(&apiversion, &jobdata, &statedata); err != nil { 210 return jobs, innerErr 211 } 212 job, innerErr = model.APIVersionParseJob(apiversion, jobdata) 213 if innerErr != nil { 214 return nil, err 215 } 216 state, innerErr := model.APIVersionParseJobState(apiversion, statedata) 217 if innerErr != nil { 218 return nil, err 219 } 220 job.Status.State = state 221 jobs = append(jobs, &job) 222 } 223 if err = rows.Err(); err != nil { 224 return jobs, err 225 } 226 return jobs, nil 227 } 228 229 func (d *GenericSQLDatastore) GetJobs(ctx context.Context, query localdb.JobQuery) ([]*model.Job, error) { 230 d.mtx.RLock() 231 defer d.mtx.RUnlock() 232 return getJobs(d.db, ctx, query) 233 } 234 235 func (d *GenericSQLDatastore) GetJobsCount(ctx context.Context, query localdb.JobQuery) (int, error) { 236 if query.ID != "" { 237 _, err := getJob(d.db, ctx, query.ID) 238 if err != nil { 239 return 0, err 240 } 241 return 1, nil 242 } 243 244 useQuery := query 245 useQuery.Limit = 0 246 useQuery.Offset = 0 247 useQuery.SortBy = "" 248 249 sqlQuery, args, err := getJobsSQL(useQuery, true) 250 if err != nil { 251 return 0, err 252 } 253 254 var count int 255 row := d.db.QueryRow(sqlQuery, args...) 256 err = row.Scan(&count) 257 if err != nil { 258 return 0, err 259 } 260 return count, nil 261 } 262 263 func getJobEvents(db SQLClient, ctx context.Context, id string) ([]model.JobEvent, error) { 264 var args []interface{} 265 args = append(args, id) 266 267 rows, err := db.QueryContext(ctx, ` 268 select 269 apiversion, 270 eventdata 271 from 272 job_event 273 where 274 job_id = $1 275 order by 276 created asc 277 `, args...) 278 if err != nil { 279 return nil, err 280 } 281 defer rows.Close() 282 var events []model.JobEvent 283 for rows.Next() { 284 var apiversion string 285 var eventdata string 286 var ev model.JobEvent 287 if err = rows.Scan(&apiversion, &eventdata); err != nil { 288 return events, err 289 } 290 ev, err = model.APIVersionParseJobEvent(apiversion, eventdata) 291 if err != nil { 292 return nil, err 293 } 294 events = append(events, ev) 295 } 296 if err = rows.Err(); err != nil { 297 return events, err 298 } 299 return events, nil 300 } 301 302 func (d *GenericSQLDatastore) GetJobEvents(ctx context.Context, id string) ([]model.JobEvent, error) { 303 d.mtx.RLock() 304 defer d.mtx.RUnlock() 305 return getJobEvents(d.db, ctx, id) 306 } 307 308 func getJobLocalEvents(db SQLClient, ctx context.Context, id string) ([]model.JobLocalEvent, error) { 309 var args []interface{} 310 args = append(args, id) 311 312 rows, err := db.QueryContext(ctx, ` 313 select 314 apiversion, 315 eventdata 316 from 317 local_event 318 where 319 job_id = $1 320 order by 321 created asc 322 `, args...) 323 if err != nil { 324 return nil, err 325 } 326 defer rows.Close() 327 var events []model.JobLocalEvent 328 for rows.Next() { 329 var apiversion string 330 var eventdata string 331 var ev model.JobLocalEvent 332 if err = rows.Scan(&apiversion, &eventdata); err != nil { 333 return events, err 334 } 335 ev, err = model.APIVersionParseJobLocalEvent(apiversion, eventdata) 336 if err != nil { 337 return nil, err 338 } 339 events = append(events, ev) 340 } 341 if err = rows.Err(); err != nil { 342 return events, err 343 } 344 return events, nil 345 } 346 347 func (d *GenericSQLDatastore) GetJobLocalEvents(ctx context.Context, id string) ([]model.JobLocalEvent, error) { 348 d.mtx.RLock() 349 defer d.mtx.RUnlock() 350 return getJobLocalEvents(d.db, ctx, id) 351 } 352 353 func (d *GenericSQLDatastore) HasLocalEvent(ctx context.Context, jobID string, eventFilter localdb.LocalEventFilter) (bool, error) { 354 jobLocalEvents, err := d.GetJobLocalEvents(ctx, jobID) 355 if err != nil { 356 return false, err 357 } 358 hasEvent := false 359 for _, localEvent := range jobLocalEvents { 360 if eventFilter(localEvent) { 361 hasEvent = true 362 break 363 } 364 } 365 return hasEvent, nil 366 } 367 368 func (d *GenericSQLDatastore) AddJob(ctx context.Context, j *model.Job) error { 369 d.mtx.Lock() 370 defer d.mtx.Unlock() 371 372 tx, err := d.db.Begin() 373 if err != nil { 374 return err 375 } 376 //nolint:errcheck 377 defer tx.Rollback() 378 379 sqlStatement := ` 380 INSERT INTO job (id, created, executor, clientid, apiversion, jobdata) 381 VALUES ($1, $2, $3, $4, $5, $6)` 382 jobData, err := json.Marshal(j) 383 if err != nil { 384 return err 385 } 386 _, err = tx.ExecContext( 387 ctx, 388 sqlStatement, 389 j.Metadata.ID, 390 j.Metadata.CreatedAt.UTC().Format(time.RFC3339), 391 j.Spec.Engine.String(), 392 j.Metadata.ClientID, 393 model.APIVersionLatest().String(), 394 string(jobData), 395 ) 396 if err != nil { 397 return err 398 } 399 for _, annotation := range j.Spec.Annotations { 400 sqlStatement := ` 401 INSERT INTO job_annotation (job_id, annotation) 402 VALUES ($1, $2)` 403 _, err = tx.ExecContext( 404 ctx, 405 sqlStatement, 406 j.Metadata.ID, 407 annotation, 408 ) 409 if err != nil { 410 return err 411 } 412 } 413 return tx.Commit() 414 } 415 416 func (d *GenericSQLDatastore) AddEvent(ctx context.Context, jobID string, ev model.JobEvent) error { 417 d.mtx.Lock() 418 defer d.mtx.Unlock() 419 //nolint:ineffassign,staticcheck 420 sqlStatement := ` 421 INSERT INTO job_event (job_id, created, apiversion, eventdata) 422 VALUES ($1, $2, $3, $4)` 423 eventData, err := json.Marshal(ev) 424 if err != nil { 425 return err 426 } 427 _, err = d.db.ExecContext( 428 ctx, 429 sqlStatement, 430 jobID, 431 ev.EventTime.UTC().Format(time.RFC3339), 432 model.APIVersionLatest().String(), 433 string(eventData), 434 ) 435 if err != nil { 436 return err 437 } 438 return nil 439 } 440 441 func (d *GenericSQLDatastore) AddLocalEvent(ctx context.Context, jobID string, ev model.JobLocalEvent) error { 442 d.mtx.Lock() 443 defer d.mtx.Unlock() 444 //nolint:ineffassign,staticcheck 445 sqlStatement := ` 446 INSERT INTO local_event (job_id, created, apiversion, eventdata) 447 VALUES ($1, $2, $3, $4)` 448 eventData, err := json.Marshal(ev) 449 if err != nil { 450 return err 451 } 452 _, err = d.db.ExecContext( 453 ctx, 454 sqlStatement, 455 jobID, 456 time.Now().UTC().Format(time.RFC3339), 457 model.APIVersionLatest().String(), 458 string(eventData), 459 ) 460 if err != nil { 461 return err 462 } 463 return nil 464 } 465 466 func (d *GenericSQLDatastore) UpdateJobDeal(ctx context.Context, jobID string, deal model.Deal) error { 467 d.mtx.Lock() 468 defer d.mtx.Unlock() 469 //nolint:ineffassign,staticcheck 470 tx, err := d.db.Begin() 471 if err != nil { 472 return err 473 } 474 //nolint:errcheck 475 defer tx.Rollback() 476 477 job, err := getJob(tx, ctx, jobID) 478 if err != nil { 479 return err 480 } 481 job.Spec.Deal = deal 482 sqlStatement := `UPDATE JOB SET jobdata = $1, apiversion = $2 WHERE id = $3` 483 jobData, err := json.Marshal(job) 484 if err != nil { 485 return err 486 } 487 _, err = tx.ExecContext( 488 ctx, 489 sqlStatement, 490 string(jobData), 491 model.APIVersionLatest().String(), 492 jobID, 493 ) 494 if err != nil { 495 return err 496 } 497 return tx.Commit() 498 } 499 500 func getJobState(db SQLClient, ctx context.Context, jobID string) (model.JobState, error) { 501 var apiversion string 502 var statedata string 503 row := db.QueryRowContext(ctx, "select apiversion, statedata from job where id = $1 limit 1", jobID) 504 505 err := row.Scan(&apiversion, &statedata) 506 if err != nil { 507 if err == sql.ErrNoRows { 508 return model.JobState{}, fmt.Errorf("job not found: %s %s", jobID, err.Error()) 509 } else { 510 return model.JobState{}, err 511 } 512 } 513 if statedata == "" { 514 return model.JobState{ 515 Nodes: map[string]model.JobNodeState{}, 516 }, nil 517 } else { 518 return model.APIVersionParseJobState(apiversion, statedata) 519 } 520 } 521 522 func (d *GenericSQLDatastore) GetJobState(ctx context.Context, jobID string) (model.JobState, error) { 523 d.mtx.RLock() 524 defer d.mtx.RUnlock() 525 return getJobState(d.db, ctx, jobID) 526 } 527 528 func (d *GenericSQLDatastore) UpdateShardState( 529 ctx context.Context, 530 jobID, nodeID string, 531 shardIndex int, 532 update model.JobShardState, 533 ) error { 534 d.mtx.Lock() 535 defer d.mtx.Unlock() 536 tx, err := d.db.Begin() 537 if err != nil { 538 return err 539 } 540 //nolint:errcheck 541 defer tx.Rollback() 542 543 state, err := getJobState(tx, ctx, jobID) 544 if err != nil { 545 return err 546 } 547 err = UpdateShardState(nodeID, shardIndex, &state, update) 548 if err != nil { 549 return err 550 } 551 sqlStatement := `UPDATE JOB SET statedata = $1, apiversion = $2 WHERE id = $3` 552 stateData, err := json.Marshal(state) 553 if err != nil { 554 return err 555 } 556 _, err = tx.ExecContext( 557 ctx, 558 sqlStatement, 559 string(stateData), 560 model.APIVersionLatest().String(), 561 jobID, 562 ) 563 if err != nil { 564 return err 565 } 566 return tx.Commit() 567 } 568 569 //go:embed migrations/*.sql 570 var fs embed.FS 571 572 func (d *GenericSQLDatastore) GetMigrations() (*migrate.Migrate, error) { 573 files, err := iofs.New(fs, "migrations") 574 if err != nil { 575 return nil, err 576 } 577 migrations, err := migrate.NewWithSourceInstance("iofs", files, d.connectionString) 578 if err != nil { 579 return nil, err 580 } 581 return migrations, nil 582 } 583 584 func (d *GenericSQLDatastore) MigrateUp() error { 585 migrations, err := d.GetMigrations() 586 if err != nil { 587 return err 588 } 589 err = migrations.Up() 590 if err != migrate.ErrNoChange { 591 return err 592 } 593 return nil 594 } 595 596 func (d *GenericSQLDatastore) MigrateDown() error { 597 migrations, err := d.GetMigrations() 598 if err != nil { 599 return err 600 } 601 err = migrations.Down() 602 if err != migrate.ErrNoChange { 603 return err 604 } 605 return nil 606 } 607 608 // Static check to ensure that Transport implements Transport: 609 var _ localdb.LocalDB = (*GenericSQLDatastore)(nil)