github.com/cs3org/reva/v2@v2.27.7/internal/http/interceptors/auth/auth.go (about)

     1  // Copyright 2018-2021 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package auth
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"net/http"
    25  	"strings"
    26  	"sync"
    27  	"time"
    28  
    29  	"github.com/bluele/gcache"
    30  	authpb "github.com/cs3org/go-cs3apis/cs3/auth/provider/v1beta1"
    31  	gateway "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
    32  	userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
    33  	rpc "github.com/cs3org/go-cs3apis/cs3/rpc/v1beta1"
    34  	"github.com/cs3org/reva/v2/internal/http/interceptors/auth/credential/registry"
    35  	tokenregistry "github.com/cs3org/reva/v2/internal/http/interceptors/auth/token/registry"
    36  	tokenwriterregistry "github.com/cs3org/reva/v2/internal/http/interceptors/auth/tokenwriter/registry"
    37  	"github.com/cs3org/reva/v2/pkg/appctx"
    38  	"github.com/cs3org/reva/v2/pkg/auth"
    39  	"github.com/cs3org/reva/v2/pkg/auth/scope"
    40  	ctxpkg "github.com/cs3org/reva/v2/pkg/ctx"
    41  	"github.com/cs3org/reva/v2/pkg/errtypes"
    42  	"github.com/cs3org/reva/v2/pkg/rgrpc/status"
    43  	"github.com/cs3org/reva/v2/pkg/rgrpc/todo/pool"
    44  	"github.com/cs3org/reva/v2/pkg/rhttp/global"
    45  	"github.com/cs3org/reva/v2/pkg/sharedconf"
    46  	"github.com/cs3org/reva/v2/pkg/token"
    47  	tokenmgr "github.com/cs3org/reva/v2/pkg/token/manager/registry"
    48  	"github.com/cs3org/reva/v2/pkg/utils"
    49  	"github.com/mitchellh/mapstructure"
    50  	"github.com/pkg/errors"
    51  	"github.com/rs/zerolog"
    52  	semconv "go.opentelemetry.io/otel/semconv/v1.20.0"
    53  	"go.opentelemetry.io/otel/trace"
    54  	"google.golang.org/grpc/metadata"
    55  )
    56  
    57  // name is the Tracer name used to identify this instrumentation library.
    58  const tracerName = "auth"
    59  
    60  var (
    61  	cacheOnce       sync.Once
    62  	userGroupsCache gcache.Cache
    63  )
    64  
    65  type config struct {
    66  	Priority   int    `mapstructure:"priority"`
    67  	GatewaySvc string `mapstructure:"gatewaysvc"`
    68  	// TODO(jdf): Realm is optional, will be filled with request host if not given?
    69  	Realm                  string                            `mapstructure:"realm"`
    70  	CredentialsByUserAgent map[string]string                 `mapstructure:"credentials_by_user_agent"`
    71  	CredentialChain        []string                          `mapstructure:"credential_chain"`
    72  	CredentialStrategies   map[string]map[string]interface{} `mapstructure:"credential_strategies"`
    73  	TokenStrategyChain     []string                          `mapstructure:"token_strategy_chain"`
    74  	TokenStrategies        map[string]map[string]interface{} `mapstructure:"token_strategies"`
    75  	TokenManager           string                            `mapstructure:"token_manager"`
    76  	TokenManagers          map[string]map[string]interface{} `mapstructure:"token_managers"`
    77  	TokenWriter            string                            `mapstructure:"token_writer"`
    78  	TokenWriters           map[string]map[string]interface{} `mapstructure:"token_writers"`
    79  	UserGroupsCacheSize    int                               `mapstructure:"usergroups_cache_size"`
    80  }
    81  
    82  func parseConfig(m map[string]interface{}) (*config, error) {
    83  	c := &config{}
    84  	if err := mapstructure.Decode(m, c); err != nil {
    85  		err = errors.Wrap(err, "error decoding conf")
    86  		return nil, err
    87  	}
    88  	return c, nil
    89  }
    90  
    91  // New returns a new middleware with defined priority.
    92  func New(m map[string]interface{}, unprotected []string, tp trace.TracerProvider) (global.Middleware, error) {
    93  	conf, err := parseConfig(m)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	conf.GatewaySvc = sharedconf.GetGatewaySVC(conf.GatewaySvc)
    99  
   100  	// set defaults
   101  	if len(conf.TokenStrategyChain) == 0 {
   102  		conf.TokenStrategyChain = []string{"header"}
   103  	}
   104  
   105  	if conf.TokenWriter == "" {
   106  		conf.TokenWriter = "header"
   107  	}
   108  
   109  	if conf.TokenManager == "" {
   110  		conf.TokenManager = "jwt"
   111  	}
   112  
   113  	if conf.CredentialsByUserAgent == nil {
   114  		conf.CredentialsByUserAgent = map[string]string{}
   115  	}
   116  
   117  	if conf.UserGroupsCacheSize == 0 {
   118  		conf.UserGroupsCacheSize = 5000
   119  	}
   120  
   121  	cacheOnce.Do(func() {
   122  		userGroupsCache = gcache.New(conf.UserGroupsCacheSize).LFU().Build()
   123  	})
   124  
   125  	credChain := map[string]auth.CredentialStrategy{}
   126  	for i, key := range conf.CredentialChain {
   127  		f, ok := registry.NewCredentialFuncs[conf.CredentialChain[i]]
   128  		if !ok {
   129  			return nil, fmt.Errorf("credential strategy not found: %s", conf.CredentialChain[i])
   130  		}
   131  
   132  		credStrategy, err := f(conf.CredentialStrategies[conf.CredentialChain[i]])
   133  		if err != nil {
   134  			return nil, err
   135  		}
   136  		credChain[key] = credStrategy
   137  	}
   138  
   139  	tokenStrategyChain := make([]auth.TokenStrategy, 0, len(conf.TokenStrategyChain))
   140  	for _, strategy := range conf.TokenStrategyChain {
   141  		g, ok := tokenregistry.NewTokenFuncs[strategy]
   142  		if !ok {
   143  			return nil, fmt.Errorf("token strategy not found: %s", strategy)
   144  		}
   145  		tokenStrategy, err := g(conf.TokenStrategies[strategy])
   146  		if err != nil {
   147  			return nil, err
   148  		}
   149  		tokenStrategyChain = append(tokenStrategyChain, tokenStrategy)
   150  	}
   151  
   152  	h, ok := tokenmgr.NewFuncs[conf.TokenManager]
   153  	if !ok {
   154  		return nil, fmt.Errorf("token manager not found: %s", conf.TokenManager)
   155  	}
   156  
   157  	tokenManager, err := h(conf.TokenManagers[conf.TokenManager])
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  
   162  	i, ok := tokenwriterregistry.NewTokenFuncs[conf.TokenWriter]
   163  	if !ok {
   164  		return nil, fmt.Errorf("token writer not found: %s", conf.TokenWriter)
   165  	}
   166  
   167  	tokenWriter, err := i(conf.TokenWriters[conf.TokenWriter])
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	chain := func(h http.Handler) http.Handler {
   173  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   174  			// OPTION requests need to pass for preflight requests
   175  			// TODO(labkode): this will break options for auth protected routes.
   176  			// Maybe running the CORS middleware before auth kicks in is enough.
   177  			ctx := r.Context()
   178  			span := trace.SpanFromContext(ctx)
   179  			defer span.End()
   180  			if !span.SpanContext().HasTraceID() {
   181  				_, span = tp.Tracer(tracerName).Start(ctx, "http auth interceptor")
   182  			}
   183  
   184  			if r.Method == "OPTIONS" {
   185  				h.ServeHTTP(w, r)
   186  				return
   187  			}
   188  
   189  			log := appctx.GetLogger(r.Context())
   190  			isUnprotectedEndpoint := false
   191  
   192  			// For unprotected URLs, we try to authenticate the request in case some service needs it,
   193  			// but don't return any errors if it fails.
   194  			if utils.Skip(r.URL.Path, unprotected) {
   195  				log.Info().Msg("skipping auth check for: " + r.URL.Path)
   196  				isUnprotectedEndpoint = true
   197  			}
   198  
   199  			ctx, err := authenticateUser(w, r, conf, tokenStrategyChain, tokenManager, tokenWriter, credChain, isUnprotectedEndpoint)
   200  			if err != nil {
   201  				if !isUnprotectedEndpoint {
   202  					return
   203  				}
   204  			} else {
   205  				u, ok := ctxpkg.ContextGetUser(ctx)
   206  				if ok {
   207  					span.SetAttributes(semconv.EnduserIDKey.String(u.Id.OpaqueId))
   208  				}
   209  
   210  				r = r.WithContext(ctx)
   211  			}
   212  			h.ServeHTTP(w, r)
   213  
   214  		})
   215  	}
   216  	return chain, nil
   217  }
   218  
   219  func authenticateUser(w http.ResponseWriter, r *http.Request, conf *config, tokenStrategies []auth.TokenStrategy, tokenManager token.Manager, tokenWriter auth.TokenWriter, credChain map[string]auth.CredentialStrategy, isUnprotectedEndpoint bool) (context.Context, error) {
   220  	ctx := r.Context()
   221  	log := appctx.GetLogger(ctx)
   222  
   223  	// Add the request user-agent to the ctx
   224  	ctx = metadata.NewIncomingContext(ctx, metadata.New(map[string]string{ctxpkg.UserAgentHeader: r.UserAgent()}))
   225  
   226  	client, err := pool.GetGatewayServiceClient(conf.GatewaySvc)
   227  	if err != nil {
   228  		logError(isUnprotectedEndpoint, log, err, "error getting the authsvc client", http.StatusUnauthorized, w)
   229  		return nil, err
   230  	}
   231  
   232  	// reva token or auth token can be passed using the same technique (for example bearer)
   233  	// before validating it against an auth provider, we can check directly if it's a reva
   234  	// token and if not try to use it for authenticating the user.
   235  	for _, tokenStrategy := range tokenStrategies {
   236  		token := tokenStrategy.GetToken(r)
   237  		if token != "" {
   238  			if user, tokenScope, ok := isTokenValid(r, tokenManager, token); ok {
   239  				if err := insertGroupsInUser(ctx, userGroupsCache, client, user); err != nil {
   240  					logError(isUnprotectedEndpoint, log, err, "got an error retrieving groups for user "+user.Username, http.StatusInternalServerError, w)
   241  					return nil, err
   242  				}
   243  				return ctxWithUserInfo(ctx, r, user, token, tokenScope, r.Header.Get(ctxpkg.InitiatorHeader)), nil
   244  			}
   245  		}
   246  	}
   247  
   248  	log.Warn().Msg("core access token not set")
   249  
   250  	userAgentCredKeys := getCredsForUserAgent(r.UserAgent(), conf.CredentialsByUserAgent, conf.CredentialChain)
   251  
   252  	// obtain credentials (basic auth, bearer token, ...) based on user agent
   253  	var creds *auth.Credentials
   254  	for _, k := range userAgentCredKeys {
   255  		creds, err = credChain[k].GetCredentials(w, r)
   256  		if err != nil {
   257  			log.Debug().Err(err).Msg("error retrieving credentials")
   258  		}
   259  
   260  		if creds != nil {
   261  			log.Debug().Msgf("credentials obtained from credential strategy: type: %s, client_id: %s", creds.Type, creds.ClientID)
   262  			break
   263  		}
   264  	}
   265  
   266  	// if no credentials are found, reply with authentication challenge depending on user agent
   267  	if creds == nil {
   268  		if !isUnprotectedEndpoint {
   269  			for _, key := range userAgentCredKeys {
   270  				if cred, ok := credChain[key]; ok {
   271  					cred.AddWWWAuthenticate(w, r, conf.Realm)
   272  				} else {
   273  					log.Error().Msg("auth credential strategy: " + key + "must have been loaded in init method")
   274  					w.WriteHeader(http.StatusInternalServerError)
   275  					return nil, errtypes.InternalError("no credentials found")
   276  				}
   277  			}
   278  			w.WriteHeader(http.StatusUnauthorized)
   279  		}
   280  		return nil, errtypes.PermissionDenied("no credentials found")
   281  	}
   282  
   283  	req := &gateway.AuthenticateRequest{
   284  		Type:         creds.Type,
   285  		ClientId:     creds.ClientID,
   286  		ClientSecret: creds.ClientSecret,
   287  	}
   288  
   289  	log.Debug().Msgf("AuthenticateRequest: type: %s, client_id: %s against %s", req.Type, req.ClientId, conf.GatewaySvc)
   290  
   291  	res, err := client.Authenticate(ctx, req)
   292  	if err != nil {
   293  		logError(isUnprotectedEndpoint, log, err, "error calling Authenticate", http.StatusUnauthorized, w)
   294  		return nil, err
   295  	}
   296  
   297  	if res.Status.Code != rpc.Code_CODE_OK {
   298  		err := status.NewErrorFromCode(res.Status.Code, "auth")
   299  		logError(isUnprotectedEndpoint, log, err, "error generating access token from credentials", http.StatusUnauthorized, w)
   300  		return nil, err
   301  	}
   302  
   303  	log.Info().Msg("core access token generated") // write token to response
   304  
   305  	// write token to response
   306  	token := res.Token
   307  	tokenWriter.WriteToken(token, w)
   308  
   309  	// validate token
   310  	u, tokenScope, err := tokenManager.DismantleToken(r.Context(), token)
   311  	if err != nil {
   312  		logError(isUnprotectedEndpoint, log, err, "error dismantling token", http.StatusUnauthorized, w)
   313  		return nil, err
   314  	}
   315  
   316  	if sharedconf.SkipUserGroupsInToken() {
   317  		var groups []string
   318  		if groupsIf, err := userGroupsCache.Get(u.Id.OpaqueId); err == nil {
   319  			groups = groupsIf.([]string)
   320  		} else {
   321  			groupsRes, err := client.GetUserGroups(ctx, &userpb.GetUserGroupsRequest{UserId: u.Id})
   322  			if err != nil {
   323  				logError(isUnprotectedEndpoint, log, err, "error retrieving user groups", http.StatusInternalServerError, w)
   324  				return nil, err
   325  			}
   326  			groups = groupsRes.Groups
   327  			_ = userGroupsCache.SetWithExpire(u.Id.OpaqueId, groupsRes.Groups, 3600*time.Second)
   328  		}
   329  		u.Groups = groups
   330  	}
   331  
   332  	// ensure access to the resource is allowed
   333  	ok, err := scope.VerifyScope(ctx, tokenScope, r.URL.Path)
   334  	if err != nil {
   335  		logError(isUnprotectedEndpoint, log, err, "error verifying scope of access token", http.StatusInternalServerError, w)
   336  		return nil, err
   337  	}
   338  	if !ok {
   339  		err := errtypes.PermissionDenied("access to resource not allowed")
   340  		logError(isUnprotectedEndpoint, log, err, "access to resource not allowed", http.StatusUnauthorized, w)
   341  		return nil, err
   342  	}
   343  
   344  	return ctxWithUserInfo(ctx, r, u, token, tokenScope, r.Header.Get(ctxpkg.InitiatorHeader)), nil
   345  }
   346  
   347  func ctxWithUserInfo(ctx context.Context, r *http.Request, user *userpb.User, token string, tokenScope map[string]*authpb.Scope, initiatorid string) context.Context {
   348  	ctx = ctxpkg.ContextSetUser(ctx, user)
   349  	ctx = ctxpkg.ContextSetToken(ctx, token)
   350  	ctx = ctxpkg.ContextSetInitiator(ctx, initiatorid)
   351  	ctx = metadata.AppendToOutgoingContext(ctx, ctxpkg.TokenHeader, token)
   352  	ctx = metadata.AppendToOutgoingContext(ctx, ctxpkg.UserAgentHeader, r.UserAgent())
   353  	ctx = metadata.AppendToOutgoingContext(ctx, ctxpkg.InitiatorHeader, initiatorid)
   354  	ctx = ctxpkg.ContextSetScopes(ctx, tokenScope)
   355  	return ctx
   356  }
   357  
   358  func insertGroupsInUser(ctx context.Context, userGroupsCache gcache.Cache, client gateway.GatewayAPIClient, user *userpb.User) error {
   359  	if sharedconf.SkipUserGroupsInToken() {
   360  		var groups []string
   361  		if groupsIf, err := userGroupsCache.Get(user.Id.OpaqueId); err == nil {
   362  			groups = groupsIf.([]string)
   363  		} else {
   364  			groupsRes, err := client.GetUserGroups(ctx, &userpb.GetUserGroupsRequest{UserId: user.Id})
   365  			if err != nil {
   366  				return err
   367  			}
   368  			groups = groupsRes.Groups
   369  			_ = userGroupsCache.SetWithExpire(user.Id.OpaqueId, groupsRes.Groups, 3600*time.Second)
   370  		}
   371  		user.Groups = groups
   372  	}
   373  	return nil
   374  }
   375  
   376  func isTokenValid(r *http.Request, tokenManager token.Manager, token string) (*userpb.User, map[string]*authpb.Scope, bool) {
   377  	ctx := r.Context()
   378  
   379  	u, tokenScope, err := tokenManager.DismantleToken(ctx, token)
   380  	if err != nil {
   381  		return nil, nil, false
   382  	}
   383  
   384  	// ensure access to the resource is allowed
   385  	ok, err := scope.VerifyScope(ctx, tokenScope, r.URL.Path)
   386  	if err != nil {
   387  		return nil, nil, false
   388  	}
   389  
   390  	return u, tokenScope, ok
   391  }
   392  
   393  func logError(isUnprotectedEndpoint bool, log *zerolog.Logger, err error, msg string, status int, w http.ResponseWriter) {
   394  	if !isUnprotectedEndpoint {
   395  		log.Error().Err(err).Msg(msg)
   396  		w.WriteHeader(status)
   397  	}
   398  }
   399  
   400  // getCredsForUserAgent returns the WWW Authenticate challenges keys to use given an http request
   401  // and available credentials.
   402  func getCredsForUserAgent(ua string, uam map[string]string, creds []string) []string {
   403  	if ua == "" || len(uam) == 0 {
   404  		return creds
   405  	}
   406  
   407  	for u, cred := range uam {
   408  		if strings.Contains(ua, u) {
   409  			for _, v := range creds {
   410  				if v == cred {
   411  					return []string{cred}
   412  				}
   413  			}
   414  			return creds
   415  
   416  		}
   417  	}
   418  
   419  	return creds
   420  }