github.com/kisexp/xdchain@v0.0.0-20211206025815-490d6b732aa7/rpc/security.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"net/url"
     9  	"os"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/kisexp/xdchain/core/types"
    14  	"github.com/kisexp/xdchain/log"
    15  	"github.com/kisexp/xdchain/multitenancy"
    16  	"github.com/kisexp/xdchain/plugin/security"
    17  	"github.com/golang/protobuf/ptypes"
    18  	"github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
    19  )
    20  
    21  type securityContextSupport interface {
    22  	securityContextConfigurer
    23  	SecurityContextResolver
    24  }
    25  
    26  type securityContextConfigurer interface {
    27  	Configure(secCtx SecurityContext)
    28  }
    29  
    30  type SecurityContextResolver interface {
    31  	Resolve() SecurityContext
    32  }
    33  
    34  type securityError struct{ message string }
    35  
    36  // Provider function to return token being injected in Authorization http request header
    37  type HttpCredentialsProviderFunc func(ctx context.Context) (string, error)
    38  
    39  // Provider function to return a string value which will be
    40  // 1. injected in HttpPrivateStateIdentifierHeader http request header for HTTP/WS transports
    41  // 2. encoded in JSON MessageID for IPC/InProc transports
    42  type PSIProviderFunc func(ctx context.Context) (types.PrivateStateIdentifier, error)
    43  
    44  func (e *securityError) ErrorCode() int { return -32001 }
    45  
    46  func (e *securityError) Error() string { return e.message }
    47  
    48  func extractToken(req *http.Request) (string, bool) {
    49  	token := req.Header.Get(HttpAuthorizationHeader)
    50  	return token, token != ""
    51  }
    52  
    53  func verifyExpiration(token *proto.PreAuthenticatedAuthenticationToken) error {
    54  	if token == nil {
    55  		return nil
    56  	}
    57  	expiredAt, err := ptypes.Timestamp(token.ExpiredAt)
    58  	if err != nil {
    59  		return fmt.Errorf("invalid timestamp in token: %s", err)
    60  	}
    61  	if time.Now().Before(expiredAt) {
    62  		return nil
    63  	}
    64  	return &securityError{"token expired"}
    65  }
    66  
    67  func verifyAccess(service, method string, authorities []*proto.GrantedAuthority) error {
    68  	for _, authority := range authorities {
    69  		if authority.Service == "*" && authority.Method == "*" {
    70  			return nil
    71  		}
    72  		if authority.Service == "*" && authority.Method == method {
    73  			return nil
    74  		}
    75  		if authority.Service == service && authority.Method == "*" {
    76  			return nil
    77  		}
    78  		if authority.Service == service && authority.Method == method {
    79  			return nil
    80  		}
    81  	}
    82  	return &securityError{fmt.Sprintf("%s%s%s - access denied", service, serviceMethodSeparator, method)}
    83  }
    84  
    85  // verify if a call is authorized using information available in the security context
    86  // it also checks for token expiration. That means if this is called multiple times (batch processing),
    87  // token expiration is checked multiple times.
    88  //
    89  // It returns the verfied security context for caller to use.
    90  func SecureCall(resolver SecurityContextResolver, method string) (context.Context, error) {
    91  	secCtx := resolver.Resolve()
    92  	if secCtx == nil {
    93  		return context.Background(), nil
    94  	}
    95  	if err, hasError := secCtx.Value(ctxAuthenticationError).(error); hasError {
    96  		return nil, err
    97  	}
    98  	if authToken := PreauthenticatedTokenFromContext(secCtx); authToken != nil {
    99  		if err := verifyExpiration(authToken); err != nil {
   100  			return nil, err
   101  		}
   102  		elem := strings.SplitN(method, serviceMethodSeparator, 2)
   103  		if len(elem) != 2 {
   104  			log.Warn("unsupported method when performing authorization check", "method", method)
   105  		} else if err := verifyAccess(elem[0], elem[1], authToken.Authorities); err != nil {
   106  			return nil, err
   107  		}
   108  		// authorization check for PSI when multitenancy is enabled
   109  		if isMultitenant := IsMultitenantFromContext(secCtx); isMultitenant {
   110  			var authorizedPSI types.PrivateStateIdentifier
   111  			var err error
   112  			// does user provide PSI in the request
   113  			if requestPSI, ok := secCtx.Value(ctxRequestPrivateStateIdentifier).(types.PrivateStateIdentifier); !ok {
   114  				// let's try to extract from token
   115  				authorizedPSI, err = multitenancy.ExtractPSI(authToken)
   116  				if err != nil {
   117  					return nil, err
   118  				}
   119  			} else {
   120  				isAuthorized, err := multitenancy.IsPSIAuthorized(authToken, requestPSI)
   121  				if err != nil {
   122  					return nil, err
   123  				}
   124  				if !isAuthorized {
   125  					return nil, multitenancy.ErrNotAuthorized
   126  				}
   127  				authorizedPSI = requestPSI
   128  			}
   129  			secCtx = WithPrivateStateIdentifier(secCtx, authorizedPSI)
   130  			log.Debug("Determined authorized PSI", "psi", authorizedPSI)
   131  		}
   132  	}
   133  	return secCtx, nil
   134  }
   135  
   136  // AuthenticateHttpRequest uses the provided authManager to authenticate an http request and populates
   137  // the provided ctx with additional information useful for consumers
   138  func AuthenticateHttpRequest(ctx context.Context, r *http.Request, authManager security.AuthenticationManager) (securityContext context.Context) {
   139  	securityContext = ctx
   140  	userProvidedPSI, found := extractPSI(r)
   141  	if found {
   142  		securityContext = context.WithValue(securityContext, ctxRequestPrivateStateIdentifier, userProvidedPSI)
   143  	}
   144  	if isAuthEnabled, err := authManager.IsEnabled(context.Background()); err != nil {
   145  		// this indicates a failure in the plugin. We don't want any subsequent request unchecked
   146  		log.Error("failure when checking if authentication manager is enabled", "err", err)
   147  		securityContext = context.WithValue(securityContext, ctxAuthenticationError, &securityError{"internal error"})
   148  		return
   149  	} else if !isAuthEnabled {
   150  		// node is not configured to be multitenant but MPS is enabled
   151  		securityContext = WithPrivateStateIdentifier(securityContext, userProvidedPSI)
   152  		return
   153  	}
   154  	if token, hasToken := extractToken(r); hasToken {
   155  		if authToken, err := authManager.Authenticate(context.Background(), token); err != nil {
   156  			securityContext = context.WithValue(securityContext, ctxAuthenticationError, &securityError{err.Error()})
   157  		} else {
   158  			securityContext = WithPreauthenticatedToken(securityContext, authToken)
   159  		}
   160  	} else {
   161  		securityContext = context.WithValue(securityContext, ctxAuthenticationError, &securityError{"missing access token"})
   162  	}
   163  	return
   164  }
   165  
   166  // construct JSON RPC error message which has the ID of the request
   167  func securityErrorMessage(forMsg *jsonrpcMessage, err error) *jsonrpcMessage {
   168  	msg := &jsonrpcMessage{Version: vsn, ID: forMsg.ID, Error: &jsonError{
   169  		Code:    defaultErrorCode,
   170  		Message: err.Error(),
   171  	}}
   172  	ec, ok := err.(Error)
   173  	if ok {
   174  		msg.Error.Code = ec.ErrorCode()
   175  	}
   176  	return msg
   177  }
   178  
   179  // extractPSI tries to extract the PSI from the HTTP Header then the URL
   180  // otherwise return the default value but still signal the caller
   181  // that user doesn't provide PSI
   182  func extractPSI(r *http.Request) (types.PrivateStateIdentifier, bool) {
   183  	psi := r.Header.Get(HttpPrivateStateIdentifierHeader)
   184  	if len(psi) == 0 {
   185  		psi = r.URL.Query().Get(QueryPrivateStateIdentifierParamName)
   186  	}
   187  	if len(psi) == 0 {
   188  		return types.DefaultPrivateStateIdentifier, false
   189  	}
   190  	return types.PrivateStateIdentifier(psi), true
   191  }
   192  
   193  // resolvePSIProvider enriches the given context with PSIProviderFunc if PSI value found
   194  // in URL Query or env variable
   195  func resolvePSIProvider(ctx context.Context, endpoint string) (newCtx context.Context) {
   196  	newCtx = ctx
   197  	var rawPSI string
   198  	// first take from endpoint
   199  	parsedUrl, err := url.Parse(endpoint)
   200  	if err != nil {
   201  		return
   202  	}
   203  	switch parsedUrl.Scheme {
   204  	case "http", "https", "ws", "wss":
   205  		rawPSI = parsedUrl.Query().Get(QueryPrivateStateIdentifierParamName)
   206  	default:
   207  	}
   208  	// then from the env variable
   209  	if value := os.Getenv(EnvVarPrivateStateIdentifier); len(value) > 0 {
   210  		rawPSI = value
   211  	}
   212  	if len(rawPSI) > 0 {
   213  		// must declare type here so the context value reflects the same
   214  		var f PSIProviderFunc = func(_ context.Context) (types.PrivateStateIdentifier, error) {
   215  			return types.PrivateStateIdentifier(rawPSI), nil
   216  		}
   217  		newCtx = WithPSIProvider(ctx, f)
   218  	}
   219  	return
   220  }
   221  
   222  // encodePSI includes counter and PSI value in an JSON message ID.
   223  // i.e.: <counter> becomes "<psi>/32"
   224  func encodePSI(idCounterBytes []byte, psi types.PrivateStateIdentifier) json.RawMessage {
   225  	if len(psi) == 0 {
   226  		return idCounterBytes
   227  	}
   228  	newID := make([]byte, len(idCounterBytes)+len(psi)+3) // including 2 double quotes and '@'
   229  	newID[0], newID[len(newID)-1] = '"', '"'
   230  	copy(newID[1:len(psi)+1], psi)
   231  	copy(newID[len(psi)+1:], append([]byte("/"), idCounterBytes...))
   232  	return newID
   233  }
   234  
   235  // decodePSI extracts PSI value from an encoded JSON message ID. Return DefaultPrivateStateIdentifier
   236  // if not found
   237  // i.e.: "<counter>/<psi>" returns <psi>
   238  func decodePSI(id json.RawMessage) types.PrivateStateIdentifier {
   239  	idStr := string(id)
   240  	if !strings.HasPrefix(idStr, "\"") || !strings.HasSuffix(idStr, "\"") {
   241  		return types.DefaultPrivateStateIdentifier
   242  	}
   243  	sepIdx := strings.Index(idStr, "/")
   244  	if sepIdx == -1 {
   245  		return types.DefaultPrivateStateIdentifier
   246  	}
   247  	return types.PrivateStateIdentifier(id[1:sepIdx])
   248  }