github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/syz-cluster/pkg/db/series_repo.go (about)

     1  // Copyright 2024 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package db
     5  
     6  // TODO: split off some SeriesPatchesRepository.
     7  
     8  import (
     9  	"context"
    10  	"errors"
    11  	"fmt"
    12  	"sync"
    13  	"time"
    14  
    15  	"cloud.google.com/go/spanner"
    16  	"github.com/google/uuid"
    17  	"google.golang.org/api/iterator"
    18  )
    19  
    20  type SeriesRepository struct {
    21  	client *spanner.Client
    22  	*genericEntityOps[Series, string]
    23  }
    24  
    25  func NewSeriesRepository(client *spanner.Client) *SeriesRepository {
    26  	return &SeriesRepository{
    27  		client: client,
    28  		genericEntityOps: &genericEntityOps[Series, string]{
    29  			client:   client,
    30  			keyField: "ID",
    31  			table:    "Series",
    32  		},
    33  	}
    34  }
    35  
    36  // TODO: move to SeriesPatchesRepository?
    37  // nolint:dupl
    38  func (repo *SeriesRepository) PatchByID(ctx context.Context, id string) (*Patch, error) {
    39  	return readEntity[Patch](ctx, repo.client.Single(), spanner.Statement{
    40  		SQL:    "SELECT * FROM Patches WHERE ID=@id",
    41  		Params: map[string]interface{}{"id": id},
    42  	})
    43  }
    44  
    45  // nolint:dupl
    46  func (repo *SeriesRepository) GetByExtID(ctx context.Context, extID string) (*Series, error) {
    47  	return readEntity[Series](ctx, repo.client.Single(), spanner.Statement{
    48  		SQL:    "SELECT * FROM Series WHERE ExtID=@extID",
    49  		Params: map[string]interface{}{"extID": extID},
    50  	})
    51  }
    52  
    53  var ErrSeriesExists = errors.New("the series already exists")
    54  
    55  // Insert() checks whether there already exists a series with the same ExtID.
    56  // Since Patch content is stored elsewhere, we do not demand it be filled out before calling Insert().
    57  // Instead, Insert() obtains this data via a callback.
    58  func (repo *SeriesRepository) Insert(ctx context.Context, series *Series,
    59  	queryPatches func() ([]*Patch, error)) error {
    60  	var patches []*Patch
    61  	var patchesErr error
    62  	var patchesOnce sync.Once
    63  	doQueryPatches := func() {
    64  		if queryPatches == nil {
    65  			return
    66  		}
    67  		patches, patchesErr = queryPatches()
    68  	}
    69  	if series.ID == "" {
    70  		series.ID = uuid.NewString()
    71  	}
    72  	_, err := repo.client.ReadWriteTransaction(ctx,
    73  		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
    74  			// Check if the series already exists.
    75  			stmt := spanner.Statement{
    76  				SQL:    "SELECT 1 from `Series` WHERE `ExtID`=@extID",
    77  				Params: map[string]interface{}{"ExtID": series.ExtID},
    78  			}
    79  			iter := txn.Query(ctx, stmt)
    80  			defer iter.Stop()
    81  
    82  			_, iterErr := iter.Next()
    83  			if iterErr == nil {
    84  				return ErrSeriesExists
    85  			} else if iterErr != iterator.Done {
    86  				return iterErr
    87  			}
    88  			// Query patches (once).
    89  			patchesOnce.Do(doQueryPatches)
    90  			if patchesErr != nil {
    91  				return patchesErr
    92  			}
    93  			// Save the objects.
    94  			var stmts []*spanner.Mutation
    95  			seriesStmt, err := spanner.InsertStruct("Series", series)
    96  			if err != nil {
    97  				return err
    98  			}
    99  			stmts = append(stmts, seriesStmt)
   100  			for _, patch := range patches {
   101  				patch.ID = uuid.NewString()
   102  				patch.SeriesID = series.ID
   103  				stmt, err := spanner.InsertStruct("Patches", patch)
   104  				if err != nil {
   105  					return err
   106  				}
   107  				stmts = append(stmts, stmt)
   108  			}
   109  			return txn.BufferWrite(stmts)
   110  		})
   111  	return err
   112  }
   113  
   114  func (repo *SeriesRepository) Count(ctx context.Context) (int, error) {
   115  	stmt := spanner.Statement{SQL: "SELECT COUNT(*) FROM `Series`"}
   116  	var count int64
   117  	err := repo.client.Single().Query(ctx, stmt).Do(func(row *spanner.Row) error {
   118  		return row.Column(0, &count)
   119  	})
   120  	return int(count), err
   121  }
   122  
   123  type SeriesWithSession struct {
   124  	Series   *Series
   125  	Session  *Session
   126  	Findings int
   127  }
   128  
   129  type SeriesFilter struct {
   130  	Cc           string
   131  	Status       SessionStatus
   132  	WithFindings bool
   133  	Limit        int
   134  	Offset       int
   135  }
   136  
   137  // ListLatest() returns the list of series ordered by the decreasing PublishedAt value.
   138  func (repo *SeriesRepository) ListLatest(ctx context.Context, filter SeriesFilter,
   139  	maxPublishedAt time.Time) ([]*SeriesWithSession, error) {
   140  	ro := repo.client.ReadOnlyTransaction()
   141  	defer ro.Close()
   142  
   143  	stmt := spanner.Statement{
   144  		SQL:    "SELECT Series.* FROM Series WHERE 1=1",
   145  		Params: map[string]interface{}{},
   146  	}
   147  	if !maxPublishedAt.IsZero() {
   148  		stmt.SQL += " AND PublishedAt < @toTime"
   149  		stmt.Params["toTime"] = maxPublishedAt
   150  	}
   151  	if filter.Cc != "" {
   152  		stmt.SQL += " AND @cc IN UNNEST(Cc)"
   153  		stmt.Params["cc"] = filter.Cc
   154  	}
   155  	if filter.Status != SessionStatusAny {
   156  		// It could have been an INNER JOIN in the main query, but let's favor the simpler code
   157  		// in this function.
   158  		// The optimizer should transform the query to a JOIN anyway.
   159  		stmt.SQL += " AND EXISTS(SELECT 1 FROM Sessions WHERE"
   160  		switch filter.Status {
   161  		case SessionStatusWaiting:
   162  			stmt.SQL += " Sessions.SeriesID = Series.ID AND Sessions.StartedAt IS NULL"
   163  		case SessionStatusInProgress:
   164  			stmt.SQL += " Sessions.ID = Series.LatestSessionID AND Sessions.FinishedAt IS NULL"
   165  		case SessionStatusFinished:
   166  			stmt.SQL += " Sessions.ID = Series.LatestSessionID AND Sessions.FinishedAt IS NOT NULL" +
   167  				" AND Sessions.SkipReason IS NULL"
   168  		case SessionStatusSkipped:
   169  			stmt.SQL += " Sessions.ID = Series.LatestSessionID AND Sessions.SkipReason IS NOT NULL"
   170  		default:
   171  			return nil, fmt.Errorf("unknown status value: %q", filter.Status)
   172  		}
   173  		stmt.SQL += ")"
   174  	}
   175  	if filter.WithFindings {
   176  		stmt.SQL += " AND Series.LatestSessionID IS NOT NULL AND EXISTS(" +
   177  			"SELECT 1 FROM Findings WHERE " +
   178  			"Findings.SessionID = Series.LatestSessionID AND Findings.InvalidatedAt IS NULL)"
   179  	}
   180  	stmt.SQL += " ORDER BY PublishedAt DESC, ID"
   181  	if filter.Limit > 0 {
   182  		stmt.SQL += " LIMIT @limit"
   183  		stmt.Params["limit"] = filter.Limit
   184  	}
   185  	if filter.Offset > 0 {
   186  		stmt.SQL += " OFFSET @offset"
   187  		stmt.Params["offset"] = filter.Offset
   188  	}
   189  	seriesList, err := readEntities[Series](ctx, ro, stmt)
   190  	if err != nil {
   191  		return nil, err
   192  	}
   193  
   194  	// Now query Sessions.
   195  	var ret []*SeriesWithSession
   196  	for _, series := range seriesList {
   197  		obj := &SeriesWithSession{Series: series}
   198  		ret = append(ret, obj)
   199  	}
   200  
   201  	// And the rest of the data.
   202  	err = repo.querySessions(ctx, ro, ret)
   203  	if err != nil {
   204  		return nil, fmt.Errorf("failed to query sessions: %w", err)
   205  	}
   206  	err = repo.queryFindingCounts(ctx, ro, ret)
   207  	if err != nil {
   208  		return nil, fmt.Errorf("failed to query finding counts: %w", err)
   209  	}
   210  	return ret, nil
   211  }
   212  
   213  func (repo *SeriesRepository) querySessions(ctx context.Context, ro *spanner.ReadOnlyTransaction,
   214  	seriesList []*SeriesWithSession) error {
   215  	idToSeries := map[string]*SeriesWithSession{}
   216  	var keys []string
   217  	for _, item := range seriesList {
   218  		series := item.Series
   219  		idToSeries[series.ID] = item
   220  		if !series.LatestSessionID.IsNull() {
   221  			keys = append(keys, series.LatestSessionID.String())
   222  		}
   223  	}
   224  	if len(keys) == 0 {
   225  		return nil
   226  	}
   227  	sessions, err := readEntities[Session](ctx, ro, spanner.Statement{
   228  		SQL: "SELECT * FROM Sessions WHERE ID IN UNNEST(@ids)",
   229  		Params: map[string]interface{}{
   230  			"ids": keys,
   231  		},
   232  	})
   233  	if err != nil {
   234  		return err
   235  	}
   236  	for _, session := range sessions {
   237  		obj := idToSeries[session.SeriesID]
   238  		if obj != nil {
   239  			obj.Session = session
   240  		}
   241  	}
   242  	return nil
   243  }
   244  
   245  func (repo *SeriesRepository) queryFindingCounts(ctx context.Context, ro *spanner.ReadOnlyTransaction,
   246  	seriesList []*SeriesWithSession) error {
   247  	var keys []string
   248  	sessionToSeries := map[string]*SeriesWithSession{}
   249  	for _, series := range seriesList {
   250  		if series.Session == nil || series.Session.Status() != SessionStatusFinished {
   251  			continue
   252  		}
   253  		keys = append(keys, series.Session.ID)
   254  		sessionToSeries[series.Session.ID] = series
   255  	}
   256  	if len(keys) == 0 {
   257  		return nil
   258  	}
   259  
   260  	type findingCount struct {
   261  		SessionID string `spanner:"SessionID"`
   262  		Count     int64  `spanner:"Count"`
   263  	}
   264  	list, err := readEntities[findingCount](ctx, repo.client.Single(), spanner.Statement{
   265  		SQL: "SELECT `SessionID`, COUNT(`ID`) as `Count` FROM `Findings` " +
   266  			"WHERE `SessionID` IN UNNEST(@ids) AND `Findings`.`InvalidatedAt` IS NULL " +
   267  			"GROUP BY `SessionID`",
   268  		Params: map[string]interface{}{
   269  			"ids": keys,
   270  		},
   271  	})
   272  	if err != nil {
   273  		return err
   274  	}
   275  	for _, item := range list {
   276  		sessionToSeries[item.SessionID].Findings = int(item.Count)
   277  	}
   278  	return nil
   279  }
   280  
   281  // golint sees too much similarity with SessionRepository's ListForSeries, but in reality there's not.
   282  // nolint:dupl
   283  func (repo *SeriesRepository) ListPatches(ctx context.Context, series *Series) ([]*Patch, error) {
   284  	return readEntities[Patch](ctx, repo.client.Single(), spanner.Statement{
   285  		SQL: "SELECT * FROM `Patches` WHERE `SeriesID` = @seriesID ORDER BY `Seq`",
   286  		Params: map[string]interface{}{
   287  			"seriesID": series.ID,
   288  		},
   289  	})
   290  }