github.com/wfusion/gofusion@v1.1.14/common/infra/drivers/orm/gorm.go (about) 1 package orm 2 3 import ( 4 "context" 5 "fmt" 6 "path" 7 "time" 8 9 "github.com/pkg/errors" 10 "gorm.io/driver/clickhouse" 11 "gorm.io/driver/mysql" 12 "gorm.io/driver/postgres" 13 "gorm.io/driver/sqlserver" 14 "gorm.io/gorm" 15 16 "github.com/wfusion/gofusion/common/env" 17 "github.com/wfusion/gofusion/common/infra/drivers/orm/opengauss" 18 "github.com/wfusion/gofusion/common/infra/drivers/orm/sqlite" 19 "github.com/wfusion/gofusion/common/utils" 20 ) 21 22 var Gorm Dialect = new(gormDriver) 23 24 type gormDriver struct{} 25 26 type gormDriverOption struct { 27 Driver driver `yaml:"driver"` 28 Dialect dialect `yaml:"dialect"` 29 Timeout string `yaml:"timeout"` 30 ReadTimeout string `yaml:"read_timeout"` 31 WriteTimeout string `yaml:"write_timeout"` 32 User string `yaml:"user"` 33 Password string `yaml:"password"` 34 DBName string `yaml:"db_name"` 35 DBCharset string `yaml:"db_charset"` 36 DBHostname string `yaml:"db_hostname"` 37 DBPort string `yaml:"db_port"` 38 MaxIdleConns int `yaml:"max_idle_conns"` 39 MaxOpenConns int `yaml:"max_open_conns"` 40 Scheme string `yaml:"scheme"` 41 } 42 43 func (g *gormDriver) New(ctx context.Context, option Option, opts ...utils.OptionExtender) (db *DB, err error) { 44 opt := g.parseDBOption(option) 45 gormDB, dialector, err := g.open(opt.Driver, string(opt.Dialect), opt) 46 if err != nil { 47 return 48 } 49 50 sqlDB, err := gormDB.DB() 51 if err != nil { 52 return 53 } 54 55 // optional 56 if opt.MaxOpenConns > 0 { 57 sqlDB.SetMaxOpenConns(opt.MaxOpenConns) 58 } 59 if opt.MaxIdleConns > 0 { 60 sqlDB.SetMaxIdleConns(opt.MaxIdleConns) 61 } 62 if utils.IsStrNotBlank(option.ConnMaxLifeTime) { 63 if liftTime, err := time.ParseDuration(option.ConnMaxLifeTime); err == nil { 64 sqlDB.SetConnMaxLifetime(liftTime) 65 } 66 } 67 if utils.IsStrNotBlank(option.ConnMaxLifeTime) { 68 if idleTime, err := time.ParseDuration(option.ConnMaxIdleTime); err == nil { 69 sqlDB.SetConnMaxIdleTime(idleTime) 70 } 71 } 72 73 newOpt := utils.ApplyOptions[newOption](opts...) 74 if newOpt.logger != nil { 75 gormDB.Logger = newOpt.logger 76 } 77 78 return &DB{DB: gormDB.WithContext(ctx), dialector: dialector}, nil 79 } 80 81 func (g *gormDriver) open(driver driver, dialect string, opt *gormDriverOption) ( 82 db *gorm.DB, dialector gorm.Dialector, err error) { 83 // alternative driver 84 switch driver { 85 case DriverMysql: 86 dialector = mysql.New(mysql.Config{ 87 DriverName: dialect, 88 DSN: g.genMySqlDsn(opt), 89 }) 90 91 case DriverPostgres: 92 if dialect == string(DialectOpenGauss) { 93 dialector = opengauss.New(opengauss.Config{ 94 DriverName: dialect, 95 DSN: g.genPostgresDsn(opt), 96 }) 97 } else { 98 dialector = postgres.New(postgres.Config{ 99 DriverName: dialect, 100 DSN: g.genPostgresDsn(opt), 101 }) 102 } 103 104 // sqlite dsn is filepath 105 // or file::memory:?cache=shared is also available, see also https://www.sqlite.org/inmemorydb.html 106 case DriverSqlite: 107 dialector = sqlite.Open(path.Join(env.WorkDir, path.Clean(opt.DBName))) 108 109 case DriverSqlserver: 110 dialector = sqlserver.New(sqlserver.Config{ 111 DriverName: dialect, 112 DSN: g.genSqlServerDsn(opt), 113 }) 114 115 // tidb is compatible with mysql protocol 116 case DriverTiDB: 117 dialector = mysql.New(mysql.Config{ 118 DriverName: dialect, 119 DSN: g.genMySqlDsn(opt), 120 }) 121 122 case DriverClickhouse: 123 dialector = clickhouse.New(clickhouse.Config{ 124 DriverName: dialect, 125 DSN: g.genClickhouseDsn(opt), 126 }) 127 128 default: 129 panic(errors.Errorf("unknown db driver or dialect: %s %s", driver, dialect)) 130 } 131 132 db, err = gorm.Open(dialector, &gorm.Config{ 133 DisableForeignKeyConstraintWhenMigrating: true, 134 }) 135 return 136 } 137 138 func (g *gormDriver) parseDBOption(option Option) (parsed *gormDriverOption) { 139 parsed = &gormDriverOption{ 140 Driver: option.Driver, 141 Dialect: option.Dialect, 142 Timeout: option.Timeout, 143 ReadTimeout: option.ReadTimeout, 144 WriteTimeout: option.WriteTimeout, 145 User: option.User, 146 Password: option.Password, 147 DBName: option.DB, 148 DBCharset: "utf8mb4,utf8", 149 DBHostname: option.Host, 150 DBPort: fmt.Sprintf("%v", option.Port), 151 Scheme: "tcp", 152 } 153 154 if option.Driver != "" { 155 parsed.Driver = option.Driver 156 } 157 if option.MaxIdleConns > 0 { 158 parsed.MaxIdleConns = option.MaxIdleConns 159 } 160 if option.MaxOpenConns > 0 { 161 parsed.MaxOpenConns = option.MaxOpenConns 162 } 163 164 if utils.IsStrBlank(string(parsed.Dialect)) { 165 parsed.Dialect = defaultDriverDialectMapping[parsed.Driver] 166 } 167 168 return 169 } 170 171 func (g *gormDriver) genMySqlDsn(opt *gormDriverOption) (dsn string) { 172 if opt.DBCharset == "" { 173 opt.DBCharset = "utf8" 174 } 175 if opt.Scheme == "" { 176 opt.Scheme = "tcp" 177 } 178 179 const ( 180 dsnFormat = "%s:%s@%s(%s:%s)/%s?charset=%s&parseTime=True&loc=Local&timeout=%s&readTimeout=%s&writeTimeout=%s" 181 ) 182 183 return fmt.Sprintf(dsnFormat, opt.User, opt.Password, opt.Scheme, opt.DBHostname, opt.DBPort, opt.DBName, 184 opt.DBCharset, opt.Timeout, opt.ReadTimeout, opt.WriteTimeout) 185 } 186 187 func (g *gormDriver) genPostgresDsn(opt *gormDriverOption) (dsn string) { 188 const ( 189 dsnFormat = "host=%s user=%s password=%s dbname=%s port=%s sslmode=disable TimeZone=Asia/Shanghai" 190 ) 191 192 return fmt.Sprintf(dsnFormat, opt.DBHostname, opt.User, opt.Password, opt.DBName, opt.DBPort) 193 } 194 195 func (g *gormDriver) genSqlServerDsn(opt *gormDriverOption) (dsn string) { 196 const ( 197 dsnFormat = "sqlserver://%s:%s@%s:%s?database=%s&connection+timeout=%s" 198 ) 199 200 timeout := "5" // seconds 201 if utils.IsStrNotBlank(opt.Timeout) { 202 if duration, err := time.ParseDuration(opt.Timeout); err == nil { 203 timeout = fmt.Sprintf("%v", int(duration/time.Second)) 204 } 205 } 206 207 return fmt.Sprintf(dsnFormat, opt.User, opt.Password, opt.DBHostname, opt.DBPort, opt.DBName, timeout) 208 } 209 210 func (g *gormDriver) genClickhouseDsn(opt *gormDriverOption) (dsn string) { 211 const ( 212 dsnFormat = "tcp://%s:%s?database=%s&username=%s&password=%s&read_timeout=%s&write_timeout=%s" 213 ) 214 215 readTimeout := "2" // seconds 216 if utils.IsStrNotBlank(opt.ReadTimeout) { 217 if duration, err := time.ParseDuration(opt.ReadTimeout); err == nil { 218 readTimeout = fmt.Sprintf("%v", int(duration/time.Second)) 219 } 220 } 221 writeTimeout := "2" // seconds 222 if utils.IsStrNotBlank(opt.ReadTimeout) { 223 if duration, err := time.ParseDuration(opt.WriteTimeout); err == nil { 224 writeTimeout = fmt.Sprintf("%v", int(duration/time.Second)) 225 } 226 } 227 228 return fmt.Sprintf(dsnFormat, opt.DBHostname, opt.DBPort, opt.DBName, opt.User, opt.Password, 229 readTimeout, writeTimeout) 230 }