github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/syz-cluster/pkg/db/session_test_repo.go (about) 1 // Copyright 2024 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 package db 5 6 import ( 7 "context" 8 9 "cloud.google.com/go/spanner" 10 ) 11 12 type SessionTestRepository struct { 13 client *spanner.Client 14 } 15 16 func NewSessionTestRepository(client *spanner.Client) *SessionTestRepository { 17 return &SessionTestRepository{ 18 client: client, 19 } 20 } 21 22 // If the beforeSave callback is specified, it will be called before saving the entity. 23 func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *SessionTest, 24 beforeSave func(*SessionTest)) error { 25 _, err := repo.client.ReadWriteTransaction(ctx, 26 func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { 27 // Check if the test already exists. 28 dbTest, err := readEntity[SessionTest](ctx, txn, spanner.Statement{ 29 SQL: "SELECT * from `SessionTests` WHERE `SessionID`=@sessionID AND `TestName` = @testName", 30 Params: map[string]interface{}{ 31 "sessionID": test.SessionID, 32 "testName": test.TestName, 33 }, 34 }) 35 var stmts []*spanner.Mutation 36 if err != nil { 37 return err 38 } else if dbTest != nil { 39 if beforeSave != nil { 40 beforeSave(test) 41 } 42 m, err := spanner.UpdateStruct("SessionTests", test) 43 if err != nil { 44 return err 45 } 46 stmts = append(stmts, m) 47 } else { 48 if beforeSave != nil { 49 beforeSave(test) 50 } 51 m, err := spanner.InsertStruct("SessionTests", test) 52 if err != nil { 53 return err 54 } 55 stmts = append(stmts, m) 56 } 57 return txn.BufferWrite(stmts) 58 }) 59 return err 60 } 61 62 func (repo *SessionTestRepository) Get(ctx context.Context, sessionID, testName string) (*SessionTest, error) { 63 return readEntity[SessionTest](ctx, repo.client.Single(), spanner.Statement{ 64 SQL: "SELECT * FROM `SessionTests` WHERE `SessionID` = @session AND `TestName` = @name", 65 Params: map[string]interface{}{ 66 "session": sessionID, 67 "name": testName, 68 }, 69 }) 70 } 71 72 type FullSessionTest struct { 73 *SessionTest 74 BaseBuild *Build 75 PatchedBuild *Build 76 } 77 78 func (repo *SessionTestRepository) BySession(ctx context.Context, sessionID string) ([]*FullSessionTest, error) { 79 list, err := repo.BySessionRaw(ctx, sessionID) 80 if err != nil { 81 return nil, err 82 } 83 var ret []*FullSessionTest 84 needBuilds := map[string][]**Build{} 85 for _, obj := range list { 86 full := &FullSessionTest{SessionTest: obj} 87 ret = append(ret, full) 88 if id := obj.BaseBuildID.String(); !obj.BaseBuildID.IsNull() { 89 needBuilds[id] = append(needBuilds[id], &full.BaseBuild) 90 } 91 if id := obj.PatchedBuildID.String(); !obj.PatchedBuildID.IsNull() { 92 needBuilds[id] = append(needBuilds[id], &full.PatchedBuild) 93 } 94 } 95 if len(needBuilds) > 0 { 96 var keys []string 97 for key := range needBuilds { 98 keys = append(keys, key) 99 } 100 builds, err := readEntities[Build](ctx, repo.client.Single(), spanner.Statement{ 101 SQL: "SELECT * FROM `Builds` WHERE `ID` IN UNNEST(@ids)", 102 Params: map[string]interface{}{"ids": keys}, 103 }) 104 if err != nil { 105 return nil, err 106 } 107 for _, build := range builds { 108 for _, patch := range needBuilds[build.ID] { 109 *patch = build 110 } 111 } 112 } 113 return ret, nil 114 } 115 116 func (repo *SessionTestRepository) BySessionRaw(ctx context.Context, sessionID string) ([]*SessionTest, error) { 117 return readEntities[SessionTest](ctx, repo.client.Single(), spanner.Statement{ 118 SQL: "SELECT * FROM `SessionTests` WHERE `SessionID` = @session" + 119 " ORDER BY `UpdatedAt`", 120 Params: map[string]interface{}{"session": sessionID}, 121 }) 122 }