github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/syz-cluster/pkg/db/spanner.go (about) 1 // Copyright 2024 syzkaller project authors. All rights reserved. 2 // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. 3 4 package db 5 6 import ( 7 "bufio" 8 "context" 9 "embed" 10 "errors" 11 "fmt" 12 "io" 13 "os" 14 "os/exec" 15 "regexp" 16 "strings" 17 "sync" 18 "testing" 19 "time" 20 21 "cloud.google.com/go/spanner" 22 database "cloud.google.com/go/spanner/admin/database/apiv1" 23 "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" 24 instance "cloud.google.com/go/spanner/admin/instance/apiv1" 25 "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" 26 "github.com/golang-migrate/migrate/v4" 27 migrate_spanner "github.com/golang-migrate/migrate/v4/database/spanner" 28 "github.com/golang-migrate/migrate/v4/source/iofs" 29 "google.golang.org/api/iterator" 30 "google.golang.org/grpc/codes" 31 "google.golang.org/grpc/status" 32 ) 33 34 type ParsedURI struct { 35 ProjectPrefix string // projects/<project> 36 InstancePrefix string // projects/<project>/instances/<instance> 37 Instance string 38 Database string 39 Full string 40 } 41 42 func ParseURI(uri string) (ParsedURI, error) { 43 ret := ParsedURI{Full: uri} 44 matches := regexp.MustCompile(`projects/(.*)/instances/(.*)/databases/(.*)`).FindStringSubmatch(uri) 45 if matches == nil || len(matches) != 4 { 46 return ret, fmt.Errorf("failed to parse %q", uri) 47 } 48 ret.ProjectPrefix = "projects/" + matches[1] 49 ret.InstancePrefix = ret.ProjectPrefix + "/instances/" + matches[2] 50 ret.Instance = matches[2] 51 ret.Database = matches[3] 52 return ret, nil 53 } 54 55 func CreateSpannerInstance(ctx context.Context, uri ParsedURI) error { 56 client, err := instance.NewInstanceAdminClient(ctx) 57 if err != nil { 58 return err 59 } 60 defer client.Close() 61 _, err = client.GetInstance(ctx, &instancepb.GetInstanceRequest{ 62 Name: uri.InstancePrefix, 63 }) 64 if err != nil && spanner.ErrCode(err) == codes.NotFound { 65 _, err = client.CreateInstance(ctx, &instancepb.CreateInstanceRequest{ 66 Parent: uri.ProjectPrefix, 67 InstanceId: uri.Instance, 68 }) 69 return err 70 } 71 return err 72 } 73 74 func CreateSpannerDB(ctx context.Context, uri ParsedURI) error { 75 client, err := database.NewDatabaseAdminClient(ctx) 76 if err != nil { 77 return err 78 } 79 defer client.Close() 80 _, err = client.GetDatabase(ctx, &databasepb.GetDatabaseRequest{Name: uri.Full}) 81 if err != nil && spanner.ErrCode(err) == codes.NotFound { 82 op, err := client.CreateDatabase(ctx, &databasepb.CreateDatabaseRequest{ 83 Parent: uri.InstancePrefix, 84 CreateStatement: `CREATE DATABASE ` + uri.Database, 85 ExtraStatements: []string{}, 86 }) 87 if err != nil { 88 return err 89 } 90 _, err = op.Wait(ctx) 91 return err 92 } 93 return err 94 } 95 96 func dropSpannerDB(ctx context.Context, uri ParsedURI) error { 97 client, err := database.NewDatabaseAdminClient(ctx) 98 if err != nil { 99 return err 100 } 101 defer client.Close() 102 return client.DropDatabase(ctx, &databasepb.DropDatabaseRequest{Database: uri.Full}) 103 } 104 105 //go:embed migrations/*.sql 106 var migrationsFs embed.FS 107 108 func RunMigrations(uri string) error { 109 m, err := getMigrateInstance(uri) 110 if err != nil { 111 return err 112 } 113 err = m.Up() 114 if err == migrate.ErrNoChange { 115 // Not really an error. 116 return nil 117 } 118 return err 119 } 120 121 func getMigrateInstance(uri string) (*migrate.Migrate, error) { 122 sourceDriver, err := iofs.New(migrationsFs, "migrations") 123 if err != nil { 124 return nil, err 125 } 126 s := &migrate_spanner.Spanner{} 127 dbDriver, err := s.Open("spanner://" + uri + "?x-clean-statements=true") 128 if err != nil { 129 return nil, err 130 } 131 m, err := migrate.NewWithInstance("iofs", sourceDriver, "spanner", dbDriver) 132 if err != nil { 133 return nil, err 134 } 135 return m, nil 136 } 137 138 func NewTransientDB(t *testing.T) (*spanner.Client, context.Context) { 139 // If the environment contains the emulator binary, start it. 140 if bin := os.Getenv("SPANNER_EMULATOR_BIN"); bin != "" { 141 host := spannerTestWrapper(t, bin) 142 os.Setenv("SPANNER_EMULATOR_HOST", host) 143 } else if os.Getenv("CI") != "" { 144 // We do want to always run these tests on CI. 145 t.Fatalf("CI is set, but SPANNER_EMULATOR_BIN is empty") 146 } 147 if os.Getenv("SPANNER_EMULATOR_HOST") == "" { 148 t.Skip("SPANNER_EMULATOR_HOST must be set") 149 return nil, nil 150 } 151 uri, err := ParseURI("projects/my-project/instances/test-instance/databases/" + 152 fmt.Sprintf("db%v", time.Now().UnixNano())) 153 if err != nil { 154 t.Fatal(err) 155 } 156 ctx := t.Context() 157 err = CreateSpannerInstance(ctx, uri) 158 if err != nil { 159 t.Fatal(err) 160 } 161 err = CreateSpannerDB(ctx, uri) 162 if err != nil { 163 t.Fatal(err) 164 } 165 t.Cleanup(func() { 166 err := dropSpannerDB(ctx, uri) 167 if err != nil { 168 t.Logf("failed to drop the test DB: %v", err) 169 } 170 }) 171 client, err := spanner.NewClient(ctx, uri.Full) 172 if err != nil { 173 t.Fatal(err) 174 } 175 t.Cleanup(client.Close) 176 err = RunMigrations(uri.Full) 177 if err != nil { 178 t.Fatal(err) 179 } 180 return client, ctx 181 } 182 183 var setupSpannerOnce sync.Once 184 var spannerHost string 185 186 func spannerTestWrapper(t *testing.T, bin string) string { 187 setupSpannerOnce.Do(func() { 188 t.Logf("this could be the first test requiring a Spanner emulator, starting %s", bin) 189 cmd, host, err := runSpanner(bin) 190 if err != nil { 191 t.Fatal(err) 192 } 193 spannerHost = host 194 t.Cleanup(func() { 195 cmd.Process.Kill() 196 cmd.Wait() 197 }) 198 }) 199 return spannerHost 200 } 201 202 var portRe = regexp.MustCompile(`Server address: ([\w:]+)`) 203 204 func runSpanner(bin string) (*exec.Cmd, string, error) { 205 cmd := exec.Command(bin, "--override_max_databases_per_instance=1000", 206 "--grpc_port=0", "--http_port=0") 207 stdout, err := cmd.StdoutPipe() 208 if err != nil { 209 return nil, "", err 210 } 211 cmd.Stderr = cmd.Stdout 212 if err := cmd.Start(); err != nil { 213 return nil, "", err 214 } 215 scanner := bufio.NewScanner(stdout) 216 started, host := false, "" 217 for scanner.Scan() { 218 line := scanner.Text() 219 if strings.Contains(line, "Cloud Spanner Emulator running") { 220 started = true 221 } else if parts := portRe.FindStringSubmatch(line); parts != nil { 222 host = parts[1] 223 } 224 if started && host != "" { 225 break 226 } 227 } 228 if err := scanner.Err(); err != nil { 229 return cmd, "", err 230 } 231 // The program may block if we don't read out all the remaining output. 232 go io.Copy(io.Discard, stdout) 233 234 if !started { 235 return cmd, "", fmt.Errorf("the emulator did not print that it started") 236 } 237 if host == "" { 238 return cmd, "", fmt.Errorf("did not detect the host") 239 } 240 return cmd, host, nil 241 } 242 243 func readRow[T any](iter *spanner.RowIterator) (*T, error) { 244 row, err := iter.Next() 245 if err == iterator.Done { 246 return nil, nil 247 } 248 if err != nil { 249 return nil, err 250 } 251 var obj T 252 err = row.ToStruct(&obj) 253 if err != nil { 254 return nil, err 255 } 256 return &obj, nil 257 } 258 259 type dbQuerier interface { 260 Query(context.Context, spanner.Statement) *spanner.RowIterator 261 } 262 263 func readEntity[T any](ctx context.Context, txn dbQuerier, stmt spanner.Statement) (*T, error) { 264 iter := txn.Query(ctx, stmt) 265 defer iter.Stop() 266 return readRow[T](iter) 267 } 268 269 func readRows[T any](iter *spanner.RowIterator) ([]*T, error) { 270 var ret []*T 271 for { 272 obj, err := readRow[T](iter) 273 if err != nil { 274 return nil, err 275 } 276 if obj == nil { 277 break 278 } 279 ret = append(ret, obj) 280 } 281 return ret, nil 282 } 283 284 func readEntities[T any](ctx context.Context, txn dbQuerier, stmt spanner.Statement) ([]*T, error) { 285 iter := txn.Query(ctx, stmt) 286 defer iter.Stop() 287 return readRows[T](iter) 288 } 289 290 const NoLimit = 0 291 292 func addLimit(stmt *spanner.Statement, limit int) { 293 if limit != NoLimit { 294 stmt.SQL += " LIMIT @limit" 295 stmt.Params["limit"] = limit 296 } 297 } 298 299 type genericEntityOps[EntityType, KeyType any] struct { 300 client *spanner.Client 301 keyField string 302 table string 303 } 304 305 func (g *genericEntityOps[EntityType, KeyType]) GetByID(ctx context.Context, key KeyType) (*EntityType, error) { 306 stmt := spanner.Statement{ 307 SQL: "SELECT * FROM " + g.table + " WHERE " + g.keyField + "=@key", 308 Params: map[string]interface{}{"key": key}, 309 } 310 return readEntity[EntityType](ctx, g.client.Single(), stmt) 311 } 312 313 var ErrEntityNotFound = errors.New("entity not found") 314 315 func (g *genericEntityOps[EntityType, KeyType]) Update(ctx context.Context, key KeyType, 316 cb func(*EntityType) error) error { 317 _, err := g.client.ReadWriteTransaction(ctx, 318 func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { 319 entity, err := readEntity[EntityType](ctx, txn, spanner.Statement{ 320 SQL: "SELECT * from `" + g.table + "` WHERE `" + g.keyField + "`=@key", 321 Params: map[string]interface{}{"key": key}, 322 }) 323 if err != nil { 324 return err 325 } 326 if entity == nil { 327 return ErrEntityNotFound 328 } 329 err = cb(entity) 330 if err != nil { 331 return err 332 } 333 m, err := spanner.UpdateStruct(g.table, entity) 334 if err != nil { 335 return err 336 } 337 return txn.BufferWrite([]*spanner.Mutation{m}) 338 }) 339 return err 340 } 341 342 var errEntityExists = errors.New("entity already exists") 343 344 func (g *genericEntityOps[EntityType, KeyType]) Insert(ctx context.Context, obj *EntityType) error { 345 _, err := g.client.ReadWriteTransaction(ctx, 346 func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { 347 insert, err := spanner.InsertStruct(g.table, obj) 348 if err != nil { 349 return err 350 } 351 return txn.BufferWrite([]*spanner.Mutation{insert}) 352 }) 353 if status.Code(err) == codes.AlreadyExists { 354 return errEntityExists 355 } 356 return err 357 } 358 359 func (g *genericEntityOps[EntityType, KeyType]) readEntities(ctx context.Context, stmt spanner.Statement) ( 360 []*EntityType, error) { 361 return readEntities[EntityType](ctx, g.client.Single(), stmt) 362 }