github.com/etecs-ru/gnomock@v0.13.2/preset/mariadb/preset.go (about)

     1  // Package mariadb provides a Gnomock Preset for MariaDB database
     2  package mariadb
     3  
     4  import (
     5  	"context"
     6  	"database/sql"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"log"
    10  	"sync"
    11  
    12  	"github.com/etecs-ru/gnomock"
    13  	"github.com/etecs-ru/gnomock/internal/registry"
    14  	mysqldriver "github.com/go-sql-driver/mysql"
    15  )
    16  
    17  const (
    18  	defaultUser     = "gnomock"
    19  	defaultPassword = "gnoria"
    20  	defaultDatabase = "mydb"
    21  	defaultPort     = 3306
    22  	defaultVersion  = "10.5.8"
    23  )
    24  
    25  var setLoggerOnce sync.Once
    26  
    27  func init() {
    28  	registry.Register("mariadb", func() gnomock.Preset { return &P{} })
    29  }
    30  
    31  // Preset creates a new Gmomock MariaDB preset. This preset includes a MariaDB
    32  // specific healthcheck function, default MariaDB image and port, and allows to
    33  // optionally set up initial state.
    34  //
    35  // When used without any configuration, it creates a superuser `gnomock` with
    36  // password `gnoria`, and `mydb` database. Default MariaDB version is 10.5.8.
    37  func Preset(opts ...Option) gnomock.Preset {
    38  	p := &P{}
    39  
    40  	for _, opt := range opts {
    41  		opt(p)
    42  	}
    43  
    44  	return p
    45  }
    46  
    47  // P is a Gnomock Preset implementation of MariaDB database
    48  type P struct {
    49  	DB           string   `json:"db"`
    50  	User         string   `json:"user"`
    51  	Password     string   `json:"password"`
    52  	Queries      []string `json:"queries"`
    53  	QueriesFiles []string `json:"queries_files"`
    54  	Version      string   `json:"version"`
    55  }
    56  
    57  // Image returns an image that should be pulled to create this container
    58  func (p *P) Image() string {
    59  	return fmt.Sprintf("docker.io/library/mariadb:%s", p.Version)
    60  }
    61  
    62  // Ports returns ports that should be used to access this container
    63  func (p *P) Ports() gnomock.NamedPorts {
    64  	return gnomock.DefaultTCP(defaultPort)
    65  }
    66  
    67  // Options returns a list of options to configure this container
    68  func (p *P) Options() []gnomock.Option {
    69  	setLoggerOnce.Do(func() {
    70  		// err is always nil for non-nil logger
    71  		_ = mysqldriver.SetLogger(log.New(ioutil.Discard, "", -1))
    72  	})
    73  
    74  	p.setDefaults()
    75  
    76  	opts := []gnomock.Option{
    77  		gnomock.WithHealthCheck(p.healthcheck),
    78  		gnomock.WithEnv("MYSQL_USER=" + p.User),
    79  		gnomock.WithEnv("MYSQL_PASSWORD=" + p.Password),
    80  		gnomock.WithEnv("MYSQL_DATABASE=" + p.DB),
    81  		gnomock.WithEnv("MYSQL_RANDOM_ROOT_PASSWORD=yes"),
    82  		gnomock.WithInit(p.initf()),
    83  	}
    84  
    85  	return opts
    86  }
    87  
    88  func (p *P) healthcheck(ctx context.Context, c *gnomock.Container) error {
    89  	addr := c.Address(gnomock.DefaultPort)
    90  
    91  	db, err := p.connect(addr)
    92  	if err != nil {
    93  		return err
    94  	}
    95  
    96  	defer func() {
    97  		_ = db.Close()
    98  	}()
    99  
   100  	var one int
   101  
   102  	return db.QueryRow(`select 1`).Scan(&one)
   103  }
   104  
   105  func (p *P) initf() gnomock.InitFunc {
   106  	return func(ctx context.Context, c *gnomock.Container) error {
   107  		addr := c.Address(gnomock.DefaultPort)
   108  
   109  		db, err := p.connect(addr)
   110  		if err != nil {
   111  			return err
   112  		}
   113  
   114  		defer func() { _ = db.Close() }()
   115  
   116  		if len(p.QueriesFiles) > 0 {
   117  			for _, f := range p.QueriesFiles {
   118  				bs, err := ioutil.ReadFile(f) // nolint:gosec
   119  				if err != nil {
   120  					return fmt.Errorf("can't read queries file '%s': %w", f, err)
   121  				}
   122  
   123  				p.Queries = append([]string{string(bs)}, p.Queries...)
   124  			}
   125  		}
   126  
   127  		for _, q := range p.Queries {
   128  			_, err = db.Exec(q)
   129  			if err != nil {
   130  				return err
   131  			}
   132  		}
   133  
   134  		return nil
   135  	}
   136  }
   137  
   138  func (p *P) connect(addr string) (*sql.DB, error) {
   139  	connStr := fmt.Sprintf(
   140  		"%s:%s@tcp(%s)/%s?multiStatements=true",
   141  		p.User, p.Password, addr, p.DB,
   142  	)
   143  
   144  	db, err := sql.Open("mysql", connStr)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	return db, db.Ping()
   150  }
   151  
   152  func (p *P) setDefaults() {
   153  	if p.DB == "" {
   154  		p.DB = defaultDatabase
   155  	}
   156  
   157  	if p.User == "" {
   158  		p.User = defaultUser
   159  	}
   160  
   161  	if p.Password == "" {
   162  		p.Password = defaultPassword
   163  	}
   164  
   165  	if p.Version == "" {
   166  		p.Version = defaultVersion
   167  	}
   168  }