github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/syz-cluster/pkg/db/session_repo.go (about)

     1  // Copyright 2024 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package db
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"time"
    10  
    11  	"cloud.google.com/go/spanner"
    12  	"github.com/google/uuid"
    13  )
    14  
    15  type SessionRepository struct {
    16  	client *spanner.Client
    17  	*genericEntityOps[Session, string]
    18  }
    19  
    20  func NewSessionRepository(client *spanner.Client) *SessionRepository {
    21  	return &SessionRepository{
    22  		client: client,
    23  		genericEntityOps: &genericEntityOps[Session, string]{
    24  			client:   client,
    25  			keyField: "ID",
    26  			table:    "Sessions",
    27  		},
    28  	}
    29  }
    30  
    31  var ErrSessionAlreadyStarted = errors.New("the session already started")
    32  
    33  func (repo *SessionRepository) Start(ctx context.Context, sessionID string) error {
    34  	_, err := repo.client.ReadWriteTransaction(ctx,
    35  		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
    36  			session, err := readEntity[Session](ctx, txn, spanner.Statement{
    37  				SQL:    "SELECT * from `Sessions` WHERE `ID`=@id",
    38  				Params: map[string]interface{}{"id": sessionID},
    39  			})
    40  			if err != nil {
    41  				return err
    42  			}
    43  			if !session.StartedAt.IsNull() {
    44  				return ErrSessionAlreadyStarted
    45  			}
    46  			session.SetStartedAt(time.Now())
    47  			updateSession, err := spanner.UpdateStruct("Sessions", session)
    48  			if err != nil {
    49  				return err
    50  			}
    51  			series, err := readEntity[Series](ctx, txn, spanner.Statement{
    52  				SQL:    "SELECT * from `Series` WHERE `ID`=@id",
    53  				Params: map[string]interface{}{"id": session.SeriesID},
    54  			})
    55  			if err != nil {
    56  				return err
    57  			}
    58  			series.SetLatestSession(session)
    59  			updateSeries, err := spanner.UpdateStruct("Series", series)
    60  			if err != nil {
    61  				return err
    62  			}
    63  			return txn.BufferWrite([]*spanner.Mutation{updateSeries, updateSession})
    64  		})
    65  	return err
    66  }
    67  
    68  func (repo *SessionRepository) Insert(ctx context.Context, session *Session) error {
    69  	if session.ID == "" {
    70  		session.ID = uuid.NewString()
    71  	}
    72  	return repo.genericEntityOps.Insert(ctx, session)
    73  }
    74  
    75  func (repo *SessionRepository) ListRunning(ctx context.Context) ([]*Session, error) {
    76  	return repo.readEntities(ctx, spanner.Statement{
    77  		SQL: "SELECT * FROM `Sessions` WHERE `StartedAt` IS NOT NULL AND `FinishedAt` IS NULL",
    78  	})
    79  }
    80  
    81  type NextSession struct {
    82  	id        string
    83  	createdAt time.Time
    84  }
    85  
    86  func (repo *SessionRepository) ListWaiting(ctx context.Context, from *NextSession,
    87  	limit int) ([]*Session, *NextSession, error) {
    88  	stmt := spanner.Statement{
    89  		SQL:    "SELECT * FROM `Sessions` WHERE `StartedAt` IS NULL",
    90  		Params: map[string]interface{}{},
    91  	}
    92  	if from != nil {
    93  		stmt.SQL += " AND ((`CreatedAt` > @from) OR (`CreatedAt` = @from AND `ID` > @id))"
    94  		stmt.Params["from"] = from.createdAt
    95  		stmt.Params["id"] = from.id
    96  	}
    97  	stmt.SQL += " ORDER BY `CreatedAt`, `ID`"
    98  	addLimit(&stmt, limit)
    99  	list, err := repo.readEntities(ctx, stmt)
   100  
   101  	var next *NextSession
   102  	if err == nil && len(list) > 0 {
   103  		last := list[len(list)-1]
   104  		next = &NextSession{
   105  			id:        last.ID,
   106  			createdAt: last.CreatedAt,
   107  		}
   108  	}
   109  	return list, next, err
   110  }
   111  
   112  // golint sees too much similarity with SeriesRepository's ListPatches, but in reality there's not.
   113  // nolint:dupl
   114  func (repo *SessionRepository) ListForSeries(ctx context.Context, series *Series) ([]*Session, error) {
   115  	return repo.readEntities(ctx, spanner.Statement{
   116  		SQL:    "SELECT * FROM `Sessions` WHERE `SeriesID` = @series ORDER BY CreatedAt DESC",
   117  		Params: map[string]interface{}{"series": series.ID},
   118  	})
   119  }
   120  
   121  // MissingReportList lists the session objects that are missing any SessionReport objects,
   122  // but do have Findings.
   123  // Once the conditions for creating a SessionRepor object become more complex, it will
   124  // likely be not enough to have this simple method, but for now it should be fine.
   125  func (repo *SessionRepository) MissingReportList(ctx context.Context, from time.Time, limit int) ([]*Session, error) {
   126  	stmt := spanner.Statement{
   127  		SQL: "SELECT * FROM Sessions WHERE FinishedAt IS NOT NULL " +
   128  			" AND NOT EXISTS (" +
   129  			"SELECT 1 FROM SessionReports WHERE SessionReports.SessionID = Sessions.ID" +
   130  			") AND EXISTS (" +
   131  			"SELECT 1 FROM Findings WHERE Findings.SessionID = Sessions.ID)",
   132  		Params: map[string]interface{}{},
   133  	}
   134  	if !from.IsZero() {
   135  		stmt.SQL += " AND `FinishedAt` > @from"
   136  		stmt.Params["from"] = from
   137  	}
   138  	stmt.SQL += " ORDER BY `FinishedAt`"
   139  	addLimit(&stmt, limit)
   140  	return repo.readEntities(ctx, stmt)
   141  }