trpc.group/trpc-go/trpc-go@v1.0.3/client/stream_filter.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package client
    15  
    16  import (
    17  	"context"
    18  	"sync"
    19  )
    20  
    21  var (
    22  	streamFilters = make(map[string]StreamFilter)
    23  	lock          = sync.RWMutex{}
    24  )
    25  
    26  // ClientStream is the interface returned to users to call its methods.
    27  type ClientStream interface {
    28  	// RecvMsg receives messages.
    29  	RecvMsg(m interface{}) error
    30  	// SendMsg sends messages.
    31  	SendMsg(m interface{}) error
    32  	// CloseSend closes sender.
    33  	// No more sending messages,
    34  	// but it's still allowed to continue to receive messages.
    35  	CloseSend() error
    36  	// Context gets the Context.
    37  	Context() context.Context
    38  }
    39  
    40  // ClientStreamDesc is the client stream description.
    41  type ClientStreamDesc struct {
    42  	// StreamName is the name of the stream, corresponding to Method of unary RPC.
    43  	StreamName string
    44  	// ClientStreams indicates whether it's client streaming.
    45  	ClientStreams bool
    46  	// ServerStreams indicates whether it's server streaming.
    47  	ServerStreams bool
    48  }
    49  
    50  // StreamFilter is the client stream filter.
    51  // StreamFilter processing happens before or after stream's establishing.
    52  type StreamFilter func(ctx context.Context, desc *ClientStreamDesc, streamer Streamer) (ClientStream, error)
    53  
    54  // Streamer is the wrapper filter function used to filter all methods of ClientStream.
    55  type Streamer func(ctx context.Context, desc *ClientStreamDesc) (ClientStream, error)
    56  
    57  // RegisterStreamFilter registers a StreamFilter by name.
    58  func RegisterStreamFilter(name string, filter StreamFilter) {
    59  	lock.Lock()
    60  	streamFilters[name] = filter
    61  	lock.Unlock()
    62  }
    63  
    64  // GetStreamFilter returns a StreamFilter by name.
    65  func GetStreamFilter(name string) StreamFilter {
    66  	lock.RLock()
    67  	f := streamFilters[name]
    68  	lock.RUnlock()
    69  	return f
    70  }
    71  
    72  // StreamFilterChain client stream filters chain.
    73  type StreamFilterChain []StreamFilter
    74  
    75  // Filter implements StreamFilter for multi stream filters.
    76  func (c StreamFilterChain) Filter(ctx context.Context,
    77  	desc *ClientStreamDesc, streamer Streamer) (ClientStream, error) {
    78  	for i := len(c) - 1; i >= 0; i-- {
    79  		next, curFilter := streamer, c[i]
    80  		streamer = func(ctx context.Context, desc *ClientStreamDesc) (ClientStream, error) {
    81  			return curFilter(ctx, desc, next)
    82  		}
    83  	}
    84  	return streamer(ctx, desc)
    85  }