github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/creds/vault/api_client.go (about)

     1  package vault
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"path"
     9  	"strings"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"code.cloudfoundry.org/lager"
    14  	"github.com/hashicorp/go-rootcerts"
    15  	vaultapi "github.com/hashicorp/vault/api"
    16  )
    17  
    18  // The APIClient is a SecretReader which maintains an authorized
    19  // client using the Login and Renew functions.
    20  type APIClient struct {
    21  	logger lager.Logger
    22  
    23  	apiURL     string
    24  	namespace  string
    25  	tlsConfig  TLSConfig
    26  	authConfig AuthConfig
    27  
    28  	clientValue *atomic.Value
    29  
    30  	renewable bool
    31  }
    32  
    33  // NewAPIClient with the associated authorization config and underlying vault client.
    34  func NewAPIClient(logger lager.Logger, apiURL string, tlsConfig TLSConfig, authConfig AuthConfig, namespace string) (*APIClient, error) {
    35  	ac := &APIClient{
    36  		logger: logger,
    37  
    38  		apiURL:     apiURL,
    39  		namespace:  namespace,
    40  		tlsConfig:  tlsConfig,
    41  		authConfig: authConfig,
    42  
    43  		clientValue: &atomic.Value{},
    44  
    45  		renewable: true,
    46  	}
    47  
    48  	client, err := ac.baseClient()
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	ac.setClient(client)
    54  
    55  	return ac, nil
    56  }
    57  
    58  // Read must be called after a successful login has occurred or an
    59  // un-authorized client will be used.
    60  func (ac *APIClient) Read(path string) (*vaultapi.Secret, error) {
    61  	return ac.client().Logical().Read(path)
    62  }
    63  
    64  func (ac *APIClient) loginParams() map[string]interface{} {
    65  	loginParams := make(map[string]interface{})
    66  	for k, v := range ac.authConfig.Params {
    67  		loginParams[k] = v
    68  	}
    69  
    70  	return loginParams
    71  }
    72  
    73  // Login the APIClient using the credentials passed at
    74  // construction. Returns a duration after which renew must be called.
    75  func (ac *APIClient) Login() (time.Duration, error) {
    76  	logger := ac.logger.Session("login")
    77  
    78  	// If we are configured with a client token return right away
    79  	// and trigger a renewal.
    80  	if ac.authConfig.ClientToken != "" {
    81  		newClient, err := ac.clientWithToken(ac.authConfig.ClientToken)
    82  		if err != nil {
    83  			logger.Error("failed-to-create-client", err)
    84  			return time.Second, err
    85  		}
    86  
    87  		ac.setClient(newClient)
    88  
    89  		logger.Info("token-set")
    90  
    91  		return time.Second, nil
    92  	}
    93  
    94  	client := ac.client()
    95  	loginPath := path.Join("auth", ac.authConfig.Backend, "login")
    96  
    97  	if ac.authConfig.Backend == "ldap" || ac.authConfig.Backend == "okta" {
    98  		username, ok := ac.loginParams()["username"].(string)
    99  		if !ok {
   100  			err := fmt.Errorf("failed to assert username as string")
   101  			logger.Error("failed", err)
   102  			return time.Second, err
   103  		}
   104  		loginPath = path.Join("auth", ac.authConfig.Backend, "login", username)
   105  	}
   106  
   107  	secret, err := client.Logical().Write(loginPath, ac.loginParams())
   108  	if err != nil {
   109  		logger.Error("failed", err)
   110  		return time.Second, err
   111  	}
   112  
   113  	logger.Info("succeeded", lager.Data{
   114  		"token-accessor": secret.Auth.Accessor,
   115  		"lease-duration": secret.Auth.LeaseDuration,
   116  		"policies":       secret.Auth.Policies,
   117  	})
   118  
   119  	newClient, err := ac.clientWithToken(secret.Auth.ClientToken)
   120  	if err != nil {
   121  		logger.Error("failed-to-create-client", err)
   122  		return time.Second, err
   123  	}
   124  
   125  	ac.setClient(newClient)
   126  
   127  	return time.Duration(secret.Auth.LeaseDuration) * time.Second, nil
   128  }
   129  
   130  // Renew the APIClient login using the credentials passed at
   131  // construction. Must be called after a successful login. Returns a
   132  // duration after which renew must be called again.
   133  func (ac *APIClient) Renew() (time.Duration, error) {
   134  	if !ac.renewable {
   135  		return time.Second, nil
   136  	}
   137  
   138  	logger := ac.logger.Session("renew")
   139  
   140  	client := ac.client()
   141  
   142  	secret, err := client.Auth().Token().RenewSelf(0)
   143  	if err != nil {
   144  		// When tests with a Vault dev server, renew is not allowed.
   145  		if strings.Index(err.Error(), "lease is not renewable") > 0 {
   146  			ac.renewable = false
   147  			return time.Second, nil
   148  		}
   149  		logger.Error("failed", err)
   150  		return time.Second, err
   151  	}
   152  
   153  	logger.Info("succeeded", lager.Data{
   154  		"token-accessor": secret.Auth.Accessor,
   155  		"lease-duration": secret.Auth.LeaseDuration,
   156  		"policies":       secret.Auth.Policies,
   157  	})
   158  
   159  	newClient, err := ac.clientWithToken(secret.Auth.ClientToken)
   160  	if err != nil {
   161  		logger.Error("failed-to-create-client", err)
   162  		return time.Second, err
   163  	}
   164  
   165  	ac.setClient(newClient)
   166  
   167  	return time.Duration(secret.Auth.LeaseDuration) * time.Second, nil
   168  }
   169  
   170  func (ac *APIClient) client() *vaultapi.Client {
   171  	return ac.clientValue.Load().(*vaultapi.Client)
   172  }
   173  
   174  func (ac *APIClient) setClient(client *vaultapi.Client) {
   175  	ac.clientValue.Store(client)
   176  }
   177  
   178  func (ac *APIClient) baseClient() (*vaultapi.Client, error) {
   179  	config := vaultapi.DefaultConfig()
   180  
   181  	err := ac.configureTLS(config.HttpClient.Transport.(*http.Transport).TLSClientConfig)
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  
   186  	client, err := vaultapi.NewClient(config)
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  
   191  	err = client.SetAddress(ac.apiURL)
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  
   196  	if ac.namespace != "" {
   197  		client.SetNamespace(ac.namespace)
   198  	}
   199  
   200  	return client, nil
   201  }
   202  
   203  func (ac *APIClient) configureTLS(config *tls.Config) error {
   204  	if ac.tlsConfig.CACert != "" || ac.tlsConfig.CACertFile != "" || ac.tlsConfig.CAPath != "" {
   205  		rootConfig := &rootcerts.Config{
   206  			CAFile:        ac.tlsConfig.CACertFile,
   207  			CAPath:        ac.tlsConfig.CAPath,
   208  			CACertificate: []byte(ac.tlsConfig.CACert),
   209  		}
   210  
   211  		if err := rootcerts.ConfigureTLS(config, rootConfig); err != nil {
   212  			return err
   213  		}
   214  	}
   215  
   216  	if ac.tlsConfig.ClientCertFile != "" {
   217  		content, err := ioutil.ReadFile(ac.tlsConfig.ClientCertFile)
   218  		if err != nil {
   219  			return err
   220  		}
   221  
   222  		ac.tlsConfig.ClientCert = string(content)
   223  	}
   224  
   225  	if ac.tlsConfig.ClientKeyFile != "" {
   226  		content, err := ioutil.ReadFile(ac.tlsConfig.ClientKeyFile)
   227  		if err != nil {
   228  			return err
   229  		}
   230  
   231  		ac.tlsConfig.ClientKey = string(content)
   232  	}
   233  
   234  	if ac.tlsConfig.Insecure {
   235  		config.InsecureSkipVerify = true
   236  	}
   237  
   238  	var clientCert tls.Certificate
   239  	foundClientCert := false
   240  
   241  	switch {
   242  	case ac.tlsConfig.ClientCert != "" && ac.tlsConfig.ClientKey != "":
   243  		var err error
   244  		clientCert, err = tls.X509KeyPair([]byte(ac.tlsConfig.ClientCert), []byte(ac.tlsConfig.ClientKey))
   245  		if err != nil {
   246  			return err
   247  		}
   248  
   249  		foundClientCert = true
   250  	case ac.tlsConfig.ClientCert != "" || ac.tlsConfig.ClientKey != "":
   251  		return fmt.Errorf("both client cert and client key must be provided")
   252  	}
   253  
   254  	if foundClientCert {
   255  		// We use this function to ignore the server's preferential list of
   256  		// CAs, otherwise any CA used for the cert auth backend must be in the
   257  		// server's CA pool
   258  		config.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
   259  			return &clientCert, nil
   260  		}
   261  	}
   262  
   263  	if ac.tlsConfig.ServerName != "" {
   264  		config.ServerName = ac.tlsConfig.ServerName
   265  	}
   266  
   267  	return nil
   268  }
   269  
   270  func (ac *APIClient) clientWithToken(token string) (*vaultapi.Client, error) {
   271  	client, err := ac.baseClient()
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  
   276  	client.SetToken(token)
   277  
   278  	return client, nil
   279  }
   280  
   281  func (ac *APIClient) health() (*vaultapi.HealthResponse, error) {
   282  	client, err := ac.baseClient()
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  
   287  	healthResponse, err := client.Sys().Health()
   288  	return healthResponse, err
   289  }