github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/syz-cluster/pkg/db/spanner.go (about)

     1  // Copyright 2024 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package db
     5  
     6  import (
     7  	"bufio"
     8  	"context"
     9  	"embed"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"os"
    14  	"os/exec"
    15  	"regexp"
    16  	"strings"
    17  	"sync"
    18  	"testing"
    19  	"time"
    20  
    21  	"cloud.google.com/go/spanner"
    22  	database "cloud.google.com/go/spanner/admin/database/apiv1"
    23  	"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
    24  	instance "cloud.google.com/go/spanner/admin/instance/apiv1"
    25  	"cloud.google.com/go/spanner/admin/instance/apiv1/instancepb"
    26  	"github.com/golang-migrate/migrate/v4"
    27  	migrate_spanner "github.com/golang-migrate/migrate/v4/database/spanner"
    28  	"github.com/golang-migrate/migrate/v4/source/iofs"
    29  	"google.golang.org/api/iterator"
    30  	"google.golang.org/grpc/codes"
    31  	"google.golang.org/grpc/status"
    32  )
    33  
    34  type ParsedURI struct {
    35  	ProjectPrefix  string // projects/<project>
    36  	InstancePrefix string // projects/<project>/instances/<instance>
    37  	Instance       string
    38  	Database       string
    39  	Full           string
    40  }
    41  
    42  func ParseURI(uri string) (ParsedURI, error) {
    43  	ret := ParsedURI{Full: uri}
    44  	matches := regexp.MustCompile(`projects/(.*)/instances/(.*)/databases/(.*)`).FindStringSubmatch(uri)
    45  	if matches == nil || len(matches) != 4 {
    46  		return ret, fmt.Errorf("failed to parse %q", uri)
    47  	}
    48  	ret.ProjectPrefix = "projects/" + matches[1]
    49  	ret.InstancePrefix = ret.ProjectPrefix + "/instances/" + matches[2]
    50  	ret.Instance = matches[2]
    51  	ret.Database = matches[3]
    52  	return ret, nil
    53  }
    54  
    55  func CreateSpannerInstance(ctx context.Context, uri ParsedURI) error {
    56  	client, err := instance.NewInstanceAdminClient(ctx)
    57  	if err != nil {
    58  		return err
    59  	}
    60  	defer client.Close()
    61  	_, err = client.GetInstance(ctx, &instancepb.GetInstanceRequest{
    62  		Name: uri.InstancePrefix,
    63  	})
    64  	if err != nil && spanner.ErrCode(err) == codes.NotFound {
    65  		_, err = client.CreateInstance(ctx, &instancepb.CreateInstanceRequest{
    66  			Parent:     uri.ProjectPrefix,
    67  			InstanceId: uri.Instance,
    68  		})
    69  		return err
    70  	}
    71  	return err
    72  }
    73  
    74  func CreateSpannerDB(ctx context.Context, uri ParsedURI) error {
    75  	client, err := database.NewDatabaseAdminClient(ctx)
    76  	if err != nil {
    77  		return err
    78  	}
    79  	defer client.Close()
    80  	_, err = client.GetDatabase(ctx, &databasepb.GetDatabaseRequest{Name: uri.Full})
    81  	if err != nil && spanner.ErrCode(err) == codes.NotFound {
    82  		op, err := client.CreateDatabase(ctx, &databasepb.CreateDatabaseRequest{
    83  			Parent:          uri.InstancePrefix,
    84  			CreateStatement: `CREATE DATABASE ` + uri.Database,
    85  			ExtraStatements: []string{},
    86  		})
    87  		if err != nil {
    88  			return err
    89  		}
    90  		_, err = op.Wait(ctx)
    91  		return err
    92  	}
    93  	return err
    94  }
    95  
    96  func dropSpannerDB(ctx context.Context, uri ParsedURI) error {
    97  	client, err := database.NewDatabaseAdminClient(ctx)
    98  	if err != nil {
    99  		return err
   100  	}
   101  	defer client.Close()
   102  	return client.DropDatabase(ctx, &databasepb.DropDatabaseRequest{Database: uri.Full})
   103  }
   104  
   105  //go:embed migrations/*.sql
   106  var migrationsFs embed.FS
   107  
   108  func RunMigrations(uri string) error {
   109  	m, err := getMigrateInstance(uri)
   110  	if err != nil {
   111  		return err
   112  	}
   113  	err = m.Up()
   114  	if err == migrate.ErrNoChange {
   115  		// Not really an error.
   116  		return nil
   117  	}
   118  	return err
   119  }
   120  
   121  func getMigrateInstance(uri string) (*migrate.Migrate, error) {
   122  	sourceDriver, err := iofs.New(migrationsFs, "migrations")
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	s := &migrate_spanner.Spanner{}
   127  	dbDriver, err := s.Open("spanner://" + uri + "?x-clean-statements=true")
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  	m, err := migrate.NewWithInstance("iofs", sourceDriver, "spanner", dbDriver)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	return m, nil
   136  }
   137  
   138  func NewTransientDB(t *testing.T) (*spanner.Client, context.Context) {
   139  	// If the environment contains the emulator binary, start it.
   140  	if bin := os.Getenv("SPANNER_EMULATOR_BIN"); bin != "" {
   141  		host := spannerTestWrapper(t, bin)
   142  		os.Setenv("SPANNER_EMULATOR_HOST", host)
   143  	} else if os.Getenv("CI") != "" {
   144  		// We do want to always run these tests on CI.
   145  		t.Fatalf("CI is set, but SPANNER_EMULATOR_BIN is empty")
   146  	}
   147  	if os.Getenv("SPANNER_EMULATOR_HOST") == "" {
   148  		t.Skip("SPANNER_EMULATOR_HOST must be set")
   149  		return nil, nil
   150  	}
   151  	uri, err := ParseURI("projects/my-project/instances/test-instance/databases/" +
   152  		fmt.Sprintf("db%v", time.Now().UnixNano()))
   153  	if err != nil {
   154  		t.Fatal(err)
   155  	}
   156  	ctx := t.Context()
   157  	err = CreateSpannerInstance(ctx, uri)
   158  	if err != nil {
   159  		t.Fatal(err)
   160  	}
   161  	err = CreateSpannerDB(ctx, uri)
   162  	if err != nil {
   163  		t.Fatal(err)
   164  	}
   165  	t.Cleanup(func() {
   166  		err := dropSpannerDB(ctx, uri)
   167  		if err != nil {
   168  			t.Logf("failed to drop the test DB: %v", err)
   169  		}
   170  	})
   171  	client, err := spanner.NewClient(ctx, uri.Full)
   172  	if err != nil {
   173  		t.Fatal(err)
   174  	}
   175  	t.Cleanup(client.Close)
   176  	err = RunMigrations(uri.Full)
   177  	if err != nil {
   178  		t.Fatal(err)
   179  	}
   180  	return client, ctx
   181  }
   182  
   183  var setupSpannerOnce sync.Once
   184  var spannerHost string
   185  
   186  func spannerTestWrapper(t *testing.T, bin string) string {
   187  	setupSpannerOnce.Do(func() {
   188  		t.Logf("this could be the first test requiring a Spanner emulator, starting %s", bin)
   189  		cmd, host, err := runSpanner(bin)
   190  		if err != nil {
   191  			t.Fatal(err)
   192  		}
   193  		spannerHost = host
   194  		t.Cleanup(func() {
   195  			cmd.Process.Kill()
   196  			cmd.Wait()
   197  		})
   198  	})
   199  	return spannerHost
   200  }
   201  
   202  var portRe = regexp.MustCompile(`Server address: ([\w:]+)`)
   203  
   204  func runSpanner(bin string) (*exec.Cmd, string, error) {
   205  	cmd := exec.Command(bin, "--override_max_databases_per_instance=1000",
   206  		"--grpc_port=0", "--http_port=0")
   207  	stdout, err := cmd.StdoutPipe()
   208  	if err != nil {
   209  		return nil, "", err
   210  	}
   211  	cmd.Stderr = cmd.Stdout
   212  	if err := cmd.Start(); err != nil {
   213  		return nil, "", err
   214  	}
   215  	scanner := bufio.NewScanner(stdout)
   216  	started, host := false, ""
   217  	for scanner.Scan() {
   218  		line := scanner.Text()
   219  		if strings.Contains(line, "Cloud Spanner Emulator running") {
   220  			started = true
   221  		} else if parts := portRe.FindStringSubmatch(line); parts != nil {
   222  			host = parts[1]
   223  		}
   224  		if started && host != "" {
   225  			break
   226  		}
   227  	}
   228  	if err := scanner.Err(); err != nil {
   229  		return cmd, "", err
   230  	}
   231  	// The program may block if we don't read out all the remaining output.
   232  	go io.Copy(io.Discard, stdout)
   233  
   234  	if !started {
   235  		return cmd, "", fmt.Errorf("the emulator did not print that it started")
   236  	}
   237  	if host == "" {
   238  		return cmd, "", fmt.Errorf("did not detect the host")
   239  	}
   240  	return cmd, host, nil
   241  }
   242  
   243  func readRow[T any](iter *spanner.RowIterator) (*T, error) {
   244  	row, err := iter.Next()
   245  	if err == iterator.Done {
   246  		return nil, nil
   247  	}
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  	var obj T
   252  	err = row.ToStruct(&obj)
   253  	if err != nil {
   254  		return nil, err
   255  	}
   256  	return &obj, nil
   257  }
   258  
   259  type dbQuerier interface {
   260  	Query(context.Context, spanner.Statement) *spanner.RowIterator
   261  }
   262  
   263  func readEntity[T any](ctx context.Context, txn dbQuerier, stmt spanner.Statement) (*T, error) {
   264  	iter := txn.Query(ctx, stmt)
   265  	defer iter.Stop()
   266  	return readRow[T](iter)
   267  }
   268  
   269  func readRows[T any](iter *spanner.RowIterator) ([]*T, error) {
   270  	var ret []*T
   271  	for {
   272  		obj, err := readRow[T](iter)
   273  		if err != nil {
   274  			return nil, err
   275  		}
   276  		if obj == nil {
   277  			break
   278  		}
   279  		ret = append(ret, obj)
   280  	}
   281  	return ret, nil
   282  }
   283  
   284  func readEntities[T any](ctx context.Context, txn dbQuerier, stmt spanner.Statement) ([]*T, error) {
   285  	iter := txn.Query(ctx, stmt)
   286  	defer iter.Stop()
   287  	return readRows[T](iter)
   288  }
   289  
   290  const NoLimit = 0
   291  
   292  func addLimit(stmt *spanner.Statement, limit int) {
   293  	if limit != NoLimit {
   294  		stmt.SQL += " LIMIT @limit"
   295  		stmt.Params["limit"] = limit
   296  	}
   297  }
   298  
   299  type genericEntityOps[EntityType, KeyType any] struct {
   300  	client   *spanner.Client
   301  	keyField string
   302  	table    string
   303  }
   304  
   305  func (g *genericEntityOps[EntityType, KeyType]) GetByID(ctx context.Context, key KeyType) (*EntityType, error) {
   306  	stmt := spanner.Statement{
   307  		SQL:    "SELECT * FROM " + g.table + " WHERE " + g.keyField + "=@key",
   308  		Params: map[string]interface{}{"key": key},
   309  	}
   310  	return readEntity[EntityType](ctx, g.client.Single(), stmt)
   311  }
   312  
   313  var ErrEntityNotFound = errors.New("entity not found")
   314  
   315  func (g *genericEntityOps[EntityType, KeyType]) Update(ctx context.Context, key KeyType,
   316  	cb func(*EntityType) error) error {
   317  	_, err := g.client.ReadWriteTransaction(ctx,
   318  		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
   319  			entity, err := readEntity[EntityType](ctx, txn, spanner.Statement{
   320  				SQL:    "SELECT * from `" + g.table + "` WHERE `" + g.keyField + "`=@key",
   321  				Params: map[string]interface{}{"key": key},
   322  			})
   323  			if err != nil {
   324  				return err
   325  			}
   326  			if entity == nil {
   327  				return ErrEntityNotFound
   328  			}
   329  			err = cb(entity)
   330  			if err != nil {
   331  				return err
   332  			}
   333  			m, err := spanner.UpdateStruct(g.table, entity)
   334  			if err != nil {
   335  				return err
   336  			}
   337  			return txn.BufferWrite([]*spanner.Mutation{m})
   338  		})
   339  	return err
   340  }
   341  
   342  var errEntityExists = errors.New("entity already exists")
   343  
   344  func (g *genericEntityOps[EntityType, KeyType]) Insert(ctx context.Context, obj *EntityType) error {
   345  	_, err := g.client.ReadWriteTransaction(ctx,
   346  		func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
   347  			insert, err := spanner.InsertStruct(g.table, obj)
   348  			if err != nil {
   349  				return err
   350  			}
   351  			return txn.BufferWrite([]*spanner.Mutation{insert})
   352  		})
   353  	if status.Code(err) == codes.AlreadyExists {
   354  		return errEntityExists
   355  	}
   356  	return err
   357  }
   358  
   359  func (g *genericEntityOps[EntityType, KeyType]) readEntities(ctx context.Context, stmt spanner.Statement) (
   360  	[]*EntityType, error) {
   361  	return readEntities[EntityType](ctx, g.client.Single(), stmt)
   362  }