github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/db/open.go (about) 1 package db 2 3 import ( 4 "context" 5 "database/sql" 6 "database/sql/driver" 7 "fmt" 8 "io" 9 "strings" 10 "time" 11 12 "code.cloudfoundry.org/lager" 13 "github.com/Masterminds/squirrel" 14 "github.com/pf-qiu/concourse/v6/atc/db/encryption" 15 "github.com/pf-qiu/concourse/v6/atc/db/lock" 16 "github.com/pf-qiu/concourse/v6/atc/db/migration" 17 multierror "github.com/hashicorp/go-multierror" 18 "github.com/lib/pq" 19 ) 20 21 //go:generate counterfeiter . Conn 22 23 type Conn interface { 24 Bus() NotificationsBus 25 EncryptionStrategy() encryption.Strategy 26 27 Ping() error 28 Driver() driver.Driver 29 30 Begin() (Tx, error) 31 Exec(string, ...interface{}) (sql.Result, error) 32 Prepare(string) (*sql.Stmt, error) 33 Query(string, ...interface{}) (*sql.Rows, error) 34 QueryRow(string, ...interface{}) squirrel.RowScanner 35 36 BeginTx(context.Context, *sql.TxOptions) (Tx, error) 37 ExecContext(context.Context, string, ...interface{}) (sql.Result, error) 38 PrepareContext(context.Context, string) (*sql.Stmt, error) 39 QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) 40 QueryRowContext(context.Context, string, ...interface{}) squirrel.RowScanner 41 42 SetMaxIdleConns(int) 43 SetMaxOpenConns(int) 44 Stats() sql.DBStats 45 46 Close() error 47 Name() string 48 } 49 50 //go:generate counterfeiter . Tx 51 52 type Tx interface { 53 Commit() error 54 Exec(string, ...interface{}) (sql.Result, error) 55 Prepare(string) (*sql.Stmt, error) 56 Query(string, ...interface{}) (*sql.Rows, error) 57 QueryRow(string, ...interface{}) squirrel.RowScanner 58 ExecContext(context.Context, string, ...interface{}) (sql.Result, error) 59 PrepareContext(context.Context, string) (*sql.Stmt, error) 60 QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) 61 QueryRowContext(context.Context, string, ...interface{}) squirrel.RowScanner 62 Rollback() error 63 Stmt(*sql.Stmt) *sql.Stmt 64 EncryptionStrategy() encryption.Strategy 65 } 66 67 func Open(logger lager.Logger, driver, dsn string, newKey, oldKey *encryption.Key, name string, lockFactory lock.LockFactory) (Conn, error) { 68 for { 69 sqlDB, err := migration.NewOpenHelper(driver, dsn, lockFactory, newKey, oldKey).Open() 70 if err != nil { 71 if shouldRetry(err) { 72 logger.Error("failed-to-open-db-retrying", err) 73 time.Sleep(5 * time.Second) 74 continue 75 } 76 77 return nil, err 78 } 79 80 return NewConn(name, sqlDB, dsn, oldKey, newKey), nil 81 } 82 } 83 84 func NewConn(name string, sqlDB *sql.DB, dsn string, oldKey, newKey *encryption.Key) Conn { 85 listener := pq.NewDialListener(keepAliveDialer{}, dsn, time.Second, time.Minute, nil) 86 87 var strategy encryption.Strategy 88 if newKey != nil { 89 strategy = newKey 90 } else { 91 strategy = encryption.NewNoEncryption() 92 } 93 94 return &db{ 95 DB: sqlDB, 96 97 bus: NewNotificationsBus(listener, sqlDB), 98 encryption: strategy, 99 name: name, 100 } 101 } 102 103 func shouldRetry(err error) bool { 104 if strings.Contains(err.Error(), "dial ") { 105 return true 106 } 107 108 if pqErr, ok := err.(*pq.Error); ok { 109 return pqErr.Code.Name() == "cannot_connect_now" 110 } 111 112 return false 113 } 114 115 type db struct { 116 *sql.DB 117 118 bus NotificationsBus 119 encryption encryption.Strategy 120 name string 121 } 122 123 func (db *db) Name() string { 124 return db.name 125 } 126 127 func (db *db) Bus() NotificationsBus { 128 return db.bus 129 } 130 131 func (db *db) EncryptionStrategy() encryption.Strategy { 132 return db.encryption 133 } 134 135 func (db *db) Close() error { 136 var errs error 137 dbErr := db.DB.Close() 138 if dbErr != nil { 139 errs = multierror.Append(errs, dbErr) 140 } 141 142 busErr := db.bus.Close() 143 if busErr != nil { 144 errs = multierror.Append(errs, busErr) 145 } 146 147 return errs 148 } 149 150 // Close ignores errors, and should used with defer. 151 // makes errcheck happy that those errs are captured 152 func Close(c io.Closer) { 153 _ = c.Close() 154 } 155 156 func (db *db) Begin() (Tx, error) { 157 tx, err := db.DB.Begin() 158 if err != nil { 159 return nil, err 160 } 161 162 return &dbTx{tx, GlobalConnectionTracker.Track(), db.EncryptionStrategy()}, nil 163 } 164 165 func (db *db) Exec(query string, args ...interface{}) (sql.Result, error) { 166 defer GlobalConnectionTracker.Track().Release() 167 return db.DB.Exec(query, args...) 168 } 169 170 func (db *db) Prepare(query string) (*sql.Stmt, error) { 171 defer GlobalConnectionTracker.Track().Release() 172 return db.DB.Prepare(query) 173 } 174 175 func (db *db) Query(query string, args ...interface{}) (*sql.Rows, error) { 176 defer GlobalConnectionTracker.Track().Release() 177 return db.DB.Query(query, args...) 178 } 179 180 // to conform to squirrel.Runner interface 181 func (db *db) QueryRow(query string, args ...interface{}) squirrel.RowScanner { 182 defer GlobalConnectionTracker.Track().Release() 183 return db.DB.QueryRow(query, args...) 184 } 185 186 func (db *db) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { 187 tx, err := db.DB.BeginTx(ctx, opts) 188 if err != nil { 189 return nil, err 190 } 191 192 return &dbTx{tx, GlobalConnectionTracker.Track(), db.EncryptionStrategy()}, nil 193 } 194 195 func (db *db) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { 196 defer GlobalConnectionTracker.Track().Release() 197 return db.DB.ExecContext(ctx, query, args...) 198 } 199 200 func (db *db) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { 201 defer GlobalConnectionTracker.Track().Release() 202 return db.DB.PrepareContext(ctx, query) 203 } 204 205 func (db *db) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { 206 defer GlobalConnectionTracker.Track().Release() 207 return db.DB.QueryContext(ctx, query, args...) 208 } 209 210 // to conform to squirrel.Runner interface 211 func (db *db) QueryRowContext(ctx context.Context, query string, args ...interface{}) squirrel.RowScanner { 212 defer GlobalConnectionTracker.Track().Release() 213 return db.DB.QueryRowContext(ctx, query, args...) 214 } 215 216 type dbTx struct { 217 *sql.Tx 218 219 session *ConnectionSession 220 encryptionStrategy encryption.Strategy 221 } 222 223 // to conform to squirrel.Runner interface 224 func (tx *dbTx) QueryRow(query string, args ...interface{}) squirrel.RowScanner { 225 return tx.Tx.QueryRow(query, args...) 226 } 227 228 func (tx *dbTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) squirrel.RowScanner { 229 return tx.Tx.QueryRowContext(ctx, query, args...) 230 } 231 232 func (tx *dbTx) Commit() error { 233 defer tx.session.Release() 234 return tx.Tx.Commit() 235 } 236 237 func (tx *dbTx) Rollback() error { 238 defer tx.session.Release() 239 return tx.Tx.Rollback() 240 } 241 242 func (tx *dbTx) EncryptionStrategy() encryption.Strategy { 243 return tx.encryptionStrategy 244 } 245 246 // Rollback ignores errors, and should be used with defer. 247 // makes errcheck happy that those errs are captured 248 func Rollback(tx Tx) { 249 _ = tx.Rollback() 250 } 251 252 type NonOneRowAffectedError struct { 253 RowsAffected int64 254 } 255 256 func (err NonOneRowAffectedError) Error() string { 257 return fmt.Sprintf("expected 1 row to be updated; got %d", err.RowsAffected) 258 }