github.com/quay/claircore@v1.5.28/test/postgres/db_common.go (about)

     1  package postgres
     2  
     3  import (
     4  	"context"
     5  	"embed"
     6  	"fmt"
     7  	"io/fs"
     8  	"os"
     9  	"path"
    10  	"testing"
    11  
    12  	"github.com/jackc/pgx/v4"
    13  	"github.com/jackc/pgx/v4/log/testingadapter"
    14  	"github.com/jackc/pgx/v4/pgxpool"
    15  	"github.com/jackc/pgx/v4/stdlib"
    16  	"github.com/remind101/migrate"
    17  
    18  	"github.com/quay/claircore/datastore/postgres/migrations"
    19  	"github.com/quay/claircore/test/integration"
    20  )
    21  
    22  // TestMatcherDB returns a [pgxpool.Pool] connected to a started and configured
    23  // for a Matcher database.
    24  //
    25  // If any errors are encountered, the test is failed and exited.
    26  func TestMatcherDB(ctx context.Context, t testing.TB) *pgxpool.Pool {
    27  	return testDB(ctx, t, dbMatcher)
    28  }
    29  
    30  // TestIndexerDB returns a [pgxpool.Pool] connected to a started and configured
    31  // for a Indexer database.
    32  //
    33  // If any errors are encountered, the test is failed and exited.
    34  func TestIndexerDB(ctx context.Context, t testing.TB) *pgxpool.Pool {
    35  	return testDB(ctx, t, dbIndexer)
    36  }
    37  
    38  type dbFlavor uint
    39  
    40  const (
    41  	_ dbFlavor = iota
    42  	dbMatcher
    43  	dbIndexer
    44  )
    45  
    46  func testDB(ctx context.Context, t testing.TB, which dbFlavor) *pgxpool.Pool {
    47  	t.Helper()
    48  	db, err := integration.NewDB(ctx, t)
    49  	if err != nil {
    50  		t.Fatalf("unable to create test database: %v", err)
    51  	}
    52  	cfg := db.Config()
    53  	cfg.ConnConfig.LogLevel = pgx.LogLevelError
    54  	cfg.ConnConfig.Logger = testingadapter.NewLogger(t)
    55  	pool, err := pgxpool.ConnectConfig(ctx, cfg)
    56  	if err != nil {
    57  		t.Fatalf("failed to connect: %v", err)
    58  	}
    59  	mdb := stdlib.OpenDB(*cfg.ConnConfig)
    60  	defer mdb.Close()
    61  	// run migrations
    62  	migrator := migrate.NewPostgresMigrator(mdb)
    63  	switch which {
    64  	case dbMatcher:
    65  		migrator.Table = migrations.MatcherMigrationTable
    66  		err = migrator.Exec(migrate.Up, migrations.MatcherMigrations...)
    67  	case dbIndexer:
    68  		migrator.Table = migrations.IndexerMigrationTable
    69  		err = migrator.Exec(migrate.Up, migrations.IndexerMigrations...)
    70  	}
    71  	if err != nil {
    72  		t.Fatalf("failed to perform migrations: %v", err)
    73  	}
    74  	loadHelpers(ctx, t, pool, which)
    75  
    76  	// BUG(hank) The Test*DB functions close over the passed-in Context and use
    77  	// it for the Cleanup method. Because Cleanup functions are earlier in the
    78  	// stack than any defers inside the test, make sure the Context isn't one
    79  	// that's deferred to be canceled.
    80  	t.Cleanup(func() {
    81  		pool.Close()
    82  		db.Close(ctx, t)
    83  	})
    84  	return pool
    85  }
    86  
    87  //go:embed sql
    88  var extraSQL embed.FS
    89  
    90  // LoadHelpers loads extra SQL from both the "sql" directory in this package and
    91  // the test package's "testdata" directory.
    92  //
    93  // The "flavor" argument selects which prefix is added onto the file glob.
    94  func loadHelpers(ctx context.Context, t testing.TB, pool *pgxpool.Pool, flavor dbFlavor) {
    95  	t.Helper()
    96  	logprefix := [...]string{"global", "local"}
    97  	var look []fs.FS
    98  	if sys, err := fs.Sub(extraSQL, "sql"); err != nil {
    99  		t.Fatalf("unexpected error from embed.FS: %v", err)
   100  	} else {
   101  		look = append(look, sys)
   102  	}
   103  	// NB This is relative to the test being run, _not_ this file. Because this
   104  	// is a helper library, this is different than you may expect.
   105  	if sys, err := fs.Sub(os.DirFS("."), "testdata"); err != nil {
   106  		t.Log("no testdata directory, skipping local loading")
   107  	} else {
   108  		look = append(look, sys)
   109  	}
   110  
   111  	conn, err := pool.Acquire(ctx)
   112  	if err != nil {
   113  		t.Fatalf("unable to acquire connection: %v", err)
   114  	}
   115  	defer conn.Release()
   116  	glob := []string{"all_*.psql"}
   117  	switch flavor {
   118  	case dbMatcher:
   119  		glob = append(glob, "matcher_*.psql")
   120  	case dbIndexer:
   121  		glob = append(glob, "indexer_*.psql")
   122  	}
   123  	for i, sys := range look {
   124  		for _, g := range glob {
   125  			ms, err := fs.Glob(sys, g)
   126  			if err != nil {
   127  				panic(fmt.Sprintf("programmer error: %v", err))
   128  			}
   129  			for _, f := range ms {
   130  				b, err := fs.ReadFile(sys, f)
   131  				if err != nil {
   132  					t.Error(err)
   133  					continue
   134  				}
   135  				t.Logf("loading %s %q", logprefix[i], path.Base(f))
   136  				if _, err := conn.Exec(ctx, string(b)); err != nil {
   137  					t.Error(err)
   138  				}
   139  			}
   140  		}
   141  	}
   142  }