github.com/wfusion/gofusion@v1.1.14/common/infra/drivers/orm/opengauss/opengauss.go (about) 1 package opengauss 2 3 import ( 4 "database/sql" 5 "fmt" 6 "regexp" 7 "strconv" 8 "strings" 9 10 "github.com/spf13/cast" 11 "gorm.io/gorm" 12 "gorm.io/gorm/callbacks" 13 "gorm.io/gorm/clause" 14 "gorm.io/gorm/logger" 15 "gorm.io/gorm/migrator" 16 "gorm.io/gorm/schema" 17 18 "github.com/wfusion/gofusion/common/utils" 19 20 pq "gitee.com/opengauss/openGauss-connector-go-pq" 21 ) 22 23 type Dialector struct { 24 *Config 25 } 26 27 type Config struct { 28 DriverName string 29 DSN string 30 WithoutReturning bool 31 Conn gorm.ConnPool 32 } 33 34 func Open(dsn string) gorm.Dialector { 35 return &Dialector{&Config{DSN: dsn}} 36 } 37 38 func New(config Config) gorm.Dialector { 39 return &Dialector{Config: &config} 40 } 41 42 func (dialector Dialector) Name() string { 43 return "postgres" 44 } 45 46 var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") 47 48 func (dialector Dialector) Initialize(db *gorm.DB) (err error) { 49 callbackConfig := &callbacks.Config{ 50 CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"}, // remove returning when on conflict 51 UpdateClauses: []string{"UPDATE", "SET", "WHERE"}, 52 DeleteClauses: []string{"DELETE", "FROM", "WHERE"}, 53 } 54 // register callbacks 55 if !dialector.WithoutReturning { 56 callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING") 57 callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING") 58 callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING") 59 } 60 callbacks.RegisterDefaultCallbacks(db, callbackConfig) 61 for k, v := range dialector.ClauseBuilders() { 62 db.ClauseBuilders[k] = v 63 } 64 65 if dialector.Conn != nil { 66 db.ConnPool = dialector.Conn 67 } else if dialector.DriverName != "" { 68 db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN) 69 } else { 70 var config *pq.Config 71 config, err = pq.ParseConfig(dialector.Config.DSN) 72 if err != nil { 73 return 74 } 75 result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN) 76 if len(result) > 2 { 77 config.RuntimeParams["timezone"] = result[2] 78 } 79 80 connConfig := utils.Must(pq.NewConnectorConfig(config)) 81 db.ConnPool = sql.OpenDB(connConfig) 82 } 83 return 84 } 85 86 func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { 87 return Migrator{migrator.Migrator{Config: migrator.Config{ 88 DB: db, 89 Dialector: dialector, 90 CreateIndexAfterCreateTable: true, 91 }}} 92 } 93 94 const ( 95 // ClauseOnConflict for clause.ClauseBuilder ON CONFLICT key 96 ClauseOnConflict = "ON CONFLICT" 97 // ClauseValues for clause.ClauseBuilder VALUES key 98 ClauseValues = "VALUES" 99 // ClauseReturning for clause.ClauseBuilder RETURNING key 100 ClauseReturning = "RETURNING" 101 102 hasConflictKey = "~~opengauss_on_onflict~~" 103 ) 104 105 func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { 106 clauseBuilders := map[string]clause.ClauseBuilder{ 107 ClauseOnConflict: func(c clause.Clause, builder clause.Builder) { 108 onConflict, ok := c.Expression.(clause.OnConflict) 109 if !ok { 110 c.Build(builder) 111 return 112 } 113 114 if stmt, ok := builder.(*gorm.Statement); ok { 115 stmt.Set(hasConflictKey, true) 116 } 117 118 builder.WriteString("ON DUPLICATE KEY UPDATE ") 119 if len(onConflict.DoUpdates) == 0 { 120 if s := builder.(*gorm.Statement).Schema; s != nil { 121 var column clause.Column 122 onConflict.DoNothing = false 123 124 if s.PrioritizedPrimaryField != nil { 125 column = clause.Column{Name: s.PrioritizedPrimaryField.DBName} 126 } else if len(s.DBNames) > 0 { 127 column = clause.Column{Name: s.DBNames[0]} 128 } 129 130 if column.Name != "" { 131 onConflict.DoUpdates = []clause.Assignment{{Column: column, Value: column}} 132 } 133 134 builder.(*gorm.Statement).AddClause(onConflict) 135 } 136 } 137 138 for idx, assignment := range onConflict.DoUpdates { 139 if idx > 0 { 140 builder.WriteByte(',') 141 } 142 143 builder.WriteQuoted(assignment.Column) 144 builder.WriteByte('=') 145 if column, ok := assignment.Value.(clause.Column); ok && column.Table == "excluded" { 146 column.Table = "" 147 builder.WriteString("VALUES(") 148 builder.WriteQuoted(column) 149 builder.WriteByte(')') 150 } else { 151 builder.AddVar(builder, assignment.Value) 152 } 153 } 154 }, 155 ClauseValues: func(c clause.Clause, builder clause.Builder) { 156 if values, ok := c.Expression.(clause.Values); ok && len(values.Columns) == 0 { 157 builder.WriteString("VALUES()") 158 return 159 } 160 c.Build(builder) 161 }, 162 // opengauss 3.0.0 not support returning with on DUPLICATE KEY UPDATE 163 // and not support on DUPLICATE KEY UPDATE on primary key or unique key 164 ClauseReturning: func(c clause.Clause, builder clause.Builder) { 165 if _, ok := c.Expression.(clause.Returning); !ok { 166 c.Build(builder) 167 return 168 } 169 if stmt, ok := builder.(*gorm.Statement); ok { 170 if has, ok := stmt.Get(hasConflictKey); ok && cast.ToBool(has) { 171 return 172 } 173 } 174 c.Build(builder) 175 }, 176 } 177 178 return clauseBuilders 179 } 180 181 func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression { 182 return clause.Expr{SQL: "DEFAULT"} 183 } 184 185 func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { 186 writer.WriteByte('$') 187 writer.WriteString(strconv.Itoa(len(stmt.Vars))) 188 } 189 190 func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { 191 var ( 192 underQuoted, selfQuoted bool 193 continuousBacktick int8 194 shiftDelimiter int8 195 ) 196 197 for _, v := range []byte(str) { 198 switch v { 199 case '"': 200 continuousBacktick++ 201 if continuousBacktick == 2 { 202 writer.WriteString(`""`) 203 continuousBacktick = 0 204 } 205 case '.': 206 if continuousBacktick > 0 || !selfQuoted { 207 shiftDelimiter = 0 208 underQuoted = false 209 continuousBacktick = 0 210 writer.WriteByte('"') 211 } 212 writer.WriteByte(v) 213 continue 214 default: 215 if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { 216 writer.WriteByte('"') 217 underQuoted = true 218 if selfQuoted = continuousBacktick > 0; selfQuoted { 219 continuousBacktick -= 1 220 } 221 } 222 223 for ; continuousBacktick > 0; continuousBacktick -= 1 { 224 writer.WriteString(`""`) 225 } 226 227 writer.WriteByte(v) 228 } 229 shiftDelimiter++ 230 } 231 232 if continuousBacktick > 0 && !selfQuoted { 233 writer.WriteString(`""`) 234 } 235 writer.WriteByte('"') 236 } 237 238 var numericPlaceholder = regexp.MustCompile(`\$(\d+)`) 239 240 func (dialector Dialector) Explain(sql string, vars ...interface{}) string { 241 return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) 242 } 243 244 func (dialector Dialector) DataTypeOf(field *schema.Field) string { 245 switch field.DataType { 246 case schema.Bool: 247 return "boolean" 248 case schema.Int, schema.Uint: 249 size := field.Size 250 if field.DataType == schema.Uint { 251 size++ 252 } 253 if field.AutoIncrement { 254 switch { 255 case size <= 16: 256 return "smallserial" 257 case size <= 32: 258 return "serial" 259 default: 260 return "bigserial" 261 } 262 } else { 263 switch { 264 case size <= 16: 265 return "smallint" 266 case size <= 32: 267 return "integer" 268 default: 269 return "bigint" 270 } 271 } 272 case schema.Float: 273 if field.Precision > 0 { 274 if field.Scale > 0 { 275 return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale) 276 } 277 return fmt.Sprintf("numeric(%d)", field.Precision) 278 } 279 return "decimal" 280 case schema.String: 281 if field.Size > 0 { 282 return fmt.Sprintf("varchar(%d)", field.Size) 283 } 284 return "text" 285 case schema.Time: 286 if field.Precision > 0 { 287 return fmt.Sprintf("timestamptz(%d)", field.Precision) 288 } 289 return "timestamptz" 290 case schema.Bytes: 291 return "bytea" 292 default: 293 return dialector.getSchemaCustomType(field) 294 } 295 } 296 297 func (dialector Dialector) getSchemaCustomType(field *schema.Field) string { 298 sqlType := string(field.DataType) 299 300 if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "serial") { 301 size := field.Size 302 if field.GORMDataType == schema.Uint { 303 size++ 304 } 305 switch { 306 case size <= 16: 307 sqlType = "smallserial" 308 case size <= 32: 309 sqlType = "serial" 310 default: 311 sqlType = "bigserial" 312 } 313 } 314 315 return sqlType 316 } 317 318 func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error { 319 tx.Exec("SAVEPOINT " + name) 320 return nil 321 } 322 323 func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error { 324 tx.Exec("ROLLBACK TO SAVEPOINT " + name) 325 return nil 326 } 327 328 func getSerialDatabaseType(s string) (dbType string, ok bool) { 329 switch s { 330 case "smallserial": 331 return "smallint", true 332 case "serial": 333 return "integer", true 334 case "bigserial": 335 return "bigint", true 336 default: 337 return "", false 338 } 339 }