github.com/cloudwego/kitex@v0.9.0/pkg/endpoint/endpoint_test.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package endpoint 18 19 import ( 20 "context" 21 "testing" 22 23 "github.com/cloudwego/kitex/internal/test" 24 ) 25 26 type val struct { 27 str string 28 } 29 30 var ( 31 biz = "Biz" 32 beforeMW0 = "BeforeMiddleware0" 33 afterMW0 = "AfterMiddleware0" 34 beforeMW1 = "BeforeMiddleware1" 35 afterMW1 = "AfterMiddleware1" 36 ) 37 38 func invoke(ctx context.Context, req, resp interface{}) (err error) { 39 val, ok := req.(*val) 40 if ok { 41 val.str += biz 42 } 43 return nil 44 } 45 46 func mockMW0(next Endpoint) Endpoint { 47 return func(ctx context.Context, req, resp interface{}) (err error) { 48 val, ok := req.(*val) 49 if ok { 50 val.str += beforeMW0 51 } 52 err = next(ctx, req, resp) 53 if err != nil { 54 return err 55 } 56 if ok { 57 val.str += afterMW0 58 } 59 return nil 60 } 61 } 62 63 func mockMW1(next Endpoint) Endpoint { 64 return func(ctx context.Context, req, resp interface{}) (err error) { 65 val, ok := req.(*val) 66 if ok { 67 val.str += beforeMW1 68 } 69 err = next(ctx, req, resp) 70 if err != nil { 71 return err 72 } 73 if ok { 74 val.str += afterMW1 75 } 76 return nil 77 } 78 } 79 80 func TestChain(t *testing.T) { 81 mws := Chain(mockMW0, mockMW1) 82 req := &val{} 83 mws(invoke)(context.Background(), req, nil) 84 final := beforeMW0 + beforeMW1 + biz + afterMW1 + afterMW0 85 test.Assert(t, req.str == final) 86 } 87 88 func TestBuild(t *testing.T) { 89 test.Assert(t, Build(nil)(DummyEndpoint)(context.Background(), nil, nil) == nil) 90 mws := Build([]Middleware{mockMW0, mockMW1}) 91 req := &val{} 92 mws(invoke)(context.Background(), req, nil) 93 final := beforeMW0 + beforeMW1 + biz + afterMW1 + afterMW0 94 test.Assert(t, req.str == final) 95 }