go.temporal.io/server@v1.23.0/common/authorization/interceptor.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package authorization
    26  
    27  import (
    28  	"context"
    29  	"crypto/x509"
    30  	"crypto/x509/pkix"
    31  	"time"
    32  
    33  	"go.temporal.io/api/serviceerror"
    34  	"google.golang.org/grpc"
    35  	"google.golang.org/grpc/credentials"
    36  	"google.golang.org/grpc/metadata"
    37  	"google.golang.org/grpc/peer"
    38  
    39  	"go.temporal.io/server/common/log"
    40  	"go.temporal.io/server/common/log/tag"
    41  	"go.temporal.io/server/common/metrics"
    42  	"go.temporal.io/server/common/util"
    43  )
    44  
    45  type (
    46  	contextKeyMappedClaims struct{}
    47  	contextKeyAuthHeader   struct{}
    48  )
    49  
    50  type (
    51  	// JWTAudienceMapper returns JWT audience for a given request
    52  	JWTAudienceMapper interface {
    53  		Audience(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo) string
    54  	}
    55  )
    56  
    57  const (
    58  	RequestUnauthorized = "Request unauthorized."
    59  
    60  	defaultAuthHeaderName      = "authorization"
    61  	defaultAuthExtraHeaderName = "authorization-extras"
    62  )
    63  
    64  var (
    65  	errUnauthorized = serviceerror.NewPermissionDenied(RequestUnauthorized, "")
    66  
    67  	MappedClaims contextKeyMappedClaims
    68  	AuthHeader   contextKeyAuthHeader
    69  )
    70  
    71  func (a *interceptor) Interceptor(
    72  	ctx context.Context,
    73  	req interface{},
    74  	info *grpc.UnaryServerInfo,
    75  	handler grpc.UnaryHandler,
    76  ) (interface{}, error) {
    77  
    78  	var claims *Claims
    79  
    80  	if a.claimMapper != nil && a.authorizer != nil {
    81  		var tlsSubject *pkix.Name
    82  		var authHeaders []string
    83  		var authExtraHeaders []string
    84  		var tlsConnection *credentials.TLSInfo
    85  
    86  		if md, ok := metadata.FromIncomingContext(ctx); ok {
    87  			authHeaders = md[a.authHeaderName]
    88  			authExtraHeaders = md[a.authExtraHeaderName]
    89  		}
    90  		tlsConnection = TLSInfoFormContext(ctx)
    91  		clientCert := PeerCert(tlsConnection)
    92  		if clientCert != nil {
    93  			tlsSubject = &clientCert.Subject
    94  		}
    95  
    96  		authInfoRequired := true
    97  		if cm, ok := a.claimMapper.(ClaimMapperWithAuthInfoRequired); ok {
    98  			authInfoRequired = cm.AuthInfoRequired()
    99  		}
   100  
   101  		// Add auth info to context only if there's some auth info
   102  		if tlsSubject != nil || len(authHeaders) > 0 || !authInfoRequired {
   103  			var authHeader string
   104  			var authExtraHeader string
   105  			var audience string
   106  			if len(authHeaders) > 0 {
   107  				authHeader = authHeaders[0]
   108  			}
   109  			if len(authExtraHeaders) > 0 {
   110  				authExtraHeader = authExtraHeaders[0]
   111  			}
   112  			if a.audienceGetter != nil {
   113  				audience = a.audienceGetter.Audience(ctx, req, info)
   114  			}
   115  			authInfo := AuthInfo{
   116  				AuthToken:     authHeader,
   117  				TLSSubject:    tlsSubject,
   118  				TLSConnection: tlsConnection,
   119  				ExtraData:     authExtraHeader,
   120  				Audience:      audience,
   121  			}
   122  			mappedClaims, err := a.claimMapper.GetClaims(&authInfo)
   123  			if err != nil {
   124  				a.logAuthError(err)
   125  				return nil, errUnauthorized // return a generic error to the caller without disclosing details
   126  			}
   127  			claims = mappedClaims
   128  			ctx = context.WithValue(ctx, MappedClaims, mappedClaims)
   129  			if authHeader != "" {
   130  				ctx = context.WithValue(ctx, AuthHeader, authHeader)
   131  			}
   132  		}
   133  	}
   134  
   135  	if a.authorizer != nil {
   136  		var namespace string
   137  		requestWithNamespace, ok := req.(hasNamespace)
   138  		if ok {
   139  			namespace = requestWithNamespace.GetNamespace()
   140  		}
   141  
   142  		handler := a.getMetricsHandler(metrics.AuthorizationScope, namespace)
   143  		result, err := a.authorize(ctx, claims, &CallTarget{
   144  			Namespace: namespace,
   145  			APIName:   info.FullMethod,
   146  			Request:   req,
   147  		}, handler)
   148  		if err != nil {
   149  			handler.Counter(metrics.ServiceErrAuthorizeFailedCounter.Name()).Record(1)
   150  			a.logAuthError(err)
   151  			return nil, errUnauthorized // return a generic error to the caller without disclosing details
   152  		}
   153  		if result.Decision != DecisionAllow {
   154  			handler.Counter(metrics.ServiceErrUnauthorizedCounter.Name()).Record(1)
   155  			// if a reason is included in the result, include it in the error message
   156  			if result.Reason != "" {
   157  				return nil, serviceerror.NewPermissionDenied(RequestUnauthorized, result.Reason)
   158  			}
   159  			return nil, errUnauthorized // return a generic error to the caller without disclosing details
   160  		}
   161  	}
   162  	return handler(ctx, req)
   163  }
   164  
   165  func (a *interceptor) authorize(
   166  	ctx context.Context,
   167  	claims *Claims,
   168  	callTarget *CallTarget,
   169  	metricsHandler metrics.Handler) (Result, error) {
   170  	startTime := time.Now().UTC()
   171  	defer func() {
   172  		metricsHandler.Timer(metrics.ServiceAuthorizationLatency.Name()).Record(time.Since(startTime))
   173  	}()
   174  	return a.authorizer.Authorize(ctx, claims, callTarget)
   175  }
   176  
   177  func (a *interceptor) logAuthError(err error) {
   178  	a.logger.Error("Authorization error", tag.Error(err))
   179  }
   180  
   181  type interceptor struct {
   182  	authorizer          Authorizer
   183  	claimMapper         ClaimMapper
   184  	metricsHandler      metrics.Handler
   185  	logger              log.Logger
   186  	audienceGetter      JWTAudienceMapper
   187  	authHeaderName      string
   188  	authExtraHeaderName string
   189  }
   190  
   191  // NewAuthorizationInterceptor creates an authorization interceptor and return a func that points to its Interceptor method
   192  func NewAuthorizationInterceptor(
   193  	claimMapper ClaimMapper,
   194  	authorizer Authorizer,
   195  	metricsHandler metrics.Handler,
   196  	logger log.Logger,
   197  	audienceGetter JWTAudienceMapper,
   198  	authHeaderName string,
   199  	authExtraHeaderName string,
   200  ) grpc.UnaryServerInterceptor {
   201  	return (&interceptor{
   202  		claimMapper:         claimMapper,
   203  		authorizer:          authorizer,
   204  		metricsHandler:      metricsHandler,
   205  		logger:              logger,
   206  		audienceGetter:      audienceGetter,
   207  		authHeaderName:      util.Coalesce(authHeaderName, defaultAuthHeaderName),
   208  		authExtraHeaderName: util.Coalesce(authExtraHeaderName, defaultAuthExtraHeaderName),
   209  	}).Interceptor
   210  }
   211  
   212  // getMetricsHandler return metrics handler with namespace tag
   213  func (a *interceptor) getMetricsHandler(
   214  	operation string,
   215  	namespace string,
   216  ) metrics.Handler {
   217  	var metricsHandler metrics.Handler
   218  	if namespace != "" {
   219  		metricsHandler = a.metricsHandler.WithTags(metrics.OperationTag(operation), metrics.NamespaceTag(namespace))
   220  	} else {
   221  		metricsHandler = a.metricsHandler.WithTags(metrics.OperationTag(operation), metrics.NamespaceUnknownTag())
   222  	}
   223  	return metricsHandler
   224  }
   225  
   226  func TLSInfoFormContext(ctx context.Context) *credentials.TLSInfo {
   227  
   228  	p, ok := peer.FromContext(ctx)
   229  	if !ok {
   230  		return nil
   231  	}
   232  	if tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo); ok {
   233  		return &tlsInfo
   234  	}
   235  	return nil
   236  }
   237  
   238  func PeerCert(tlsInfo *credentials.TLSInfo) *x509.Certificate {
   239  
   240  	if tlsInfo == nil || len(tlsInfo.State.VerifiedChains) == 0 || len(tlsInfo.State.VerifiedChains[0]) == 0 {
   241  		return nil
   242  	}
   243  	// The assumption here is that we only expect a single verified chain of certs (first[0]).
   244  	// It's unclear how we should handle a situation when more than one chain is presented,
   245  	// which subject to use. It's okay for us to limit ourselves to one chain.
   246  	// We can always extend this logic later.
   247  	// We take the first element in the chain ([0]) because that's the client cert
   248  	// (at the beginning of the chain), not intermediary CAs or the root CA (at the end of the chain).
   249  	return tlsInfo.State.VerifiedChains[0][0]
   250  }