github.com/hashicorp/vault/sdk@v0.13.0/database/helper/connutil/sql.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package connutil
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"fmt"
    10  	"net/url"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/hashicorp/errwrap"
    16  	"github.com/hashicorp/go-secure-stdlib/parseutil"
    17  	"github.com/hashicorp/go-uuid"
    18  	"github.com/hashicorp/vault/sdk/database/dbplugin"
    19  	"github.com/hashicorp/vault/sdk/database/helper/dbutil"
    20  	"github.com/mitchellh/mapstructure"
    21  )
    22  
    23  const (
    24  	AuthTypeGCPIAM = "gcp_iam"
    25  
    26  	dbTypePostgres   = "pgx"
    27  	cloudSQLPostgres = "cloudsql-postgres"
    28  )
    29  
    30  var _ ConnectionProducer = &SQLConnectionProducer{}
    31  
    32  // SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
    33  type SQLConnectionProducer struct {
    34  	ConnectionURL            string      `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
    35  	MaxOpenConnections       int         `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
    36  	MaxIdleConnections       int         `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"`
    37  	MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
    38  	Username                 string      `json:"username" mapstructure:"username" structs:"username"`
    39  	Password                 string      `json:"password" mapstructure:"password" structs:"password"`
    40  	AuthType                 string      `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"`
    41  	ServiceAccountJSON       string      `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"`
    42  	DisableEscaping          bool        `json:"disable_escaping" mapstructure:"disable_escaping" structs:"disable_escaping"`
    43  
    44  	// cloud options here - cloudDriverName is globally unique, but only needs to be retained for the lifetime
    45  	// of driver registration, not across plugin restarts.
    46  	cloudDriverName    string
    47  	cloudDialerCleanup func() error
    48  
    49  	Type                  string
    50  	RawConfig             map[string]interface{}
    51  	maxConnectionLifetime time.Duration
    52  	Initialized           bool
    53  	db                    *sql.DB
    54  	sync.Mutex
    55  }
    56  
    57  func (c *SQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
    58  	_, err := c.Init(ctx, conf, verifyConnection)
    59  	return err
    60  }
    61  
    62  func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
    63  	c.Lock()
    64  	defer c.Unlock()
    65  
    66  	c.RawConfig = conf
    67  
    68  	err := mapstructure.WeakDecode(conf, &c)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	if len(c.ConnectionURL) == 0 {
    74  		return nil, fmt.Errorf("connection_url cannot be empty")
    75  	}
    76  
    77  	// Do not allow the username or password template pattern to be used as
    78  	// part of the user-supplied username or password
    79  	if strings.Contains(c.Username, "{{username}}") ||
    80  		strings.Contains(c.Username, "{{password}}") ||
    81  		strings.Contains(c.Password, "{{username}}") ||
    82  		strings.Contains(c.Password, "{{password}}") {
    83  
    84  		return nil, fmt.Errorf("username and/or password cannot contain the template variables")
    85  	}
    86  
    87  	// Don't escape special characters for MySQL password
    88  	// Also don't escape special characters for the username and password if
    89  	// the disable_escaping parameter is set to true
    90  	username := c.Username
    91  	password := c.Password
    92  	if !c.DisableEscaping {
    93  		username = url.PathEscape(c.Username)
    94  	}
    95  	if (c.Type != "mysql") && !c.DisableEscaping {
    96  		password = url.PathEscape(c.Password)
    97  	}
    98  
    99  	// QueryHelper doesn't do any SQL escaping, but if it starts to do so
   100  	// then maybe we won't be able to use it to do URL substitution any more.
   101  	c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
   102  		"username": username,
   103  		"password": password,
   104  	})
   105  
   106  	if c.MaxOpenConnections == 0 {
   107  		c.MaxOpenConnections = 4
   108  	}
   109  
   110  	if c.MaxIdleConnections == 0 {
   111  		c.MaxIdleConnections = c.MaxOpenConnections
   112  	}
   113  	if c.MaxIdleConnections > c.MaxOpenConnections {
   114  		c.MaxIdleConnections = c.MaxOpenConnections
   115  	}
   116  	if c.MaxConnectionLifetimeRaw == nil {
   117  		c.MaxConnectionLifetimeRaw = "0s"
   118  	}
   119  
   120  	c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
   121  	if err != nil {
   122  		return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err)
   123  	}
   124  
   125  	// validate auth_type if provided
   126  	authType := c.AuthType
   127  	if authType != "" {
   128  		if ok := ValidateAuthType(authType); !ok {
   129  			return nil, fmt.Errorf("invalid auth_type %s provided", authType)
   130  		}
   131  	}
   132  
   133  	if authType == AuthTypeGCPIAM {
   134  		c.cloudDriverName, err = uuid.GenerateUUID()
   135  		if err != nil {
   136  			return nil, fmt.Errorf("unable to generate UUID for IAM configuration: %w", err)
   137  		}
   138  
   139  		// for _most_ sql databases, the driver itself contains no state. In the case of google's cloudsql drivers,
   140  		// however, the driver might store a credentials file, in which case the state stored by the driver is in
   141  		// fact critical to the proper function of the connection. So it needs to be registered here inside the
   142  		// ConnectionProducer init.
   143  		dialerCleanup, err := c.registerDrivers(c.cloudDriverName, c.ServiceAccountJSON)
   144  		if err != nil {
   145  			return nil, err
   146  		}
   147  
   148  		c.cloudDialerCleanup = dialerCleanup
   149  	}
   150  
   151  	// Set initialized to true at this point since all fields are set,
   152  	// and the connection can be established at a later time.
   153  	c.Initialized = true
   154  
   155  	if verifyConnection {
   156  		if _, err := c.Connection(ctx); err != nil {
   157  			return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
   158  		}
   159  
   160  		if err := c.db.PingContext(ctx); err != nil {
   161  			return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
   162  		}
   163  	}
   164  
   165  	return c.RawConfig, nil
   166  }
   167  
   168  func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
   169  	if !c.Initialized {
   170  		return nil, ErrNotInitialized
   171  	}
   172  
   173  	// If we already have a DB, test it and return
   174  	if c.db != nil {
   175  		if err := c.db.PingContext(ctx); err == nil {
   176  			return c.db, nil
   177  		}
   178  		// If the ping was unsuccessful, close it and ignore errors as we'll be
   179  		// reestablishing anyways
   180  		c.db.Close()
   181  
   182  		// if IAM authentication is enabled
   183  		// ensure open dialer is also closed
   184  		if c.AuthType == AuthTypeGCPIAM {
   185  			if c.cloudDialerCleanup != nil {
   186  				c.cloudDialerCleanup()
   187  			}
   188  		}
   189  	}
   190  
   191  	// default non-IAM behavior
   192  	driverName := c.Type
   193  
   194  	if c.AuthType == AuthTypeGCPIAM {
   195  		driverName = c.cloudDriverName
   196  	} else if c.Type == "mssql" {
   197  		// For mssql backend, switch to sqlserver instead
   198  		driverName = "sqlserver"
   199  	}
   200  
   201  	// Otherwise, attempt to make connection
   202  	conn := c.ConnectionURL
   203  
   204  	// PostgreSQL specific settings
   205  	if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
   206  		// Ensure timezone is set to UTC for all the connections
   207  		if strings.Contains(conn, "?") {
   208  			conn += "&timezone=UTC"
   209  		} else {
   210  			conn += "?timezone=UTC"
   211  		}
   212  
   213  		// Ensure a reasonable application_name is set
   214  		if !strings.Contains(conn, "application_name") {
   215  			conn += "&application_name=vault"
   216  		}
   217  	}
   218  
   219  	var err error
   220  	c.db, err = sql.Open(driverName, conn)
   221  	if err != nil {
   222  		return nil, err
   223  	}
   224  
   225  	// Set some connection pool settings. We don't need much of this,
   226  	// since the request rate shouldn't be high.
   227  	c.db.SetMaxOpenConns(c.MaxOpenConnections)
   228  	c.db.SetMaxIdleConns(c.MaxIdleConnections)
   229  	c.db.SetConnMaxLifetime(c.maxConnectionLifetime)
   230  
   231  	return c.db, nil
   232  }
   233  
   234  func (c *SQLConnectionProducer) SecretValues() map[string]interface{} {
   235  	return map[string]interface{}{
   236  		c.Password: "[password]",
   237  	}
   238  }
   239  
   240  // Close attempts to close the connection
   241  func (c *SQLConnectionProducer) Close() error {
   242  	// Grab the write lock
   243  	c.Lock()
   244  	defer c.Unlock()
   245  
   246  	if c.db != nil {
   247  		c.db.Close()
   248  
   249  		// cleanup IAM dialer if it exists
   250  		if c.AuthType == AuthTypeGCPIAM {
   251  			if c.cloudDialerCleanup != nil {
   252  				c.cloudDialerCleanup()
   253  			}
   254  		}
   255  	}
   256  
   257  	c.db = nil
   258  
   259  	return nil
   260  }
   261  
   262  // SetCredentials uses provided information to set/create a user in the
   263  // database. Unlike CreateUser, this method requires a username be provided and
   264  // uses the name given, instead of generating a name. This is used for creating
   265  // and setting the password of static accounts, as well as rolling back
   266  // passwords in the database in the event an updated database fails to save in
   267  // Vault's storage.
   268  func (c *SQLConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
   269  	return "", "", dbutil.Unimplemented()
   270  }