github.com/go-courier/sqlx/v2@v2.23.13/connectors/postgresql/postgresql_connector.go (about) 1 package postgresql 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql/driver" 7 "fmt" 8 "io" 9 "reflect" 10 "strconv" 11 "strings" 12 13 typex "github.com/go-courier/x/types" 14 15 "github.com/go-courier/sqlx/v2" 16 "github.com/go-courier/sqlx/v2/builder" 17 "github.com/go-courier/sqlx/v2/migration" 18 "github.com/lib/pq" 19 ) 20 21 var _ interface { 22 driver.Connector 23 builder.Dialect 24 } = (*PostgreSQLConnector)(nil) 25 26 type PostgreSQLConnector struct { 27 Host string 28 DBName string 29 Extra string 30 Extensions []string 31 } 32 33 func (c *PostgreSQLConnector) Connect(ctx context.Context) (driver.Conn, error) { 34 d := c.Driver() 35 36 conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra)) 37 if err != nil { 38 if c.IsErrorUnknownDatabase(err) { 39 connectForCreateDB, err := d.Open(dsn(c.Host, "", c.Extra)) 40 if err != nil { 41 return nil, err 42 } 43 if _, err := connectForCreateDB.(driver.ExecerContext).ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil { 44 return nil, err 45 } 46 if err := connectForCreateDB.Close(); err != nil { 47 return nil, err 48 } 49 return c.Connect(ctx) 50 } 51 return nil, err 52 } 53 for _, ex := range c.Extensions { 54 if _, err := conn.(driver.ExecerContext).ExecContext(context.Background(), "CREATE EXTENSION IF NOT EXISTS "+ex+";", nil); err != nil { 55 return nil, err 56 } 57 } 58 59 return conn, nil 60 } 61 62 func (PostgreSQLConnector) Driver() driver.Driver { 63 return &PostgreSQLLoggingDriver{} 64 } 65 66 func dsn(host string, dbName string, extra string) string { 67 if extra != "" { 68 extra = "?" + extra 69 } 70 return host + "/" + dbName + extra 71 } 72 73 func (c PostgreSQLConnector) WithDBName(dbName string) driver.Connector { 74 c.DBName = dbName 75 return &c 76 } 77 78 func (c *PostgreSQLConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error { 79 output := migration.MigrationOutputFromContext(ctx) 80 81 prevDB, err := dbFromInformationSchema(db) 82 if err != nil { 83 return err 84 } 85 86 d := db.D() 87 dialect := db.Dialect() 88 89 exec := func(expr builder.SqlExpr) error { 90 if expr == nil || expr.IsNil() { 91 return nil 92 } 93 94 if output != nil { 95 _, _ = io.WriteString(output, builder.ResolveExpr(expr).Query()) 96 _, _ = io.WriteString(output, "\n") 97 return nil 98 } 99 100 _, err := db.ExecExpr(expr) 101 return err 102 } 103 104 if prevDB == nil { 105 prevDB = &sqlx.Database{ 106 Name: d.Name, 107 } 108 if err := exec(dialect.CreateDatabase(d.Name)); err != nil { 109 return err 110 } 111 } 112 113 if d.Schema != "" { 114 if err := exec(dialect.CreateSchema(d.Schema)); err != nil { 115 return err 116 } 117 prevDB = prevDB.WithSchema(d.Schema) 118 } 119 120 for _, name := range d.Tables.TableNames() { 121 table := d.Table(name) 122 123 prevTable := prevDB.Table(name) 124 125 if prevTable == nil { 126 for _, expr := range dialect.CreateTableIsNotExists(table) { 127 if err := exec(expr); err != nil { 128 return err 129 } 130 } 131 continue 132 } 133 134 exprList := table.Diff(prevTable, dialect) 135 136 for _, expr := range exprList { 137 if err := exec(expr); err != nil { 138 return err 139 } 140 } 141 } 142 143 return nil 144 } 145 146 func (PostgreSQLConnector) DriverName() string { 147 return "postgres" 148 } 149 150 func (PostgreSQLConnector) PrimaryKeyName() string { 151 return "pkey" 152 } 153 154 func (PostgreSQLConnector) IsErrorUnknownDatabase(err error) bool { 155 if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "3D000" { 156 return true 157 } 158 return false 159 } 160 161 func (PostgreSQLConnector) IsErrorConflict(err error) bool { 162 if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "23505" { 163 return true 164 } 165 return false 166 } 167 168 func (c *PostgreSQLConnector) CreateDatabase(dbName string) builder.SqlExpr { 169 e := builder.Expr("CREATE DATABASE ") 170 e.WriteQuery(dbName) 171 e.WriteEnd() 172 return e 173 } 174 175 func (c *PostgreSQLConnector) CreateSchema(schema string) builder.SqlExpr { 176 e := builder.Expr("CREATE SCHEMA IF NOT EXISTS ") 177 e.WriteQuery(schema) 178 e.WriteEnd() 179 return e 180 } 181 182 func (c *PostgreSQLConnector) DropDatabase(dbName string) builder.SqlExpr { 183 e := builder.Expr("DROP DATABASE IF EXISTS ") 184 e.WriteQuery(dbName) 185 e.WriteEnd() 186 return e 187 } 188 189 func (c *PostgreSQLConnector) AddIndex(key *builder.Key) builder.SqlExpr { 190 if key.IsPrimary() { 191 e := builder.Expr("ALTER TABLE ") 192 e.WriteExpr(key.Table) 193 e.WriteQuery(" ADD PRIMARY KEY ") 194 e.WriteExpr(key.Def.TableExpr(key.Table)) 195 e.WriteEnd() 196 return e 197 } 198 199 e := builder.Expr("CREATE ") 200 if key.IsUnique { 201 e.WriteQuery("UNIQUE ") 202 } 203 e.WriteQuery("INDEX ") 204 205 e.WriteQuery(key.Table.Name) 206 e.WriteQuery("_") 207 e.WriteQuery(key.Name) 208 209 e.WriteQuery(" ON ") 210 e.WriteExpr(key.Table) 211 212 if m := strings.ToUpper(key.Method); m != "" { 213 if m == "SPATIAL" { 214 m = "GIST" 215 } 216 e.WriteQuery(" USING ") 217 e.WriteQuery(m) 218 } 219 220 e.WriteQueryByte(' ') 221 e.WriteExpr(key.Def.TableExpr(key.Table)) 222 223 e.WriteEnd() 224 return e 225 } 226 227 func (c *PostgreSQLConnector) DropIndex(key *builder.Key) builder.SqlExpr { 228 if key.IsPrimary() { 229 e := builder.Expr("ALTER TABLE ") 230 e.WriteExpr(key.Table) 231 e.WriteQuery(" DROP CONSTRAINT ") 232 e.WriteExpr(key.Table) 233 e.WriteQuery("_pkey") 234 e.WriteEnd() 235 return e 236 } 237 e := builder.Expr("DROP ") 238 239 e.WriteQuery("INDEX IF EXISTS ") 240 e.WriteExpr(key.Table) 241 e.WriteQueryByte('_') 242 e.WriteQuery(key.Name) 243 e.WriteEnd() 244 245 return e 246 } 247 248 func (c *PostgreSQLConnector) CreateTableIsNotExists(t *builder.Table) (exprs []builder.SqlExpr) { 249 expr := builder.Expr("CREATE TABLE IF NOT EXISTS ") 250 expr.WriteExpr(t) 251 expr.WriteQueryByte(' ') 252 expr.WriteGroup(func(e *builder.Ex) { 253 if t.Columns.IsNil() { 254 return 255 } 256 257 t.Columns.Range(func(col *builder.Column, idx int) { 258 if col.DeprecatedActions != nil { 259 return 260 } 261 262 if idx > 0 { 263 e.WriteQueryByte(',') 264 } 265 e.WriteQueryByte('\n') 266 e.WriteQueryByte('\t') 267 268 e.WriteExpr(col) 269 e.WriteQueryByte(' ') 270 e.WriteExpr(c.DataType(col.ColumnType)) 271 }) 272 273 t.Keys.Range(func(key *builder.Key, idx int) { 274 if key.IsPrimary() { 275 e.WriteQueryByte(',') 276 e.WriteQueryByte('\n') 277 e.WriteQueryByte('\t') 278 e.WriteQuery("PRIMARY KEY ") 279 e.WriteExpr(key.Def.TableExpr(key.Table)) 280 } 281 }) 282 283 expr.WriteQueryByte('\n') 284 }) 285 286 expr.WriteEnd() 287 exprs = append(exprs, expr) 288 289 t.Keys.Range(func(key *builder.Key, idx int) { 290 if !key.IsPrimary() { 291 exprs = append(exprs, c.AddIndex(key)) 292 } 293 }) 294 295 return 296 } 297 298 func (c *PostgreSQLConnector) DropTable(t *builder.Table) builder.SqlExpr { 299 e := builder.Expr("DROP TABLE IF EXISTS ") 300 e.WriteExpr(t) 301 e.WriteEnd() 302 return e 303 } 304 305 func (c *PostgreSQLConnector) TruncateTable(t *builder.Table) builder.SqlExpr { 306 e := builder.Expr("TRUNCATE TABLE ") 307 e.WriteExpr(t) 308 e.WriteEnd() 309 return e 310 } 311 312 func (c *PostgreSQLConnector) AddColumn(col *builder.Column) builder.SqlExpr { 313 e := builder.Expr("ALTER TABLE ") 314 e.WriteExpr(col.Table) 315 e.WriteQuery(" ADD COLUMN ") 316 e.WriteExpr(col) 317 e.WriteQueryByte(' ') 318 e.WriteExpr(c.DataType(col.ColumnType)) 319 e.WriteEnd() 320 return e 321 } 322 323 func (c *PostgreSQLConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr { 324 e := builder.Expr("ALTER TABLE ") 325 e.WriteExpr(col.Table) 326 e.WriteQuery(" RENAME COLUMN ") 327 e.WriteExpr(col) 328 e.WriteQuery(" TO ") 329 e.WriteExpr(target) 330 e.WriteEnd() 331 return e 332 } 333 334 func (c *PostgreSQLConnector) ModifyColumn(col *builder.Column, prev *builder.Column) builder.SqlExpr { 335 if col.AutoIncrement { 336 return nil 337 } 338 339 e := builder.Expr("ALTER TABLE ") 340 e.WriteExpr(col.Table) 341 342 dbDataType := c.dataType(col.ColumnType.Type, col.ColumnType) 343 prevDbDataType := c.dataType(prev.ColumnType.Type, prev.ColumnType) 344 345 isFirstSub := true 346 isEmpty := true 347 348 prepareAppendSubCmd := func() { 349 if !isFirstSub { 350 e.WriteQueryByte(',') 351 } 352 isFirstSub = false 353 isEmpty = false 354 } 355 356 if dbDataType != prevDbDataType { 357 prepareAppendSubCmd() 358 359 e.WriteQuery(" ALTER COLUMN ") 360 e.WriteExpr(col) 361 e.WriteQuery(" TYPE ") 362 e.WriteQuery(dbDataType) 363 364 e.WriteQuery(" /* FROM ") 365 e.WriteQuery(prevDbDataType) 366 e.WriteQuery(" */") 367 } 368 369 if col.Null != prev.Null { 370 prepareAppendSubCmd() 371 372 e.WriteQuery(" ALTER COLUMN ") 373 e.WriteExpr(col) 374 if !col.Null { 375 e.WriteQuery(" SET NOT NULL") 376 } else { 377 e.WriteQuery(" DROP NOT NULL") 378 } 379 } 380 381 defaultValue := normalizeDefaultValue(col.Default, dbDataType) 382 prevDefaultValue := normalizeDefaultValue(prev.Default, prevDbDataType) 383 384 if defaultValue != prevDefaultValue { 385 prepareAppendSubCmd() 386 387 e.WriteQuery(" ALTER COLUMN ") 388 e.WriteExpr(col) 389 if col.Default != nil { 390 e.WriteQuery(" SET DEFAULT ") 391 e.WriteQuery(defaultValue) 392 393 e.WriteQuery(" /* FROM ") 394 e.WriteQuery(prevDefaultValue) 395 e.WriteQuery(" */") 396 } else { 397 e.WriteQuery(" DROP DEFAULT") 398 } 399 } 400 401 if isEmpty { 402 return nil 403 } 404 405 e.WriteEnd() 406 407 return e 408 } 409 410 func (c *PostgreSQLConnector) DropColumn(col *builder.Column) builder.SqlExpr { 411 e := builder.Expr("ALTER TABLE ") 412 e.WriteExpr(col.Table) 413 e.WriteQuery(" DROP COLUMN ") 414 e.WriteQuery(col.Name) 415 e.WriteEnd() 416 return e 417 } 418 419 func (c *PostgreSQLConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr { 420 dbDataType := dealias(c.dbDataType(columnType.Type, columnType)) 421 return builder.Expr(dbDataType + autocompleteSize(dbDataType, columnType) + c.dataTypeModify(columnType, dbDataType)) 422 } 423 424 func (c *PostgreSQLConnector) dataType(typ typex.Type, columnType *builder.ColumnType) string { 425 dbDataType := dealias(c.dbDataType(columnType.Type, columnType)) 426 return dbDataType + autocompleteSize(dbDataType, columnType) 427 } 428 429 func (c *PostgreSQLConnector) dbDataType(typ typex.Type, columnType *builder.ColumnType) string { 430 if columnType.DataType != "" { 431 return columnType.DataType 432 } 433 434 if rv, ok := typex.TryNew(typ); ok { 435 if dtd, ok := rv.Interface().(builder.DataTypeDescriber); ok { 436 return dtd.DataType(c.DriverName()) 437 } 438 } 439 440 switch typ.Kind() { 441 case reflect.Ptr: 442 return c.dataType(typ.Elem(), columnType) 443 case reflect.Bool: 444 return "boolean" 445 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: 446 if columnType.AutoIncrement { 447 return "serial" 448 } 449 return "integer" 450 case reflect.Int64, reflect.Uint64: 451 if columnType.AutoIncrement { 452 return "bigserial" 453 } 454 return "bigint" 455 case reflect.Float64: 456 return "double precision" 457 case reflect.Float32: 458 return "real" 459 case reflect.Slice: 460 if typ.Elem().Kind() == reflect.Uint8 { 461 return "bytea" 462 } 463 case reflect.String: 464 size := columnType.Length 465 if size < 65535/3 { 466 return "varchar" 467 } 468 return "text" 469 } 470 471 switch typ.Name() { 472 case "Hstore": 473 return "hstore" 474 case "ByteaArray": 475 return c.dataType(typex.FromRType(reflect.TypeOf(pq.ByteaArray{[]byte("")}[0])), columnType) + "[]" 476 case "BoolArray": 477 return c.dataType(typex.FromRType(reflect.TypeOf(pq.BoolArray{true}[0])), columnType) + "[]" 478 case "Float64Array": 479 return c.dataType(typex.FromRType(reflect.TypeOf(pq.Float64Array{0}[0])), columnType) + "[]" 480 case "Int64Array": 481 return c.dataType(typex.FromRType(reflect.TypeOf(pq.Int64Array{0}[0])), columnType) + "[]" 482 case "StringArray": 483 return c.dataType(typex.FromRType(reflect.TypeOf(pq.StringArray{""}[0])), columnType) + "[]" 484 case "NullInt64": 485 return "bigint" 486 case "NullFloat64": 487 return "double precision" 488 case "NullBool": 489 return "boolean" 490 case "Time", "NullTime": 491 return "timestamp with time zone" 492 } 493 494 panic(fmt.Errorf("unsupport type %s", typ)) 495 } 496 497 func (c *PostgreSQLConnector) dataTypeModify(columnType *builder.ColumnType, dataType string) string { 498 buf := bytes.NewBuffer(nil) 499 500 if !columnType.Null { 501 buf.WriteString(" NOT NULL") 502 } 503 504 if columnType.Default != nil { 505 buf.WriteString(" DEFAULT ") 506 buf.WriteString(normalizeDefaultValue(columnType.Default, dataType)) 507 } 508 509 return buf.String() 510 } 511 512 func normalizeDefaultValue(defaultValue *string, dataType string) string { 513 if defaultValue == nil { 514 return "" 515 } 516 517 dv := *defaultValue 518 519 if dv[0] == '\'' { 520 if strings.Contains(dv, "'::") { 521 return dv 522 } 523 return dv + "::" + dataType 524 } 525 526 _, err := strconv.ParseFloat(dv, 64) 527 if err == nil { 528 return "'" + dv + "'::" + dataType 529 } 530 531 return dv 532 } 533 534 func autocompleteSize(dataType string, columnType *builder.ColumnType) string { 535 switch dataType { 536 case "character varying", "character": 537 size := columnType.Length 538 if size == 0 { 539 size = 255 540 } 541 return sizeModifier(size, columnType.Decimal) 542 case "decimal", "numeric", "real", "double precision": 543 if columnType.Length > 0 { 544 return sizeModifier(columnType.Length, columnType.Decimal) 545 } 546 } 547 return "" 548 } 549 550 func dealias(dataType string) string { 551 switch dataType { 552 case "varchar": 553 return "character varying" 554 case "timestamp": 555 return "timestamp without time zone" 556 } 557 return dataType 558 } 559 560 func sizeModifier(length uint64, decimal uint64) string { 561 if length > 0 { 562 size := strconv.FormatUint(length, 10) 563 if decimal > 0 { 564 return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")" 565 } 566 return "(" + size + ")" 567 } 568 return "" 569 }