github.com/zorawar87/trillian@v1.2.1/server/interceptor/interceptor.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package interceptor defines gRPC interceptors for Trillian.
    16  package interceptor
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"regexp"
    22  	"sync"
    23  	"time"
    24  
    25  	"github.com/golang/glog"
    26  	"github.com/google/trillian"
    27  	"github.com/google/trillian/monitoring"
    28  	"github.com/google/trillian/quota"
    29  	"github.com/google/trillian/quota/etcd/quotapb"
    30  	"github.com/google/trillian/server/errors"
    31  	"github.com/google/trillian/storage"
    32  	"github.com/google/trillian/trees"
    33  	"go.opencensus.io/trace"
    34  	"google.golang.org/grpc"
    35  	"google.golang.org/grpc/codes"
    36  	"google.golang.org/grpc/status"
    37  )
    38  
    39  const (
    40  	badInfoReason            = "bad_info"
    41  	badTreeReason            = "bad_tree"
    42  	insufficientTokensReason = "insufficient_tokens"
    43  	getTreeStage             = "get_tree"
    44  	getTokensStage           = "get_tokens"
    45  	traceSpanRoot            = "github/com/google/trillian/server/interceptor"
    46  )
    47  
    48  var (
    49  	// PutTokensTimeout is the timeout used for PutTokens calls.
    50  	// PutTokens happens in a separate goroutine and with an independent context, therefore it has
    51  	// its own timeout, separate from the RPC that causes the calls.
    52  	PutTokensTimeout = 5 * time.Second
    53  
    54  	requestCounter       monitoring.Counter
    55  	requestDeniedCounter monitoring.Counter
    56  	contextErrCounter    monitoring.Counter
    57  	metricsOnce          sync.Once
    58  	enabledServices      = map[string]bool{
    59  		"trillian.TrillianLog":   true,
    60  		"trillian.TrillianMap":   true,
    61  		"trillian.TrillianAdmin": true,
    62  		"TrillianLog":            true,
    63  		"TrillianMap":            true,
    64  		"TrillianAdmin":          true,
    65  	}
    66  )
    67  
    68  // RequestProcessor encapsulates the logic to intercept a request, split into separate stages:
    69  // before and after the handler is invoked.
    70  type RequestProcessor interface {
    71  
    72  	// Before implements all interceptor logic that happens before the handler is called.
    73  	// It returns a (potentially) modified context that's passed forward to the handler (and After),
    74  	// plus an error, in case the request should be interrupted before the handler is invoked.
    75  	Before(ctx context.Context, req interface{}, method string) (context.Context, error)
    76  
    77  	// After implements all interceptor logic that happens after the handler is invoked.
    78  	// Before must be invoked prior to After and the same RequestProcessor instance must to be used
    79  	// to process a given request.
    80  	After(ctx context.Context, resp interface{}, method string, handlerErr error)
    81  }
    82  
    83  // TrillianInterceptor checks that:
    84  // * Requests addressing a tree have the correct tree type and tree state;
    85  // * TODO(codingllama): Requests are properly authenticated / authorized ; and
    86  // * Requests are rate limited appropriately.
    87  type TrillianInterceptor struct {
    88  	admin storage.AdminStorage
    89  	qm    quota.Manager
    90  
    91  	// quotaDryRun controls whether lack of tokens actually blocks requests (if set to true, no
    92  	// requests are blocked by lack of tokens).
    93  	quotaDryRun bool
    94  }
    95  
    96  // New returns a new TrillianInterceptor instance.
    97  func New(admin storage.AdminStorage, qm quota.Manager, quotaDryRun bool, mf monitoring.MetricFactory) *TrillianInterceptor {
    98  	metricsOnce.Do(func() { initMetrics(mf) })
    99  	return &TrillianInterceptor{
   100  		admin:       admin,
   101  		qm:          qm,
   102  		quotaDryRun: quotaDryRun,
   103  	}
   104  }
   105  
   106  func initMetrics(mf monitoring.MetricFactory) {
   107  	if mf == nil {
   108  		mf = monitoring.InertMetricFactory{}
   109  	}
   110  	quota.InitMetrics(mf)
   111  	requestCounter = mf.NewCounter(
   112  		"interceptor_request_count",
   113  		"Total number of intercepted requests",
   114  		monitoring.TreeIDLabel)
   115  	requestDeniedCounter = mf.NewCounter(
   116  		"interceptor_request_denied_count",
   117  		"Number of requests by denied, labeled according to the reason for denial",
   118  		"reason", monitoring.TreeIDLabel, "quota_user")
   119  	contextErrCounter = mf.NewCounter(
   120  		"interceptor_context_err_counter",
   121  		"Total number of times request context has been cancelled or deadline exceeded by stage",
   122  		"stage")
   123  }
   124  
   125  func incRequestDeniedCounter(reason string, treeID int64, quotaUser string) {
   126  	requestDeniedCounter.Inc(reason, fmt.Sprint(treeID), quotaUser)
   127  }
   128  
   129  // UnaryInterceptor executes the TrillianInterceptor logic for unary RPCs.
   130  func (i *TrillianInterceptor) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
   131  	// Implement UnaryInterceptor using a RequestProcessor, so we
   132  	// 1. exercise it
   133  	// 2. make it easier to port this logic to non-gRPC implementations.
   134  
   135  	rp := i.NewProcessor()
   136  	var err error
   137  	ctx, err = rp.Before(ctx, req, info.FullMethod)
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  	resp, err := handler(ctx, req)
   142  	rp.After(ctx, resp, info.FullMethod, err)
   143  	return resp, err
   144  }
   145  
   146  // NewProcessor returns a RequestProcessor for the TrillianInterceptor logic.
   147  func (i *TrillianInterceptor) NewProcessor() RequestProcessor {
   148  	return &trillianProcessor{parent: i}
   149  }
   150  
   151  type trillianProcessor struct {
   152  	parent *TrillianInterceptor
   153  	info   *rpcInfo
   154  }
   155  
   156  func (tp *trillianProcessor) Before(ctx context.Context, req interface{}, method string) (context.Context, error) {
   157  	// Skip if the interceptor is not enabled for this service.
   158  	if !enabledServices[serviceName(method)] {
   159  		return ctx, nil
   160  	}
   161  
   162  	ctx, span := spanFor(ctx, "Before")
   163  	defer span.End()
   164  	info, err := newRPCInfo(req)
   165  	if err != nil {
   166  		glog.Warningf("Failed to read tree info: %v", err)
   167  		incRequestDeniedCounter(badInfoReason, 0, "")
   168  		return ctx, err
   169  	}
   170  	tp.info = info
   171  	requestCounter.Inc(fmt.Sprint(info.treeID))
   172  
   173  	// TODO(codingllama): Add auth interception
   174  
   175  	if info.getTree {
   176  		tree, err := trees.GetTree(
   177  			ctx, tp.parent.admin, info.treeID, trees.NewGetOpts(trees.Admin, info.treeTypes...))
   178  		if err != nil {
   179  			incRequestDeniedCounter(badTreeReason, info.treeID, info.quotaUsers)
   180  			return ctx, err
   181  		}
   182  		if err := ctx.Err(); err != nil {
   183  			contextErrCounter.Inc(getTreeStage)
   184  			return ctx, err
   185  		}
   186  		ctx = trees.NewContext(ctx, tree)
   187  	}
   188  
   189  	if info.tokens > 0 && len(info.specs) > 0 {
   190  		err := tp.parent.qm.GetTokens(ctx, info.tokens, info.specs)
   191  		if err != nil {
   192  			if !tp.parent.quotaDryRun {
   193  				incRequestDeniedCounter(insufficientTokensReason, info.treeID, info.quotaUsers)
   194  				return ctx, status.Errorf(codes.ResourceExhausted, "quota exhausted: %v", err)
   195  			}
   196  			glog.Warningf("(quotaDryRun) Request %+v not denied due to dry run mode: %v", req, err)
   197  		}
   198  		quota.Metrics.IncAcquired(info.tokens, info.specs, err == nil)
   199  		if err = ctx.Err(); err != nil {
   200  			contextErrCounter.Inc(getTokensStage)
   201  			return ctx, err
   202  		}
   203  	}
   204  
   205  	return ctx, nil
   206  }
   207  
   208  func (tp *trillianProcessor) After(ctx context.Context, resp interface{}, method string, handlerErr error) {
   209  	if !enabledServices[serviceName(method)] {
   210  		return
   211  	}
   212  	_, span := spanFor(ctx, "After")
   213  	defer span.End()
   214  	switch {
   215  	case tp.info == nil:
   216  		glog.Warningf("After called with nil rpcInfo, resp = [%+v], handlerErr = [%v]", resp, handlerErr)
   217  		return
   218  	case tp.info.tokens == 0:
   219  		// After() currently only does quota processing
   220  		return
   221  	}
   222  
   223  	// Decide if we have to replenish tokens. There are a few situations that require tokens to
   224  	// be replenished:
   225  	// * Invalid requests (a bad request shouldn't spend sequencing-based tokens, as it won't
   226  	//   cause a corresponding sequencing to happen)
   227  	// * Requests that filter out duplicates (e.g., QueueLeaf and QueueLeaves, for the same
   228  	//   reason as above: duplicates aren't queued for sequencing)
   229  	tokens := 0
   230  	if handlerErr != nil {
   231  		// Return the tokens spent by invalid requests
   232  		tokens = tp.info.tokens
   233  	} else {
   234  		switch resp := resp.(type) {
   235  		case *trillian.QueueLeafResponse:
   236  			if !isLeafOK(resp.GetQueuedLeaf()) {
   237  				tokens = 1
   238  			}
   239  		case *trillian.AddSequencedLeavesResponse:
   240  			for _, leaf := range resp.GetResults() {
   241  				if !isLeafOK(leaf) {
   242  					tokens++
   243  				}
   244  			}
   245  		case *trillian.QueueLeavesResponse:
   246  			for _, leaf := range resp.GetQueuedLeaves() {
   247  				if !isLeafOK(leaf) {
   248  					tokens++
   249  				}
   250  			}
   251  		}
   252  	}
   253  	if len(tp.info.specs) > 0 && tokens > 0 {
   254  		// Run PutTokens in a separate goroutine and with a separate context.
   255  		// It shouldn't block RPC completion, nor should it share the RPC's context deadline.
   256  		go func() {
   257  			ctx, span := spanFor(context.Background(), "After.PutTokens")
   258  			defer span.End()
   259  			ctx, cancel := context.WithTimeout(ctx, PutTokensTimeout)
   260  			defer cancel()
   261  
   262  			// TODO(codingllama): If PutTokens turns out to be unreliable we can still leak tokens. In
   263  			// this case, we may want to keep tabs on how many tokens we failed to replenish and bundle
   264  			// them up in the next PutTokens call (possibly as a QuotaManager decorator, or internally
   265  			// in its impl).
   266  			err := tp.parent.qm.PutTokens(ctx, tokens, tp.info.specs)
   267  			if err != nil {
   268  				glog.Warningf("Failed to replenish %v tokens: %v", tokens, err)
   269  			}
   270  			quota.Metrics.IncReturned(tokens, tp.info.specs, err == nil)
   271  		}()
   272  	}
   273  }
   274  
   275  func isLeafOK(leaf *trillian.QueuedLogLeaf) bool {
   276  	// Be biased in favor of OK, as that matches TrillianLogRPCServer's behavior.
   277  	return leaf == nil || leaf.Status == nil || leaf.Status.Code == int32(codes.OK)
   278  }
   279  
   280  var fullyQualifiedRE = regexp.MustCompile(`^/([\w.]+)/(\w+)$`)
   281  var unqualifiedRE = regexp.MustCompile(`^/(\w+)\.(\w+)$`)
   282  
   283  // serviceName returns the fully qualified service name
   284  // "some.package.service" for "/some.package.service/method".
   285  // It returns the unqualified service name "service" for "/service.method".
   286  func serviceName(fullMethod string) string {
   287  	if matches := fullyQualifiedRE.FindStringSubmatch(fullMethod); len(matches) == 3 {
   288  		return matches[1]
   289  	}
   290  	if matches := unqualifiedRE.FindStringSubmatch(fullMethod); len(matches) == 3 {
   291  		return matches[1]
   292  	}
   293  	return ""
   294  }
   295  
   296  type rpcInfo struct {
   297  	// getTree indicates whether the interceptor should populate treeID.
   298  	getTree bool
   299  
   300  	readonly  bool
   301  	treeID    int64
   302  	treeTypes []trillian.TreeType
   303  
   304  	specs  []quota.Spec
   305  	tokens int
   306  	// Single string describing all of the users against which quota is requested.
   307  	quotaUsers string
   308  }
   309  
   310  // chargable is satisfied by request proto messages which contain a GetChargeTo
   311  // accessor.
   312  type chargable interface {
   313  	GetChargeTo() *trillian.ChargeTo
   314  }
   315  
   316  // chargedUsers returns user identifiers for any chargable user quotas.
   317  func chargedUsers(req interface{}) []string {
   318  	c, ok := req.(chargable)
   319  	if !ok {
   320  		return nil
   321  	}
   322  	chargeTo := c.GetChargeTo()
   323  	if chargeTo == nil {
   324  		return nil
   325  	}
   326  
   327  	return chargeTo.User
   328  }
   329  
   330  func newRPCInfoForRequest(req interface{}) (*rpcInfo, error) {
   331  	// Set "safe" defaults: enable all interception and assume requests are readonly.
   332  	info := &rpcInfo{
   333  		getTree:   true,
   334  		readonly:  true,
   335  		treeTypes: nil,
   336  		tokens:    0,
   337  	}
   338  
   339  	switch req := req.(type) {
   340  
   341  	// Not intercepted at all
   342  	case
   343  		// Quota configuration requests
   344  		*quotapb.CreateConfigRequest,
   345  		*quotapb.DeleteConfigRequest,
   346  		*quotapb.GetConfigRequest,
   347  		*quotapb.ListConfigsRequest,
   348  		*quotapb.UpdateConfigRequest:
   349  		info.getTree = false
   350  		info.readonly = false // Doesn't really matter as all interceptors are turned off
   351  
   352  	// Admin create
   353  	case *trillian.CreateTreeRequest:
   354  		info.getTree = false // Tree doesn't exist
   355  		info.readonly = false
   356  
   357  	// Admin list
   358  	case *trillian.ListTreesRequest:
   359  		info.getTree = false // Zero to many trees
   360  
   361  	// Admin / readonly
   362  	case *trillian.GetTreeRequest:
   363  		info.getTree = false // Read done within RPC handler
   364  
   365  	// Admin / readwrite
   366  	case *trillian.DeleteTreeRequest,
   367  		*trillian.UndeleteTreeRequest,
   368  		*trillian.UpdateTreeRequest:
   369  		info.getTree = false // Read-modify-write done within RPC handler
   370  		info.readonly = false
   371  
   372  	// (Log + Pre-ordered Log) / readonly
   373  	case *trillian.GetConsistencyProofRequest,
   374  		*trillian.GetEntryAndProofRequest,
   375  		*trillian.GetInclusionProofByHashRequest,
   376  		*trillian.GetInclusionProofRequest,
   377  		*trillian.GetLatestSignedLogRootRequest:
   378  		info.treeTypes = []trillian.TreeType{trillian.TreeType_LOG, trillian.TreeType_PREORDERED_LOG}
   379  		info.tokens = 1
   380  	case *trillian.GetLeavesByHashRequest:
   381  		info.treeTypes = []trillian.TreeType{trillian.TreeType_LOG, trillian.TreeType_PREORDERED_LOG}
   382  		info.tokens = len(req.GetLeafHash())
   383  	case *trillian.GetLeavesByIndexRequest:
   384  		info.treeTypes = []trillian.TreeType{trillian.TreeType_LOG, trillian.TreeType_PREORDERED_LOG}
   385  		info.tokens = len(req.GetLeafIndex())
   386  	case *trillian.GetLeavesByRangeRequest:
   387  		info.treeTypes = []trillian.TreeType{trillian.TreeType_LOG, trillian.TreeType_PREORDERED_LOG}
   388  		info.tokens = 1
   389  		if c := req.GetCount(); c > 1 {
   390  			info.tokens = int(c)
   391  		}
   392  	case *trillian.GetSequencedLeafCountRequest:
   393  		info.treeTypes = []trillian.TreeType{trillian.TreeType_LOG, trillian.TreeType_PREORDERED_LOG}
   394  
   395  	// Log / readwrite
   396  	case *trillian.QueueLeafRequest:
   397  		info.readonly = false
   398  		info.treeTypes = []trillian.TreeType{trillian.TreeType_LOG}
   399  		info.tokens = 1
   400  	case *trillian.QueueLeavesRequest:
   401  		info.readonly = false
   402  		info.treeTypes = []trillian.TreeType{trillian.TreeType_LOG}
   403  		info.tokens = len(req.GetLeaves())
   404  
   405  	// Pre-ordered Log / readwrite
   406  	case *trillian.AddSequencedLeafRequest:
   407  		info.readonly = false
   408  		info.treeTypes = []trillian.TreeType{trillian.TreeType_PREORDERED_LOG}
   409  		info.tokens = 1
   410  	case *trillian.AddSequencedLeavesRequest:
   411  		info.readonly = false
   412  		info.treeTypes = []trillian.TreeType{trillian.TreeType_PREORDERED_LOG}
   413  		info.tokens = len(req.GetLeaves())
   414  
   415  	// (Log + Pre-ordered Log) / readwrite
   416  	case *trillian.InitLogRequest:
   417  		info.readonly = false
   418  		info.treeTypes = []trillian.TreeType{trillian.TreeType_LOG, trillian.TreeType_PREORDERED_LOG}
   419  		info.tokens = 1
   420  
   421  	// Map / readonly
   422  	case *trillian.GetMapLeavesByRevisionRequest:
   423  		info.treeTypes = []trillian.TreeType{trillian.TreeType_MAP}
   424  		info.tokens = len(req.GetIndex())
   425  	case *trillian.GetMapLeavesRequest:
   426  		info.treeTypes = []trillian.TreeType{trillian.TreeType_MAP}
   427  		info.tokens = len(req.GetIndex())
   428  	case *trillian.GetSignedMapRootByRevisionRequest,
   429  		*trillian.GetSignedMapRootRequest:
   430  		info.treeTypes = []trillian.TreeType{trillian.TreeType_MAP}
   431  		info.tokens = 1
   432  
   433  	// Map / readwrite
   434  	case *trillian.SetMapLeavesRequest:
   435  		info.readonly = false
   436  		info.treeTypes = []trillian.TreeType{trillian.TreeType_MAP}
   437  		info.tokens = len(req.GetLeaves())
   438  	case *trillian.InitMapRequest:
   439  		info.readonly = false
   440  		info.treeTypes = []trillian.TreeType{trillian.TreeType_MAP}
   441  		info.tokens = 1
   442  
   443  	default:
   444  		return nil, status.Errorf(codes.Internal, "newRPCInfo: unmapped request type: %T", req)
   445  	}
   446  
   447  	return info, nil
   448  }
   449  
   450  func newRPCInfo(req interface{}) (*rpcInfo, error) {
   451  	info, err := newRPCInfoForRequest(req)
   452  	if err != nil {
   453  		return nil, err
   454  	}
   455  
   456  	if info.getTree || info.tokens > 0 {
   457  		switch req := req.(type) {
   458  		case logIDRequest:
   459  			info.treeID = req.GetLogId()
   460  		case mapIDRequest:
   461  			info.treeID = req.GetMapId()
   462  		case treeIDRequest:
   463  			info.treeID = req.GetTreeId()
   464  		case treeRequest:
   465  			info.treeID = req.GetTree().GetTreeId()
   466  		default:
   467  			return nil, status.Errorf(codes.Internal, "cannot retrieve treeID from request: %T", req)
   468  		}
   469  	}
   470  
   471  	if info.tokens > 0 {
   472  		kind := quota.Write
   473  		if info.readonly {
   474  			kind = quota.Read
   475  		}
   476  
   477  		for _, user := range chargedUsers(req) {
   478  			info.specs = append(info.specs, quota.Spec{Group: quota.User, Kind: kind, User: user})
   479  			if len(info.quotaUsers) > 0 {
   480  				info.quotaUsers += "+"
   481  			}
   482  			info.quotaUsers += user
   483  		}
   484  		info.specs = append(info.specs, []quota.Spec{
   485  			{Group: quota.Tree, Kind: kind, TreeID: info.treeID},
   486  			{Group: quota.Global, Kind: kind},
   487  		}...)
   488  	}
   489  
   490  	return info, nil
   491  }
   492  
   493  type logIDRequest interface {
   494  	GetLogId() int64
   495  }
   496  
   497  type mapIDRequest interface {
   498  	GetMapId() int64
   499  }
   500  
   501  type treeIDRequest interface {
   502  	GetTreeId() int64
   503  }
   504  
   505  type treeRequest interface {
   506  	GetTree() *trillian.Tree
   507  }
   508  
   509  // ErrorWrapper is a grpc.UnaryServerInterceptor that wraps the errors emitted by the underlying handler.
   510  func ErrorWrapper(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
   511  	ctx, span := spanFor(ctx, "ErrorWrapper")
   512  	defer span.End()
   513  	rsp, err := handler(ctx, req)
   514  	return rsp, errors.WrapError(err)
   515  }
   516  
   517  func spanFor(ctx context.Context, name string) (context.Context, *trace.Span) {
   518  	return trace.StartSpan(ctx, fmt.Sprintf("%s.%s", traceSpanRoot, name))
   519  }