trpc.group/trpc-go/trpc-go@v1.0.3/filter/filter.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  // Package filter implements client/server filter(interceptor) chains.
    15  //
    16  // Signatures of filters have been refactored after v0.9.0.
    17  // There remains lots of dirty codes to keep backward compatibility.
    18  package filter
    19  
    20  import (
    21  	"context"
    22  	"sync"
    23  
    24  	"trpc.group/trpc-go/trpc-go/rpcz"
    25  )
    26  
    27  // ClientHandleFunc defines the client side filter(interceptor) function type.
    28  type ClientHandleFunc func(ctx context.Context, req, rsp interface{}) error
    29  
    30  // ServerHandleFunc defines the server side filter(interceptor) function type.
    31  type ServerHandleFunc func(ctx context.Context, req interface{}) (rsp interface{}, err error)
    32  
    33  // ClientFilter is the client side filter(interceptor) type. They are chained to process request.
    34  type ClientFilter func(ctx context.Context, req, rsp interface{}, next ClientHandleFunc) error
    35  
    36  // ServerFilter is the server side filter(interceptor) type. They are chained to process request.
    37  type ServerFilter func(ctx context.Context, req interface{}, next ServerHandleFunc) (rsp interface{}, err error)
    38  
    39  // NoopServerFilter is the noop implementation of ServerFilter.
    40  func NoopServerFilter(ctx context.Context, req interface{}, next ServerHandleFunc) (rsp interface{}, err error) {
    41  	return next(ctx, req)
    42  }
    43  
    44  // NoopClientFilter is the noop implementation of ClientFilter.
    45  func NoopClientFilter(ctx context.Context, req, rsp interface{}, next ClientHandleFunc) error {
    46  	return next(ctx, req, rsp)
    47  }
    48  
    49  // EmptyChain is an empty chain.
    50  var EmptyChain = ClientChain{}
    51  
    52  // ClientChain chains client side filters.
    53  type ClientChain []ClientFilter
    54  
    55  // Filter invokes every client side filters in the chain.
    56  func (c ClientChain) Filter(ctx context.Context, req, rsp interface{}, next ClientHandleFunc) error {
    57  	nextF := func(ctx context.Context, req, rsp interface{}) error {
    58  		_, end, ctx := rpcz.NewSpanContext(ctx, "CallFunc")
    59  		err := next(ctx, req, rsp)
    60  		end.End()
    61  		return err
    62  	}
    63  
    64  	names, ok := names(ctx)
    65  	for i := len(c) - 1; i >= 0; i-- {
    66  		curHandleFunc, curFilter, curI := nextF, c[i], i
    67  		nextF = func(ctx context.Context, req, rsp interface{}) error {
    68  			if ok {
    69  				var ender rpcz.Ender
    70  				_, ender, ctx = rpcz.NewSpanContext(ctx, name(names, curI))
    71  				defer ender.End()
    72  			}
    73  			return curFilter(ctx, req, rsp, curHandleFunc)
    74  		}
    75  	}
    76  	return nextF(ctx, req, rsp)
    77  }
    78  
    79  func names(ctx context.Context) ([]string, bool) {
    80  	names, ok := rpcz.SpanFromContext(ctx).Attribute(rpcz.TRPCAttributeFilterNames)
    81  	if !ok {
    82  		return nil, false
    83  	}
    84  	ns, ok := names.([]string)
    85  	return ns, ok
    86  }
    87  
    88  func name(names []string, index int) string {
    89  	if index >= len(names) || index < 0 {
    90  		const unknownName = "unknown"
    91  		return unknownName
    92  	}
    93  	return names[index]
    94  }
    95  
    96  // ServerChain chains server side filters.
    97  type ServerChain []ServerFilter
    98  
    99  // Filter invokes every server side filters in the chain.
   100  func (c ServerChain) Filter(ctx context.Context, req interface{}, next ServerHandleFunc) (interface{}, error) {
   101  	nextF := func(ctx context.Context, req interface{}) (rsp interface{}, err error) {
   102  		_, end, ctx := rpcz.NewSpanContext(ctx, "HandleFunc")
   103  		rsp, err = next(ctx, req)
   104  		end.End()
   105  		return rsp, err
   106  	}
   107  
   108  	names, ok := names(ctx)
   109  	for i := len(c) - 1; i >= 0; i-- {
   110  		curHandleFunc, curFilter, curI := nextF, c[i], i
   111  		nextF = func(ctx context.Context, req interface{}) (interface{}, error) {
   112  			if ok {
   113  				var ender rpcz.Ender
   114  				_, ender, ctx = rpcz.NewSpanContext(ctx, name(names, curI))
   115  				defer ender.End()
   116  			}
   117  			rsp, err := curFilter(ctx, req, curHandleFunc)
   118  			return rsp, err
   119  		}
   120  	}
   121  	return nextF(ctx, req)
   122  }
   123  
   124  var (
   125  	lock          = sync.RWMutex{}
   126  	serverFilters = make(map[string]ServerFilter)
   127  	clientFilters = make(map[string]ClientFilter)
   128  )
   129  
   130  // Register registers server/client filters by name.
   131  func Register(name string, s ServerFilter, c ClientFilter) {
   132  	lock.Lock()
   133  	defer lock.Unlock()
   134  	serverFilters[name] = s
   135  	clientFilters[name] = c
   136  }
   137  
   138  // GetServer gets the ServerFilter by name.
   139  func GetServer(name string) ServerFilter {
   140  	lock.RLock()
   141  	f := serverFilters[name]
   142  	lock.RUnlock()
   143  	return f
   144  }
   145  
   146  // GetClient gets the ClientFilter by name.
   147  func GetClient(name string) ClientFilter {
   148  	lock.RLock()
   149  	f := clientFilters[name]
   150  	lock.RUnlock()
   151  	return f
   152  }