github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/syz-cluster/pkg/db/series_repo.go (about) 1 // Copyright 2024 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 package db 5 6 // TODO: split off some SeriesPatchesRepository. 7 8 import ( 9 "context" 10 "errors" 11 "fmt" 12 "sync" 13 "time" 14 15 "cloud.google.com/go/spanner" 16 "github.com/google/uuid" 17 "google.golang.org/api/iterator" 18 ) 19 20 type SeriesRepository struct { 21 client *spanner.Client 22 *genericEntityOps[Series, string] 23 } 24 25 func NewSeriesRepository(client *spanner.Client) *SeriesRepository { 26 return &SeriesRepository{ 27 client: client, 28 genericEntityOps: &genericEntityOps[Series, string]{ 29 client: client, 30 keyField: "ID", 31 table: "Series", 32 }, 33 } 34 } 35 36 // TODO: move to SeriesPatchesRepository? 37 // nolint:dupl 38 func (repo *SeriesRepository) PatchByID(ctx context.Context, id string) (*Patch, error) { 39 return readEntity[Patch](ctx, repo.client.Single(), spanner.Statement{ 40 SQL: "SELECT * FROM Patches WHERE ID=@id", 41 Params: map[string]interface{}{"id": id}, 42 }) 43 } 44 45 // nolint:dupl 46 func (repo *SeriesRepository) GetByExtID(ctx context.Context, extID string) (*Series, error) { 47 return readEntity[Series](ctx, repo.client.Single(), spanner.Statement{ 48 SQL: "SELECT * FROM Series WHERE ExtID=@extID", 49 Params: map[string]interface{}{"extID": extID}, 50 }) 51 } 52 53 var ErrSeriesExists = errors.New("the series already exists") 54 55 // Insert() checks whether there already exists a series with the same ExtID. 56 // Since Patch content is stored elsewhere, we do not demand it be filled out before calling Insert(). 57 // Instead, Insert() obtains this data via a callback. 58 func (repo *SeriesRepository) Insert(ctx context.Context, series *Series, 59 queryPatches func() ([]*Patch, error)) error { 60 var patches []*Patch 61 var patchesErr error 62 var patchesOnce sync.Once 63 doQueryPatches := func() { 64 if queryPatches == nil { 65 return 66 } 67 patches, patchesErr = queryPatches() 68 } 69 if series.ID == "" { 70 series.ID = uuid.NewString() 71 } 72 _, err := repo.client.ReadWriteTransaction(ctx, 73 func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { 74 // Check if the series already exists. 75 stmt := spanner.Statement{ 76 SQL: "SELECT 1 from `Series` WHERE `ExtID`=@extID", 77 Params: map[string]interface{}{"ExtID": series.ExtID}, 78 } 79 iter := txn.Query(ctx, stmt) 80 defer iter.Stop() 81 82 _, iterErr := iter.Next() 83 if iterErr == nil { 84 return ErrSeriesExists 85 } else if iterErr != iterator.Done { 86 return iterErr 87 } 88 // Query patches (once). 89 patchesOnce.Do(doQueryPatches) 90 if patchesErr != nil { 91 return patchesErr 92 } 93 // Save the objects. 94 var stmts []*spanner.Mutation 95 seriesStmt, err := spanner.InsertStruct("Series", series) 96 if err != nil { 97 return err 98 } 99 stmts = append(stmts, seriesStmt) 100 for _, patch := range patches { 101 patch.ID = uuid.NewString() 102 patch.SeriesID = series.ID 103 stmt, err := spanner.InsertStruct("Patches", patch) 104 if err != nil { 105 return err 106 } 107 stmts = append(stmts, stmt) 108 } 109 return txn.BufferWrite(stmts) 110 }) 111 return err 112 } 113 114 func (repo *SeriesRepository) Count(ctx context.Context) (int, error) { 115 stmt := spanner.Statement{SQL: "SELECT COUNT(*) FROM `Series`"} 116 var count int64 117 err := repo.client.Single().Query(ctx, stmt).Do(func(row *spanner.Row) error { 118 return row.Column(0, &count) 119 }) 120 return int(count), err 121 } 122 123 type SeriesWithSession struct { 124 Series *Series 125 Session *Session 126 Findings int 127 } 128 129 type SeriesFilter struct { 130 Cc string 131 Status SessionStatus 132 WithFindings bool 133 Limit int 134 Offset int 135 } 136 137 // ListLatest() returns the list of series ordered by the decreasing PublishedAt value. 138 func (repo *SeriesRepository) ListLatest(ctx context.Context, filter SeriesFilter, 139 maxPublishedAt time.Time) ([]*SeriesWithSession, error) { 140 ro := repo.client.ReadOnlyTransaction() 141 defer ro.Close() 142 143 stmt := spanner.Statement{ 144 SQL: "SELECT Series.* FROM Series WHERE 1=1", 145 Params: map[string]interface{}{}, 146 } 147 if !maxPublishedAt.IsZero() { 148 stmt.SQL += " AND PublishedAt < @toTime" 149 stmt.Params["toTime"] = maxPublishedAt 150 } 151 if filter.Cc != "" { 152 stmt.SQL += " AND @cc IN UNNEST(Cc)" 153 stmt.Params["cc"] = filter.Cc 154 } 155 if filter.Status != SessionStatusAny { 156 // It could have been an INNER JOIN in the main query, but let's favor the simpler code 157 // in this function. 158 // The optimizer should transform the query to a JOIN anyway. 159 stmt.SQL += " AND EXISTS(SELECT 1 FROM Sessions WHERE" 160 switch filter.Status { 161 case SessionStatusWaiting: 162 stmt.SQL += " Sessions.SeriesID = Series.ID AND Sessions.StartedAt IS NULL" 163 case SessionStatusInProgress: 164 stmt.SQL += " Sessions.ID = Series.LatestSessionID AND Sessions.FinishedAt IS NULL" 165 case SessionStatusFinished: 166 stmt.SQL += " Sessions.ID = Series.LatestSessionID AND Sessions.FinishedAt IS NOT NULL" + 167 " AND Sessions.SkipReason IS NULL" 168 case SessionStatusSkipped: 169 stmt.SQL += " Sessions.ID = Series.LatestSessionID AND Sessions.SkipReason IS NOT NULL" 170 default: 171 return nil, fmt.Errorf("unknown status value: %q", filter.Status) 172 } 173 stmt.SQL += ")" 174 } 175 if filter.WithFindings { 176 stmt.SQL += " AND Series.LatestSessionID IS NOT NULL AND EXISTS(" + 177 "SELECT 1 FROM Findings WHERE " + 178 "Findings.SessionID = Series.LatestSessionID AND Findings.InvalidatedAt IS NULL)" 179 } 180 stmt.SQL += " ORDER BY PublishedAt DESC, ID" 181 if filter.Limit > 0 { 182 stmt.SQL += " LIMIT @limit" 183 stmt.Params["limit"] = filter.Limit 184 } 185 if filter.Offset > 0 { 186 stmt.SQL += " OFFSET @offset" 187 stmt.Params["offset"] = filter.Offset 188 } 189 seriesList, err := readEntities[Series](ctx, ro, stmt) 190 if err != nil { 191 return nil, err 192 } 193 194 // Now query Sessions. 195 var ret []*SeriesWithSession 196 for _, series := range seriesList { 197 obj := &SeriesWithSession{Series: series} 198 ret = append(ret, obj) 199 } 200 201 // And the rest of the data. 202 err = repo.querySessions(ctx, ro, ret) 203 if err != nil { 204 return nil, fmt.Errorf("failed to query sessions: %w", err) 205 } 206 err = repo.queryFindingCounts(ctx, ro, ret) 207 if err != nil { 208 return nil, fmt.Errorf("failed to query finding counts: %w", err) 209 } 210 return ret, nil 211 } 212 213 func (repo *SeriesRepository) querySessions(ctx context.Context, ro *spanner.ReadOnlyTransaction, 214 seriesList []*SeriesWithSession) error { 215 idToSeries := map[string]*SeriesWithSession{} 216 var keys []string 217 for _, item := range seriesList { 218 series := item.Series 219 idToSeries[series.ID] = item 220 if !series.LatestSessionID.IsNull() { 221 keys = append(keys, series.LatestSessionID.String()) 222 } 223 } 224 if len(keys) == 0 { 225 return nil 226 } 227 sessions, err := readEntities[Session](ctx, ro, spanner.Statement{ 228 SQL: "SELECT * FROM Sessions WHERE ID IN UNNEST(@ids)", 229 Params: map[string]interface{}{ 230 "ids": keys, 231 }, 232 }) 233 if err != nil { 234 return err 235 } 236 for _, session := range sessions { 237 obj := idToSeries[session.SeriesID] 238 if obj != nil { 239 obj.Session = session 240 } 241 } 242 return nil 243 } 244 245 func (repo *SeriesRepository) queryFindingCounts(ctx context.Context, ro *spanner.ReadOnlyTransaction, 246 seriesList []*SeriesWithSession) error { 247 var keys []string 248 sessionToSeries := map[string]*SeriesWithSession{} 249 for _, series := range seriesList { 250 if series.Session == nil || series.Session.Status() != SessionStatusFinished { 251 continue 252 } 253 keys = append(keys, series.Session.ID) 254 sessionToSeries[series.Session.ID] = series 255 } 256 if len(keys) == 0 { 257 return nil 258 } 259 260 type findingCount struct { 261 SessionID string `spanner:"SessionID"` 262 Count int64 `spanner:"Count"` 263 } 264 list, err := readEntities[findingCount](ctx, repo.client.Single(), spanner.Statement{ 265 SQL: "SELECT `SessionID`, COUNT(`ID`) as `Count` FROM `Findings` " + 266 "WHERE `SessionID` IN UNNEST(@ids) AND `Findings`.`InvalidatedAt` IS NULL " + 267 "GROUP BY `SessionID`", 268 Params: map[string]interface{}{ 269 "ids": keys, 270 }, 271 }) 272 if err != nil { 273 return err 274 } 275 for _, item := range list { 276 sessionToSeries[item.SessionID].Findings = int(item.Count) 277 } 278 return nil 279 } 280 281 // golint sees too much similarity with SessionRepository's ListForSeries, but in reality there's not. 282 // nolint:dupl 283 func (repo *SeriesRepository) ListPatches(ctx context.Context, series *Series) ([]*Patch, error) { 284 return readEntities[Patch](ctx, repo.client.Single(), spanner.Statement{ 285 SQL: "SELECT * FROM `Patches` WHERE `SeriesID` = @seriesID ORDER BY `Seq`", 286 Params: map[string]interface{}{ 287 "seriesID": series.ID, 288 }, 289 }) 290 }