code.vegaprotocol.io/vega@v0.79.0/datanode/utils/databasetest/setup.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package databasetest
    17  
    18  import (
    19  	"bytes"
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"io/fs"
    24  	"math/rand"
    25  	"net"
    26  	"os"
    27  	"path/filepath"
    28  	"testing"
    29  	"time"
    30  
    31  	"code.vegaprotocol.io/vega/datanode/sqlstore"
    32  	"code.vegaprotocol.io/vega/logging"
    33  
    34  	"github.com/cenkalti/backoff/v4"
    35  	"github.com/jackc/pgx/v4"
    36  )
    37  
    38  var (
    39  	sqlTestsEnabled       = true
    40  	minPort               = 30000
    41  	maxPort               = 40000
    42  	postgresServerTimeout = time.Second * 10
    43  )
    44  
    45  func TestMain(m *testing.M, mainCtx context.Context, onSetupComplete func(sqlstore.Config, *sqlstore.ConnectionSource, *bytes.Buffer),
    46  	postgresRuntimePath string, sqlFs fs.FS,
    47  ) int {
    48  	testDBSocketDir := filepath.Join(postgresRuntimePath)
    49  	testDBPort := 5432
    50  	testDBHost := ""
    51  	sqlConfig := NewTestConfig(testDBPort, testDBHost, testDBSocketDir)
    52  
    53  	if sqlTestsEnabled {
    54  		log := logging.NewTestLogger()
    55  
    56  		err := os.Mkdir(postgresRuntimePath, fs.ModePerm)
    57  		if err != nil {
    58  			panic(err)
    59  		}
    60  		defer os.RemoveAll(postgresRuntimePath)
    61  
    62  		postgresLog := &bytes.Buffer{}
    63  		embeddedPostgres, err := sqlstore.StartEmbeddedPostgres(log, sqlConfig, postgresRuntimePath, postgresLog)
    64  		if err != nil {
    65  			log.Errorf("failed to start postgres: %s", postgresLog.String())
    66  			panic(err)
    67  		}
    68  
    69  		log.Infof("Test DB Socket Directory: %s", testDBSocketDir)
    70  		log.Infof("Test DB Port: %d", testDBPort)
    71  
    72  		// Make sure the database has started before we run the tests.
    73  		ctx, cancel := context.WithTimeout(mainCtx, postgresServerTimeout)
    74  
    75  		op := func() error {
    76  			connStr := sqlConfig.ConnectionConfig.GetConnectionString()
    77  			conn, err := pgx.Connect(ctx, connStr)
    78  			if err != nil {
    79  				return err
    80  			}
    81  
    82  			return conn.Ping(ctx)
    83  		}
    84  
    85  		if err := backoff.Retry(op, backoff.NewExponentialBackOff()); err != nil {
    86  			cancel()
    87  			panic(err)
    88  		}
    89  
    90  		cancel()
    91  		connectionSource, err := sqlstore.NewTransactionalConnectionSource(mainCtx, log, sqlConfig.ConnectionConfig)
    92  		if err != nil {
    93  			panic(err)
    94  		}
    95  		defer embeddedPostgres.Stop()
    96  
    97  		if err = sqlstore.WipeDatabaseAndMigrateSchemaToLatestVersion(log, sqlConfig.ConnectionConfig, sqlFs, false); err != nil {
    98  			log.Errorf("failed to wipe database and migrate schema, dumping postgres log:\n %s", postgresLog.String())
    99  			panic(err)
   100  		}
   101  
   102  		if err = sqlstore.ApplyDataRetentionPolicies(sqlConfig, log); err != nil {
   103  			panic(err)
   104  		}
   105  
   106  		onSetupComplete(sqlConfig, connectionSource, postgresLog)
   107  
   108  		return m.Run()
   109  	}
   110  
   111  	return 0
   112  }
   113  
   114  func NewTestConfig(port int, host, socketDir string) sqlstore.Config {
   115  	sqlConfig := sqlstore.NewDefaultConfig()
   116  	sqlConfig.UseEmbedded = true
   117  	sqlConfig.ConnectionConfig.Port = port
   118  	sqlConfig.ConnectionConfig.Host = host
   119  	sqlConfig.ConnectionConfig.SocketDir = socketDir
   120  
   121  	return sqlConfig
   122  }
   123  
   124  func GetNextFreePort() int {
   125  	rand.Seed(time.Now().UnixNano())
   126  	for {
   127  		port := rand.Intn(maxPort-minPort+1) + minPort
   128  		timeout := time.Millisecond * 100
   129  		conn, err := net.DialTimeout("tcp", net.JoinHostPort("localhost", fmt.Sprintf("%d", port)), timeout)
   130  		if err != nil && !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
   131  			return port
   132  		}
   133  
   134  		if conn != nil {
   135  			conn.Close()
   136  		}
   137  	}
   138  }