gitee.com/h79/goutils@v1.22.10/dao/db/adapter.go (about)

     1  package db
     2  
     3  import (
     4  	"crypto/rsa"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	commonconfig "gitee.com/h79/goutils/common/config"
     9  	commonoption "gitee.com/h79/goutils/common/option"
    10  	"gitee.com/h79/goutils/common/result"
    11  	commontls "gitee.com/h79/goutils/common/tls"
    12  	"gitee.com/h79/goutils/dao/config"
    13  	"gitee.com/h79/goutils/dao/option"
    14  	drivermysql "github.com/go-sql-driver/mysql"
    15  	"gorm.io/driver/mysql"
    16  	"gorm.io/driver/postgres"
    17  	"gorm.io/driver/sqlserver"
    18  	"gorm.io/plugin/dbresolver"
    19  	"net/url"
    20  	"strings"
    21  	"time"
    22  
    23  	"gorm.io/gorm"
    24  	"runtime"
    25  )
    26  
    27  type DialFunc func(string) gorm.Dialector
    28  
    29  var openFuncs = map[string]DialFunc{
    30  	"mysql":     mysql.Open,
    31  	"postgres":  postgres.Open,
    32  	"sqlserver": sqlserver.Open,
    33  }
    34  
    35  var _ Sql = (*Adapter)(nil)
    36  
    37  // Adapter represents the Gorm adapter for policy storage.
    38  type Adapter struct {
    39  	driverName   string
    40  	databaseName string
    41  	dsn          string
    42  	db           *gorm.DB
    43  }
    44  
    45  type ScopesFunc func(db *gorm.DB) *gorm.DB
    46  
    47  // finalizer is the destructor for Adapter.
    48  func finalizer(a *Adapter) {
    49  	sqlDB, err := a.db.DB()
    50  	if err != nil {
    51  		panic(err)
    52  	}
    53  	err = sqlDB.Close()
    54  	if err != nil {
    55  		panic(err)
    56  	}
    57  }
    58  
    59  var DefaultDnsFunc = func(cnf *config.Database, tls, serverPubKey bool) string {
    60  	if cnf == nil {
    61  		return ""
    62  	}
    63  	if cnf.DriverType == "mysql" {
    64  		if cnf.Charset == "" {
    65  			cnf.Charset = "utf8mb4"
    66  		}
    67  		dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local", cnf.User, cnf.Pwd, cnf.Host, cnf.Port, cnf.Name, cnf.Charset)
    68  		if tls {
    69  			dsn += "&tls=" + url.QueryEscape(cnf.Tls.Key)
    70  		}
    71  		if serverPubKey {
    72  			dsn += "&serverPubKey=" + url.QueryEscape(cnf.ServerPubKey.Key)
    73  		}
    74  		return dsn
    75  	} else if cnf.DriverType == "postgres" {
    76  		return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", cnf.Host, cnf.Port, cnf.User, cnf.Pwd, cnf.Name)
    77  	} else if cnf.DriverType == "sqlite3" {
    78  		return cnf.Name
    79  	} else if cnf.DriverType == "sql" {
    80  		return fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s", cnf.User, cnf.Pwd, cnf.Host, cnf.Port, cnf.Name)
    81  	}
    82  	return ""
    83  }
    84  
    85  func WithDnsOption(f func(cnf *config.Database, tls, pubKey bool) string) commonoption.Option {
    86  	return dnsFunc(f)
    87  }
    88  
    89  type dnsFunc func(cnf *config.Database, tls bool, pubKey bool) string
    90  
    91  func (t dnsFunc) String() string {
    92  	return "sql:dns"
    93  }
    94  func (t dnsFunc) Type() int          { return option.TypeSqlDns }
    95  func (t dnsFunc) Value() interface{} { return t }
    96  
    97  func dnsFuncExist(opts ...commonoption.Option) dnsFunc {
    98  	if r, ok := commonoption.Exist(option.TypeSqlDns, opts...); ok {
    99  		return r.Value().(dnsFunc)
   100  	}
   101  	return nil
   102  }
   103  
   104  var DefaultTlsFunc = func(cnf *config.Database) (*tls.Config, error) {
   105  	if strings.EqualFold(cnf.Tls.Key, "true") ||
   106  		strings.EqualFold(cnf.Tls.Key, "false") ||
   107  		strings.EqualFold(cnf.Tls.Key, "skip-verify") ||
   108  		strings.EqualFold(cnf.Tls.Key, "preferred") {
   109  		return nil, nil
   110  	}
   111  	cert, rootCertPool, err := commontls.GetCertificate(&cnf.Tls)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  	return &tls.Config{
   116  		RootCAs:      rootCertPool,
   117  		Certificates: []tls.Certificate{cert},
   118  	}, nil
   119  }
   120  
   121  func WithTlsOption(f func(cnf *config.Database) (*tls.Config, error)) commonoption.Option {
   122  	return tlsFunc(f)
   123  }
   124  
   125  type tlsFunc func(cnf *config.Database) (*tls.Config, error)
   126  
   127  func (t tlsFunc) String() string {
   128  	return "sql:tls"
   129  }
   130  func (t tlsFunc) Type() int          { return option.TypeSqlTls }
   131  func (t tlsFunc) Value() interface{} { return t }
   132  
   133  func tlsFuncExist(opts ...commonoption.Option) tlsFunc {
   134  	if r, ok := commonoption.Exist(option.TypeSqlTls, opts...); ok {
   135  		return r.Value().(tlsFunc)
   136  	}
   137  	return nil
   138  }
   139  
   140  func WithServerPubKeyOption(f func(cnf *config.Database) (*rsa.PublicKey, error)) commonoption.Option {
   141  	return ServerPubKeyFunc(f)
   142  }
   143  
   144  type ServerPubKeyFunc func(cnf *config.Database) (*rsa.PublicKey, error)
   145  
   146  func (t ServerPubKeyFunc) String() string {
   147  	return "sql:serverPubKey"
   148  }
   149  func (t ServerPubKeyFunc) Type() int          { return option.TypeSqlServerPubKey }
   150  func (t ServerPubKeyFunc) Value() interface{} { return t }
   151  
   152  func serverPubKeyFuncExist(opts ...commonoption.Option) ServerPubKeyFunc {
   153  	if r, ok := commonoption.Exist(option.TypeSqlServerPubKey, opts...); ok {
   154  		return r.Value().(ServerPubKeyFunc)
   155  	}
   156  	return nil
   157  }
   158  
   159  func getDns(cnf *config.Database, tls, serverPubKey bool, opts ...commonoption.Option) string {
   160  	fn := dnsFuncExist(opts...)
   161  	if fn == nil {
   162  		fn = DefaultDnsFunc
   163  	}
   164  	return fn(cnf, tls, serverPubKey)
   165  }
   166  
   167  func UseTls(cnf *config.Database, opts ...commonoption.Option) error {
   168  	if cnf.DriverType != "mysql" {
   169  		return result.RErrNotSupport
   170  	}
   171  	fn := tlsFuncExist(opts...)
   172  	if fn == nil {
   173  		fn = DefaultTlsFunc
   174  	}
   175  	tlsCfg, err := fn(cnf)
   176  	if err != nil {
   177  		return err
   178  	}
   179  	if tlsCfg != nil {
   180  		return drivermysql.RegisterTLSConfig(cnf.Tls.Key, tlsCfg)
   181  	}
   182  	return nil
   183  }
   184  
   185  func UseServerPubKey(cnf *config.Database, opts ...commonoption.Option) error {
   186  	if cnf.DriverType != "mysql" {
   187  		return result.RErrNotSupport
   188  	}
   189  	fn := serverPubKeyFuncExist(opts...)
   190  	if fn == nil {
   191  		fn = func(cnf *config.Database) (*rsa.PublicKey, error) {
   192  			return commontls.GetServerPubKey(&cnf.ServerPubKey)
   193  		}
   194  	}
   195  	pk, err := fn(cnf)
   196  	if err != nil {
   197  		return err
   198  	}
   199  	drivermysql.RegisterServerPubKey(cnf.ServerPubKey.Key, pk)
   200  	return nil
   201  }
   202  
   203  // NewAdapter is the constructor for Adapter.
   204  func NewAdapter(cfg *config.Sql, opts ...commonoption.Option) (*Adapter, error) {
   205  	a := &Adapter{}
   206  	tlsIf := false
   207  	err := UseTls(&cfg.Master, opts...)
   208  	if err == nil {
   209  		tlsIf = true
   210  	}
   211  	pubKeyIf := false
   212  	err = UseServerPubKey(&cfg.Master, opts...)
   213  	if err == nil {
   214  		pubKeyIf = true
   215  	}
   216  	a.driverName = cfg.Master.DriverType
   217  	a.databaseName = cfg.Master.Name
   218  	a.dsn = getDns(&cfg.Master, tlsIf, pubKeyIf, opts...)
   219  
   220  	// Open the DB
   221  	db, err := openDB(a.driverName, a.dsn)
   222  	if err != nil {
   223  		return nil, err
   224  	}
   225  	var sources []gorm.Dialector
   226  	var replicas []gorm.Dialector
   227  	for _, source := range cfg.Sources {
   228  		tlsIf = false
   229  		if err = UseTls(&source, opts...); err == nil {
   230  			tlsIf = true
   231  		}
   232  		pubKeyIf = false
   233  		err = UseServerPubKey(&source, opts...)
   234  		if err == nil {
   235  			pubKeyIf = true
   236  		}
   237  		dr, er := getDriver(source.DriverType, getDns(&source, tlsIf, pubKeyIf, opts...))
   238  		if er != nil {
   239  			return nil, er
   240  		}
   241  		sources = append(sources, dr)
   242  	}
   243  	for _, replica := range cfg.Replicas {
   244  		tlsIf = false
   245  		if err = UseTls(&replica, opts...); err == nil {
   246  			tlsIf = true
   247  		}
   248  		pubKeyIf = false
   249  		err = UseServerPubKey(&replica, opts...)
   250  		if err == nil {
   251  			pubKeyIf = true
   252  		}
   253  		dr, er := getDriver(replica.DriverType, getDns(&replica, tlsIf, pubKeyIf, opts...))
   254  		if er != nil {
   255  			return nil, er
   256  		}
   257  		replicas = append(replicas, dr)
   258  	}
   259  	resolver := dbresolver.Register(dbresolver.Config{
   260  		Sources:  sources,
   261  		Replicas: replicas,
   262  		// sources/replicas load balancing policy
   263  		Policy: dbresolver.RandomPolicy{},
   264  	})
   265  	if cfg.MaxOpenConns > 0 {
   266  		resolver.SetMaxOpenConns(cfg.MaxOpenConns)
   267  	}
   268  	if cfg.MaxIdleConns > 0 {
   269  		resolver.SetMaxIdleConns(cfg.MaxIdleConns)
   270  	}
   271  	if cfg.MaxLifetime > 0 {
   272  		resolver.SetConnMaxLifetime(cfg.MaxLifetime)
   273  	}
   274  	if cfg.MaxIdleTime > 0 {
   275  		resolver.SetConnMaxIdleTime(time.Minute * cfg.MaxIdleTime)
   276  	}
   277  	if err = db.Use(resolver); err != nil {
   278  		return nil, err
   279  	}
   280  	if cfg.Logger.LogLevel > 1 {
   281  		if cfg.Logger.SlowThreshold <= 0 {
   282  			cfg.Logger.SlowThreshold = 200
   283  		}
   284  		cfg.Logger.SlowThreshold = cfg.Logger.SlowThreshold * time.Millisecond
   285  		log := &Logger{
   286  			SqlLogger: cfg.Logger,
   287  		}
   288  		db.Logger = log
   289  		if commonconfig.RegisterConfig != nil {
   290  			commonconfig.RegisterConfig("DB|"+cfg.Name, log.handlerConfig)
   291  		}
   292  	}
   293  	a.db = db
   294  
   295  	// Call the destructor when the object is released.
   296  	runtime.SetFinalizer(a, finalizer)
   297  
   298  	return a, nil
   299  }
   300  
   301  func (a *Adapter) Db() *gorm.DB {
   302  	return a.db
   303  }
   304  
   305  func (a *Adapter) Name() string {
   306  	return a.databaseName
   307  }
   308  
   309  func (a *Adapter) Close() {
   310  	a.db = nil
   311  }
   312  
   313  func AddDriver(driverName string, dial DialFunc) {
   314  	openFuncs[driverName] = dial
   315  }
   316  
   317  func getDriver(driverName, dataSourceName string) (gorm.Dialector, error) {
   318  	driver, ok := openFuncs[driverName]
   319  	if !ok {
   320  		return nil, errors.New("database dialect is not supported")
   321  	}
   322  	return driver(dataSourceName), nil
   323  }
   324  
   325  func openDB(driverName, dataSourceName string) (*gorm.DB, error) {
   326  	dr, err := getDriver(driverName, dataSourceName)
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  	return gorm.Open(dr, &gorm.Config{})
   331  }