github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/syz-cluster/pkg/db/session_test_repo.go (about)

     1  // Copyright 2024 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package db
     5  
     6  import (
     7  	"context"
     8  
     9  	"cloud.google.com/go/spanner"
    10  )
    11  
    12  type SessionTestRepository struct {
    13  	client *spanner.Client
    14  }
    15  
    16  func NewSessionTestRepository(client *spanner.Client) *SessionTestRepository {
    17  	return &SessionTestRepository{
    18  		client: client,
    19  	}
    20  }
    21  
    22  // If the beforeSave callback is specified, it will be called before saving the entity.
    23  func (repo *SessionTestRepository) InsertOrUpdate(ctx context.Context, test *SessionTest,
    24  	beforeSave func(*SessionTest)) error {
    25  	_, err := repo.client.ReadWriteTransaction(ctx,
    26  		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
    27  			// Check if the test already exists.
    28  			dbTest, err := readEntity[SessionTest](ctx, txn, spanner.Statement{
    29  				SQL: "SELECT * from `SessionTests` WHERE `SessionID`=@sessionID AND `TestName` = @testName",
    30  				Params: map[string]interface{}{
    31  					"sessionID": test.SessionID,
    32  					"testName":  test.TestName,
    33  				},
    34  			})
    35  			var stmts []*spanner.Mutation
    36  			if err != nil {
    37  				return err
    38  			} else if dbTest != nil {
    39  				if beforeSave != nil {
    40  					beforeSave(test)
    41  				}
    42  				m, err := spanner.UpdateStruct("SessionTests", test)
    43  				if err != nil {
    44  					return err
    45  				}
    46  				stmts = append(stmts, m)
    47  			} else {
    48  				if beforeSave != nil {
    49  					beforeSave(test)
    50  				}
    51  				m, err := spanner.InsertStruct("SessionTests", test)
    52  				if err != nil {
    53  					return err
    54  				}
    55  				stmts = append(stmts, m)
    56  			}
    57  			return txn.BufferWrite(stmts)
    58  		})
    59  	return err
    60  }
    61  
    62  func (repo *SessionTestRepository) Get(ctx context.Context, sessionID, testName string) (*SessionTest, error) {
    63  	return readEntity[SessionTest](ctx, repo.client.Single(), spanner.Statement{
    64  		SQL: "SELECT * FROM `SessionTests` WHERE `SessionID` = @session AND `TestName` = @name",
    65  		Params: map[string]interface{}{
    66  			"session": sessionID,
    67  			"name":    testName,
    68  		},
    69  	})
    70  }
    71  
    72  type FullSessionTest struct {
    73  	*SessionTest
    74  	BaseBuild    *Build
    75  	PatchedBuild *Build
    76  }
    77  
    78  func (repo *SessionTestRepository) BySession(ctx context.Context, sessionID string) ([]*FullSessionTest, error) {
    79  	list, err := repo.BySessionRaw(ctx, sessionID)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	var ret []*FullSessionTest
    84  	needBuilds := map[string][]**Build{}
    85  	for _, obj := range list {
    86  		full := &FullSessionTest{SessionTest: obj}
    87  		ret = append(ret, full)
    88  		if id := obj.BaseBuildID.String(); !obj.BaseBuildID.IsNull() {
    89  			needBuilds[id] = append(needBuilds[id], &full.BaseBuild)
    90  		}
    91  		if id := obj.PatchedBuildID.String(); !obj.PatchedBuildID.IsNull() {
    92  			needBuilds[id] = append(needBuilds[id], &full.PatchedBuild)
    93  		}
    94  	}
    95  	if len(needBuilds) > 0 {
    96  		var keys []string
    97  		for key := range needBuilds {
    98  			keys = append(keys, key)
    99  		}
   100  		builds, err := readEntities[Build](ctx, repo.client.Single(), spanner.Statement{
   101  			SQL:    "SELECT * FROM `Builds` WHERE `ID` IN UNNEST(@ids)",
   102  			Params: map[string]interface{}{"ids": keys},
   103  		})
   104  		if err != nil {
   105  			return nil, err
   106  		}
   107  		for _, build := range builds {
   108  			for _, patch := range needBuilds[build.ID] {
   109  				*patch = build
   110  			}
   111  		}
   112  	}
   113  	return ret, nil
   114  }
   115  
   116  func (repo *SessionTestRepository) BySessionRaw(ctx context.Context, sessionID string) ([]*SessionTest, error) {
   117  	return readEntities[SessionTest](ctx, repo.client.Single(), spanner.Statement{
   118  		SQL: "SELECT * FROM `SessionTests` WHERE `SessionID` = @session" +
   119  			" ORDER BY `UpdatedAt`",
   120  		Params: map[string]interface{}{"session": sessionID},
   121  	})
   122  }