github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/internal/sql/adapter/postgres/adapter.go (about) 1 package postgres 2 3 import ( 4 "context" 5 "database/sql" 6 "database/sql/driver" 7 "fmt" 8 "net/url" 9 "strings" 10 11 "github.com/jackc/pgx/v5/pgconn" 12 "github.com/jackc/pgx/v5/stdlib" 13 "github.com/octohelm/storage/internal/sql/adapter" 14 "github.com/octohelm/storage/internal/sql/loggingdriver" 15 "github.com/octohelm/storage/pkg/dberr" 16 "github.com/octohelm/storage/pkg/sqlbuilder" 17 "github.com/pkg/errors" 18 ) 19 20 func init() { 21 adapter.Register(&pgAdapter{}, "postgresql") 22 } 23 24 func Open(ctx context.Context, dsn *url.URL) (adapter.Adapter, error) { 25 return (&pgAdapter{}).Open(ctx, dsn) 26 } 27 28 type pgAdapter struct { 29 dialect 30 adapter.DB 31 dbName string 32 } 33 34 func (a *pgAdapter) Dialect() adapter.Dialect { 35 return &a.dialect 36 } 37 38 func (pgAdapter) DriverName() string { 39 return "postgres" 40 } 41 42 func (a *pgAdapter) Connector() driver.DriverContext { 43 return loggingdriver.Wrap(&stdlib.Driver{}, a.DriverName(), func(err error) int { 44 if pqerr, ok := dberr.UnwrapAll(err).(*pgconn.PgError); ok { 45 // unique_violation 46 if pqerr.Code == "23505" { 47 return 0 48 } 49 } 50 return 1 51 }) 52 } 53 54 func dbNameFromDSN(dsn *url.URL) string { 55 return strings.TrimLeft(dsn.Path, "/") 56 } 57 58 func (a *pgAdapter) Open(ctx context.Context, dsn *url.URL) (adapter.Adapter, error) { 59 if a.DriverName() != dsn.Scheme { 60 return nil, errors.Errorf("invalid schema %s", dsn) 61 } 62 63 dbName := dbNameFromDSN(dsn) 64 65 c, err := a.Connector().OpenConnector(dsn.String()) 66 if err != nil { 67 return nil, err 68 } 69 70 db := sql.OpenDB(c) 71 72 if err := db.PingContext(ctx); err != nil { 73 if isErrorUnknownDatabase(err) { 74 if err := a.createDatabase(ctx, dbName, *dsn); err != nil { 75 return nil, err 76 } 77 return a.Open(ctx, dsn) 78 } 79 return nil, err 80 } 81 82 return &pgAdapter{ 83 dbName: dbName, 84 DB: adapter.Wrap(db, func(err error) error { 85 if isErrorConflict(err) { 86 return dberr.New(dberr.ErrTypeConflict, err.Error()) 87 } 88 return err 89 }), 90 }, nil 91 } 92 93 func isErrorConflict(err error) bool { 94 if e, ok := dberr.UnwrapAll(err).(*pgconn.PgError); ok { 95 if e.Code == "23505" { 96 return true 97 } 98 } 99 return false 100 } 101 102 func isErrorUnknownDatabase(err error) bool { 103 if e, ok := dberr.UnwrapAll(err).(*pgconn.PgError); ok { 104 if e.Code == "3D000" { 105 return true 106 } 107 } 108 return false 109 } 110 111 func (a *pgAdapter) createDatabase(ctx context.Context, dbName string, dsn url.URL) error { 112 dsn.Path = "" 113 114 adaptor, err := a.Open(ctx, &dsn) 115 if err != nil { 116 return err 117 } 118 defer adaptor.Close() 119 120 _, err = adaptor.Exec(context.Background(), sqlbuilder.Expr(fmt.Sprintf("CREATE DATABASE %s", dbName))) 121 return err 122 }