github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/internal/sql/adapter/postgres/adapter.go (about)

     1  package postgres
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"net/url"
     9  	"strings"
    10  
    11  	"github.com/jackc/pgx/v5/pgconn"
    12  	"github.com/jackc/pgx/v5/stdlib"
    13  	"github.com/octohelm/storage/internal/sql/adapter"
    14  	"github.com/octohelm/storage/internal/sql/loggingdriver"
    15  	"github.com/octohelm/storage/pkg/dberr"
    16  	"github.com/octohelm/storage/pkg/sqlbuilder"
    17  	"github.com/pkg/errors"
    18  )
    19  
    20  func init() {
    21  	adapter.Register(&pgAdapter{}, "postgresql")
    22  }
    23  
    24  func Open(ctx context.Context, dsn *url.URL) (adapter.Adapter, error) {
    25  	return (&pgAdapter{}).Open(ctx, dsn)
    26  }
    27  
    28  type pgAdapter struct {
    29  	dialect
    30  	adapter.DB
    31  	dbName string
    32  }
    33  
    34  func (a *pgAdapter) Dialect() adapter.Dialect {
    35  	return &a.dialect
    36  }
    37  
    38  func (pgAdapter) DriverName() string {
    39  	return "postgres"
    40  }
    41  
    42  func (a *pgAdapter) Connector() driver.DriverContext {
    43  	return loggingdriver.Wrap(&stdlib.Driver{}, a.DriverName(), func(err error) int {
    44  		if pqerr, ok := dberr.UnwrapAll(err).(*pgconn.PgError); ok {
    45  			// unique_violation
    46  			if pqerr.Code == "23505" {
    47  				return 0
    48  			}
    49  		}
    50  		return 1
    51  	})
    52  }
    53  
    54  func dbNameFromDSN(dsn *url.URL) string {
    55  	return strings.TrimLeft(dsn.Path, "/")
    56  }
    57  
    58  func (a *pgAdapter) Open(ctx context.Context, dsn *url.URL) (adapter.Adapter, error) {
    59  	if a.DriverName() != dsn.Scheme {
    60  		return nil, errors.Errorf("invalid schema %s", dsn)
    61  	}
    62  
    63  	dbName := dbNameFromDSN(dsn)
    64  
    65  	c, err := a.Connector().OpenConnector(dsn.String())
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  
    70  	db := sql.OpenDB(c)
    71  
    72  	if err := db.PingContext(ctx); err != nil {
    73  		if isErrorUnknownDatabase(err) {
    74  			if err := a.createDatabase(ctx, dbName, *dsn); err != nil {
    75  				return nil, err
    76  			}
    77  			return a.Open(ctx, dsn)
    78  		}
    79  		return nil, err
    80  	}
    81  
    82  	return &pgAdapter{
    83  		dbName: dbName,
    84  		DB: adapter.Wrap(db, func(err error) error {
    85  			if isErrorConflict(err) {
    86  				return dberr.New(dberr.ErrTypeConflict, err.Error())
    87  			}
    88  			return err
    89  		}),
    90  	}, nil
    91  }
    92  
    93  func isErrorConflict(err error) bool {
    94  	if e, ok := dberr.UnwrapAll(err).(*pgconn.PgError); ok {
    95  		if e.Code == "23505" {
    96  			return true
    97  		}
    98  	}
    99  	return false
   100  }
   101  
   102  func isErrorUnknownDatabase(err error) bool {
   103  	if e, ok := dberr.UnwrapAll(err).(*pgconn.PgError); ok {
   104  		if e.Code == "3D000" {
   105  			return true
   106  		}
   107  	}
   108  	return false
   109  }
   110  
   111  func (a *pgAdapter) createDatabase(ctx context.Context, dbName string, dsn url.URL) error {
   112  	dsn.Path = ""
   113  
   114  	adaptor, err := a.Open(ctx, &dsn)
   115  	if err != nil {
   116  		return err
   117  	}
   118  	defer adaptor.Close()
   119  
   120  	_, err = adaptor.Exec(context.Background(), sqlbuilder.Expr(fmt.Sprintf("CREATE DATABASE %s", dbName)))
   121  	return err
   122  }