trpc.group/trpc-go/trpc-go@v1.0.3/server/stream_filter.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package server
    15  
    16  import "sync"
    17  
    18  var (
    19  	streamFilters = make(map[string]StreamFilter)
    20  	lock          = sync.RWMutex{}
    21  )
    22  
    23  // StreamServerInfo is stream information on server side.
    24  type StreamServerInfo struct {
    25  	// FullMethod is the full RPC method string, i.e., /package.service/method.
    26  	FullMethod string
    27  	// IsClientStream indicates whether the RPC is a client streaming RPC.
    28  	IsClientStream bool
    29  	// IsServerStream indicates whether the RPC is a server streaming RPC.
    30  	IsServerStream bool
    31  }
    32  
    33  // StreamFilter is server stream filter.
    34  type StreamFilter func(ss Stream, info *StreamServerInfo, handler StreamHandler) error
    35  
    36  // StreamFilterChain  server stream filters chain.
    37  type StreamFilterChain []StreamFilter
    38  
    39  // Filter implements StreamFilter for multi stream filters.
    40  func (c StreamFilterChain) Filter(ss Stream, info *StreamServerInfo, handler StreamHandler) error {
    41  	for i := len(c) - 1; i >= 0; i-- {
    42  		next, curFilter := handler, c[i]
    43  		handler = func(ss Stream) error {
    44  			return curFilter(ss, info, next)
    45  		}
    46  	}
    47  	return handler(ss)
    48  }
    49  
    50  // RegisterStreamFilter registers server stream filter with name.
    51  func RegisterStreamFilter(name string, filter StreamFilter) {
    52  	lock.Lock()
    53  	streamFilters[name] = filter
    54  	lock.Unlock()
    55  }
    56  
    57  // GetStreamFilter gets server stream filter by name.
    58  func GetStreamFilter(name string) StreamFilter {
    59  	lock.RLock()
    60  	f := streamFilters[name]
    61  	lock.RUnlock()
    62  	return f
    63  }