github.com/binbinly/pkg@v0.0.11-0.20240321014439-f4fbf666eb0f/storage/orm/gorm.go (about) 1 package orm 2 3 import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "os" 8 9 "gorm.io/driver/mysql" 10 "gorm.io/driver/postgres" 11 "gorm.io/gorm" 12 "gorm.io/gorm/logger" 13 "gorm.io/gorm/schema" 14 "gorm.io/plugin/opentelemetry/tracing" 15 ) 16 17 const ( 18 DriverPostgres = "postgres" 19 DriverMysql = "mysql" 20 ) 21 22 // NewDB create a db 23 func NewDB(c *Config) (db *gorm.DB) { 24 var sqlDB *sql.DB 25 if c.Driver == DriverPostgres { 26 sqlDB = openPostgres(c) 27 } else { 28 sqlDB = openMysql(c) 29 } 30 31 // set for db connection 32 // 用于设置最大打开的连接数,默认值为0表示不限制.设置最大的连接数,可以避免并发太高导致连接mysql出现too many connections的错误。 33 if c.MaxOpenConn > 0 { 34 sqlDB.SetMaxOpenConns(c.MaxOpenConn) 35 } 36 // 用于设置闲置的连接数.设置闲置的连接数则当开启的一个连接使用完成后可以放在池里等候下一次使用。 37 if c.MaxIdleConn > 0 { 38 sqlDB.SetMaxIdleConns(c.MaxIdleConn) 39 } 40 // 单个连接最大存活时间,建议设置比数据库超时时长(wait_timeout)稍小一些 41 if c.ConnMaxLifeTime > 0 { 42 sqlDB.SetConnMaxLifetime(c.ConnMaxLifeTime) 43 } 44 45 var err error 46 if c.Driver == DriverPostgres { 47 db, err = gorm.Open(postgres.New(postgres.Config{Conn: sqlDB, PreferSimpleProtocol: true}), gormConfig(c)) 48 } else { 49 db, err = gorm.Open(mysql.New(mysql.Config{Conn: sqlDB}), gormConfig(c)) 50 } 51 db.Set("gorm:table_options", "CHARSET=utf8mb4") 52 if err != nil { 53 log.Panicf("database %s connection failed. database name: %s, err: %+v", c.Driver, c.Database, err) 54 } 55 56 if c.Trace { //链路追踪 57 if err = db.Use(tracing.NewPlugin(tracing.WithoutMetrics())); err != nil { 58 log.Panicf("use tracing failed. database name: %s, err: %+v", c.Database, err) 59 } 60 } 61 62 return db 63 } 64 65 // gormConfig 根据配置决定是否开启日志 66 func gormConfig(c *Config) *gorm.Config { 67 conf := &gorm.Config{ 68 DisableForeignKeyConstraintWhenMigrating: true, //禁用自动创建数据库外键约束 69 PrepareStmt: true, //PreparedStmt 在执行任何 SQL 时都会创建一个 prepared statement 并将其缓存,以提高后续的效率 70 Logger: logger.Default.LogMode(logger.Warn), 71 } 72 if c.TablePrefix != "" { 73 conf.NamingStrategy = schema.NamingStrategy{ 74 TablePrefix: c.TablePrefix, // 表名前缀,`User` 的表名应该是 `t_users` 75 SingularTable: true, // 使用单数表名,启用该选项,此时,`User` 的表名应该是 `t_user` 76 } 77 } 78 // 打印所有SQL 79 if c.Debug { 80 conf.Logger = logger.Default.LogMode(logger.Info) 81 } 82 // 只打印慢查询 83 if c.SlowThreshold > 0 { 84 conf.Logger = logger.New( 85 //将标准输出作为Writer 86 log.New(os.Stdout, "\r\n", log.LstdFlags), 87 logger.Config{ 88 //设定慢查询时间阈值 89 SlowThreshold: c.SlowThreshold, // nolint: golint 90 //设置日志级别,只有指定级别以上会输出慢查询日志 91 LogLevel: logger.Warn, 92 }, 93 ) 94 } 95 return conf 96 } 97 98 func openMysql(c *Config) *sql.DB { 99 dsn := c.Dsn 100 if dsn == "" { 101 dsn = fmt.Sprintf("%s:%s@tcp(%s:%v)/%s?charset=utf8mb4&parseTime=%t&loc=%s", 102 c.User, c.Password, c.Host, c.Port, c.Database, true, "Local") 103 } 104 sqlDB, err := sql.Open("mysql", dsn) 105 if err != nil { 106 log.Panicf("open %s failed. database name: %s, err: %+v", c.Driver, c.Database, err) 107 } 108 109 return sqlDB 110 } 111 112 func openPostgres(c *Config) *sql.DB { 113 dsn := c.Dsn 114 if dsn == "" { 115 dsn = fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=%s TimeZone=%s", 116 c.Host, c.User, c.Password, c.Database, c.Port, "disable", "Asia/Shanghai") 117 } 118 119 sqlDB, err := sql.Open("pgx", dsn) 120 if err != nil { 121 log.Panicf("open %s failed. database name: %s, err: %+v", c.Driver, c.Database, err) 122 } 123 124 return sqlDB 125 }