github.com/wanlay/gorm-dm8@v1.0.5/dm.go (about)

     1  package dm
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"regexp"
     7  	"strconv"
     8  	"strings"
     9  
    10  	"github.com/thoas/go-funk"
    11  	"github.com/wanlay/gorm-dm8/clauses"
    12  	_ "github.com/wanlay/gorm-dm8/dmr"
    13  	"gorm.io/gorm"
    14  	"gorm.io/gorm/callbacks"
    15  	"gorm.io/gorm/clause"
    16  	"gorm.io/gorm/logger"
    17  	"gorm.io/gorm/migrator"
    18  	"gorm.io/gorm/schema"
    19  	"gorm.io/gorm/utils"
    20  )
    21  
    22  type Config struct {
    23  	DriverName        string
    24  	DSN               string
    25  	Conn              gorm.ConnPool //*sql.DB
    26  	DefaultStringSize uint
    27  }
    28  
    29  type Dialector struct {
    30  	*Config
    31  }
    32  
    33  func Open(dsn string) gorm.Dialector {
    34  	return &Dialector{Config: &Config{DSN: dsn}}
    35  }
    36  
    37  func New(config Config) gorm.Dialector {
    38  	return &Dialector{Config: &config}
    39  }
    40  
    41  func (d Dialector) DummyTableName() string {
    42  	return "DUAL"
    43  }
    44  
    45  func (d Dialector) Name() string {
    46  	return "dm"
    47  }
    48  
    49  func (d Dialector) Initialize(db *gorm.DB) (err error) {
    50  	db.NamingStrategy = Namer{}
    51  	d.DefaultStringSize = 4096
    52  
    53  	// register callbacks
    54  	callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
    55  		CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
    56  		UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
    57  		DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
    58  	})
    59  
    60  	d.DriverName = "dm"
    61  
    62  	if d.Conn != nil {
    63  		db.ConnPool = d.Conn
    64  	} else {
    65  		db.ConnPool, err = sql.Open(d.DriverName, d.DSN)
    66  	}
    67  
    68  	if err = db.Callback().Create().Replace("gorm:create", Create); err != nil {
    69  		return
    70  	}
    71  
    72  	for k, v := range d.ClauseBuilders() {
    73  		db.ClauseBuilders[k] = v
    74  	}
    75  	return
    76  }
    77  
    78  func (d Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
    79  	return map[string]clause.ClauseBuilder{
    80  		"LIMIT": d.RewriteLimit,
    81  		"WHERE": d.RewriteWhere,
    82  	}
    83  }
    84  
    85  func (d Dialector) RewriteWhere(c clause.Clause, builder clause.Builder) {
    86  	if where, ok := c.Expression.(clause.Where); ok {
    87  		builder.WriteString(" WHERE ")
    88  
    89  		// Switch position if the first query expression is a single Or condition
    90  		for idx, expr := range where.Exprs {
    91  			if v, ok := expr.(clause.OrConditions); !ok || len(v.Exprs) > 1 {
    92  				if idx != 0 {
    93  					where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0]
    94  				}
    95  				break
    96  			}
    97  		}
    98  
    99  		wrapInParentheses := false
   100  		for idx, expr := range where.Exprs {
   101  			if idx > 0 {
   102  				if v, ok := expr.(clause.OrConditions); ok && len(v.Exprs) == 1 {
   103  					builder.WriteString(" OR ")
   104  				} else {
   105  					builder.WriteString(" AND ")
   106  				}
   107  			}
   108  
   109  			if len(where.Exprs) > 1 {
   110  				switch v := expr.(type) {
   111  				case clause.OrConditions:
   112  					if len(v.Exprs) == 1 {
   113  						if e, ok := v.Exprs[0].(clause.Expr); ok {
   114  							sql := strings.ToLower(e.SQL)
   115  							wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
   116  						}
   117  					}
   118  				case clause.AndConditions:
   119  					if len(v.Exprs) == 1 {
   120  						if e, ok := v.Exprs[0].(clause.Expr); ok {
   121  							sql := strings.ToLower(e.SQL)
   122  							wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
   123  						}
   124  					}
   125  				case clause.Expr:
   126  					sql := strings.ToLower(v.SQL)
   127  					wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or")
   128  				}
   129  			}
   130  
   131  			if wrapInParentheses {
   132  				builder.WriteString(`(`)
   133  				expr.Build(builder)
   134  				builder.WriteString(`)`)
   135  				wrapInParentheses = false
   136  			} else {
   137  				if e, ok := expr.(clause.IN); ok {
   138  					if values, ok := e.Values[0].([]interface{}); ok {
   139  						if len(values) > 1 {
   140  							newExpr := clauses.IN{
   141  								Column: expr.(clause.IN).Column,
   142  								Values: expr.(clause.IN).Values,
   143  							}
   144  							newExpr.Build(builder)
   145  							continue
   146  						}
   147  					}
   148  				}
   149  
   150  				expr.Build(builder)
   151  			}
   152  		}
   153  	}
   154  }
   155  
   156  func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) {
   157  	if limit, ok := c.Expression.(clause.Limit); ok {
   158  		if stmt, ok := builder.(*gorm.Statement); ok {
   159  			if _, ok := stmt.Clauses["ORDER BY"]; !ok {
   160  				s := stmt.Schema
   161  				builder.WriteString("ORDER BY ")
   162  				if s != nil && s.PrioritizedPrimaryField != nil {
   163  					builder.WriteQuoted(s.PrioritizedPrimaryField.DBName)
   164  					builder.WriteByte(' ')
   165  				} else {
   166  					builder.WriteString("(SELECT NULL FROM ")
   167  					builder.WriteString(d.DummyTableName())
   168  					builder.WriteString(")")
   169  				}
   170  			}
   171  		}
   172  
   173  		if offset := limit.Offset; offset > 0 {
   174  			builder.WriteString(" OFFSET ")
   175  			builder.WriteString(strconv.Itoa(offset))
   176  			builder.WriteString(" ROWS")
   177  		}
   178  		if limit := limit.Limit; limit > 0 {
   179  			builder.WriteString(" FETCH NEXT ")
   180  			builder.WriteString(strconv.Itoa(limit))
   181  			builder.WriteString(" ROWS ONLY")
   182  		}
   183  	}
   184  }
   185  
   186  func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
   187  	return clause.Expr{SQL: "VALUES (DEFAULT)"}
   188  }
   189  
   190  func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
   191  	return Migrator{
   192  		Migrator: migrator.Migrator{
   193  			Config: migrator.Config{
   194  				DB:                          db,
   195  				Dialector:                   d,
   196  				CreateIndexAfterCreateTable: true,
   197  			},
   198  		},
   199  		Dialector: d,
   200  	}
   201  }
   202  
   203  func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
   204  	_, _ = writer.WriteString(":")
   205  	_, _ = writer.WriteString(strconv.Itoa(len(stmt.Vars)))
   206  }
   207  
   208  func (d Dialector) QuoteTo(writer clause.Writer, str string) {
   209  	_, _ = writer.WriteString(str)
   210  }
   211  
   212  var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
   213  
   214  func (d Dialector) Explain(sql string, vars ...interface{}) string {
   215  	return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v interface{}) interface{} {
   216  		switch v := v.(type) {
   217  		case bool:
   218  			if v {
   219  				return 1
   220  			}
   221  			return 0
   222  		default:
   223  			return v
   224  		}
   225  	}).([]interface{})...)
   226  }
   227  
   228  func (d Dialector) DataTypeOf(field *schema.Field) string {
   229  	if _, found := field.TagSettings["RESTRICT"]; found {
   230  		delete(field.TagSettings, "RESTRICT")
   231  	}
   232  
   233  	var sqlType string
   234  
   235  	switch field.DataType {
   236  	case schema.Bool, schema.Int, schema.Uint, schema.Float:
   237  		sqlType = "INTEGER"
   238  
   239  		switch {
   240  		case field.DataType == schema.Float:
   241  			sqlType = "FLOAT"
   242  		case field.Size <= 8:
   243  			sqlType = "SMALLINT"
   244  		}
   245  
   246  		if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
   247  			sqlType += " GENERATED BY DEFAULT AS IDENTITY"
   248  		}
   249  	case schema.String, "VARCHAR2":
   250  		size := field.Size
   251  		defaultSize := d.DefaultStringSize
   252  
   253  		if size == 0 {
   254  			if defaultSize > 0 {
   255  				size = int(defaultSize)
   256  			} else {
   257  				hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
   258  				// TEXT, GEOMETRY or JSON column can't have a default value
   259  				if field.PrimaryKey || field.HasDefaultValue || hasIndex {
   260  					size = 191 // utf8mb4
   261  				}
   262  			}
   263  		}
   264  
   265  		if size > 4096 {
   266  			sqlType = "CLOB"
   267  		} else {
   268  			sqlType = fmt.Sprintf("VARCHAR2(%d)", size)
   269  		}
   270  
   271  	case schema.Time:
   272  		sqlType = "TIMESTAMP WITH TIME ZONE"
   273  		if field.NotNull || field.PrimaryKey {
   274  			sqlType += " NOT NULL"
   275  		}
   276  	case schema.Bytes:
   277  		sqlType = "BLOB"
   278  	default:
   279  		sqlType = string(field.DataType)
   280  
   281  		if strings.EqualFold(sqlType, "text") {
   282  			sqlType = "CLOB"
   283  		}
   284  
   285  		if sqlType == "" {
   286  			panic(fmt.Sprintf("invalid sql type %s (%s) for dm", field.FieldType.Name(), field.FieldType.String()))
   287  		}
   288  
   289  		notNull, _ := field.TagSettings["NOT NULL"]
   290  		unique, _ := field.TagSettings["UNIQUE"]
   291  		additionalType := fmt.Sprintf("%s %s", notNull, unique)
   292  		if value, ok := field.TagSettings["DEFAULT"]; ok {
   293  			additionalType = fmt.Sprintf("%s %s %s%s", "DEFAULT", value, additionalType, func() string {
   294  				if value, ok := field.TagSettings["COMMENT"]; ok {
   295  					return " COMMENT " + value
   296  				}
   297  				return ""
   298  			}())
   299  		}
   300  		sqlType = fmt.Sprintf("%v %v", sqlType, additionalType)
   301  	}
   302  
   303  	return sqlType
   304  }
   305  
   306  func (d Dialector) SavePoint(tx *gorm.DB, name string) error {
   307  	tx.Exec("SAVEPOINT " + name)
   308  	return tx.Error
   309  }
   310  
   311  func (d Dialector) RollbackTo(tx *gorm.DB, name string) error {
   312  	tx.Exec("ROLLBACK TO SAVEPOINT " + name)
   313  	return tx.Error
   314  }