gitee.com/go-genie/sqlx@v1.0.3/connectors/postgresql/postgresql_connector.go (about) 1 package postgresql 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql/driver" 7 "fmt" 8 "gitee.com/go-genie/sqlx/generator" 9 "io" 10 "reflect" 11 "strconv" 12 "strings" 13 14 typex "gitee.com/go-genie/xx/types" 15 16 "gitee.com/go-genie/sqlx" 17 "gitee.com/go-genie/sqlx/builder" 18 "gitee.com/go-genie/sqlx/command" 19 "github.com/lib/pq" 20 ) 21 22 var _ interface { 23 driver.Connector 24 builder.Dialect 25 } = (*PostgreSQLConnector)(nil) 26 27 type PostgreSQLConnector struct { 28 Host string 29 DBName string 30 Extra string 31 Extensions []string 32 } 33 34 func (c *PostgreSQLConnector) Connect(ctx context.Context) (driver.Conn, error) { 35 d := c.Driver() 36 37 conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra)) 38 if err != nil { 39 if c.IsErrorUnknownDatabase(err) { 40 connectForCreateDB, err := d.Open(dsn(c.Host, "", c.Extra)) 41 if err != nil { 42 return nil, err 43 } 44 if _, err := connectForCreateDB.(driver.ExecerContext).ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil { 45 return nil, err 46 } 47 if err := connectForCreateDB.Close(); err != nil { 48 return nil, err 49 } 50 return c.Connect(ctx) 51 } 52 return nil, err 53 } 54 for _, ex := range c.Extensions { 55 if _, err := conn.(driver.ExecerContext).ExecContext(context.Background(), "CREATE EXTENSION IF NOT EXISTS "+ex+";", nil); err != nil { 56 return nil, err 57 } 58 } 59 60 return conn, nil 61 } 62 63 func (PostgreSQLConnector) Driver() driver.Driver { 64 return &PostgreSQLLoggingDriver{} 65 } 66 67 func dsn(host string, dbName string, extra string) string { 68 if extra != "" { 69 extra = "?" + extra 70 } 71 return host + "/" + dbName + extra 72 } 73 74 func (c PostgreSQLConnector) WithDBName(dbName string) driver.Connector { 75 c.DBName = dbName 76 return &c 77 } 78 79 func (c *PostgreSQLConnector) Generate(ctx context.Context, db sqlx.DBExecutor) error { 80 //output := command.MigrationOutputFromContext(ctx) 81 82 prevDB, err := dbFromInformationSchema(db, COMMAND_GENERATE) 83 if err != nil { 84 return err 85 } 86 87 //cwd, _ := os.Getwd() 88 // 89 //pkg, err := packagesx.Load(cwd) 90 //if err != nil { 91 // panic(err) 92 //} 93 94 models := generator.NewModelsFromDataBase(prevDB) 95 96 for _, item := range models { 97 err = item.Generator() 98 if err != nil { 99 panic(err) 100 } 101 } 102 103 return nil 104 } 105 106 func (c *PostgreSQLConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error { 107 output := command.MigrationOutputFromContext(ctx) 108 109 prevDB, err := dbFromInformationSchema(db, COMMAND_MIGRATE) 110 if err != nil { 111 return err 112 } 113 114 d := db.D() 115 dialect := db.Dialect() 116 117 exec := func(expr builder.SqlExpr) error { 118 if expr == nil || expr.IsNil() { 119 return nil 120 } 121 122 if output != nil { 123 _, _ = io.WriteString(output, builder.ResolveExpr(expr).Query()) 124 _, _ = io.WriteString(output, "\n") 125 return nil 126 } 127 128 _, err := db.ExecExpr(expr) 129 return err 130 } 131 132 if prevDB == nil { 133 prevDB = &sqlx.Database{ 134 Name: d.Name, 135 } 136 if err := exec(dialect.CreateDatabase(d.Name)); err != nil { 137 return err 138 } 139 } 140 141 if d.Schema != "" { 142 if err := exec(dialect.CreateSchema(d.Schema)); err != nil { 143 return err 144 } 145 prevDB = prevDB.WithSchema(d.Schema) 146 } 147 148 for _, name := range d.Tables.TableNames() { 149 table := d.Table(name) 150 151 prevTable := prevDB.Table(name) 152 153 if prevTable == nil { 154 for _, expr := range dialect.CreateTableIsNotExists(table) { 155 if err := exec(expr); err != nil { 156 return err 157 } 158 } 159 continue 160 } 161 162 exprList := table.Diff(prevTable, dialect) 163 164 for _, expr := range exprList { 165 if err := exec(expr); err != nil { 166 return err 167 } 168 } 169 } 170 171 return nil 172 } 173 174 func (PostgreSQLConnector) DriverName() string { 175 return "postgres" 176 } 177 178 func (PostgreSQLConnector) PrimaryKeyName() string { 179 return "pkey" 180 } 181 182 func (PostgreSQLConnector) IsErrorUnknownDatabase(err error) bool { 183 if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "3D000" { 184 return true 185 } 186 return false 187 } 188 189 func (PostgreSQLConnector) IsErrorConflict(err error) bool { 190 if e, ok := sqlx.UnwrapAll(err).(*pq.Error); ok && e.Code == "23505" { 191 return true 192 } 193 return false 194 } 195 196 func (c *PostgreSQLConnector) CreateDatabase(dbName string) builder.SqlExpr { 197 e := builder.Expr("CREATE DATABASE ") 198 e.WriteQuery(dbName) 199 e.WriteEnd() 200 return e 201 } 202 203 func (c *PostgreSQLConnector) CreateSchema(schema string) builder.SqlExpr { 204 e := builder.Expr("CREATE SCHEMA IF NOT EXISTS ") 205 e.WriteQuery(schema) 206 e.WriteEnd() 207 return e 208 } 209 210 func (c *PostgreSQLConnector) DropDatabase(dbName string) builder.SqlExpr { 211 e := builder.Expr("DROP DATABASE IF EXISTS ") 212 e.WriteQuery(dbName) 213 e.WriteEnd() 214 return e 215 } 216 217 func (c *PostgreSQLConnector) AddIndex(key *builder.Key) builder.SqlExpr { 218 if key.IsPrimary() { 219 e := builder.Expr("ALTER TABLE ") 220 e.WriteExpr(key.Table) 221 e.WriteQuery(" ADD PRIMARY KEY ") 222 e.WriteExpr(key.Def.TableExpr(key.Table)) 223 e.WriteEnd() 224 return e 225 } 226 227 e := builder.Expr("CREATE ") 228 if key.IsUnique { 229 e.WriteQuery("UNIQUE ") 230 } 231 e.WriteQuery("INDEX ") 232 233 e.WriteQuery(key.Table.Name) 234 e.WriteQuery("_") 235 e.WriteQuery(key.Name) 236 237 e.WriteQuery(" ON ") 238 e.WriteExpr(key.Table) 239 240 if m := strings.ToUpper(key.Method); m != "" { 241 if m == "SPATIAL" { 242 m = "GIST" 243 } 244 e.WriteQuery(" USING ") 245 e.WriteQuery(m) 246 } 247 248 e.WriteQueryByte(' ') 249 e.WriteExpr(key.Def.TableExpr(key.Table)) 250 251 e.WriteEnd() 252 return e 253 } 254 255 func (c *PostgreSQLConnector) DropIndex(key *builder.Key) builder.SqlExpr { 256 if key.IsPrimary() { 257 e := builder.Expr("ALTER TABLE ") 258 e.WriteExpr(key.Table) 259 e.WriteQuery(" DROP CONSTRAINT ") 260 e.WriteExpr(key.Table) 261 e.WriteQuery("_pkey") 262 e.WriteEnd() 263 return e 264 } 265 e := builder.Expr("DROP ") 266 267 e.WriteQuery("INDEX IF EXISTS ") 268 e.WriteExpr(key.Table) 269 e.WriteQueryByte('_') 270 e.WriteQuery(key.Name) 271 e.WriteEnd() 272 273 return e 274 } 275 276 func (c *PostgreSQLConnector) CreateTableIsNotExists(t *builder.Table) (exprs []builder.SqlExpr) { 277 expr := builder.Expr("CREATE TABLE IF NOT EXISTS ") 278 expr.WriteExpr(t) 279 expr.WriteQueryByte(' ') 280 expr.WriteGroup(func(e *builder.Ex) { 281 if t.Columns.IsNil() { 282 return 283 } 284 285 t.Columns.Range(func(col *builder.Column, idx int) { 286 if col.DeprecatedActions != nil { 287 return 288 } 289 290 if idx > 0 { 291 e.WriteQueryByte(',') 292 } 293 e.WriteQueryByte('\n') 294 e.WriteQueryByte('\t') 295 296 e.WriteExpr(col) 297 e.WriteQueryByte(' ') 298 e.WriteExpr(c.DataType(col.ColumnType)) 299 }) 300 301 t.Keys.Range(func(key *builder.Key, idx int) { 302 if key.IsPrimary() { 303 e.WriteQueryByte(',') 304 e.WriteQueryByte('\n') 305 e.WriteQueryByte('\t') 306 e.WriteQuery("PRIMARY KEY ") 307 e.WriteExpr(key.Def.TableExpr(key.Table)) 308 } 309 }) 310 311 expr.WriteQueryByte('\n') 312 }) 313 314 expr.WriteEnd() 315 exprs = append(exprs, expr) 316 317 t.Keys.Range(func(key *builder.Key, idx int) { 318 if !key.IsPrimary() { 319 exprs = append(exprs, c.AddIndex(key)) 320 } 321 }) 322 323 return 324 } 325 326 func (c *PostgreSQLConnector) DropTable(t *builder.Table) builder.SqlExpr { 327 e := builder.Expr("DROP TABLE IF EXISTS ") 328 e.WriteExpr(t) 329 e.WriteEnd() 330 return e 331 } 332 333 func (c *PostgreSQLConnector) TruncateTable(t *builder.Table) builder.SqlExpr { 334 e := builder.Expr("TRUNCATE TABLE ") 335 e.WriteExpr(t) 336 e.WriteEnd() 337 return e 338 } 339 340 func (c *PostgreSQLConnector) AddColumn(col *builder.Column) builder.SqlExpr { 341 e := builder.Expr("ALTER TABLE ") 342 e.WriteExpr(col.Table) 343 e.WriteQuery(" ADD COLUMN ") 344 e.WriteExpr(col) 345 e.WriteQueryByte(' ') 346 e.WriteExpr(c.DataType(col.ColumnType)) 347 e.WriteEnd() 348 return e 349 } 350 351 func (c *PostgreSQLConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr { 352 e := builder.Expr("ALTER TABLE ") 353 e.WriteExpr(col.Table) 354 e.WriteQuery(" RENAME COLUMN ") 355 e.WriteExpr(col) 356 e.WriteQuery(" TO ") 357 e.WriteExpr(target) 358 e.WriteEnd() 359 return e 360 } 361 362 func (c *PostgreSQLConnector) ModifyColumn(col *builder.Column, prev *builder.Column) builder.SqlExpr { 363 if col.AutoIncrement { 364 return nil 365 } 366 367 e := builder.Expr("ALTER TABLE ") 368 e.WriteExpr(col.Table) 369 370 dbDataType := c.dataType(col.ColumnType.Type, col.ColumnType) 371 prevDbDataType := c.dataType(prev.ColumnType.Type, prev.ColumnType) 372 373 isFirstSub := true 374 isEmpty := true 375 376 prepareAppendSubCmd := func() { 377 if !isFirstSub { 378 e.WriteQueryByte(',') 379 } 380 isFirstSub = false 381 isEmpty = false 382 } 383 384 if dbDataType != prevDbDataType { 385 prepareAppendSubCmd() 386 387 e.WriteQuery(" ALTER COLUMN ") 388 e.WriteExpr(col) 389 e.WriteQuery(" TYPE ") 390 e.WriteQuery(dbDataType) 391 392 e.WriteQuery(" /* FROM ") 393 e.WriteQuery(prevDbDataType) 394 e.WriteQuery(" */") 395 } 396 397 if col.Null != prev.Null { 398 prepareAppendSubCmd() 399 400 e.WriteQuery(" ALTER COLUMN ") 401 e.WriteExpr(col) 402 if !col.Null { 403 e.WriteQuery(" SET NOT NULL") 404 } else { 405 e.WriteQuery(" DROP NOT NULL") 406 } 407 } 408 409 defaultValue := normalizeDefaultValue(col.Default, dbDataType) 410 prevDefaultValue := normalizeDefaultValue(prev.Default, prevDbDataType) 411 412 if defaultValue != prevDefaultValue { 413 prepareAppendSubCmd() 414 415 e.WriteQuery(" ALTER COLUMN ") 416 e.WriteExpr(col) 417 if col.Default != nil { 418 e.WriteQuery(" SET DEFAULT ") 419 e.WriteQuery(defaultValue) 420 421 e.WriteQuery(" /* FROM ") 422 e.WriteQuery(prevDefaultValue) 423 e.WriteQuery(" */") 424 } else { 425 e.WriteQuery(" DROP DEFAULT") 426 } 427 } 428 429 if isEmpty { 430 return nil 431 } 432 433 e.WriteEnd() 434 435 return e 436 } 437 438 func (c *PostgreSQLConnector) DropColumn(col *builder.Column) builder.SqlExpr { 439 e := builder.Expr("ALTER TABLE ") 440 e.WriteExpr(col.Table) 441 e.WriteQuery(" DROP COLUMN ") 442 e.WriteQuery(col.Name) 443 e.WriteEnd() 444 return e 445 } 446 447 func (c *PostgreSQLConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr { 448 dbDataType := dealias(c.dbDataType(columnType.Type, columnType)) 449 return builder.Expr(dbDataType + autocompleteSize(dbDataType, columnType) + c.dataTypeModify(columnType, dbDataType)) 450 } 451 452 func (c *PostgreSQLConnector) dataType(typ typex.Type, columnType *builder.ColumnType) string { 453 dbDataType := dealias(c.dbDataType(columnType.Type, columnType)) 454 return dbDataType + autocompleteSize(dbDataType, columnType) 455 } 456 457 func (c *PostgreSQLConnector) dbDataType(typ typex.Type, columnType *builder.ColumnType) string { 458 if columnType.DataType != "" { 459 return columnType.DataType 460 } 461 462 if rv, ok := typex.TryNew(typ); ok { 463 if dtd, ok := rv.Interface().(builder.DataTypeDescriber); ok { 464 return dtd.DataType(c.DriverName()) 465 } 466 } 467 468 switch typ.Kind() { 469 case reflect.Ptr: 470 return c.dataType(typ.Elem(), columnType) 471 case reflect.Bool: 472 return "boolean" 473 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: 474 if columnType.AutoIncrement { 475 return "serial" 476 } 477 return "integer" 478 case reflect.Int64, reflect.Uint64: 479 if columnType.AutoIncrement { 480 return "bigserial" 481 } 482 return "bigint" 483 case reflect.Float64: 484 return "double precision" 485 case reflect.Float32: 486 return "real" 487 case reflect.Slice: 488 if typ.Elem().Kind() == reflect.Uint8 { 489 return "bytea" 490 } 491 case reflect.String: 492 size := columnType.Length 493 if size < 65535/3 { 494 return "varchar" 495 } 496 return "text" 497 } 498 499 switch typ.Name() { 500 case "Hstore": 501 return "hstore" 502 case "ByteaArray": 503 return c.dataType(typex.FromRType(reflect.TypeOf(pq.ByteaArray{[]byte("")}[0])), columnType) + "[]" 504 case "BoolArray": 505 return c.dataType(typex.FromRType(reflect.TypeOf(pq.BoolArray{true}[0])), columnType) + "[]" 506 case "Float64Array": 507 return c.dataType(typex.FromRType(reflect.TypeOf(pq.Float64Array{0}[0])), columnType) + "[]" 508 case "Int64Array": 509 return c.dataType(typex.FromRType(reflect.TypeOf(pq.Int64Array{0}[0])), columnType) + "[]" 510 case "StringArray": 511 return c.dataType(typex.FromRType(reflect.TypeOf(pq.StringArray{""}[0])), columnType) + "[]" 512 case "NullInt64": 513 return "bigint" 514 case "NullFloat64": 515 return "double precision" 516 case "NullBool": 517 return "boolean" 518 case "Time", "NullTime": 519 return "timestamp with time zone" 520 } 521 522 panic(fmt.Errorf("unsupport type %s", typ)) 523 } 524 525 func (c *PostgreSQLConnector) dataTypeModify(columnType *builder.ColumnType, dataType string) string { 526 buf := bytes.NewBuffer(nil) 527 528 if !columnType.Null { 529 buf.WriteString(" NOT NULL") 530 } 531 532 if columnType.Default != nil { 533 buf.WriteString(" DEFAULT ") 534 buf.WriteString(normalizeDefaultValue(columnType.Default, dataType)) 535 } 536 537 return buf.String() 538 } 539 540 func normalizeDefaultValue(defaultValue *string, dataType string) string { 541 if defaultValue == nil { 542 return "" 543 } 544 545 dv := *defaultValue 546 547 if dv[0] == '\'' { 548 if strings.Contains(dv, "'::") { 549 return dv 550 } 551 return dv + "::" + dataType 552 } 553 554 _, err := strconv.ParseFloat(dv, 64) 555 if err == nil { 556 return "'" + dv + "'::" + dataType 557 } 558 559 return dv 560 } 561 562 func autocompleteSize(dataType string, columnType *builder.ColumnType) string { 563 switch dataType { 564 case "character varying", "character": 565 size := columnType.Length 566 if size == 0 { 567 size = 255 568 } 569 return sizeModifier(size, columnType.Decimal) 570 case "decimal", "numeric", "real", "double precision": 571 if columnType.Length > 0 { 572 return sizeModifier(columnType.Length, columnType.Decimal) 573 } 574 } 575 return "" 576 } 577 578 func dealias(dataType string) string { 579 switch dataType { 580 case "varchar": 581 return "character varying" 582 case "timestamp": 583 return "timestamp without time zone" 584 } 585 return dataType 586 } 587 588 func sizeModifier(length uint64, decimal uint64) string { 589 if length > 0 { 590 size := strconv.FormatUint(length, 10) 591 if decimal > 0 { 592 return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")" 593 } 594 return "(" + size + ")" 595 } 596 return "" 597 }