github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/testutil/db.go (about)

     1  package testutil
     2  
     3  import (
     4  	"context"
     5  	"flag"
     6  	"fmt"
     7  	"log"
     8  	"os"
     9  	"strings"
    10  	"testing"
    11  
    12  	"cloud.google.com/go/storage"
    13  	"github.com/jackc/pgx/v5/pgxpool"
    14  	_ "github.com/jackc/pgx/v5/stdlib"
    15  	"github.com/ory/dockertest/v3"
    16  	"github.com/treeverse/lakefs/pkg/block"
    17  	"github.com/treeverse/lakefs/pkg/block/gs"
    18  	"github.com/treeverse/lakefs/pkg/block/mem"
    19  	"github.com/treeverse/lakefs/pkg/block/params"
    20  	blocks3 "github.com/treeverse/lakefs/pkg/block/s3"
    21  )
    22  
    23  const (
    24  	DBName                    = "lakefs_db"
    25  	DBContainerTimeoutSeconds = 60 * 30 // 30 minutes
    26  
    27  	EnvKeyUseBlockAdapter = "USE_BLOCK_ADAPTER" // pragma: allowlist secret
    28  	envKeyAwsKeyID        = "AWS_ACCESS_KEY_ID"
    29  	envKeyAwsSecretKey    = "AWS_SECRET_ACCESS_KEY" //nolint:gosec
    30  	envKeyAwsRegion       = "AWS_DEFAULT_REGION"    // pragma: allowlist secret
    31  )
    32  
    33  var keepDB = flag.Bool("keep-db", false, "keep test DB instance running")
    34  
    35  func GetDBInstance(pool *dockertest.Pool) (string, func()) {
    36  	// connect using docker container name
    37  	containerName := os.Getenv("PG_DB_CONTAINER")
    38  	if containerName != "" {
    39  		resource, ok := pool.ContainerByName(containerName)
    40  		if !ok {
    41  			log.Fatalf("Cloud not find DB container (%s)", containerName)
    42  		}
    43  		uri := formatPostgresResourceURI(resource)
    44  		return uri, func() {}
    45  	}
    46  
    47  	// connect using supply address
    48  	dbURI := os.Getenv("PG_TEST_URI")
    49  	if len(dbURI) > 0 {
    50  		// use supplied DB connection for testing
    51  		if err := verifyDBConnectionString(dbURI); err != nil {
    52  			log.Fatalf("could not connect to postgres: %s", err)
    53  		}
    54  		return dbURI, func() {}
    55  	}
    56  
    57  	// run new instance and connect
    58  	resource, err := pool.Run("postgres", "11", []string{
    59  		"POSTGRES_USER=lakefs",
    60  		"POSTGRES_PASSWORD=lakefs",
    61  		fmt.Sprintf("POSTGRES_DB=%s", DBName),
    62  	})
    63  	if err != nil {
    64  		log.Fatalf("Could not start postgresql: %s", err)
    65  	}
    66  
    67  	// expire the container, just to be on the safe side
    68  	if !*keepDB {
    69  		err = resource.Expire(DBContainerTimeoutSeconds)
    70  		if err != nil {
    71  			log.Fatalf("could not expire postgres container")
    72  		}
    73  	}
    74  
    75  	// format db uri
    76  	uri := formatPostgresResourceURI(resource)
    77  
    78  	// wait for container to start and connect to db
    79  	if err = pool.Retry(func() error {
    80  		return verifyDBConnectionString(uri)
    81  	}); err != nil {
    82  		log.Fatalf("could not connect to postgres: %s", err)
    83  	}
    84  
    85  	// set cleanup
    86  	closer := func() {
    87  		if *keepDB {
    88  			return
    89  		}
    90  		err := pool.Purge(resource)
    91  		if err != nil {
    92  			log.Fatalf("could not kill postgres container")
    93  		}
    94  	}
    95  
    96  	// return DB address and closer func
    97  	return uri, closer
    98  }
    99  
   100  func formatPostgresResourceURI(resource *dockertest.Resource) string {
   101  	dbParams := map[string]string{
   102  		"POSTGRES_DB":       DBName,
   103  		"POSTGRES_USER":     "lakefs",
   104  		"POSTGRES_PASSWORD": "lakefs",
   105  		"POSTGRES_PORT":     resource.GetPort("5432/tcp"),
   106  	}
   107  	env := resource.Container.Config.Env
   108  	for _, entry := range env {
   109  		for key := range dbParams {
   110  			if strings.HasPrefix(entry, key+"=") {
   111  				dbParams[key] = entry[len(key)+1:]
   112  				break
   113  			}
   114  		}
   115  	}
   116  	uri := fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable",
   117  		dbParams["POSTGRES_USER"],
   118  		dbParams["POSTGRES_PASSWORD"],
   119  		dbParams["POSTGRES_PORT"],
   120  		dbParams["POSTGRES_DB"],
   121  	)
   122  	return uri
   123  }
   124  
   125  func verifyDBConnectionString(uri string) error {
   126  	ctx := context.Background()
   127  	pool, err := pgxpool.New(ctx, uri)
   128  	if err != nil {
   129  		return err
   130  	}
   131  	defer pool.Close()
   132  	return PingPG(ctx, pool)
   133  }
   134  
   135  type GetDBOptions struct {
   136  	ApplyDDL bool
   137  }
   138  
   139  type GetDBOption func(options *GetDBOptions)
   140  
   141  func Must(t testing.TB, err error) {
   142  	t.Helper()
   143  	if err != nil {
   144  		t.Fatalf("error returned for operation: %v", err)
   145  	}
   146  }
   147  
   148  func MustDo(t testing.TB, what string, err error) {
   149  	t.Helper()
   150  	if err != nil {
   151  		t.Fatalf("%s, expected no error, got err=%s", what, err)
   152  	}
   153  }
   154  
   155  func NewBlockAdapterByType(t testing.TB, blockstoreType string) block.Adapter {
   156  	ctx := context.Background()
   157  	switch blockstoreType {
   158  	case block.BlockstoreTypeGS:
   159  		client, err := storage.NewClient(ctx)
   160  		if err != nil {
   161  			t.Fatal("Google Storage new client", err)
   162  		}
   163  		return gs.NewAdapter(client)
   164  
   165  	case block.BlockstoreTypeS3:
   166  		var s3Params params.S3
   167  		if awsRegion, ok := os.LookupEnv(envKeyAwsRegion); ok {
   168  			s3Params.Region = awsRegion
   169  		} else {
   170  			s3Params.Region = "us-east-1"
   171  		}
   172  		awsKey, keyOk := os.LookupEnv(envKeyAwsKeyID)
   173  		awsSecret, secretOk := os.LookupEnv(envKeyAwsSecretKey)
   174  		if keyOk && secretOk {
   175  			s3Params.Credentials.AccessKeyID = awsKey
   176  			s3Params.Credentials.SecretAccessKey = awsSecret
   177  		}
   178  		blockAdapter, err := blocks3.NewAdapter(ctx, s3Params)
   179  		if err != nil {
   180  			t.Fatal("Failed to create S3 block adapter", err)
   181  		}
   182  		return blockAdapter
   183  
   184  	default:
   185  		return mem.New(context.Background())
   186  	}
   187  }
   188  
   189  func PingPG(ctx context.Context, pool *pgxpool.Pool) error {
   190  	conn, err := pool.Acquire(ctx)
   191  	if err != nil {
   192  		return fmt.Errorf("acquire to ping: %w", err)
   193  	}
   194  	defer conn.Release()
   195  	err = conn.Conn().Ping(ctx)
   196  	if err != nil {
   197  		return fmt.Errorf("ping: %w", err)
   198  	}
   199  	return nil
   200  }