github.com/blend/go-sdk@v1.20220411.3/testutil/connection.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package testutil 9 10 import ( 11 "context" 12 "errors" 13 "fmt" 14 "net" 15 "strings" 16 17 "github.com/jackc/pgconn" 18 19 "github.com/blend/go-sdk/db" 20 "github.com/blend/go-sdk/env" 21 "github.com/blend/go-sdk/ex" 22 ) 23 24 const ( 25 passwordText = "..password-redacted.." 26 requireDBErrorTemplate = `%s 27 %s 28 Connection String: 29 %q 30 ` 31 ) 32 33 func redactEnvironmentVariable(key, value string) string { 34 if key == db.EnvVarDBPassword { 35 return passwordText 36 } 37 38 if key == db.EnvVarDatabaseURL { 39 return createLoggingDSN(db.Config{DSN: value}) 40 } 41 42 return value 43 } 44 45 // allDBEnvironmentVariables returns a slice of all the environment variables 46 // used in `sdk/db/config.go::Config.Resolve()`. The accepted list of 47 // environment variables may change there over time so this hardcoded list 48 // may drift. 49 func allDBEnvironmentVariables() []string { 50 return []string{ 51 db.EnvVarDBEngine, 52 db.EnvVarDatabaseURL, 53 db.EnvVarDBHost, 54 db.EnvVarDBPort, 55 db.EnvVarDBName, 56 db.EnvVarDBSchema, 57 db.EnvVarDBApplicationName, 58 db.EnvVarDBUser, 59 db.EnvVarDBPassword, 60 db.EnvVarDBConnectTimeout, 61 db.EnvVarDBLockTimeout, 62 db.EnvVarDBStatementTimeout, 63 db.EnvVarDBSSLMode, 64 db.EnvVarDBIdleConnections, 65 db.EnvVarDBMaxConnections, 66 db.EnvVarDBMaxLifetime, 67 db.EnvVarDBBufferPoolSize, 68 db.EnvVarDBDialect, 69 } 70 } 71 72 func envResolveError(err error, ev env.Vars) error { 73 found := []string{} 74 for _, name := range allDBEnvironmentVariables() { 75 value, ok := ev[name] 76 if ok { 77 line := fmt.Sprintf("- %s=%q", name, redactEnvironmentVariable(name, value)) 78 found = append(found, line) 79 } 80 } 81 82 lines := []string{ 83 "Failed to read 'DB_*' environment variables. Error:", 84 "%s", 85 } 86 if len(found) > 0 { 87 lines = append(lines, "") 88 lines = append(lines, "Environment Variables:") 89 lines = append(lines, found...) 90 } 91 92 lines = append(lines, "") // Trailing newline 93 template := strings.Join(lines, "\n") 94 errString := fmt.Sprintf("%v", err) 95 return ex.Class(fmt.Sprintf(template, indentTwo(errString))) 96 } 97 98 // ResolveDBConfig is intended to be used to help debug issues resolving 99 // a `db.Config` from the environment. 100 // 101 // In the case of failure, this wraps the `Resolve()` error with a helpful 102 // message and a list of all relevant environment variables. 103 func ResolveDBConfig(ctx context.Context, c *db.Config) error { 104 ev := env.GetVars(ctx) 105 err := c.Resolve(ctx) 106 if err == nil { 107 return nil 108 } 109 110 return envResolveError(err, ev) 111 } 112 113 func indentTwo(s string) string { 114 lines := strings.Split(s, "\n") 115 indented := make([]string, len(lines)) 116 for i, line := range lines { 117 indented[i] = " " + line 118 } 119 return strings.Join(indented, "\n") 120 } 121 122 func getSQLErrorMessage(err error) *string { 123 errString := err.Error() 124 // NOTE: The string-munging is partially because `errors.errorString` is 125 // not exported. We could instead get around this by using `reflect` 126 // to get the underlying package and type name. Additionally, these 127 // errors may be wrapped in an `ex.Ex` as `Class`. 128 if strings.HasPrefix(errString, "sql: ") { 129 withoutPrefix := strings.TrimPrefix(errString, "sql: ") 130 return &withoutPrefix 131 } 132 133 return nil 134 } 135 136 // ValidatePool validates that 137 // - the connection string is valid 138 // - the selected `sql` driver can be used 139 // - a simple ping can be sent over the connection (is the DB reachable?) 140 // 141 // In the case of failure, this tries to diagnose the connection error and 142 // produce helpful tips on how to resolve. 143 func ValidatePool(ctx context.Context, pool *db.Connection, hints string) error { 144 if pool == nil { 145 return ex.New("Cannot validate a nil connection pool") 146 } 147 148 err := poolOpen(pool, hints) 149 if err != nil { 150 return err 151 } 152 153 return verifyConnect(ctx, pool, hints) 154 } 155 156 func formatKnownError(header, hints, dsn string) error { 157 return ex.Class(fmt.Sprintf(requireDBErrorTemplate, header, hints, dsn)) 158 } 159 160 func formatUnknownError(header, hints, dsn string) error { 161 return ex.New(fmt.Sprintf(requireDBErrorTemplate, header, hints, dsn)) 162 } 163 164 // poolOpen calls `Open()` to verify the connection string is valid and 165 // that the selected `sql` driver can be used. 166 func poolOpen(pool *db.Connection, hints string) error { 167 // Early exit if the connection is already open. 168 if pool.Connection != nil { 169 return nil 170 } 171 172 err := pool.Open() 173 if err == nil { 174 return nil 175 } 176 177 dsn := createLoggingDSN(pool.Config) 178 sqlErrorMessage := getSQLErrorMessage(err) 179 if sqlErrorMessage != nil { 180 header := fmt.Sprintf( 181 "Error from 'sql' package:\n %s\nDatabase Engine:\n %s", 182 *sqlErrorMessage, pool.Config.EngineOrDefault(), 183 ) 184 return formatKnownError(header, hints, dsn) 185 } 186 187 errString := fmt.Sprintf("%+v", err) 188 header := fmt.Sprintf("Unexpected Open() failure:\n%s", indentTwo(errString)) 189 return formatUnknownError(header, hints, dsn) 190 } 191 192 func unwrapNetOpError(err error) *net.OpError { 193 noe, ok := err.(*net.OpError) 194 if ok { 195 return noe 196 } 197 198 ue := errors.Unwrap(err) 199 noe, ok = ue.(*net.OpError) 200 if ok { 201 return noe 202 } 203 204 return nil 205 } 206 207 func isConnectionRefused(err error) bool { 208 noe := unwrapNetOpError(err) 209 if noe == nil { 210 return false 211 } 212 213 // NOTE: We could go deeper in here by type asserting `noe.Err` as an 214 // `*os.SyscallError` and checking for `syscall.ECONNREFUSED`. 215 // The string `connect: connection refused` has been verified in 216 // Go 1.12, 1.13, 1.14, 1.15 on macOS and Alpine Linux but may change 217 // in future releases. 218 return noe.Err.Error() == "connect: connection refused" 219 } 220 221 func getPGXErrorMessage(err error) *string { 222 pe, ok := err.(*pgconn.PgError) 223 if ok { 224 return &pe.Message 225 } 226 227 ue := errors.Unwrap(err) 228 pe, ok = ue.(*pgconn.PgError) 229 if ok { 230 return &pe.Message 231 } 232 233 errString := err.Error() 234 // NOTE: The string-munging is partially because `pgconn.connectError` is 235 // not exported. 236 if strings.HasPrefix(errString, "failed to connect to `host=") { 237 wrappedErrString := ue.Error() 238 return &wrappedErrString 239 } 240 241 return nil 242 } 243 244 // verifyConnect verifies that the target database is actually running and the 245 // connection pool can actually connect. 246 func verifyConnect(ctx context.Context, pool *db.Connection, hints string) error { 247 err := pool.Connection.PingContext(ctx) 248 if err == nil { 249 return nil 250 } 251 252 dsn := createLoggingDSN(pool.Config) 253 if isConnectionRefused(err) { 254 header := "Network error:\n Could not connect to database." 255 return formatKnownError(header, hints, dsn) 256 } 257 258 pgxErrorMessage := getPGXErrorMessage(err) 259 if pgxErrorMessage != nil { 260 header := fmt.Sprintf("PostgreSQL error when connecting to the database:\n %s", *pgxErrorMessage) 261 return formatKnownError(header, hints, dsn) 262 } 263 264 errString := fmt.Sprintf("%+v", err) 265 header := fmt.Sprintf("Unexpected PingContext() failure:\n%s", indentTwo(errString)) 266 return formatUnknownError(header, hints, dsn) 267 } 268 269 func createLoggingDSN(c db.Config) string { 270 if c.DSN != "" { 271 nc, err := db.NewConfigFromDSN(c.DSN) 272 if err != nil { 273 return "Failed to parse DSN: see DATABASE_URL environment variable" 274 } 275 return createLoggingDSN(nc) 276 } 277 278 dsn := c.CreateLoggingDSN() 279 if c.Username == "" || c.Password == "" { 280 return dsn 281 } 282 283 parts := strings.SplitN(dsn, "@", 2) 284 if len(parts) != 2 { 285 return dsn 286 } 287 288 return fmt.Sprintf("%s:%s@%s", parts[0], passwordText, parts[1]) 289 }