github.com/isyscore/isc-gobase@v1.5.3-0.20231218061332-cbc7451899e9/extend/orm/gorm.go (about) 1 package orm 2 3 import ( 4 "context" 5 "database/sql" 6 driverMysql "github.com/go-sql-driver/mysql" 7 "github.com/isyscore/isc-gobase/bean" 8 "github.com/isyscore/isc-gobase/config" 9 "github.com/isyscore/isc-gobase/constants" 10 "github.com/isyscore/isc-gobase/listener" 11 baseLogger "github.com/isyscore/isc-gobase/logger" 12 "github.com/lib/pq" 13 "github.com/mattn/go-sqlite3" 14 "github.com/qustavo/sqlhooks/v2" 15 "github.com/sirupsen/logrus" 16 "gorm.io/driver/mysql" 17 "gorm.io/driver/postgres" 18 "gorm.io/driver/sqlite" 19 "gorm.io/driver/sqlserver" 20 "gorm.io/gorm" 21 "gorm.io/gorm/logger" 22 "strings" 23 "time" 24 ) 25 26 func NewGormDb() (*gorm.DB, error) { 27 return doNewGormDb("", &gorm.Config{}) 28 } 29 30 func NewGormDbWitConfig(gormConfig *gorm.Config) (*gorm.DB, error) { 31 return doNewGormDb("", gormConfig) 32 } 33 34 func NewGormDbWithName(datasourceName string) (*gorm.DB, error) { 35 return doNewGormDb(datasourceName, &gorm.Config{}) 36 } 37 38 func NewGormDbWithNameAndConfig(datasourceName string, gormConfig *gorm.Config) (*gorm.DB, error) { 39 return doNewGormDb(datasourceName, gormConfig) 40 } 41 42 func doNewGormDb(datasourceName string, gormConfig *gorm.Config) (*gorm.DB, error) { 43 datasourceConfig := config.DatasourceConfig{} 44 targetDatasourceName := "base.datasource" 45 if datasourceName != "" { 46 targetDatasourceName = "base.datasource." + datasourceName 47 } 48 err := config.GetValueObject(targetDatasourceName, &datasourceConfig) 49 if err != nil { 50 baseLogger.Warn("读取读取配置【datasource】异常") 51 return nil, err 52 } 53 54 // 注册原生的sql的hook 55 if len(gormHooks) != 0 { 56 sqlRegister(datasourceConfig.DriverName) 57 } 58 59 var gormDb *gorm.DB 60 dsn := getDbDsn(datasourceConfig.DriverName, datasourceConfig) 61 gormDb, err = gorm.Open(getDialect(dsn, datasourceConfig.DriverName), gormConfig) 62 if err != nil { 63 baseLogger.Warn("获取数据库db异常:%v", err.Error()) 64 return nil, err 65 } 66 67 d, _ := gormDb.DB() 68 69 maxIdleConns := config.GetValueInt("base.datasource.connect-pool.max-idle-conns") 70 if maxIdleConns != 0 { 71 // 设置空闲的最大连接数 72 d.SetMaxIdleConns(maxIdleConns) 73 } 74 75 maxOpenConns := config.GetValueInt("base.datasource.connect-pool.max-open-conns") 76 if maxOpenConns != 0 { 77 // 设置数据库打开连接的最大数量 78 d.SetMaxOpenConns(maxOpenConns) 79 } 80 81 maxLifeTime := config.GetValueString("base.datasource.connect-pool.max-life-time") 82 if maxLifeTime != "" { 83 // 设置连接可重复使用的最大时间 84 t, err := time.ParseDuration(maxLifeTime) 85 if err != nil { 86 baseLogger.Warn("读取配置【base.datasource.connect-pool.max-life-time】异常", err) 87 } else { 88 d.SetConnMaxLifetime(t) 89 } 90 } 91 92 maxIdleTime := config.GetValueString("base.datasource.connect-pool.max-idle-time") 93 if maxIdleTime != "" { 94 // 设置conn最大空闲时间设置连接空闲的最大时间 95 t, err := time.ParseDuration(maxIdleTime) 96 if err != nil { 97 baseLogger.Warn("读取配置【base.datasource.connect-pool.max-idle-time】异常", err) 98 } else { 99 d.SetConnMaxIdleTime(t) 100 } 101 } 102 103 gormDb.Logger = &GormLoggerAdapter{} 104 bean.AddBean(constants.BeanNameGormPre+datasourceName, gormDb) 105 // 添加orm的配置监听器 106 listener.AddListener(listener.EventOfConfigChange, ConfigChangeListenerOfOrm) 107 108 return gormDb, nil 109 } 110 111 // 特殊字符处理 112 func specialCharChange(url string) string { 113 return strings.ReplaceAll(url, "/", "%2F") 114 } 115 116 func getDialect(dsn, driverName string) gorm.Dialector { 117 switch driverName { 118 case "mysql": 119 return mysql.New(getMysqlConfig(dsn, driverName)) 120 case "postgresql": 121 return postgres.New(postgres.Config{DSN: dsn, DriverName: WrapDriverName(driverName)}) 122 case "sqlite": 123 return sqlite.Dialector{DSN: dsn, DriverName: WrapDriverName(driverName)} 124 case "sqlserver": 125 return sqlserver.New(sqlserver.Config{DSN: dsn, DriverName: WrapDriverName(driverName)}) 126 } 127 return nil 128 } 129 130 func sqlRegister(driverName string) { 131 name := WrapDriverName(driverName) 132 for _, driver := range sql.Drivers() { 133 if driver == name { 134 return 135 } 136 } 137 138 switch driverName { 139 case "mysql": 140 sql.Register(name, sqlhooks.Wrap(&driverMysql.MySQLDriver{}, &GobaseSqlHookProxy{DriverName: driverName})) 141 case "postgresql": 142 sql.Register(name, sqlhooks.Wrap(&pq.Driver{}, &GobaseSqlHookProxy{DriverName: driverName})) 143 case "sqlite": 144 sql.Register(name, sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &GobaseSqlHookProxy{DriverName: driverName})) 145 //case "sqlserver": 暂时不支持 146 // sql.Register(WrapDriverName(driverName), sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &GobaseSqlHookProxy{})) 147 } 148 } 149 150 func getMysqlConfig(dsn, driverName string) mysql.Config { 151 return mysql.Config{ 152 DriverName: driverName, 153 DSN: dsn, 154 ServerVersion: config.GetValueStringDefault("base.datasource.mysql.server-version", ""), 155 SkipInitializeWithVersion: config.GetValueBoolDefault("base.datasource.mysql.skip-initialize-with-version", false), 156 DefaultStringSize: config.GetValueUIntDefault("base.datasource.mysql.default-string-size", 0), 157 DisableWithReturning: config.GetValueBoolDefault("base.datasource.mysql.disable-with-returning", false), 158 DisableDatetimePrecision: config.GetValueBoolDefault("base.datasource.mysql.disable-datetime-precision", false), 159 DontSupportRenameIndex: config.GetValueBoolDefault("base.datasource.mysql.dont-support-rename-index", false), 160 DontSupportRenameColumn: config.GetValueBoolDefault("base.datasource.mysql.dont-support-rename-column", false), 161 DontSupportForShareClause: config.GetValueBoolDefault("base.datasource.mysql.dont-support-for-share-clause", false), 162 DontSupportNullAsDefaultValue: config.GetValueBoolDefault("base.datasource.mysql.dont-support-null-as-default-value", false), 163 } 164 } 165 166 func WrapDriverName(driverName string) string { 167 if len(gormHooks) != 0 { 168 return driverName + "Hook" 169 } 170 return driverName 171 } 172 173 type GobaseGormHook interface { 174 Before(ctx context.Context, driverName string, parameters map[string]any) (context.Context, error) 175 After(ctx context.Context, driverName string, parameters map[string]any) (context.Context, error) 176 Err(ctx context.Context, driverName string, err error, parameters map[string]any) error 177 } 178 179 var gormHooks []GobaseGormHook 180 181 func init() { 182 gormHooks = []GobaseGormHook{} 183 } 184 185 func AddGormHook(hook GobaseGormHook) { 186 gormHooks = append(gormHooks, hook) 187 } 188 189 type GobaseSqlHookProxy struct { 190 DriverName string 191 } 192 193 func (proxy *GobaseSqlHookProxy) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 194 var ctxFinal context.Context 195 for _, hook := range gormHooks { 196 parametersMap := map[string]any{ 197 "query": query, 198 "args": args, 199 } 200 _ctx, err := hook.Before(ctx, proxy.DriverName, parametersMap) 201 if err != nil { 202 return _ctx, err 203 } else { 204 ctxFinal = _ctx 205 } 206 } 207 return ctxFinal, nil 208 } 209 210 func (proxy *GobaseSqlHookProxy) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { 211 for _, hook := range gormHooks { 212 parametersMap := map[string]any{ 213 "query": query, 214 "args": args, 215 } 216 ctx, err := hook.After(ctx, proxy.DriverName, parametersMap) 217 if err != nil { 218 return ctx, err 219 } 220 } 221 return ctx, nil 222 } 223 224 func (proxy *GobaseSqlHookProxy) OnError(ctx context.Context, err error, query string, args ...interface{}) error { 225 for _, hook := range gormHooks { 226 parametersMap := map[string]any{ 227 "query": query, 228 "args": args, 229 } 230 err := hook.Err(ctx, proxy.DriverName, err, parametersMap) 231 if err != nil { 232 return err 233 } 234 } 235 return nil 236 } 237 238 type GormLoggerAdapter struct { 239 } 240 241 func (l *GormLoggerAdapter) LogMode(level logger.LogLevel) logger.Interface { 242 var levelStr logrus.Level 243 switch level { 244 case logger.Silent: 245 levelStr = logrus.TraceLevel 246 case logger.Error: 247 levelStr = logrus.ErrorLevel 248 case logger.Warn: 249 levelStr = logrus.WarnLevel 250 case logger.Info: 251 levelStr = logrus.InfoLevel 252 } 253 baseLogger.Group("orm").SetLevel(levelStr) 254 return l 255 } 256 257 func (l *GormLoggerAdapter) Info(ctx context.Context, msg string, data ...interface{}) { 258 baseLogger.Info(msg, data) 259 } 260 261 func (l *GormLoggerAdapter) Warn(ctx context.Context, msg string, data ...interface{}) { 262 baseLogger.Warn(msg, data) 263 } 264 265 func (l *GormLoggerAdapter) Error(ctx context.Context, msg string, data ...interface{}) { 266 baseLogger.Error(msg, data) 267 } 268 269 func (l *GormLoggerAdapter) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { 270 elapsed := time.Since(begin) 271 sqlStr, rowsAffected := fc() 272 if err != nil { 273 baseLogger.Group("orm").Errorf("[SQL][%v]%s; error: %v", elapsed, sqlStr, err.Error()) 274 } else { 275 baseLogger.Group("orm").Debugf("[SQL][%v][row:%v]%s", elapsed, rowsAffected, sqlStr) 276 } 277 }