github.com/zorawar87/trillian@v1.2.1/server/interceptor/interceptor_test.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package interceptor
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/golang/mock/gomock"
    24  	"github.com/golang/protobuf/proto"
    25  	"github.com/golang/protobuf/ptypes"
    26  	"github.com/google/trillian"
    27  	"github.com/google/trillian/quota"
    28  	"github.com/google/trillian/quota/etcd/quotapb"
    29  	"github.com/google/trillian/storage"
    30  	"github.com/google/trillian/storage/testonly"
    31  	"github.com/google/trillian/trees"
    32  	"github.com/kylelemons/godebug/pretty"
    33  	"google.golang.org/grpc"
    34  	"google.golang.org/grpc/codes"
    35  	"google.golang.org/grpc/status"
    36  
    37  	serrors "github.com/google/trillian/server/errors"
    38  	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
    39  )
    40  
    41  func TestServiceName(t *testing.T) {
    42  	for _, tc := range []struct {
    43  		desc   string
    44  		method string
    45  		want   string
    46  	}{
    47  		{desc: "trillian", method: "/trillian.TrillianLog/QueueLeaf", want: "trillian.TrillianLog"},
    48  		{desc: "fullyqualified", method: "/some.package.service/method", want: "some.package.service"},
    49  		{desc: "unqualified", method: "/service.method", want: "service"},
    50  		{desc: "noleadingslash", method: "no.leading.slash/method"},
    51  		{desc: "malformed", method: "/package.service.method"},
    52  	} {
    53  		t.Run(tc.desc, func(t *testing.T) {
    54  			if got, want := serviceName(tc.method), tc.want; got != want {
    55  				t.Errorf("serviceName(%v): %v, want %v", tc.method, got, want)
    56  			}
    57  		})
    58  	}
    59  }
    60  
    61  func TestTrillianInterceptor_TreeInterception(t *testing.T) {
    62  	logTree := proto.Clone(testonly.LogTree).(*trillian.Tree)
    63  	logTree.TreeId = 10
    64  	mapTree := proto.Clone(testonly.MapTree).(*trillian.Tree)
    65  	mapTree.TreeId = 11
    66  	deletedTree := proto.Clone(testonly.LogTree).(*trillian.Tree)
    67  	deletedTree.TreeId = 12
    68  	deletedTree.Deleted = true
    69  	deletedTree.DeleteTime = ptypes.TimestampNow()
    70  	unknownTreeID := int64(999)
    71  
    72  	tests := []struct {
    73  		desc       string
    74  		method     string
    75  		req        interface{}
    76  		handlerErr error
    77  		wantErr    bool
    78  		wantTree   *trillian.Tree
    79  		cancelled  bool
    80  	}{
    81  		// TODO(codingllama): Admin requests don't benefit from tree-reading logic, but we may read
    82  		// their tree IDs for auth purposes.
    83  		{
    84  			desc:   "adminReadByID",
    85  			method: "/trillian.TrillianAdmin/GetTree",
    86  			req:    &trillian.GetTreeRequest{TreeId: logTree.TreeId},
    87  		},
    88  		{
    89  			desc:   "adminWriteByID",
    90  			method: "/trillian.TrillianAdmin/DeleteTree",
    91  			req:    &trillian.DeleteTreeRequest{TreeId: logTree.TreeId},
    92  		},
    93  		{
    94  			desc:   "adminWriteByTree",
    95  			method: "/trillian.TrillianAdmin/UpdateTree",
    96  			req:    &trillian.UpdateTreeRequest{Tree: &trillian.Tree{TreeId: logTree.TreeId}},
    97  		},
    98  		{
    99  			desc:     "logRPC",
   100  			method:   "/trillian.TrillianLog/GetLatestSignedLogRoot",
   101  			req:      &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
   102  			wantTree: logTree,
   103  		},
   104  		{
   105  			desc:     "mapRPC",
   106  			method:   "/trillian.TrillianMap/GetSignedMapRoot",
   107  			req:      &trillian.GetSignedMapRootRequest{MapId: mapTree.TreeId},
   108  			wantTree: mapTree,
   109  		},
   110  		{
   111  			desc:    "unknownRequest",
   112  			req:     "not-a-request",
   113  			wantErr: false,
   114  		},
   115  		{
   116  			desc:    "unknownTree",
   117  			method:  "/trillian.TrillianLog/GetLatestSignedLogRoot",
   118  			req:     &trillian.GetLatestSignedLogRootRequest{LogId: unknownTreeID},
   119  			wantErr: true,
   120  		},
   121  		{
   122  			desc:    "deletedTree",
   123  			method:  "/trillian.TrillianLog/GetLatestSignedLogRoot",
   124  			req:     &trillian.GetLatestSignedLogRootRequest{LogId: deletedTree.TreeId},
   125  			wantErr: true,
   126  		},
   127  		{
   128  			desc:      "cancelled",
   129  			method:    "/trillian.TrillianLog/GetLatestSignedLogRoot",
   130  			req:       &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
   131  			cancelled: true,
   132  			wantErr:   true,
   133  		},
   134  	}
   135  
   136  	ctx := context.Background()
   137  	for _, test := range tests {
   138  		t.Run(test.desc, func(t *testing.T) {
   139  			ctrl := gomock.NewController(t)
   140  			defer ctrl.Finish()
   141  			admin := storage.NewMockAdminStorage(ctrl)
   142  			adminTX := storage.NewMockReadOnlyAdminTX(ctrl)
   143  			admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil)
   144  			adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(logTree, nil)
   145  			adminTX.EXPECT().GetTree(gomock.Any(), mapTree.TreeId).AnyTimes().Return(mapTree, nil)
   146  			adminTX.EXPECT().GetTree(gomock.Any(), deletedTree.TreeId).AnyTimes().Return(deletedTree, nil)
   147  			adminTX.EXPECT().GetTree(gomock.Any(), unknownTreeID).AnyTimes().Return(nil, errors.New("not found"))
   148  			adminTX.EXPECT().Close().AnyTimes().Return(nil)
   149  			adminTX.EXPECT().Commit().AnyTimes().Return(nil)
   150  
   151  			intercept := New(admin, quota.Noop(), false /* quotaDryRun */, nil /* mf */)
   152  			handler := &fakeHandler{resp: "handler response", err: test.handlerErr}
   153  
   154  			if test.cancelled {
   155  				// Use a context that's already been cancelled
   156  				newCtx, cancel := context.WithCancel(ctx)
   157  				cancel()
   158  				ctx = newCtx
   159  			}
   160  
   161  			resp, err := intercept.UnaryInterceptor(ctx, test.req,
   162  				&grpc.UnaryServerInfo{FullMethod: test.method},
   163  				handler.run)
   164  			if hasErr := err != nil && err != test.handlerErr; hasErr != test.wantErr {
   165  				t.Fatalf("UnaryInterceptor() returned err = %v, wantErr = %v", err, test.wantErr)
   166  			} else if hasErr {
   167  				return
   168  			}
   169  
   170  			if !handler.called {
   171  				t.Fatal("handler not called")
   172  			}
   173  			if handler.resp != resp {
   174  				t.Errorf("resp = %v, want = %v", resp, handler.resp)
   175  			}
   176  			if handler.err != err {
   177  				t.Errorf("err = %v, want = %v", err, handler.err)
   178  			}
   179  
   180  			if test.wantTree != nil {
   181  				switch tree, ok := trees.FromContext(handler.ctx); {
   182  				case !ok:
   183  					t.Error("tree not in handler ctx")
   184  				case !proto.Equal(tree, test.wantTree):
   185  					diff := pretty.Compare(tree, test.wantTree)
   186  					t.Errorf("post-FromContext diff:\n%v", diff)
   187  				}
   188  			}
   189  		})
   190  	}
   191  }
   192  
   193  func TestTrillianInterceptor_QuotaInterception(t *testing.T) {
   194  
   195  	logTree := *testonly.LogTree
   196  	logTree.TreeId = 10
   197  
   198  	mapTree := *testonly.MapTree
   199  	mapTree.TreeId = 11
   200  
   201  	preorderedTree := *testonly.PreorderedLogTree
   202  	preorderedTree.TreeId = 12
   203  
   204  	charge1 := "alpaca"
   205  	charge2 := "cama"
   206  	charges := &trillian.ChargeTo{User: []string{charge1, charge2}}
   207  	tests := []struct {
   208  		desc         string
   209  		dryRun       bool
   210  		method       string
   211  		req          interface{}
   212  		specs        []quota.Spec
   213  		getTokensErr error
   214  		wantCode     codes.Code
   215  		wantTokens   int
   216  	}{
   217  		{
   218  			desc:   "logRead",
   219  			method: "/trillian.TrillianLog/GetLatestSignedLogRoot",
   220  			req:    &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
   221  			specs: []quota.Spec{
   222  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   223  				{Group: quota.Global, Kind: quota.Read},
   224  			},
   225  			wantTokens: 1,
   226  		},
   227  		{
   228  			desc:   "logReadIndices",
   229  			method: "/trillian.TrillianLog/GetLeavesByIndex",
   230  			req:    &trillian.GetLeavesByIndexRequest{LogId: logTree.TreeId, LeafIndex: []int64{1, 2, 3}},
   231  			specs: []quota.Spec{
   232  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   233  				{Group: quota.Global, Kind: quota.Read},
   234  			},
   235  			wantTokens: 3,
   236  		},
   237  		{
   238  			desc:   "logReadRange",
   239  			method: "/trillian.TrillianLog/GetLeavesByRange",
   240  			req:    &trillian.GetLeavesByRangeRequest{LogId: logTree.TreeId, Count: 123},
   241  			specs: []quota.Spec{
   242  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   243  				{Group: quota.Global, Kind: quota.Read},
   244  			},
   245  			wantTokens: 123,
   246  		},
   247  		{
   248  			desc:   "logReadNegativeRange",
   249  			method: "/trillian.TrillianLog/GetLeavesByRange",
   250  			req:    &trillian.GetLeavesByRangeRequest{LogId: logTree.TreeId, Count: -123},
   251  			specs: []quota.Spec{
   252  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   253  				{Group: quota.Global, Kind: quota.Read},
   254  			},
   255  			wantTokens: 1,
   256  		},
   257  		{
   258  			desc:   "logReadZeroRange",
   259  			method: "/trillian.TrillianLog/GetLeavesByRange",
   260  			req:    &trillian.GetLeavesByRangeRequest{LogId: logTree.TreeId, Count: 0},
   261  			specs: []quota.Spec{
   262  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   263  				{Group: quota.Global, Kind: quota.Read},
   264  			},
   265  			wantTokens: 1,
   266  		},
   267  		{
   268  			desc:   "logRead with charges",
   269  			method: "/trillian.TrillianLog/GetLatestSignedLogRoot",
   270  			req:    &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId, ChargeTo: charges},
   271  			specs: []quota.Spec{
   272  				{Group: quota.User, Kind: quota.Read, User: charge1},
   273  				{Group: quota.User, Kind: quota.Read, User: charge2},
   274  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   275  				{Group: quota.Global, Kind: quota.Read},
   276  			},
   277  			wantTokens: 1,
   278  		},
   279  		{
   280  			desc:   "logWrite",
   281  			method: "/trillian.TrillianLog/QueueLeaf",
   282  			req:    &trillian.QueueLeafRequest{LogId: logTree.TreeId},
   283  			specs: []quota.Spec{
   284  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   285  				{Group: quota.Global, Kind: quota.Write},
   286  			},
   287  			wantTokens: 1,
   288  		},
   289  		{
   290  			desc:   "logWrite with charges",
   291  			method: "/trillian.TrillianLog/QueueLeaf",
   292  			req:    &trillian.QueueLeafRequest{LogId: logTree.TreeId, ChargeTo: charges},
   293  			specs: []quota.Spec{
   294  				{Group: quota.User, Kind: quota.Write, User: charge1},
   295  				{Group: quota.User, Kind: quota.Write, User: charge2},
   296  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   297  				{Group: quota.Global, Kind: quota.Write},
   298  			},
   299  			wantTokens: 1,
   300  		},
   301  		{
   302  			desc:   "mapRead",
   303  			method: "/trillian.TrillianMap/GetLeaves",
   304  			req:    &trillian.GetMapLeavesRequest{MapId: mapTree.TreeId, Index: [][]byte{{0x01}, {0x02}}},
   305  			specs: []quota.Spec{
   306  				{Group: quota.Tree, Kind: quota.Read, TreeID: mapTree.TreeId},
   307  				{Group: quota.Global, Kind: quota.Read},
   308  			},
   309  			wantTokens: 2,
   310  		},
   311  		{
   312  			desc:   "emptyBatchRequest",
   313  			method: "/trillian.TrillianLog/QueueLeaves",
   314  			req: &trillian.QueueLeavesRequest{
   315  				LogId:  logTree.TreeId,
   316  				Leaves: nil,
   317  			},
   318  		},
   319  		{
   320  			desc:   "batchLogLeavesRequest",
   321  			method: "/trillian.TrillianLog/QueueLeaves",
   322  			req: &trillian.QueueLeavesRequest{
   323  				LogId:  logTree.TreeId,
   324  				Leaves: []*trillian.LogLeaf{{}, {}, {}},
   325  			},
   326  			specs: []quota.Spec{
   327  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   328  				{Group: quota.Global, Kind: quota.Write},
   329  			},
   330  			wantTokens: 3,
   331  		},
   332  		{
   333  			desc:   "batchSequencedLogLeavesRequest",
   334  			method: "/trillian.TrillianLog/AddSequencedLeaves",
   335  			req: &trillian.AddSequencedLeavesRequest{
   336  				LogId:  preorderedTree.TreeId,
   337  				Leaves: []*trillian.LogLeaf{{}, {}, {}},
   338  			},
   339  			specs: []quota.Spec{
   340  				{Group: quota.Tree, Kind: quota.Write, TreeID: preorderedTree.TreeId},
   341  				{Group: quota.Global, Kind: quota.Write},
   342  			},
   343  			wantTokens: 3,
   344  		},
   345  		{
   346  			desc:   "batchLogLeavesRequest with charges",
   347  			method: "/trillian.TrillianLog/QueueLeaves",
   348  			req: &trillian.QueueLeavesRequest{
   349  				LogId:    logTree.TreeId,
   350  				Leaves:   []*trillian.LogLeaf{{}, {}, {}},
   351  				ChargeTo: charges,
   352  			},
   353  			specs: []quota.Spec{
   354  				{Group: quota.User, Kind: quota.Write, User: charge1},
   355  				{Group: quota.User, Kind: quota.Write, User: charge2},
   356  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   357  				{Group: quota.Global, Kind: quota.Write},
   358  			},
   359  			wantTokens: 3,
   360  		},
   361  		{
   362  			desc:   "batchMapLeavesRequest",
   363  			method: "/trillian.TrillianMap/SetLeaves",
   364  			req: &trillian.SetMapLeavesRequest{
   365  				MapId:  mapTree.TreeId,
   366  				Leaves: []*trillian.MapLeaf{{}, {}, {}, {}, {}},
   367  			},
   368  			specs: []quota.Spec{
   369  				{Group: quota.Tree, Kind: quota.Write, TreeID: mapTree.TreeId},
   370  				{Group: quota.Global, Kind: quota.Write},
   371  			},
   372  			wantTokens: 5,
   373  		},
   374  		{
   375  			desc:   "quotaError",
   376  			method: "/trillian.TrillianLog/GetLatestSignedLogRoot",
   377  			req:    &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
   378  			specs: []quota.Spec{
   379  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   380  				{Group: quota.Global, Kind: quota.Read},
   381  			},
   382  			getTokensErr: errors.New("not enough tokens"),
   383  			wantCode:     codes.ResourceExhausted,
   384  			wantTokens:   1,
   385  		},
   386  		{
   387  			desc:   "quotaDryRunError",
   388  			dryRun: true,
   389  			method: "/trillian.TrillianLog/GetLatestSignedLogRoot",
   390  			req:    &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
   391  			specs: []quota.Spec{
   392  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   393  				{Group: quota.Global, Kind: quota.Read},
   394  			},
   395  			getTokensErr: errors.New("not enough tokens"),
   396  			wantTokens:   1,
   397  		},
   398  	}
   399  
   400  	ctx := context.Background()
   401  	for _, test := range tests {
   402  		t.Run(test.desc, func(t *testing.T) {
   403  			ctrl := gomock.NewController(t)
   404  			defer ctrl.Finish()
   405  			admin := storage.NewMockAdminStorage(ctrl)
   406  			adminTX := storage.NewMockReadOnlyAdminTX(ctrl)
   407  			admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil)
   408  			adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(&logTree, nil)
   409  			adminTX.EXPECT().GetTree(gomock.Any(), mapTree.TreeId).AnyTimes().Return(&mapTree, nil)
   410  			adminTX.EXPECT().GetTree(gomock.Any(), preorderedTree.TreeId).AnyTimes().Return(&preorderedTree, nil)
   411  			adminTX.EXPECT().Close().AnyTimes().Return(nil)
   412  			adminTX.EXPECT().Commit().AnyTimes().Return(nil)
   413  
   414  			qm := quota.NewMockManager(ctrl)
   415  			if test.wantTokens > 0 {
   416  				qm.EXPECT().GetTokens(gomock.Any(), test.wantTokens, test.specs).Return(test.getTokensErr)
   417  			}
   418  
   419  			handler := &fakeHandler{resp: "ok"}
   420  			intercept := New(admin, qm, test.dryRun, nil /* mf */)
   421  
   422  			// resp and handler assertions are done by TestTrillianInterceptor_TreeInterception,
   423  			// we're only concerned with the quota logic here.
   424  			_, err := intercept.UnaryInterceptor(ctx, test.req,
   425  				&grpc.UnaryServerInfo{FullMethod: test.method},
   426  				handler.run)
   427  			if s, ok := status.FromError(err); !ok || s.Code() != test.wantCode {
   428  				t.Errorf("UnaryInterceptor() returned err = %q, wantCode = %v", err, test.wantCode)
   429  			}
   430  		})
   431  	}
   432  }
   433  
   434  func TestTrillianInterceptor_QuotaInterception_ReturnsTokens(t *testing.T) {
   435  
   436  	logTree := *testonly.LogTree
   437  	logTree.TreeId = 10
   438  
   439  	tests := []struct {
   440  		desc                         string
   441  		method                       string
   442  		req, resp                    interface{}
   443  		specs                        []quota.Spec
   444  		handlerErr                   error
   445  		wantGetTokens, wantPutTokens int
   446  	}{
   447  		{
   448  			desc:   "badRequest",
   449  			method: "/trillian.TrillianLog/GetLatestSignedLogRoot",
   450  			req:    &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
   451  			specs: []quota.Spec{
   452  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   453  				{Group: quota.Global, Kind: quota.Read},
   454  			},
   455  			handlerErr:    errors.New("bad request"),
   456  			wantGetTokens: 1,
   457  			wantPutTokens: 1,
   458  		},
   459  		{
   460  			desc:   "newLeaf",
   461  			method: "/trillian.TrillianLog/QueueLeaf",
   462  			req:    &trillian.QueueLeafRequest{LogId: logTree.TreeId, Leaf: &trillian.LogLeaf{}},
   463  			resp:   &trillian.QueueLeafResponse{QueuedLeaf: &trillian.QueuedLogLeaf{}},
   464  			specs: []quota.Spec{
   465  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   466  				{Group: quota.Global, Kind: quota.Write},
   467  			},
   468  			wantGetTokens: 1,
   469  		},
   470  		{
   471  			desc:   "duplicateLeaf",
   472  			method: "/trillian.TrillianLog/QueueLeaf",
   473  			req:    &trillian.QueueLeafRequest{LogId: logTree.TreeId},
   474  			resp: &trillian.QueueLeafResponse{
   475  				QueuedLeaf: &trillian.QueuedLogLeaf{
   476  					Status: status.New(codes.AlreadyExists, "duplicate leaf").Proto(),
   477  				},
   478  			},
   479  			specs: []quota.Spec{
   480  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   481  				{Group: quota.Global, Kind: quota.Write},
   482  			},
   483  			wantGetTokens: 1,
   484  			wantPutTokens: 1,
   485  		},
   486  		{
   487  			desc:   "newLeaves",
   488  			method: "/trillian.TrillianLog/QueueLeaves",
   489  			req: &trillian.QueueLeavesRequest{
   490  				LogId:  logTree.TreeId,
   491  				Leaves: []*trillian.LogLeaf{{}, {}, {}},
   492  			},
   493  			resp: &trillian.QueueLeavesResponse{
   494  				QueuedLeaves: []*trillian.QueuedLogLeaf{{}, {}, {}}, // No explicit Status means OK
   495  			},
   496  			specs: []quota.Spec{
   497  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   498  				{Group: quota.Global, Kind: quota.Write},
   499  			},
   500  			wantGetTokens: 3,
   501  		},
   502  		{
   503  			desc:   "duplicateLeaves",
   504  			method: "/trillian.TrillianLog/QueueLeaves",
   505  			req: &trillian.QueueLeavesRequest{
   506  				LogId:  logTree.TreeId,
   507  				Leaves: []*trillian.LogLeaf{{}, {}, {}},
   508  			},
   509  			resp: &trillian.QueueLeavesResponse{
   510  				QueuedLeaves: []*trillian.QueuedLogLeaf{
   511  					{Status: status.New(codes.AlreadyExists, "duplicate leaf").Proto()},
   512  					{Status: status.New(codes.AlreadyExists, "duplicate leaf").Proto()},
   513  					{},
   514  				},
   515  			},
   516  			specs: []quota.Spec{
   517  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   518  				{Group: quota.Global, Kind: quota.Write},
   519  			},
   520  			wantGetTokens: 3,
   521  			wantPutTokens: 2,
   522  		},
   523  		{
   524  			desc:   "badQueueLeavesRequest",
   525  			method: "/trillian.TrillianLog/QueueLeaves",
   526  			req: &trillian.QueueLeavesRequest{
   527  				LogId:  logTree.TreeId,
   528  				Leaves: []*trillian.LogLeaf{{}, {}, {}},
   529  			},
   530  			specs: []quota.Spec{
   531  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   532  				{Group: quota.Global, Kind: quota.Write},
   533  			},
   534  			handlerErr:    errors.New("bad request"),
   535  			wantGetTokens: 3,
   536  			wantPutTokens: 3,
   537  		},
   538  	}
   539  
   540  	defer func(timeout time.Duration) {
   541  		PutTokensTimeout = timeout
   542  	}(PutTokensTimeout)
   543  	PutTokensTimeout = 5 * time.Second
   544  
   545  	// Use a ctx with a timeout smaller than PutTokensTimeout. Not too short or
   546  	// spurious failures will occur when the deadline expires.
   547  	ctx, cancel := context.WithTimeout(context.Background(), PutTokensTimeout-2*time.Second)
   548  	defer cancel()
   549  
   550  	for _, test := range tests {
   551  		t.Run(test.desc, func(t *testing.T) {
   552  			ctrl := gomock.NewController(t)
   553  			defer ctrl.Finish()
   554  			admin := storage.NewMockAdminStorage(ctrl)
   555  			adminTX := storage.NewMockReadOnlyAdminTX(ctrl)
   556  			admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil)
   557  			adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(&logTree, nil)
   558  			adminTX.EXPECT().Close().AnyTimes().Return(nil)
   559  			adminTX.EXPECT().Commit().AnyTimes().Return(nil)
   560  			putTokensCh := make(chan bool, 1)
   561  			wantDeadline := time.Now().Add(PutTokensTimeout)
   562  
   563  			qm := quota.NewMockManager(ctrl)
   564  			if test.wantGetTokens > 0 {
   565  				qm.EXPECT().GetTokens(gomock.Any(), test.wantGetTokens, test.specs).Return(nil)
   566  			}
   567  			if test.wantPutTokens > 0 {
   568  				qm.EXPECT().PutTokens(gomock.Any(), test.wantPutTokens, test.specs).Do(func(ctx context.Context, numTokens int, specs []quota.Spec) {
   569  					switch d, ok := ctx.Deadline(); {
   570  					case !ok:
   571  						t.Errorf("PutTokens() ctx has no deadline: %v", ctx)
   572  					case d.Before(wantDeadline):
   573  						t.Errorf("PutTokens() ctx deadline too short, got %v, want >= %v", d, wantDeadline)
   574  					}
   575  					putTokensCh <- true
   576  				}).Return(nil)
   577  			}
   578  
   579  			handler := &fakeHandler{resp: test.resp, err: test.handlerErr}
   580  			intercept := New(admin, qm, false /* quotaDryRun */, nil /* mf */)
   581  
   582  			if _, err := intercept.UnaryInterceptor(ctx, test.req,
   583  				&grpc.UnaryServerInfo{FullMethod: test.method},
   584  				handler.run); err != test.handlerErr {
   585  				t.Errorf("UnaryInterceptor() returned err = [%v], want = [%v]", err, test.handlerErr)
   586  			}
   587  
   588  			// PutTokens may be delegated to a separate goroutine. Give it some time to complete.
   589  			select {
   590  			case <-putTokensCh:
   591  				// OK
   592  			case <-time.After(1 * time.Second):
   593  				// No need to error here, gomock will fail if the call is missing.
   594  			}
   595  		})
   596  	}
   597  }
   598  
   599  func TestTrillianInterceptor_NotIntercepted(t *testing.T) {
   600  	tests := []struct {
   601  		method string
   602  		req    interface{}
   603  	}{
   604  		// Admin
   605  		{method: "/trillian.TrillianAdmin/CreateTree", req: &trillian.CreateTreeRequest{}},
   606  		{method: "/trillian.TrillianAdmin/ListTrees", req: &trillian.ListTreesRequest{}},
   607  		// Quota
   608  		{method: "/quotapb.Quota/CreateConfig", req: &quotapb.CreateConfigRequest{}},
   609  		{method: "/quotapb.Quota/DeleteConfig", req: &quotapb.DeleteConfigRequest{}},
   610  		{method: "/quotapb.Quota/GetConfig", req: &quotapb.GetConfigRequest{}},
   611  		{method: "/quotapb.Quota/ListConfigs", req: &quotapb.ListConfigsRequest{}},
   612  		{method: "/quotapb.Quota/UpdateConfig", req: &quotapb.UpdateConfigRequest{}},
   613  	}
   614  
   615  	ctx := context.Background()
   616  	for _, test := range tests {
   617  		handler := &fakeHandler{}
   618  		intercept := New(nil /* admin */, quota.Noop(), false /* quotaDryRun */, nil /* mf */)
   619  		if _, err := intercept.UnaryInterceptor(ctx, test.req,
   620  			&grpc.UnaryServerInfo{FullMethod: test.method},
   621  			handler.run); err != nil {
   622  			t.Errorf("UnaryInterceptor(%#v) returned err = %v", test.req, err)
   623  		}
   624  		if !handler.called {
   625  			t.Errorf("UnaryInterceptor(%#v): handler not called", test.req)
   626  		}
   627  	}
   628  }
   629  
   630  // TestTrillianInterceptor_BeforeAfter tests a few Before/After interactions that are
   631  // difficult/impossible to get unless the methods are called separately (i.e., not via
   632  // UnaryInterceptor()).
   633  func TestTrillianInterceptor_BeforeAfter(t *testing.T) {
   634  	logTree := *testonly.LogTree
   635  	logTree.TreeId = 10
   636  
   637  	qm := quota.Noop()
   638  
   639  	tests := []struct {
   640  		desc          string
   641  		req, resp     interface{}
   642  		handlerErr    error
   643  		wantBeforeErr bool
   644  	}{
   645  		{
   646  			desc: "success",
   647  			req:  &trillian.CreateTreeRequest{},
   648  			resp: &trillian.Tree{},
   649  		},
   650  		{
   651  			desc:          "badRequest",
   652  			req:           "bad",
   653  			resp:          nil,
   654  			handlerErr:    errors.New("bad"),
   655  			wantBeforeErr: true,
   656  		},
   657  	}
   658  
   659  	ctx := context.Background()
   660  	for _, test := range tests {
   661  		t.Run(test.desc, func(t *testing.T) {
   662  			ctrl := gomock.NewController(t)
   663  			defer ctrl.Finish()
   664  			admin := storage.NewMockAdminStorage(ctrl)
   665  			adminTX := storage.NewMockReadOnlyAdminTX(ctrl)
   666  			admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil)
   667  			adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(&logTree, nil)
   668  			adminTX.EXPECT().Close().AnyTimes().Return(nil)
   669  			adminTX.EXPECT().Commit().AnyTimes().Return(nil)
   670  
   671  			intercept := New(admin, qm, false /* quotaDryRun */, nil /* mf */)
   672  			p := intercept.NewProcessor()
   673  
   674  			_, err := p.Before(ctx, test.req, "/trillian.TrillianLog/foo")
   675  			if gotErr := err != nil; gotErr != test.wantBeforeErr {
   676  				t.Fatalf("Before() returned err = %v, wantErr = %v", err, test.wantBeforeErr)
   677  			}
   678  
   679  			// Other TrillianInterceptor tests assert After side-effects more in-depth, silently
   680  			// returning is good enough here.
   681  			p.After(ctx, test.resp, "", test.handlerErr)
   682  		})
   683  	}
   684  }
   685  
   686  func TestCombine(t *testing.T) {
   687  	i1 := &fakeInterceptor{key: "key1", val: "foo"}
   688  	i2 := &fakeInterceptor{key: "key2", val: "bar"}
   689  	i3 := &fakeInterceptor{key: "key3", val: "baz"}
   690  	e1 := &fakeInterceptor{err: errors.New("intercept error")}
   691  
   692  	handlerErr := errors.New("handler error")
   693  
   694  	tests := []struct {
   695  		desc         string
   696  		interceptors []*fakeInterceptor
   697  		handlerErr   error
   698  		wantCalled   int
   699  		wantErr      error
   700  	}{
   701  		{
   702  			desc: "noInterceptors",
   703  		},
   704  		{
   705  			desc:         "single",
   706  			interceptors: []*fakeInterceptor{i1},
   707  			wantCalled:   1,
   708  		},
   709  		{
   710  			desc:         "multi1",
   711  			interceptors: []*fakeInterceptor{i1, i2, i3},
   712  			wantCalled:   3,
   713  		},
   714  		{
   715  			desc:         "multi2",
   716  			interceptors: []*fakeInterceptor{i3, i1, i2},
   717  			wantCalled:   3,
   718  		},
   719  		{
   720  			desc:         "handlerErr",
   721  			interceptors: []*fakeInterceptor{i1, i2},
   722  			handlerErr:   handlerErr,
   723  			wantCalled:   2,
   724  			wantErr:      handlerErr,
   725  		},
   726  		{
   727  			desc:         "interceptErr",
   728  			interceptors: []*fakeInterceptor{i1, e1, i2},
   729  			wantCalled:   2,
   730  			wantErr:      e1.err,
   731  		},
   732  	}
   733  
   734  	ctx := context.Background()
   735  	req := "request"
   736  	info := &grpc.UnaryServerInfo{}
   737  	for _, test := range tests {
   738  		t.Run(test.desc, func(t *testing.T) {
   739  			if l := len(test.interceptors); l < test.wantCalled {
   740  				t.Fatalf("len(interceptors) = %v, want >= %v", l, test.wantCalled)
   741  			}
   742  
   743  			intercepts := []grpc.UnaryServerInterceptor{}
   744  			for _, i := range test.interceptors {
   745  				i.called = false
   746  				intercepts = append(intercepts, i.run)
   747  			}
   748  			intercept := grpc_middleware.ChainUnaryServer(intercepts...)
   749  
   750  			handler := &fakeHandler{resp: "response", err: test.handlerErr}
   751  			resp, err := intercept(ctx, req, info, handler.run)
   752  			if err != test.wantErr {
   753  				t.Fatalf("err = %q, want = %q", err, test.wantErr)
   754  			}
   755  
   756  			called := 0
   757  			callsStopped := false
   758  			for _, i := range test.interceptors {
   759  				switch {
   760  				case i.called:
   761  					if callsStopped {
   762  						t.Errorf("interceptor called out of order: %v", i)
   763  					}
   764  					called++
   765  				case !i.called:
   766  					// No calls should have happened from here on
   767  					callsStopped = true
   768  				}
   769  			}
   770  			if called != test.wantCalled {
   771  				t.Errorf("called %v interceptors, want = %v", called, test.wantCalled)
   772  			}
   773  
   774  			// Assertions below this point assume that the handler was called (ie, all
   775  			// interceptors succeeded).
   776  			if err != nil && err != test.handlerErr {
   777  				return
   778  			}
   779  
   780  			if resp != handler.resp {
   781  				t.Errorf("resp = %v, want = %v", resp, handler.resp)
   782  			}
   783  
   784  			// Chain the ctxs for all called interceptors and verify it got through to the
   785  			// handler.
   786  			wantCtx := ctx
   787  			for _, i := range test.interceptors {
   788  				h := &fakeHandler{resp: "ok"}
   789  				i.called = false
   790  				_, err = i.run(wantCtx, req, info, h.run)
   791  				if err != nil {
   792  					t.Fatalf("unexpected handler failure: %v", err)
   793  				}
   794  				wantCtx = h.ctx
   795  			}
   796  			if diff := pretty.Compare(handler.ctx, wantCtx); diff != "" {
   797  				t.Errorf("handler ctx diff:\n%v", diff)
   798  			}
   799  		})
   800  	}
   801  }
   802  
   803  func TestErrorWrapper(t *testing.T) {
   804  	badLlamaErr := status.Errorf(codes.InvalidArgument, "Bad Llama")
   805  	tests := []struct {
   806  		desc         string
   807  		resp         interface{}
   808  		err, wantErr error
   809  	}{
   810  		{
   811  			desc: "success",
   812  			resp: "ok",
   813  		},
   814  		{
   815  			desc:    "error",
   816  			err:     badLlamaErr,
   817  			wantErr: serrors.WrapError(badLlamaErr),
   818  		},
   819  	}
   820  	ctx := context.Background()
   821  	for _, test := range tests {
   822  		t.Run(test.desc, func(t *testing.T) {
   823  			handler := fakeHandler{resp: test.resp, err: test.err}
   824  			resp, err := ErrorWrapper(ctx, "req", &grpc.UnaryServerInfo{}, handler.run)
   825  			if resp != test.resp {
   826  				t.Errorf("resp = %v, want = %v", resp, test.resp)
   827  			}
   828  			if diff := pretty.Compare(err, test.wantErr); diff != "" {
   829  				t.Errorf("post-WrapErrors diff:\n%v", diff)
   830  			}
   831  		})
   832  	}
   833  }
   834  
   835  type fakeHandler struct {
   836  	called bool
   837  	resp   interface{}
   838  	err    error
   839  	// Attributes recorded by run calls
   840  	ctx context.Context
   841  	req interface{}
   842  }
   843  
   844  func (f *fakeHandler) run(ctx context.Context, req interface{}) (interface{}, error) {
   845  	if f.called {
   846  		panic("handler already called; either create a new handler or set called to false before reusing")
   847  	}
   848  	f.called = true
   849  	f.ctx = ctx
   850  	f.req = req
   851  	return f.resp, f.err
   852  }
   853  
   854  type fakeInterceptor struct {
   855  	key    interface{}
   856  	val    interface{}
   857  	called bool
   858  	err    error
   859  }
   860  
   861  func (f *fakeInterceptor) run(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
   862  	if f.called {
   863  		panic("interceptor already called; either create a new interceptor or set called to false before reusing")
   864  	}
   865  	f.called = true
   866  	if f.err != nil {
   867  		return nil, f.err
   868  	}
   869  	return handler(context.WithValue(ctx, f.key, f.val), req)
   870  }