github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/internal/sql/adapter/sqlite/adapter.go (about)

     1  package sqlite
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"net/url"
     9  
    10  	"github.com/pkg/errors"
    11  	"modernc.org/sqlite"
    12  
    13  	"github.com/octohelm/storage/internal/sql/adapter"
    14  	"github.com/octohelm/storage/internal/sql/loggingdriver"
    15  	"github.com/octohelm/storage/pkg/dberr"
    16  	"github.com/octohelm/storage/pkg/sqlbuilder"
    17  )
    18  
    19  func init() {
    20  	adapter.Register(&sqliteAdapter{})
    21  }
    22  
    23  func Open(ctx context.Context, dsn *url.URL) (adapter.Adapter, error) {
    24  	return (&sqliteAdapter{}).Open(ctx, dsn)
    25  }
    26  
    27  type sqliteAdapter struct {
    28  	dialect
    29  	adapter.DB
    30  	mutexSet
    31  }
    32  
    33  func (sqliteAdapter) DriverName() string {
    34  	return "sqlite"
    35  }
    36  
    37  func (a *sqliteAdapter) Dialect() adapter.Dialect {
    38  	return &a.dialect
    39  }
    40  
    41  func (a *sqliteAdapter) Connector() driver.DriverContext {
    42  	return loggingdriver.Wrap(
    43  		&sqlite.Driver{},
    44  		a.DriverName(),
    45  		func(err error) int {
    46  			if e, ok := dberr.UnwrapAll(err).(*sqlite.Error); ok {
    47  				if e.Code() == 2067 {
    48  					return 0
    49  				}
    50  			}
    51  			return 1
    52  		},
    53  	)
    54  }
    55  
    56  func (a *sqliteAdapter) Open(ctx context.Context, dsn *url.URL) (adapter.Adapter, error) {
    57  	if a.DriverName() != dsn.Scheme {
    58  		return nil, errors.Errorf("invalid schema %s", dsn)
    59  	}
    60  
    61  	dbUri := dsn.Path + "?" + dsn.Query().Encode()
    62  
    63  	connector := &driverContextWithMutex{
    64  		DriverContext: a.Connector(),
    65  		Mutex:         a.of(dbUri),
    66  	}
    67  
    68  	conn, err := connector.OpenConnector(dbUri)
    69  	if err != nil {
    70  		return nil, errors.Wrapf(err, "connect failed with %s", dsn.Path)
    71  	}
    72  
    73  	db := sql.OpenDB(conn)
    74  
    75  	return &sqliteAdapter{
    76  		DB: adapter.Wrap(db, func(err error) error {
    77  			if isErrorConflict(err) {
    78  				return dberr.New(dberr.ErrTypeConflict, err.Error())
    79  			}
    80  			return err
    81  		}),
    82  	}, nil
    83  }
    84  
    85  func isErrorConflict(err error) bool {
    86  	if e, ok := dberr.UnwrapAll(err).(*sqlite.Error); ok && e.Code() == 2067 {
    87  		return true
    88  	}
    89  	return false
    90  }
    91  
    92  func (a *sqliteAdapter) createDatabase(ctx context.Context, dbName string, dsn url.URL) error {
    93  	dsn.Path = ""
    94  
    95  	adaptor, err := a.Open(ctx, &dsn)
    96  	if err != nil {
    97  		return err
    98  	}
    99  	defer adaptor.Close()
   100  
   101  	_, err = adaptor.Exec(context.Background(), sqlbuilder.Expr(fmt.Sprintf("CREATE DATABASE %s", dbName)))
   102  	return err
   103  }