github.com/blend/go-sdk@v1.20220411.3/db/config.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package db 9 10 import ( 11 "context" 12 "fmt" 13 "net/url" 14 "strconv" 15 "strings" 16 "time" 17 18 "github.com/blend/go-sdk/configutil" 19 "github.com/blend/go-sdk/env" 20 "github.com/blend/go-sdk/ex" 21 "github.com/blend/go-sdk/stringutil" 22 ) 23 24 // NewConfigFromDSN creates a new config from a DSN. 25 // Errors can be produced by parsing the DSN. 26 func NewConfigFromDSN(dsn string) (config Config, err error) { 27 parsed, parseErr := ParseURL(dsn) 28 if parseErr != nil { 29 err = ex.New(parseErr) 30 return 31 } 32 33 pieces := stringutil.SplitSpace(parsed) 34 for _, piece := range pieces { 35 if strings.HasPrefix(piece, "host=") { 36 config.Host = strings.TrimPrefix(piece, "host=") 37 } else if strings.HasPrefix(piece, "port=") { 38 config.Port = strings.TrimPrefix(piece, "port=") 39 } else if strings.HasPrefix(piece, "dbname=") { 40 config.Database = strings.TrimPrefix(piece, "dbname=") 41 } else if strings.HasPrefix(piece, "user=") { 42 config.Username = strings.TrimPrefix(piece, "user=") 43 } else if strings.HasPrefix(piece, "password=") { 44 config.Password = strings.TrimPrefix(piece, "password=") 45 } else if strings.HasPrefix(piece, "sslmode=") { 46 config.SSLMode = strings.TrimPrefix(piece, "sslmode=") 47 } else if strings.HasPrefix(piece, "search_path=") { 48 config.Schema = strings.TrimPrefix(piece, "search_path=") 49 } else if strings.HasPrefix(piece, "application_name=") { 50 config.ApplicationName = strings.TrimPrefix(piece, "application_name=") 51 } else if strings.HasPrefix(piece, "connect_timeout=") { 52 timeout, parseErr := strconv.Atoi(strings.TrimPrefix(piece, "connect_timeout=")) 53 if parseErr != nil { 54 err = ex.New(parseErr, ex.OptMessage("field: connect_timeout")) 55 return 56 } 57 config.ConnectTimeout = time.Second * time.Duration(timeout) 58 } else if strings.HasPrefix(piece, "lock_timeout=") { 59 config.LockTimeout, parseErr = time.ParseDuration(strings.TrimPrefix(piece, "lock_timeout=")) 60 if parseErr != nil { 61 err = ex.New(parseErr, ex.OptMessage("field: lock_timeout")) 62 return 63 } 64 } else if strings.HasPrefix(piece, "statement_timeout=") { 65 config.StatementTimeout, parseErr = time.ParseDuration(strings.TrimPrefix(piece, "statement_timeout=")) 66 if parseErr != nil { 67 err = ex.New(parseErr, ex.OptMessage("field: statement_timeout")) 68 return 69 } 70 } 71 } 72 73 return 74 } 75 76 // NewConfigFromEnv returns a new config from the environment. 77 // The environment variable mappings are as follows: 78 // - DB_ENGINE = Engine 79 // - DATABASE_URL = DSN //note that this has precedence over other vars (!!) 80 // - DB_HOST = Host 81 // - DB_PORT = Port 82 // - DB_NAME = Database 83 // - DB_SCHEMA = Schema 84 // - DB_APPLICATION_NAME = ApplicationName 85 // - DB_USER = Username 86 // - DB_PASSWORD = Password 87 // - DB_CONNECT_TIMEOUT = ConnectTimeout 88 // - DB_LOCK_TIMEOUT = LockTimeout 89 // - DB_STATEMENT_TIMEOUT = StatementTimeout 90 // - DB_SSLMODE = SSLMode 91 // - DB_IDLE_CONNECTIONS = IdleConnections 92 // - DB_MAX_CONNECTIONS = MaxConnections 93 // - DB_MAX_LIFETIME = MaxLifetime 94 // - DB_BUFFER_POOL_SIZE = BufferPoolSize 95 // - DB_DIALECT = Dialect 96 func NewConfigFromEnv() (config Config, err error) { 97 if err = (&config).Resolve(env.WithVars(context.Background(), env.Env())); err != nil { 98 return 99 } 100 return 101 } 102 103 // MustNewConfigFromEnv returns a new config from the environment, 104 // it will panic if there is an error. 105 func MustNewConfigFromEnv() Config { 106 cfg, err := NewConfigFromEnv() 107 if err != nil { 108 panic(err) 109 } 110 return cfg 111 } 112 113 // Config is a set of connection config options. 114 type Config struct { 115 // Engine is the database engine. 116 Engine string `json:"engine,omitempty" yaml:"engine,omitempty" env:"DB_ENGINE"` 117 // DSN is a fully formed DSN (this skips DSN formation from all other variables outside `schema`). 118 DSN string `json:"dsn,omitempty" yaml:"dsn,omitempty" env:"DATABASE_URL"` 119 // Host is the server to connect to. 120 Host string `json:"host,omitempty" yaml:"host,omitempty" env:"DB_HOST"` 121 // Port is the port to connect to. 122 Port string `json:"port,omitempty" yaml:"port,omitempty" env:"DB_PORT"` 123 // DBName is the database name 124 Database string `json:"database,omitempty" yaml:"database,omitempty" env:"DB_NAME"` 125 // Schema is the application schema within the database, defaults to `public`. This schema is used to set the 126 // Postgres "search_path" If you want to reference tables in other schemas, you'll need to specify those schemas 127 // in your queries e.g. "SELECT * FROM schema_two.table_one..." 128 // Using the public schema in a production application is considered bad practice as newly created roles will have 129 // visibility into this data by default. We strongly recommend specifying this option and using a schema that is 130 // owned by your service's role 131 // We recommend against setting a multi-schema search_path, but if you really want to, you provide multiple comma- 132 // separated schema names as the value for this config, or you can dbc.Invoke().Exec a SET statement on a newly 133 // opened connection such as "SET search_path = 'schema_one,schema_two';" Again, we recommend against this practice 134 // and encourage you to specify schema names beyond the first in your queries. 135 Schema string `json:"schema,omitempty" yaml:"schema,omitempty" env:"DB_SCHEMA"` 136 // ApplicationName is the name set by an application connection to a database 137 // server, intended to be transmitted in the connection string. It can be 138 // used to uniquely identify an application and will be included in the 139 // `pg_stat_activity` view. 140 // 141 // See: https://www.postgresql.org/docs/12/runtime-config-logging.html#GUC-APPLICATION-NAME 142 ApplicationName string `json:"applicationName,omitempty" yaml:"applicationName,omitempty" env:"DB_APPLICATION_NAME"` 143 // Username is the username for the connection via password auth. 144 Username string `json:"username,omitempty" yaml:"username,omitempty" env:"DB_USER"` 145 // Password is the password for the connection via password auth. 146 Password string `json:"password,omitempty" yaml:"password,omitempty" env:"DB_PASSWORD"` 147 // ConnectTimeout determines the maximum wait for connection. The minimum 148 // allowed timeout is 2 seconds, so anything below is treated the same 149 // as unset. PostgreSQL will only accept second precision so this value will be 150 // rounded to the nearest second before being set on a connection string. 151 // Use `Validate()` to confirm that `ConnectTimeout` is exact to second 152 // precision. 153 // 154 // See: https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-CONNECT-CONNECT-TIMEOUT 155 ConnectTimeout time.Duration `json:"connectTimeout,omitempty" yaml:"connectTimeout,omitempty" env:"DB_CONNECT_TIMEOUT"` 156 // LockTimeout is the timeout to use when attempting to acquire a lock. 157 // PostgreSQL will only accept millisecond precision so this value will be 158 // rounded to the nearest millisecond before being set on a connection string. 159 // Use `Validate()` to confirm that `LockTimeout` is exact to millisecond 160 // precision. 161 // 162 // See: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-LOCK-TIMEOUT 163 LockTimeout time.Duration `json:"lockTimeout,omitempty" yaml:"lockTimeout,omitempty" env:"DB_LOCK_TIMEOUT"` 164 // StatementTimeout is the timeout to use when invoking a SQL statement. 165 // PostgreSQL will only accept millisecond precision so this value will be 166 // rounded to the nearest millisecond before being set on a connection string. 167 // Use `Validate()` to confirm that `StatementTimeout` is exact to millisecond 168 // precision. 169 // 170 // See: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT 171 StatementTimeout time.Duration `json:"statementTimeout,omitempty" yaml:"statementTimeout,omitempty" env:"DB_STATEMENT_TIMEOUT"` 172 // SSLMode is the sslmode for the connection. 173 SSLMode string `json:"sslMode,omitempty" yaml:"sslMode,omitempty" env:"DB_SSLMODE"` 174 // IdleConnections is the number of idle connections. 175 IdleConnections int `json:"idleConnections,omitempty" yaml:"idleConnections,omitempty" env:"DB_IDLE_CONNECTIONS"` 176 // MaxConnections is the maximum number of connections. 177 MaxConnections int `json:"maxConnections,omitempty" yaml:"maxConnections,omitempty" env:"DB_MAX_CONNECTIONS"` 178 // MaxLifetime is the maximum time a connection can be open. 179 MaxLifetime time.Duration `json:"maxLifetime,omitempty" yaml:"maxLifetime,omitempty" env:"DB_MAX_LIFETIME"` 180 // MaxIdleTime is the maximum time a connection can be idle. 181 MaxIdleTime time.Duration `json:"maxIdleTime,omitempty" yaml:"maxIdleTime,omitempty" env:"DB_MAX_IDLE_TIME"` 182 // BufferPoolSize is the number of query composition buffers to maintain. 183 BufferPoolSize int `json:"bufferPoolSize,omitempty" yaml:"bufferPoolSize,omitempty" env:"DB_BUFFER_POOL_SIZE"` 184 // Dialect includes hints to tweak specific sql semantics by database connection. 185 Dialect string `json:"dialect,omitempty" yaml:"dialect,omitempty" env:"DB_DIALECT"` 186 } 187 188 // IsZero returns if the config is unset. 189 func (c Config) IsZero() bool { 190 return c.DSN == "" && c.Host == "" && c.Port == "" && c.Database == "" && c.Schema == "" && c.Username == "" && c.Password == "" && c.SSLMode == "" 191 } 192 193 // Resolve applies any external data sources to the config. 194 func (c *Config) Resolve(ctx context.Context) error { 195 return configutil.Resolve(ctx, 196 configutil.SetString(&c.Engine, configutil.Env(EnvVarDBEngine), configutil.String(c.Engine), configutil.String(DefaultEngine)), 197 configutil.SetString(&c.DSN, configutil.Env(EnvVarDatabaseURL), configutil.String(c.DSN)), 198 configutil.SetString(&c.Host, configutil.Env(EnvVarDBHost), configutil.String(c.Host), configutil.String(DefaultHost)), 199 configutil.SetString(&c.Port, configutil.Env(EnvVarDBPort), configutil.String(c.Port), configutil.String(DefaultPort)), 200 configutil.SetString(&c.Database, configutil.Env(EnvVarDBName), configutil.String(c.Database), configutil.String(DefaultDatabase)), 201 configutil.SetString(&c.Schema, configutil.Env(EnvVarDBSchema), configutil.String(c.Schema)), 202 configutil.SetString(&c.ApplicationName, configutil.Env(EnvVarDBApplicationName), configutil.String(c.ApplicationName)), 203 configutil.SetString(&c.Username, configutil.Env(EnvVarDBUser), configutil.String(c.Username), configutil.Env("USER")), 204 configutil.SetString(&c.Password, configutil.Env(EnvVarDBPassword), configutil.String(c.Password)), 205 configutil.SetDuration(&c.ConnectTimeout, configutil.Env(EnvVarDBConnectTimeout), configutil.Duration(c.ConnectTimeout), configutil.Duration(DefaultConnectTimeout)), 206 configutil.SetDuration(&c.LockTimeout, configutil.Env(EnvVarDBLockTimeout), configutil.Duration(c.LockTimeout)), 207 configutil.SetDuration(&c.StatementTimeout, configutil.Env(EnvVarDBStatementTimeout), configutil.Duration(c.StatementTimeout)), 208 configutil.SetString(&c.SSLMode, configutil.Env(EnvVarDBSSLMode), configutil.String(c.SSLMode)), 209 configutil.SetInt(&c.IdleConnections, configutil.Env(EnvVarDBIdleConnections), configutil.Int(c.IdleConnections), configutil.Int(DefaultIdleConnections)), 210 configutil.SetInt(&c.MaxConnections, configutil.Env(EnvVarDBMaxConnections), configutil.Int(c.MaxConnections), configutil.Int(DefaultMaxConnections)), 211 configutil.SetDuration(&c.MaxLifetime, configutil.Env(EnvVarDBMaxLifetime), configutil.Duration(c.MaxLifetime), configutil.Duration(DefaultMaxLifetime)), 212 configutil.SetDuration(&c.MaxIdleTime, configutil.Env(EnvVarDBMaxIdleTime), configutil.Duration(c.MaxIdleTime), configutil.Duration(DefaultMaxIdleTime)), 213 configutil.SetInt(&c.BufferPoolSize, configutil.Env(EnvVarDBBufferPoolSize), configutil.Int(c.BufferPoolSize), configutil.Int(DefaultBufferPoolSize)), 214 configutil.SetString(&c.Dialect, configutil.Env(EnvVarDBDialect), configutil.String(c.Dialect), configutil.String(DialectPostgres)), 215 ) 216 } 217 218 // Reparse creates a DSN and reparses it, in case some values need to be coalesced. 219 func (c Config) Reparse() (Config, error) { 220 cfg, err := NewConfigFromDSN(c.CreateDSN()) 221 if err != nil { 222 return Config{}, err 223 } 224 225 cfg.IdleConnections = c.IdleConnections 226 cfg.MaxConnections = c.MaxConnections 227 cfg.BufferPoolSize = c.BufferPoolSize 228 cfg.MaxLifetime = c.MaxLifetime 229 230 return cfg, nil 231 } 232 233 // MustReparse creates a DSN and reparses it, in case some values need to be coalesced, 234 // and panics if there is an error. 235 func (c Config) MustReparse() Config { 236 cfg, err := NewConfigFromDSN(c.CreateDSN()) 237 if err != nil { 238 panic(err) 239 } 240 return cfg 241 } 242 243 // EngineOrDefault returns the database engine. 244 func (c Config) EngineOrDefault() string { 245 if c.Engine != "" { 246 return c.Engine 247 } 248 return DefaultEngine 249 } 250 251 // HostOrDefault returns the postgres host for the connection or a default. 252 func (c Config) HostOrDefault() string { 253 if c.Host != "" { 254 return c.Host 255 } 256 return DefaultHost 257 } 258 259 // PortOrDefault returns the port for a connection if it is not the standard postgres port. 260 func (c Config) PortOrDefault() string { 261 if c.Port != "" { 262 return c.Port 263 } 264 return DefaultPort 265 } 266 267 // DatabaseOrDefault returns the connection database or a default. 268 func (c Config) DatabaseOrDefault() string { 269 if c.Database != "" { 270 return c.Database 271 } 272 return DefaultDatabase 273 } 274 275 // SchemaOrDefault returns the schema on the search_path or the default ("public"). It's considered bad practice to 276 // use the public schema in production 277 func (c Config) SchemaOrDefault() string { 278 if c.Schema != "" { 279 return c.Schema 280 } 281 return DefaultSchema 282 } 283 284 // IdleConnectionsOrDefault returns the number of idle connections or a default. 285 func (c Config) IdleConnectionsOrDefault() int { 286 if c.IdleConnections > 0 { 287 return c.IdleConnections 288 } 289 return DefaultIdleConnections 290 } 291 292 // MaxConnectionsOrDefault returns the maximum number of connections or a default. 293 func (c Config) MaxConnectionsOrDefault() int { 294 if c.MaxConnections > 0 { 295 return c.MaxConnections 296 } 297 return DefaultMaxConnections 298 } 299 300 // MaxLifetimeOrDefault returns the maximum lifetime of a driver connection. 301 func (c Config) MaxLifetimeOrDefault() time.Duration { 302 if c.MaxLifetime > 0 { 303 return c.MaxLifetime 304 } 305 return DefaultMaxLifetime 306 } 307 308 // MaxIdleTimeOrDefault returns the maximum idle time of a driver connection. 309 func (c Config) MaxIdleTimeOrDefault() time.Duration { 310 if c.MaxIdleTime > 0 { 311 return c.MaxIdleTime 312 } 313 return DefaultMaxIdleTime 314 } 315 316 // BufferPoolSizeOrDefault returns the number of query buffers to maintain or a default. 317 func (c Config) BufferPoolSizeOrDefault() int { 318 if c.BufferPoolSize > 0 { 319 return c.BufferPoolSize 320 } 321 return DefaultBufferPoolSize 322 } 323 324 // DialectOrDefault returns the sql dialect or a default. 325 func (c Config) DialectOrDefault() Dialect { 326 if c.Dialect != "" { 327 return Dialect(c.Dialect) 328 } 329 return DialectPostgres 330 } 331 332 // CreateDSN creates a postgres connection string from the config. 333 func (c Config) CreateDSN() string { 334 if c.DSN != "" { 335 return c.DSN 336 } 337 338 host := c.HostOrDefault() 339 if c.PortOrDefault() != "" { 340 host = host + ":" + c.PortOrDefault() 341 } 342 343 dsn := &url.URL{ 344 Scheme: "postgres", 345 Host: host, 346 Path: c.DatabaseOrDefault(), 347 } 348 349 if len(c.Username) > 0 { 350 if len(c.Password) > 0 { 351 dsn.User = url.UserPassword(c.Username, c.Password) 352 } else { 353 dsn.User = url.User(c.Username) 354 } 355 } 356 357 queryArgs := url.Values{} 358 if len(c.SSLMode) > 0 { 359 queryArgs.Add("sslmode", c.SSLMode) 360 } 361 if c.ConnectTimeout > 0 { 362 setTimeoutSeconds(queryArgs, "connect_timeout", c.ConnectTimeout) 363 } 364 if c.LockTimeout > 0 { 365 setTimeoutMilliseconds(queryArgs, "lock_timeout", c.LockTimeout) 366 } 367 if c.StatementTimeout > 0 { 368 setTimeoutMilliseconds(queryArgs, "statement_timeout", c.StatementTimeout) 369 } 370 if c.Schema != "" { 371 queryArgs.Add("search_path", c.Schema) 372 } 373 if c.ApplicationName != "" { 374 queryArgs.Add("application_name", c.ApplicationName) 375 } 376 377 dsn.RawQuery = queryArgs.Encode() 378 return dsn.String() 379 } 380 381 // CreateLoggingDSN creates a postgres connection string from the config suitable for logging. 382 // It will not include the password. 383 func (c Config) CreateLoggingDSN() string { 384 if c.DSN != "" { 385 nc, err := NewConfigFromDSN(c.DSN) 386 if err != nil { 387 return "Failed to parse DSN: see DATABASE_URL environment variable" 388 } 389 return nc.CreateLoggingDSN() 390 } 391 392 // NOTE: Since `c` is a value receiver, we can modify it without 393 // mutating the actual value. 394 c.Password = "" 395 return c.CreateDSN() 396 } 397 398 // Validate validates that user-provided values are valid, e.g. that timeouts 399 // can be exactly rounded into a multiple of a given base value. 400 func (c Config) Validate() error { 401 if c.ConnectTimeout.Round(time.Second) != c.ConnectTimeout { 402 return ex.New(ErrDurationConversion, ex.OptMessagef("connect_timeout=%s", c.ConnectTimeout)) 403 } 404 if c.LockTimeout.Round(time.Millisecond) != c.LockTimeout { 405 return ex.New(ErrDurationConversion, ex.OptMessagef("lock_timeout=%s", c.LockTimeout)) 406 } 407 if c.StatementTimeout.Round(time.Millisecond) != c.StatementTimeout { 408 return ex.New(ErrDurationConversion, ex.OptMessagef("statement_timeout=%s", c.StatementTimeout)) 409 } 410 411 return nil 412 } 413 414 // ValidateProduction validates production configuration for the config. 415 func (c Config) ValidateProduction() error { 416 if !(len(c.SSLMode) == 0 || 417 stringutil.EqualsCaseless(c.SSLMode, SSLModeRequire) || 418 stringutil.EqualsCaseless(c.SSLMode, SSLModeVerifyCA) || 419 stringutil.EqualsCaseless(c.SSLMode, SSLModeVerifyFull)) { 420 return ex.New(ErrUnsafeSSLMode, ex.OptMessagef("sslmode: %s", c.SSLMode)) 421 } 422 if len(c.Username) == 0 { 423 return ex.New(ErrUsernameUnset) 424 } 425 if len(c.Password) == 0 { 426 return ex.New(ErrPasswordUnset) 427 } 428 return c.Validate() 429 } 430 431 // setTimeoutMilliseconds sets a timeout value in connection string query parameters. 432 // 433 // Valid units for this parameter in PostgresSQL are "ms", "s", "min", "h" 434 // and "d" and the value should be between 0 and 2147483647ms. We explicitly 435 // cast to milliseconds but leave validation on the value to PostgreSQL. 436 // 437 // blend=> BEGIN; 438 // BEGIN 439 // blend=> SET LOCAL lock_timeout TO '4000ms'; 440 // SET 441 // blend=> SHOW lock_timeout; 442 // lock_timeout 443 // -------------- 444 // 4s 445 // (1 row) 446 // -- 447 // blend=> SET LOCAL lock_timeout TO '4500ms'; 448 // SET 449 // blend=> SHOW lock_timeout; 450 // lock_timeout 451 // -------------- 452 // 4500ms 453 // (1 row) 454 // -- 455 // blend=> SET LOCAL lock_timeout = 'go'; 456 // ERROR: invalid value for parameter "lock_timeout": "go" 457 // blend=> SET LOCAL lock_timeout = '1ns'; 458 // ERROR: invalid value for parameter "lock_timeout": "1ns" 459 // HINT: Valid units for this parameter are "ms", "s", "min", "h", and "d". 460 // blend=> SET LOCAL lock_timeout = '-1ms'; 461 // ERROR: -1 is outside the valid range for parameter "lock_timeout" (0 .. 2147483647) 462 // -- 463 // blend=> COMMIT; 464 // COMMIT 465 // 466 // See: 467 // - https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-LOCK-TIMEOUT 468 // - https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT 469 func setTimeoutMilliseconds(q url.Values, name string, d time.Duration) { 470 ms := d.Round(time.Millisecond) / time.Millisecond 471 q.Add(name, fmt.Sprintf("%dms", ms)) 472 } 473 474 // setTimeoutSeconds sets a timeout value in connection string query parameters. 475 // 476 // This timeout is expected to be an exact number of seconds (as an integer) 477 // so we convert `d` to an integer first and set the value as a query parameter 478 // without units. 479 // 480 // See: 481 // - https://www.postgresql.org/docs/10/libpq-connect.html#LIBPQ-CONNECT-CONNECT-TIMEOUT 482 func setTimeoutSeconds(q url.Values, name string, d time.Duration) { 483 s := d.Round(time.Second) / time.Second 484 q.Add(name, fmt.Sprintf("%d", s)) 485 }