github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/testutil/db.go (about) 1 package testutil 2 3 import ( 4 "context" 5 "flag" 6 "fmt" 7 "log" 8 "os" 9 "strings" 10 "testing" 11 12 "cloud.google.com/go/storage" 13 "github.com/jackc/pgx/v5/pgxpool" 14 _ "github.com/jackc/pgx/v5/stdlib" 15 "github.com/ory/dockertest/v3" 16 "github.com/treeverse/lakefs/pkg/block" 17 "github.com/treeverse/lakefs/pkg/block/gs" 18 "github.com/treeverse/lakefs/pkg/block/mem" 19 "github.com/treeverse/lakefs/pkg/block/params" 20 blocks3 "github.com/treeverse/lakefs/pkg/block/s3" 21 ) 22 23 const ( 24 DBName = "lakefs_db" 25 DBContainerTimeoutSeconds = 60 * 30 // 30 minutes 26 27 EnvKeyUseBlockAdapter = "USE_BLOCK_ADAPTER" // pragma: allowlist secret 28 envKeyAwsKeyID = "AWS_ACCESS_KEY_ID" 29 envKeyAwsSecretKey = "AWS_SECRET_ACCESS_KEY" //nolint:gosec 30 envKeyAwsRegion = "AWS_DEFAULT_REGION" // pragma: allowlist secret 31 ) 32 33 var keepDB = flag.Bool("keep-db", false, "keep test DB instance running") 34 35 func GetDBInstance(pool *dockertest.Pool) (string, func()) { 36 // connect using docker container name 37 containerName := os.Getenv("PG_DB_CONTAINER") 38 if containerName != "" { 39 resource, ok := pool.ContainerByName(containerName) 40 if !ok { 41 log.Fatalf("Cloud not find DB container (%s)", containerName) 42 } 43 uri := formatPostgresResourceURI(resource) 44 return uri, func() {} 45 } 46 47 // connect using supply address 48 dbURI := os.Getenv("PG_TEST_URI") 49 if len(dbURI) > 0 { 50 // use supplied DB connection for testing 51 if err := verifyDBConnectionString(dbURI); err != nil { 52 log.Fatalf("could not connect to postgres: %s", err) 53 } 54 return dbURI, func() {} 55 } 56 57 // run new instance and connect 58 resource, err := pool.Run("postgres", "11", []string{ 59 "POSTGRES_USER=lakefs", 60 "POSTGRES_PASSWORD=lakefs", 61 fmt.Sprintf("POSTGRES_DB=%s", DBName), 62 }) 63 if err != nil { 64 log.Fatalf("Could not start postgresql: %s", err) 65 } 66 67 // expire the container, just to be on the safe side 68 if !*keepDB { 69 err = resource.Expire(DBContainerTimeoutSeconds) 70 if err != nil { 71 log.Fatalf("could not expire postgres container") 72 } 73 } 74 75 // format db uri 76 uri := formatPostgresResourceURI(resource) 77 78 // wait for container to start and connect to db 79 if err = pool.Retry(func() error { 80 return verifyDBConnectionString(uri) 81 }); err != nil { 82 log.Fatalf("could not connect to postgres: %s", err) 83 } 84 85 // set cleanup 86 closer := func() { 87 if *keepDB { 88 return 89 } 90 err := pool.Purge(resource) 91 if err != nil { 92 log.Fatalf("could not kill postgres container") 93 } 94 } 95 96 // return DB address and closer func 97 return uri, closer 98 } 99 100 func formatPostgresResourceURI(resource *dockertest.Resource) string { 101 dbParams := map[string]string{ 102 "POSTGRES_DB": DBName, 103 "POSTGRES_USER": "lakefs", 104 "POSTGRES_PASSWORD": "lakefs", 105 "POSTGRES_PORT": resource.GetPort("5432/tcp"), 106 } 107 env := resource.Container.Config.Env 108 for _, entry := range env { 109 for key := range dbParams { 110 if strings.HasPrefix(entry, key+"=") { 111 dbParams[key] = entry[len(key)+1:] 112 break 113 } 114 } 115 } 116 uri := fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable", 117 dbParams["POSTGRES_USER"], 118 dbParams["POSTGRES_PASSWORD"], 119 dbParams["POSTGRES_PORT"], 120 dbParams["POSTGRES_DB"], 121 ) 122 return uri 123 } 124 125 func verifyDBConnectionString(uri string) error { 126 ctx := context.Background() 127 pool, err := pgxpool.New(ctx, uri) 128 if err != nil { 129 return err 130 } 131 defer pool.Close() 132 return PingPG(ctx, pool) 133 } 134 135 type GetDBOptions struct { 136 ApplyDDL bool 137 } 138 139 type GetDBOption func(options *GetDBOptions) 140 141 func Must(t testing.TB, err error) { 142 t.Helper() 143 if err != nil { 144 t.Fatalf("error returned for operation: %v", err) 145 } 146 } 147 148 func MustDo(t testing.TB, what string, err error) { 149 t.Helper() 150 if err != nil { 151 t.Fatalf("%s, expected no error, got err=%s", what, err) 152 } 153 } 154 155 func NewBlockAdapterByType(t testing.TB, blockstoreType string) block.Adapter { 156 ctx := context.Background() 157 switch blockstoreType { 158 case block.BlockstoreTypeGS: 159 client, err := storage.NewClient(ctx) 160 if err != nil { 161 t.Fatal("Google Storage new client", err) 162 } 163 return gs.NewAdapter(client) 164 165 case block.BlockstoreTypeS3: 166 var s3Params params.S3 167 if awsRegion, ok := os.LookupEnv(envKeyAwsRegion); ok { 168 s3Params.Region = awsRegion 169 } else { 170 s3Params.Region = "us-east-1" 171 } 172 awsKey, keyOk := os.LookupEnv(envKeyAwsKeyID) 173 awsSecret, secretOk := os.LookupEnv(envKeyAwsSecretKey) 174 if keyOk && secretOk { 175 s3Params.Credentials.AccessKeyID = awsKey 176 s3Params.Credentials.SecretAccessKey = awsSecret 177 } 178 blockAdapter, err := blocks3.NewAdapter(ctx, s3Params) 179 if err != nil { 180 t.Fatal("Failed to create S3 block adapter", err) 181 } 182 return blockAdapter 183 184 default: 185 return mem.New(context.Background()) 186 } 187 } 188 189 func PingPG(ctx context.Context, pool *pgxpool.Pool) error { 190 conn, err := pool.Acquire(ctx) 191 if err != nil { 192 return fmt.Errorf("acquire to ping: %w", err) 193 } 194 defer conn.Release() 195 err = conn.Conn().Ping(ctx) 196 if err != nil { 197 return fmt.Errorf("ping: %w", err) 198 } 199 return nil 200 }