github.com/polarismesh/polaris@v1.17.8/apiserver/grpcserver/stream.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package grpcserver
    19  
    20  import (
    21  	"context"
    22  	"io"
    23  	"strings"
    24  	"time"
    25  
    26  	"go.uber.org/zap"
    27  	"google.golang.org/grpc"
    28  	"google.golang.org/grpc/metadata"
    29  	"google.golang.org/grpc/peer"
    30  
    31  	commonlog "github.com/polarismesh/polaris/common/log"
    32  )
    33  
    34  // initVirtualStream 对 VirtualStream 的一些初始化动作
    35  type initVirtualStream func(vStream *VirtualStream)
    36  
    37  // WithVirtualStreamMethod 设置 method
    38  func WithVirtualStreamMethod(method string) initVirtualStream {
    39  	return func(vStream *VirtualStream) {
    40  		vStream.Method = method
    41  	}
    42  }
    43  
    44  // WithVirtualStreamServerStream 设置 grpc.ServerStream
    45  func WithVirtualStreamServerStream(stream grpc.ServerStream) initVirtualStream {
    46  	return func(vStream *VirtualStream) {
    47  		vStream.stream = stream
    48  	}
    49  }
    50  
    51  // WithVirtualStreamPreProcessFunc 设置 PreProcessFunc
    52  func WithVirtualStreamPreProcessFunc(preprocess PreProcessFunc) initVirtualStream {
    53  	return func(vStream *VirtualStream) {
    54  		vStream.preprocess = preprocess
    55  	}
    56  }
    57  
    58  // WithVirtualStreamPostProcessFunc 设置 PostProcessFunc
    59  func WithVirtualStreamPostProcessFunc(postprocess PostProcessFunc) initVirtualStream {
    60  	return func(vStream *VirtualStream) {
    61  		vStream.postprocess = postprocess
    62  	}
    63  }
    64  
    65  // WithVirtualStreamBaseServer 设置 BaseGrpcServer
    66  func WithVirtualStreamBaseServer(server *BaseGrpcServer) initVirtualStream {
    67  	return func(vStream *VirtualStream) {
    68  		vStream.server = server
    69  	}
    70  }
    71  
    72  // WithVirtualStreamLogger 设置 Logger
    73  func WithVirtualStreamLogger(log *commonlog.Scope) initVirtualStream {
    74  	return func(vStream *VirtualStream) {
    75  		vStream.log = log
    76  	}
    77  }
    78  
    79  func newVirtualStream(ctx context.Context, initOptions ...initVirtualStream) *VirtualStream {
    80  	var clientAddress string
    81  	var clientIP string
    82  	var userAgent string
    83  	var requestID string
    84  
    85  	peerAddress, exist := peer.FromContext(ctx)
    86  	if exist {
    87  		clientAddress = peerAddress.Addr.String()
    88  		// 解析获取clientIP
    89  		items := strings.Split(clientAddress, ":")
    90  		if len(items) == 2 {
    91  			clientIP = items[0]
    92  		}
    93  	}
    94  
    95  	meta, exist := metadata.FromIncomingContext(ctx)
    96  	if exist {
    97  		agents := meta["user-agent"]
    98  		if len(agents) > 0 {
    99  			userAgent = agents[0]
   100  		}
   101  
   102  		ids := meta["request-id"]
   103  		if len(ids) > 0 {
   104  			requestID = ids[0]
   105  		}
   106  	}
   107  
   108  	virtualStream := &VirtualStream{
   109  		ClientAddress: clientAddress,
   110  		ClientIP:      clientIP,
   111  		UserAgent:     userAgent,
   112  		RequestID:     requestID,
   113  		server:        nil,
   114  		stream:        nil,
   115  		Code:          0,
   116  	}
   117  
   118  	for i := range initOptions {
   119  		initOptions[i](virtualStream)
   120  	}
   121  
   122  	return virtualStream
   123  }
   124  
   125  // VirtualStream 虚拟Stream 继承ServerStream
   126  type VirtualStream struct {
   127  	server *BaseGrpcServer
   128  
   129  	Method        string
   130  	ClientAddress string
   131  	ClientIP      string
   132  	UserAgent     string
   133  	RequestID     string
   134  
   135  	stream grpc.ServerStream
   136  
   137  	Code int
   138  
   139  	preprocess  PreProcessFunc
   140  	postprocess PostProcessFunc
   141  
   142  	StartTime time.Time
   143  
   144  	log *commonlog.Scope
   145  }
   146  
   147  // SetHeader sets the header metadata. It may be called multiple times.
   148  // When call multiple times, all the provided metadata will be merged.
   149  // All the metadata will be sent out when one of the following happens:
   150  //   - ServerStream.SendHeader() is called;
   151  //   - The first response is sent out;
   152  //   - An RPC status is sent out (error or success).
   153  func (v *VirtualStream) SetHeader(md metadata.MD) error {
   154  	return v.stream.SetHeader(md)
   155  }
   156  
   157  // SendHeader sends the header metadata.
   158  // The provided md and headers set by SetHeader() will be sent.
   159  // It fails if called multiple times.
   160  func (v *VirtualStream) SendHeader(md metadata.MD) error {
   161  	return v.stream.SendHeader(md)
   162  }
   163  
   164  // SetTrailer sets the trailer metadata which will be sent with the RPC status.
   165  // When called more than once, all the provided metadata will be merged.
   166  func (v *VirtualStream) SetTrailer(md metadata.MD) {
   167  	v.stream.SetTrailer(md)
   168  }
   169  
   170  // Context returns the context for this stream.
   171  func (v *VirtualStream) Context() context.Context {
   172  	return v.stream.Context()
   173  }
   174  
   175  // RecvMsg blocks until it receives a message into m or the stream is
   176  // done. It returns io.EOF when the client has performed a CloseSend. On
   177  // any non-EOF error, the stream is aborted and the error contains the
   178  // RPC status.
   179  //
   180  // It is safe to have a goroutine calling SendMsg and another goroutine
   181  // calling RecvMsg on the same stream at the same time, but it is not
   182  // safe to call RecvMsg on the same stream in different goroutines.
   183  func (v *VirtualStream) RecvMsg(m interface{}) error {
   184  	err := v.stream.RecvMsg(m)
   185  	if err == io.EOF {
   186  		return err
   187  	}
   188  
   189  	if err == nil {
   190  		err = v.preprocess(v, false)
   191  	} else {
   192  		v.Code = -1
   193  	}
   194  
   195  	return err
   196  }
   197  
   198  // SendMsg sends a message. On error, SendMsg aborts the stream and the
   199  // error is returned directly.
   200  //
   201  // SendMsg blocks until:
   202  //   - There is sufficient flow control to schedule m with the transport, or
   203  //   - The stream is done, or
   204  //   - The stream breaks.
   205  //
   206  // SendMsg does not wait until the message is received by the client. An
   207  // untimely stream closure may result in lost messages.
   208  //
   209  // It is safe to have a goroutine calling SendMsg and another goroutine
   210  // calling RecvMsg on the same stream at the same time, but it is not safe
   211  // to call SendMsg on the same stream in different goroutines.
   212  func (v *VirtualStream) SendMsg(m interface{}) error {
   213  	v.postprocess(v, m)
   214  	m = v.handleResponse(v.stream, m)
   215  	err := v.stream.SendMsg(m)
   216  	if err != nil {
   217  		v.Code = -2
   218  	}
   219  	return err
   220  }
   221  
   222  func (v *VirtualStream) handleResponse(stream grpc.ServerStream, m interface{}) interface{} {
   223  	if v.server.cache == nil {
   224  		return m
   225  	}
   226  
   227  	cacheVal := v.server.convert(m)
   228  	if cacheVal == nil {
   229  		return m
   230  	}
   231  
   232  	if saveVal := v.server.cache.Get(cacheVal.CacheType, cacheVal.Key); saveVal != nil {
   233  		return saveVal.GetPreparedMessage()
   234  	}
   235  
   236  	if err := cacheVal.PrepareMessage(stream); err != nil {
   237  		v.log.Warn("[Grpc][ProtoCache] prepare message fail, direct send msg", zap.String("key", cacheVal.Key),
   238  			zap.Error(err))
   239  		return m
   240  	}
   241  
   242  	cacheVal, ok := v.server.cache.Put(cacheVal)
   243  	if !ok {
   244  		v.log.Warn("[Grpc][ProtoCache] put cache ignore", zap.String("key", cacheVal.Key),
   245  			zap.String("cacheType", cacheVal.CacheType))
   246  	}
   247  	if cacheVal == nil {
   248  		return m
   249  	}
   250  
   251  	return cacheVal.GetPreparedMessage()
   252  }