github.com/eden-framework/sqlx@v0.0.2/postgresqlconnector/postgresql_connector.go (about) 1 package postgresqlconnector 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql/driver" 7 "fmt" 8 "log" 9 "reflect" 10 "strconv" 11 "strings" 12 13 "github.com/eden-framework/sqlx" 14 "github.com/eden-framework/sqlx/builder" 15 "github.com/eden-framework/sqlx/migration" 16 "github.com/lib/pq" 17 "github.com/sirupsen/logrus" 18 ) 19 20 var _ interface { 21 driver.Connector 22 builder.Dialect 23 } = (*PostgreSQLConnector)(nil) 24 25 type PostgreSQLConnector struct { 26 Host string 27 DBName string 28 Extra string 29 Extensions []string 30 } 31 32 func dsn(host string, dbName string, extra string) string { 33 if extra != "" { 34 extra = "?" + extra 35 } 36 return host + "/" + dbName + extra 37 } 38 39 func (c PostgreSQLConnector) WithDBName(dbName string) driver.Connector { 40 c.DBName = dbName 41 return &c 42 } 43 44 func (c *PostgreSQLConnector) Migrate(ctx context.Context, db sqlx.DBExecutor) error { 45 opts := migration.MigrationOptsFromContext(ctx) 46 47 prevDB := dbFromInformationSchema(db) 48 49 d := db.D() 50 dialect := db.Dialect() 51 52 if prevDB == nil { 53 prevDB = &sqlx.Database{ 54 Name: d.Name, 55 } 56 if _, err := db.ExecExpr(dialect.CreateDatabase(d.Name)); err != nil { 57 return err 58 } 59 } 60 61 if d.Schema != "" { 62 if _, err := db.ExecExpr(dialect.CreateSchema(d.Schema)); err != nil { 63 return err 64 } 65 prevDB = prevDB.WithSchema(d.Schema) 66 } 67 68 for _, name := range d.Tables.TableNames() { 69 table := d.Table(name) 70 71 prevTable := prevDB.Table(name) 72 73 if prevTable == nil { 74 for _, expr := range dialect.CreateTableIsNotExists(table) { 75 if _, err := db.ExecExpr(expr); err != nil { 76 return err 77 } 78 } 79 continue 80 } 81 82 exprList := table.Diff(prevTable, dialect) 83 84 for _, expr := range exprList { 85 if !(expr == nil || expr.IsNil()) { 86 if opts.DryRun { 87 log.Printf(builder.ResolveExpr(expr).Query()) 88 } else { 89 if _, err := db.ExecExpr(expr); err != nil { 90 return err 91 } 92 } 93 } 94 } 95 } 96 97 return nil 98 } 99 100 func (c *PostgreSQLConnector) Connect(ctx context.Context) (driver.Conn, error) { 101 d := c.Driver() 102 conn, err := d.Open(dsn(c.Host, c.DBName, c.Extra)) 103 if err != nil { 104 if c.IsErrorUnknownDatabase(err) { 105 connectForCreateDB, err := d.Open(dsn(c.Host, "", c.Extra)) 106 if err != nil { 107 return nil, err 108 } 109 if _, err := connectForCreateDB.(driver.ExecerContext). 110 ExecContext(context.Background(), builder.ResolveExpr(c.CreateDatabase(c.DBName)).Query(), nil); err != nil { 111 return nil, err 112 } 113 if err := connectForCreateDB.Close(); err != nil { 114 return nil, err 115 } 116 return c.Connect(ctx) 117 } 118 return nil, err 119 } 120 for _, ex := range c.Extensions { 121 if _, err := conn.(driver.ExecerContext). 122 ExecContext(context.Background(), "CREATE EXTENSION IF NOT EXISTS "+ex+";", nil); err != nil { 123 return nil, err 124 } 125 } 126 return conn, nil 127 } 128 129 func (PostgreSQLConnector) Driver() driver.Driver { 130 return &PostgreSQLLoggingDriver{Driver: &pq.Driver{}, Logger: logrus.StandardLogger()} 131 } 132 133 func (PostgreSQLConnector) DriverName() string { 134 return "postgres" 135 } 136 137 func (PostgreSQLConnector) PrimaryKeyName() string { 138 return "pkey" 139 } 140 141 func (PostgreSQLConnector) IsErrorUnknownDatabase(err error) bool { 142 if e, ok := err.(*pq.Error); ok && e.Code == "3D000" { 143 return true 144 } 145 return false 146 } 147 148 func (PostgreSQLConnector) IsErrorConflict(err error) bool { 149 if e, ok := err.(*pq.Error); ok && e.Code == "23505" { 150 return true 151 } 152 return false 153 } 154 155 func (c *PostgreSQLConnector) CreateDatabase(dbName string) builder.SqlExpr { 156 e := builder.Expr("CREATE DATABASE ") 157 e.WriteString(dbName) 158 e.WriteEnd() 159 return e 160 } 161 162 func (c *PostgreSQLConnector) CreateSchema(schema string) builder.SqlExpr { 163 e := builder.Expr("CREATE SCHEMA IF NOT EXISTS ") 164 e.WriteString(schema) 165 e.WriteEnd() 166 return e 167 } 168 169 func (c *PostgreSQLConnector) DropDatabase(dbName string) builder.SqlExpr { 170 e := builder.Expr("DROP DATABASE IF EXISTS ") 171 e.WriteString(dbName) 172 e.WriteEnd() 173 return e 174 } 175 176 func (c *PostgreSQLConnector) AddIndex(key *builder.Key) builder.SqlExpr { 177 if key.IsPrimary() { 178 e := builder.Expr("ALTER TABLE ") 179 e.WriteExpr(key.Table) 180 e.WriteString(" ADD PRIMARY KEY ") 181 e.WriteGroup(func(e *builder.Ex) { 182 e.WriteExpr(key.Columns) 183 }) 184 e.WriteEnd() 185 return e 186 } 187 188 e := builder.Expr("CREATE ") 189 if key.IsUnique { 190 e.WriteString("UNIQUE ") 191 } 192 e.WriteString("INDEX ") 193 194 e.WriteString(key.Table.Name) 195 e.WriteString("_") 196 e.WriteString(key.Name) 197 198 e.WriteString(" ON ") 199 e.WriteExpr(key.Table) 200 201 if map[string]bool{ 202 "BTREE": true, 203 "HASH": true, 204 "SPATIAL": true, 205 "GIST": true, 206 }[strings.ToUpper(key.Method)] { 207 e.WriteString(" USING ") 208 209 if key.Method == "SPATIAL" { 210 e.WriteString("GIST") 211 } else { 212 e.WriteString(key.Method) 213 } 214 } 215 216 e.WriteByte(' ') 217 e.WriteGroup(func(e *builder.Ex) { 218 e.WriteExpr(key.Columns) 219 }) 220 221 e.WriteEnd() 222 return e 223 } 224 225 func (c *PostgreSQLConnector) DropIndex(key *builder.Key) builder.SqlExpr { 226 if key.IsPrimary() { 227 e := builder.Expr("ALTER TABLE ") 228 e.WriteExpr(key.Table) 229 e.WriteString(" DROP CONSTRAINT ") 230 e.WriteExpr(key.Table) 231 e.WriteString("_pkey") 232 e.WriteEnd() 233 return e 234 } 235 e := builder.Expr("DROP ") 236 237 e.WriteString("INDEX IF EXISTS ") 238 e.WriteExpr(key.Table) 239 e.WriteByte('_') 240 e.WriteString(key.Name) 241 242 return e 243 } 244 245 func (c *PostgreSQLConnector) CreateTableIsNotExists(t *builder.Table) (exprs []builder.SqlExpr) { 246 expr := builder.Expr("CREATE TABLE IF NOT EXISTS ") 247 expr.WriteExpr(t) 248 expr.WriteByte(' ') 249 expr.WriteGroup(func(e *builder.Ex) { 250 if t.Columns.IsNil() { 251 return 252 } 253 254 t.Columns.Range(func(col *builder.Column, idx int) { 255 if col.DeprecatedActions != nil { 256 return 257 } 258 259 if idx > 0 { 260 e.WriteByte(',') 261 } 262 e.WriteByte('\n') 263 e.WriteByte('\t') 264 265 e.WriteExpr(col) 266 e.WriteByte(' ') 267 e.WriteExpr(c.DataType(col.ColumnType)) 268 }) 269 270 t.Keys.Range(func(key *builder.Key, idx int) { 271 if key.IsPrimary() { 272 e.WriteByte(',') 273 e.WriteByte('\n') 274 e.WriteByte('\t') 275 e.WriteString("PRIMARY KEY ") 276 e.WriteGroup(func(e *builder.Ex) { 277 e.WriteExpr(key.Columns) 278 }) 279 } 280 }) 281 282 expr.WriteByte('\n') 283 }) 284 285 expr.WriteEnd() 286 exprs = append(exprs, expr) 287 288 t.Keys.Range(func(key *builder.Key, idx int) { 289 if !key.IsPrimary() { 290 exprs = append(exprs, c.AddIndex(key)) 291 } 292 }) 293 294 return 295 } 296 297 func (c *PostgreSQLConnector) DropTable(t *builder.Table) builder.SqlExpr { 298 e := builder.Expr("DROP TABLE IF EXISTS ") 299 e.WriteExpr(t) 300 e.WriteEnd() 301 return e 302 } 303 304 func (c *PostgreSQLConnector) TruncateTable(t *builder.Table) builder.SqlExpr { 305 e := builder.Expr("TRUNCATE TABLE ") 306 e.WriteExpr(t) 307 e.WriteEnd() 308 return e 309 } 310 311 func (c *PostgreSQLConnector) AddColumn(col *builder.Column) builder.SqlExpr { 312 e := builder.Expr("ALTER TABLE ") 313 e.WriteExpr(col.Table) 314 e.WriteString(" ADD COLUMN ") 315 e.WriteExpr(col) 316 e.WriteByte(' ') 317 e.WriteExpr(c.DataType(col.ColumnType)) 318 e.WriteEnd() 319 return e 320 } 321 322 func (c *PostgreSQLConnector) RenameColumn(col *builder.Column, target *builder.Column) builder.SqlExpr { 323 e := builder.Expr("ALTER TABLE ") 324 e.WriteExpr(col.Table) 325 e.WriteString(" RENAME COLUMN ") 326 e.WriteExpr(col) 327 e.WriteString(" TO ") 328 e.WriteExpr(target) 329 e.WriteEnd() 330 return e 331 } 332 333 func (c *PostgreSQLConnector) ModifyColumn(col *builder.Column) builder.SqlExpr { 334 if col.AutoIncrement { 335 return nil 336 } 337 338 e := builder.Expr("ALTER TABLE ") 339 e.WriteExpr(col.Table) 340 341 e.WriteString(" ALTER COLUMN ") 342 e.WriteExpr(col) 343 e.WriteString(" TYPE ") 344 e.WriteString(c.dataType(col.ColumnType.Type, col.ColumnType)) 345 346 { 347 e.WriteString(", ALTER COLUMN ") 348 e.WriteExpr(col) 349 if !col.Null { 350 e.WriteString(" SET NOT NULL") 351 } else { 352 e.WriteString(" DROP NOT NULL") 353 } 354 } 355 356 { 357 e.WriteString(", ALTER COLUMN ") 358 e.WriteExpr(col) 359 if col.Default != nil { 360 e.WriteString(" SET DEFAULT ") 361 e.WriteString(*col.Default) 362 } else { 363 e.WriteString(" DROP DEFAULT") 364 } 365 } 366 367 e.WriteEnd() 368 369 return e 370 } 371 372 func (c *PostgreSQLConnector) DropColumn(col *builder.Column) builder.SqlExpr { 373 e := builder.Expr("ALTER TABLE ") 374 e.WriteExpr(col.Table) 375 e.WriteString(" DROP COLUMN ") 376 e.WriteString(col.Name) 377 e.WriteEnd() 378 return e 379 } 380 381 func (c *PostgreSQLConnector) DataType(columnType *builder.ColumnType) builder.SqlExpr { 382 return builder.Expr(c.dataType(columnType.Type, columnType) + c.dataTypeModify(columnType)) 383 } 384 385 func (c *PostgreSQLConnector) dataType(typ reflect.Type, columnType *builder.ColumnType) string { 386 if columnType.GetDataType != nil { 387 return columnType.GetDataType(c.DriverName()) 388 } 389 390 switch typ.Kind() { 391 case reflect.Ptr: 392 return c.dataType(typ.Elem(), columnType) 393 case reflect.Bool: 394 return "boolean" 395 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: 396 if columnType.AutoIncrement { 397 return "serial" 398 } 399 return "integer" 400 case reflect.Int64, reflect.Uint64: 401 if columnType.AutoIncrement { 402 return "bigserial" 403 } 404 return "bigint" 405 case reflect.Float64: 406 return "double precision" 407 case reflect.Float32: 408 return "real" 409 case reflect.Slice: 410 if typ.Elem().Kind() == reflect.Uint8 { 411 return "bytea" 412 } 413 case reflect.String: 414 size := columnType.Length 415 if size == 0 { 416 size = 255 417 } 418 if size < 65535/3 { 419 return "varchar" + sizeModifier(size, 0) 420 } 421 return "text" 422 } 423 424 switch typ.Name() { 425 case "Hstore": 426 return "hstore" 427 case "ByteaArray": 428 return c.dataType(reflect.TypeOf(pq.ByteaArray{[]byte("")}[0]), columnType) + "[]" 429 case "BoolArray": 430 return c.dataType(reflect.TypeOf(pq.BoolArray{true}[0]), columnType) + "[]" 431 case "Float64Array": 432 return c.dataType(reflect.TypeOf(pq.Float64Array{0}[0]), columnType) + "[]" 433 case "Int64Array": 434 return c.dataType(reflect.TypeOf(pq.Int64Array{0}[0]), columnType) + "[]" 435 case "StringArray": 436 return c.dataType(reflect.TypeOf(pq.StringArray{""}[0]), columnType) + "[]" 437 case "NullInt64": 438 return "bigint" 439 case "NullFloat64": 440 return "double precision" 441 case "NullBool": 442 return "boolean" 443 case "Time", "NullTime": 444 return "timestamp with time zone" 445 } 446 447 panic(fmt.Errorf("unsupport type %s", typ)) 448 } 449 450 func (c *PostgreSQLConnector) dataTypeModify(columnType *builder.ColumnType) string { 451 buf := bytes.NewBuffer(nil) 452 453 if !columnType.Null { 454 buf.WriteString(" NOT NULL") 455 } 456 457 if columnType.Default != nil { 458 buf.WriteString(" DEFAULT ") 459 buf.WriteString(*columnType.Default) 460 } 461 462 return buf.String() 463 } 464 465 func sizeModifier(length uint64, decimal uint64) string { 466 if length > 0 { 467 size := strconv.FormatUint(length, 10) 468 if decimal > 0 { 469 return "(" + size + "," + strconv.FormatUint(decimal, 10) + ")" 470 } 471 return "(" + size + ")" 472 } 473 return "" 474 }