github.com/wanlay/gorm-dm8@v1.0.5/dm.go (about) 1 package dm 2 3 import ( 4 "database/sql" 5 "fmt" 6 "regexp" 7 "strconv" 8 "strings" 9 10 "github.com/thoas/go-funk" 11 "github.com/wanlay/gorm-dm8/clauses" 12 _ "github.com/wanlay/gorm-dm8/dmr" 13 "gorm.io/gorm" 14 "gorm.io/gorm/callbacks" 15 "gorm.io/gorm/clause" 16 "gorm.io/gorm/logger" 17 "gorm.io/gorm/migrator" 18 "gorm.io/gorm/schema" 19 "gorm.io/gorm/utils" 20 ) 21 22 type Config struct { 23 DriverName string 24 DSN string 25 Conn gorm.ConnPool //*sql.DB 26 DefaultStringSize uint 27 } 28 29 type Dialector struct { 30 *Config 31 } 32 33 func Open(dsn string) gorm.Dialector { 34 return &Dialector{Config: &Config{DSN: dsn}} 35 } 36 37 func New(config Config) gorm.Dialector { 38 return &Dialector{Config: &config} 39 } 40 41 func (d Dialector) DummyTableName() string { 42 return "DUAL" 43 } 44 45 func (d Dialector) Name() string { 46 return "dm" 47 } 48 49 func (d Dialector) Initialize(db *gorm.DB) (err error) { 50 db.NamingStrategy = Namer{} 51 d.DefaultStringSize = 4096 52 53 // register callbacks 54 callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ 55 CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, 56 UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, 57 DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, 58 }) 59 60 d.DriverName = "dm" 61 62 if d.Conn != nil { 63 db.ConnPool = d.Conn 64 } else { 65 db.ConnPool, err = sql.Open(d.DriverName, d.DSN) 66 } 67 68 if err = db.Callback().Create().Replace("gorm:create", Create); err != nil { 69 return 70 } 71 72 for k, v := range d.ClauseBuilders() { 73 db.ClauseBuilders[k] = v 74 } 75 return 76 } 77 78 func (d Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { 79 return map[string]clause.ClauseBuilder{ 80 "LIMIT": d.RewriteLimit, 81 "WHERE": d.RewriteWhere, 82 } 83 } 84 85 func (d Dialector) RewriteWhere(c clause.Clause, builder clause.Builder) { 86 if where, ok := c.Expression.(clause.Where); ok { 87 builder.WriteString(" WHERE ") 88 89 // Switch position if the first query expression is a single Or condition 90 for idx, expr := range where.Exprs { 91 if v, ok := expr.(clause.OrConditions); !ok || len(v.Exprs) > 1 { 92 if idx != 0 { 93 where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0] 94 } 95 break 96 } 97 } 98 99 wrapInParentheses := false 100 for idx, expr := range where.Exprs { 101 if idx > 0 { 102 if v, ok := expr.(clause.OrConditions); ok && len(v.Exprs) == 1 { 103 builder.WriteString(" OR ") 104 } else { 105 builder.WriteString(" AND ") 106 } 107 } 108 109 if len(where.Exprs) > 1 { 110 switch v := expr.(type) { 111 case clause.OrConditions: 112 if len(v.Exprs) == 1 { 113 if e, ok := v.Exprs[0].(clause.Expr); ok { 114 sql := strings.ToLower(e.SQL) 115 wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") 116 } 117 } 118 case clause.AndConditions: 119 if len(v.Exprs) == 1 { 120 if e, ok := v.Exprs[0].(clause.Expr); ok { 121 sql := strings.ToLower(e.SQL) 122 wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") 123 } 124 } 125 case clause.Expr: 126 sql := strings.ToLower(v.SQL) 127 wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") 128 } 129 } 130 131 if wrapInParentheses { 132 builder.WriteString(`(`) 133 expr.Build(builder) 134 builder.WriteString(`)`) 135 wrapInParentheses = false 136 } else { 137 if e, ok := expr.(clause.IN); ok { 138 if values, ok := e.Values[0].([]interface{}); ok { 139 if len(values) > 1 { 140 newExpr := clauses.IN{ 141 Column: expr.(clause.IN).Column, 142 Values: expr.(clause.IN).Values, 143 } 144 newExpr.Build(builder) 145 continue 146 } 147 } 148 } 149 150 expr.Build(builder) 151 } 152 } 153 } 154 } 155 156 func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) { 157 if limit, ok := c.Expression.(clause.Limit); ok { 158 if stmt, ok := builder.(*gorm.Statement); ok { 159 if _, ok := stmt.Clauses["ORDER BY"]; !ok { 160 s := stmt.Schema 161 builder.WriteString("ORDER BY ") 162 if s != nil && s.PrioritizedPrimaryField != nil { 163 builder.WriteQuoted(s.PrioritizedPrimaryField.DBName) 164 builder.WriteByte(' ') 165 } else { 166 builder.WriteString("(SELECT NULL FROM ") 167 builder.WriteString(d.DummyTableName()) 168 builder.WriteString(")") 169 } 170 } 171 } 172 173 if offset := limit.Offset; offset > 0 { 174 builder.WriteString(" OFFSET ") 175 builder.WriteString(strconv.Itoa(offset)) 176 builder.WriteString(" ROWS") 177 } 178 if limit := limit.Limit; limit > 0 { 179 builder.WriteString(" FETCH NEXT ") 180 builder.WriteString(strconv.Itoa(limit)) 181 builder.WriteString(" ROWS ONLY") 182 } 183 } 184 } 185 186 func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression { 187 return clause.Expr{SQL: "VALUES (DEFAULT)"} 188 } 189 190 func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator { 191 return Migrator{ 192 Migrator: migrator.Migrator{ 193 Config: migrator.Config{ 194 DB: db, 195 Dialector: d, 196 CreateIndexAfterCreateTable: true, 197 }, 198 }, 199 Dialector: d, 200 } 201 } 202 203 func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { 204 _, _ = writer.WriteString(":") 205 _, _ = writer.WriteString(strconv.Itoa(len(stmt.Vars))) 206 } 207 208 func (d Dialector) QuoteTo(writer clause.Writer, str string) { 209 _, _ = writer.WriteString(str) 210 } 211 212 var numericPlaceholder = regexp.MustCompile(`:(\d+)`) 213 214 func (d Dialector) Explain(sql string, vars ...interface{}) string { 215 return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v interface{}) interface{} { 216 switch v := v.(type) { 217 case bool: 218 if v { 219 return 1 220 } 221 return 0 222 default: 223 return v 224 } 225 }).([]interface{})...) 226 } 227 228 func (d Dialector) DataTypeOf(field *schema.Field) string { 229 if _, found := field.TagSettings["RESTRICT"]; found { 230 delete(field.TagSettings, "RESTRICT") 231 } 232 233 var sqlType string 234 235 switch field.DataType { 236 case schema.Bool, schema.Int, schema.Uint, schema.Float: 237 sqlType = "INTEGER" 238 239 switch { 240 case field.DataType == schema.Float: 241 sqlType = "FLOAT" 242 case field.Size <= 8: 243 sqlType = "SMALLINT" 244 } 245 246 if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { 247 sqlType += " GENERATED BY DEFAULT AS IDENTITY" 248 } 249 case schema.String, "VARCHAR2": 250 size := field.Size 251 defaultSize := d.DefaultStringSize 252 253 if size == 0 { 254 if defaultSize > 0 { 255 size = int(defaultSize) 256 } else { 257 hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != "" 258 // TEXT, GEOMETRY or JSON column can't have a default value 259 if field.PrimaryKey || field.HasDefaultValue || hasIndex { 260 size = 191 // utf8mb4 261 } 262 } 263 } 264 265 if size > 4096 { 266 sqlType = "CLOB" 267 } else { 268 sqlType = fmt.Sprintf("VARCHAR2(%d)", size) 269 } 270 271 case schema.Time: 272 sqlType = "TIMESTAMP WITH TIME ZONE" 273 if field.NotNull || field.PrimaryKey { 274 sqlType += " NOT NULL" 275 } 276 case schema.Bytes: 277 sqlType = "BLOB" 278 default: 279 sqlType = string(field.DataType) 280 281 if strings.EqualFold(sqlType, "text") { 282 sqlType = "CLOB" 283 } 284 285 if sqlType == "" { 286 panic(fmt.Sprintf("invalid sql type %s (%s) for dm", field.FieldType.Name(), field.FieldType.String())) 287 } 288 289 notNull, _ := field.TagSettings["NOT NULL"] 290 unique, _ := field.TagSettings["UNIQUE"] 291 additionalType := fmt.Sprintf("%s %s", notNull, unique) 292 if value, ok := field.TagSettings["DEFAULT"]; ok { 293 additionalType = fmt.Sprintf("%s %s %s%s", "DEFAULT", value, additionalType, func() string { 294 if value, ok := field.TagSettings["COMMENT"]; ok { 295 return " COMMENT " + value 296 } 297 return "" 298 }()) 299 } 300 sqlType = fmt.Sprintf("%v %v", sqlType, additionalType) 301 } 302 303 return sqlType 304 } 305 306 func (d Dialector) SavePoint(tx *gorm.DB, name string) error { 307 tx.Exec("SAVEPOINT " + name) 308 return tx.Error 309 } 310 311 func (d Dialector) RollbackTo(tx *gorm.DB, name string) error { 312 tx.Exec("ROLLBACK TO SAVEPOINT " + name) 313 return tx.Error 314 }