github.com/etecs-ru/gnomock@v0.13.2/preset/mssql/preset.go (about)

     1  // Package mssql provides a Gnomock Preset for Microsoft SQL Server database
     2  package mssql
     3  
     4  import (
     5  	"context"
     6  	"database/sql"
     7  	"fmt"
     8  	"io/ioutil"
     9  
    10  	_ "github.com/denisenkom/go-mssqldb" // mssql driver
    11  	"github.com/etecs-ru/gnomock"
    12  	"github.com/etecs-ru/gnomock/internal/registry"
    13  )
    14  
    15  const (
    16  	masterDB        = "master"
    17  	defaultPassword = "Gn0m!ck~"
    18  	defaultDatabase = "mydb"
    19  	defaultPort     = 1433
    20  	defaultVersion  = "2019-latest"
    21  )
    22  
    23  func init() {
    24  	registry.Register("mssql", func() gnomock.Preset { return &P{} })
    25  }
    26  
    27  // Preset creates a new Gmomock Microsoft SQL Server preset. This preset
    28  // includes a mssql specific healthcheck function, default mssql image and
    29  // port, and allows to optionally set up initial state.
    30  //
    31  // When used without any configuration, it uses `mydb` database, and `Gn0m!ck~`
    32  // administrator password (user: `sa`). You must accept EULA to use this image
    33  // (`WithLicense` option). By default, version `2019-latest` is used.
    34  func Preset(opts ...Option) gnomock.Preset {
    35  	p := &P{}
    36  
    37  	for _, opt := range opts {
    38  		opt(p)
    39  	}
    40  
    41  	return p
    42  }
    43  
    44  // P is a Gnomock Preset implementation of Microsoft SQL Server database
    45  type P struct {
    46  	DB           string   `json:"db"`
    47  	Password     string   `json:"password"`
    48  	Queries      []string `json:"queries"`
    49  	QueriesFiles []string `json:"queries_files"`
    50  	License      bool     `json:"license"`
    51  	Version      string   `json:"version"`
    52  }
    53  
    54  // Image returns an image that should be pulled to create this container
    55  func (p *P) Image() string {
    56  	return fmt.Sprintf("mcr.microsoft.com/mssql/server:%s", p.Version)
    57  }
    58  
    59  // Ports returns ports that should be used to access this container
    60  func (p *P) Ports() gnomock.NamedPorts {
    61  	return gnomock.DefaultTCP(defaultPort)
    62  }
    63  
    64  // Options returns a list of options to configure this container
    65  func (p *P) Options() []gnomock.Option {
    66  	p.setDefaults()
    67  
    68  	opts := []gnomock.Option{
    69  		gnomock.WithHealthCheck(p.healthcheck),
    70  		gnomock.WithEnv("SA_PASSWORD=" + p.Password),
    71  		gnomock.WithInit(p.initf()),
    72  	}
    73  
    74  	if p.License {
    75  		opts = append(opts, gnomock.WithEnv("ACCEPT_EULA=Y"))
    76  	}
    77  
    78  	return opts
    79  }
    80  
    81  func (p *P) healthcheck(ctx context.Context, c *gnomock.Container) error {
    82  	addr := c.Address(gnomock.DefaultPort)
    83  
    84  	db, err := p.connect(addr, masterDB)
    85  	if err != nil {
    86  		return err
    87  	}
    88  
    89  	defer func() {
    90  		_ = db.Close()
    91  	}()
    92  
    93  	var one int
    94  
    95  	return db.QueryRow(`select 1`).Scan(&one)
    96  }
    97  
    98  func (p *P) initf() gnomock.InitFunc {
    99  	return func(ctx context.Context, c *gnomock.Container) error {
   100  		addr := c.Address(gnomock.DefaultPort)
   101  
   102  		db, err := p.connect(addr, masterDB)
   103  		if err != nil {
   104  			return err
   105  		}
   106  
   107  		defer func() { _ = db.Close() }()
   108  
   109  		_, err = db.Exec("create database " + p.DB)
   110  		if err != nil {
   111  			return fmt.Errorf("can't create database '%s': %w", p.DB, err)
   112  		}
   113  
   114  		db, err = p.connect(addr, p.DB)
   115  		if err != nil {
   116  			return err
   117  		}
   118  
   119  		if len(p.QueriesFiles) > 0 {
   120  			for _, f := range p.QueriesFiles {
   121  				bs, err := ioutil.ReadFile(f) // nolint:gosec
   122  				if err != nil {
   123  					return fmt.Errorf("can't read queries file '%s': %w", f, err)
   124  				}
   125  
   126  				p.Queries = append([]string{string(bs)}, p.Queries...)
   127  			}
   128  		}
   129  
   130  		for _, q := range p.Queries {
   131  			_, err = db.Exec(q)
   132  			if err != nil {
   133  				return err
   134  			}
   135  		}
   136  
   137  		return nil
   138  	}
   139  }
   140  
   141  func (p *P) connect(addr, db string) (*sql.DB, error) {
   142  	connStr := fmt.Sprintf(
   143  		"sqlserver://sa:%s@%s?database=%s",
   144  		p.Password, addr, db,
   145  	)
   146  
   147  	return sql.Open("sqlserver", connStr)
   148  }
   149  
   150  func (p *P) setDefaults() {
   151  	if p.DB == "" {
   152  		p.DB = defaultDatabase
   153  	}
   154  
   155  	if p.Password == "" {
   156  		p.Password = defaultPassword
   157  	}
   158  
   159  	if p.Version == "" {
   160  		p.Version = defaultVersion
   161  	}
   162  }