github.com/navikt/knorten@v0.0.0-20240419132333-1333f46ed8b6/local/dbsetup/dbsetup.go (about)

     1  package dbsetup
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"log"
     8  	"os"
     9  	"time"
    10  
    11  	"github.com/jackc/pgx/v5"
    12  	"github.com/jackc/pgx/v5/pgtype"
    13  	"github.com/ory/dockertest/v3"
    14  )
    15  
    16  func SetupDB(ctx context.Context, dbURL, dbname string) error {
    17  	db, err := pgx.Connect(ctx, dbURL)
    18  	if err != nil {
    19  		return err
    20  	}
    21  	defer db.Close(ctx)
    22  
    23  	fmt.Println("Successfully connected!")
    24  
    25  	if err := db.QueryRow(ctx, "SELECT FROM pg_catalog.pg_database WHERE datname = $1", dbname).Scan(); err != nil {
    26  		if err == pgx.ErrNoRows {
    27  			fmt.Printf("Creating database %v\n", dbname)
    28  			_, err := db.Exec(ctx, fmt.Sprintf("CREATE DATABASE %v", dbname))
    29  			if err != nil {
    30  				return err
    31  			}
    32  		} else {
    33  			return err
    34  		}
    35  	}
    36  
    37  	err = db.Close(ctx)
    38  	if err != nil {
    39  		return err
    40  	}
    41  
    42  	db, err = pgx.Connect(ctx, dbURL+"/"+dbname)
    43  	if err != nil {
    44  		return err
    45  	}
    46  	defer db.Close(ctx)
    47  
    48  	if err := db.QueryRow(ctx, "SELECT FROM pg_tables WHERE schemaname = 'public' AND tablename  = $1", "goose_db_version").Scan(); err != nil {
    49  		if err == pgx.ErrNoRows {
    50  			fmt.Println("You need to run `make goose cmd=up`")
    51  			os.Exit(1)
    52  		} else {
    53  			return err
    54  		}
    55  	}
    56  
    57  	var oid uint32
    58  	err = db.QueryRow(context.Background(), "select oid from pg_type where typname=$1;", "chart_type").Scan(&oid)
    59  	if err != nil {
    60  		return err
    61  	}
    62  
    63  	db.TypeMap().RegisterType(&pgtype.Type{Name: "chart_type", OID: oid, Codec: &pgtype.EnumCodec{}})
    64  
    65  	airflowContainer := func(name string) string {
    66  		return fmt.Sprintf(`[{"name": "%v", "image": "registry.k8s.io/git-sync/git-sync:v3.6.3","args": ["", "", "/dags", "60"], "volumeMounts":[{"mountPath":"/dags","name":"dags"}]}]`, name)
    67  	}
    68  
    69  	fmt.Println("Time to insert dummy data for local development")
    70  	rows := [][]interface{}{
    71  		{"airflow", "config.core.dags_folder", `"/dags"`},
    72  		{"airflow", "createUserJob.serviceAccount.create", "false"},
    73  		{"airflow", "postgresql.enabled", "false"},
    74  		{"airflow", "scheduler.extraContainers", airflowContainer("git-nada")},
    75  		{"airflow", "scheduler.extraInitContainers", airflowContainer("git-nada-clone")},
    76  		{"airflow", "webserver.extraContainers", airflowContainer("git-nada")},
    77  		{"airflow", "webserver.serviceAccount.create", "false"},
    78  		{"airflow", "webserverSecretKeySecretName", "airflow-webserver"},
    79  		{"airflow", "workers.extraInitContainers", airflowContainer("git-nada")},
    80  		{"airflow", "workers.serviceAccount.create", "false"},
    81  		{"jupyterhub", "singleuser.profileList", "[]"},
    82  	}
    83  	_, err = db.CopyFrom(ctx,
    84  		pgx.Identifier{"chart_global_values"},
    85  		[]string{"chart_type", "key", "value"},
    86  		pgx.CopyFromRows(rows))
    87  	if err != nil {
    88  		return err
    89  	}
    90  
    91  	return nil
    92  }
    93  
    94  func SetupDBForTests() (string, error) {
    95  	dbString := "user=postgres dbname=knorten sslmode=disable password=postgres host=db port=5432"
    96  
    97  	if os.Getenv("CLOUDBUILD") != "true" {
    98  		dockerHost := os.Getenv("HOME") + "/.colima/docker.sock"
    99  		_, err := os.Stat(dockerHost)
   100  		if err != nil {
   101  			// uses a sensible default on windows (tcp/http) and linux/osx (socket)
   102  			dockerHost = ""
   103  		} else {
   104  			dockerHost = "unix://" + dockerHost
   105  		}
   106  
   107  		pool, err := dockertest.NewPool(dockerHost)
   108  		if err != nil {
   109  			log.Fatalf("Could not connect to docker: %s", err)
   110  		}
   111  
   112  		// pulls an image, creates a container based on it and runs it
   113  		resource, err := pool.Run("postgres", "14", []string{"POSTGRES_PASSWORD=postgres", "POSTGRES_DB=knorten"})
   114  		if err != nil {
   115  			log.Fatalf("Could not start resource: %s", err)
   116  		}
   117  
   118  		// setting resource timeout as postgres container is not terminated automatically
   119  		if err = resource.Expire(120); err != nil {
   120  			log.Fatalf("failed creating postgres expire: %v", err)
   121  		}
   122  
   123  		dbPort := resource.GetPort("5432/tcp")
   124  		dbString = fmt.Sprintf("user=postgres dbname=knorten sslmode=disable password=postgres host=localhost port=%v", dbPort)
   125  	}
   126  
   127  	if err := waitForDB(dbString); err != nil {
   128  		log.Fatal(err)
   129  	}
   130  
   131  	return dbString, nil
   132  }
   133  
   134  func waitForDB(dbString string) error {
   135  	sleepDuration := 1 * time.Second
   136  	numRetries := 60
   137  	for i := 0; i < numRetries; i++ {
   138  		time.Sleep(sleepDuration)
   139  		db, err := sql.Open("postgres", dbString)
   140  		if err != nil {
   141  			return err
   142  		}
   143  
   144  		if err := db.Ping(); err == nil {
   145  			return nil
   146  		}
   147  	}
   148  
   149  	return fmt.Errorf("unable to connect to db in %v seconds", int(sleepDuration)*numRetries/1000000000)
   150  }