github.com/isyscore/isc-gobase@v1.5.3-0.20231218061332-cbc7451899e9/extend/orm/gorm.go (about)

     1  package orm
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	driverMysql "github.com/go-sql-driver/mysql"
     7  	"github.com/isyscore/isc-gobase/bean"
     8  	"github.com/isyscore/isc-gobase/config"
     9  	"github.com/isyscore/isc-gobase/constants"
    10  	"github.com/isyscore/isc-gobase/listener"
    11  	baseLogger "github.com/isyscore/isc-gobase/logger"
    12  	"github.com/lib/pq"
    13  	"github.com/mattn/go-sqlite3"
    14  	"github.com/qustavo/sqlhooks/v2"
    15  	"github.com/sirupsen/logrus"
    16  	"gorm.io/driver/mysql"
    17  	"gorm.io/driver/postgres"
    18  	"gorm.io/driver/sqlite"
    19  	"gorm.io/driver/sqlserver"
    20  	"gorm.io/gorm"
    21  	"gorm.io/gorm/logger"
    22  	"strings"
    23  	"time"
    24  )
    25  
    26  func NewGormDb() (*gorm.DB, error) {
    27  	return doNewGormDb("", &gorm.Config{})
    28  }
    29  
    30  func NewGormDbWitConfig(gormConfig *gorm.Config) (*gorm.DB, error) {
    31  	return doNewGormDb("", gormConfig)
    32  }
    33  
    34  func NewGormDbWithName(datasourceName string) (*gorm.DB, error) {
    35  	return doNewGormDb(datasourceName, &gorm.Config{})
    36  }
    37  
    38  func NewGormDbWithNameAndConfig(datasourceName string, gormConfig *gorm.Config) (*gorm.DB, error) {
    39  	return doNewGormDb(datasourceName, gormConfig)
    40  }
    41  
    42  func doNewGormDb(datasourceName string, gormConfig *gorm.Config) (*gorm.DB, error) {
    43  	datasourceConfig := config.DatasourceConfig{}
    44  	targetDatasourceName := "base.datasource"
    45  	if datasourceName != "" {
    46  		targetDatasourceName = "base.datasource." + datasourceName
    47  	}
    48  	err := config.GetValueObject(targetDatasourceName, &datasourceConfig)
    49  	if err != nil {
    50  		baseLogger.Warn("读取读取配置【datasource】异常")
    51  		return nil, err
    52  	}
    53  
    54  	// 注册原生的sql的hook
    55  	if len(gormHooks) != 0 {
    56  		sqlRegister(datasourceConfig.DriverName)
    57  	}
    58  
    59  	var gormDb *gorm.DB
    60  	dsn := getDbDsn(datasourceConfig.DriverName, datasourceConfig)
    61  	gormDb, err = gorm.Open(getDialect(dsn, datasourceConfig.DriverName), gormConfig)
    62  	if err != nil {
    63  		baseLogger.Warn("获取数据库db异常:%v", err.Error())
    64  		return nil, err
    65  	}
    66  
    67  	d, _ := gormDb.DB()
    68  
    69  	maxIdleConns := config.GetValueInt("base.datasource.connect-pool.max-idle-conns")
    70  	if maxIdleConns != 0 {
    71  		// 设置空闲的最大连接数
    72  		d.SetMaxIdleConns(maxIdleConns)
    73  	}
    74  
    75  	maxOpenConns := config.GetValueInt("base.datasource.connect-pool.max-open-conns")
    76  	if maxOpenConns != 0 {
    77  		// 设置数据库打开连接的最大数量
    78  		d.SetMaxOpenConns(maxOpenConns)
    79  	}
    80  
    81  	maxLifeTime := config.GetValueString("base.datasource.connect-pool.max-life-time")
    82  	if maxLifeTime != "" {
    83  		// 设置连接可重复使用的最大时间
    84  		t, err := time.ParseDuration(maxLifeTime)
    85  		if err != nil {
    86  			baseLogger.Warn("读取配置【base.datasource.connect-pool.max-life-time】异常", err)
    87  		} else {
    88  			d.SetConnMaxLifetime(t)
    89  		}
    90  	}
    91  
    92  	maxIdleTime := config.GetValueString("base.datasource.connect-pool.max-idle-time")
    93  	if maxIdleTime != "" {
    94  		// 设置conn最大空闲时间设置连接空闲的最大时间
    95  		t, err := time.ParseDuration(maxIdleTime)
    96  		if err != nil {
    97  			baseLogger.Warn("读取配置【base.datasource.connect-pool.max-idle-time】异常", err)
    98  		} else {
    99  			d.SetConnMaxIdleTime(t)
   100  		}
   101  	}
   102  
   103  	gormDb.Logger = &GormLoggerAdapter{}
   104  	bean.AddBean(constants.BeanNameGormPre+datasourceName, gormDb)
   105  	// 添加orm的配置监听器
   106  	listener.AddListener(listener.EventOfConfigChange, ConfigChangeListenerOfOrm)
   107  
   108  	return gormDb, nil
   109  }
   110  
   111  // 特殊字符处理
   112  func specialCharChange(url string) string {
   113  	return strings.ReplaceAll(url, "/", "%2F")
   114  }
   115  
   116  func getDialect(dsn, driverName string) gorm.Dialector {
   117  	switch driverName {
   118  	case "mysql":
   119  		return mysql.New(getMysqlConfig(dsn, driverName))
   120  	case "postgresql":
   121  		return postgres.New(postgres.Config{DSN: dsn, DriverName: WrapDriverName(driverName)})
   122  	case "sqlite":
   123  		return sqlite.Dialector{DSN: dsn, DriverName: WrapDriverName(driverName)}
   124  	case "sqlserver":
   125  		return sqlserver.New(sqlserver.Config{DSN: dsn, DriverName: WrapDriverName(driverName)})
   126  	}
   127  	return nil
   128  }
   129  
   130  func sqlRegister(driverName string) {
   131  	name := WrapDriverName(driverName)
   132  	for _, driver := range sql.Drivers() {
   133  		if driver == name {
   134  			return
   135  		}
   136  	}
   137  
   138  	switch driverName {
   139  	case "mysql":
   140  		sql.Register(name, sqlhooks.Wrap(&driverMysql.MySQLDriver{}, &GobaseSqlHookProxy{DriverName: driverName}))
   141  	case "postgresql":
   142  		sql.Register(name, sqlhooks.Wrap(&pq.Driver{}, &GobaseSqlHookProxy{DriverName: driverName}))
   143  	case "sqlite":
   144  		sql.Register(name, sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &GobaseSqlHookProxy{DriverName: driverName}))
   145  		//case "sqlserver": 暂时不支持
   146  		//	sql.Register(WrapDriverName(driverName), sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &GobaseSqlHookProxy{}))
   147  	}
   148  }
   149  
   150  func getMysqlConfig(dsn, driverName string) mysql.Config {
   151  	return mysql.Config{
   152  		DriverName: driverName,
   153  		DSN: dsn,
   154  		ServerVersion:                 config.GetValueStringDefault("base.datasource.mysql.server-version", ""),
   155  		SkipInitializeWithVersion:     config.GetValueBoolDefault("base.datasource.mysql.skip-initialize-with-version", false),
   156  		DefaultStringSize:             config.GetValueUIntDefault("base.datasource.mysql.default-string-size", 0),
   157  		DisableWithReturning:          config.GetValueBoolDefault("base.datasource.mysql.disable-with-returning", false),
   158  		DisableDatetimePrecision:      config.GetValueBoolDefault("base.datasource.mysql.disable-datetime-precision", false),
   159  		DontSupportRenameIndex:        config.GetValueBoolDefault("base.datasource.mysql.dont-support-rename-index", false),
   160  		DontSupportRenameColumn:       config.GetValueBoolDefault("base.datasource.mysql.dont-support-rename-column", false),
   161  		DontSupportForShareClause:     config.GetValueBoolDefault("base.datasource.mysql.dont-support-for-share-clause", false),
   162  		DontSupportNullAsDefaultValue: config.GetValueBoolDefault("base.datasource.mysql.dont-support-null-as-default-value", false),
   163  	}
   164  }
   165  
   166  func WrapDriverName(driverName string) string {
   167  	if len(gormHooks) != 0 {
   168  		return driverName + "Hook"
   169  	}
   170  	return driverName
   171  }
   172  
   173  type GobaseGormHook interface {
   174  	Before(ctx context.Context, driverName string, parameters map[string]any) (context.Context, error)
   175  	After(ctx context.Context, driverName string, parameters map[string]any) (context.Context, error)
   176  	Err(ctx context.Context, driverName string, err error, parameters map[string]any) error
   177  }
   178  
   179  var gormHooks []GobaseGormHook
   180  
   181  func init() {
   182  	gormHooks = []GobaseGormHook{}
   183  }
   184  
   185  func AddGormHook(hook GobaseGormHook) {
   186  	gormHooks = append(gormHooks, hook)
   187  }
   188  
   189  type GobaseSqlHookProxy struct {
   190  	DriverName string
   191  }
   192  
   193  func (proxy *GobaseSqlHookProxy) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
   194  	var ctxFinal context.Context
   195  	for _, hook := range gormHooks {
   196  		parametersMap := map[string]any{
   197  			"query": query,
   198  			"args":  args,
   199  		}
   200  		_ctx, err := hook.Before(ctx, proxy.DriverName, parametersMap)
   201  		if err != nil {
   202  			return _ctx, err
   203  		} else {
   204  			ctxFinal = _ctx
   205  		}
   206  	}
   207  	return ctxFinal, nil
   208  }
   209  
   210  func (proxy *GobaseSqlHookProxy) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
   211  	for _, hook := range gormHooks {
   212  		parametersMap := map[string]any{
   213  			"query": query,
   214  			"args":  args,
   215  		}
   216  		ctx, err := hook.After(ctx, proxy.DriverName, parametersMap)
   217  		if err != nil {
   218  			return ctx, err
   219  		}
   220  	}
   221  	return ctx, nil
   222  }
   223  
   224  func (proxy *GobaseSqlHookProxy) OnError(ctx context.Context, err error, query string, args ...interface{}) error {
   225  	for _, hook := range gormHooks {
   226  		parametersMap := map[string]any{
   227  			"query": query,
   228  			"args":  args,
   229  		}
   230  		err := hook.Err(ctx, proxy.DriverName, err, parametersMap)
   231  		if err != nil {
   232  			return err
   233  		}
   234  	}
   235  	return nil
   236  }
   237  
   238  type GormLoggerAdapter struct {
   239  }
   240  
   241  func (l *GormLoggerAdapter) LogMode(level logger.LogLevel) logger.Interface {
   242  	var levelStr logrus.Level
   243  	switch level {
   244  	case logger.Silent:
   245  		levelStr = logrus.TraceLevel
   246  	case logger.Error:
   247  		levelStr = logrus.ErrorLevel
   248  	case logger.Warn:
   249  		levelStr = logrus.WarnLevel
   250  	case logger.Info:
   251  		levelStr = logrus.InfoLevel
   252  	}
   253  	baseLogger.Group("orm").SetLevel(levelStr)
   254  	return l
   255  }
   256  
   257  func (l *GormLoggerAdapter) Info(ctx context.Context, msg string, data ...interface{}) {
   258  	baseLogger.Info(msg, data)
   259  }
   260  
   261  func (l *GormLoggerAdapter) Warn(ctx context.Context, msg string, data ...interface{}) {
   262  	baseLogger.Warn(msg, data)
   263  }
   264  
   265  func (l *GormLoggerAdapter) Error(ctx context.Context, msg string, data ...interface{}) {
   266  	baseLogger.Error(msg, data)
   267  }
   268  
   269  func (l *GormLoggerAdapter) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
   270  	elapsed := time.Since(begin)
   271  	sqlStr, rowsAffected := fc()
   272  	if err != nil {
   273  		baseLogger.Group("orm").Errorf("[SQL][%v]%s; error: %v", elapsed, sqlStr, err.Error())
   274  	} else {
   275  		baseLogger.Group("orm").Debugf("[SQL][%v][row:%v]%s", elapsed, rowsAffected, sqlStr)
   276  	}
   277  }