github.com/wfusion/gofusion@v1.1.14/common/infra/drivers/orm/sqlite/sqlite.go (about)

     1  package sqlite
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"strconv"
     7  
     8  	"gorm.io/gorm/callbacks"
     9  
    10  	gosqlite "github.com/glebarez/go-sqlite"
    11  	sqlite3 "modernc.org/sqlite/lib"
    12  
    13  	"gorm.io/gorm"
    14  	"gorm.io/gorm/clause"
    15  	"gorm.io/gorm/logger"
    16  	"gorm.io/gorm/migrator"
    17  	"gorm.io/gorm/schema"
    18  )
    19  
    20  // DriverName is the default driver name for SQLite.
    21  const DriverName = "sqlite"
    22  
    23  type Dialector struct {
    24  	DriverName string
    25  	DSN        string
    26  	Conn       gorm.ConnPool
    27  }
    28  
    29  func Open(dsn string) gorm.Dialector {
    30  	return &Dialector{DSN: dsn}
    31  }
    32  
    33  func (dialector Dialector) Name() string {
    34  	return "sqlite"
    35  }
    36  
    37  func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
    38  	if dialector.DriverName == "" {
    39  		dialector.DriverName = DriverName
    40  	}
    41  
    42  	if dialector.Conn != nil {
    43  		db.ConnPool = dialector.Conn
    44  	} else {
    45  		conn, err := sql.Open(dialector.DriverName, dialector.DSN)
    46  		if err != nil {
    47  			return err
    48  		}
    49  		db.ConnPool = conn
    50  	}
    51  
    52  	var version string
    53  	if err := db.ConnPool.QueryRowContext(context.Background(), "select sqlite_version()").Scan(&version); err != nil {
    54  		return err
    55  	}
    56  	// https://www.sqlite.org/releaselog/3_35_0.html
    57  	if compareVersion(version, "3.35.0") >= 0 {
    58  		callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
    59  			CreateClauses:        []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
    60  			UpdateClauses:        []string{"UPDATE", "SET", "WHERE", "RETURNING"},
    61  			DeleteClauses:        []string{"DELETE", "FROM", "WHERE", "RETURNING"},
    62  			LastInsertIDReversed: true,
    63  		})
    64  	} else {
    65  		callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
    66  			LastInsertIDReversed: true,
    67  		})
    68  	}
    69  
    70  	for k, v := range dialector.ClauseBuilders() {
    71  		db.ClauseBuilders[k] = v
    72  	}
    73  	return
    74  }
    75  
    76  func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
    77  	return map[string]clause.ClauseBuilder{
    78  		"INSERT": func(c clause.Clause, builder clause.Builder) {
    79  			if insert, ok := c.Expression.(clause.Insert); ok {
    80  				if stmt, ok := builder.(*gorm.Statement); ok {
    81  					stmt.WriteString("INSERT ")
    82  					if insert.Modifier != "" {
    83  						stmt.WriteString(insert.Modifier)
    84  						stmt.WriteByte(' ')
    85  					}
    86  
    87  					stmt.WriteString("INTO ")
    88  					if insert.Table.Name == "" {
    89  						stmt.WriteQuoted(stmt.Table)
    90  					} else {
    91  						stmt.WriteQuoted(insert.Table)
    92  					}
    93  					return
    94  				}
    95  			}
    96  
    97  			c.Build(builder)
    98  		},
    99  		"LIMIT": func(c clause.Clause, builder clause.Builder) {
   100  			if limit, ok := c.Expression.(clause.Limit); ok {
   101  				var lmt = -1
   102  				if limit.Limit != nil && *limit.Limit >= 0 {
   103  					lmt = *limit.Limit
   104  				}
   105  				if lmt >= 0 || limit.Offset > 0 {
   106  					builder.WriteString("LIMIT ")
   107  					builder.WriteString(strconv.Itoa(lmt))
   108  				}
   109  				if limit.Offset > 0 {
   110  					builder.WriteString(" OFFSET ")
   111  					builder.WriteString(strconv.Itoa(limit.Offset))
   112  				}
   113  			}
   114  		},
   115  		"FOR": func(c clause.Clause, builder clause.Builder) {
   116  			if _, ok := c.Expression.(clause.Locking); ok {
   117  				// SQLite3 does not support row-level locking.
   118  				return
   119  			}
   120  			c.Build(builder)
   121  		},
   122  	}
   123  }
   124  
   125  func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
   126  	if field.AutoIncrement {
   127  		return clause.Expr{SQL: "NULL"}
   128  	}
   129  
   130  	// doesn't work, will raise error
   131  	return clause.Expr{SQL: "DEFAULT"}
   132  }
   133  
   134  func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
   135  	return Migrator{migrator.Migrator{Config: migrator.Config{
   136  		DB:                          db,
   137  		Dialector:                   dialector,
   138  		CreateIndexAfterCreateTable: true,
   139  	}}}
   140  }
   141  
   142  func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
   143  	writer.WriteByte('?')
   144  }
   145  
   146  func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
   147  	var (
   148  		underQuoted, selfQuoted bool
   149  		continuousBacktick      int8
   150  		shiftDelimiter          int8
   151  	)
   152  
   153  	for _, v := range []byte(str) {
   154  		switch v {
   155  		case '`':
   156  			continuousBacktick++
   157  			if continuousBacktick == 2 {
   158  				writer.WriteString("``")
   159  				continuousBacktick = 0
   160  			}
   161  		case '.':
   162  			if continuousBacktick > 0 || !selfQuoted {
   163  				shiftDelimiter = 0
   164  				underQuoted = false
   165  				continuousBacktick = 0
   166  				writer.WriteString("`")
   167  			}
   168  			writer.WriteByte(v)
   169  			continue
   170  		default:
   171  			if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
   172  				writer.WriteString("`")
   173  				underQuoted = true
   174  				if selfQuoted = continuousBacktick > 0; selfQuoted {
   175  					continuousBacktick -= 1
   176  				}
   177  			}
   178  
   179  			for ; continuousBacktick > 0; continuousBacktick -= 1 {
   180  				writer.WriteString("``")
   181  			}
   182  
   183  			writer.WriteByte(v)
   184  		}
   185  		shiftDelimiter++
   186  	}
   187  
   188  	if continuousBacktick > 0 && !selfQuoted {
   189  		writer.WriteString("``")
   190  	}
   191  	writer.WriteString("`")
   192  }
   193  
   194  func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
   195  	return logger.ExplainSQL(sql, nil, `"`, vars...)
   196  }
   197  
   198  func (dialector Dialector) DataTypeOf(field *schema.Field) string {
   199  	switch field.DataType {
   200  	case schema.Bool:
   201  		return "numeric"
   202  	case schema.Int, schema.Uint:
   203  		if field.AutoIncrement {
   204  			// doesn't check `PrimaryKey`, to keep backward compatibility
   205  			// https://www.sqlite.org/autoinc.html
   206  			return "integer PRIMARY KEY AUTOINCREMENT"
   207  		} else {
   208  			return "integer"
   209  		}
   210  	case schema.Float:
   211  		return "real"
   212  	case schema.String:
   213  		return "text"
   214  	case schema.Time:
   215  		// Distinguish between schema.Time and tag time
   216  		if val, ok := field.TagSettings["TYPE"]; ok {
   217  			return val
   218  		} else {
   219  			return "datetime"
   220  		}
   221  	case schema.Bytes:
   222  		return "blob"
   223  	}
   224  
   225  	return string(field.DataType)
   226  }
   227  
   228  func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
   229  	tx.Exec("SAVEPOINT " + name)
   230  	return nil
   231  }
   232  
   233  func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
   234  	tx.Exec("ROLLBACK TO SAVEPOINT " + name)
   235  	return nil
   236  }
   237  
   238  func (dialector Dialector) Translate(err error) error {
   239  	switch terr := err.(type) {
   240  	case *gosqlite.Error:
   241  		switch terr.Code() {
   242  		case sqlite3.SQLITE_CONSTRAINT_UNIQUE:
   243  			return gorm.ErrDuplicatedKey
   244  		case sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY:
   245  			return gorm.ErrDuplicatedKey
   246  		case sqlite3.SQLITE_CONSTRAINT_FOREIGNKEY:
   247  			return gorm.ErrForeignKeyViolated
   248  		}
   249  	}
   250  	return err
   251  }
   252  
   253  func compareVersion(version1, version2 string) int {
   254  	n, m := len(version1), len(version2)
   255  	i, j := 0, 0
   256  	for i < n || j < m {
   257  		x := 0
   258  		for ; i < n && version1[i] != '.'; i++ {
   259  			x = x*10 + int(version1[i]-'0')
   260  		}
   261  		i++
   262  		y := 0
   263  		for ; j < m && version2[j] != '.'; j++ {
   264  			y = y*10 + int(version2[j]-'0')
   265  		}
   266  		j++
   267  		if x > y {
   268  			return 1
   269  		}
   270  		if x < y {
   271  			return -1
   272  		}
   273  	}
   274  	return 0
   275  }