github.com/microsoft/moc@v0.17.1/pkg/auth/auth.go (about)

     1  // Copyright (c) Microsoft Corporation.
     2  // Licensed under the Apache v2.0 license.
     3  
     4  package auth
     5  
     6  //go:generate mockgen -destination mock/auth_mock.go github.com/microsoft/moc/pkg/auth Authorizer
     7  import (
     8  	context "context"
     9  	"crypto/tls"
    10  	"crypto/x509"
    11  	"errors"
    12  	"fmt"
    13  	"io/fs"
    14  	"io/ioutil"
    15  
    16  	"github.com/microsoft/moc/pkg/config"
    17  	"github.com/microsoft/moc/pkg/marshal"
    18  	"github.com/microsoft/moc/rpc/common"
    19  	"google.golang.org/grpc/credentials"
    20  )
    21  
    22  const (
    23  	ServerName = "ServerName"
    24  )
    25  
    26  type WssdConfig struct {
    27  	CloudCertificate      string
    28  	ClientCertificate     string
    29  	ClientKey             string
    30  	IdentityName          string
    31  	ClientCertificateType LoginType //Depricated : Needs to cleaned up after removing references
    32  }
    33  
    34  type Authorizer interface {
    35  	WithTransportAuthorization() credentials.TransportCredentials
    36  	WithRPCAuthorization() credentials.PerRPCCredentials
    37  }
    38  
    39  type ManagedIdentityConfig struct {
    40  	ClientTokenPath string
    41  	WssdConfigPath  string
    42  	ServerName      string
    43  }
    44  
    45  type ClientType string
    46  
    47  const (
    48  	Admin          ClientType = "Admin"
    49  	BareMetal      ClientType = "BareMetal"
    50  	ControlPlane   ClientType = "ControlPlane"
    51  	ExternalClient ClientType = "ExternalClient"
    52  	LoadBalancer   ClientType = "LoadBalancer"
    53  	Node           ClientType = "Node"
    54  )
    55  
    56  type LoginConfig struct {
    57  	Name          string     `json:"name,omitempty"`
    58  	Token         string     `json:"token,omitempty"`
    59  	Certificate   string     `json:"certificate,omitempty"`
    60  	ClientType    ClientType `json:"clienttype,omitempty"`
    61  	CloudFqdn     string     `json:"cloudfqdn,omitempty"`
    62  	CloudPort     int32      `json:"cloudport,omitempty"`
    63  	CloudAuthPort int32      `json:"cloudauthport,omitempty"`
    64  	Location      string     `json:"location,omitempty"`
    65  	Type          LoginType  `json:"type,omitempty"` //Depricated : Needs to cleaned up after removing references
    66  }
    67  
    68  // LoginType [Depricated : Needs to cleaned up after removing references]
    69  type LoginType string
    70  
    71  const (
    72  	// SelfSigned ...
    73  	SelfSigned LoginType = "Self-Signed"
    74  	// CASigned ...
    75  	CASigned LoginType = "CA-Signed"
    76  )
    77  
    78  func LoginTypeToAuthType(authType string) common.AuthenticationType {
    79  	switch authType {
    80  	case string(SelfSigned):
    81  		return common.AuthenticationType_SELFSIGNED
    82  	case string(CASigned):
    83  		return common.AuthenticationType_CASIGNED
    84  	}
    85  	return common.AuthenticationType_SELFSIGNED
    86  }
    87  
    88  func AuthTypeToLoginType(authType common.AuthenticationType) LoginType {
    89  	switch authType {
    90  	case common.AuthenticationType_SELFSIGNED:
    91  		return SelfSigned
    92  	case common.AuthenticationType_CASIGNED:
    93  		return CASigned
    94  	}
    95  	return SelfSigned
    96  }
    97  
    98  type JwtTokenProvider struct {
    99  	RawData string `json:"rawdata"`
   100  }
   101  
   102  func (c JwtTokenProvider) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
   103  	return map[string]string{
   104  		"authorization": c.RawData,
   105  	}, nil
   106  }
   107  
   108  func (c JwtTokenProvider) RequireTransportSecurity() bool {
   109  	return true
   110  }
   111  
   112  func NewTokenCredentialProvider(token string) JwtTokenProvider {
   113  	return JwtTokenProvider{token}
   114  }
   115  
   116  func NewEmptyTokenCredentialProvider() JwtTokenProvider {
   117  	return JwtTokenProvider{}
   118  }
   119  
   120  type TransportCredentialsProvider struct {
   121  	serverName            string
   122  	certificate           []tls.Certificate
   123  	rootCAPool            *x509.CertPool
   124  	verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
   125  }
   126  
   127  func NewEmptyTransportCredential() *TransportCredentialsProvider {
   128  	return &TransportCredentialsProvider{}
   129  }
   130  
   131  func NewTransportCredentialFromAuthBase64(serverName string, rootCACertsBase64 string) (*TransportCredentialsProvider, error) {
   132  	caCertPem, err := marshal.FromBase64(rootCACertsBase64)
   133  	if err != nil {
   134  		return nil, fmt.Errorf("could not marshal the server certificate")
   135  	}
   136  
   137  	return NewTransportCredentialFromAuthFromPem(serverName, caCertPem)
   138  }
   139  
   140  func NewTransportCredentialFromAuthFromPem(serverName string, caCertPem []byte) (*TransportCredentialsProvider, error) {
   141  	certPool := x509.NewCertPool()
   142  	// Append the client certificates from the CA
   143  	if ok := certPool.AppendCertsFromPEM(caCertPem); !ok {
   144  		return nil, fmt.Errorf("could not append the server certificate")
   145  	}
   146  	return &TransportCredentialsProvider{
   147  		serverName: serverName,
   148  		rootCAPool: certPool,
   149  	}, nil
   150  }
   151  
   152  func NewTransportCredentialFromBase64(serverName, clientCertificateBase64, clientKeyBase64 string, rootCACertsBase64 string) (*TransportCredentialsProvider, error) {
   153  	transportCreds, err := NewTransportCredentialFromAuthBase64(serverName, rootCACertsBase64)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	clientPem, err := marshal.FromBase64(clientCertificateBase64)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	keyPem, err := marshal.FromBase64(clientKeyBase64)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	if err = CertCheck(clientPem); err != nil {
   167  		return nil, err
   168  	}
   169  
   170  	tlsCert, err := tls.X509KeyPair(clientPem, keyPem)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  
   175  	transportCreds.certificate = []tls.Certificate{tlsCert}
   176  
   177  	return transportCreds, nil
   178  }
   179  
   180  func NewTransportCredentialFromTlsCerts(serverName string, tlsCerts []tls.Certificate, rootCACertsPem []byte) (*TransportCredentialsProvider, error) {
   181  	transportCreds, err := NewTransportCredentialFromAuthFromPem(serverName, rootCACertsPem)
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  	transportCreds.certificate = tlsCerts
   186  	return transportCreds, nil
   187  }
   188  
   189  func NewTransportCredentialFromAccessFileLocation(serverName, accessFileLocation string) (*TransportCredentialsProvider, error) {
   190  	accessFile := WssdConfig{}
   191  	err := marshal.FromJSONFile(accessFileLocation, &accessFile)
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  	return NewTransportCredentialFromAccessFile(serverName, accessFile)
   196  }
   197  
   198  func NewTransportCredentialFromAccessFile(serverName string, accessFile WssdConfig) (*TransportCredentialsProvider, error) {
   199  	caCertPem, tlscerts, err := AccessFileToTls(accessFile)
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  	return NewTransportCredentialFromTlsCerts(serverName, []tls.Certificate{tlscerts}, caCertPem)
   204  }
   205  
   206  func (transportCredentials *TransportCredentialsProvider) GetTransportCredentials() credentials.TransportCredentials {
   207  	creds := &tls.Config{
   208  		ServerName: transportCredentials.serverName,
   209  	}
   210  	if len(transportCredentials.certificate) > 0 {
   211  		creds.Certificates = transportCredentials.certificate
   212  	}
   213  	if transportCredentials.rootCAPool != nil {
   214  		creds.RootCAs = transportCredentials.rootCAPool
   215  	}
   216  	if transportCredentials.verifyPeerCertificate != nil {
   217  		creds.VerifyPeerCertificate = transportCredentials.verifyPeerCertificate
   218  	}
   219  	return credentials.NewTLS(creds)
   220  }
   221  
   222  // BearerAuthorizer implements the bearer authorization
   223  type BearerAuthorizer struct {
   224  	tokenProvider        JwtTokenProvider
   225  	transportCredentials credentials.TransportCredentials
   226  }
   227  
   228  func (ba *BearerAuthorizer) WithRPCAuthorization() credentials.PerRPCCredentials {
   229  	return ba.tokenProvider
   230  }
   231  
   232  func (ba *BearerAuthorizer) WithTransportAuthorization() credentials.TransportCredentials {
   233  	return ba.transportCredentials
   234  }
   235  
   236  func NewEmptyBearerAuthorizer() *BearerAuthorizer {
   237  	return &BearerAuthorizer{
   238  		tokenProvider:        NewEmptyTokenCredentialProvider(),
   239  		transportCredentials: NewEmptyBearerAuthorizer().transportCredentials,
   240  	}
   241  }
   242  
   243  // NewBearerAuthorizer crates a BearerAuthorizer using the given token provider
   244  func NewBearerAuthorizer(tp JwtTokenProvider, tc credentials.TransportCredentials) *BearerAuthorizer {
   245  	return &BearerAuthorizer{
   246  		tokenProvider:        tp,
   247  		transportCredentials: tc,
   248  	}
   249  }
   250  
   251  // EnvironmentSettings contains the available authentication settings.
   252  type EnvironmentSettings struct {
   253  	Values map[string]string
   254  }
   255  
   256  func NewAuthorizerFromEnvironment(serverName string) (Authorizer, error) {
   257  	settings := GetSettingsFromEnvironment(serverName)
   258  	err := RenewCertificates(settings.GetManagedIdentityConfig().ServerName, settings.GetManagedIdentityConfig().WssdConfigPath)
   259  	if err != nil {
   260  		return nil, err
   261  	}
   262  	return settings.GetAuthorizer()
   263  }
   264  
   265  func NewAuthorizerFromEnvironmentByName(serverName, subfolder, filename string) (Authorizer, error) {
   266  	settings, err := GetSettingsFromEnvironmentByName(serverName, subfolder, filename)
   267  	if err != nil {
   268  		return nil, err
   269  	}
   270  	err = RenewCertificates(settings.GetManagedIdentityConfig().ServerName, settings.GetManagedIdentityConfig().WssdConfigPath)
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  	return settings.GetAuthorizer()
   275  }
   276  
   277  func NewAuthorizerFromInput(tlsCert tls.Certificate, serverCertificate []byte, server string) (Authorizer, error) {
   278  	transportCreds := TransportCredentialsFromNode(tlsCert, serverCertificate, server)
   279  	return NewBearerAuthorizer(NewEmptyTokenCredentialProvider(), transportCreds), nil
   280  }
   281  
   282  func NewAuthorizerForAuth(tokenString string, certificate string, server string) (Authorizer, error) {
   283  	credentials, err := NewTransportCredentialFromAuthBase64(server, certificate)
   284  	if err != nil {
   285  		return NewEmptyBearerAuthorizer(), err
   286  	}
   287  	return NewBearerAuthorizer(NewTokenCredentialProvider(tokenString), credentials.GetTransportCredentials()), nil
   288  }
   289  
   290  // GetSettingsFromEnvironment Read settings from WssdConfigLocation
   291  func GetSettingsFromEnvironment(serverName string) (s EnvironmentSettings) {
   292  	s = EnvironmentSettings{
   293  		Values: map[string]string{},
   294  	}
   295  	s.Values[ClientTokenPath] = getClientTokenLocation()
   296  	s.Values[WssdConfigPath] = GetWssdConfigLocation()
   297  
   298  	s.Values[ServerName] = serverName
   299  
   300  	return
   301  }
   302  
   303  // GetSettingsFromEnvironmentByName Read settings from GetWssdConfigLocationName
   304  func GetSettingsFromEnvironmentByName(serverName, subfolder, filename string) (s EnvironmentSettings, err error) {
   305  	s = EnvironmentSettings{
   306  		Values: map[string]string{},
   307  	}
   308  	s.Values[ClientTokenPath] = getClientTokenLocation()
   309  	s.Values[WssdConfigPath] = GetMocConfigLocationName(subfolder, filename)
   310  	s.Values[ServerName] = serverName
   311  
   312  	return
   313  }
   314  
   315  func (settings EnvironmentSettings) GetAuthorizer() (Authorizer, error) {
   316  	return settings.GetManagedIdentityConfig().Authorizer()
   317  }
   318  
   319  func (settings EnvironmentSettings) GetManagedIdentityConfig() ManagedIdentityConfig {
   320  	return ManagedIdentityConfig{
   321  		settings.Values[ClientTokenPath],
   322  		settings.Values[WssdConfigPath],
   323  		settings.Values[ServerName],
   324  	}
   325  }
   326  
   327  func (mc ManagedIdentityConfig) Authorizer() (Authorizer, error) {
   328  
   329  	jwtCreds, err := TokenProviderFromFile(mc.ClientTokenPath)
   330  	if err != nil {
   331  		return nil, err
   332  	}
   333  	transportCreds := TransportCredentialsFromFile(mc.WssdConfigPath, mc.ServerName)
   334  
   335  	return NewBearerAuthorizer(jwtCreds, transportCreds), nil
   336  }
   337  
   338  func TokenProviderFromFile(tokenLocation string) (JwtTokenProvider, error) {
   339  	if tokenLocation == "" {
   340  		return NewEmptyTokenCredentialProvider(), nil
   341  	}
   342  	loginconfig := LoginConfig{}
   343  	err := config.LoadYAMLFile(tokenLocation, &loginconfig)
   344  	if err != nil {
   345  		// if File does not exist we return no error. This to prevent any breaking changes
   346  		if errors.Is(err, fs.ErrNotExist) {
   347  			err = nil
   348  		}
   349  		return NewEmptyTokenCredentialProvider(), err
   350  	}
   351  	return NewTokenCredentialProvider(loginconfig.Token), nil
   352  }
   353  
   354  func TransportCredentialsFromFile(wssdConfigLocation string, server string) credentials.TransportCredentials {
   355  	credentials, err := NewTransportCredentialFromAccessFileLocation(server, wssdConfigLocation)
   356  	if err != nil {
   357  		return NewEmptyTransportCredential().GetTransportCredentials()
   358  	}
   359  	return credentials.GetTransportCredentials()
   360  }
   361  
   362  func ReadAccessFileToTls(accessFileLocation string) ([]byte, tls.Certificate, error) {
   363  	accessFile := WssdConfig{}
   364  	err := marshal.FromJSONFile(accessFileLocation, &accessFile)
   365  	if err != nil {
   366  		return []byte{}, tls.Certificate{}, err
   367  	}
   368  	return AccessFileToTls(accessFile)
   369  }
   370  func TransportCredentialsFromNode(tlsCert tls.Certificate, serverCertificate []byte, server string) credentials.TransportCredentials {
   371  
   372  	credential, err := NewTransportCredentialFromTlsCerts(server, []tls.Certificate{tlsCert}, serverCertificate)
   373  	if err != nil {
   374  		return NewEmptyTransportCredential().GetTransportCredentials()
   375  	}
   376  	return credential.GetTransportCredentials()
   377  
   378  }
   379  
   380  func SaveToken(tokenStr string) error {
   381  	return ioutil.WriteFile(
   382  		getClientTokenLocation(),
   383  		[]byte(tokenStr),
   384  		0644)
   385  }
   386  
   387  // PrintAccessFile stores wssdConfig in WssdConfigLocation
   388  func PrintAccessFile(accessFile WssdConfig) error {
   389  	return marshal.ToJSONFile(accessFile, GetWssdConfigLocation())
   390  }
   391  
   392  // PrintAccessFileByName stores wssdConfig in GetWssdConfigLocationName
   393  func PrintAccessFileByName(accessFile WssdConfig, subfolder, filename string) error {
   394  	return marshal.ToJSONFile(accessFile, GetMocConfigLocationName(subfolder, filename))
   395  }
   396  
   397  func AccessFileToTls(accessFile WssdConfig) ([]byte, tls.Certificate, error) {
   398  	serverPem, err := marshal.FromBase64(accessFile.CloudCertificate)
   399  	if err != nil {
   400  		return []byte{}, tls.Certificate{}, err
   401  	}
   402  	clientPem, err := marshal.FromBase64(accessFile.ClientCertificate)
   403  	if err != nil {
   404  		return []byte{}, tls.Certificate{}, err
   405  	}
   406  	keyPem, err := marshal.FromBase64(accessFile.ClientKey)
   407  	if err != nil {
   408  		return []byte{}, tls.Certificate{}, err
   409  	}
   410  
   411  	if err = CertCheck(clientPem); err != nil {
   412  		return []byte{}, tls.Certificate{}, err
   413  	}
   414  
   415  	tlsCert, err := tls.X509KeyPair(clientPem, keyPem)
   416  	if err != nil {
   417  		return []byte{}, tls.Certificate{}, err
   418  	}
   419  
   420  	return serverPem, tlsCert, nil
   421  }