trpc.group/trpc-go/trpc-go@v1.0.3/filter/filter.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 // Package filter implements client/server filter(interceptor) chains. 15 // 16 // Signatures of filters have been refactored after v0.9.0. 17 // There remains lots of dirty codes to keep backward compatibility. 18 package filter 19 20 import ( 21 "context" 22 "sync" 23 24 "trpc.group/trpc-go/trpc-go/rpcz" 25 ) 26 27 // ClientHandleFunc defines the client side filter(interceptor) function type. 28 type ClientHandleFunc func(ctx context.Context, req, rsp interface{}) error 29 30 // ServerHandleFunc defines the server side filter(interceptor) function type. 31 type ServerHandleFunc func(ctx context.Context, req interface{}) (rsp interface{}, err error) 32 33 // ClientFilter is the client side filter(interceptor) type. They are chained to process request. 34 type ClientFilter func(ctx context.Context, req, rsp interface{}, next ClientHandleFunc) error 35 36 // ServerFilter is the server side filter(interceptor) type. They are chained to process request. 37 type ServerFilter func(ctx context.Context, req interface{}, next ServerHandleFunc) (rsp interface{}, err error) 38 39 // NoopServerFilter is the noop implementation of ServerFilter. 40 func NoopServerFilter(ctx context.Context, req interface{}, next ServerHandleFunc) (rsp interface{}, err error) { 41 return next(ctx, req) 42 } 43 44 // NoopClientFilter is the noop implementation of ClientFilter. 45 func NoopClientFilter(ctx context.Context, req, rsp interface{}, next ClientHandleFunc) error { 46 return next(ctx, req, rsp) 47 } 48 49 // EmptyChain is an empty chain. 50 var EmptyChain = ClientChain{} 51 52 // ClientChain chains client side filters. 53 type ClientChain []ClientFilter 54 55 // Filter invokes every client side filters in the chain. 56 func (c ClientChain) Filter(ctx context.Context, req, rsp interface{}, next ClientHandleFunc) error { 57 nextF := func(ctx context.Context, req, rsp interface{}) error { 58 _, end, ctx := rpcz.NewSpanContext(ctx, "CallFunc") 59 err := next(ctx, req, rsp) 60 end.End() 61 return err 62 } 63 64 names, ok := names(ctx) 65 for i := len(c) - 1; i >= 0; i-- { 66 curHandleFunc, curFilter, curI := nextF, c[i], i 67 nextF = func(ctx context.Context, req, rsp interface{}) error { 68 if ok { 69 var ender rpcz.Ender 70 _, ender, ctx = rpcz.NewSpanContext(ctx, name(names, curI)) 71 defer ender.End() 72 } 73 return curFilter(ctx, req, rsp, curHandleFunc) 74 } 75 } 76 return nextF(ctx, req, rsp) 77 } 78 79 func names(ctx context.Context) ([]string, bool) { 80 names, ok := rpcz.SpanFromContext(ctx).Attribute(rpcz.TRPCAttributeFilterNames) 81 if !ok { 82 return nil, false 83 } 84 ns, ok := names.([]string) 85 return ns, ok 86 } 87 88 func name(names []string, index int) string { 89 if index >= len(names) || index < 0 { 90 const unknownName = "unknown" 91 return unknownName 92 } 93 return names[index] 94 } 95 96 // ServerChain chains server side filters. 97 type ServerChain []ServerFilter 98 99 // Filter invokes every server side filters in the chain. 100 func (c ServerChain) Filter(ctx context.Context, req interface{}, next ServerHandleFunc) (interface{}, error) { 101 nextF := func(ctx context.Context, req interface{}) (rsp interface{}, err error) { 102 _, end, ctx := rpcz.NewSpanContext(ctx, "HandleFunc") 103 rsp, err = next(ctx, req) 104 end.End() 105 return rsp, err 106 } 107 108 names, ok := names(ctx) 109 for i := len(c) - 1; i >= 0; i-- { 110 curHandleFunc, curFilter, curI := nextF, c[i], i 111 nextF = func(ctx context.Context, req interface{}) (interface{}, error) { 112 if ok { 113 var ender rpcz.Ender 114 _, ender, ctx = rpcz.NewSpanContext(ctx, name(names, curI)) 115 defer ender.End() 116 } 117 rsp, err := curFilter(ctx, req, curHandleFunc) 118 return rsp, err 119 } 120 } 121 return nextF(ctx, req) 122 } 123 124 var ( 125 lock = sync.RWMutex{} 126 serverFilters = make(map[string]ServerFilter) 127 clientFilters = make(map[string]ClientFilter) 128 ) 129 130 // Register registers server/client filters by name. 131 func Register(name string, s ServerFilter, c ClientFilter) { 132 lock.Lock() 133 defer lock.Unlock() 134 serverFilters[name] = s 135 clientFilters[name] = c 136 } 137 138 // GetServer gets the ServerFilter by name. 139 func GetServer(name string) ServerFilter { 140 lock.RLock() 141 f := serverFilters[name] 142 lock.RUnlock() 143 return f 144 } 145 146 // GetClient gets the ClientFilter by name. 147 func GetClient(name string) ClientFilter { 148 lock.RLock() 149 f := clientFilters[name] 150 lock.RUnlock() 151 return f 152 }