github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/internal/sql/adapter/postgres/dialect.go (about)

     1  package postgres
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"reflect"
     7  	"strconv"
     8  	"strings"
     9  
    10  	"github.com/octohelm/storage/internal/sql/adapter"
    11  	"github.com/octohelm/storage/pkg/sqlbuilder"
    12  	typex "github.com/octohelm/x/types"
    13  )
    14  
    15  var _ adapter.Dialect = (*dialect)(nil)
    16  
    17  type dialect struct {
    18  }
    19  
    20  func (dialect) DriverName() string {
    21  	return "postgres"
    22  }
    23  
    24  func (c *dialect) indexName(key sqlbuilder.Key) string {
    25  	name := key.Name()
    26  	if name == "primary" {
    27  		name = "pkey"
    28  	}
    29  	return key.T().TableName() + "_" + name
    30  }
    31  
    32  func (c *dialect) AddIndex(key sqlbuilder.Key) sqlbuilder.SqlExpr {
    33  	if key.IsPrimary() {
    34  		e := sqlbuilder.Expr("ALTER TABLE ")
    35  		e.WriteExpr(key.T())
    36  		e.WriteQuery(" ADD PRIMARY KEY ")
    37  		e.WriteGroup(func(e *sqlbuilder.Ex) {
    38  			e.WriteExpr(key.Columns())
    39  		})
    40  		e.WriteEnd()
    41  		return e
    42  	}
    43  
    44  	e := sqlbuilder.Expr("CREATE ")
    45  
    46  	if key.IsUnique() {
    47  		e.WriteQuery("UNIQUE ")
    48  	}
    49  
    50  	e.WriteQuery("INDEX ")
    51  
    52  	e.WriteQuery(c.indexName(key))
    53  
    54  	e.WriteQuery(" ON ")
    55  	e.WriteExpr(key.T())
    56  
    57  	keyDef := key.(sqlbuilder.KeyDef)
    58  
    59  	if m := strings.ToUpper(keyDef.Method()); m != "" {
    60  		if m == "SPATIAL" {
    61  			m = "GIST"
    62  		}
    63  		e.WriteQuery(" USING ")
    64  		e.WriteQuery(m)
    65  	}
    66  
    67  	e.WriteQueryByte(' ')
    68  	e.WriteGroup(func(e *sqlbuilder.Ex) {
    69  		for i, colNameAndOpt := range keyDef.ColNameAndOptions() {
    70  			parts := strings.Split(colNameAndOpt, "/")
    71  			if i != 0 {
    72  				_ = e.WriteByte(',')
    73  			}
    74  			e.WriteExpr(key.T().F(parts[0]))
    75  			if len(parts) > 1 {
    76  				e.WriteQuery(" ")
    77  				e.WriteQuery(parts[1])
    78  			}
    79  		}
    80  	})
    81  
    82  	e.WriteEnd()
    83  	return e
    84  }
    85  
    86  func (c *dialect) DropIndex(key sqlbuilder.Key) sqlbuilder.SqlExpr {
    87  	if key.IsPrimary() {
    88  		e := sqlbuilder.Expr("ALTER TABLE ")
    89  		e.WriteExpr(key.T())
    90  		e.WriteQuery(" DROP CONSTRAINT ")
    91  		e.WriteQuery(c.indexName(key))
    92  		e.WriteEnd()
    93  		return e
    94  	}
    95  	e := sqlbuilder.Expr("DROP ")
    96  
    97  	e.WriteQuery("INDEX IF EXISTS ")
    98  	e.WriteQuery(c.indexName(key))
    99  	e.WriteEnd()
   100  
   101  	return e
   102  }
   103  
   104  func (c *dialect) CreateTableIsNotExists(t sqlbuilder.Table) (exprs []sqlbuilder.SqlExpr) {
   105  	expr := sqlbuilder.Expr("CREATE TABLE IF NOT EXISTS @table ", sqlbuilder.NamedArgSet{
   106  		"table": t,
   107  	})
   108  
   109  	expr.WriteGroup(func(e *sqlbuilder.Ex) {
   110  		cols := t.Cols()
   111  
   112  		if cols.IsNil() {
   113  			return
   114  		}
   115  
   116  		cols.RangeCol(func(col sqlbuilder.Column, idx int) bool {
   117  			def := col.Def()
   118  
   119  			if def.DeprecatedActions != nil {
   120  				return true
   121  			}
   122  
   123  			if idx > 0 {
   124  				e.WriteQueryByte(',')
   125  			}
   126  			e.WriteQueryByte('\n')
   127  			e.WriteQueryByte('\t')
   128  
   129  			e.WriteExpr(col)
   130  			e.WriteQueryByte(' ')
   131  			e.WriteExpr(c.DataType(def))
   132  
   133  			return true
   134  		})
   135  
   136  		t.Keys().RangeKey(func(key sqlbuilder.Key, idx int) bool {
   137  			if key.IsPrimary() {
   138  				e.WriteQueryByte(',')
   139  				e.WriteQueryByte('\n')
   140  				e.WriteQueryByte('\t')
   141  				e.WriteQuery("PRIMARY KEY ")
   142  				e.WriteGroup(func(e *sqlbuilder.Ex) {
   143  					e.WriteExpr(key.Columns())
   144  				})
   145  			}
   146  			return true
   147  		})
   148  
   149  		expr.WriteQueryByte('\n')
   150  	})
   151  
   152  	expr.WriteEnd()
   153  
   154  	exprs = append(exprs, expr)
   155  
   156  	t.Keys().RangeKey(func(key sqlbuilder.Key, idx int) bool {
   157  		if !key.IsPrimary() {
   158  			exprs = append(exprs, c.AddIndex(key))
   159  		}
   160  		return true
   161  	})
   162  
   163  	return
   164  }
   165  
   166  func (c *dialect) DropTable(t sqlbuilder.Table) sqlbuilder.SqlExpr {
   167  	return sqlbuilder.Expr("DROP TABLE IF EXISTS @table;", sqlbuilder.NamedArgSet{
   168  		"table": t,
   169  	})
   170  }
   171  
   172  func (c *dialect) TruncateTable(t sqlbuilder.Table) sqlbuilder.SqlExpr {
   173  	return sqlbuilder.Expr("TRUNCATE TABLE @table;", sqlbuilder.NamedArgSet{
   174  		"table": t,
   175  	})
   176  }
   177  
   178  func (c *dialect) AddColumn(col sqlbuilder.Column) sqlbuilder.SqlExpr {
   179  	return sqlbuilder.Expr("ALTER TABLE @table ADD COLUMN @col @dataType;", sqlbuilder.NamedArgSet{
   180  		"table":    col.T(),
   181  		"col":      col,
   182  		"dataType": c.DataType(col.Def()),
   183  	})
   184  }
   185  
   186  func (c *dialect) RenameColumn(col sqlbuilder.Column, target sqlbuilder.Column) sqlbuilder.SqlExpr {
   187  	return sqlbuilder.Expr("ALTER TABLE @table RENAME COLUMN @oldCol TO @newCol;", sqlbuilder.NamedArgSet{
   188  		"table":  col.T(),
   189  		"oldCol": col,
   190  		"newCol": target,
   191  	})
   192  }
   193  
   194  func (c *dialect) ModifyColumn(col sqlbuilder.Column, prev sqlbuilder.Column) sqlbuilder.SqlExpr {
   195  	def := col.Def()
   196  	prevDef := prev.Def()
   197  
   198  	if def.AutoIncrement {
   199  		return nil
   200  	}
   201  
   202  	e := sqlbuilder.Expr("ALTER TABLE ")
   203  	e.WriteExpr(col.T())
   204  
   205  	dbDataType := c.dataType(def.Type, def)
   206  	prevDbDataType := c.dataType(prevDef.Type, prevDef)
   207  
   208  	isFirstSub := true
   209  	isEmpty := true
   210  
   211  	prepareAppendSubCmd := func() {
   212  		if !isFirstSub {
   213  			e.WriteQueryByte(',')
   214  		}
   215  		isFirstSub = false
   216  		isEmpty = false
   217  	}
   218  
   219  	if dbDataType != prevDbDataType {
   220  		prepareAppendSubCmd()
   221  
   222  		e.WriteQuery(" ALTER COLUMN ")
   223  		e.WriteExpr(col)
   224  		e.WriteQuery(" TYPE ")
   225  		e.WriteQuery(dbDataType)
   226  
   227  		e.WriteQuery(" /* FROM ")
   228  		e.WriteQuery(prevDbDataType)
   229  		e.WriteQuery(" */")
   230  	}
   231  
   232  	if def.Null != prevDef.Null {
   233  		prepareAppendSubCmd()
   234  
   235  		e.WriteQuery(" ALTER COLUMN ")
   236  		e.WriteExpr(col)
   237  		if !def.Null {
   238  			e.WriteQuery(" SET NOT NULL")
   239  		} else {
   240  			e.WriteQuery(" DROP NOT NULL")
   241  		}
   242  	}
   243  
   244  	defaultValue := normalizeDefaultValue(def.Default, dbDataType)
   245  	prevDefaultValue := normalizeDefaultValue(prevDef.Default, prevDbDataType)
   246  
   247  	if defaultValue != prevDefaultValue {
   248  		prepareAppendSubCmd()
   249  
   250  		e.WriteQuery(" ALTER COLUMN ")
   251  		e.WriteExpr(col)
   252  		if def.Default != nil {
   253  			e.WriteQuery(" SET DEFAULT ")
   254  			e.WriteQuery(defaultValue)
   255  
   256  			e.WriteQuery(" /* FROM ")
   257  			e.WriteQuery(prevDefaultValue)
   258  			e.WriteQuery(" */")
   259  		} else {
   260  			e.WriteQuery(" DROP DEFAULT")
   261  		}
   262  	}
   263  
   264  	if isEmpty {
   265  		return nil
   266  	}
   267  
   268  	e.WriteEnd()
   269  
   270  	return e
   271  }
   272  
   273  func (c *dialect) DropColumn(col sqlbuilder.Column) sqlbuilder.SqlExpr {
   274  	return sqlbuilder.Expr("ALTER TABLE @table DROP COLUMN @col;", sqlbuilder.NamedArgSet{
   275  		"table": col.T(),
   276  		"col":   col,
   277  	})
   278  }
   279  
   280  func (c *dialect) DataType(columnType sqlbuilder.ColumnDef) sqlbuilder.SqlExpr {
   281  	dbDataType := dealias(c.dbDataType(columnType.Type, columnType))
   282  	return sqlbuilder.Expr(dbDataType + autocompleteSize(dbDataType, columnType) + c.dataTypeModify(columnType, dbDataType))
   283  }
   284  
   285  func (c *dialect) dataType(typ typex.Type, columnType sqlbuilder.ColumnDef) string {
   286  	dbDataType := dealias(c.dbDataType(typ, columnType))
   287  	return dbDataType + autocompleteSize(dbDataType, columnType)
   288  }
   289  
   290  func (c *dialect) dbDataType(typ typex.Type, columnType sqlbuilder.ColumnDef) string {
   291  	if columnType.DataType != "" {
   292  		return columnType.DataType
   293  	}
   294  
   295  	if rv, ok := typex.TryNew(typ); ok {
   296  		if dtd, ok := rv.Interface().(sqlbuilder.DataTypeDescriber); ok {
   297  			return dtd.DataType(c.DriverName())
   298  		}
   299  	}
   300  
   301  	switch typ.Kind() {
   302  	case reflect.Ptr:
   303  		return c.dataType(typ.Elem(), columnType)
   304  	case reflect.Bool:
   305  		return "boolean"
   306  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
   307  		if columnType.AutoIncrement {
   308  			return "serial"
   309  		}
   310  		return "integer"
   311  	case reflect.Int64, reflect.Uint64:
   312  		if columnType.AutoIncrement {
   313  			return "bigserial"
   314  		}
   315  		return "bigint"
   316  	case reflect.Float64:
   317  		return "double precision"
   318  	case reflect.Float32:
   319  		return "real"
   320  	case reflect.Slice:
   321  		if typ.Elem().Kind() == reflect.Uint8 {
   322  			return "bytea"
   323  		}
   324  	case reflect.String:
   325  		size := columnType.Length
   326  		if size < 65535/3 {
   327  			return "varchar"
   328  		}
   329  		return "text"
   330  	}
   331  
   332  	switch typ.Name() {
   333  	case "Hstore":
   334  		return "hstore"
   335  	case "NullInt64":
   336  		return "bigint"
   337  	case "NullFloat64":
   338  		return "double precision"
   339  	case "NullBool":
   340  		return "boolean"
   341  	case "Time", "NullTime":
   342  		return "timestamp with time zone"
   343  	}
   344  
   345  	panic(fmt.Errorf("unsupport type %s", typ))
   346  }
   347  
   348  func (c *dialect) dataTypeModify(columnType sqlbuilder.ColumnDef, dataType string) string {
   349  	buf := bytes.NewBuffer(nil)
   350  
   351  	if !columnType.Null {
   352  		buf.WriteString(" NOT NULL")
   353  	}
   354  
   355  	if columnType.Default != nil {
   356  		buf.WriteString(" DEFAULT ")
   357  		buf.WriteString(normalizeDefaultValue(columnType.Default, dataType))
   358  	}
   359  
   360  	return buf.String()
   361  }
   362  
   363  func normalizeDefaultValue(defaultValue *string, dataType string) string {
   364  	if defaultValue == nil {
   365  		return ""
   366  	}
   367  
   368  	dv := *defaultValue
   369  
   370  	if dv[0] == '\'' {
   371  		if strings.Contains(dv, "'::") {
   372  			return dv
   373  		}
   374  		return dv + "::" + dataType
   375  	}
   376  
   377  	_, err := strconv.ParseFloat(dv, 64)
   378  	if err == nil {
   379  		return "'" + dv + "'::" + dataType
   380  	}
   381  
   382  	return dv
   383  }
   384  
   385  func autocompleteSize(dataType string, columnType sqlbuilder.ColumnDef) string {
   386  	switch dataType {
   387  	case "character varying", "character":
   388  		size := columnType.Length
   389  		if size == 0 {
   390  			size = 255
   391  		}
   392  		return sizeModifier(size, columnType.Decimal)
   393  	case "decimal", "numeric", "real", "double precision":
   394  		if columnType.Length > 0 {
   395  			return sizeModifier(columnType.Length, columnType.Decimal)
   396  		}
   397  	}
   398  	return ""
   399  }
   400  
   401  func dealias(dataType string) string {
   402  	switch dataType {
   403  	case "varchar":
   404  		return "character varying"
   405  	case "timestamp":
   406  		return "timestamp without time zone"
   407  	}
   408  	return dataType
   409  }
   410  
   411  func sizeModifier(length uint64, decimal uint64) string {
   412  	if length > 0 {
   413  		size := strconv.FormatUint(length, 10)
   414  		if decimal > 0 {
   415  			return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")"
   416  		}
   417  		return "(" + size + ")"
   418  	}
   419  	return ""
   420  }