github.com/zorawar87/trillian@v1.2.1/server/interceptor/interceptor_test.go (about) 1 // Copyright 2017 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package interceptor 16 17 import ( 18 "context" 19 "errors" 20 "testing" 21 "time" 22 23 "github.com/golang/mock/gomock" 24 "github.com/golang/protobuf/proto" 25 "github.com/golang/protobuf/ptypes" 26 "github.com/google/trillian" 27 "github.com/google/trillian/quota" 28 "github.com/google/trillian/quota/etcd/quotapb" 29 "github.com/google/trillian/storage" 30 "github.com/google/trillian/storage/testonly" 31 "github.com/google/trillian/trees" 32 "github.com/kylelemons/godebug/pretty" 33 "google.golang.org/grpc" 34 "google.golang.org/grpc/codes" 35 "google.golang.org/grpc/status" 36 37 serrors "github.com/google/trillian/server/errors" 38 grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" 39 ) 40 41 func TestServiceName(t *testing.T) { 42 for _, tc := range []struct { 43 desc string 44 method string 45 want string 46 }{ 47 {desc: "trillian", method: "/trillian.TrillianLog/QueueLeaf", want: "trillian.TrillianLog"}, 48 {desc: "fullyqualified", method: "/some.package.service/method", want: "some.package.service"}, 49 {desc: "unqualified", method: "/service.method", want: "service"}, 50 {desc: "noleadingslash", method: "no.leading.slash/method"}, 51 {desc: "malformed", method: "/package.service.method"}, 52 } { 53 t.Run(tc.desc, func(t *testing.T) { 54 if got, want := serviceName(tc.method), tc.want; got != want { 55 t.Errorf("serviceName(%v): %v, want %v", tc.method, got, want) 56 } 57 }) 58 } 59 } 60 61 func TestTrillianInterceptor_TreeInterception(t *testing.T) { 62 logTree := proto.Clone(testonly.LogTree).(*trillian.Tree) 63 logTree.TreeId = 10 64 mapTree := proto.Clone(testonly.MapTree).(*trillian.Tree) 65 mapTree.TreeId = 11 66 deletedTree := proto.Clone(testonly.LogTree).(*trillian.Tree) 67 deletedTree.TreeId = 12 68 deletedTree.Deleted = true 69 deletedTree.DeleteTime = ptypes.TimestampNow() 70 unknownTreeID := int64(999) 71 72 tests := []struct { 73 desc string 74 method string 75 req interface{} 76 handlerErr error 77 wantErr bool 78 wantTree *trillian.Tree 79 cancelled bool 80 }{ 81 // TODO(codingllama): Admin requests don't benefit from tree-reading logic, but we may read 82 // their tree IDs for auth purposes. 83 { 84 desc: "adminReadByID", 85 method: "/trillian.TrillianAdmin/GetTree", 86 req: &trillian.GetTreeRequest{TreeId: logTree.TreeId}, 87 }, 88 { 89 desc: "adminWriteByID", 90 method: "/trillian.TrillianAdmin/DeleteTree", 91 req: &trillian.DeleteTreeRequest{TreeId: logTree.TreeId}, 92 }, 93 { 94 desc: "adminWriteByTree", 95 method: "/trillian.TrillianAdmin/UpdateTree", 96 req: &trillian.UpdateTreeRequest{Tree: &trillian.Tree{TreeId: logTree.TreeId}}, 97 }, 98 { 99 desc: "logRPC", 100 method: "/trillian.TrillianLog/GetLatestSignedLogRoot", 101 req: &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId}, 102 wantTree: logTree, 103 }, 104 { 105 desc: "mapRPC", 106 method: "/trillian.TrillianMap/GetSignedMapRoot", 107 req: &trillian.GetSignedMapRootRequest{MapId: mapTree.TreeId}, 108 wantTree: mapTree, 109 }, 110 { 111 desc: "unknownRequest", 112 req: "not-a-request", 113 wantErr: false, 114 }, 115 { 116 desc: "unknownTree", 117 method: "/trillian.TrillianLog/GetLatestSignedLogRoot", 118 req: &trillian.GetLatestSignedLogRootRequest{LogId: unknownTreeID}, 119 wantErr: true, 120 }, 121 { 122 desc: "deletedTree", 123 method: "/trillian.TrillianLog/GetLatestSignedLogRoot", 124 req: &trillian.GetLatestSignedLogRootRequest{LogId: deletedTree.TreeId}, 125 wantErr: true, 126 }, 127 { 128 desc: "cancelled", 129 method: "/trillian.TrillianLog/GetLatestSignedLogRoot", 130 req: &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId}, 131 cancelled: true, 132 wantErr: true, 133 }, 134 } 135 136 ctx := context.Background() 137 for _, test := range tests { 138 t.Run(test.desc, func(t *testing.T) { 139 ctrl := gomock.NewController(t) 140 defer ctrl.Finish() 141 admin := storage.NewMockAdminStorage(ctrl) 142 adminTX := storage.NewMockReadOnlyAdminTX(ctrl) 143 admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil) 144 adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(logTree, nil) 145 adminTX.EXPECT().GetTree(gomock.Any(), mapTree.TreeId).AnyTimes().Return(mapTree, nil) 146 adminTX.EXPECT().GetTree(gomock.Any(), deletedTree.TreeId).AnyTimes().Return(deletedTree, nil) 147 adminTX.EXPECT().GetTree(gomock.Any(), unknownTreeID).AnyTimes().Return(nil, errors.New("not found")) 148 adminTX.EXPECT().Close().AnyTimes().Return(nil) 149 adminTX.EXPECT().Commit().AnyTimes().Return(nil) 150 151 intercept := New(admin, quota.Noop(), false /* quotaDryRun */, nil /* mf */) 152 handler := &fakeHandler{resp: "handler response", err: test.handlerErr} 153 154 if test.cancelled { 155 // Use a context that's already been cancelled 156 newCtx, cancel := context.WithCancel(ctx) 157 cancel() 158 ctx = newCtx 159 } 160 161 resp, err := intercept.UnaryInterceptor(ctx, test.req, 162 &grpc.UnaryServerInfo{FullMethod: test.method}, 163 handler.run) 164 if hasErr := err != nil && err != test.handlerErr; hasErr != test.wantErr { 165 t.Fatalf("UnaryInterceptor() returned err = %v, wantErr = %v", err, test.wantErr) 166 } else if hasErr { 167 return 168 } 169 170 if !handler.called { 171 t.Fatal("handler not called") 172 } 173 if handler.resp != resp { 174 t.Errorf("resp = %v, want = %v", resp, handler.resp) 175 } 176 if handler.err != err { 177 t.Errorf("err = %v, want = %v", err, handler.err) 178 } 179 180 if test.wantTree != nil { 181 switch tree, ok := trees.FromContext(handler.ctx); { 182 case !ok: 183 t.Error("tree not in handler ctx") 184 case !proto.Equal(tree, test.wantTree): 185 diff := pretty.Compare(tree, test.wantTree) 186 t.Errorf("post-FromContext diff:\n%v", diff) 187 } 188 } 189 }) 190 } 191 } 192 193 func TestTrillianInterceptor_QuotaInterception(t *testing.T) { 194 195 logTree := *testonly.LogTree 196 logTree.TreeId = 10 197 198 mapTree := *testonly.MapTree 199 mapTree.TreeId = 11 200 201 preorderedTree := *testonly.PreorderedLogTree 202 preorderedTree.TreeId = 12 203 204 charge1 := "alpaca" 205 charge2 := "cama" 206 charges := &trillian.ChargeTo{User: []string{charge1, charge2}} 207 tests := []struct { 208 desc string 209 dryRun bool 210 method string 211 req interface{} 212 specs []quota.Spec 213 getTokensErr error 214 wantCode codes.Code 215 wantTokens int 216 }{ 217 { 218 desc: "logRead", 219 method: "/trillian.TrillianLog/GetLatestSignedLogRoot", 220 req: &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId}, 221 specs: []quota.Spec{ 222 {Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId}, 223 {Group: quota.Global, Kind: quota.Read}, 224 }, 225 wantTokens: 1, 226 }, 227 { 228 desc: "logReadIndices", 229 method: "/trillian.TrillianLog/GetLeavesByIndex", 230 req: &trillian.GetLeavesByIndexRequest{LogId: logTree.TreeId, LeafIndex: []int64{1, 2, 3}}, 231 specs: []quota.Spec{ 232 {Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId}, 233 {Group: quota.Global, Kind: quota.Read}, 234 }, 235 wantTokens: 3, 236 }, 237 { 238 desc: "logReadRange", 239 method: "/trillian.TrillianLog/GetLeavesByRange", 240 req: &trillian.GetLeavesByRangeRequest{LogId: logTree.TreeId, Count: 123}, 241 specs: []quota.Spec{ 242 {Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId}, 243 {Group: quota.Global, Kind: quota.Read}, 244 }, 245 wantTokens: 123, 246 }, 247 { 248 desc: "logReadNegativeRange", 249 method: "/trillian.TrillianLog/GetLeavesByRange", 250 req: &trillian.GetLeavesByRangeRequest{LogId: logTree.TreeId, Count: -123}, 251 specs: []quota.Spec{ 252 {Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId}, 253 {Group: quota.Global, Kind: quota.Read}, 254 }, 255 wantTokens: 1, 256 }, 257 { 258 desc: "logReadZeroRange", 259 method: "/trillian.TrillianLog/GetLeavesByRange", 260 req: &trillian.GetLeavesByRangeRequest{LogId: logTree.TreeId, Count: 0}, 261 specs: []quota.Spec{ 262 {Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId}, 263 {Group: quota.Global, Kind: quota.Read}, 264 }, 265 wantTokens: 1, 266 }, 267 { 268 desc: "logRead with charges", 269 method: "/trillian.TrillianLog/GetLatestSignedLogRoot", 270 req: &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId, ChargeTo: charges}, 271 specs: []quota.Spec{ 272 {Group: quota.User, Kind: quota.Read, User: charge1}, 273 {Group: quota.User, Kind: quota.Read, User: charge2}, 274 {Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId}, 275 {Group: quota.Global, Kind: quota.Read}, 276 }, 277 wantTokens: 1, 278 }, 279 { 280 desc: "logWrite", 281 method: "/trillian.TrillianLog/QueueLeaf", 282 req: &trillian.QueueLeafRequest{LogId: logTree.TreeId}, 283 specs: []quota.Spec{ 284 {Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId}, 285 {Group: quota.Global, Kind: quota.Write}, 286 }, 287 wantTokens: 1, 288 }, 289 { 290 desc: "logWrite with charges", 291 method: "/trillian.TrillianLog/QueueLeaf", 292 req: &trillian.QueueLeafRequest{LogId: logTree.TreeId, ChargeTo: charges}, 293 specs: []quota.Spec{ 294 {Group: quota.User, Kind: quota.Write, User: charge1}, 295 {Group: quota.User, Kind: quota.Write, User: charge2}, 296 {Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId}, 297 {Group: quota.Global, Kind: quota.Write}, 298 }, 299 wantTokens: 1, 300 }, 301 { 302 desc: "mapRead", 303 method: "/trillian.TrillianMap/GetLeaves", 304 req: &trillian.GetMapLeavesRequest{MapId: mapTree.TreeId, Index: [][]byte{{0x01}, {0x02}}}, 305 specs: []quota.Spec{ 306 {Group: quota.Tree, Kind: quota.Read, TreeID: mapTree.TreeId}, 307 {Group: quota.Global, Kind: quota.Read}, 308 }, 309 wantTokens: 2, 310 }, 311 { 312 desc: "emptyBatchRequest", 313 method: "/trillian.TrillianLog/QueueLeaves", 314 req: &trillian.QueueLeavesRequest{ 315 LogId: logTree.TreeId, 316 Leaves: nil, 317 }, 318 }, 319 { 320 desc: "batchLogLeavesRequest", 321 method: "/trillian.TrillianLog/QueueLeaves", 322 req: &trillian.QueueLeavesRequest{ 323 LogId: logTree.TreeId, 324 Leaves: []*trillian.LogLeaf{{}, {}, {}}, 325 }, 326 specs: []quota.Spec{ 327 {Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId}, 328 {Group: quota.Global, Kind: quota.Write}, 329 }, 330 wantTokens: 3, 331 }, 332 { 333 desc: "batchSequencedLogLeavesRequest", 334 method: "/trillian.TrillianLog/AddSequencedLeaves", 335 req: &trillian.AddSequencedLeavesRequest{ 336 LogId: preorderedTree.TreeId, 337 Leaves: []*trillian.LogLeaf{{}, {}, {}}, 338 }, 339 specs: []quota.Spec{ 340 {Group: quota.Tree, Kind: quota.Write, TreeID: preorderedTree.TreeId}, 341 {Group: quota.Global, Kind: quota.Write}, 342 }, 343 wantTokens: 3, 344 }, 345 { 346 desc: "batchLogLeavesRequest with charges", 347 method: "/trillian.TrillianLog/QueueLeaves", 348 req: &trillian.QueueLeavesRequest{ 349 LogId: logTree.TreeId, 350 Leaves: []*trillian.LogLeaf{{}, {}, {}}, 351 ChargeTo: charges, 352 }, 353 specs: []quota.Spec{ 354 {Group: quota.User, Kind: quota.Write, User: charge1}, 355 {Group: quota.User, Kind: quota.Write, User: charge2}, 356 {Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId}, 357 {Group: quota.Global, Kind: quota.Write}, 358 }, 359 wantTokens: 3, 360 }, 361 { 362 desc: "batchMapLeavesRequest", 363 method: "/trillian.TrillianMap/SetLeaves", 364 req: &trillian.SetMapLeavesRequest{ 365 MapId: mapTree.TreeId, 366 Leaves: []*trillian.MapLeaf{{}, {}, {}, {}, {}}, 367 }, 368 specs: []quota.Spec{ 369 {Group: quota.Tree, Kind: quota.Write, TreeID: mapTree.TreeId}, 370 {Group: quota.Global, Kind: quota.Write}, 371 }, 372 wantTokens: 5, 373 }, 374 { 375 desc: "quotaError", 376 method: "/trillian.TrillianLog/GetLatestSignedLogRoot", 377 req: &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId}, 378 specs: []quota.Spec{ 379 {Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId}, 380 {Group: quota.Global, Kind: quota.Read}, 381 }, 382 getTokensErr: errors.New("not enough tokens"), 383 wantCode: codes.ResourceExhausted, 384 wantTokens: 1, 385 }, 386 { 387 desc: "quotaDryRunError", 388 dryRun: true, 389 method: "/trillian.TrillianLog/GetLatestSignedLogRoot", 390 req: &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId}, 391 specs: []quota.Spec{ 392 {Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId}, 393 {Group: quota.Global, Kind: quota.Read}, 394 }, 395 getTokensErr: errors.New("not enough tokens"), 396 wantTokens: 1, 397 }, 398 } 399 400 ctx := context.Background() 401 for _, test := range tests { 402 t.Run(test.desc, func(t *testing.T) { 403 ctrl := gomock.NewController(t) 404 defer ctrl.Finish() 405 admin := storage.NewMockAdminStorage(ctrl) 406 adminTX := storage.NewMockReadOnlyAdminTX(ctrl) 407 admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil) 408 adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(&logTree, nil) 409 adminTX.EXPECT().GetTree(gomock.Any(), mapTree.TreeId).AnyTimes().Return(&mapTree, nil) 410 adminTX.EXPECT().GetTree(gomock.Any(), preorderedTree.TreeId).AnyTimes().Return(&preorderedTree, nil) 411 adminTX.EXPECT().Close().AnyTimes().Return(nil) 412 adminTX.EXPECT().Commit().AnyTimes().Return(nil) 413 414 qm := quota.NewMockManager(ctrl) 415 if test.wantTokens > 0 { 416 qm.EXPECT().GetTokens(gomock.Any(), test.wantTokens, test.specs).Return(test.getTokensErr) 417 } 418 419 handler := &fakeHandler{resp: "ok"} 420 intercept := New(admin, qm, test.dryRun, nil /* mf */) 421 422 // resp and handler assertions are done by TestTrillianInterceptor_TreeInterception, 423 // we're only concerned with the quota logic here. 424 _, err := intercept.UnaryInterceptor(ctx, test.req, 425 &grpc.UnaryServerInfo{FullMethod: test.method}, 426 handler.run) 427 if s, ok := status.FromError(err); !ok || s.Code() != test.wantCode { 428 t.Errorf("UnaryInterceptor() returned err = %q, wantCode = %v", err, test.wantCode) 429 } 430 }) 431 } 432 } 433 434 func TestTrillianInterceptor_QuotaInterception_ReturnsTokens(t *testing.T) { 435 436 logTree := *testonly.LogTree 437 logTree.TreeId = 10 438 439 tests := []struct { 440 desc string 441 method string 442 req, resp interface{} 443 specs []quota.Spec 444 handlerErr error 445 wantGetTokens, wantPutTokens int 446 }{ 447 { 448 desc: "badRequest", 449 method: "/trillian.TrillianLog/GetLatestSignedLogRoot", 450 req: &trillian.GetLatestSignedLogRootRequest{LogId: logTree.TreeId}, 451 specs: []quota.Spec{ 452 {Group: quota.Tree, Kind: quota.Read, TreeID: logTree.TreeId}, 453 {Group: quota.Global, Kind: quota.Read}, 454 }, 455 handlerErr: errors.New("bad request"), 456 wantGetTokens: 1, 457 wantPutTokens: 1, 458 }, 459 { 460 desc: "newLeaf", 461 method: "/trillian.TrillianLog/QueueLeaf", 462 req: &trillian.QueueLeafRequest{LogId: logTree.TreeId, Leaf: &trillian.LogLeaf{}}, 463 resp: &trillian.QueueLeafResponse{QueuedLeaf: &trillian.QueuedLogLeaf{}}, 464 specs: []quota.Spec{ 465 {Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId}, 466 {Group: quota.Global, Kind: quota.Write}, 467 }, 468 wantGetTokens: 1, 469 }, 470 { 471 desc: "duplicateLeaf", 472 method: "/trillian.TrillianLog/QueueLeaf", 473 req: &trillian.QueueLeafRequest{LogId: logTree.TreeId}, 474 resp: &trillian.QueueLeafResponse{ 475 QueuedLeaf: &trillian.QueuedLogLeaf{ 476 Status: status.New(codes.AlreadyExists, "duplicate leaf").Proto(), 477 }, 478 }, 479 specs: []quota.Spec{ 480 {Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId}, 481 {Group: quota.Global, Kind: quota.Write}, 482 }, 483 wantGetTokens: 1, 484 wantPutTokens: 1, 485 }, 486 { 487 desc: "newLeaves", 488 method: "/trillian.TrillianLog/QueueLeaves", 489 req: &trillian.QueueLeavesRequest{ 490 LogId: logTree.TreeId, 491 Leaves: []*trillian.LogLeaf{{}, {}, {}}, 492 }, 493 resp: &trillian.QueueLeavesResponse{ 494 QueuedLeaves: []*trillian.QueuedLogLeaf{{}, {}, {}}, // No explicit Status means OK 495 }, 496 specs: []quota.Spec{ 497 {Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId}, 498 {Group: quota.Global, Kind: quota.Write}, 499 }, 500 wantGetTokens: 3, 501 }, 502 { 503 desc: "duplicateLeaves", 504 method: "/trillian.TrillianLog/QueueLeaves", 505 req: &trillian.QueueLeavesRequest{ 506 LogId: logTree.TreeId, 507 Leaves: []*trillian.LogLeaf{{}, {}, {}}, 508 }, 509 resp: &trillian.QueueLeavesResponse{ 510 QueuedLeaves: []*trillian.QueuedLogLeaf{ 511 {Status: status.New(codes.AlreadyExists, "duplicate leaf").Proto()}, 512 {Status: status.New(codes.AlreadyExists, "duplicate leaf").Proto()}, 513 {}, 514 }, 515 }, 516 specs: []quota.Spec{ 517 {Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId}, 518 {Group: quota.Global, Kind: quota.Write}, 519 }, 520 wantGetTokens: 3, 521 wantPutTokens: 2, 522 }, 523 { 524 desc: "badQueueLeavesRequest", 525 method: "/trillian.TrillianLog/QueueLeaves", 526 req: &trillian.QueueLeavesRequest{ 527 LogId: logTree.TreeId, 528 Leaves: []*trillian.LogLeaf{{}, {}, {}}, 529 }, 530 specs: []quota.Spec{ 531 {Group: quota.Tree, Kind: quota.Write, TreeID: logTree.TreeId}, 532 {Group: quota.Global, Kind: quota.Write}, 533 }, 534 handlerErr: errors.New("bad request"), 535 wantGetTokens: 3, 536 wantPutTokens: 3, 537 }, 538 } 539 540 defer func(timeout time.Duration) { 541 PutTokensTimeout = timeout 542 }(PutTokensTimeout) 543 PutTokensTimeout = 5 * time.Second 544 545 // Use a ctx with a timeout smaller than PutTokensTimeout. Not too short or 546 // spurious failures will occur when the deadline expires. 547 ctx, cancel := context.WithTimeout(context.Background(), PutTokensTimeout-2*time.Second) 548 defer cancel() 549 550 for _, test := range tests { 551 t.Run(test.desc, func(t *testing.T) { 552 ctrl := gomock.NewController(t) 553 defer ctrl.Finish() 554 admin := storage.NewMockAdminStorage(ctrl) 555 adminTX := storage.NewMockReadOnlyAdminTX(ctrl) 556 admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil) 557 adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(&logTree, nil) 558 adminTX.EXPECT().Close().AnyTimes().Return(nil) 559 adminTX.EXPECT().Commit().AnyTimes().Return(nil) 560 putTokensCh := make(chan bool, 1) 561 wantDeadline := time.Now().Add(PutTokensTimeout) 562 563 qm := quota.NewMockManager(ctrl) 564 if test.wantGetTokens > 0 { 565 qm.EXPECT().GetTokens(gomock.Any(), test.wantGetTokens, test.specs).Return(nil) 566 } 567 if test.wantPutTokens > 0 { 568 qm.EXPECT().PutTokens(gomock.Any(), test.wantPutTokens, test.specs).Do(func(ctx context.Context, numTokens int, specs []quota.Spec) { 569 switch d, ok := ctx.Deadline(); { 570 case !ok: 571 t.Errorf("PutTokens() ctx has no deadline: %v", ctx) 572 case d.Before(wantDeadline): 573 t.Errorf("PutTokens() ctx deadline too short, got %v, want >= %v", d, wantDeadline) 574 } 575 putTokensCh <- true 576 }).Return(nil) 577 } 578 579 handler := &fakeHandler{resp: test.resp, err: test.handlerErr} 580 intercept := New(admin, qm, false /* quotaDryRun */, nil /* mf */) 581 582 if _, err := intercept.UnaryInterceptor(ctx, test.req, 583 &grpc.UnaryServerInfo{FullMethod: test.method}, 584 handler.run); err != test.handlerErr { 585 t.Errorf("UnaryInterceptor() returned err = [%v], want = [%v]", err, test.handlerErr) 586 } 587 588 // PutTokens may be delegated to a separate goroutine. Give it some time to complete. 589 select { 590 case <-putTokensCh: 591 // OK 592 case <-time.After(1 * time.Second): 593 // No need to error here, gomock will fail if the call is missing. 594 } 595 }) 596 } 597 } 598 599 func TestTrillianInterceptor_NotIntercepted(t *testing.T) { 600 tests := []struct { 601 method string 602 req interface{} 603 }{ 604 // Admin 605 {method: "/trillian.TrillianAdmin/CreateTree", req: &trillian.CreateTreeRequest{}}, 606 {method: "/trillian.TrillianAdmin/ListTrees", req: &trillian.ListTreesRequest{}}, 607 // Quota 608 {method: "/quotapb.Quota/CreateConfig", req: "apb.CreateConfigRequest{}}, 609 {method: "/quotapb.Quota/DeleteConfig", req: "apb.DeleteConfigRequest{}}, 610 {method: "/quotapb.Quota/GetConfig", req: "apb.GetConfigRequest{}}, 611 {method: "/quotapb.Quota/ListConfigs", req: "apb.ListConfigsRequest{}}, 612 {method: "/quotapb.Quota/UpdateConfig", req: "apb.UpdateConfigRequest{}}, 613 } 614 615 ctx := context.Background() 616 for _, test := range tests { 617 handler := &fakeHandler{} 618 intercept := New(nil /* admin */, quota.Noop(), false /* quotaDryRun */, nil /* mf */) 619 if _, err := intercept.UnaryInterceptor(ctx, test.req, 620 &grpc.UnaryServerInfo{FullMethod: test.method}, 621 handler.run); err != nil { 622 t.Errorf("UnaryInterceptor(%#v) returned err = %v", test.req, err) 623 } 624 if !handler.called { 625 t.Errorf("UnaryInterceptor(%#v): handler not called", test.req) 626 } 627 } 628 } 629 630 // TestTrillianInterceptor_BeforeAfter tests a few Before/After interactions that are 631 // difficult/impossible to get unless the methods are called separately (i.e., not via 632 // UnaryInterceptor()). 633 func TestTrillianInterceptor_BeforeAfter(t *testing.T) { 634 logTree := *testonly.LogTree 635 logTree.TreeId = 10 636 637 qm := quota.Noop() 638 639 tests := []struct { 640 desc string 641 req, resp interface{} 642 handlerErr error 643 wantBeforeErr bool 644 }{ 645 { 646 desc: "success", 647 req: &trillian.CreateTreeRequest{}, 648 resp: &trillian.Tree{}, 649 }, 650 { 651 desc: "badRequest", 652 req: "bad", 653 resp: nil, 654 handlerErr: errors.New("bad"), 655 wantBeforeErr: true, 656 }, 657 } 658 659 ctx := context.Background() 660 for _, test := range tests { 661 t.Run(test.desc, func(t *testing.T) { 662 ctrl := gomock.NewController(t) 663 defer ctrl.Finish() 664 admin := storage.NewMockAdminStorage(ctrl) 665 adminTX := storage.NewMockReadOnlyAdminTX(ctrl) 666 admin.EXPECT().Snapshot(gomock.Any()).AnyTimes().Return(adminTX, nil) 667 adminTX.EXPECT().GetTree(gomock.Any(), logTree.TreeId).AnyTimes().Return(&logTree, nil) 668 adminTX.EXPECT().Close().AnyTimes().Return(nil) 669 adminTX.EXPECT().Commit().AnyTimes().Return(nil) 670 671 intercept := New(admin, qm, false /* quotaDryRun */, nil /* mf */) 672 p := intercept.NewProcessor() 673 674 _, err := p.Before(ctx, test.req, "/trillian.TrillianLog/foo") 675 if gotErr := err != nil; gotErr != test.wantBeforeErr { 676 t.Fatalf("Before() returned err = %v, wantErr = %v", err, test.wantBeforeErr) 677 } 678 679 // Other TrillianInterceptor tests assert After side-effects more in-depth, silently 680 // returning is good enough here. 681 p.After(ctx, test.resp, "", test.handlerErr) 682 }) 683 } 684 } 685 686 func TestCombine(t *testing.T) { 687 i1 := &fakeInterceptor{key: "key1", val: "foo"} 688 i2 := &fakeInterceptor{key: "key2", val: "bar"} 689 i3 := &fakeInterceptor{key: "key3", val: "baz"} 690 e1 := &fakeInterceptor{err: errors.New("intercept error")} 691 692 handlerErr := errors.New("handler error") 693 694 tests := []struct { 695 desc string 696 interceptors []*fakeInterceptor 697 handlerErr error 698 wantCalled int 699 wantErr error 700 }{ 701 { 702 desc: "noInterceptors", 703 }, 704 { 705 desc: "single", 706 interceptors: []*fakeInterceptor{i1}, 707 wantCalled: 1, 708 }, 709 { 710 desc: "multi1", 711 interceptors: []*fakeInterceptor{i1, i2, i3}, 712 wantCalled: 3, 713 }, 714 { 715 desc: "multi2", 716 interceptors: []*fakeInterceptor{i3, i1, i2}, 717 wantCalled: 3, 718 }, 719 { 720 desc: "handlerErr", 721 interceptors: []*fakeInterceptor{i1, i2}, 722 handlerErr: handlerErr, 723 wantCalled: 2, 724 wantErr: handlerErr, 725 }, 726 { 727 desc: "interceptErr", 728 interceptors: []*fakeInterceptor{i1, e1, i2}, 729 wantCalled: 2, 730 wantErr: e1.err, 731 }, 732 } 733 734 ctx := context.Background() 735 req := "request" 736 info := &grpc.UnaryServerInfo{} 737 for _, test := range tests { 738 t.Run(test.desc, func(t *testing.T) { 739 if l := len(test.interceptors); l < test.wantCalled { 740 t.Fatalf("len(interceptors) = %v, want >= %v", l, test.wantCalled) 741 } 742 743 intercepts := []grpc.UnaryServerInterceptor{} 744 for _, i := range test.interceptors { 745 i.called = false 746 intercepts = append(intercepts, i.run) 747 } 748 intercept := grpc_middleware.ChainUnaryServer(intercepts...) 749 750 handler := &fakeHandler{resp: "response", err: test.handlerErr} 751 resp, err := intercept(ctx, req, info, handler.run) 752 if err != test.wantErr { 753 t.Fatalf("err = %q, want = %q", err, test.wantErr) 754 } 755 756 called := 0 757 callsStopped := false 758 for _, i := range test.interceptors { 759 switch { 760 case i.called: 761 if callsStopped { 762 t.Errorf("interceptor called out of order: %v", i) 763 } 764 called++ 765 case !i.called: 766 // No calls should have happened from here on 767 callsStopped = true 768 } 769 } 770 if called != test.wantCalled { 771 t.Errorf("called %v interceptors, want = %v", called, test.wantCalled) 772 } 773 774 // Assertions below this point assume that the handler was called (ie, all 775 // interceptors succeeded). 776 if err != nil && err != test.handlerErr { 777 return 778 } 779 780 if resp != handler.resp { 781 t.Errorf("resp = %v, want = %v", resp, handler.resp) 782 } 783 784 // Chain the ctxs for all called interceptors and verify it got through to the 785 // handler. 786 wantCtx := ctx 787 for _, i := range test.interceptors { 788 h := &fakeHandler{resp: "ok"} 789 i.called = false 790 _, err = i.run(wantCtx, req, info, h.run) 791 if err != nil { 792 t.Fatalf("unexpected handler failure: %v", err) 793 } 794 wantCtx = h.ctx 795 } 796 if diff := pretty.Compare(handler.ctx, wantCtx); diff != "" { 797 t.Errorf("handler ctx diff:\n%v", diff) 798 } 799 }) 800 } 801 } 802 803 func TestErrorWrapper(t *testing.T) { 804 badLlamaErr := status.Errorf(codes.InvalidArgument, "Bad Llama") 805 tests := []struct { 806 desc string 807 resp interface{} 808 err, wantErr error 809 }{ 810 { 811 desc: "success", 812 resp: "ok", 813 }, 814 { 815 desc: "error", 816 err: badLlamaErr, 817 wantErr: serrors.WrapError(badLlamaErr), 818 }, 819 } 820 ctx := context.Background() 821 for _, test := range tests { 822 t.Run(test.desc, func(t *testing.T) { 823 handler := fakeHandler{resp: test.resp, err: test.err} 824 resp, err := ErrorWrapper(ctx, "req", &grpc.UnaryServerInfo{}, handler.run) 825 if resp != test.resp { 826 t.Errorf("resp = %v, want = %v", resp, test.resp) 827 } 828 if diff := pretty.Compare(err, test.wantErr); diff != "" { 829 t.Errorf("post-WrapErrors diff:\n%v", diff) 830 } 831 }) 832 } 833 } 834 835 type fakeHandler struct { 836 called bool 837 resp interface{} 838 err error 839 // Attributes recorded by run calls 840 ctx context.Context 841 req interface{} 842 } 843 844 func (f *fakeHandler) run(ctx context.Context, req interface{}) (interface{}, error) { 845 if f.called { 846 panic("handler already called; either create a new handler or set called to false before reusing") 847 } 848 f.called = true 849 f.ctx = ctx 850 f.req = req 851 return f.resp, f.err 852 } 853 854 type fakeInterceptor struct { 855 key interface{} 856 val interface{} 857 called bool 858 err error 859 } 860 861 func (f *fakeInterceptor) run(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 862 if f.called { 863 panic("interceptor already called; either create a new interceptor or set called to false before reusing") 864 } 865 f.called = true 866 if f.err != nil { 867 return nil, f.err 868 } 869 return handler(context.WithValue(ctx, f.key, f.val), req) 870 }