github.com/monkeswag33/noter-go@v0.0.0-20220505233910-9d72ccb0bdb6/db/db.go (about)

     1  package db
     2  
     3  import (
     4  	"github.com/monkeswag33/noter-go/errordef"
     5  	"github.com/monkeswag33/noter-go/types"
     6  	"github.com/sirupsen/logrus"
     7  	"github.com/spf13/viper"
     8  	"gorm.io/driver/postgres"
     9  	"gorm.io/driver/sqlite"
    10  	"gorm.io/gorm"
    11  	"gorm.io/gorm/logger"
    12  )
    13  
    14  type Note struct {
    15  	ID     int
    16  	Name   string `gorm:"unique;not null"`
    17  	Body   string `gorm:"not null"`
    18  	UserID int    `gorm:"not null"`
    19  	User   User   `gorm:"constraint:OnDelete:CASCADE;"`
    20  }
    21  
    22  type User struct {
    23  	ID       int
    24  	Username string `gorm:"unique;not null"`
    25  	Password string `gorm:"not null"`
    26  }
    27  
    28  type DB struct {
    29  	DB       *gorm.DB
    30  	LogLevel types.LogLevelParams
    31  }
    32  
    33  var Database *DB
    34  
    35  func SetupDB(logLevel types.LogLevelParams) error {
    36  	Database = &DB{
    37  		LogLevel: logLevel,
    38  	}
    39  	if err := Database.Init(); err != nil {
    40  		return err
    41  	}
    42  	return nil
    43  }
    44  
    45  func ShutdownDB() error {
    46  	if err := Database.Close(); err != nil {
    47  		return err
    48  	}
    49  	Database = nil
    50  	return nil
    51  }
    52  
    53  func InitTesterDB() (*DB, error) {
    54  	// Specify location to be in-memory
    55  	// Cache is not shared, so each connection will be a seperate database
    56  	var location string = ":memory:"
    57  	gormDB, err := gorm.Open(sqlite.Open(location), &gorm.Config{})
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	var database DB = DB{
    62  		LogLevel: types.LogLevelParams{
    63  			LogLevel:     "warn",
    64  			GormLogLevel: "info",
    65  		},
    66  		DB: gormDB,
    67  	}
    68  	if err := database.Init(); err != nil {
    69  		return nil, err
    70  	}
    71  	return &database, nil
    72  }
    73  
    74  func (db *DB) getLogLevel() (loggerLevel logger.LogLevel) {
    75  	if val, ok := types.GormLogLevels[db.LogLevel.GormLogLevel]; ok {
    76  		loggerLevel = val
    77  	} else {
    78  		logrus.Warnf("Unrecognized gorm log level %q, using default value WARN", db.LogLevel.GormLogLevel)
    79  		db.LogLevel.GormLogLevel = "warn"
    80  		loggerLevel = logger.Warn
    81  	}
    82  	logrus.Debugf("GORM log level: %q", db.LogLevel.GormLogLevel)
    83  	return loggerLevel
    84  }
    85  
    86  func (db *DB) Init() error {
    87  	if db.DB == nil {
    88  		logrus.Info("Looking for POSTGRES_URI environment variable...")
    89  		var uri string = viper.GetString("POSTGRES_URI")
    90  		if len(uri) == 0 {
    91  			return errordef.ErrCouldNotFindPostgresURI
    92  		}
    93  		logrus.Info("Found POSTGRES_URI")
    94  		logrus.Debugf("POSTGRES_URI is: %q", uri)
    95  		var loggerLevel logger.Interface = logger.Default.LogMode(db.getLogLevel())
    96  		logrus.Info("Got GORM log level")
    97  
    98  		logrus.Info("Connecting to database...")
    99  		var err error
   100  		db.DB, err = gorm.Open(postgres.Open(uri), &gorm.Config{
   101  			Logger: loggerLevel,
   102  		})
   103  		logrus.Info("Connected to database")
   104  
   105  		if err != nil {
   106  			return errordef.ErrFailedToConnect
   107  		}
   108  	}
   109  
   110  	logrus.Info("Running migrations...")
   111  	db.DB.AutoMigrate(&User{})
   112  	logrus.Trace("Migrated users...")
   113  	db.DB.AutoMigrate(&Note{})
   114  	logrus.Trace("Migrated notes...")
   115  	logrus.Info("Finished migrations")
   116  	return nil
   117  }
   118  
   119  func (db *DB) Close() error {
   120  	var err error
   121  	sqlDB, dbErr := db.DB.DB() // Get the actual sqlDB of the database
   122  	if dbErr == nil {
   123  		closeErr := sqlDB.Close()
   124  		if err == nil {
   125  			logrus.Info("Closed the database")
   126  			return nil
   127  		} else {
   128  			err = closeErr
   129  		}
   130  	} else {
   131  		err = dbErr
   132  	}
   133  	return err
   134  }