github.com/polarismesh/polaris@v1.17.8/apiserver/grpcserver/base.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package grpcserver
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"net"
    24  	"net/http"
    25  	"runtime"
    26  	"strings"
    27  	"time"
    28  
    29  	apimodel "github.com/polarismesh/specification/source/go/api/v1/model"
    30  	"go.uber.org/zap"
    31  	"google.golang.org/grpc"
    32  	"google.golang.org/grpc/codes"
    33  	"google.golang.org/grpc/credentials"
    34  	"google.golang.org/grpc/metadata"
    35  	"google.golang.org/grpc/peer"
    36  	"google.golang.org/grpc/status"
    37  
    38  	api "github.com/polarismesh/polaris/common/api/v1"
    39  	connhook "github.com/polarismesh/polaris/common/conn/hook"
    40  	connlimit "github.com/polarismesh/polaris/common/conn/limit"
    41  	commonlog "github.com/polarismesh/polaris/common/log"
    42  	"github.com/polarismesh/polaris/common/metrics"
    43  	"github.com/polarismesh/polaris/common/model"
    44  	"github.com/polarismesh/polaris/common/secure"
    45  	"github.com/polarismesh/polaris/common/utils"
    46  	"github.com/polarismesh/polaris/plugin"
    47  )
    48  
    49  // InitServer BaseGrpcServer.Run 中回调函数的定义
    50  type InitServer func(*grpc.Server) error
    51  
    52  // BaseGrpcServer base utilities and functions for gRPC Connector
    53  type BaseGrpcServer struct {
    54  	listenIP        string
    55  	listenPort      uint32
    56  	connLimitConfig *connlimit.Config
    57  	tlsInfo         *secure.TLSInfo
    58  	start           bool
    59  	restart         bool
    60  	exitCh          chan struct{}
    61  
    62  	protocol string
    63  
    64  	bz model.BzModule
    65  
    66  	server     *grpc.Server
    67  	statis     plugin.Statis
    68  	ratelimit  plugin.Ratelimit
    69  	OpenMethod map[string]bool
    70  
    71  	cache   Cache
    72  	convert MessageToCache
    73  
    74  	log *commonlog.Scope
    75  }
    76  
    77  // GetPort get the connector listen port value
    78  func (b *BaseGrpcServer) GetPort() uint32 {
    79  	return b.listenPort
    80  }
    81  
    82  // Initialize init the gRPC server
    83  func (b *BaseGrpcServer) Initialize(ctx context.Context, conf map[string]interface{}, initOptions ...InitOption) error {
    84  	for i := range initOptions {
    85  		initOptions[i](b)
    86  	}
    87  
    88  	b.listenIP = conf["listenIP"].(string)
    89  	b.listenPort = uint32(conf["listenPort"].(int))
    90  
    91  	if raw, _ := conf["connLimit"].(map[interface{}]interface{}); raw != nil {
    92  		connConfig, err := connlimit.ParseConnLimitConfig(raw)
    93  		if err != nil {
    94  			return err
    95  		}
    96  		b.connLimitConfig = connConfig
    97  	}
    98  
    99  	if raw, _ := conf["tls"].(map[interface{}]interface{}); raw != nil {
   100  		tlsConfig, err := secure.ParseTLSConfig(raw)
   101  		if err != nil {
   102  			return err
   103  		}
   104  		b.tlsInfo = &secure.TLSInfo{
   105  			CertFile:      tlsConfig.CertFile,
   106  			KeyFile:       tlsConfig.KeyFile,
   107  			TrustedCAFile: tlsConfig.TrustedCAFile,
   108  		}
   109  	}
   110  
   111  	if ratelimit := plugin.GetRatelimit(); ratelimit != nil {
   112  		b.log.Infof("[API-Server] %s server open the ratelimit", b.protocol)
   113  		b.ratelimit = ratelimit
   114  	}
   115  
   116  	return nil
   117  }
   118  
   119  // Stop stopping the gRPC server
   120  func (b *BaseGrpcServer) Stop(protocol string) {
   121  	connlimit.RemoveLimitListener(protocol)
   122  	if b.server != nil {
   123  		b.server.Stop()
   124  	}
   125  }
   126  
   127  // Run server main loop
   128  func (b *BaseGrpcServer) Run(errCh chan error, protocol string, initServer InitServer) {
   129  	b.log.Infof("[API-Server] start %s server", protocol)
   130  	b.exitCh = make(chan struct{})
   131  	b.start = true
   132  	defer func() {
   133  		close(b.exitCh)
   134  		b.start = false
   135  	}()
   136  
   137  	address := fmt.Sprintf("%v:%v", b.listenIP, b.listenPort)
   138  	listener, err := net.Listen("tcp", address)
   139  	if err != nil {
   140  		b.log.Error("[API-Server][GRPC] %v", zap.Error(err))
   141  		errCh <- err
   142  		return
   143  	}
   144  	defer listener.Close()
   145  
   146  	// 如果设置最大连接数
   147  	if b.connLimitConfig != nil && b.connLimitConfig.OpenConnLimit {
   148  		b.log.Infof("[API-Server][GRPC] grpc server use max connection limit: %d, grpc max limit: %d",
   149  			b.connLimitConfig.MaxConnPerHost, b.connLimitConfig.MaxConnLimit)
   150  		listener, err = connlimit.NewListener(listener, protocol, b.connLimitConfig)
   151  		if err != nil {
   152  			b.log.Error("[API-Server][GRPC] conn limit init", zap.Error(err))
   153  			errCh <- err
   154  			return
   155  		}
   156  	}
   157  
   158  	b.log.Infof("[API-Server][GRPC] open connection counter net.Listener")
   159  	listener = connhook.NewHookListener(listener, &connCounterHook{
   160  		bz: b.bz,
   161  	})
   162  
   163  	// 指定使用服务端证书创建一个 TLS credentials
   164  	var creds credentials.TransportCredentials
   165  	if !b.tlsInfo.IsEmpty() {
   166  		creds, err = credentials.NewServerTLSFromFile(b.tlsInfo.CertFile, b.tlsInfo.KeyFile)
   167  		if err != nil {
   168  			b.log.Error("failed to create credentials: %v", zap.Error(err))
   169  			errCh <- err
   170  			return
   171  		}
   172  	}
   173  
   174  	// 设置 grpc server options
   175  	opts := []grpc.ServerOption{
   176  		grpc.UnaryInterceptor(b.unaryInterceptor),
   177  		grpc.StreamInterceptor(b.streamInterceptor),
   178  	}
   179  	if creds != nil {
   180  		// 指定使用 TLS credentials
   181  		opts = append(opts, grpc.Creds(creds))
   182  	}
   183  	server := grpc.NewServer(opts...)
   184  
   185  	if err = initServer(server); err != nil {
   186  		errCh <- err
   187  		return
   188  	}
   189  	b.server = server
   190  
   191  	b.statis = plugin.GetStatis()
   192  
   193  	if err := server.Serve(listener); err != nil {
   194  		b.log.Errorf("[API-Server][GRPC] %v", err)
   195  		errCh <- err
   196  		return
   197  	}
   198  
   199  	b.log.Infof("[API-Server] %s server stop", protocol)
   200  }
   201  
   202  var notPrintableMethods = map[string]bool{
   203  	"/v1.PolarisGRPC/Heartbeat": true,
   204  }
   205  
   206  func (b *BaseGrpcServer) unaryInterceptor(ctx context.Context, req interface{},
   207  	info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (rsp interface{}, err error) {
   208  	stream := newVirtualStream(ctx,
   209  		WithVirtualStreamBaseServer(b),
   210  		WithVirtualStreamLogger(b.log),
   211  		WithVirtualStreamMethod(info.FullMethod),
   212  		WithVirtualStreamPreProcessFunc(b.preprocess),
   213  		WithVirtualStreamPostProcessFunc(b.postprocess),
   214  	)
   215  
   216  	func() {
   217  		_, ok := notPrintableMethods[info.FullMethod]
   218  		var printable = !ok
   219  		if err := b.preprocess(stream, printable); err != nil {
   220  			return
   221  		}
   222  
   223  		// 判断是否允许访问
   224  		if ok := b.AllowAccess(stream.Method); !ok {
   225  			rsp = api.NewResponse(apimodel.Code_ClientAPINotOpen)
   226  			return
   227  		}
   228  
   229  		// handler执行前,限流
   230  		if code := b.EnterRatelimit(stream.ClientIP, stream.Method); code != uint32(api.ExecuteSuccess) {
   231  			rsp = api.NewResponse(apimodel.Code(code))
   232  			return
   233  		}
   234  		defer func() {
   235  			if panicInfo := recover(); panicInfo != nil {
   236  				var buf [4086]byte
   237  				n := runtime.Stack(buf[:], false)
   238  				b.log.Errorf("panic %+v", string(buf[:n]))
   239  			}
   240  		}()
   241  
   242  		rsp, err = handler(ctx, req)
   243  	}()
   244  
   245  	b.postprocess(stream, rsp)
   246  
   247  	return
   248  }
   249  
   250  func (b *BaseGrpcServer) recoverFunc(i interface{}, w http.ResponseWriter) {
   251  
   252  }
   253  
   254  func (b *BaseGrpcServer) streamInterceptor(srv interface{}, ss grpc.ServerStream,
   255  	info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
   256  	stream := newVirtualStream(ss.Context(),
   257  		WithVirtualStreamBaseServer(b),
   258  		WithVirtualStreamServerStream(ss),
   259  		WithVirtualStreamMethod(info.FullMethod),
   260  		WithVirtualStreamPreProcessFunc(b.preprocess),
   261  		WithVirtualStreamPostProcessFunc(b.postprocess),
   262  	)
   263  
   264  	defer func() {
   265  		if err := recover(); err != nil {
   266  			var buf [4086]byte
   267  			n := runtime.Stack(buf[:], false)
   268  			b.log.Errorf("panic %+v", string(buf[:n]))
   269  		}
   270  	}()
   271  
   272  	err = handler(srv, stream)
   273  	if err != nil {
   274  		fromError, ok := status.FromError(err)
   275  		if ok && fromError.Code() == codes.Canceled {
   276  			// 存在非EOF读错误或者写错误
   277  			b.log.Info("[API-Server][GRPC] handler stream is canceled by client",
   278  				zap.String("client-address", stream.ClientAddress),
   279  				zap.String("user-agent", stream.UserAgent),
   280  				utils.ZapRequestID(stream.RequestID),
   281  				zap.String("method", stream.Method),
   282  				zap.Error(err),
   283  			)
   284  		} else {
   285  			// 存在非EOF读错误或者写错误
   286  			b.log.Error("[API-Server][GRPC] handler stream",
   287  				zap.String("client-address", stream.ClientAddress),
   288  				zap.String("user-agent", stream.UserAgent),
   289  				utils.ZapRequestID(stream.RequestID),
   290  				zap.String("method", stream.Method),
   291  				zap.Error(err),
   292  			)
   293  		}
   294  
   295  		b.statis.ReportCallMetrics(metrics.CallMetric{
   296  			Type:     metrics.ServerCallMetric,
   297  			API:      stream.Method,
   298  			Protocol: "gRPC",
   299  			Code:     int(stream.Code),
   300  			Duration: 0,
   301  		})
   302  	}
   303  	return
   304  }
   305  
   306  // PreProcessFunc preprocess function define
   307  type PreProcessFunc func(stream *VirtualStream, isPrint bool) error
   308  
   309  func (b *BaseGrpcServer) preprocess(stream *VirtualStream, isPrint bool) error {
   310  	// 设置开始时间
   311  	stream.StartTime = time.Now()
   312  
   313  	if isPrint {
   314  		// 打印请求
   315  		b.log.Info("[API-Server][GRPC] receive request",
   316  			zap.String("client-address", stream.ClientAddress),
   317  			zap.String("user-agent", stream.UserAgent),
   318  			utils.ZapRequestID(stream.RequestID),
   319  			zap.String("method", stream.Method),
   320  		)
   321  	}
   322  
   323  	return nil
   324  }
   325  
   326  // PostProcessFunc postprocess function define
   327  type PostProcessFunc func(stream *VirtualStream, m interface{})
   328  
   329  func (b *BaseGrpcServer) postprocess(stream *VirtualStream, m interface{}) {
   330  	var code int
   331  	var polarisCode uint32
   332  	if response, ok := m.(api.ResponseMessage); ok {
   333  		polarisCode = response.GetCode().GetValue()
   334  		code = api.CalcCode(response)
   335  
   336  		// 打印回复
   337  		if code != http.StatusOK {
   338  			b.log.Error("[API-Server][GRPC] send response",
   339  				zap.String("client-address", stream.ClientAddress),
   340  				zap.String("user-agent", stream.UserAgent),
   341  				utils.ZapRequestID(stream.RequestID),
   342  				zap.String("method", stream.Method),
   343  				zap.String("response", response.String()),
   344  			)
   345  		}
   346  		if polarisCode > 0 {
   347  			code = int(polarisCode)
   348  		}
   349  	} else {
   350  		code = stream.Code
   351  		// 打印回复
   352  		if code != int(codes.OK) {
   353  			b.log.Error("[API-Server][GRPC] send response",
   354  				zap.String("client-address", stream.ClientAddress),
   355  				zap.String("user-agent", stream.UserAgent),
   356  				utils.ZapRequestID(stream.RequestID),
   357  				zap.String("method", stream.Method),
   358  				zap.String("response", response.String()),
   359  			)
   360  		}
   361  	}
   362  
   363  	// 接口调用统计
   364  	diff := time.Since(stream.StartTime)
   365  
   366  	// 打印耗时超过1s的请求
   367  	if diff > time.Second {
   368  		b.log.Info("[API-Server][GRPC] handling time > 1s",
   369  			zap.String("client-address", stream.ClientAddress),
   370  			zap.String("user-agent", stream.UserAgent),
   371  			utils.ZapRequestID(stream.RequestID),
   372  			zap.String("method", stream.Method),
   373  			zap.Duration("handling-time", diff),
   374  		)
   375  	}
   376  
   377  	b.statis.ReportCallMetrics(metrics.CallMetric{
   378  		Type:     metrics.ServerCallMetric,
   379  		API:      stream.Method,
   380  		Protocol: "gRPC",
   381  		Code:     int(stream.Code),
   382  		Duration: diff,
   383  	})
   384  }
   385  
   386  // Restart restart gRPC server
   387  func (b *BaseGrpcServer) Restart(
   388  	initialize func() error, run func(), protocol string, option map[string]interface{}) error {
   389  	b.log.Infof("[API-Server][GRPC] restart %s server with new config: %+v", protocol, option)
   390  
   391  	b.restart = true
   392  	b.Stop(protocol)
   393  	if b.start {
   394  		<-b.exitCh
   395  	}
   396  
   397  	b.log.Infof("[API-Server][GRPC] old %s server has stopped, begin restarting it", protocol)
   398  	if err := initialize(); err != nil {
   399  		b.log.Errorf("restart %s server err: %s", protocol, err.Error())
   400  		return err
   401  	}
   402  
   403  	b.log.Infof("[API-Server][GRPC] init %s server successfully, restart it", protocol)
   404  	b.restart = false
   405  	go run()
   406  
   407  	return nil
   408  }
   409  
   410  // EnterRatelimit api ratelimit
   411  func (b *BaseGrpcServer) EnterRatelimit(ip string, method string) uint32 {
   412  	if b.ratelimit == nil {
   413  		return api.ExecuteSuccess
   414  	}
   415  
   416  	// ipRatelimit
   417  	if ok := b.ratelimit.Allow(plugin.IPRatelimit, ip); !ok {
   418  		b.log.Error("[API-Server][GRPC] ip ratelimit is not allow", zap.String("client-ip", ip),
   419  			zap.String("method", method))
   420  		return api.IPRateLimit
   421  	}
   422  	// apiRatelimit
   423  	if ok := b.ratelimit.Allow(plugin.APIRatelimit, method); !ok {
   424  		b.log.Error("[API-Server][GRPC] api rate limit is not allow", zap.String("client-ip", ip),
   425  			zap.String("method", method))
   426  		return api.APIRateLimit
   427  	}
   428  
   429  	return api.ExecuteSuccess
   430  }
   431  
   432  // AllowAccess api allow access
   433  func (b *BaseGrpcServer) AllowAccess(method string) bool {
   434  	if len(b.OpenMethod) == 0 {
   435  		return true
   436  	}
   437  	_, ok := b.OpenMethod[method]
   438  	return ok
   439  }
   440  
   441  // ConvertContext 将GRPC上下文转换成内部上下文
   442  func ConvertContext(ctx context.Context) context.Context {
   443  	var (
   444  		requestID = ""
   445  		userAgent = ""
   446  	)
   447  	meta, exist := metadata.FromIncomingContext(ctx)
   448  	if exist {
   449  		ids := meta["request-id"]
   450  		if len(ids) > 0 {
   451  			requestID = ids[0]
   452  		}
   453  		agents := meta["user-agent"]
   454  		if len(agents) > 0 {
   455  			userAgent = agents[0]
   456  		}
   457  	} else {
   458  		meta = metadata.MD{}
   459  	}
   460  
   461  	var (
   462  		clientIP = ""
   463  		address  = ""
   464  	)
   465  	if pr, ok := peer.FromContext(ctx); ok && pr.Addr != nil {
   466  		address = pr.Addr.String()
   467  		addrSlice := strings.Split(address, ":")
   468  		if len(addrSlice) == 2 {
   469  			clientIP = addrSlice[0]
   470  		}
   471  	}
   472  
   473  	ctx = context.Background()
   474  	ctx = context.WithValue(ctx, utils.ContextGrpcHeader, meta)
   475  	ctx = context.WithValue(ctx, utils.StringContext("request-id"), requestID)
   476  	ctx = context.WithValue(ctx, utils.StringContext("client-ip"), clientIP)
   477  	ctx = context.WithValue(ctx, utils.ContextClientAddress, address)
   478  	ctx = context.WithValue(ctx, utils.StringContext("user-agent"), userAgent)
   479  
   480  	return ctx
   481  }
   482  
   483  type connCounterHook struct {
   484  	bz model.BzModule
   485  }
   486  
   487  func (h *connCounterHook) OnAccept(conn net.Conn) {
   488  	if h.bz == model.DiscoverModule {
   489  		metrics.AddDiscoveryClientConn()
   490  	}
   491  	if h.bz == model.ConfigModule {
   492  		metrics.AddConfigurationClientConn()
   493  	}
   494  	metrics.AddSDKClientConn()
   495  }
   496  
   497  func (h *connCounterHook) OnRelease(conn net.Conn) {
   498  	if h.bz == model.DiscoverModule {
   499  		metrics.RemoveDiscoveryClientConn()
   500  	}
   501  	if h.bz == model.ConfigModule {
   502  		metrics.RemoveConfigurationClientConn()
   503  	}
   504  	metrics.RemoveSDKClientConn()
   505  }
   506  
   507  func (h *connCounterHook) OnClose() {
   508  	metrics.ResetDiscoveryClientConn()
   509  	metrics.ResetConfigurationClientConn()
   510  	metrics.ResetSDKClientConn()
   511  }