github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/internal/sql/adapter/sqlite/dialect.go (about)

     1  package sqlite
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  
     9  	"github.com/octohelm/storage/internal/sql/adapter"
    10  	"github.com/octohelm/storage/pkg/sqlbuilder"
    11  	typex "github.com/octohelm/x/types"
    12  )
    13  
    14  var _ adapter.Dialect = (*dialect)(nil)
    15  
    16  type dialect struct {
    17  }
    18  
    19  func (dialect) DriverName() string {
    20  	return "sqlite"
    21  }
    22  
    23  func (c *dialect) AddIndex(key sqlbuilder.Key) sqlbuilder.SqlExpr {
    24  	if key.IsPrimary() {
    25  		e := sqlbuilder.Expr("ALTER TABLE ")
    26  		e.WriteExpr(key.T())
    27  		e.WriteQuery(" ADD PRIMARY KEY ")
    28  		e.WriteGroup(func(e *sqlbuilder.Ex) {
    29  			e.WriteExpr(key.Columns())
    30  		})
    31  		e.WriteEnd()
    32  		return e
    33  	}
    34  
    35  	e := sqlbuilder.Expr("CREATE ")
    36  	if key.IsUnique() {
    37  		e.WriteQuery("UNIQUE ")
    38  	}
    39  	e.WriteQuery("INDEX ")
    40  
    41  	e.WriteExpr(c.indexName(key))
    42  
    43  	e.WriteQuery(" ON ")
    44  	e.WriteExpr(key.T())
    45  
    46  	keyDef := key.(sqlbuilder.KeyDef)
    47  
    48  	e.WriteQueryByte(' ')
    49  	e.WriteGroup(func(e *sqlbuilder.Ex) {
    50  		for i, colNameAndOpt := range keyDef.ColNameAndOptions() {
    51  			parts := strings.Split(colNameAndOpt, "/")
    52  			if i != 0 {
    53  				_ = e.WriteByte(',')
    54  			}
    55  			e.WriteExpr(key.T().F(parts[0]))
    56  			if len(parts) > 1 {
    57  				e.WriteQuery(" ")
    58  				e.WriteQuery(parts[1])
    59  			}
    60  		}
    61  	})
    62  
    63  	e.WriteEnd()
    64  	return e
    65  }
    66  
    67  func (c *dialect) DropIndex(key sqlbuilder.Key) sqlbuilder.SqlExpr {
    68  	if key.IsPrimary() {
    69  		// pk could not changed
    70  		return nil
    71  	}
    72  
    73  	return sqlbuilder.Expr("DROP INDEX IF EXISTS @index;", sqlbuilder.NamedArgSet{
    74  		"index": c.indexName(key),
    75  	})
    76  }
    77  
    78  func (c *dialect) CreateTableIsNotExists(t sqlbuilder.Table) (exprs []sqlbuilder.SqlExpr) {
    79  	expr := sqlbuilder.Expr("CREATE TABLE IF NOT EXISTS @table ", sqlbuilder.NamedArgSet{
    80  		"table": t,
    81  	})
    82  
    83  	expr.WriteGroup(func(e *sqlbuilder.Ex) {
    84  		cols := t.Cols()
    85  
    86  		if cols.IsNil() {
    87  			return
    88  		}
    89  
    90  		var autoIncrement sqlbuilder.Column
    91  
    92  		cols.RangeCol(func(col sqlbuilder.Column, idx int) bool {
    93  			def := col.Def()
    94  
    95  			if def.DeprecatedActions != nil {
    96  				return true
    97  			}
    98  
    99  			if def.AutoIncrement {
   100  				autoIncrement = col
   101  			}
   102  
   103  			if idx > 0 {
   104  				e.WriteQueryByte(',')
   105  			}
   106  			e.WriteQueryByte('\n')
   107  			e.WriteQueryByte('\t')
   108  
   109  			e.WriteExpr(col)
   110  			e.WriteQueryByte(' ')
   111  			e.WriteExpr(c.DataType(col.Def()))
   112  
   113  			return true
   114  		})
   115  
   116  		t.Keys().RangeKey(func(key sqlbuilder.Key, idx int) bool {
   117  			if key.IsPrimary() {
   118  				var skip = false
   119  
   120  				if autoIncrement != nil {
   121  					key.Columns().RangeCol(func(col sqlbuilder.Column, idx int) bool {
   122  						if autoIncrement.Name() == col.Name() {
   123  							skip = true
   124  							// auto increment pk will create when table define
   125  							return false
   126  						}
   127  						return true
   128  					})
   129  				}
   130  
   131  				if skip {
   132  					return true
   133  				}
   134  
   135  				e.WriteQueryByte(',')
   136  				e.WriteQueryByte('\n')
   137  				e.WriteQueryByte('\t')
   138  				e.WriteQuery("PRIMARY KEY ")
   139  				e.WriteGroup(func(e *sqlbuilder.Ex) {
   140  					e.WriteExpr(key.Columns())
   141  				})
   142  			}
   143  
   144  			return true
   145  		})
   146  
   147  		expr.WriteQueryByte('\n')
   148  	})
   149  
   150  	expr.WriteEnd()
   151  	exprs = append(exprs, expr)
   152  
   153  	t.Keys().RangeKey(func(key sqlbuilder.Key, idx int) bool {
   154  		if !key.IsPrimary() {
   155  			exprs = append(exprs, c.AddIndex(key))
   156  		}
   157  		return true
   158  	})
   159  
   160  	return
   161  }
   162  
   163  func (c *dialect) DropTable(t sqlbuilder.Table) sqlbuilder.SqlExpr {
   164  	return sqlbuilder.Expr("DROP TABLE IF EXISTS @table;", sqlbuilder.NamedArgSet{
   165  		"table": t,
   166  	})
   167  }
   168  
   169  func (c *dialect) TruncateTable(t sqlbuilder.Table) sqlbuilder.SqlExpr {
   170  	return sqlbuilder.Expr("TRUNCATE TABLE @table;", sqlbuilder.NamedArgSet{
   171  		"table": t,
   172  	})
   173  }
   174  
   175  func (c *dialect) AddColumn(col sqlbuilder.Column) sqlbuilder.SqlExpr {
   176  	return sqlbuilder.Expr("ALTER TABLE @table ADD COLUMN @col @dataType;", sqlbuilder.NamedArgSet{
   177  		"table":    col.T(),
   178  		"col":      col,
   179  		"dataType": c.DataType(col.Def()),
   180  	})
   181  }
   182  
   183  func (c *dialect) RenameColumn(col sqlbuilder.Column, target sqlbuilder.Column) sqlbuilder.SqlExpr {
   184  	return sqlbuilder.Expr("ALTER TABLE @table RENAME COLUMN @oldCol TO @newCol;", sqlbuilder.NamedArgSet{
   185  		"table":  col.T(),
   186  		"oldCol": col,
   187  		"newCol": target,
   188  	})
   189  }
   190  
   191  func (c *dialect) ModifyColumn(col sqlbuilder.Column, prevCol sqlbuilder.Column) sqlbuilder.SqlExpr {
   192  	def := col.Def()
   193  
   194  	// incr id never modified
   195  	if def.AutoIncrement {
   196  		return nil
   197  	}
   198  
   199  	prevTmpCol := sqlbuilder.Col("__"+prevCol.Name(), sqlbuilder.ColDef(prevCol.Def())).Of(prevCol.T())
   200  
   201  	e := sqlbuilder.Expr("")
   202  
   203  	e.WriteExpr(sqlbuilder.Expr("ALTER TABLE @table RENAME COLUMN @prevCol TO @tmpCol;", sqlbuilder.NamedArgSet{
   204  		"table":   prevCol.T(),
   205  		"prevCol": prevCol,
   206  		"tmpCol":  prevTmpCol,
   207  	}))
   208  
   209  	e.WriteExpr(c.AddColumn(col))
   210  
   211  	e.WriteExpr(sqlbuilder.Expr("UPDATE @table SET @col = @tmpCol;", sqlbuilder.NamedArgSet{
   212  		"table":  col.T(),
   213  		"col":    col,
   214  		"tmpCol": prevTmpCol,
   215  	}))
   216  
   217  	e.WriteExpr(c.DropColumn(prevTmpCol))
   218  
   219  	return e
   220  }
   221  
   222  func (c *dialect) DropColumn(col sqlbuilder.Column) sqlbuilder.SqlExpr {
   223  	return sqlbuilder.Expr("ALTER TABLE @table DROP COLUMN @col;", sqlbuilder.NamedArgSet{
   224  		"table": col.T(),
   225  		"col":   col,
   226  	})
   227  }
   228  
   229  func (c *dialect) DataType(columnType sqlbuilder.ColumnDef) sqlbuilder.SqlExpr {
   230  	dbDataType := c.dbDataType(columnType.Type, columnType)
   231  	return sqlbuilder.Expr(dbDataType + c.dataTypeModify(columnType, dbDataType))
   232  }
   233  
   234  func (c *dialect) dataType(typ typex.Type, columnType sqlbuilder.ColumnDef) string {
   235  	return c.dbDataType(columnType.Type, columnType)
   236  }
   237  
   238  func (c *dialect) dbDataType(typ typex.Type, columnType sqlbuilder.ColumnDef) string {
   239  	if columnType.DataType != "" {
   240  		// for type from catalog
   241  		return columnType.DataType
   242  	}
   243  
   244  	if rv, ok := typex.TryNew(typ); ok {
   245  		v := rv.Interface()
   246  
   247  		if dtd, ok := v.(sqlbuilder.DataTypeDescriber); ok {
   248  			return dtd.DataType(c.DriverName())
   249  		}
   250  	}
   251  
   252  	if columnType.AutoIncrement {
   253  		return "INTEGER PRIMARY KEY AUTOINCREMENT"
   254  	}
   255  
   256  	switch typ.Kind() {
   257  	case reflect.Ptr:
   258  		return c.dataType(typ.Elem(), columnType)
   259  	case reflect.Bool:
   260  		return "BOOLEAN"
   261  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
   262  		return "INTEGER"
   263  	case reflect.Int64:
   264  		return "BIGINT"
   265  	case reflect.Uint64:
   266  		return "UNSIGNED BIG INT"
   267  	case reflect.Float32:
   268  		return "FLOAT"
   269  	case reflect.Float64:
   270  		return "DOUBLE"
   271  	case reflect.Slice:
   272  		if typ.Elem().Kind() == reflect.Uint8 {
   273  			return "BLOB"
   274  		}
   275  	case reflect.String:
   276  		return "TEXT"
   277  	default:
   278  		if typ.Name() == "Time" && typ.PkgPath() == "time" {
   279  			return "DATETIME"
   280  		}
   281  	}
   282  
   283  	panic(fmt.Errorf("unsupport type %s", typ))
   284  }
   285  
   286  func (c *dialect) dataTypeModify(columnType sqlbuilder.ColumnDef, dataType string) string {
   287  	buf := bytes.NewBuffer(nil)
   288  
   289  	if !columnType.Null {
   290  		buf.WriteString(" NOT NULL")
   291  	}
   292  
   293  	if columnType.Default != nil {
   294  		buf.WriteString(" DEFAULT ")
   295  		buf.WriteString(normalizeDefaultValue(columnType.Default, dataType))
   296  	}
   297  
   298  	return buf.String()
   299  }
   300  
   301  func (c dialect) indexName(key sqlbuilder.Key) sqlbuilder.SqlExpr {
   302  	return sqlbuilder.Expr(fmt.Sprintf("%s_%s", key.T().TableName(), key.Name()))
   303  }
   304  
   305  func normalizeDefaultValue(defaultValue *string, dataType string) string {
   306  	if defaultValue == nil {
   307  		return ""
   308  	}
   309  	return *defaultValue
   310  }