github.com/etecs-ru/gnomock@v0.13.2/preset/postgres/preset.go (about)

     1  // Package postgres provides a Gnomock Preset for PostgreSQL database.
     2  package postgres
     3  
     4  import (
     5  	"context"
     6  	"database/sql"
     7  	"fmt"
     8  	"io/ioutil"
     9  
    10  	"github.com/etecs-ru/gnomock"
    11  	"github.com/etecs-ru/gnomock/internal/registry"
    12  	_ "github.com/lib/pq" // postgres driver
    13  )
    14  
    15  const (
    16  	defaultUser     = "postgres"
    17  	defaultPassword = "password"
    18  	defaultDatabase = "postgres"
    19  	defaultSSLMode  = "disable"
    20  	defaultPort     = 5432
    21  	defaultVersion  = "12.5"
    22  )
    23  
    24  func init() {
    25  	registry.Register("postgres", func() gnomock.Preset { return &P{} })
    26  }
    27  
    28  // Preset creates a new Gmomock Postgres preset. This preset includes a Postgres
    29  // specific healthcheck function, default Postgres image and port, and allows to
    30  // optionally set up initial state.
    31  //
    32  // By default, this preset uses `postgres` user with `password` password, with
    33  // default database `postgres`. Default PostgresQL version is 12.5.
    34  func Preset(opts ...Option) gnomock.Preset {
    35  	p := &P{}
    36  
    37  	for _, opt := range opts {
    38  		opt(p)
    39  	}
    40  
    41  	return p
    42  }
    43  
    44  // P is a Gnomock Preset implementation of PostgreSQL database
    45  type P struct {
    46  	DB           string   `json:"db"`
    47  	Queries      []string `json:"queries"`
    48  	QueriesFiles []string `json:"queries_files"`
    49  	User         string   `json:"user"`
    50  	Password     string   `json:"password"`
    51  	Version      string   `json:"version"`
    52  }
    53  
    54  // Image returns an image that should be pulled to create this container
    55  func (p *P) Image() string {
    56  	return fmt.Sprintf("docker.io/library/postgres:%s", p.Version)
    57  }
    58  
    59  // Ports returns ports that should be used to access this container
    60  func (p *P) Ports() gnomock.NamedPorts {
    61  	return gnomock.DefaultTCP(defaultPort)
    62  }
    63  
    64  // Options returns a list of options to configure this container
    65  func (p *P) Options() []gnomock.Option {
    66  	p.setDefaults()
    67  
    68  	if p.User != "" && p.Password != "" {
    69  		q := fmt.Sprintf(
    70  			`create user %s with superuser password '%s'`,
    71  			p.User, p.Password,
    72  		)
    73  
    74  		p.Queries = append(p.Queries, q)
    75  	}
    76  
    77  	opts := []gnomock.Option{
    78  		gnomock.WithHealthCheck(p.healthcheck),
    79  		gnomock.WithEnv("POSTGRES_PASSWORD=" + defaultPassword),
    80  		gnomock.WithInit(p.initf()),
    81  	}
    82  
    83  	return opts
    84  }
    85  
    86  func (p *P) healthcheck(ctx context.Context, c *gnomock.Container) error {
    87  	db, err := connect(c, defaultDatabase)
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	defer func() {
    93  		_ = db.Close()
    94  	}()
    95  
    96  	var one int
    97  
    98  	return db.QueryRow(`select 1`).Scan(&one)
    99  }
   100  
   101  func (p *P) initf() gnomock.InitFunc {
   102  	return func(ctx context.Context, c *gnomock.Container) error {
   103  		if p.DB != defaultDatabase {
   104  			db, err := connect(c, defaultDatabase)
   105  			if err != nil {
   106  				return err
   107  			}
   108  
   109  			_, err = db.Exec("create database " + p.DB)
   110  			if err != nil {
   111  				return err
   112  			}
   113  
   114  			_ = db.Close()
   115  		}
   116  
   117  		db, err := connect(c, p.DB)
   118  		if err != nil {
   119  			return err
   120  		}
   121  
   122  		defer func() { _ = db.Close() }()
   123  
   124  		if len(p.QueriesFiles) > 0 {
   125  			for _, f := range p.QueriesFiles {
   126  				bs, err := ioutil.ReadFile(f) // nolint:gosec
   127  				if err != nil {
   128  					return fmt.Errorf("can't read queries file '%s': %w", f, err)
   129  				}
   130  
   131  				p.Queries = append([]string{string(bs)}, p.Queries...)
   132  			}
   133  		}
   134  
   135  		for _, q := range p.Queries {
   136  			_, err = db.Exec(q)
   137  			if err != nil {
   138  				return err
   139  			}
   140  		}
   141  
   142  		return nil
   143  	}
   144  }
   145  
   146  func (p *P) setDefaults() {
   147  	if p.DB == "" {
   148  		p.DB = defaultDatabase
   149  	}
   150  
   151  	if p.Version == "" {
   152  		p.Version = defaultVersion
   153  	}
   154  }
   155  
   156  func connect(c *gnomock.Container, db string) (*sql.DB, error) {
   157  	connStr := fmt.Sprintf(
   158  		"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
   159  		c.Host, c.Port(gnomock.DefaultPort),
   160  		defaultUser, defaultPassword, db, defaultSSLMode,
   161  	)
   162  
   163  	conn, err := sql.Open("postgres", connStr)
   164  	if err != nil {
   165  		return nil, err
   166  	}
   167  
   168  	return conn, conn.Ping()
   169  }