goyave.dev/goyave/v5@v5.0.0-rc9.0.20240517145003-d3f977d0b9f3/database/database.go (about)

     1  package database
     2  
     3  import (
     4  	"errors"
     5  	"time"
     6  
     7  	"gorm.io/gorm"
     8  	"goyave.dev/goyave/v5/config"
     9  	"goyave.dev/goyave/v5/slog"
    10  
    11  	errorutil "goyave.dev/goyave/v5/util/errors"
    12  )
    13  
    14  // New create a new connection pool using the settings defined in the given configuration.
    15  //
    16  // In order to use a specific driver / dialect ("mysql", "sqlite3", ...), you must not
    17  // forget to blank-import it in your main file.
    18  //
    19  //	import _ "goyave.dev/goyave/v5/database/dialect/mysql"
    20  //	import _ "goyave.dev/goyave/v5/database/dialect/postgres"
    21  //	import _ "goyave.dev/goyave/v5/database/dialect/sqlite"
    22  //	import _ "goyave.dev/goyave/v5/database/dialect/mssql"
    23  func New(cfg *config.Config, logger func() *slog.Logger) (*gorm.DB, error) {
    24  	driver := cfg.GetString("database.connection")
    25  
    26  	if driver == "none" {
    27  		return nil, errorutil.Errorf("Cannot create DB connection. Database is set to \"none\" in the config")
    28  	}
    29  
    30  	dialect, ok := dialects[driver]
    31  	if !ok {
    32  		return nil, errorutil.Errorf("DB Connection %q not supported, forgotten import?", driver)
    33  	}
    34  
    35  	dsn := dialect.buildDSN(cfg)
    36  	db, err := gorm.Open(dialect.initializer(dsn), newConfig(cfg, logger))
    37  	if err != nil {
    38  		return nil, errorutil.New(err)
    39  	}
    40  
    41  	if err := initTimeoutPlugin(cfg, db); err != nil {
    42  		return db, errorutil.New(err)
    43  	}
    44  
    45  	return db, initSQLDB(cfg, db)
    46  }
    47  
    48  // NewFromDialector create a new connection pool from a gorm dialector and using the settings
    49  // defined in the given configuration.
    50  //
    51  // This can be used in tests to create a mock connection pool.
    52  func NewFromDialector(cfg *config.Config, logger func() *slog.Logger, dialector gorm.Dialector) (*gorm.DB, error) {
    53  	db, err := gorm.Open(dialector, newConfig(cfg, logger))
    54  	if err != nil {
    55  		return nil, errorutil.New(err)
    56  	}
    57  
    58  	if err := initTimeoutPlugin(cfg, db); err != nil {
    59  		return db, errorutil.New(err)
    60  	}
    61  
    62  	return db, initSQLDB(cfg, db)
    63  }
    64  
    65  func newConfig(cfg *config.Config, logger func() *slog.Logger) *gorm.Config {
    66  	if !cfg.GetBool("app.debug") {
    67  		// Stay silent about DB operations when not in debug mode
    68  		logger = nil
    69  	}
    70  	return &gorm.Config{
    71  		Logger:                                   NewLogger(logger),
    72  		SkipDefaultTransaction:                   cfg.GetBool("database.config.skipDefaultTransaction"),
    73  		DryRun:                                   cfg.GetBool("database.config.dryRun"),
    74  		PrepareStmt:                              cfg.GetBool("database.config.prepareStmt"),
    75  		DisableNestedTransaction:                 cfg.GetBool("database.config.disableNestedTransaction"),
    76  		AllowGlobalUpdate:                        cfg.GetBool("database.config.allowGlobalUpdate"),
    77  		DisableAutomaticPing:                     cfg.GetBool("database.config.disableAutomaticPing"),
    78  		DisableForeignKeyConstraintWhenMigrating: cfg.GetBool("database.config.disableForeignKeyConstraintWhenMigrating"),
    79  	}
    80  }
    81  
    82  func initTimeoutPlugin(cfg *config.Config, db *gorm.DB) error {
    83  	timeoutPlugin := &TimeoutPlugin{
    84  		ReadTimeout:  time.Duration(cfg.GetInt("database.defaultReadQueryTimeout")) * time.Millisecond,
    85  		WriteTimeout: time.Duration(cfg.GetInt("database.defaultWriteQueryTimeout")) * time.Millisecond,
    86  	}
    87  	return errorutil.New(db.Use(timeoutPlugin))
    88  }
    89  
    90  func initSQLDB(cfg *config.Config, db *gorm.DB) error {
    91  	sqlDB, err := db.DB()
    92  	if err != nil {
    93  		if errors.Is(err, gorm.ErrInvalidDB) {
    94  			return nil
    95  		}
    96  		return errorutil.New(err)
    97  	}
    98  	sqlDB.SetMaxOpenConns(cfg.GetInt("database.maxOpenConnections"))
    99  	sqlDB.SetMaxIdleConns(cfg.GetInt("database.maxIdleConnections"))
   100  	sqlDB.SetConnMaxLifetime(time.Duration(cfg.GetInt("database.maxLifetime")) * time.Second)
   101  	return nil
   102  }