github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/internal/sql/adapter/postgres/dialect.go (about) 1 package postgres 2 3 import ( 4 "bytes" 5 "fmt" 6 "reflect" 7 "strconv" 8 "strings" 9 10 "github.com/octohelm/storage/internal/sql/adapter" 11 "github.com/octohelm/storage/pkg/sqlbuilder" 12 typex "github.com/octohelm/x/types" 13 ) 14 15 var _ adapter.Dialect = (*dialect)(nil) 16 17 type dialect struct { 18 } 19 20 func (dialect) DriverName() string { 21 return "postgres" 22 } 23 24 func (c *dialect) indexName(key sqlbuilder.Key) string { 25 name := key.Name() 26 if name == "primary" { 27 name = "pkey" 28 } 29 return key.T().TableName() + "_" + name 30 } 31 32 func (c *dialect) AddIndex(key sqlbuilder.Key) sqlbuilder.SqlExpr { 33 if key.IsPrimary() { 34 e := sqlbuilder.Expr("ALTER TABLE ") 35 e.WriteExpr(key.T()) 36 e.WriteQuery(" ADD PRIMARY KEY ") 37 e.WriteGroup(func(e *sqlbuilder.Ex) { 38 e.WriteExpr(key.Columns()) 39 }) 40 e.WriteEnd() 41 return e 42 } 43 44 e := sqlbuilder.Expr("CREATE ") 45 46 if key.IsUnique() { 47 e.WriteQuery("UNIQUE ") 48 } 49 50 e.WriteQuery("INDEX ") 51 52 e.WriteQuery(c.indexName(key)) 53 54 e.WriteQuery(" ON ") 55 e.WriteExpr(key.T()) 56 57 keyDef := key.(sqlbuilder.KeyDef) 58 59 if m := strings.ToUpper(keyDef.Method()); m != "" { 60 if m == "SPATIAL" { 61 m = "GIST" 62 } 63 e.WriteQuery(" USING ") 64 e.WriteQuery(m) 65 } 66 67 e.WriteQueryByte(' ') 68 e.WriteGroup(func(e *sqlbuilder.Ex) { 69 for i, colNameAndOpt := range keyDef.ColNameAndOptions() { 70 parts := strings.Split(colNameAndOpt, "/") 71 if i != 0 { 72 _ = e.WriteByte(',') 73 } 74 e.WriteExpr(key.T().F(parts[0])) 75 if len(parts) > 1 { 76 e.WriteQuery(" ") 77 e.WriteQuery(parts[1]) 78 } 79 } 80 }) 81 82 e.WriteEnd() 83 return e 84 } 85 86 func (c *dialect) DropIndex(key sqlbuilder.Key) sqlbuilder.SqlExpr { 87 if key.IsPrimary() { 88 e := sqlbuilder.Expr("ALTER TABLE ") 89 e.WriteExpr(key.T()) 90 e.WriteQuery(" DROP CONSTRAINT ") 91 e.WriteQuery(c.indexName(key)) 92 e.WriteEnd() 93 return e 94 } 95 e := sqlbuilder.Expr("DROP ") 96 97 e.WriteQuery("INDEX IF EXISTS ") 98 e.WriteQuery(c.indexName(key)) 99 e.WriteEnd() 100 101 return e 102 } 103 104 func (c *dialect) CreateTableIsNotExists(t sqlbuilder.Table) (exprs []sqlbuilder.SqlExpr) { 105 expr := sqlbuilder.Expr("CREATE TABLE IF NOT EXISTS @table ", sqlbuilder.NamedArgSet{ 106 "table": t, 107 }) 108 109 expr.WriteGroup(func(e *sqlbuilder.Ex) { 110 cols := t.Cols() 111 112 if cols.IsNil() { 113 return 114 } 115 116 cols.RangeCol(func(col sqlbuilder.Column, idx int) bool { 117 def := col.Def() 118 119 if def.DeprecatedActions != nil { 120 return true 121 } 122 123 if idx > 0 { 124 e.WriteQueryByte(',') 125 } 126 e.WriteQueryByte('\n') 127 e.WriteQueryByte('\t') 128 129 e.WriteExpr(col) 130 e.WriteQueryByte(' ') 131 e.WriteExpr(c.DataType(def)) 132 133 return true 134 }) 135 136 t.Keys().RangeKey(func(key sqlbuilder.Key, idx int) bool { 137 if key.IsPrimary() { 138 e.WriteQueryByte(',') 139 e.WriteQueryByte('\n') 140 e.WriteQueryByte('\t') 141 e.WriteQuery("PRIMARY KEY ") 142 e.WriteGroup(func(e *sqlbuilder.Ex) { 143 e.WriteExpr(key.Columns()) 144 }) 145 } 146 return true 147 }) 148 149 expr.WriteQueryByte('\n') 150 }) 151 152 expr.WriteEnd() 153 154 exprs = append(exprs, expr) 155 156 t.Keys().RangeKey(func(key sqlbuilder.Key, idx int) bool { 157 if !key.IsPrimary() { 158 exprs = append(exprs, c.AddIndex(key)) 159 } 160 return true 161 }) 162 163 return 164 } 165 166 func (c *dialect) DropTable(t sqlbuilder.Table) sqlbuilder.SqlExpr { 167 return sqlbuilder.Expr("DROP TABLE IF EXISTS @table;", sqlbuilder.NamedArgSet{ 168 "table": t, 169 }) 170 } 171 172 func (c *dialect) TruncateTable(t sqlbuilder.Table) sqlbuilder.SqlExpr { 173 return sqlbuilder.Expr("TRUNCATE TABLE @table;", sqlbuilder.NamedArgSet{ 174 "table": t, 175 }) 176 } 177 178 func (c *dialect) AddColumn(col sqlbuilder.Column) sqlbuilder.SqlExpr { 179 return sqlbuilder.Expr("ALTER TABLE @table ADD COLUMN @col @dataType;", sqlbuilder.NamedArgSet{ 180 "table": col.T(), 181 "col": col, 182 "dataType": c.DataType(col.Def()), 183 }) 184 } 185 186 func (c *dialect) RenameColumn(col sqlbuilder.Column, target sqlbuilder.Column) sqlbuilder.SqlExpr { 187 return sqlbuilder.Expr("ALTER TABLE @table RENAME COLUMN @oldCol TO @newCol;", sqlbuilder.NamedArgSet{ 188 "table": col.T(), 189 "oldCol": col, 190 "newCol": target, 191 }) 192 } 193 194 func (c *dialect) ModifyColumn(col sqlbuilder.Column, prev sqlbuilder.Column) sqlbuilder.SqlExpr { 195 def := col.Def() 196 prevDef := prev.Def() 197 198 if def.AutoIncrement { 199 return nil 200 } 201 202 e := sqlbuilder.Expr("ALTER TABLE ") 203 e.WriteExpr(col.T()) 204 205 dbDataType := c.dataType(def.Type, def) 206 prevDbDataType := c.dataType(prevDef.Type, prevDef) 207 208 isFirstSub := true 209 isEmpty := true 210 211 prepareAppendSubCmd := func() { 212 if !isFirstSub { 213 e.WriteQueryByte(',') 214 } 215 isFirstSub = false 216 isEmpty = false 217 } 218 219 if dbDataType != prevDbDataType { 220 prepareAppendSubCmd() 221 222 e.WriteQuery(" ALTER COLUMN ") 223 e.WriteExpr(col) 224 e.WriteQuery(" TYPE ") 225 e.WriteQuery(dbDataType) 226 227 e.WriteQuery(" /* FROM ") 228 e.WriteQuery(prevDbDataType) 229 e.WriteQuery(" */") 230 } 231 232 if def.Null != prevDef.Null { 233 prepareAppendSubCmd() 234 235 e.WriteQuery(" ALTER COLUMN ") 236 e.WriteExpr(col) 237 if !def.Null { 238 e.WriteQuery(" SET NOT NULL") 239 } else { 240 e.WriteQuery(" DROP NOT NULL") 241 } 242 } 243 244 defaultValue := normalizeDefaultValue(def.Default, dbDataType) 245 prevDefaultValue := normalizeDefaultValue(prevDef.Default, prevDbDataType) 246 247 if defaultValue != prevDefaultValue { 248 prepareAppendSubCmd() 249 250 e.WriteQuery(" ALTER COLUMN ") 251 e.WriteExpr(col) 252 if def.Default != nil { 253 e.WriteQuery(" SET DEFAULT ") 254 e.WriteQuery(defaultValue) 255 256 e.WriteQuery(" /* FROM ") 257 e.WriteQuery(prevDefaultValue) 258 e.WriteQuery(" */") 259 } else { 260 e.WriteQuery(" DROP DEFAULT") 261 } 262 } 263 264 if isEmpty { 265 return nil 266 } 267 268 e.WriteEnd() 269 270 return e 271 } 272 273 func (c *dialect) DropColumn(col sqlbuilder.Column) sqlbuilder.SqlExpr { 274 return sqlbuilder.Expr("ALTER TABLE @table DROP COLUMN @col;", sqlbuilder.NamedArgSet{ 275 "table": col.T(), 276 "col": col, 277 }) 278 } 279 280 func (c *dialect) DataType(columnType sqlbuilder.ColumnDef) sqlbuilder.SqlExpr { 281 dbDataType := dealias(c.dbDataType(columnType.Type, columnType)) 282 return sqlbuilder.Expr(dbDataType + autocompleteSize(dbDataType, columnType) + c.dataTypeModify(columnType, dbDataType)) 283 } 284 285 func (c *dialect) dataType(typ typex.Type, columnType sqlbuilder.ColumnDef) string { 286 dbDataType := dealias(c.dbDataType(typ, columnType)) 287 return dbDataType + autocompleteSize(dbDataType, columnType) 288 } 289 290 func (c *dialect) dbDataType(typ typex.Type, columnType sqlbuilder.ColumnDef) string { 291 if columnType.DataType != "" { 292 return columnType.DataType 293 } 294 295 if rv, ok := typex.TryNew(typ); ok { 296 if dtd, ok := rv.Interface().(sqlbuilder.DataTypeDescriber); ok { 297 return dtd.DataType(c.DriverName()) 298 } 299 } 300 301 switch typ.Kind() { 302 case reflect.Ptr: 303 return c.dataType(typ.Elem(), columnType) 304 case reflect.Bool: 305 return "boolean" 306 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: 307 if columnType.AutoIncrement { 308 return "serial" 309 } 310 return "integer" 311 case reflect.Int64, reflect.Uint64: 312 if columnType.AutoIncrement { 313 return "bigserial" 314 } 315 return "bigint" 316 case reflect.Float64: 317 return "double precision" 318 case reflect.Float32: 319 return "real" 320 case reflect.Slice: 321 if typ.Elem().Kind() == reflect.Uint8 { 322 return "bytea" 323 } 324 case reflect.String: 325 size := columnType.Length 326 if size < 65535/3 { 327 return "varchar" 328 } 329 return "text" 330 } 331 332 switch typ.Name() { 333 case "Hstore": 334 return "hstore" 335 case "NullInt64": 336 return "bigint" 337 case "NullFloat64": 338 return "double precision" 339 case "NullBool": 340 return "boolean" 341 case "Time", "NullTime": 342 return "timestamp with time zone" 343 } 344 345 panic(fmt.Errorf("unsupport type %s", typ)) 346 } 347 348 func (c *dialect) dataTypeModify(columnType sqlbuilder.ColumnDef, dataType string) string { 349 buf := bytes.NewBuffer(nil) 350 351 if !columnType.Null { 352 buf.WriteString(" NOT NULL") 353 } 354 355 if columnType.Default != nil { 356 buf.WriteString(" DEFAULT ") 357 buf.WriteString(normalizeDefaultValue(columnType.Default, dataType)) 358 } 359 360 return buf.String() 361 } 362 363 func normalizeDefaultValue(defaultValue *string, dataType string) string { 364 if defaultValue == nil { 365 return "" 366 } 367 368 dv := *defaultValue 369 370 if dv[0] == '\'' { 371 if strings.Contains(dv, "'::") { 372 return dv 373 } 374 return dv + "::" + dataType 375 } 376 377 _, err := strconv.ParseFloat(dv, 64) 378 if err == nil { 379 return "'" + dv + "'::" + dataType 380 } 381 382 return dv 383 } 384 385 func autocompleteSize(dataType string, columnType sqlbuilder.ColumnDef) string { 386 switch dataType { 387 case "character varying", "character": 388 size := columnType.Length 389 if size == 0 { 390 size = 255 391 } 392 return sizeModifier(size, columnType.Decimal) 393 case "decimal", "numeric", "real", "double precision": 394 if columnType.Length > 0 { 395 return sizeModifier(columnType.Length, columnType.Decimal) 396 } 397 } 398 return "" 399 } 400 401 func dealias(dataType string) string { 402 switch dataType { 403 case "varchar": 404 return "character varying" 405 case "timestamp": 406 return "timestamp without time zone" 407 } 408 return dataType 409 } 410 411 func sizeModifier(length uint64, decimal uint64) string { 412 if length > 0 { 413 size := strconv.FormatUint(length, 10) 414 if decimal > 0 { 415 return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")" 416 } 417 return "(" + size + ")" 418 } 419 return "" 420 }