github.com/kunlun-qilian/sqlx/v2@v2.24.0/connectors/postgresql/postgresql_connector.go (about) 1 package postgresql 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql/driver" 7 "fmt" 8 "io" 9 "reflect" 10 "strconv" 11 "strings" 12 13 typex "github.com/go-courier/x/types" 14 15 "github.com/kunlun-qilian/sqlx/v2" 16 "github.com/kunlun-qilian/sqlx/v2/builder" 17 "github.com/kunlun-qilian/sqlx/v2/migration" 18 "github.com/lib/pq" 19 ) 20 21 var _ interface { 22 driver.Connector 23 builder.Dialect 24 } = (*PostgreSQLConnector)(nil) 25 26 type PostgreSQLConnector struct { 27 Host string 28 DBName string 29 Extra string 30 Extensions []string 31 } 32 33 func (c *PostgreSQLConnector) Connect(ctx context.Context) (driver.Conn, error) { 34 d := c.Driver() 35 36 conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra)) 37 if err != nil { 38 if c.IsErrorUnknownDatabase(err) { 39 connectForCreateDB, err := d.Open(dsn(c.Host, "", c.Extra)) 40 if err != nil { 41 return nil, err 42 } 43 if _, err := connectForCreateDB.(driver.ExecerContext).ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil { 44 return nil, err 45 } 46 if err := connectForCreateDB.Close(); err != nil { 47 return nil, err 48 } 49 return c.Connect(ctx) 50 } 51 return nil, err 52 } 53 for _, ex := range c.Extensions { 54 if _, err := conn.(driver.ExecerContext).ExecContext(context.Background(), "CREATE EXTENSION IF NOT EXISTS "+ex+";", nil); err != nil { 55 return nil, err 56 } 57 } 58 59 return conn, nil 60 } 61 62 func (PostgreSQLConnector) Driver() driver.Driver { 63 return &PostgreSQLLoggingDriver{} 64 } 65 66 func dsn(host string, dbName string, extra string) string { 67 if extra != "" { 68 extra = "?" + extra 69 } 70 return host + "/" + dbName + extra 71 } 72 73 func (c PostgreSQLConnector) WithDBName(dbName string) driver.Connector { 74 c.DBName = dbName 75 return &c 76 } 77 78 func (c *PostgreSQLConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error { 79 output := migration.MigrationOutputFromContext(ctx) 80 81 prevDB, err := dbFromInformationSchema(db) 82 if err != nil { 83 return err 84 } 85 86 d := db.D() 87 dialect := db.Dialect() 88 89 exec := func(expr builder.SqlExpr) error { 90 if expr == nil || expr.IsNil() { 91 return nil 92 } 93 94 if output != nil { 95 _, _ = io.WriteString(output, builder.ResolveExpr(expr).Query()) 96 _, _ = io.WriteString(output, "\n") 97 return nil 98 } 99 100 _, err := db.ExecExpr(expr) 101 return err 102 } 103 104 if prevDB == nil { 105 prevDB = &sqlx.Database{ 106 Name: d.Name, 107 } 108 if err := exec(dialect.CreateDatabase(d.Name)); err != nil { 109 return err 110 } 111 } 112 113 if d.Schema != "" { 114 if err := exec(dialect.CreateSchema(d.Schema)); err != nil { 115 return err 116 } 117 prevDB = prevDB.WithSchema(d.Schema) 118 } 119 120 for _, name := range d.Tables.TableNames() { 121 table := d.Table(name) 122 123 prevTable := prevDB.Table(name) 124 125 if prevTable == nil { 126 for _, expr := range dialect.CreateTableIsNotExists(table) { 127 if err := exec(expr); err != nil { 128 return err 129 } 130 } 131 continue 132 } 133 134 exprList := table.Diff(prevTable, dialect) 135 136 for _, expr := range exprList { 137 if err := exec(expr); err != nil { 138 return err 139 } 140 } 141 } 142 143 return nil 144 } 145 146 func (PostgreSQLConnector) DriverName() string { 147 return "postgres" 148 } 149 150 func (PostgreSQLConnector) PrimaryKeyName() string { 151 return "pkey" 152 } 153 154 func (PostgreSQLConnector) IsErrorUnknownDatabase(err error) bool { 155 if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "3D000" { 156 return true 157 } 158 return false 159 } 160 161 func (PostgreSQLConnector) IsErrorConflict(err error) bool { 162 if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "23505" { 163 return true 164 } 165 return false 166 } 167 168 func (c *PostgreSQLConnector) CreateDatabase(dbName string) builder.SqlExpr { 169 e := builder.Expr("CREATE DATABASE ") 170 e.WriteQuery(dbName) 171 e.WriteEnd() 172 return e 173 } 174 175 func (c *PostgreSQLConnector) CreateSchema(schema string) builder.SqlExpr { 176 e := builder.Expr("CREATE SCHEMA IF NOT EXISTS ") 177 e.WriteQuery(schema) 178 e.WriteEnd() 179 return e 180 } 181 182 func (c *PostgreSQLConnector) DropDatabase(dbName string) builder.SqlExpr { 183 e := builder.Expr("DROP DATABASE IF EXISTS ") 184 e.WriteQuery(dbName) 185 e.WriteEnd() 186 return e 187 } 188 189 func (c *PostgreSQLConnector) AddIndex(key *builder.Key) builder.SqlExpr { 190 if key.IsPrimary() { 191 e := builder.Expr("ALTER TABLE ") 192 e.WriteExpr(key.Table) 193 e.WriteQuery(" ADD PRIMARY KEY ") 194 e.WriteExpr(key.Def.TableExpr(key.Table)) 195 e.WriteEnd() 196 return e 197 } 198 199 e := builder.Expr("CREATE ") 200 if key.IsUnique { 201 e.WriteQuery("UNIQUE ") 202 } 203 e.WriteQuery("INDEX ") 204 205 e.WriteQuery(key.Table.Name) 206 e.WriteQuery("_") 207 e.WriteQuery(key.Name) 208 209 e.WriteQuery(" ON ") 210 e.WriteExpr(key.Table) 211 212 if m := strings.ToUpper(key.Method); m != "" { 213 if m == "SPATIAL" { 214 m = "GIST" 215 } 216 e.WriteQuery(" USING ") 217 e.WriteQuery(m) 218 } 219 220 e.WriteQueryByte(' ') 221 e.WriteExpr(key.Def.TableExpr(key.Table)) 222 223 e.WriteEnd() 224 return e 225 } 226 227 func (c *PostgreSQLConnector) DropIndex(key *builder.Key) builder.SqlExpr { 228 if key.IsPrimary() { 229 e := builder.Expr("ALTER TABLE ") 230 e.WriteExpr(key.Table) 231 e.WriteQuery(" DROP CONSTRAINT ") 232 e.WriteExpr(key.Table) 233 e.WriteQuery("_pkey") 234 e.WriteEnd() 235 return e 236 } 237 e := builder.Expr("DROP ") 238 239 e.WriteQuery("INDEX IF EXISTS ") 240 e.WriteExpr(key.Table) 241 e.WriteQueryByte('_') 242 e.WriteQuery(key.Name) 243 e.WriteEnd() 244 245 return e 246 } 247 248 func (c *PostgreSQLConnector) CreateTableIsNotExists(t *builder.Table) (exprs []builder.SqlExpr) { 249 expr := builder.Expr("CREATE TABLE IF NOT EXISTS ") 250 expr.WriteExpr(t) 251 expr.WriteQueryByte(' ') 252 expr.WriteGroup(func(e *builder.Ex) { 253 if t.Columns.IsNil() { 254 return 255 } 256 257 t.Columns.Range(func(col *builder.Column, idx int) { 258 if col.DeprecatedActions != nil { 259 return 260 } 261 262 if idx > 0 { 263 e.WriteQueryByte(',') 264 } 265 e.WriteQueryByte('\n') 266 e.WriteQueryByte('\t') 267 268 e.WriteExpr(col) 269 e.WriteQueryByte(' ') 270 e.WriteExpr(c.DataType(col.ColumnType)) 271 }) 272 273 t.Keys.Range(func(key *builder.Key, idx int) { 274 if key.IsPrimary() { 275 e.WriteQueryByte(',') 276 e.WriteQueryByte('\n') 277 e.WriteQueryByte('\t') 278 e.WriteQuery("PRIMARY KEY ") 279 e.WriteExpr(key.Def.TableExpr(key.Table)) 280 } 281 }) 282 283 expr.WriteQueryByte('\n') 284 }) 285 286 t.Keys.Range(func(key *builder.Key, idx int) { 287 if key.IsPartition() { 288 expr.WriteQuery(" PARTITION BY ") 289 expr.WriteQuery(key.Method) 290 expr.WriteQueryByte(' ') 291 expr.WriteExpr(key.Def.TableExpr(key.Table)) 292 } 293 }) 294 295 expr.WriteEnd() 296 exprs = append(exprs, expr) 297 298 t.Keys.Range(func(key *builder.Key, idx int) { 299 if !key.IsPrimary() && !key.IsPartition() { 300 exprs = append(exprs, c.AddIndex(key)) 301 } 302 }) 303 304 return 305 } 306 307 func (c *PostgreSQLConnector) DropTable(t *builder.Table) builder.SqlExpr { 308 e := builder.Expr("DROP TABLE IF EXISTS ") 309 e.WriteExpr(t) 310 e.WriteEnd() 311 return e 312 } 313 314 func (c *PostgreSQLConnector) TruncateTable(t *builder.Table) builder.SqlExpr { 315 e := builder.Expr("TRUNCATE TABLE ") 316 e.WriteExpr(t) 317 e.WriteEnd() 318 return e 319 } 320 321 func (c *PostgreSQLConnector) AddColumn(col *builder.Column) builder.SqlExpr { 322 e := builder.Expr("ALTER TABLE ") 323 e.WriteExpr(col.Table) 324 e.WriteQuery(" ADD COLUMN ") 325 e.WriteExpr(col) 326 e.WriteQueryByte(' ') 327 e.WriteExpr(c.DataType(col.ColumnType)) 328 e.WriteEnd() 329 return e 330 } 331 332 func (c *PostgreSQLConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr { 333 e := builder.Expr("ALTER TABLE ") 334 e.WriteExpr(col.Table) 335 e.WriteQuery(" RENAME COLUMN ") 336 e.WriteExpr(col) 337 e.WriteQuery(" TO ") 338 e.WriteExpr(target) 339 e.WriteEnd() 340 return e 341 } 342 343 func (c *PostgreSQLConnector) ModifyColumn(col *builder.Column, prev *builder.Column) builder.SqlExpr { 344 if col.AutoIncrement { 345 return nil 346 } 347 348 e := builder.Expr("ALTER TABLE ") 349 e.WriteExpr(col.Table) 350 351 dbDataType := c.dataType(col.ColumnType.Type, col.ColumnType) 352 prevDbDataType := c.dataType(prev.ColumnType.Type, prev.ColumnType) 353 354 isFirstSub := true 355 isEmpty := true 356 357 prepareAppendSubCmd := func() { 358 if !isFirstSub { 359 e.WriteQueryByte(',') 360 } 361 isFirstSub = false 362 isEmpty = false 363 } 364 365 if dbDataType != prevDbDataType { 366 prepareAppendSubCmd() 367 368 e.WriteQuery(" ALTER COLUMN ") 369 e.WriteExpr(col) 370 e.WriteQuery(" TYPE ") 371 e.WriteQuery(dbDataType) 372 373 e.WriteQuery(" /* FROM ") 374 e.WriteQuery(prevDbDataType) 375 e.WriteQuery(" */") 376 } 377 378 if col.Null != prev.Null { 379 prepareAppendSubCmd() 380 381 e.WriteQuery(" ALTER COLUMN ") 382 e.WriteExpr(col) 383 if !col.Null { 384 e.WriteQuery(" SET NOT NULL") 385 } else { 386 e.WriteQuery(" DROP NOT NULL") 387 } 388 } 389 390 defaultValue := normalizeDefaultValue(col.Default, dbDataType) 391 prevDefaultValue := normalizeDefaultValue(prev.Default, prevDbDataType) 392 393 if defaultValue != prevDefaultValue { 394 prepareAppendSubCmd() 395 396 e.WriteQuery(" ALTER COLUMN ") 397 e.WriteExpr(col) 398 if col.Default != nil { 399 e.WriteQuery(" SET DEFAULT ") 400 e.WriteQuery(defaultValue) 401 402 e.WriteQuery(" /* FROM ") 403 e.WriteQuery(prevDefaultValue) 404 e.WriteQuery(" */") 405 } else { 406 e.WriteQuery(" DROP DEFAULT") 407 } 408 } 409 410 if isEmpty { 411 return nil 412 } 413 414 e.WriteEnd() 415 416 return e 417 } 418 419 func (c *PostgreSQLConnector) DropColumn(col *builder.Column) builder.SqlExpr { 420 e := builder.Expr("ALTER TABLE ") 421 e.WriteExpr(col.Table) 422 e.WriteQuery(" DROP COLUMN ") 423 e.WriteQuery(col.Name) 424 e.WriteEnd() 425 return e 426 } 427 428 func (c *PostgreSQLConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr { 429 dbDataType := dealias(c.dbDataType(columnType.Type, columnType)) 430 return builder.Expr(dbDataType + autocompleteSize(dbDataType, columnType) + c.dataTypeModify(columnType, dbDataType)) 431 } 432 433 func (c *PostgreSQLConnector) dataType(typ typex.Type, columnType *builder.ColumnType) string { 434 dbDataType := dealias(c.dbDataType(columnType.Type, columnType)) 435 return dbDataType + autocompleteSize(dbDataType, columnType) 436 } 437 438 func (c *PostgreSQLConnector) dbDataType(typ typex.Type, columnType *builder.ColumnType) string { 439 if columnType.DataType != "" { 440 return columnType.DataType 441 } 442 443 if rv, ok := typex.TryNew(typ); ok { 444 if dtd, ok := rv.Interface().(builder.DataTypeDescriber); ok { 445 return dtd.DataType(c.DriverName()) 446 } 447 } 448 449 switch typ.Kind() { 450 case reflect.Ptr: 451 return c.dataType(typ.Elem(), columnType) 452 case reflect.Bool: 453 return "boolean" 454 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: 455 if columnType.AutoIncrement { 456 return "serial" 457 } 458 return "integer" 459 case reflect.Int64, reflect.Uint64: 460 if columnType.AutoIncrement { 461 return "bigserial" 462 } 463 return "bigint" 464 case reflect.Float64: 465 return "double precision" 466 case reflect.Float32: 467 return "real" 468 case reflect.Slice: 469 if typ.Elem().Kind() == reflect.Uint8 { 470 return "bytea" 471 } 472 case reflect.String: 473 size := columnType.Length 474 if size < 65535/3 { 475 return "varchar" 476 } 477 return "text" 478 } 479 480 switch typ.Name() { 481 case "Hstore": 482 return "hstore" 483 case "ByteaArray": 484 return c.dataType(typex.FromRType(reflect.TypeOf(pq.ByteaArray{[]byte("")}[0])), columnType) + "[]" 485 case "BoolArray": 486 return c.dataType(typex.FromRType(reflect.TypeOf(pq.BoolArray{true}[0])), columnType) + "[]" 487 case "Float64Array": 488 return c.dataType(typex.FromRType(reflect.TypeOf(pq.Float64Array{0}[0])), columnType) + "[]" 489 case "Int64Array": 490 return c.dataType(typex.FromRType(reflect.TypeOf(pq.Int64Array{0}[0])), columnType) + "[]" 491 case "StringArray": 492 return c.dataType(typex.FromRType(reflect.TypeOf(pq.StringArray{""}[0])), columnType) + "[]" 493 case "NullInt64": 494 return "bigint" 495 case "NullFloat64": 496 return "double precision" 497 case "NullBool": 498 return "boolean" 499 case "Time", "NullTime": 500 return "timestamp with time zone" 501 } 502 503 panic(fmt.Errorf("unsupport type %s", typ)) 504 } 505 506 func (c *PostgreSQLConnector) dataTypeModify(columnType *builder.ColumnType, dataType string) string { 507 buf := bytes.NewBuffer(nil) 508 509 if !columnType.Null { 510 buf.WriteString(" NOT NULL") 511 } 512 513 if columnType.Default != nil { 514 buf.WriteString(" DEFAULT ") 515 buf.WriteString(normalizeDefaultValue(columnType.Default, dataType)) 516 } 517 518 return buf.String() 519 } 520 521 func normalizeDefaultValue(defaultValue *string, dataType string) string { 522 if defaultValue == nil { 523 return "" 524 } 525 526 dv := *defaultValue 527 528 if dv[0] == '\'' { 529 if strings.Contains(dv, "'::") { 530 return dv 531 } 532 return dv + "::" + dataType 533 } 534 535 _, err := strconv.ParseFloat(dv, 64) 536 if err == nil { 537 return "'" + dv + "'::" + dataType 538 } 539 540 return dv 541 } 542 543 func autocompleteSize(dataType string, columnType *builder.ColumnType) string { 544 switch dataType { 545 case "character varying", "character": 546 size := columnType.Length 547 if size == 0 { 548 size = 255 549 } 550 return sizeModifier(size, columnType.Decimal) 551 case "decimal", "numeric", "real", "double precision": 552 if columnType.Length > 0 { 553 return sizeModifier(columnType.Length, columnType.Decimal) 554 } 555 } 556 return "" 557 } 558 559 func dealias(dataType string) string { 560 switch dataType { 561 case "varchar": 562 return "character varying" 563 case "timestamp": 564 return "timestamp without time zone" 565 } 566 return dataType 567 } 568 569 func sizeModifier(length uint64, decimal uint64) string { 570 if length > 0 { 571 size := strconv.FormatUint(length, 10) 572 if decimal > 0 { 573 return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")" 574 } 575 return "(" + size + ")" 576 } 577 return "" 578 }