gitee.com/h79/goutils@v1.22.10/dao/db/adapter.go (about) 1 package db 2 3 import ( 4 "crypto/rsa" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 commonconfig "gitee.com/h79/goutils/common/config" 9 commonoption "gitee.com/h79/goutils/common/option" 10 "gitee.com/h79/goutils/common/result" 11 commontls "gitee.com/h79/goutils/common/tls" 12 "gitee.com/h79/goutils/dao/config" 13 "gitee.com/h79/goutils/dao/option" 14 drivermysql "github.com/go-sql-driver/mysql" 15 "gorm.io/driver/mysql" 16 "gorm.io/driver/postgres" 17 "gorm.io/driver/sqlserver" 18 "gorm.io/plugin/dbresolver" 19 "net/url" 20 "strings" 21 "time" 22 23 "gorm.io/gorm" 24 "runtime" 25 ) 26 27 type DialFunc func(string) gorm.Dialector 28 29 var openFuncs = map[string]DialFunc{ 30 "mysql": mysql.Open, 31 "postgres": postgres.Open, 32 "sqlserver": sqlserver.Open, 33 } 34 35 var _ Sql = (*Adapter)(nil) 36 37 // Adapter represents the Gorm adapter for policy storage. 38 type Adapter struct { 39 driverName string 40 databaseName string 41 dsn string 42 db *gorm.DB 43 } 44 45 type ScopesFunc func(db *gorm.DB) *gorm.DB 46 47 // finalizer is the destructor for Adapter. 48 func finalizer(a *Adapter) { 49 sqlDB, err := a.db.DB() 50 if err != nil { 51 panic(err) 52 } 53 err = sqlDB.Close() 54 if err != nil { 55 panic(err) 56 } 57 } 58 59 var DefaultDnsFunc = func(cnf *config.Database, tls, serverPubKey bool) string { 60 if cnf == nil { 61 return "" 62 } 63 if cnf.DriverType == "mysql" { 64 if cnf.Charset == "" { 65 cnf.Charset = "utf8mb4" 66 } 67 dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local", cnf.User, cnf.Pwd, cnf.Host, cnf.Port, cnf.Name, cnf.Charset) 68 if tls { 69 dsn += "&tls=" + url.QueryEscape(cnf.Tls.Key) 70 } 71 if serverPubKey { 72 dsn += "&serverPubKey=" + url.QueryEscape(cnf.ServerPubKey.Key) 73 } 74 return dsn 75 } else if cnf.DriverType == "postgres" { 76 return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", cnf.Host, cnf.Port, cnf.User, cnf.Pwd, cnf.Name) 77 } else if cnf.DriverType == "sqlite3" { 78 return cnf.Name 79 } else if cnf.DriverType == "sql" { 80 return fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s", cnf.User, cnf.Pwd, cnf.Host, cnf.Port, cnf.Name) 81 } 82 return "" 83 } 84 85 func WithDnsOption(f func(cnf *config.Database, tls, pubKey bool) string) commonoption.Option { 86 return dnsFunc(f) 87 } 88 89 type dnsFunc func(cnf *config.Database, tls bool, pubKey bool) string 90 91 func (t dnsFunc) String() string { 92 return "sql:dns" 93 } 94 func (t dnsFunc) Type() int { return option.TypeSqlDns } 95 func (t dnsFunc) Value() interface{} { return t } 96 97 func dnsFuncExist(opts ...commonoption.Option) dnsFunc { 98 if r, ok := commonoption.Exist(option.TypeSqlDns, opts...); ok { 99 return r.Value().(dnsFunc) 100 } 101 return nil 102 } 103 104 var DefaultTlsFunc = func(cnf *config.Database) (*tls.Config, error) { 105 if strings.EqualFold(cnf.Tls.Key, "true") || 106 strings.EqualFold(cnf.Tls.Key, "false") || 107 strings.EqualFold(cnf.Tls.Key, "skip-verify") || 108 strings.EqualFold(cnf.Tls.Key, "preferred") { 109 return nil, nil 110 } 111 cert, rootCertPool, err := commontls.GetCertificate(&cnf.Tls) 112 if err != nil { 113 return nil, err 114 } 115 return &tls.Config{ 116 RootCAs: rootCertPool, 117 Certificates: []tls.Certificate{cert}, 118 }, nil 119 } 120 121 func WithTlsOption(f func(cnf *config.Database) (*tls.Config, error)) commonoption.Option { 122 return tlsFunc(f) 123 } 124 125 type tlsFunc func(cnf *config.Database) (*tls.Config, error) 126 127 func (t tlsFunc) String() string { 128 return "sql:tls" 129 } 130 func (t tlsFunc) Type() int { return option.TypeSqlTls } 131 func (t tlsFunc) Value() interface{} { return t } 132 133 func tlsFuncExist(opts ...commonoption.Option) tlsFunc { 134 if r, ok := commonoption.Exist(option.TypeSqlTls, opts...); ok { 135 return r.Value().(tlsFunc) 136 } 137 return nil 138 } 139 140 func WithServerPubKeyOption(f func(cnf *config.Database) (*rsa.PublicKey, error)) commonoption.Option { 141 return ServerPubKeyFunc(f) 142 } 143 144 type ServerPubKeyFunc func(cnf *config.Database) (*rsa.PublicKey, error) 145 146 func (t ServerPubKeyFunc) String() string { 147 return "sql:serverPubKey" 148 } 149 func (t ServerPubKeyFunc) Type() int { return option.TypeSqlServerPubKey } 150 func (t ServerPubKeyFunc) Value() interface{} { return t } 151 152 func serverPubKeyFuncExist(opts ...commonoption.Option) ServerPubKeyFunc { 153 if r, ok := commonoption.Exist(option.TypeSqlServerPubKey, opts...); ok { 154 return r.Value().(ServerPubKeyFunc) 155 } 156 return nil 157 } 158 159 func getDns(cnf *config.Database, tls, serverPubKey bool, opts ...commonoption.Option) string { 160 fn := dnsFuncExist(opts...) 161 if fn == nil { 162 fn = DefaultDnsFunc 163 } 164 return fn(cnf, tls, serverPubKey) 165 } 166 167 func UseTls(cnf *config.Database, opts ...commonoption.Option) error { 168 if cnf.DriverType != "mysql" { 169 return result.RErrNotSupport 170 } 171 fn := tlsFuncExist(opts...) 172 if fn == nil { 173 fn = DefaultTlsFunc 174 } 175 tlsCfg, err := fn(cnf) 176 if err != nil { 177 return err 178 } 179 if tlsCfg != nil { 180 return drivermysql.RegisterTLSConfig(cnf.Tls.Key, tlsCfg) 181 } 182 return nil 183 } 184 185 func UseServerPubKey(cnf *config.Database, opts ...commonoption.Option) error { 186 if cnf.DriverType != "mysql" { 187 return result.RErrNotSupport 188 } 189 fn := serverPubKeyFuncExist(opts...) 190 if fn == nil { 191 fn = func(cnf *config.Database) (*rsa.PublicKey, error) { 192 return commontls.GetServerPubKey(&cnf.ServerPubKey) 193 } 194 } 195 pk, err := fn(cnf) 196 if err != nil { 197 return err 198 } 199 drivermysql.RegisterServerPubKey(cnf.ServerPubKey.Key, pk) 200 return nil 201 } 202 203 // NewAdapter is the constructor for Adapter. 204 func NewAdapter(cfg *config.Sql, opts ...commonoption.Option) (*Adapter, error) { 205 a := &Adapter{} 206 tlsIf := false 207 err := UseTls(&cfg.Master, opts...) 208 if err == nil { 209 tlsIf = true 210 } 211 pubKeyIf := false 212 err = UseServerPubKey(&cfg.Master, opts...) 213 if err == nil { 214 pubKeyIf = true 215 } 216 a.driverName = cfg.Master.DriverType 217 a.databaseName = cfg.Master.Name 218 a.dsn = getDns(&cfg.Master, tlsIf, pubKeyIf, opts...) 219 220 // Open the DB 221 db, err := openDB(a.driverName, a.dsn) 222 if err != nil { 223 return nil, err 224 } 225 var sources []gorm.Dialector 226 var replicas []gorm.Dialector 227 for _, source := range cfg.Sources { 228 tlsIf = false 229 if err = UseTls(&source, opts...); err == nil { 230 tlsIf = true 231 } 232 pubKeyIf = false 233 err = UseServerPubKey(&source, opts...) 234 if err == nil { 235 pubKeyIf = true 236 } 237 dr, er := getDriver(source.DriverType, getDns(&source, tlsIf, pubKeyIf, opts...)) 238 if er != nil { 239 return nil, er 240 } 241 sources = append(sources, dr) 242 } 243 for _, replica := range cfg.Replicas { 244 tlsIf = false 245 if err = UseTls(&replica, opts...); err == nil { 246 tlsIf = true 247 } 248 pubKeyIf = false 249 err = UseServerPubKey(&replica, opts...) 250 if err == nil { 251 pubKeyIf = true 252 } 253 dr, er := getDriver(replica.DriverType, getDns(&replica, tlsIf, pubKeyIf, opts...)) 254 if er != nil { 255 return nil, er 256 } 257 replicas = append(replicas, dr) 258 } 259 resolver := dbresolver.Register(dbresolver.Config{ 260 Sources: sources, 261 Replicas: replicas, 262 // sources/replicas load balancing policy 263 Policy: dbresolver.RandomPolicy{}, 264 }) 265 if cfg.MaxOpenConns > 0 { 266 resolver.SetMaxOpenConns(cfg.MaxOpenConns) 267 } 268 if cfg.MaxIdleConns > 0 { 269 resolver.SetMaxIdleConns(cfg.MaxIdleConns) 270 } 271 if cfg.MaxLifetime > 0 { 272 resolver.SetConnMaxLifetime(cfg.MaxLifetime) 273 } 274 if cfg.MaxIdleTime > 0 { 275 resolver.SetConnMaxIdleTime(time.Minute * cfg.MaxIdleTime) 276 } 277 if err = db.Use(resolver); err != nil { 278 return nil, err 279 } 280 if cfg.Logger.LogLevel > 1 { 281 if cfg.Logger.SlowThreshold <= 0 { 282 cfg.Logger.SlowThreshold = 200 283 } 284 cfg.Logger.SlowThreshold = cfg.Logger.SlowThreshold * time.Millisecond 285 log := &Logger{ 286 SqlLogger: cfg.Logger, 287 } 288 db.Logger = log 289 if commonconfig.RegisterConfig != nil { 290 commonconfig.RegisterConfig("DB|"+cfg.Name, log.handlerConfig) 291 } 292 } 293 a.db = db 294 295 // Call the destructor when the object is released. 296 runtime.SetFinalizer(a, finalizer) 297 298 return a, nil 299 } 300 301 func (a *Adapter) Db() *gorm.DB { 302 return a.db 303 } 304 305 func (a *Adapter) Name() string { 306 return a.databaseName 307 } 308 309 func (a *Adapter) Close() { 310 a.db = nil 311 } 312 313 func AddDriver(driverName string, dial DialFunc) { 314 openFuncs[driverName] = dial 315 } 316 317 func getDriver(driverName, dataSourceName string) (gorm.Dialector, error) { 318 driver, ok := openFuncs[driverName] 319 if !ok { 320 return nil, errors.New("database dialect is not supported") 321 } 322 return driver(dataSourceName), nil 323 } 324 325 func openDB(driverName, dataSourceName string) (*gorm.DB, error) { 326 dr, err := getDriver(driverName, dataSourceName) 327 if err != nil { 328 return nil, err 329 } 330 return gorm.Open(dr, &gorm.Config{}) 331 }