github.com/hashicorp/vault/sdk@v0.13.0/database/helper/connutil/sql.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package connutil 5 6 import ( 7 "context" 8 "database/sql" 9 "fmt" 10 "net/url" 11 "strings" 12 "sync" 13 "time" 14 15 "github.com/hashicorp/errwrap" 16 "github.com/hashicorp/go-secure-stdlib/parseutil" 17 "github.com/hashicorp/go-uuid" 18 "github.com/hashicorp/vault/sdk/database/dbplugin" 19 "github.com/hashicorp/vault/sdk/database/helper/dbutil" 20 "github.com/mitchellh/mapstructure" 21 ) 22 23 const ( 24 AuthTypeGCPIAM = "gcp_iam" 25 26 dbTypePostgres = "pgx" 27 cloudSQLPostgres = "cloudsql-postgres" 28 ) 29 30 var _ ConnectionProducer = &SQLConnectionProducer{} 31 32 // SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases 33 type SQLConnectionProducer struct { 34 ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"` 35 MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"` 36 MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"` 37 MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"` 38 Username string `json:"username" mapstructure:"username" structs:"username"` 39 Password string `json:"password" mapstructure:"password" structs:"password"` 40 AuthType string `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"` 41 ServiceAccountJSON string `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"` 42 DisableEscaping bool `json:"disable_escaping" mapstructure:"disable_escaping" structs:"disable_escaping"` 43 44 // cloud options here - cloudDriverName is globally unique, but only needs to be retained for the lifetime 45 // of driver registration, not across plugin restarts. 46 cloudDriverName string 47 cloudDialerCleanup func() error 48 49 Type string 50 RawConfig map[string]interface{} 51 maxConnectionLifetime time.Duration 52 Initialized bool 53 db *sql.DB 54 sync.Mutex 55 } 56 57 func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error { 58 _, err := c.Init(ctx, conf, verifyConnection) 59 return err 60 } 61 62 func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) { 63 c.Lock() 64 defer c.Unlock() 65 66 c.RawConfig = conf 67 68 err := mapstructure.WeakDecode(conf, &c) 69 if err != nil { 70 return nil, err 71 } 72 73 if len(c.ConnectionURL) == 0 { 74 return nil, fmt.Errorf("connection_url cannot be empty") 75 } 76 77 // Do not allow the username or password template pattern to be used as 78 // part of the user-supplied username or password 79 if strings.Contains(c.Username, "{{username}}") || 80 strings.Contains(c.Username, "{{password}}") || 81 strings.Contains(c.Password, "{{username}}") || 82 strings.Contains(c.Password, "{{password}}") { 83 84 return nil, fmt.Errorf("username and/or password cannot contain the template variables") 85 } 86 87 // Don't escape special characters for MySQL password 88 // Also don't escape special characters for the username and password if 89 // the disable_escaping parameter is set to true 90 username := c.Username 91 password := c.Password 92 if !c.DisableEscaping { 93 username = url.PathEscape(c.Username) 94 } 95 if (c.Type != "mysql") && !c.DisableEscaping { 96 password = url.PathEscape(c.Password) 97 } 98 99 // QueryHelper doesn't do any SQL escaping, but if it starts to do so 100 // then maybe we won't be able to use it to do URL substitution any more. 101 c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{ 102 "username": username, 103 "password": password, 104 }) 105 106 if c.MaxOpenConnections == 0 { 107 c.MaxOpenConnections = 4 108 } 109 110 if c.MaxIdleConnections == 0 { 111 c.MaxIdleConnections = c.MaxOpenConnections 112 } 113 if c.MaxIdleConnections > c.MaxOpenConnections { 114 c.MaxIdleConnections = c.MaxOpenConnections 115 } 116 if c.MaxConnectionLifetimeRaw == nil { 117 c.MaxConnectionLifetimeRaw = "0s" 118 } 119 120 c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw) 121 if err != nil { 122 return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err) 123 } 124 125 // validate auth_type if provided 126 authType := c.AuthType 127 if authType != "" { 128 if ok := ValidateAuthType(authType); !ok { 129 return nil, fmt.Errorf("invalid auth_type %s provided", authType) 130 } 131 } 132 133 if authType == AuthTypeGCPIAM { 134 c.cloudDriverName, err = uuid.GenerateUUID() 135 if err != nil { 136 return nil, fmt.Errorf("unable to generate UUID for IAM configuration: %w", err) 137 } 138 139 // for _most_ sql databases, the driver itself contains no state. In the case of google's cloudsql drivers, 140 // however, the driver might store a credentials file, in which case the state stored by the driver is in 141 // fact critical to the proper function of the connection. So it needs to be registered here inside the 142 // ConnectionProducer init. 143 dialerCleanup, err := c.registerDrivers(c.cloudDriverName, c.ServiceAccountJSON) 144 if err != nil { 145 return nil, err 146 } 147 148 c.cloudDialerCleanup = dialerCleanup 149 } 150 151 // Set initialized to true at this point since all fields are set, 152 // and the connection can be established at a later time. 153 c.Initialized = true 154 155 if verifyConnection { 156 if _, err := c.Connection(ctx); err != nil { 157 return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) 158 } 159 160 if err := c.db.PingContext(ctx); err != nil { 161 return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) 162 } 163 } 164 165 return c.RawConfig, nil 166 } 167 168 func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) { 169 if !c.Initialized { 170 return nil, ErrNotInitialized 171 } 172 173 // If we already have a DB, test it and return 174 if c.db != nil { 175 if err := c.db.PingContext(ctx); err == nil { 176 return c.db, nil 177 } 178 // If the ping was unsuccessful, close it and ignore errors as we'll be 179 // reestablishing anyways 180 c.db.Close() 181 182 // if IAM authentication is enabled 183 // ensure open dialer is also closed 184 if c.AuthType == AuthTypeGCPIAM { 185 if c.cloudDialerCleanup != nil { 186 c.cloudDialerCleanup() 187 } 188 } 189 } 190 191 // default non-IAM behavior 192 driverName := c.Type 193 194 if c.AuthType == AuthTypeGCPIAM { 195 driverName = c.cloudDriverName 196 } else if c.Type == "mssql" { 197 // For mssql backend, switch to sqlserver instead 198 driverName = "sqlserver" 199 } 200 201 // Otherwise, attempt to make connection 202 conn := c.ConnectionURL 203 204 // PostgreSQL specific settings 205 if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { 206 // Ensure timezone is set to UTC for all the connections 207 if strings.Contains(conn, "?") { 208 conn += "&timezone=UTC" 209 } else { 210 conn += "?timezone=UTC" 211 } 212 213 // Ensure a reasonable application_name is set 214 if !strings.Contains(conn, "application_name") { 215 conn += "&application_name=vault" 216 } 217 } 218 219 var err error 220 c.db, err = sql.Open(driverName, conn) 221 if err != nil { 222 return nil, err 223 } 224 225 // Set some connection pool settings. We don't need much of this, 226 // since the request rate shouldn't be high. 227 c.db.SetMaxOpenConns(c.MaxOpenConnections) 228 c.db.SetMaxIdleConns(c.MaxIdleConnections) 229 c.db.SetConnMaxLifetime(c.maxConnectionLifetime) 230 231 return c.db, nil 232 } 233 234 func (c *SQLConnectionProducer) SecretValues() map[string]interface{} { 235 return map[string]interface{}{ 236 c.Password: "[password]", 237 } 238 } 239 240 // Close attempts to close the connection 241 func (c *SQLConnectionProducer) Close() error { 242 // Grab the write lock 243 c.Lock() 244 defer c.Unlock() 245 246 if c.db != nil { 247 c.db.Close() 248 249 // cleanup IAM dialer if it exists 250 if c.AuthType == AuthTypeGCPIAM { 251 if c.cloudDialerCleanup != nil { 252 c.cloudDialerCleanup() 253 } 254 } 255 } 256 257 c.db = nil 258 259 return nil 260 } 261 262 // SetCredentials uses provided information to set/create a user in the 263 // database. Unlike CreateUser, this method requires a username be provided and 264 // uses the name given, instead of generating a name. This is used for creating 265 // and setting the password of static accounts, as well as rolling back 266 // passwords in the database in the event an updated database fails to save in 267 // Vault's storage. 268 func (c *SQLConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { 269 return "", "", dbutil.Unimplemented() 270 }