github.com/monkeswag33/noter-go@v0.0.0-20220505233910-9d72ccb0bdb6/db/db.go (about) 1 package db 2 3 import ( 4 "github.com/monkeswag33/noter-go/errordef" 5 "github.com/monkeswag33/noter-go/types" 6 "github.com/sirupsen/logrus" 7 "github.com/spf13/viper" 8 "gorm.io/driver/postgres" 9 "gorm.io/driver/sqlite" 10 "gorm.io/gorm" 11 "gorm.io/gorm/logger" 12 ) 13 14 type Note struct { 15 ID int 16 Name string `gorm:"unique;not null"` 17 Body string `gorm:"not null"` 18 UserID int `gorm:"not null"` 19 User User `gorm:"constraint:OnDelete:CASCADE;"` 20 } 21 22 type User struct { 23 ID int 24 Username string `gorm:"unique;not null"` 25 Password string `gorm:"not null"` 26 } 27 28 type DB struct { 29 DB *gorm.DB 30 LogLevel types.LogLevelParams 31 } 32 33 var Database *DB 34 35 func SetupDB(logLevel types.LogLevelParams) error { 36 Database = &DB{ 37 LogLevel: logLevel, 38 } 39 if err := Database.Init(); err != nil { 40 return err 41 } 42 return nil 43 } 44 45 func ShutdownDB() error { 46 if err := Database.Close(); err != nil { 47 return err 48 } 49 Database = nil 50 return nil 51 } 52 53 func InitTesterDB() (*DB, error) { 54 // Specify location to be in-memory 55 // Cache is not shared, so each connection will be a seperate database 56 var location string = ":memory:" 57 gormDB, err := gorm.Open(sqlite.Open(location), &gorm.Config{}) 58 if err != nil { 59 return nil, err 60 } 61 var database DB = DB{ 62 LogLevel: types.LogLevelParams{ 63 LogLevel: "warn", 64 GormLogLevel: "info", 65 }, 66 DB: gormDB, 67 } 68 if err := database.Init(); err != nil { 69 return nil, err 70 } 71 return &database, nil 72 } 73 74 func (db *DB) getLogLevel() (loggerLevel logger.LogLevel) { 75 if val, ok := types.GormLogLevels[db.LogLevel.GormLogLevel]; ok { 76 loggerLevel = val 77 } else { 78 logrus.Warnf("Unrecognized gorm log level %q, using default value WARN", db.LogLevel.GormLogLevel) 79 db.LogLevel.GormLogLevel = "warn" 80 loggerLevel = logger.Warn 81 } 82 logrus.Debugf("GORM log level: %q", db.LogLevel.GormLogLevel) 83 return loggerLevel 84 } 85 86 func (db *DB) Init() error { 87 if db.DB == nil { 88 logrus.Info("Looking for POSTGRES_URI environment variable...") 89 var uri string = viper.GetString("POSTGRES_URI") 90 if len(uri) == 0 { 91 return errordef.ErrCouldNotFindPostgresURI 92 } 93 logrus.Info("Found POSTGRES_URI") 94 logrus.Debugf("POSTGRES_URI is: %q", uri) 95 var loggerLevel logger.Interface = logger.Default.LogMode(db.getLogLevel()) 96 logrus.Info("Got GORM log level") 97 98 logrus.Info("Connecting to database...") 99 var err error 100 db.DB, err = gorm.Open(postgres.Open(uri), &gorm.Config{ 101 Logger: loggerLevel, 102 }) 103 logrus.Info("Connected to database") 104 105 if err != nil { 106 return errordef.ErrFailedToConnect 107 } 108 } 109 110 logrus.Info("Running migrations...") 111 db.DB.AutoMigrate(&User{}) 112 logrus.Trace("Migrated users...") 113 db.DB.AutoMigrate(&Note{}) 114 logrus.Trace("Migrated notes...") 115 logrus.Info("Finished migrations") 116 return nil 117 } 118 119 func (db *DB) Close() error { 120 var err error 121 sqlDB, dbErr := db.DB.DB() // Get the actual sqlDB of the database 122 if dbErr == nil { 123 closeErr := sqlDB.Close() 124 if err == nil { 125 logrus.Info("Closed the database") 126 return nil 127 } else { 128 err = closeErr 129 } 130 } else { 131 err = dbErr 132 } 133 return err 134 }