github.com/wfusion/gofusion@v1.1.14/common/infra/drivers/orm/opengauss/opengauss.go (about)

     1  package opengauss
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"regexp"
     7  	"strconv"
     8  	"strings"
     9  
    10  	"github.com/spf13/cast"
    11  	"gorm.io/gorm"
    12  	"gorm.io/gorm/callbacks"
    13  	"gorm.io/gorm/clause"
    14  	"gorm.io/gorm/logger"
    15  	"gorm.io/gorm/migrator"
    16  	"gorm.io/gorm/schema"
    17  
    18  	"github.com/wfusion/gofusion/common/utils"
    19  
    20  	pq "gitee.com/opengauss/openGauss-connector-go-pq"
    21  )
    22  
    23  type Dialector struct {
    24  	*Config
    25  }
    26  
    27  type Config struct {
    28  	DriverName       string
    29  	DSN              string
    30  	WithoutReturning bool
    31  	Conn             gorm.ConnPool
    32  }
    33  
    34  func Open(dsn string) gorm.Dialector {
    35  	return &Dialector{&Config{DSN: dsn}}
    36  }
    37  
    38  func New(config Config) gorm.Dialector {
    39  	return &Dialector{Config: &config}
    40  }
    41  
    42  func (dialector Dialector) Name() string {
    43  	return "postgres"
    44  }
    45  
    46  var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )")
    47  
    48  func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
    49  	callbackConfig := &callbacks.Config{
    50  		CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"}, // remove returning when on conflict
    51  		UpdateClauses: []string{"UPDATE", "SET", "WHERE"},
    52  		DeleteClauses: []string{"DELETE", "FROM", "WHERE"},
    53  	}
    54  	// register callbacks
    55  	if !dialector.WithoutReturning {
    56  		callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
    57  		callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
    58  		callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
    59  	}
    60  	callbacks.RegisterDefaultCallbacks(db, callbackConfig)
    61  	for k, v := range dialector.ClauseBuilders() {
    62  		db.ClauseBuilders[k] = v
    63  	}
    64  
    65  	if dialector.Conn != nil {
    66  		db.ConnPool = dialector.Conn
    67  	} else if dialector.DriverName != "" {
    68  		db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN)
    69  	} else {
    70  		var config *pq.Config
    71  		config, err = pq.ParseConfig(dialector.Config.DSN)
    72  		if err != nil {
    73  			return
    74  		}
    75  		result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN)
    76  		if len(result) > 2 {
    77  			config.RuntimeParams["timezone"] = result[2]
    78  		}
    79  
    80  		connConfig := utils.Must(pq.NewConnectorConfig(config))
    81  		db.ConnPool = sql.OpenDB(connConfig)
    82  	}
    83  	return
    84  }
    85  
    86  func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
    87  	return Migrator{migrator.Migrator{Config: migrator.Config{
    88  		DB:                          db,
    89  		Dialector:                   dialector,
    90  		CreateIndexAfterCreateTable: true,
    91  	}}}
    92  }
    93  
    94  const (
    95  	// ClauseOnConflict for clause.ClauseBuilder ON CONFLICT key
    96  	ClauseOnConflict = "ON CONFLICT"
    97  	// ClauseValues for clause.ClauseBuilder VALUES key
    98  	ClauseValues = "VALUES"
    99  	// ClauseReturning for clause.ClauseBuilder RETURNING key
   100  	ClauseReturning = "RETURNING"
   101  
   102  	hasConflictKey = "~~opengauss_on_onflict~~"
   103  )
   104  
   105  func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
   106  	clauseBuilders := map[string]clause.ClauseBuilder{
   107  		ClauseOnConflict: func(c clause.Clause, builder clause.Builder) {
   108  			onConflict, ok := c.Expression.(clause.OnConflict)
   109  			if !ok {
   110  				c.Build(builder)
   111  				return
   112  			}
   113  
   114  			if stmt, ok := builder.(*gorm.Statement); ok {
   115  				stmt.Set(hasConflictKey, true)
   116  			}
   117  
   118  			builder.WriteString("ON DUPLICATE KEY UPDATE ")
   119  			if len(onConflict.DoUpdates) == 0 {
   120  				if s := builder.(*gorm.Statement).Schema; s != nil {
   121  					var column clause.Column
   122  					onConflict.DoNothing = false
   123  
   124  					if s.PrioritizedPrimaryField != nil {
   125  						column = clause.Column{Name: s.PrioritizedPrimaryField.DBName}
   126  					} else if len(s.DBNames) > 0 {
   127  						column = clause.Column{Name: s.DBNames[0]}
   128  					}
   129  
   130  					if column.Name != "" {
   131  						onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}}
   132  					}
   133  
   134  					builder.(*gorm.Statement).AddClause(onConflict)
   135  				}
   136  			}
   137  
   138  			for idx, assignment := range onConflict.DoUpdates {
   139  				if idx > 0 {
   140  					builder.WriteByte(',')
   141  				}
   142  
   143  				builder.WriteQuoted(assignment.Column)
   144  				builder.WriteByte('=')
   145  				if column, ok := assignment.Value.(clause.Column); ok && column.Table == "excluded" {
   146  					column.Table = ""
   147  					builder.WriteString("VALUES(")
   148  					builder.WriteQuoted(column)
   149  					builder.WriteByte(')')
   150  				} else {
   151  					builder.AddVar(builder, assignment.Value)
   152  				}
   153  			}
   154  		},
   155  		ClauseValues: func(c clause.Clause, builder clause.Builder) {
   156  			if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 {
   157  				builder.WriteString("VALUES()")
   158  				return
   159  			}
   160  			c.Build(builder)
   161  		},
   162  		// opengauss 3.0.0 not support returning with on DUPLICATE KEY UPDATE
   163  		// and not support on DUPLICATE KEY UPDATE on primary key or unique key
   164  		ClauseReturning: func(c clause.Clause, builder clause.Builder) {
   165  			if _, ok := c.Expression.(clause.Returning); !ok {
   166  				c.Build(builder)
   167  				return
   168  			}
   169  			if stmt, ok := builder.(*gorm.Statement); ok {
   170  				if has, ok := stmt.Get(hasConflictKey); ok && cast.ToBool(has) {
   171  					return
   172  				}
   173  			}
   174  			c.Build(builder)
   175  		},
   176  	}
   177  
   178  	return clauseBuilders
   179  }
   180  
   181  func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
   182  	return clause.Expr{SQL: "DEFAULT"}
   183  }
   184  
   185  func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
   186  	writer.WriteByte('$')
   187  	writer.WriteString(strconv.Itoa(len(stmt.Vars)))
   188  }
   189  
   190  func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
   191  	var (
   192  		underQuoted, selfQuoted bool
   193  		continuousBacktick      int8
   194  		shiftDelimiter          int8
   195  	)
   196  
   197  	for _, v := range []byte(str) {
   198  		switch v {
   199  		case '"':
   200  			continuousBacktick++
   201  			if continuousBacktick == 2 {
   202  				writer.WriteString(`""`)
   203  				continuousBacktick = 0
   204  			}
   205  		case '.':
   206  			if continuousBacktick > 0 || !selfQuoted {
   207  				shiftDelimiter = 0
   208  				underQuoted = false
   209  				continuousBacktick = 0
   210  				writer.WriteByte('"')
   211  			}
   212  			writer.WriteByte(v)
   213  			continue
   214  		default:
   215  			if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
   216  				writer.WriteByte('"')
   217  				underQuoted = true
   218  				if selfQuoted = continuousBacktick > 0; selfQuoted {
   219  					continuousBacktick -= 1
   220  				}
   221  			}
   222  
   223  			for ; continuousBacktick > 0; continuousBacktick -= 1 {
   224  				writer.WriteString(`""`)
   225  			}
   226  
   227  			writer.WriteByte(v)
   228  		}
   229  		shiftDelimiter++
   230  	}
   231  
   232  	if continuousBacktick > 0 && !selfQuoted {
   233  		writer.WriteString(`""`)
   234  	}
   235  	writer.WriteByte('"')
   236  }
   237  
   238  var numericPlaceholder = regexp.MustCompile(`\$(\d+)`)
   239  
   240  func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
   241  	return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
   242  }
   243  
   244  func (dialector Dialector) DataTypeOf(field *schema.Field) string {
   245  	switch field.DataType {
   246  	case schema.Bool:
   247  		return "boolean"
   248  	case schema.Int, schema.Uint:
   249  		size := field.Size
   250  		if field.DataType == schema.Uint {
   251  			size++
   252  		}
   253  		if field.AutoIncrement {
   254  			switch {
   255  			case size <= 16:
   256  				return "smallserial"
   257  			case size <= 32:
   258  				return "serial"
   259  			default:
   260  				return "bigserial"
   261  			}
   262  		} else {
   263  			switch {
   264  			case size <= 16:
   265  				return "smallint"
   266  			case size <= 32:
   267  				return "integer"
   268  			default:
   269  				return "bigint"
   270  			}
   271  		}
   272  	case schema.Float:
   273  		if field.Precision > 0 {
   274  			if field.Scale > 0 {
   275  				return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale)
   276  			}
   277  			return fmt.Sprintf("numeric(%d)", field.Precision)
   278  		}
   279  		return "decimal"
   280  	case schema.String:
   281  		if field.Size > 0 {
   282  			return fmt.Sprintf("varchar(%d)", field.Size)
   283  		}
   284  		return "text"
   285  	case schema.Time:
   286  		if field.Precision > 0 {
   287  			return fmt.Sprintf("timestamptz(%d)", field.Precision)
   288  		}
   289  		return "timestamptz"
   290  	case schema.Bytes:
   291  		return "bytea"
   292  	default:
   293  		return dialector.getSchemaCustomType(field)
   294  	}
   295  }
   296  
   297  func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
   298  	sqlType := string(field.DataType)
   299  
   300  	if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "serial") {
   301  		size := field.Size
   302  		if field.GORMDataType == schema.Uint {
   303  			size++
   304  		}
   305  		switch {
   306  		case size <= 16:
   307  			sqlType = "smallserial"
   308  		case size <= 32:
   309  			sqlType = "serial"
   310  		default:
   311  			sqlType = "bigserial"
   312  		}
   313  	}
   314  
   315  	return sqlType
   316  }
   317  
   318  func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
   319  	tx.Exec("SAVEPOINT " + name)
   320  	return nil
   321  }
   322  
   323  func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
   324  	tx.Exec("ROLLBACK TO SAVEPOINT " + name)
   325  	return nil
   326  }
   327  
   328  func getSerialDatabaseType(s string) (dbType string, ok bool) {
   329  	switch s {
   330  	case "smallserial":
   331  		return "smallint", true
   332  	case "serial":
   333  		return "integer", true
   334  	case "bigserial":
   335  		return "bigint", true
   336  	default:
   337  		return "", false
   338  	}
   339  }