trpc.group/trpc-go/trpc-go@v1.0.3/client/stream_filter.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package client 15 16 import ( 17 "context" 18 "sync" 19 ) 20 21 var ( 22 streamFilters = make(map[string]StreamFilter) 23 lock = sync.RWMutex{} 24 ) 25 26 // ClientStream is the interface returned to users to call its methods. 27 type ClientStream interface { 28 // RecvMsg receives messages. 29 RecvMsg(m interface{}) error 30 // SendMsg sends messages. 31 SendMsg(m interface{}) error 32 // CloseSend closes sender. 33 // No more sending messages, 34 // but it's still allowed to continue to receive messages. 35 CloseSend() error 36 // Context gets the Context. 37 Context() context.Context 38 } 39 40 // ClientStreamDesc is the client stream description. 41 type ClientStreamDesc struct { 42 // StreamName is the name of the stream, corresponding to Method of unary RPC. 43 StreamName string 44 // ClientStreams indicates whether it's client streaming. 45 ClientStreams bool 46 // ServerStreams indicates whether it's server streaming. 47 ServerStreams bool 48 } 49 50 // StreamFilter is the client stream filter. 51 // StreamFilter processing happens before or after stream's establishing. 52 type StreamFilter func(ctx context.Context, desc *ClientStreamDesc, streamer Streamer) (ClientStream, error) 53 54 // Streamer is the wrapper filter function used to filter all methods of ClientStream. 55 type Streamer func(ctx context.Context, desc *ClientStreamDesc) (ClientStream, error) 56 57 // RegisterStreamFilter registers a StreamFilter by name. 58 func RegisterStreamFilter(name string, filter StreamFilter) { 59 lock.Lock() 60 streamFilters[name] = filter 61 lock.Unlock() 62 } 63 64 // GetStreamFilter returns a StreamFilter by name. 65 func GetStreamFilter(name string) StreamFilter { 66 lock.RLock() 67 f := streamFilters[name] 68 lock.RUnlock() 69 return f 70 } 71 72 // StreamFilterChain client stream filters chain. 73 type StreamFilterChain []StreamFilter 74 75 // Filter implements StreamFilter for multi stream filters. 76 func (c StreamFilterChain) Filter(ctx context.Context, 77 desc *ClientStreamDesc, streamer Streamer) (ClientStream, error) { 78 for i := len(c) - 1; i >= 0; i-- { 79 next, curFilter := streamer, c[i] 80 streamer = func(ctx context.Context, desc *ClientStreamDesc) (ClientStream, error) { 81 return curFilter(ctx, desc, next) 82 } 83 } 84 return streamer(ctx, desc) 85 }