github.com/letsencrypt/trillian@v1.1.2-0.20180615153820-ae375a99d36a/server/interceptor/interceptor_test.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package interceptor
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/golang/mock/gomock"
    24  	"github.com/golang/protobuf/proto"
    25  	"github.com/golang/protobuf/ptypes"
    26  	"github.com/google/trillian"
    27  	"github.com/google/trillian/quota"
    28  	"github.com/google/trillian/quota/etcd/quotapb"
    29  	"github.com/google/trillian/storage"
    30  	"github.com/google/trillian/storage/testonly"
    31  	"github.com/google/trillian/trees"
    32  	"github.com/kylelemons/godebug/pretty"
    33  	"google.golang.org/grpc"
    34  	"google.golang.org/grpc/codes"
    35  	"google.golang.org/grpc/status"
    36  
    37  	serrors "github.com/google/trillian/server/errors"
    38  )
    39  
    40  func TestTrillianInterceptor_TreeInterception(t *testing.T) {
    41  	logTree := proto.Clone(testonly.LogTree).(*trillian.Tree)
    42  	logTree.TreeId = 10
    43  	mapTree := proto.Clone(testonly.MapTree).(*trillian.Tree)
    44  	mapTree.TreeId = 11
    45  	deletedTree := proto.Clone(testonly.LogTree).(*trillian.Tree)
    46  	deletedTree.TreeId = 12
    47  	deletedTree.Deleted = true
    48  	deletedTree.DeleteTime = ptypes.TimestampNow()
    49  	unknownTreeID := int64(999)
    50  
    51  	tests := []struct {
    52  		desc       string
    53  		req        interface{}
    54  		handlerErr error
    55  		wantErr    bool
    56  		wantTree   *trillian.Tree
    57  		cancelled  bool
    58  	}{
    59  		// TODO(codingllama): Admin requests don't benefit from tree-reading logic, but we may read
    60  		// their tree IDs for auth purposes.
    61  		{
    62  			desc: "adminReadByID",
    63  			req:  &trillian.GetTreeRequest{TreeId: logTree.TreeId},
    64  		},
    65  		{
    66  			desc: "adminWriteByID",
    67  			req:  &trillian.DeleteTreeRequest{TreeId: logTree.TreeId},
    68  		},
    69  		{
    70  			desc: "adminWriteByTree",
    71  			req:  &trillian.UpdateTreeRequest{Tree: &trillian.Tree{TreeId: logTree.TreeId}},
    72  		},
    73  		{
    74  			desc:     "logRPC",
    75  			req:      &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
    76  			wantTree: logTree,
    77  		},
    78  		{
    79  			desc:     "mapRPC",
    80  			req:      &trillian.GetSignedMapRootRequest{MapId: mapTree.TreeId},
    81  			wantTree: mapTree,
    82  		},
    83  		{
    84  			desc:    "unknownRequest",
    85  			req:     "not-a-request",
    86  			wantErr: true,
    87  		},
    88  		{
    89  			desc:    "unknownTree",
    90  			req:     &trillian.GetLatestSignedLogRootRequest{LogId: unknownTreeID},
    91  			wantErr: true,
    92  		},
    93  		{
    94  			desc:    "deletedTree",
    95  			req:     &trillian.GetLatestSignedLogRootRequest{LogId: deletedTree.TreeId},
    96  			wantErr: true,
    97  		},
    98  		{
    99  			desc:      "cancelled",
   100  			req:       &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
   101  			cancelled: true,
   102  			wantErr:   true,
   103  		},
   104  	}
   105  
   106  	ctx := context.Background()
   107  	for _, test := range tests {
   108  		t.Run(test.desc, func(t *testing.T) {
   109  			ctrl := gomock.NewController(t)
   110  			defer ctrl.Finish()
   111  			admin := storage.NewMockAdminStorage(ctrl)
   112  			adminTX := storage.NewMockReadOnlyAdminTX(ctrl)
   113  			admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil)
   114  			adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(logTree, nil)
   115  			adminTX.EXPECT().GetTree(gomock.Any(), mapTree.TreeId).AnyTimes().Return(mapTree, nil)
   116  			adminTX.EXPECT().GetTree(gomock.Any(), deletedTree.TreeId).AnyTimes().Return(deletedTree, nil)
   117  			adminTX.EXPECT().GetTree(gomock.Any(), unknownTreeID).AnyTimes().Return(nil, errors.New("not found"))
   118  			adminTX.EXPECT().Close().AnyTimes().Return(nil)
   119  			adminTX.EXPECT().Commit().AnyTimes().Return(nil)
   120  
   121  			intercept := New(admin, quota.Noop(), false /* quotaDryRun */, nil /* mf */)
   122  			handler := &fakeHandler{resp: "handler response", err: test.handlerErr}
   123  
   124  			if test.cancelled {
   125  				// Use a context that's already been cancelled
   126  				newCtx, cancel := context.WithCancel(ctx)
   127  				cancel()
   128  				ctx = newCtx
   129  			}
   130  
   131  			resp, err := intercept.UnaryInterceptor(ctx, test.req, &grpc.UnaryServerInfo{}, handler.run)
   132  			if hasErr := err != nil && err != test.handlerErr; hasErr != test.wantErr {
   133  				t.Fatalf("UnaryInterceptor() returned err = %v, wantErr = %v", err, test.wantErr)
   134  			} else if hasErr {
   135  				return
   136  			}
   137  
   138  			if !handler.called {
   139  				t.Fatal("handler not called")
   140  			}
   141  			if handler.resp != resp {
   142  				t.Errorf("resp = %v, want = %v", resp, handler.resp)
   143  			}
   144  			if handler.err != err {
   145  				t.Errorf("err = %v, want = %v", err, handler.err)
   146  			}
   147  
   148  			if test.wantTree != nil {
   149  				switch tree, ok := trees.FromContext(handler.ctx); {
   150  				case !ok:
   151  					t.Error("tree not in handler ctx")
   152  				case !proto.Equal(tree, test.wantTree):
   153  					diff := pretty.Compare(tree, test.wantTree)
   154  					t.Errorf("post-FromContext diff:\n%v", diff)
   155  				}
   156  			}
   157  		})
   158  	}
   159  }
   160  
   161  func TestTrillianInterceptor_QuotaInterception(t *testing.T) {
   162  
   163  	logTree := *testonly.LogTree
   164  	logTree.TreeId = 10
   165  
   166  	mapTree := *testonly.MapTree
   167  	mapTree.TreeId = 11
   168  
   169  	preorderedTree := *testonly.PreorderedLogTree
   170  	preorderedTree.TreeId = 12
   171  
   172  	user := "llama"
   173  	charge1 := "alpaca"
   174  	charge2 := "cama"
   175  	charges := &trillian.ChargeTo{User: []string{charge1, charge2}}
   176  	tests := []struct {
   177  		desc         string
   178  		dryRun       bool
   179  		req          interface{}
   180  		specs        []quota.Spec
   181  		getTokensErr error
   182  		wantCode     codes.Code
   183  		wantTokens   int
   184  	}{
   185  		{
   186  			desc: "logRead",
   187  			req:  &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
   188  			specs: []quota.Spec{
   189  				{Group: quota.User, Kind: quota.Read, User: user},
   190  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   191  				{Group: quota.Global, Kind: quota.Read},
   192  			},
   193  			wantTokens: 1,
   194  		},
   195  		{
   196  			desc: "logReadIndices",
   197  			req:  &trillian.GetLeavesByIndexRequest{LogId: logTree.TreeId, LeafIndex: []int64{1, 2, 3}},
   198  			specs: []quota.Spec{
   199  				{Group: quota.User, Kind: quota.Read, User: user},
   200  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   201  				{Group: quota.Global, Kind: quota.Read},
   202  			},
   203  			wantTokens: 3,
   204  		},
   205  		{
   206  			desc: "logReadRange",
   207  			req:  &trillian.GetLeavesByRangeRequest{LogId: logTree.TreeId, Count: 123},
   208  			specs: []quota.Spec{
   209  				{Group: quota.User, Kind: quota.Read, User: user},
   210  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   211  				{Group: quota.Global, Kind: quota.Read},
   212  			},
   213  			wantTokens: 123,
   214  		},
   215  		{
   216  			desc: "logReadNegativeRange",
   217  			req:  &trillian.GetLeavesByRangeRequest{LogId: logTree.TreeId, Count: -123},
   218  			specs: []quota.Spec{
   219  				{Group: quota.User, Kind: quota.Read, User: user},
   220  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   221  				{Group: quota.Global, Kind: quota.Read},
   222  			},
   223  			wantTokens: 1,
   224  		},
   225  		{
   226  			desc: "logReadZeroRange",
   227  			req:  &trillian.GetLeavesByRangeRequest{LogId: logTree.TreeId, Count: 0},
   228  			specs: []quota.Spec{
   229  				{Group: quota.User, Kind: quota.Read, User: user},
   230  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   231  				{Group: quota.Global, Kind: quota.Read},
   232  			},
   233  			wantTokens: 1,
   234  		},
   235  		{
   236  			desc: "logRead with charges",
   237  			req:  &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId, ChargeTo: charges},
   238  			specs: []quota.Spec{
   239  				{Group: quota.User, Kind: quota.Read, User: charge1},
   240  				{Group: quota.User, Kind: quota.Read, User: charge2},
   241  				{Group: quota.User, Kind: quota.Read, User: user},
   242  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   243  				{Group: quota.Global, Kind: quota.Read},
   244  			},
   245  			wantTokens: 1,
   246  		},
   247  		{
   248  			desc: "logWrite",
   249  			req:  &trillian.QueueLeafRequest{LogId: logTree.TreeId},
   250  			specs: []quota.Spec{
   251  				{Group: quota.User, Kind: quota.Write, User: user},
   252  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   253  				{Group: quota.Global, Kind: quota.Write},
   254  			},
   255  			wantTokens: 1,
   256  		},
   257  		{
   258  			desc: "logWrite with charges",
   259  			req:  &trillian.QueueLeafRequest{LogId: logTree.TreeId, ChargeTo: charges},
   260  			specs: []quota.Spec{
   261  				{Group: quota.User, Kind: quota.Write, User: charge1},
   262  				{Group: quota.User, Kind: quota.Write, User: charge2},
   263  				{Group: quota.User, Kind: quota.Write, User: user},
   264  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   265  				{Group: quota.Global, Kind: quota.Write},
   266  			},
   267  			wantTokens: 1,
   268  		},
   269  		{
   270  			desc: "mapRead",
   271  			req:  &trillian.GetMapLeavesRequest{MapId: mapTree.TreeId, Index: [][]byte{{0x01}, {0x02}}},
   272  			specs: []quota.Spec{
   273  				{Group: quota.User, Kind: quota.Read, User: user},
   274  				{Group: quota.Tree, Kind: quota.Read, TreeID: mapTree.TreeId},
   275  				{Group: quota.Global, Kind: quota.Read},
   276  			},
   277  			wantTokens: 2,
   278  		},
   279  		{
   280  			desc: "emptyBatchRequest",
   281  			req: &trillian.QueueLeavesRequest{
   282  				LogId:  logTree.TreeId,
   283  				Leaves: nil,
   284  			},
   285  		},
   286  		{
   287  			desc: "batchLogLeavesRequest",
   288  			req: &trillian.QueueLeavesRequest{
   289  				LogId:  logTree.TreeId,
   290  				Leaves: []*trillian.LogLeaf{{}, {}, {}},
   291  			},
   292  			specs: []quota.Spec{
   293  				{Group: quota.User, Kind: quota.Write, User: user},
   294  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   295  				{Group: quota.Global, Kind: quota.Write},
   296  			},
   297  			wantTokens: 3,
   298  		},
   299  		{
   300  			desc: "batchSequencedLogLeavesRequest",
   301  			req: &trillian.AddSequencedLeavesRequest{
   302  				LogId:  preorderedTree.TreeId,
   303  				Leaves: []*trillian.LogLeaf{{}, {}, {}},
   304  			},
   305  			specs: []quota.Spec{
   306  				{Group: quota.User, Kind: quota.Write, User: user},
   307  				{Group: quota.Tree, Kind: quota.Write, TreeID: preorderedTree.TreeId},
   308  				{Group: quota.Global, Kind: quota.Write},
   309  			},
   310  			wantTokens: 3,
   311  		},
   312  		{
   313  			desc: "batchLogLeavesRequest with charges",
   314  			req: &trillian.QueueLeavesRequest{
   315  				LogId:    logTree.TreeId,
   316  				Leaves:   []*trillian.LogLeaf{{}, {}, {}},
   317  				ChargeTo: charges,
   318  			},
   319  			specs: []quota.Spec{
   320  				{Group: quota.User, Kind: quota.Write, User: charge1},
   321  				{Group: quota.User, Kind: quota.Write, User: charge2},
   322  				{Group: quota.User, Kind: quota.Write, User: user},
   323  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   324  				{Group: quota.Global, Kind: quota.Write},
   325  			},
   326  			wantTokens: 3,
   327  		},
   328  		{
   329  			desc: "batchMapLeavesRequest",
   330  			req: &trillian.SetMapLeavesRequest{
   331  				MapId:  mapTree.TreeId,
   332  				Leaves: []*trillian.MapLeaf{{}, {}, {}, {}, {}},
   333  			},
   334  			specs: []quota.Spec{
   335  				{Group: quota.User, Kind: quota.Write, User: user},
   336  				{Group: quota.Tree, Kind: quota.Write, TreeID: mapTree.TreeId},
   337  				{Group: quota.Global, Kind: quota.Write},
   338  			},
   339  			wantTokens: 5,
   340  		},
   341  		{
   342  			desc: "quotaError",
   343  			req:  &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
   344  			specs: []quota.Spec{
   345  				{Group: quota.User, Kind: quota.Read, User: user},
   346  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   347  				{Group: quota.Global, Kind: quota.Read},
   348  			},
   349  			getTokensErr: errors.New("not enough tokens"),
   350  			wantCode:     codes.ResourceExhausted,
   351  			wantTokens:   1,
   352  		},
   353  		{
   354  			desc:   "quotaDryRunError",
   355  			dryRun: true,
   356  			req:    &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
   357  			specs: []quota.Spec{
   358  				{Group: quota.User, Kind: quota.Read, User: user},
   359  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   360  				{Group: quota.Global, Kind: quota.Read},
   361  			},
   362  			getTokensErr: errors.New("not enough tokens"),
   363  			wantTokens:   1,
   364  		},
   365  	}
   366  
   367  	ctx := context.Background()
   368  	for _, test := range tests {
   369  		t.Run(test.desc, func(t *testing.T) {
   370  			ctrl := gomock.NewController(t)
   371  			defer ctrl.Finish()
   372  			admin := storage.NewMockAdminStorage(ctrl)
   373  			adminTX := storage.NewMockReadOnlyAdminTX(ctrl)
   374  			admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil)
   375  			adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(&logTree, nil)
   376  			adminTX.EXPECT().GetTree(gomock.Any(), mapTree.TreeId).AnyTimes().Return(&mapTree, nil)
   377  			adminTX.EXPECT().GetTree(gomock.Any(), preorderedTree.TreeId).AnyTimes().Return(&preorderedTree, nil)
   378  			adminTX.EXPECT().Close().AnyTimes().Return(nil)
   379  			adminTX.EXPECT().Commit().AnyTimes().Return(nil)
   380  
   381  			qm := quota.NewMockManager(ctrl)
   382  			qm.EXPECT().GetUser(gomock.Any(), test.req).MaxTimes(1).Return(user)
   383  			if test.wantTokens > 0 {
   384  				qm.EXPECT().GetTokens(gomock.Any(), test.wantTokens, test.specs).Return(test.getTokensErr)
   385  			}
   386  
   387  			handler := &fakeHandler{resp: "ok"}
   388  			intercept := New(admin, qm, test.dryRun, nil /* mf */)
   389  
   390  			// resp and handler assertions are done by TestTrillianInterceptor_TreeInterception,
   391  			// we're only concerned with the quota logic here.
   392  			_, err := intercept.UnaryInterceptor(ctx, test.req, &grpc.UnaryServerInfo{}, handler.run)
   393  			if s, ok := status.FromError(err); !ok || s.Code() != test.wantCode {
   394  				t.Errorf("UnaryInterceptor() returned err = %q, wantCode = %v", err, test.wantCode)
   395  			}
   396  		})
   397  	}
   398  }
   399  
   400  func TestTrillianInterceptor_QuotaInterception_ReturnsTokens(t *testing.T) {
   401  
   402  	logTree := *testonly.LogTree
   403  	logTree.TreeId = 10
   404  
   405  	user := "llama"
   406  	tests := []struct {
   407  		desc                         string
   408  		req, resp                    interface{}
   409  		specs                        []quota.Spec
   410  		handlerErr                   error
   411  		wantGetTokens, wantPutTokens int
   412  	}{
   413  		{
   414  			desc: "badRequest",
   415  			req:  &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId},
   416  			specs: []quota.Spec{
   417  				{Group: quota.User, Kind: quota.Read, User: user},
   418  				{Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId},
   419  				{Group: quota.Global, Kind: quota.Read},
   420  			},
   421  			handlerErr:    errors.New("bad request"),
   422  			wantGetTokens: 1,
   423  			wantPutTokens: 1,
   424  		},
   425  		{
   426  			desc: "newLeaf",
   427  			req:  &trillian.QueueLeafRequest{LogId: logTree.TreeId, Leaf: &trillian.LogLeaf{}},
   428  			resp: &trillian.QueueLeafResponse{QueuedLeaf: &trillian.QueuedLogLeaf{}},
   429  			specs: []quota.Spec{
   430  				{Group: quota.User, Kind: quota.Write, User: user},
   431  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   432  				{Group: quota.Global, Kind: quota.Write},
   433  			},
   434  			wantGetTokens: 1,
   435  		},
   436  		{
   437  			desc: "duplicateLeaf",
   438  			req:  &trillian.QueueLeafRequest{LogId: logTree.TreeId},
   439  			resp: &trillian.QueueLeafResponse{
   440  				QueuedLeaf: &trillian.QueuedLogLeaf{
   441  					Status: status.New(codes.AlreadyExists, "duplicate leaf").Proto(),
   442  				},
   443  			},
   444  			specs: []quota.Spec{
   445  				{Group: quota.User, Kind: quota.Write, User: user},
   446  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   447  				{Group: quota.Global, Kind: quota.Write},
   448  			},
   449  			wantGetTokens: 1,
   450  			wantPutTokens: 1,
   451  		},
   452  		{
   453  			desc: "newLeaves",
   454  			req: &trillian.QueueLeavesRequest{
   455  				LogId:  logTree.TreeId,
   456  				Leaves: []*trillian.LogLeaf{{}, {}, {}},
   457  			},
   458  			resp: &trillian.QueueLeavesResponse{
   459  				QueuedLeaves: []*trillian.QueuedLogLeaf{{}, {}, {}}, // No explicit Status means OK
   460  			},
   461  			specs: []quota.Spec{
   462  				{Group: quota.User, Kind: quota.Write, User: user},
   463  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   464  				{Group: quota.Global, Kind: quota.Write},
   465  			},
   466  			wantGetTokens: 3,
   467  		},
   468  		{
   469  			desc: "duplicateLeaves",
   470  			req: &trillian.QueueLeavesRequest{
   471  				LogId:  logTree.TreeId,
   472  				Leaves: []*trillian.LogLeaf{{}, {}, {}},
   473  			},
   474  			resp: &trillian.QueueLeavesResponse{
   475  				QueuedLeaves: []*trillian.QueuedLogLeaf{
   476  					{Status: status.New(codes.AlreadyExists, "duplicate leaf").Proto()},
   477  					{Status: status.New(codes.AlreadyExists, "duplicate leaf").Proto()},
   478  					{},
   479  				},
   480  			},
   481  			specs: []quota.Spec{
   482  				{Group: quota.User, Kind: quota.Write, User: user},
   483  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   484  				{Group: quota.Global, Kind: quota.Write},
   485  			},
   486  			wantGetTokens: 3,
   487  			wantPutTokens: 2,
   488  		},
   489  		{
   490  			desc: "badQueueLeavesRequest",
   491  			req: &trillian.QueueLeavesRequest{
   492  				LogId:  logTree.TreeId,
   493  				Leaves: []*trillian.LogLeaf{{}, {}, {}},
   494  			},
   495  			specs: []quota.Spec{
   496  				{Group: quota.User, Kind: quota.Write, User: user},
   497  				{Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId},
   498  				{Group: quota.Global, Kind: quota.Write},
   499  			},
   500  			handlerErr:    errors.New("bad request"),
   501  			wantGetTokens: 3,
   502  			wantPutTokens: 3,
   503  		},
   504  	}
   505  
   506  	defer func(timeout time.Duration) {
   507  		PutTokensTimeout = timeout
   508  	}(PutTokensTimeout)
   509  	PutTokensTimeout = 5 * time.Second
   510  
   511  	// Use a ctx with a timeout smaller than PutTokensTimeout. Not too short or
   512  	// spurious failures will occur when the deadline expires.
   513  	ctx, cancel := context.WithTimeout(context.Background(), PutTokensTimeout-2*time.Second)
   514  	defer cancel()
   515  
   516  	for _, test := range tests {
   517  		t.Run(test.desc, func(t *testing.T) {
   518  			ctrl := gomock.NewController(t)
   519  			defer ctrl.Finish()
   520  			admin := storage.NewMockAdminStorage(ctrl)
   521  			adminTX := storage.NewMockReadOnlyAdminTX(ctrl)
   522  			admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil)
   523  			adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(&logTree, nil)
   524  			adminTX.EXPECT().Close().AnyTimes().Return(nil)
   525  			adminTX.EXPECT().Commit().AnyTimes().Return(nil)
   526  			putTokensCh := make(chan bool, 1)
   527  			wantDeadline := time.Now().Add(PutTokensTimeout)
   528  
   529  			qm := quota.NewMockManager(ctrl)
   530  			qm.EXPECT().GetUser(gomock.Any(), test.req).MaxTimes(1).Return(user)
   531  			if test.wantGetTokens > 0 {
   532  				qm.EXPECT().GetTokens(gomock.Any(), test.wantGetTokens, test.specs).Return(nil)
   533  			}
   534  			if test.wantPutTokens > 0 {
   535  				qm.EXPECT().PutTokens(gomock.Any(), test.wantPutTokens, test.specs).Do(func(ctx context.Context, numTokens int, specs []quota.Spec) {
   536  					switch d, ok := ctx.Deadline(); {
   537  					case !ok:
   538  						t.Errorf("PutTokens() ctx has no deadline: %v", ctx)
   539  					case d.Before(wantDeadline):
   540  						t.Errorf("PutTokens() ctx deadline too short, got %v, want >= %v", d, wantDeadline)
   541  					}
   542  					putTokensCh <- true
   543  				}).Return(nil)
   544  			}
   545  
   546  			handler := &fakeHandler{resp: test.resp, err: test.handlerErr}
   547  			intercept := New(admin, qm, false /* quotaDryRun */, nil /* mf */)
   548  
   549  			if _, err := intercept.UnaryInterceptor(ctx, test.req, &grpc.UnaryServerInfo{}, handler.run); err != test.handlerErr {
   550  				t.Errorf("UnaryInterceptor() returned err = [%v], want = [%v]", err, test.handlerErr)
   551  			}
   552  
   553  			// PutTokens may be delegated to a separate goroutine. Give it some time to complete.
   554  			select {
   555  			case <-putTokensCh:
   556  				// OK
   557  			case <-time.After(1 * time.Second):
   558  				// No need to error here, gomock will fail if the call is missing.
   559  			}
   560  		})
   561  	}
   562  }
   563  
   564  func TestTrillianInterceptor_NotIntercepted(t *testing.T) {
   565  	tests := []struct {
   566  		req interface{}
   567  	}{
   568  		// Admin
   569  		{req: &trillian.CreateTreeRequest{}},
   570  		{req: &trillian.ListTreesRequest{}},
   571  		// Quota
   572  		{req: &quotapb.CreateConfigRequest{}},
   573  		{req: &quotapb.DeleteConfigRequest{}},
   574  		{req: &quotapb.GetConfigRequest{}},
   575  		{req: &quotapb.ListConfigsRequest{}},
   576  		{req: &quotapb.UpdateConfigRequest{}},
   577  	}
   578  
   579  	ctx := context.Background()
   580  	for _, test := range tests {
   581  		handler := &fakeHandler{}
   582  		intercept := New(nil /* admin */, quota.Noop(), false /* quotaDryRun */, nil /* mf */)
   583  		if _, err := intercept.UnaryInterceptor(ctx, test.req, &grpc.UnaryServerInfo{}, handler.run); err != nil {
   584  			t.Errorf("UnaryInterceptor(%#v) returned err = %v", test.req, err)
   585  		}
   586  		if !handler.called {
   587  			t.Errorf("UnaryInterceptor(%#v): handler not called", test.req)
   588  		}
   589  	}
   590  }
   591  
   592  // TestTrillianInterceptor_BeforeAfter tests a few Before/After interactions that are
   593  // difficult/impossible to get unless the methods are called separately (i.e., not via
   594  // UnaryInterceptor()).
   595  func TestTrillianInterceptor_BeforeAfter(t *testing.T) {
   596  	logTree := *testonly.LogTree
   597  	logTree.TreeId = 10
   598  
   599  	qm := quota.Noop()
   600  
   601  	tests := []struct {
   602  		desc          string
   603  		req, resp     interface{}
   604  		handlerErr    error
   605  		wantBeforeErr bool
   606  	}{
   607  		{
   608  			desc: "success",
   609  			req:  &trillian.CreateTreeRequest{},
   610  			resp: &trillian.Tree{},
   611  		},
   612  		{
   613  			desc:          "badRequest",
   614  			req:           "bad",
   615  			resp:          nil,
   616  			handlerErr:    errors.New("bad"),
   617  			wantBeforeErr: true,
   618  		},
   619  	}
   620  
   621  	ctx := context.Background()
   622  	for _, test := range tests {
   623  		t.Run(test.desc, func(t *testing.T) {
   624  			ctrl := gomock.NewController(t)
   625  			defer ctrl.Finish()
   626  			admin := storage.NewMockAdminStorage(ctrl)
   627  			adminTX := storage.NewMockReadOnlyAdminTX(ctrl)
   628  			admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil)
   629  			adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(&logTree, nil)
   630  			adminTX.EXPECT().Close().AnyTimes().Return(nil)
   631  			adminTX.EXPECT().Commit().AnyTimes().Return(nil)
   632  
   633  			intercept := New(admin, qm, false /* quotaDryRun */, nil /* mf */)
   634  			p := intercept.NewProcessor()
   635  
   636  			_, err := p.Before(ctx, test.req)
   637  			if gotErr := err != nil; gotErr != test.wantBeforeErr {
   638  				t.Fatalf("Before() returned err = %v, wantErr = %v", err, test.wantBeforeErr)
   639  			}
   640  
   641  			// Other TrillianInterceptor tests assert After side-effects more in-depth, silently
   642  			// returning is good enough here.
   643  			p.After(ctx, test.resp, test.handlerErr)
   644  		})
   645  	}
   646  }
   647  
   648  func TestCombine(t *testing.T) {
   649  	i1 := &fakeInterceptor{key: "key1", val: "foo"}
   650  	i2 := &fakeInterceptor{key: "key2", val: "bar"}
   651  	i3 := &fakeInterceptor{key: "key3", val: "baz"}
   652  	e1 := &fakeInterceptor{err: errors.New("intercept error")}
   653  
   654  	handlerErr := errors.New("handler error")
   655  
   656  	tests := []struct {
   657  		desc         string
   658  		interceptors []*fakeInterceptor
   659  		handlerErr   error
   660  		wantCalled   int
   661  		wantErr      error
   662  	}{
   663  		{
   664  			desc: "noInterceptors",
   665  		},
   666  		{
   667  			desc:         "single",
   668  			interceptors: []*fakeInterceptor{i1},
   669  			wantCalled:   1,
   670  		},
   671  		{
   672  			desc:         "multi1",
   673  			interceptors: []*fakeInterceptor{i1, i2, i3},
   674  			wantCalled:   3,
   675  		},
   676  		{
   677  			desc:         "multi2",
   678  			interceptors: []*fakeInterceptor{i3, i1, i2},
   679  			wantCalled:   3,
   680  		},
   681  		{
   682  			desc:         "handlerErr",
   683  			interceptors: []*fakeInterceptor{i1, i2},
   684  			handlerErr:   handlerErr,
   685  			wantCalled:   2,
   686  			wantErr:      handlerErr,
   687  		},
   688  		{
   689  			desc:         "interceptErr",
   690  			interceptors: []*fakeInterceptor{i1, e1, i2},
   691  			wantCalled:   2,
   692  			wantErr:      e1.err,
   693  		},
   694  	}
   695  
   696  	ctx := context.Background()
   697  	req := "request"
   698  	info := &grpc.UnaryServerInfo{}
   699  	for _, test := range tests {
   700  		t.Run(test.desc, func(t *testing.T) {
   701  			if l := len(test.interceptors); l < test.wantCalled {
   702  				t.Fatalf("len(interceptors) = %v, want >= %v", l, test.wantCalled)
   703  			}
   704  
   705  			intercepts := []grpc.UnaryServerInterceptor{}
   706  			for _, i := range test.interceptors {
   707  				i.called = false
   708  				intercepts = append(intercepts, i.run)
   709  			}
   710  			intercept := Combine(intercepts...)
   711  
   712  			handler := &fakeHandler{resp: "response", err: test.handlerErr}
   713  			resp, err := intercept(ctx, req, info, handler.run)
   714  			if err != test.wantErr {
   715  				t.Fatalf("err = %q, want = %q", err, test.wantErr)
   716  			}
   717  
   718  			called := 0
   719  			callsStopped := false
   720  			for _, i := range test.interceptors {
   721  				switch {
   722  				case i.called:
   723  					if callsStopped {
   724  						t.Errorf("interceptor called out of order: %v", i)
   725  					}
   726  					called++
   727  				case !i.called:
   728  					// No calls should have happened from here on
   729  					callsStopped = true
   730  				}
   731  			}
   732  			if called != test.wantCalled {
   733  				t.Errorf("called %v interceptors, want = %v", called, test.wantCalled)
   734  			}
   735  
   736  			// Assertions below this point assume that the handler was called (ie, all
   737  			// interceptors succeeded).
   738  			if err != nil && err != test.handlerErr {
   739  				return
   740  			}
   741  
   742  			if resp != handler.resp {
   743  				t.Errorf("resp = %v, want = %v", resp, handler.resp)
   744  			}
   745  
   746  			// Chain the ctxs for all called interceptors and verify it got through to the
   747  			// handler.
   748  			wantCtx := ctx
   749  			for _, i := range test.interceptors {
   750  				h := &fakeHandler{resp: "ok"}
   751  				i.called = false
   752  				_, err = i.run(wantCtx, req, info, h.run)
   753  				if err != nil {
   754  					t.Fatalf("unexpected handler failure: %v", err)
   755  				}
   756  				wantCtx = h.ctx
   757  			}
   758  			if diff := pretty.Compare(handler.ctx, wantCtx); diff != "" {
   759  				t.Errorf("handler ctx diff:\n%v", diff)
   760  			}
   761  		})
   762  	}
   763  }
   764  
   765  func TestErrorWrapper(t *testing.T) {
   766  	badLlamaErr := status.Errorf(codes.InvalidArgument, "Bad Llama")
   767  	tests := []struct {
   768  		desc         string
   769  		resp         interface{}
   770  		err, wantErr error
   771  	}{
   772  		{
   773  			desc: "success",
   774  			resp: "ok",
   775  		},
   776  		{
   777  			desc:    "error",
   778  			err:     badLlamaErr,
   779  			wantErr: serrors.WrapError(badLlamaErr),
   780  		},
   781  	}
   782  	ctx := context.Background()
   783  	for _, test := range tests {
   784  		t.Run(test.desc, func(t *testing.T) {
   785  			handler := fakeHandler{resp: test.resp, err: test.err}
   786  			resp, err := ErrorWrapper(ctx, "req", &grpc.UnaryServerInfo{}, handler.run)
   787  			if resp != test.resp {
   788  				t.Errorf("resp = %v, want = %v", resp, test.resp)
   789  			}
   790  			if diff := pretty.Compare(err, test.wantErr); diff != "" {
   791  				t.Errorf("post-WrapErrors diff:\n%v", diff)
   792  			}
   793  		})
   794  	}
   795  }
   796  
   797  type fakeHandler struct {
   798  	called bool
   799  	resp   interface{}
   800  	err    error
   801  	// Attributes recorded by run calls
   802  	ctx context.Context
   803  	req interface{}
   804  }
   805  
   806  func (f *fakeHandler) run(ctx context.Context, req interface{}) (interface{}, error) {
   807  	if f.called {
   808  		panic("handler already called; either create a new handler or set called to false before reusing")
   809  	}
   810  	f.called = true
   811  	f.ctx = ctx
   812  	f.req = req
   813  	return f.resp, f.err
   814  }
   815  
   816  type fakeInterceptor struct {
   817  	key    interface{}
   818  	val    interface{}
   819  	called bool
   820  	err    error
   821  }
   822  
   823  func (f *fakeInterceptor) run(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
   824  	if f.called {
   825  		panic("interceptor already called; either create a new interceptor or set called to false before reusing")
   826  	}
   827  	f.called = true
   828  	if f.err != nil {
   829  		return nil, f.err
   830  	}
   831  	return handler(context.WithValue(ctx, f.key, f.val), req)
   832  }