github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/testserver/datastore/postgres.go (about)

     1  //go:build docker
     2  // +build docker
     3  
     4  package datastore
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/google/uuid"
    13  	"github.com/jackc/pgx/v5"
    14  	"github.com/ory/dockertest/v3"
    15  	"github.com/ory/dockertest/v3/docker"
    16  	"github.com/stretchr/testify/require"
    17  
    18  	pgmigrations "github.com/authzed/spicedb/internal/datastore/postgres/migrations"
    19  	"github.com/authzed/spicedb/pkg/datastore"
    20  	"github.com/authzed/spicedb/pkg/migrate"
    21  	"github.com/authzed/spicedb/pkg/secrets"
    22  )
    23  
    24  const (
    25  	POSTGRES_TEST_USER            = "postgres"
    26  	POSTGRES_TEST_PASSWORD        = "secret"
    27  	POSTGRES_TEST_PORT            = "5432"
    28  	POSTGRES_TEST_MAX_CONNECTIONS = "2000"
    29  	PGBOUNCER_TEST_PORT           = "6432"
    30  )
    31  
    32  type container struct {
    33  	hostHostname      string
    34  	hostPort          string
    35  	containerHostname string
    36  	containerPort     string
    37  }
    38  
    39  type postgresTester struct {
    40  	container
    41  	hostConn             *pgx.Conn
    42  	creds                string
    43  	targetMigration      string
    44  	pgbouncerProxy       *container
    45  	useContainerHostname bool
    46  }
    47  
    48  // RunPostgresForTesting returns a RunningEngineForTest for postgres
    49  func RunPostgresForTesting(t testing.TB, bridgeNetworkName string, targetMigration string, pgVersion string, enablePgbouncer bool) RunningEngineForTest {
    50  	return RunPostgresForTestingWithCommitTimestamps(t, bridgeNetworkName, targetMigration, true, pgVersion, enablePgbouncer)
    51  }
    52  
    53  func RunPostgresForTestingWithCommitTimestamps(t testing.TB, bridgeNetworkName string, targetMigration string, withCommitTimestamps bool, pgVersion string, enablePgbouncer bool) RunningEngineForTest {
    54  	pool, err := dockertest.NewPool("")
    55  	require.NoError(t, err)
    56  
    57  	bridgeSupplied := bridgeNetworkName != ""
    58  	if enablePgbouncer && !bridgeSupplied {
    59  		// We will need a network bridge if we're running pgbouncer
    60  		bridgeNetworkName = createNetworkBridge(t, pool)
    61  	}
    62  
    63  	postgresContainerHostname := fmt.Sprintf("postgres-%s", uuid.New().String())
    64  
    65  	cmd := []string{"-N", POSTGRES_TEST_MAX_CONNECTIONS}
    66  	if withCommitTimestamps {
    67  		cmd = append(cmd, "-c", "track_commit_timestamp=1")
    68  	}
    69  
    70  	postgres, err := pool.RunWithOptions(&dockertest.RunOptions{
    71  		Name:       postgresContainerHostname,
    72  		Repository: "mirror.gcr.io/library/postgres",
    73  		Tag:        pgVersion,
    74  		Env: []string{
    75  			"POSTGRES_USER=" + POSTGRES_TEST_USER,
    76  			"POSTGRES_PASSWORD=" + POSTGRES_TEST_PASSWORD,
    77  			// use md5 auth to align postgres and pgbouncer auth methods
    78  			"POSTGRES_HOST_AUTH_METHOD=md5",
    79  			"POSTGRES_INITDB_ARGS=--auth=md5",
    80  		},
    81  		ExposedPorts: []string{POSTGRES_TEST_PORT + "/tcp"},
    82  		NetworkID:    bridgeNetworkName,
    83  		Cmd:          cmd,
    84  	})
    85  	require.NoError(t, err)
    86  	t.Cleanup(func() {
    87  		require.NoError(t, pool.Purge(postgres))
    88  	})
    89  
    90  	builder := &postgresTester{
    91  		container: container{
    92  			hostHostname:      "localhost",
    93  			hostPort:          postgres.GetPort(POSTGRES_TEST_PORT + "/tcp"),
    94  			containerHostname: postgresContainerHostname,
    95  			containerPort:     POSTGRES_TEST_PORT,
    96  		},
    97  		creds:                POSTGRES_TEST_USER + ":" + POSTGRES_TEST_PASSWORD,
    98  		targetMigration:      targetMigration,
    99  		useContainerHostname: bridgeSupplied,
   100  	}
   101  
   102  	if enablePgbouncer {
   103  		// if we are running with pgbouncer enabled then set it up
   104  		builder.runPgbouncerForTesting(t, pool, bridgeNetworkName)
   105  	}
   106  
   107  	builder.hostConn = builder.initializeHostConnection(t, pool)
   108  
   109  	return builder
   110  }
   111  
   112  func (b *postgresTester) NewDatabase(t testing.TB) string {
   113  	uniquePortion, err := secrets.TokenHex(4)
   114  	require.NoError(t, err)
   115  
   116  	newDBName := "db" + uniquePortion
   117  
   118  	_, err = b.hostConn.Exec(context.Background(), "CREATE DATABASE "+newDBName)
   119  	require.NoError(t, err)
   120  
   121  	hostname, port := b.getHostnameAndPort()
   122  	return fmt.Sprintf(
   123  		"postgres://%s@%s:%s/%s?sslmode=disable",
   124  		b.creds,
   125  		hostname,
   126  		port,
   127  		newDBName,
   128  	)
   129  }
   130  
   131  const (
   132  	retryCount         = 3
   133  	timeBetweenRetries = 100 * time.Millisecond
   134  )
   135  
   136  func (b *postgresTester) NewDatastore(t testing.TB, initFunc InitFunc) datastore.Datastore {
   137  	for i := 0; i < retryCount; i++ {
   138  		connectStr := b.NewDatabase(t)
   139  
   140  		migrationDriver, err := pgmigrations.NewAlembicPostgresDriver(context.Background(), connectStr, datastore.NoCredentialsProvider)
   141  		if err == nil {
   142  			ctx := context.WithValue(context.Background(), migrate.BackfillBatchSize, uint64(1000))
   143  			require.NoError(t, pgmigrations.DatabaseMigrations.Run(ctx, migrationDriver, b.targetMigration, migrate.LiveRun))
   144  			return initFunc("postgres", connectStr)
   145  		}
   146  
   147  		if i == retryCount-1 {
   148  			require.NoError(t, err, "got error when trying to create migration driver")
   149  		}
   150  
   151  		time.Sleep(timeBetweenRetries)
   152  	}
   153  
   154  	require.Fail(t, "failed to create datastore for testing")
   155  	return nil
   156  }
   157  
   158  func createNetworkBridge(t testing.TB, pool *dockertest.Pool) string {
   159  	bridgeNetworkName := fmt.Sprintf("bridge-%s", uuid.New().String())
   160  	network, err := pool.Client.CreateNetwork(docker.CreateNetworkOptions{Name: bridgeNetworkName})
   161  
   162  	require.NoError(t, err)
   163  	t.Cleanup(func() {
   164  		pool.Client.RemoveNetwork(network.ID)
   165  	})
   166  
   167  	return bridgeNetworkName
   168  }
   169  
   170  func (b *postgresTester) runPgbouncerForTesting(t testing.TB, pool *dockertest.Pool, bridgeNetworkName string) {
   171  	uniqueID := uuid.New().String()
   172  	pgbouncerContainerHostname := fmt.Sprintf("pgbouncer-%s", uniqueID)
   173  
   174  	pgbouncer, err := pool.RunWithOptions(&dockertest.RunOptions{
   175  		Name:       pgbouncerContainerHostname,
   176  		Repository: "mirror.gcr.io/edoburu/pgbouncer",
   177  		Tag:        "latest",
   178  		Env: []string{
   179  			"DB_USER=" + POSTGRES_TEST_USER,
   180  			"DB_PASSWORD=" + POSTGRES_TEST_PASSWORD,
   181  			"DB_HOST=" + b.containerHostname,
   182  			"DB_PORT=" + b.containerPort,
   183  			"LISTEN_PORT=" + PGBOUNCER_TEST_PORT,
   184  			"DB_NAME=*",     // Needed to make pgbouncer okay with the randomly named databases generated by the test suite
   185  			"AUTH_TYPE=md5", // use the same auth type as postgres
   186  			"MAX_CLIENT_CONN=" + POSTGRES_TEST_MAX_CONNECTIONS,
   187  			// params needed for spicedb
   188  			"POOL_MODE=session",                         // https://github.com/authzed/spicedb/issues/1217
   189  			"IGNORE_STARTUP_PARAMETERS=plan_cache_mode", // Tell pgbouncer to pass this param thru to postgres.
   190  		},
   191  		ExposedPorts: []string{PGBOUNCER_TEST_PORT + "/tcp"},
   192  		NetworkID:    bridgeNetworkName,
   193  	})
   194  	require.NoError(t, err)
   195  	t.Cleanup(func() {
   196  		require.NoError(t, pool.Purge(pgbouncer))
   197  	})
   198  
   199  	b.pgbouncerProxy = &container{
   200  		hostHostname:      "localhost",
   201  		hostPort:          pgbouncer.GetPort(PGBOUNCER_TEST_PORT + "/tcp"),
   202  		containerHostname: pgbouncerContainerHostname,
   203  		containerPort:     PGBOUNCER_TEST_PORT,
   204  	}
   205  }
   206  
   207  func (b *postgresTester) initializeHostConnection(t testing.TB, pool *dockertest.Pool) (conn *pgx.Conn) {
   208  	hostname, port := b.getHostHostnameAndPort()
   209  	uri := fmt.Sprintf("postgresql://%s@%s:%s/?sslmode=disable", b.creds, hostname, port)
   210  	err := pool.Retry(func() error {
   211  		var err error
   212  		ctx, cancelConnect := context.WithTimeout(context.Background(), dockerBootTimeout)
   213  		defer cancelConnect()
   214  		conn, err = pgx.Connect(ctx, uri)
   215  		if err != nil {
   216  			return err
   217  		}
   218  		return nil
   219  	})
   220  	require.NoError(t, err)
   221  	return conn
   222  }
   223  
   224  func (b *postgresTester) getHostnameAndPort() (string, string) {
   225  	// If a bridgeNetworkName is supplied then we will return the container
   226  	// hostname and port that is resolvable from within the container network.
   227  	// If bridgeNetworkName is not supplied then the hostname and port will be
   228  	// resolvable from the host.
   229  	if b.useContainerHostname {
   230  		return b.getContainerHostnameAndPort()
   231  	}
   232  	return b.getHostHostnameAndPort()
   233  }
   234  
   235  func (b *postgresTester) getHostHostnameAndPort() (string, string) {
   236  	if b.pgbouncerProxy != nil {
   237  		return b.pgbouncerProxy.hostHostname, b.pgbouncerProxy.hostPort
   238  	}
   239  	return b.hostHostname, b.hostPort
   240  }
   241  
   242  func (b *postgresTester) getContainerHostnameAndPort() (string, string) {
   243  	if b.pgbouncerProxy != nil {
   244  		return b.pgbouncerProxy.containerHostname, b.pgbouncerProxy.containerPort
   245  	}
   246  	return b.containerHostname, b.containerPort
   247  }