github.com/0xPolygon/supernets2-node@v0.0.0-20230711153321-2fe574524eaa/db/db.go (about)

     1  package db
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/0xPolygon/supernets2-node/log"
     8  	"github.com/gobuffalo/packr/v2"
     9  	"github.com/jackc/pgx/v4"
    10  	"github.com/jackc/pgx/v4/pgxpool"
    11  	"github.com/jackc/pgx/v4/stdlib"
    12  	migrate "github.com/rubenv/sql-migrate"
    13  )
    14  
    15  const (
    16  	// StateMigrationName is the name of the migration used by packr to pack the migration file
    17  	StateMigrationName = "supernets2-state-db"
    18  	// PoolMigrationName is the name of the migration used by packr to pack the migration file
    19  	PoolMigrationName = "supernets2-pool-db"
    20  )
    21  
    22  var packrMigrations = map[string]*packr.Box{
    23  	StateMigrationName: packr.New(StateMigrationName, "./migrations/state"),
    24  	PoolMigrationName:  packr.New(PoolMigrationName, "./migrations/pool"),
    25  }
    26  
    27  // NewSQLDB creates a new SQL DB
    28  func NewSQLDB(cfg Config) (*pgxpool.Pool, error) {
    29  	config, err := pgxpool.ParseConfig(fmt.Sprintf("postgres://%s:%s@%s:%s/%s?pool_max_conns=%d", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name, cfg.MaxConns))
    30  	if err != nil {
    31  		log.Errorf("Unable to parse DB config: %v\n", err)
    32  		return nil, err
    33  	}
    34  	if cfg.EnableLog {
    35  		config.ConnConfig.Logger = logger{}
    36  	}
    37  	conn, err := pgxpool.ConnectConfig(context.Background(), config)
    38  	if err != nil {
    39  		log.Errorf("Unable to connect to database: %v\n", err)
    40  		return nil, err
    41  	}
    42  	return conn, nil
    43  }
    44  
    45  // RunMigrationsUp runs migrate-up for the given config.
    46  func RunMigrationsUp(cfg Config, name string) error {
    47  	log.Info("running migrations up")
    48  	return runMigrations(cfg, name, migrate.Up)
    49  }
    50  
    51  // CheckMigrations runs migrate-up for the given config.
    52  func CheckMigrations(cfg Config, name string) error {
    53  	return checkMigrations(cfg, name, migrate.Up)
    54  }
    55  
    56  // RunMigrationsDown runs migrate-down for the given config.
    57  func RunMigrationsDown(cfg Config, name string) error {
    58  	log.Info("running migrations down")
    59  	return runMigrations(cfg, name, migrate.Down)
    60  }
    61  
    62  // runMigrations will execute pending migrations if needed to keep
    63  // the database updated with the latest changes in either direction,
    64  // up or down.
    65  func runMigrations(cfg Config, packrName string, direction migrate.MigrationDirection) error {
    66  	c, err := pgx.ParseConfig(fmt.Sprintf("postgres://%s:%s@%s:%s/%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name))
    67  	if err != nil {
    68  		return err
    69  	}
    70  	db := stdlib.OpenDB(*c)
    71  
    72  	box, ok := packrMigrations[packrName]
    73  	if !ok {
    74  		return fmt.Errorf("packr box not found with name: %v", packrName)
    75  	}
    76  
    77  	var migrations = &migrate.PackrMigrationSource{Box: box}
    78  	nMigrations, err := migrate.Exec(db, "postgres", migrations, direction)
    79  	if err != nil {
    80  		return err
    81  	}
    82  
    83  	log.Info("successfully ran ", nMigrations, " migrations")
    84  	return nil
    85  }
    86  
    87  func checkMigrations(cfg Config, packrName string, direction migrate.MigrationDirection) error {
    88  	c, err := pgx.ParseConfig(fmt.Sprintf("postgres://%s:%s@%s:%s/%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Name))
    89  	if err != nil {
    90  		return err
    91  	}
    92  	db := stdlib.OpenDB(*c)
    93  
    94  	box, ok := packrMigrations[packrName]
    95  	if !ok {
    96  		return fmt.Errorf("packr box not found with name: %v", packrName)
    97  	}
    98  
    99  	migrationSource := &migrate.PackrMigrationSource{Box: box}
   100  	migrations, err := migrationSource.FindMigrations()
   101  	if err != nil {
   102  		log.Errorf("error getting migrations from source: %v", err)
   103  		return err
   104  	}
   105  
   106  	var expected int
   107  	for _, migration := range migrations {
   108  		if len(migration.Up) != 0 {
   109  			expected++
   110  		}
   111  	}
   112  
   113  	var actual int
   114  	query := `SELECT COUNT(1) FROM public.gorp_migrations`
   115  	err = db.QueryRow(query).Scan(&actual)
   116  	if err != nil {
   117  		log.Error("error getting migrations count: ", err)
   118  		return err
   119  	}
   120  	if expected == actual {
   121  		log.Infof("Found %d migrations as expected", actual)
   122  	} else {
   123  		return fmt.Errorf("error the component needs to run %d migrations before starting. DB only contains %d migrations", expected, actual)
   124  	}
   125  	return nil
   126  }