go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/resultdb/internal/testutil/spantest.go (about)

     1  // Copyright 2019 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testutil
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"os"
    21  	"path/filepath"
    22  	"testing"
    23  	"time"
    24  
    25  	"cloud.google.com/go/spanner"
    26  	"google.golang.org/grpc/codes"
    27  
    28  	"go.chromium.org/luci/common/clock"
    29  	"go.chromium.org/luci/common/errors"
    30  	"go.chromium.org/luci/common/spantest"
    31  	"go.chromium.org/luci/server/redisconn"
    32  	"go.chromium.org/luci/server/span"
    33  
    34  	"go.chromium.org/luci/resultdb/internal/spanutil"
    35  
    36  	. "github.com/smartystreets/goconvey/convey"
    37  )
    38  
    39  const (
    40  	// IntegrationTestEnvVar is the name of the environment variable which controls
    41  	// whether spanner tests are executed.
    42  	// The value must be "1" for integration tests to run.
    43  	IntegrationTestEnvVar = "INTEGRATION_TESTS"
    44  
    45  	// RedisTestEnvVar is the name of the environment variable which controls
    46  	// whether tests will attempt to connect to *local* Redis at port 6379.
    47  	// The value must be "1" to connect to Redis.
    48  	//
    49  	// Note that this mode does not support running multiple test binaries in
    50  	// parallel, e.g. `go test ./...`.
    51  	// This could be mitigated by using different Redis databases in different
    52  	// test binaries, but the default limit is only 16.
    53  	RedisTestEnvVar = "INTEGRATION_TESTS_REDIS"
    54  
    55  	// EmulatorEnvVar is the name of the environment variable which controls
    56  	// whether to run spanner tests using Cloud Spanner Emulator.
    57  	// The value must be "1" to use emulator.
    58  	EmulatorEnvVar = "SPANNER_EMULATOR"
    59  )
    60  
    61  // runIntegrationTests returns true if integration tests should run.
    62  func runIntegrationTests() bool {
    63  	return os.Getenv(IntegrationTestEnvVar) == "1"
    64  }
    65  
    66  func runIntegrationTestsWithEmulator() bool {
    67  	return runIntegrationTests() && os.Getenv(EmulatorEnvVar) == "1"
    68  }
    69  
    70  // ConnectToRedis returns true if tests should connect to Redis.
    71  func ConnectToRedis() bool {
    72  	return os.Getenv(RedisTestEnvVar) == "1"
    73  }
    74  
    75  var spannerClient *spanner.Client
    76  
    77  // SpannerTestContext returns a context for testing code that talks to Spanner.
    78  // Skips the test if integration tests are not enabled.
    79  //
    80  // Tests that use Spanner must not call t.Parallel().
    81  func SpannerTestContext(tb testing.TB) context.Context {
    82  	switch {
    83  	case !runIntegrationTests():
    84  		tb.Skipf("env var %s=1 is missing", IntegrationTestEnvVar)
    85  	case spannerClient == nil:
    86  		tb.Fatalf("spanner client is not initialized; forgot to call SpannerTestMain?")
    87  	}
    88  
    89  	// Do not mock clock in integration tests because we cannot mock Spanner's
    90  	// clock.
    91  	ctx := testingContext(false)
    92  	err := cleanupDatabase(ctx, spannerClient)
    93  	if err != nil {
    94  		tb.Fatal(err)
    95  	}
    96  
    97  	ctx = span.UseClient(ctx, spannerClient)
    98  
    99  	if ConnectToRedis() {
   100  		ctx = redisconn.UsePool(ctx, redisconn.NewPool("localhost:6379", 0))
   101  		if err := cleanupRedis(ctx); err != nil {
   102  			tb.Fatal(err)
   103  		}
   104  	}
   105  
   106  	return ctx
   107  }
   108  
   109  // findInitScript returns path //resultdb/internal/spanutil/init_db.sql.
   110  func findInitScript() (string, error) {
   111  	ancestor, err := filepath.Abs(".")
   112  	if err != nil {
   113  		return "", err
   114  	}
   115  
   116  	for {
   117  		scriptPath := filepath.Join(ancestor, "internal", "spanutil", "init_db.sql")
   118  		_, err := os.Stat(scriptPath)
   119  		if os.IsNotExist(err) {
   120  			parent := filepath.Dir(ancestor)
   121  			if parent == ancestor {
   122  				return "", errors.Reason("init_db.sql not found").Err()
   123  			}
   124  			ancestor = parent
   125  			continue
   126  		}
   127  
   128  		return scriptPath, err
   129  	}
   130  }
   131  
   132  // SpannerTestMain is a test main function for packages that have tests that
   133  // talk to spanner. It creates/destroys a temporary spanner database
   134  // before/after running tests.
   135  //
   136  // This function never returns. Instead it calls os.Exit with the value returned
   137  // by m.Run().
   138  func SpannerTestMain(m *testing.M) {
   139  	exitCode, err := spannerTestMain(m)
   140  	if err != nil {
   141  		fmt.Fprintln(os.Stderr, err)
   142  		os.Exit(1)
   143  	}
   144  
   145  	os.Exit(exitCode)
   146  }
   147  
   148  func spannerTestMain(m *testing.M) (exitCode int, err error) {
   149  	testing.Init()
   150  
   151  	if runIntegrationTests() {
   152  		ctx := context.Background()
   153  		start := clock.Now(ctx)
   154  		var instanceName string
   155  		var emulator *spantest.Emulator
   156  
   157  		if runIntegrationTestsWithEmulator() {
   158  			var err error
   159  			// Start Cloud Spanner Emulator.
   160  			if emulator, err = spantest.StartEmulator(ctx); err != nil {
   161  				return 0, err
   162  			}
   163  			defer func() {
   164  				switch stopErr := emulator.Stop(); {
   165  				case stopErr == nil:
   166  
   167  				case err == nil:
   168  					err = stopErr
   169  
   170  				default:
   171  					fmt.Fprintf(os.Stderr, "failed to stop the emulator: %s\n", stopErr)
   172  				}
   173  			}()
   174  
   175  			// Create a Spanner instance.
   176  			if instanceName, err = emulator.NewInstance(ctx, ""); err != nil {
   177  				return 0, err
   178  			}
   179  			fmt.Printf("started cloud emulatorlator in and created a temporary Spanner instance %s in %s\n", instanceName, time.Since(start))
   180  			start = clock.Now(ctx)
   181  		}
   182  
   183  		// Find init_db.sql
   184  		initScriptPath, err := findInitScript()
   185  		if err != nil {
   186  			return 0, err
   187  		}
   188  
   189  		// Create a Spanner database.
   190  		db, err := spantest.NewTempDB(ctx, spantest.TempDBConfig{InitScriptPath: initScriptPath, InstanceName: instanceName}, emulator)
   191  		if err != nil {
   192  			return 0, errors.Annotate(err, "failed to create a temporary Spanner database").Err()
   193  		}
   194  		fmt.Printf("created a temporary Spanner database %s in %s\n", db.Name, time.Since(start))
   195  
   196  		defer func() {
   197  			switch dropErr := db.Drop(ctx); {
   198  			case dropErr == nil:
   199  
   200  			case err == nil:
   201  				err = dropErr
   202  
   203  			default:
   204  				fmt.Fprintf(os.Stderr, "failed to drop the database: %s\n", dropErr)
   205  			}
   206  		}()
   207  
   208  		// Create a global Spanner client.
   209  		spannerClient, err = db.Client(ctx)
   210  		if err != nil {
   211  			return 0, err
   212  		}
   213  	}
   214  
   215  	return m.Run(), nil
   216  }
   217  
   218  // cleanupDatabase deletes all data from all tables.
   219  func cleanupDatabase(ctx context.Context, client *spanner.Client) error {
   220  	_, err := client.Apply(ctx, []*spanner.Mutation{
   221  		// Tables that are not interleaved in Invocations table.
   222  		spanner.Delete("TQReminders", spanner.AllKeys()),
   223  		spanner.Delete("TestMetadata", spanner.AllKeys()),
   224  
   225  		// All other tables are interleaved in Invocations table.
   226  		spanner.Delete("Invocations", spanner.AllKeys()),
   227  
   228  		spanner.Delete("Baselines", spanner.AllKeys()),
   229  		spanner.Delete("BaselineTestVariants", spanner.AllKeys()),
   230  	})
   231  	return err
   232  }
   233  
   234  // cleanupRedis deletes all data from the selected Redis database.
   235  func cleanupRedis(ctx context.Context) error {
   236  	conn, err := redisconn.Get(ctx)
   237  	if err != nil {
   238  		return err
   239  	}
   240  
   241  	_, err = conn.Do("FLUSHDB")
   242  	return err
   243  }
   244  
   245  // MustApply applies the mutations to the spanner client in the context.
   246  // Asserts that application succeeds.
   247  // Returns the commit timestamp.
   248  func MustApply(ctx context.Context, ms ...*spanner.Mutation) time.Time {
   249  	ct, err := span.Apply(ctx, ms)
   250  	So(err, ShouldBeNil)
   251  	return ct
   252  }
   253  
   254  // CombineMutations concatenates mutations
   255  func CombineMutations(msSlice ...[]*spanner.Mutation) []*spanner.Mutation {
   256  	totalLen := 0
   257  	for _, ms := range msSlice {
   258  		totalLen += len(ms)
   259  	}
   260  	ret := make([]*spanner.Mutation, 0, totalLen)
   261  	for _, ms := range msSlice {
   262  		ret = append(ret, ms...)
   263  	}
   264  	return ret
   265  }
   266  
   267  // MustReadRow is a shortcut to do a single row read in a single transaction
   268  // using the current client, and assert success.
   269  func MustReadRow(ctx context.Context, table string, key spanner.Key, ptrMap map[string]any) {
   270  	err := spanutil.ReadRow(span.Single(span.WithoutTxn(ctx)), table, key, ptrMap)
   271  	So(err, ShouldBeNil)
   272  }
   273  
   274  // MustNotFindRow is a shortcut to do a single row read in a single transaction
   275  // using the current client, and assert the row was not found.
   276  func MustNotFindRow(ctx context.Context, table string, key spanner.Key, ptrMap map[string]any) {
   277  	err := spanutil.ReadRow(span.Single(span.WithoutTxn(ctx)), table, key, ptrMap)
   278  	So(spanner.ErrCode(err), ShouldEqual, codes.NotFound)
   279  }