github.com/wanddynosios/cli/v8@v8.7.9-0.20240221182337-1a92e3a7017f/api/shared/wrap_for_cf_on_k8s.go (about)

     1  package shared
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/base64"
     8  	"encoding/pem"
     9  	"errors"
    10  	"fmt"
    11  	"net/http"
    12  
    13  	"k8s.io/client-go/rest"
    14  	"k8s.io/client-go/tools/clientcmd"
    15  	"k8s.io/client-go/tools/clientcmd/api"
    16  	"k8s.io/client-go/transport"
    17  
    18  	"code.cloudfoundry.org/cli/actor/v7action"
    19  	"code.cloudfoundry.org/cli/command"
    20  
    21  	// imported for the side effects
    22  	_ "k8s.io/client-go/plugin/pkg/client/auth/azure"
    23  	_ "k8s.io/client-go/plugin/pkg/client/auth/gcp"
    24  	_ "k8s.io/client-go/plugin/pkg/client/auth/oidc"
    25  )
    26  
    27  //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 net/http.RoundTripper
    28  
    29  func WrapForCFOnK8sAuth(config command.Config, k8sConfigGetter v7action.KubernetesConfigGetter, roundTripper http.RoundTripper) (http.RoundTripper, error) {
    30  	username, err := config.CurrentUserName()
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  	if username == "" {
    35  		return nil, errors.New("current user not set")
    36  	}
    37  
    38  	k8sConfig, err := k8sConfigGetter.Get()
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  
    43  	restConfig, err := clientcmd.NewDefaultClientConfig(
    44  		*k8sConfig,
    45  		&clientcmd.ConfigOverrides{
    46  			Context: api.Context{AuthInfo: username},
    47  		},
    48  	).ClientConfig()
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	// Special case for certs, since we don't want mtls
    54  	cert, err := getCert(restConfig)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  
    59  	transportConfig, err := restConfig.TransportConfig()
    60  	if err != nil {
    61  		return nil, fmt.Errorf("failed to get transport config: %w", err)
    62  	}
    63  
    64  	if cert != nil {
    65  		return certRoundTripper{
    66  			cert:         cert,
    67  			roundTripper: roundTripper,
    68  		}, nil
    69  	}
    70  
    71  	if transportConfig.WrapTransport == nil {
    72  		// i.e. not auth-provider or exec plugin
    73  		return transport.HTTPWrappersForConfig(transportConfig, roundTripper)
    74  	}
    75  
    76  	// using auth provider to generate token
    77  	return transportConfig.WrapTransport(roundTripper), nil
    78  }
    79  
    80  func getCert(restConfig *rest.Config) (*tls.Certificate, error) {
    81  	tlsConfig, err := rest.TLSConfigFor(restConfig)
    82  	if err != nil {
    83  		return nil, fmt.Errorf("failed to get tls config: %w", err)
    84  	}
    85  
    86  	if tlsConfig != nil && tlsConfig.GetClientCertificate != nil {
    87  		cert, err := tlsConfig.GetClientCertificate(nil)
    88  		if err != nil {
    89  			return nil, fmt.Errorf("failed to get client certificate: %w", err)
    90  		}
    91  
    92  		if len(cert.Certificate) > 0 && cert.PrivateKey != nil {
    93  			return cert, nil
    94  		}
    95  	}
    96  	return nil, nil
    97  }
    98  
    99  type certRoundTripper struct {
   100  	cert         *tls.Certificate
   101  	roundTripper http.RoundTripper
   102  }
   103  
   104  func (rt certRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   105  	var buf bytes.Buffer
   106  
   107  	if err := pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: rt.cert.Certificate[0]}); err != nil {
   108  		return nil, fmt.Errorf("could not convert certificate to PEM format: %w", err)
   109  	}
   110  
   111  	key, err := x509.MarshalPKCS8PrivateKey(rt.cert.PrivateKey)
   112  	if err != nil {
   113  		return nil, fmt.Errorf("could not marshal private key: %w", err)
   114  	}
   115  
   116  	if err := pem.Encode(&buf, &pem.Block{Type: "PRIVATE KEY", Bytes: key}); err != nil {
   117  		return nil, fmt.Errorf("could not convert key to PEM format: %w", err)
   118  	}
   119  
   120  	auth := "ClientCert " + base64.StdEncoding.EncodeToString(buf.Bytes())
   121  	req.Header.Set("Authorization", auth)
   122  
   123  	return rt.roundTripper.RoundTrip(req)
   124  }