github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/db/open.go (about)

     1  package db
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"io"
     9  	"strings"
    10  	"time"
    11  
    12  	"code.cloudfoundry.org/lager"
    13  	"github.com/Masterminds/squirrel"
    14  	"github.com/pf-qiu/concourse/v6/atc/db/encryption"
    15  	"github.com/pf-qiu/concourse/v6/atc/db/lock"
    16  	"github.com/pf-qiu/concourse/v6/atc/db/migration"
    17  	multierror "github.com/hashicorp/go-multierror"
    18  	"github.com/lib/pq"
    19  )
    20  
    21  //go:generate counterfeiter . Conn
    22  
    23  type Conn interface {
    24  	Bus() NotificationsBus
    25  	EncryptionStrategy() encryption.Strategy
    26  
    27  	Ping() error
    28  	Driver() driver.Driver
    29  
    30  	Begin() (Tx, error)
    31  	Exec(string, ...interface{}) (sql.Result, error)
    32  	Prepare(string) (*sql.Stmt, error)
    33  	Query(string, ...interface{}) (*sql.Rows, error)
    34  	QueryRow(string, ...interface{}) squirrel.RowScanner
    35  
    36  	BeginTx(context.Context, *sql.TxOptions) (Tx, error)
    37  	ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
    38  	PrepareContext(context.Context, string) (*sql.Stmt, error)
    39  	QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
    40  	QueryRowContext(context.Context, string, ...interface{}) squirrel.RowScanner
    41  
    42  	SetMaxIdleConns(int)
    43  	SetMaxOpenConns(int)
    44  	Stats() sql.DBStats
    45  
    46  	Close() error
    47  	Name() string
    48  }
    49  
    50  //go:generate counterfeiter . Tx
    51  
    52  type Tx interface {
    53  	Commit() error
    54  	Exec(string, ...interface{}) (sql.Result, error)
    55  	Prepare(string) (*sql.Stmt, error)
    56  	Query(string, ...interface{}) (*sql.Rows, error)
    57  	QueryRow(string, ...interface{}) squirrel.RowScanner
    58  	ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
    59  	PrepareContext(context.Context, string) (*sql.Stmt, error)
    60  	QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
    61  	QueryRowContext(context.Context, string, ...interface{}) squirrel.RowScanner
    62  	Rollback() error
    63  	Stmt(*sql.Stmt) *sql.Stmt
    64  	EncryptionStrategy() encryption.Strategy
    65  }
    66  
    67  func Open(logger lager.Logger, driver, dsn string, newKey, oldKey *encryption.Key, name string, lockFactory lock.LockFactory) (Conn, error) {
    68  	for {
    69  		sqlDB, err := migration.NewOpenHelper(driver, dsn, lockFactory, newKey, oldKey).Open()
    70  		if err != nil {
    71  			if shouldRetry(err) {
    72  				logger.Error("failed-to-open-db-retrying", err)
    73  				time.Sleep(5 * time.Second)
    74  				continue
    75  			}
    76  
    77  			return nil, err
    78  		}
    79  
    80  		return NewConn(name, sqlDB, dsn, oldKey, newKey), nil
    81  	}
    82  }
    83  
    84  func NewConn(name string, sqlDB *sql.DB, dsn string, oldKey, newKey *encryption.Key) Conn {
    85  	listener := pq.NewDialListener(keepAliveDialer{}, dsn, time.Second, time.Minute, nil)
    86  
    87  	var strategy encryption.Strategy
    88  	if newKey != nil {
    89  		strategy = newKey
    90  	} else {
    91  		strategy = encryption.NewNoEncryption()
    92  	}
    93  
    94  	return &db{
    95  		DB: sqlDB,
    96  
    97  		bus:        NewNotificationsBus(listener, sqlDB),
    98  		encryption: strategy,
    99  		name:       name,
   100  	}
   101  }
   102  
   103  func shouldRetry(err error) bool {
   104  	if strings.Contains(err.Error(), "dial ") {
   105  		return true
   106  	}
   107  
   108  	if pqErr, ok := err.(*pq.Error); ok {
   109  		return pqErr.Code.Name() == "cannot_connect_now"
   110  	}
   111  
   112  	return false
   113  }
   114  
   115  type db struct {
   116  	*sql.DB
   117  
   118  	bus        NotificationsBus
   119  	encryption encryption.Strategy
   120  	name       string
   121  }
   122  
   123  func (db *db) Name() string {
   124  	return db.name
   125  }
   126  
   127  func (db *db) Bus() NotificationsBus {
   128  	return db.bus
   129  }
   130  
   131  func (db *db) EncryptionStrategy() encryption.Strategy {
   132  	return db.encryption
   133  }
   134  
   135  func (db *db) Close() error {
   136  	var errs error
   137  	dbErr := db.DB.Close()
   138  	if dbErr != nil {
   139  		errs = multierror.Append(errs, dbErr)
   140  	}
   141  
   142  	busErr := db.bus.Close()
   143  	if busErr != nil {
   144  		errs = multierror.Append(errs, busErr)
   145  	}
   146  
   147  	return errs
   148  }
   149  
   150  // Close ignores errors, and should used with defer.
   151  // makes errcheck happy that those errs are captured
   152  func Close(c io.Closer) {
   153  	_ = c.Close()
   154  }
   155  
   156  func (db *db) Begin() (Tx, error) {
   157  	tx, err := db.DB.Begin()
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  
   162  	return &dbTx{tx, GlobalConnectionTracker.Track(), db.EncryptionStrategy()}, nil
   163  }
   164  
   165  func (db *db) Exec(query string, args ...interface{}) (sql.Result, error) {
   166  	defer GlobalConnectionTracker.Track().Release()
   167  	return db.DB.Exec(query, args...)
   168  }
   169  
   170  func (db *db) Prepare(query string) (*sql.Stmt, error) {
   171  	defer GlobalConnectionTracker.Track().Release()
   172  	return db.DB.Prepare(query)
   173  }
   174  
   175  func (db *db) Query(query string, args ...interface{}) (*sql.Rows, error) {
   176  	defer GlobalConnectionTracker.Track().Release()
   177  	return db.DB.Query(query, args...)
   178  }
   179  
   180  // to conform to squirrel.Runner interface
   181  func (db *db) QueryRow(query string, args ...interface{}) squirrel.RowScanner {
   182  	defer GlobalConnectionTracker.Track().Release()
   183  	return db.DB.QueryRow(query, args...)
   184  }
   185  
   186  func (db *db) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
   187  	tx, err := db.DB.BeginTx(ctx, opts)
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  
   192  	return &dbTx{tx, GlobalConnectionTracker.Track(), db.EncryptionStrategy()}, nil
   193  }
   194  
   195  func (db *db) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
   196  	defer GlobalConnectionTracker.Track().Release()
   197  	return db.DB.ExecContext(ctx, query, args...)
   198  }
   199  
   200  func (db *db) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
   201  	defer GlobalConnectionTracker.Track().Release()
   202  	return db.DB.PrepareContext(ctx, query)
   203  }
   204  
   205  func (db *db) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
   206  	defer GlobalConnectionTracker.Track().Release()
   207  	return db.DB.QueryContext(ctx, query, args...)
   208  }
   209  
   210  // to conform to squirrel.Runner interface
   211  func (db *db) QueryRowContext(ctx context.Context, query string, args ...interface{}) squirrel.RowScanner {
   212  	defer GlobalConnectionTracker.Track().Release()
   213  	return db.DB.QueryRowContext(ctx, query, args...)
   214  }
   215  
   216  type dbTx struct {
   217  	*sql.Tx
   218  
   219  	session            *ConnectionSession
   220  	encryptionStrategy encryption.Strategy
   221  }
   222  
   223  // to conform to squirrel.Runner interface
   224  func (tx *dbTx) QueryRow(query string, args ...interface{}) squirrel.RowScanner {
   225  	return tx.Tx.QueryRow(query, args...)
   226  }
   227  
   228  func (tx *dbTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) squirrel.RowScanner {
   229  	return tx.Tx.QueryRowContext(ctx, query, args...)
   230  }
   231  
   232  func (tx *dbTx) Commit() error {
   233  	defer tx.session.Release()
   234  	return tx.Tx.Commit()
   235  }
   236  
   237  func (tx *dbTx) Rollback() error {
   238  	defer tx.session.Release()
   239  	return tx.Tx.Rollback()
   240  }
   241  
   242  func (tx *dbTx) EncryptionStrategy() encryption.Strategy {
   243  	return tx.encryptionStrategy
   244  }
   245  
   246  // Rollback ignores errors, and should be used with defer.
   247  // makes errcheck happy that those errs are captured
   248  func Rollback(tx Tx) {
   249  	_ = tx.Rollback()
   250  }
   251  
   252  type NonOneRowAffectedError struct {
   253  	RowsAffected int64
   254  }
   255  
   256  func (err NonOneRowAffectedError) Error() string {
   257  	return fmt.Sprintf("expected 1 row to be updated; got %d", err.RowsAffected)
   258  }