github.com/wfusion/gofusion@v1.1.14/common/infra/drivers/orm/gorm.go (about)

     1  package orm
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"path"
     7  	"time"
     8  
     9  	"github.com/pkg/errors"
    10  	"gorm.io/driver/clickhouse"
    11  	"gorm.io/driver/mysql"
    12  	"gorm.io/driver/postgres"
    13  	"gorm.io/driver/sqlserver"
    14  	"gorm.io/gorm"
    15  
    16  	"github.com/wfusion/gofusion/common/env"
    17  	"github.com/wfusion/gofusion/common/infra/drivers/orm/opengauss"
    18  	"github.com/wfusion/gofusion/common/infra/drivers/orm/sqlite"
    19  	"github.com/wfusion/gofusion/common/utils"
    20  )
    21  
    22  var Gorm Dialect = new(gormDriver)
    23  
    24  type gormDriver struct{}
    25  
    26  type gormDriverOption struct {
    27  	Driver       driver  `yaml:"driver"`
    28  	Dialect      dialect `yaml:"dialect"`
    29  	Timeout      string  `yaml:"timeout"`
    30  	ReadTimeout  string  `yaml:"read_timeout"`
    31  	WriteTimeout string  `yaml:"write_timeout"`
    32  	User         string  `yaml:"user"`
    33  	Password     string  `yaml:"password"`
    34  	DBName       string  `yaml:"db_name"`
    35  	DBCharset    string  `yaml:"db_charset"`
    36  	DBHostname   string  `yaml:"db_hostname"`
    37  	DBPort       string  `yaml:"db_port"`
    38  	MaxIdleConns int     `yaml:"max_idle_conns"`
    39  	MaxOpenConns int     `yaml:"max_open_conns"`
    40  	Scheme       string  `yaml:"scheme"`
    41  }
    42  
    43  func (g *gormDriver) New(ctx context.Context, option Option, opts ...utils.OptionExtender) (db *DB, err error) {
    44  	opt := g.parseDBOption(option)
    45  	gormDB, dialector, err := g.open(opt.Driver, string(opt.Dialect), opt)
    46  	if err != nil {
    47  		return
    48  	}
    49  
    50  	sqlDB, err := gormDB.DB()
    51  	if err != nil {
    52  		return
    53  	}
    54  
    55  	// optional
    56  	if opt.MaxOpenConns > 0 {
    57  		sqlDB.SetMaxOpenConns(opt.MaxOpenConns)
    58  	}
    59  	if opt.MaxIdleConns > 0 {
    60  		sqlDB.SetMaxIdleConns(opt.MaxIdleConns)
    61  	}
    62  	if utils.IsStrNotBlank(option.ConnMaxLifeTime) {
    63  		if liftTime, err := time.ParseDuration(option.ConnMaxLifeTime); err == nil {
    64  			sqlDB.SetConnMaxLifetime(liftTime)
    65  		}
    66  	}
    67  	if utils.IsStrNotBlank(option.ConnMaxLifeTime) {
    68  		if idleTime, err := time.ParseDuration(option.ConnMaxIdleTime); err == nil {
    69  			sqlDB.SetConnMaxIdleTime(idleTime)
    70  		}
    71  	}
    72  
    73  	newOpt := utils.ApplyOptions[newOption](opts...)
    74  	if newOpt.logger != nil {
    75  		gormDB.Logger = newOpt.logger
    76  	}
    77  
    78  	return &DB{DB: gormDB.WithContext(ctx), dialector: dialector}, nil
    79  }
    80  
    81  func (g *gormDriver) open(driver driver, dialect string, opt *gormDriverOption) (
    82  	db *gorm.DB, dialector gorm.Dialector, err error) {
    83  	// alternative driver
    84  	switch driver {
    85  	case DriverMysql:
    86  		dialector = mysql.New(mysql.Config{
    87  			DriverName: dialect,
    88  			DSN:        g.genMySqlDsn(opt),
    89  		})
    90  
    91  	case DriverPostgres:
    92  		if dialect == string(DialectOpenGauss) {
    93  			dialector = opengauss.New(opengauss.Config{
    94  				DriverName: dialect,
    95  				DSN:        g.genPostgresDsn(opt),
    96  			})
    97  		} else {
    98  			dialector = postgres.New(postgres.Config{
    99  				DriverName: dialect,
   100  				DSN:        g.genPostgresDsn(opt),
   101  			})
   102  		}
   103  
   104  	// sqlite dsn is filepath
   105  	// or file::memory:?cache=shared is also available, see also https://www.sqlite.org/inmemorydb.html
   106  	case DriverSqlite:
   107  		dialector = sqlite.Open(path.Join(env.WorkDir, path.Clean(opt.DBName)))
   108  
   109  	case DriverSqlserver:
   110  		dialector = sqlserver.New(sqlserver.Config{
   111  			DriverName: dialect,
   112  			DSN:        g.genSqlServerDsn(opt),
   113  		})
   114  
   115  	// tidb is compatible with mysql protocol
   116  	case DriverTiDB:
   117  		dialector = mysql.New(mysql.Config{
   118  			DriverName: dialect,
   119  			DSN:        g.genMySqlDsn(opt),
   120  		})
   121  
   122  	case DriverClickhouse:
   123  		dialector = clickhouse.New(clickhouse.Config{
   124  			DriverName: dialect,
   125  			DSN:        g.genClickhouseDsn(opt),
   126  		})
   127  
   128  	default:
   129  		panic(errors.Errorf("unknown db driver or dialect: %s %s", driver, dialect))
   130  	}
   131  
   132  	db, err = gorm.Open(dialector, &gorm.Config{
   133  		DisableForeignKeyConstraintWhenMigrating: true,
   134  	})
   135  	return
   136  }
   137  
   138  func (g *gormDriver) parseDBOption(option Option) (parsed *gormDriverOption) {
   139  	parsed = &gormDriverOption{
   140  		Driver:       option.Driver,
   141  		Dialect:      option.Dialect,
   142  		Timeout:      option.Timeout,
   143  		ReadTimeout:  option.ReadTimeout,
   144  		WriteTimeout: option.WriteTimeout,
   145  		User:         option.User,
   146  		Password:     option.Password,
   147  		DBName:       option.DB,
   148  		DBCharset:    "utf8mb4,utf8",
   149  		DBHostname:   option.Host,
   150  		DBPort:       fmt.Sprintf("%v", option.Port),
   151  		Scheme:       "tcp",
   152  	}
   153  
   154  	if option.Driver != "" {
   155  		parsed.Driver = option.Driver
   156  	}
   157  	if option.MaxIdleConns > 0 {
   158  		parsed.MaxIdleConns = option.MaxIdleConns
   159  	}
   160  	if option.MaxOpenConns > 0 {
   161  		parsed.MaxOpenConns = option.MaxOpenConns
   162  	}
   163  
   164  	if utils.IsStrBlank(string(parsed.Dialect)) {
   165  		parsed.Dialect = defaultDriverDialectMapping[parsed.Driver]
   166  	}
   167  
   168  	return
   169  }
   170  
   171  func (g *gormDriver) genMySqlDsn(opt *gormDriverOption) (dsn string) {
   172  	if opt.DBCharset == "" {
   173  		opt.DBCharset = "utf8"
   174  	}
   175  	if opt.Scheme == "" {
   176  		opt.Scheme = "tcp"
   177  	}
   178  
   179  	const (
   180  		dsnFormat = "%s:%s@%s(%s:%s)/%s?charset=%s&parseTime=True&loc=Local&timeout=%s&readTimeout=%s&writeTimeout=%s"
   181  	)
   182  
   183  	return fmt.Sprintf(dsnFormat, opt.User, opt.Password, opt.Scheme, opt.DBHostname, opt.DBPort, opt.DBName,
   184  		opt.DBCharset, opt.Timeout, opt.ReadTimeout, opt.WriteTimeout)
   185  }
   186  
   187  func (g *gormDriver) genPostgresDsn(opt *gormDriverOption) (dsn string) {
   188  	const (
   189  		dsnFormat = "host=%s user=%s password=%s dbname=%s port=%s sslmode=disable TimeZone=Asia/Shanghai"
   190  	)
   191  
   192  	return fmt.Sprintf(dsnFormat, opt.DBHostname, opt.User, opt.Password, opt.DBName, opt.DBPort)
   193  }
   194  
   195  func (g *gormDriver) genSqlServerDsn(opt *gormDriverOption) (dsn string) {
   196  	const (
   197  		dsnFormat = "sqlserver://%s:%s@%s:%s?database=%s&connection+timeout=%s"
   198  	)
   199  
   200  	timeout := "5" // seconds
   201  	if utils.IsStrNotBlank(opt.Timeout) {
   202  		if duration, err := time.ParseDuration(opt.Timeout); err == nil {
   203  			timeout = fmt.Sprintf("%v", int(duration/time.Second))
   204  		}
   205  	}
   206  
   207  	return fmt.Sprintf(dsnFormat, opt.User, opt.Password, opt.DBHostname, opt.DBPort, opt.DBName, timeout)
   208  }
   209  
   210  func (g *gormDriver) genClickhouseDsn(opt *gormDriverOption) (dsn string) {
   211  	const (
   212  		dsnFormat = "tcp://%s:%s?database=%s&username=%s&password=%s&read_timeout=%s&write_timeout=%s"
   213  	)
   214  
   215  	readTimeout := "2" // seconds
   216  	if utils.IsStrNotBlank(opt.ReadTimeout) {
   217  		if duration, err := time.ParseDuration(opt.ReadTimeout); err == nil {
   218  			readTimeout = fmt.Sprintf("%v", int(duration/time.Second))
   219  		}
   220  	}
   221  	writeTimeout := "2" // seconds
   222  	if utils.IsStrNotBlank(opt.ReadTimeout) {
   223  		if duration, err := time.ParseDuration(opt.WriteTimeout); err == nil {
   224  			writeTimeout = fmt.Sprintf("%v", int(duration/time.Second))
   225  		}
   226  	}
   227  
   228  	return fmt.Sprintf(dsnFormat, opt.DBHostname, opt.DBPort, opt.DBName, opt.User, opt.Password,
   229  		readTimeout, writeTimeout)
   230  }