github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/localdb/shared/genericSQL.go (about)

     1  package shared
     2  
     3  import (
     4  	"context"
     5  	"embed"
     6  	"encoding/json"
     7  	"fmt"
     8  	"strings"
     9  	"time"
    10  
    11  	"database/sql"
    12  
    13  	sync "github.com/bacalhau-project/golang-mutex-tracer"
    14  	"github.com/filecoin-project/bacalhau/pkg/bacerrors"
    15  	"github.com/filecoin-project/bacalhau/pkg/localdb"
    16  	model "github.com/filecoin-project/bacalhau/pkg/model/v1beta1"
    17  	"github.com/golang-migrate/migrate/v4"
    18  	"github.com/golang-migrate/migrate/v4/source/iofs"
    19  )
    20  
    21  // SQLClient is so we can pass *sql.DB and *sql.Tx to the same functions
    22  type SQLClient interface {
    23  	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
    24  	QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
    25  	QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
    26  }
    27  
    28  type GenericSQLDatastore struct {
    29  	mtx              sync.RWMutex
    30  	connectionString string
    31  	db               *sql.DB
    32  }
    33  
    34  func NewGenericSQLDatastore(
    35  	db *sql.DB,
    36  	name string,
    37  	connectionString string,
    38  ) (*GenericSQLDatastore, error) {
    39  	datastore := &GenericSQLDatastore{
    40  		connectionString: connectionString,
    41  		db:               db,
    42  	}
    43  	datastore.mtx.EnableTracerWithOpts(sync.Opts{
    44  		Threshold: 10 * time.Millisecond,
    45  		Id:        fmt.Sprintf("GenericSQLDatastore[%s].mtx", name),
    46  	})
    47  	return datastore, nil
    48  }
    49  
    50  func (d *GenericSQLDatastore) GetDB() *sql.DB {
    51  	return d.db
    52  }
    53  
    54  func getJob(db SQLClient, ctx context.Context, id string) (*model.Job, error) {
    55  	var apiversion string
    56  	var jobdata string
    57  	var statedata string
    58  	row := db.QueryRowContext(ctx, `select apiversion, jobdata, statedata from job where id like $1 || '%'`, strings.ToLower(id))
    59  	err := row.Scan(&apiversion, &jobdata, &statedata)
    60  	if err != nil {
    61  		if err == sql.ErrNoRows {
    62  			return nil, bacerrors.NewJobNotFound(id)
    63  		} else {
    64  			return nil, err
    65  		}
    66  	}
    67  	job, err := model.APIVersionParseJob(apiversion, jobdata)
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  	state, err := model.APIVersionParseJobState(apiversion, statedata)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	job.Status.State = state
    76  	return &job, nil
    77  }
    78  
    79  func (d *GenericSQLDatastore) GetJob(ctx context.Context, id string) (*model.Job, error) {
    80  	d.mtx.RLock()
    81  	defer d.mtx.RUnlock()
    82  	return getJob(d.db, ctx, id)
    83  }
    84  
    85  func getJobsSQL(
    86  	query localdb.JobQuery,
    87  	countMode bool,
    88  ) (string, []interface{}, error) {
    89  	var args []interface{}
    90  	clauses := []string{}
    91  
    92  	queryCounter := 0
    93  	getQueryCounter := func() string {
    94  		queryCounter++
    95  		return fmt.Sprintf("$%d", queryCounter)
    96  	}
    97  
    98  	handleTag := func(annotation string, include bool) {
    99  		appendQuery := " < 1"
   100  		if include {
   101  			appendQuery = " > 0"
   102  		}
   103  		clauses = append(clauses, fmt.Sprintf(`
   104  		(
   105  			select count(*) from job_annotation
   106  			where job_annotation.annotation = %s
   107  			and job_annotation.job_id = job.id
   108  		) %s
   109  		`, getQueryCounter(), appendQuery))
   110  		args = append(args, annotation)
   111  	}
   112  
   113  	for _, annotation := range query.IncludeTags {
   114  		handleTag(string(annotation), true)
   115  	}
   116  
   117  	for _, annotation := range query.ExcludeTags {
   118  		handleTag(string(annotation), false)
   119  	}
   120  
   121  	if query.ClientID != "" {
   122  		clauses = append(clauses, fmt.Sprintf("job.clientid = %s", getQueryCounter()))
   123  		args = append(args, query.ClientID)
   124  	}
   125  
   126  	after := ""
   127  
   128  	applyOrdering := func(field string) {
   129  		order := "asc"
   130  		if query.SortReverse {
   131  			order = "desc"
   132  		}
   133  		after = after + " order by " + field + " " + order
   134  	}
   135  
   136  	if query.SortBy == "created_at" {
   137  		applyOrdering("created")
   138  	} else if query.SortBy == "id" {
   139  		applyOrdering("id")
   140  	} else if query.SortBy != "" {
   141  		return "", nil, fmt.Errorf("invalid sort_by: %s", query.SortBy)
   142  	}
   143  
   144  	if query.Limit > 0 {
   145  		after = after + fmt.Sprintf(" limit %d", query.Limit)
   146  	}
   147  
   148  	if query.Offset > 0 {
   149  		after = after + fmt.Sprintf(" offset %d", query.Offset)
   150  	}
   151  
   152  	where := strings.Join(clauses, " and ")
   153  
   154  	if where != "" {
   155  		where = "where " + where
   156  	}
   157  
   158  	sql := fmt.Sprintf(`
   159  select
   160  	apiversion,
   161  	jobdata,
   162  	statedata
   163  from
   164  	job
   165  %s
   166  %s
   167  `, where, after)
   168  
   169  	if countMode {
   170  		sql = fmt.Sprintf(`
   171  select
   172  	count(job.id) as count
   173  from
   174  	job
   175  %s
   176  %s
   177  `, where, after)
   178  	}
   179  
   180  	return sql, args, nil
   181  }
   182  
   183  func getJobs(db SQLClient, ctx context.Context, query localdb.JobQuery) ([]*model.Job, error) {
   184  	if query.ID != "" {
   185  		job, err := getJob(db, ctx, query.ID)
   186  		if err != nil {
   187  			return nil, err
   188  		}
   189  		return []*model.Job{job}, nil
   190  	}
   191  
   192  	sql, args, err := getJobsSQL(query, false)
   193  	if err != nil {
   194  		return nil, err
   195  	}
   196  
   197  	rows, err := db.QueryContext(ctx, sql, args...)
   198  	if err != nil {
   199  		return nil, err
   200  	}
   201  	defer rows.Close()
   202  	jobs := []*model.Job{}
   203  	for rows.Next() {
   204  		var innerErr error
   205  		var apiversion string
   206  		var jobdata string
   207  		var statedata string
   208  		var job model.Job
   209  		if innerErr = rows.Scan(&apiversion, &jobdata, &statedata); err != nil {
   210  			return jobs, innerErr
   211  		}
   212  		job, innerErr = model.APIVersionParseJob(apiversion, jobdata)
   213  		if innerErr != nil {
   214  			return nil, err
   215  		}
   216  		state, innerErr := model.APIVersionParseJobState(apiversion, statedata)
   217  		if innerErr != nil {
   218  			return nil, err
   219  		}
   220  		job.Status.State = state
   221  		jobs = append(jobs, &job)
   222  	}
   223  	if err = rows.Err(); err != nil {
   224  		return jobs, err
   225  	}
   226  	return jobs, nil
   227  }
   228  
   229  func (d *GenericSQLDatastore) GetJobs(ctx context.Context, query localdb.JobQuery) ([]*model.Job, error) {
   230  	d.mtx.RLock()
   231  	defer d.mtx.RUnlock()
   232  	return getJobs(d.db, ctx, query)
   233  }
   234  
   235  func (d *GenericSQLDatastore) GetJobsCount(ctx context.Context, query localdb.JobQuery) (int, error) {
   236  	if query.ID != "" {
   237  		_, err := getJob(d.db, ctx, query.ID)
   238  		if err != nil {
   239  			return 0, err
   240  		}
   241  		return 1, nil
   242  	}
   243  
   244  	useQuery := query
   245  	useQuery.Limit = 0
   246  	useQuery.Offset = 0
   247  	useQuery.SortBy = ""
   248  
   249  	sqlQuery, args, err := getJobsSQL(useQuery, true)
   250  	if err != nil {
   251  		return 0, err
   252  	}
   253  
   254  	var count int
   255  	row := d.db.QueryRow(sqlQuery, args...)
   256  	err = row.Scan(&count)
   257  	if err != nil {
   258  		return 0, err
   259  	}
   260  	return count, nil
   261  }
   262  
   263  func getJobEvents(db SQLClient, ctx context.Context, id string) ([]model.JobEvent, error) {
   264  	var args []interface{}
   265  	args = append(args, id)
   266  
   267  	rows, err := db.QueryContext(ctx, `
   268  select
   269  	apiversion,
   270  	eventdata
   271  from
   272  	job_event
   273  where
   274  	job_id = $1
   275  order by
   276  	created asc
   277  `, args...)
   278  	if err != nil {
   279  		return nil, err
   280  	}
   281  	defer rows.Close()
   282  	var events []model.JobEvent
   283  	for rows.Next() {
   284  		var apiversion string
   285  		var eventdata string
   286  		var ev model.JobEvent
   287  		if err = rows.Scan(&apiversion, &eventdata); err != nil {
   288  			return events, err
   289  		}
   290  		ev, err = model.APIVersionParseJobEvent(apiversion, eventdata)
   291  		if err != nil {
   292  			return nil, err
   293  		}
   294  		events = append(events, ev)
   295  	}
   296  	if err = rows.Err(); err != nil {
   297  		return events, err
   298  	}
   299  	return events, nil
   300  }
   301  
   302  func (d *GenericSQLDatastore) GetJobEvents(ctx context.Context, id string) ([]model.JobEvent, error) {
   303  	d.mtx.RLock()
   304  	defer d.mtx.RUnlock()
   305  	return getJobEvents(d.db, ctx, id)
   306  }
   307  
   308  func getJobLocalEvents(db SQLClient, ctx context.Context, id string) ([]model.JobLocalEvent, error) {
   309  	var args []interface{}
   310  	args = append(args, id)
   311  
   312  	rows, err := db.QueryContext(ctx, `
   313  select
   314  	apiversion,
   315  	eventdata
   316  from
   317  	local_event
   318  where
   319  	job_id = $1
   320  order by
   321  	created asc
   322  `, args...)
   323  	if err != nil {
   324  		return nil, err
   325  	}
   326  	defer rows.Close()
   327  	var events []model.JobLocalEvent
   328  	for rows.Next() {
   329  		var apiversion string
   330  		var eventdata string
   331  		var ev model.JobLocalEvent
   332  		if err = rows.Scan(&apiversion, &eventdata); err != nil {
   333  			return events, err
   334  		}
   335  		ev, err = model.APIVersionParseJobLocalEvent(apiversion, eventdata)
   336  		if err != nil {
   337  			return nil, err
   338  		}
   339  		events = append(events, ev)
   340  	}
   341  	if err = rows.Err(); err != nil {
   342  		return events, err
   343  	}
   344  	return events, nil
   345  }
   346  
   347  func (d *GenericSQLDatastore) GetJobLocalEvents(ctx context.Context, id string) ([]model.JobLocalEvent, error) {
   348  	d.mtx.RLock()
   349  	defer d.mtx.RUnlock()
   350  	return getJobLocalEvents(d.db, ctx, id)
   351  }
   352  
   353  func (d *GenericSQLDatastore) HasLocalEvent(ctx context.Context, jobID string, eventFilter localdb.LocalEventFilter) (bool, error) {
   354  	jobLocalEvents, err := d.GetJobLocalEvents(ctx, jobID)
   355  	if err != nil {
   356  		return false, err
   357  	}
   358  	hasEvent := false
   359  	for _, localEvent := range jobLocalEvents {
   360  		if eventFilter(localEvent) {
   361  			hasEvent = true
   362  			break
   363  		}
   364  	}
   365  	return hasEvent, nil
   366  }
   367  
   368  func (d *GenericSQLDatastore) AddJob(ctx context.Context, j *model.Job) error {
   369  	d.mtx.Lock()
   370  	defer d.mtx.Unlock()
   371  
   372  	tx, err := d.db.Begin()
   373  	if err != nil {
   374  		return err
   375  	}
   376  	//nolint:errcheck
   377  	defer tx.Rollback()
   378  
   379  	sqlStatement := `
   380  INSERT INTO job (id, created, executor, clientid, apiversion, jobdata)
   381  VALUES ($1, $2, $3, $4, $5, $6)`
   382  	jobData, err := json.Marshal(j)
   383  	if err != nil {
   384  		return err
   385  	}
   386  	_, err = tx.ExecContext(
   387  		ctx,
   388  		sqlStatement,
   389  		j.Metadata.ID,
   390  		j.Metadata.CreatedAt.UTC().Format(time.RFC3339),
   391  		j.Spec.Engine.String(),
   392  		j.Metadata.ClientID,
   393  		model.APIVersionLatest().String(),
   394  		string(jobData),
   395  	)
   396  	if err != nil {
   397  		return err
   398  	}
   399  	for _, annotation := range j.Spec.Annotations {
   400  		sqlStatement := `
   401  INSERT INTO job_annotation (job_id, annotation)
   402  VALUES ($1, $2)`
   403  		_, err = tx.ExecContext(
   404  			ctx,
   405  			sqlStatement,
   406  			j.Metadata.ID,
   407  			annotation,
   408  		)
   409  		if err != nil {
   410  			return err
   411  		}
   412  	}
   413  	return tx.Commit()
   414  }
   415  
   416  func (d *GenericSQLDatastore) AddEvent(ctx context.Context, jobID string, ev model.JobEvent) error {
   417  	d.mtx.Lock()
   418  	defer d.mtx.Unlock()
   419  	//nolint:ineffassign,staticcheck
   420  	sqlStatement := `
   421  INSERT INTO job_event (job_id, created, apiversion, eventdata)
   422  VALUES ($1, $2, $3, $4)`
   423  	eventData, err := json.Marshal(ev)
   424  	if err != nil {
   425  		return err
   426  	}
   427  	_, err = d.db.ExecContext(
   428  		ctx,
   429  		sqlStatement,
   430  		jobID,
   431  		ev.EventTime.UTC().Format(time.RFC3339),
   432  		model.APIVersionLatest().String(),
   433  		string(eventData),
   434  	)
   435  	if err != nil {
   436  		return err
   437  	}
   438  	return nil
   439  }
   440  
   441  func (d *GenericSQLDatastore) AddLocalEvent(ctx context.Context, jobID string, ev model.JobLocalEvent) error {
   442  	d.mtx.Lock()
   443  	defer d.mtx.Unlock()
   444  	//nolint:ineffassign,staticcheck
   445  	sqlStatement := `
   446  INSERT INTO local_event (job_id, created, apiversion, eventdata)
   447  VALUES ($1, $2, $3, $4)`
   448  	eventData, err := json.Marshal(ev)
   449  	if err != nil {
   450  		return err
   451  	}
   452  	_, err = d.db.ExecContext(
   453  		ctx,
   454  		sqlStatement,
   455  		jobID,
   456  		time.Now().UTC().Format(time.RFC3339),
   457  		model.APIVersionLatest().String(),
   458  		string(eventData),
   459  	)
   460  	if err != nil {
   461  		return err
   462  	}
   463  	return nil
   464  }
   465  
   466  func (d *GenericSQLDatastore) UpdateJobDeal(ctx context.Context, jobID string, deal model.Deal) error {
   467  	d.mtx.Lock()
   468  	defer d.mtx.Unlock()
   469  	//nolint:ineffassign,staticcheck
   470  	tx, err := d.db.Begin()
   471  	if err != nil {
   472  		return err
   473  	}
   474  	//nolint:errcheck
   475  	defer tx.Rollback()
   476  
   477  	job, err := getJob(tx, ctx, jobID)
   478  	if err != nil {
   479  		return err
   480  	}
   481  	job.Spec.Deal = deal
   482  	sqlStatement := `UPDATE JOB SET jobdata = $1, apiversion = $2 WHERE id = $3`
   483  	jobData, err := json.Marshal(job)
   484  	if err != nil {
   485  		return err
   486  	}
   487  	_, err = tx.ExecContext(
   488  		ctx,
   489  		sqlStatement,
   490  		string(jobData),
   491  		model.APIVersionLatest().String(),
   492  		jobID,
   493  	)
   494  	if err != nil {
   495  		return err
   496  	}
   497  	return tx.Commit()
   498  }
   499  
   500  func getJobState(db SQLClient, ctx context.Context, jobID string) (model.JobState, error) {
   501  	var apiversion string
   502  	var statedata string
   503  	row := db.QueryRowContext(ctx, "select apiversion, statedata from job where id = $1 limit 1", jobID)
   504  
   505  	err := row.Scan(&apiversion, &statedata)
   506  	if err != nil {
   507  		if err == sql.ErrNoRows {
   508  			return model.JobState{}, fmt.Errorf("job not found: %s %s", jobID, err.Error())
   509  		} else {
   510  			return model.JobState{}, err
   511  		}
   512  	}
   513  	if statedata == "" {
   514  		return model.JobState{
   515  			Nodes: map[string]model.JobNodeState{},
   516  		}, nil
   517  	} else {
   518  		return model.APIVersionParseJobState(apiversion, statedata)
   519  	}
   520  }
   521  
   522  func (d *GenericSQLDatastore) GetJobState(ctx context.Context, jobID string) (model.JobState, error) {
   523  	d.mtx.RLock()
   524  	defer d.mtx.RUnlock()
   525  	return getJobState(d.db, ctx, jobID)
   526  }
   527  
   528  func (d *GenericSQLDatastore) UpdateShardState(
   529  	ctx context.Context,
   530  	jobID, nodeID string,
   531  	shardIndex int,
   532  	update model.JobShardState,
   533  ) error {
   534  	d.mtx.Lock()
   535  	defer d.mtx.Unlock()
   536  	tx, err := d.db.Begin()
   537  	if err != nil {
   538  		return err
   539  	}
   540  	//nolint:errcheck
   541  	defer tx.Rollback()
   542  
   543  	state, err := getJobState(tx, ctx, jobID)
   544  	if err != nil {
   545  		return err
   546  	}
   547  	err = UpdateShardState(nodeID, shardIndex, &state, update)
   548  	if err != nil {
   549  		return err
   550  	}
   551  	sqlStatement := `UPDATE JOB SET statedata = $1, apiversion = $2 WHERE id = $3`
   552  	stateData, err := json.Marshal(state)
   553  	if err != nil {
   554  		return err
   555  	}
   556  	_, err = tx.ExecContext(
   557  		ctx,
   558  		sqlStatement,
   559  		string(stateData),
   560  		model.APIVersionLatest().String(),
   561  		jobID,
   562  	)
   563  	if err != nil {
   564  		return err
   565  	}
   566  	return tx.Commit()
   567  }
   568  
   569  //go:embed migrations/*.sql
   570  var fs embed.FS
   571  
   572  func (d *GenericSQLDatastore) GetMigrations() (*migrate.Migrate, error) {
   573  	files, err := iofs.New(fs, "migrations")
   574  	if err != nil {
   575  		return nil, err
   576  	}
   577  	migrations, err := migrate.NewWithSourceInstance("iofs", files, d.connectionString)
   578  	if err != nil {
   579  		return nil, err
   580  	}
   581  	return migrations, nil
   582  }
   583  
   584  func (d *GenericSQLDatastore) MigrateUp() error {
   585  	migrations, err := d.GetMigrations()
   586  	if err != nil {
   587  		return err
   588  	}
   589  	err = migrations.Up()
   590  	if err != migrate.ErrNoChange {
   591  		return err
   592  	}
   593  	return nil
   594  }
   595  
   596  func (d *GenericSQLDatastore) MigrateDown() error {
   597  	migrations, err := d.GetMigrations()
   598  	if err != nil {
   599  		return err
   600  	}
   601  	err = migrations.Down()
   602  	if err != migrate.ErrNoChange {
   603  		return err
   604  	}
   605  	return nil
   606  }
   607  
   608  // Static check to ensure that Transport implements Transport:
   609  var _ localdb.LocalDB = (*GenericSQLDatastore)(nil)