github.com/godaddy-x/freego@v1.0.156/rpcx/grpcx.go (about)

     1  package rpcx
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"fmt"
     7  	"github.com/godaddy-x/freego/cache/limiter"
     8  	"github.com/godaddy-x/freego/rpcx/pb"
     9  	"github.com/godaddy-x/freego/rpcx/pool"
    10  	"github.com/godaddy-x/freego/utils"
    11  	"github.com/godaddy-x/freego/utils/crypto"
    12  	"github.com/godaddy-x/freego/utils/jwt"
    13  	"github.com/godaddy-x/freego/zlog"
    14  	consulapi "github.com/hashicorp/consul/api"
    15  	"google.golang.org/grpc"
    16  	"google.golang.org/grpc/credentials"
    17  	"google.golang.org/grpc/credentials/insecure"
    18  	"google.golang.org/grpc/keepalive"
    19  	"io/ioutil"
    20  	"net"
    21  	"net/http"
    22  	"sync"
    23  	"time"
    24  )
    25  
    26  var (
    27  	serverDialTLS   grpc.ServerOption
    28  	clientDialTLS   grpc.DialOption
    29  	jwtConfig       *jwt.JwtConfig
    30  	rateLimiterCall func(string) (rate.Option, error)
    31  	selectionCall   func([]*consulapi.ServiceEntry, GRPC) *consulapi.ServiceEntry
    32  	appConfigCall   func(string) (AppConfig, error)
    33  	authorizeTLS    *crypto.RsaObj
    34  	accessToken     = ""
    35  	clientOptions   []grpc.DialOption
    36  	clientConnPools = ClientConnPool{pools: make(map[string]pool.Pool, 0)}
    37  )
    38  
    39  type ClientConnPool struct {
    40  	m sync.Mutex
    41  	//once  concurrent.Once
    42  	pools map[string]pool.Pool
    43  }
    44  
    45  type GRPCManager struct {
    46  	consul       *ConsulManager
    47  	consulDs     string
    48  	authenticate bool
    49  }
    50  
    51  type TlsConfig struct {
    52  	UseTLS    bool
    53  	UseMTLS   bool
    54  	CACrtFile string
    55  	CAKeyFile string
    56  	KeyFile   string
    57  	CrtFile   string
    58  	HostName  string
    59  }
    60  
    61  type AppConfig struct {
    62  	AppId    string
    63  	AppKey   string
    64  	Status   int64
    65  	LastTime int64
    66  }
    67  
    68  type GRPC struct {
    69  	Ds      string                    // consul数据源ds
    70  	Tags    []string                  // 服务标签名称
    71  	Address string                    // 服务地址,为空时自动填充内网IP
    72  	RpcPort int                       // 服务地址端口
    73  	Service string                    // 服务名称
    74  	Cache   int                       // 服务缓存时间/秒
    75  	Timeout int                       // context timeout/毫秒
    76  	AddRPC  func(server *grpc.Server) // grpc注册proto服务
    77  }
    78  
    79  type AuthObject struct {
    80  	AppId     string
    81  	Nonce     string
    82  	Time      int64
    83  	Signature string
    84  }
    85  
    86  func GetGRPCJwtConfig() (*jwt.JwtConfig, error) {
    87  	if len(jwtConfig.TokenKey) == 0 {
    88  		return nil, utils.Error("grpc jwt key is nil")
    89  	}
    90  	return jwtConfig, nil
    91  }
    92  
    93  func GetAuthorizeTLS() (*crypto.RsaObj, error) {
    94  	if authorizeTLS == nil {
    95  		return nil, utils.Error("authorize tls is nil")
    96  	}
    97  	return authorizeTLS, nil
    98  }
    99  
   100  func GetGRPCAppConfig(appid string) (AppConfig, error) {
   101  	if appConfigCall == nil {
   102  		return AppConfig{}, utils.Error("grpc app config call is nil")
   103  	}
   104  	return appConfigCall(appid)
   105  }
   106  
   107  func (self *GRPCManager) CreateJwtConfig(tokenKey string, tokenExp ...int64) {
   108  	if jwtConfig != nil {
   109  		return
   110  	}
   111  	if len(tokenKey) < 32 {
   112  		panic("jwt tokenKey length should be >= 32")
   113  	}
   114  	var exp = int64(3600)
   115  	if len(tokenExp) > 0 && tokenExp[0] >= 3600 {
   116  		exp = tokenExp[0]
   117  	}
   118  	jwtConfig = &jwt.JwtConfig{
   119  		TokenTyp: jwt.JWT,
   120  		TokenAlg: jwt.HS256,
   121  		TokenKey: tokenKey,
   122  		TokenExp: exp,
   123  	}
   124  }
   125  
   126  func (self *GRPCManager) CreateAppConfigCall(fun func(appId string) (AppConfig, error)) {
   127  	if appConfigCall != nil {
   128  		return
   129  	}
   130  	appConfigCall = fun
   131  }
   132  
   133  func (self *GRPCManager) CreateRateLimiterCall(fun func(method string) (rate.Option, error)) {
   134  	if rateLimiterCall != nil {
   135  		return
   136  	}
   137  	rateLimiterCall = fun
   138  }
   139  
   140  func (self *GRPCManager) CreateSelectionCall(fun func([]*consulapi.ServiceEntry, GRPC) *consulapi.ServiceEntry) {
   141  	if selectionCall != nil {
   142  		return
   143  	}
   144  	selectionCall = fun
   145  }
   146  
   147  // CreateAuthorizeTLS If server TLS is used, the certificate server.key is used by default
   148  // Otherwise, the method needs to be explicitly called to set the certificate
   149  func (self *GRPCManager) CreateAuthorizeTLS(keyPath string) {
   150  	if authorizeTLS != nil {
   151  		return
   152  	}
   153  	if len(keyPath) == 0 {
   154  		panic("authorize tls key path is nil")
   155  	}
   156  	obj := &crypto.RsaObj{}
   157  	if err := obj.LoadRsaFile(keyPath); err != nil {
   158  		panic(err)
   159  	}
   160  	authorizeTLS = obj
   161  }
   162  
   163  func (self *GRPCManager) CreateServerTLS(tlsConfig TlsConfig) {
   164  	if serverDialTLS != nil {
   165  		return
   166  	}
   167  	if tlsConfig.UseTLS && tlsConfig.UseMTLS {
   168  		panic("only one UseTLS/UseMTLS can be used")
   169  	}
   170  	if len(tlsConfig.CrtFile) == 0 {
   171  		panic("server.crt file is nil")
   172  	}
   173  	if len(tlsConfig.KeyFile) == 0 {
   174  		panic("server.key file is nil")
   175  	}
   176  	if tlsConfig.UseTLS {
   177  		creds, err := credentials.NewServerTLSFromFile(tlsConfig.CrtFile, tlsConfig.KeyFile)
   178  		if err != nil {
   179  			panic(err)
   180  		}
   181  		serverDialTLS = grpc.Creds(creds)
   182  		self.CreateAuthorizeTLS(tlsConfig.KeyFile)
   183  	}
   184  	if tlsConfig.UseMTLS {
   185  		if len(tlsConfig.CACrtFile) == 0 {
   186  			panic("ca.crt file is nil")
   187  		}
   188  		certPool := x509.NewCertPool()
   189  		ca, err := ioutil.ReadFile(tlsConfig.CACrtFile)
   190  		if err != nil {
   191  			panic(err)
   192  		}
   193  		if ok := certPool.AppendCertsFromPEM(ca); !ok {
   194  			panic("failed to append certs")
   195  		}
   196  		cert, err := tls.LoadX509KeyPair(tlsConfig.CrtFile, tlsConfig.KeyFile)
   197  		if err != nil {
   198  			panic(err)
   199  		}
   200  		// 构建基于 TLS 的 TransportCredentials
   201  		creds := credentials.NewTLS(&tls.Config{
   202  			// 设置证书链,允许包含一个或多个
   203  			Certificates: []tls.Certificate{cert},
   204  			// 要求必须校验客户端的证书 可以根据实际情况选用其他参数
   205  			ClientAuth: tls.RequireAndVerifyClientCert, // NOTE: this is optional!
   206  			// 设置根证书的集合,校验方式使用 ClientAuth 中设定的模式
   207  			ClientCAs: certPool,
   208  		})
   209  		serverDialTLS = grpc.Creds(creds)
   210  		self.CreateAuthorizeTLS(tlsConfig.KeyFile)
   211  	}
   212  }
   213  
   214  func (self *GRPCManager) CreateClientTLS(tlsConfig TlsConfig) {
   215  	if clientDialTLS != nil {
   216  		return
   217  	}
   218  	if tlsConfig.UseTLS && tlsConfig.UseMTLS {
   219  		panic("only one tls mode can be used")
   220  	}
   221  	if len(tlsConfig.CrtFile) == 0 {
   222  		panic("server.crt file is nil")
   223  	}
   224  	if tlsConfig.UseTLS {
   225  		if len(tlsConfig.CrtFile) == 0 {
   226  			panic("server.crt file is nil")
   227  		}
   228  		if len(tlsConfig.HostName) == 0 {
   229  			panic("server host name is nil")
   230  		}
   231  		creds, err := credentials.NewClientTLSFromFile(tlsConfig.CrtFile, tlsConfig.HostName)
   232  		if err != nil {
   233  			panic(err)
   234  		}
   235  		clientDialTLS = grpc.WithTransportCredentials(creds)
   236  	}
   237  	if tlsConfig.UseMTLS {
   238  		if len(tlsConfig.CACrtFile) == 0 {
   239  			panic("ca.crt file is nil")
   240  		}
   241  		if len(tlsConfig.CrtFile) == 0 {
   242  			panic("client.crt file is nil")
   243  		}
   244  		if len(tlsConfig.KeyFile) == 0 {
   245  			panic("client.key file is nil")
   246  		}
   247  		if len(tlsConfig.HostName) == 0 {
   248  			panic("server host name is nil")
   249  		}
   250  		// 加载客户端证书
   251  		cert, err := tls.LoadX509KeyPair(tlsConfig.CrtFile, tlsConfig.KeyFile)
   252  		if err != nil {
   253  			panic(err)
   254  		}
   255  		// 构建CertPool以校验服务端证书有效性
   256  		certPool := x509.NewCertPool()
   257  		ca, err := ioutil.ReadFile(tlsConfig.CACrtFile)
   258  		if err != nil {
   259  			panic(err)
   260  		}
   261  		if ok := certPool.AppendCertsFromPEM(ca); !ok {
   262  			panic("failed to append ca certs")
   263  		}
   264  		creds := credentials.NewTLS(&tls.Config{
   265  			Certificates: []tls.Certificate{cert},
   266  			ServerName:   tlsConfig.HostName, // NOTE: this is required!
   267  			RootCAs:      certPool,
   268  		})
   269  		clientDialTLS = grpc.WithTransportCredentials(creds)
   270  	}
   271  }
   272  
   273  func RunServer(consulDs string, authenticate bool, objects ...*GRPC) {
   274  	if len(objects) == 0 {
   275  		panic("rpc objects is nil...")
   276  	}
   277  	c, err := NewConsul(consulDs)
   278  	if err != nil {
   279  		panic(err)
   280  	}
   281  	self := &GRPCManager{consul: c, consulDs: consulDs, authenticate: authenticate}
   282  	services, err := self.consul.GetAllService("")
   283  	if err != nil {
   284  		panic(err)
   285  	}
   286  	if err != nil {
   287  		panic(err)
   288  	}
   289  	opts := []grpc.ServerOption{
   290  		grpc.InitialWindowSize(pool.InitialWindowSize),
   291  		grpc.InitialConnWindowSize(pool.InitialConnWindowSize),
   292  		grpc.MaxSendMsgSize(pool.MaxSendMsgSize),
   293  		grpc.MaxRecvMsgSize(pool.MaxRecvMsgSize),
   294  		grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
   295  			PermitWithoutStream: true,
   296  		}),
   297  		grpc.KeepaliveParams(keepalive.ServerParameters{
   298  			Time:    pool.KeepAliveTime,
   299  			Timeout: pool.KeepAliveTimeout,
   300  		}),
   301  		grpc.UnaryInterceptor(self.ServerInterceptor),
   302  	}
   303  	if serverDialTLS != nil {
   304  		opts = append(opts, serverDialTLS)
   305  	}
   306  	grpcServer := grpc.NewServer(opts...)
   307  	for _, object := range objects {
   308  		address := utils.GetLocalIP()
   309  		port := self.consul.Config.RpcPort
   310  		if object.RpcPort > 0 {
   311  			port = object.RpcPort
   312  		}
   313  		if len(address) == 0 {
   314  			panic("local address reading failed")
   315  		}
   316  		if len(object.Address) > 0 {
   317  			address = object.Address
   318  		}
   319  		if len(object.Service) == 0 || len(object.Service) > 100 {
   320  			panic("rpc service invalid")
   321  		}
   322  		if self.consul.CheckService(services, object.Service, address) {
   323  			zlog.Println(utils.AddStr("grpc service [", object.Service, "][", address, "] exist, skip..."))
   324  			object.AddRPC(grpcServer)
   325  			continue
   326  		}
   327  		registration := new(consulapi.AgentServiceRegistration)
   328  		registration.ID = utils.GetUUID()
   329  		registration.Tags = object.Tags
   330  		registration.Name = object.Service
   331  		registration.Address = address
   332  		registration.Port = port
   333  		registration.Meta = make(map[string]string, 0)
   334  		registration.Check = &consulapi.AgentServiceCheck{
   335  			HTTP:                           fmt.Sprintf("http://%s:%d%s", registration.Address, self.consul.Config.CheckPort, self.consul.Config.CheckPath),
   336  			Timeout:                        self.consul.Config.Timeout,
   337  			Interval:                       self.consul.Config.Interval,
   338  			DeregisterCriticalServiceAfter: self.consul.Config.DestroyAfter,
   339  		}
   340  		zlog.Println(utils.AddStr("grpc service [", registration.Name, "][", registration.Address, "] added successful"))
   341  		if err := self.consul.Consulx.Agent().ServiceRegister(registration); err != nil {
   342  			panic(utils.AddStr("grpc service [", object.Service, "] add failed: ", err.Error()))
   343  		}
   344  		object.AddRPC(grpcServer)
   345  	}
   346  	go func() {
   347  		http.HandleFunc(self.consul.Config.CheckPath, self.consul.HealthCheck)
   348  		if err := http.ListenAndServe(fmt.Sprintf(":%d", self.consul.Config.CheckPort), nil); err != nil {
   349  			panic(err)
   350  		}
   351  	}()
   352  	l, err := net.Listen(self.consul.Config.Protocol, utils.AddStr(":", utils.AnyToStr(self.consul.Config.RpcPort)))
   353  	if err != nil {
   354  		panic(err)
   355  	}
   356  	zlog.Println(utils.AddStr("grpc server【", utils.AddStr(":", utils.AnyToStr(self.consul.Config.RpcPort)), "】has been started successful"))
   357  	if err := grpcServer.Serve(l); err != nil {
   358  		panic(err)
   359  	}
   360  }
   361  
   362  // RunClient Important: ensure that the service starts only once
   363  // JWT Token expires in 1 hour
   364  // The remaining 1200s will be automatically renewed and detected every 15s
   365  func RunClient(appId ...string) {
   366  	if len(clientOptions) == 0 {
   367  		c, err := NewConsul()
   368  		if err != nil {
   369  			panic(err)
   370  		}
   371  		client := &GRPCManager{consul: c, consulDs: ""}
   372  		clientOptions = append(clientOptions, grpc.WithInitialWindowSize(pool.InitialWindowSize))
   373  		clientOptions = append(clientOptions, grpc.WithInitialConnWindowSize(pool.InitialConnWindowSize))
   374  		clientOptions = append(clientOptions, grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(pool.MaxSendMsgSize)))
   375  		clientOptions = append(clientOptions, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(pool.MaxRecvMsgSize)))
   376  		clientOptions = append(clientOptions, grpc.WithKeepaliveParams(keepalive.ClientParameters{
   377  			Time:                pool.KeepAliveTime,
   378  			Timeout:             pool.KeepAliveTimeout,
   379  			PermitWithoutStream: true,
   380  		}))
   381  		clientOptions = append(clientOptions, grpc.WithUnaryInterceptor(client.ClientInterceptor))
   382  		if clientDialTLS != nil {
   383  			clientOptions = append(clientOptions, clientDialTLS)
   384  		} else {
   385  			clientOptions = append(clientOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
   386  		}
   387  	}
   388  	if len(appId) == 0 || len(appId[0]) == 0 {
   389  		return
   390  	}
   391  	var err error
   392  	var expired int64
   393  	for {
   394  		accessToken, expired, err = callLogin(appId[0])
   395  		if err != nil {
   396  			zlog.Error("rpc login failed", 0, zlog.AddError(err))
   397  			time.Sleep(5 * time.Second)
   398  			continue
   399  		}
   400  		break
   401  	}
   402  	go renewClientToken(appId[0], expired)
   403  }
   404  
   405  func callLogin(appId string) (string, int64, error) {
   406  	appConfig, err := GetGRPCAppConfig(appId)
   407  	if err != nil {
   408  		return "", 0, err
   409  	}
   410  	if len(appConfig.AppKey) == 0 {
   411  		return "", 0, utils.Error("rpc appConfig key is nil")
   412  	}
   413  	authObject := &AuthObject{
   414  		AppId: appId,
   415  		Nonce: utils.RandStr(16),
   416  		Time:  utils.UnixSecond(),
   417  	}
   418  	authObject.Signature = utils.HMAC_SHA256(utils.AddStr(authObject.AppId, authObject.Nonce, authObject.Time), appConfig.AppKey, true)
   419  	b64, err := utils.ToJsonBase64(authObject)
   420  	if err != nil {
   421  		return "", 0, err
   422  	}
   423  	conn, err := NewClientConn(GRPC{Service: "PubWorker"})
   424  	if err != nil {
   425  		return "", 0, err
   426  	}
   427  	conn.NewContext(60000 * time.Millisecond)
   428  	defer conn.Close()
   429  	// load public key
   430  	pub, err := pb.NewPubWorkerClient(conn.Value()).PublicKey(conn.Context(), &pb.PublicKeyReq{})
   431  	if err != nil {
   432  		return "", 0, err
   433  	}
   434  	rsaObj := &crypto.RsaObj{}
   435  	if err := rsaObj.LoadRsaPemFileBase64(pub.PublicKey); err != nil {
   436  		return "", 0, err
   437  	}
   438  	content, err := rsaObj.Encrypt(utils.Str2Bytes(b64))
   439  	if err != nil {
   440  		return "", 0, err
   441  	}
   442  	req := &pb.AuthorizeReq{
   443  		Message: content,
   444  	}
   445  	res, err := pb.NewPubWorkerClient(conn.Value()).Authorize(conn.Context(), req)
   446  	if err != nil {
   447  		return "", 0, err
   448  	}
   449  	return res.Token, res.Expired, nil
   450  }
   451  
   452  func renewClientToken(appid string, expired int64) {
   453  	for {
   454  		//zlog.Warn("detecting rpc token expiration", 0, zlog.Int64("countDown", expired-utils.TimeSecond()-timeDifference))
   455  		if expired-utils.UnixSecond() > timeDifference { // TODO token过期时间大于2400s则忽略,每15s检测一次
   456  			time.Sleep(15 * time.Second)
   457  			continue
   458  		}
   459  		RunClient(appid)
   460  		zlog.Info("rpc token renewal succeeded", 0)
   461  		return
   462  	}
   463  }
   464  
   465  func NewClientConn(object GRPC) (pool.Conn, error) {
   466  	if len(object.Service) == 0 || len(object.Service) > 100 {
   467  		return nil, utils.Error("call service invalid")
   468  	}
   469  	var tag string
   470  	var timeout int
   471  	if object.Timeout <= 0 {
   472  		timeout = 60000
   473  	}
   474  	if len(object.Tags) > 0 {
   475  		tag = object.Tags[0]
   476  	}
   477  	c, err := NewConsul(object.Ds)
   478  	if err != nil {
   479  		return nil, err
   480  	}
   481  	services, err := c.GetCacheService(object.Service, tag, object.Cache)
   482  	if err != nil {
   483  		return nil, utils.Error("query service [", object.Service, "] failed: ", err)
   484  	}
   485  	var service *consulapi.AgentService
   486  	if selectionCall == nil { // 选取规则为空则默认随机
   487  		if len(services) == 1 {
   488  			service = services[0].Service
   489  		} else {
   490  			service = services[utils.ModRand(len(services))].Service
   491  		}
   492  	} else {
   493  		service = selectionCall(services, object).Service
   494  	}
   495  	return clientConnPools.getClientConn(utils.AddStr(service.Address, ":", service.Port), timeout)
   496  }
   497  
   498  //func AddClientPool(host string) error {
   499  //	_, err := clientConnPools.readyPool(host)
   500  //	if err != nil {
   501  //		return err
   502  //	}
   503  //	return nil
   504  //}
   505  
   506  func (self *ClientConnPool) getClientConn(host string, timeout int) (conn pool.Conn, err error) {
   507  	p, b := self.pools[host]
   508  	if !b || p == nil {
   509  		zlog.Warn("client pool host creating", 0, zlog.String("host", host))
   510  		p, err = self.readyPool(host)
   511  		if err != nil {
   512  			zlog.Error("client pool ready failed", 0, zlog.AddError(err))
   513  			return nil, err
   514  		}
   515  		if p == nil {
   516  			return nil, utils.Error("init client pool failed")
   517  		}
   518  	}
   519  	conn, err = p.Get()
   520  	if err != nil {
   521  		return nil, err
   522  	}
   523  	conn.NewContext(time.Duration(timeout) * time.Millisecond)
   524  	return conn, nil
   525  }
   526  
   527  func (self *ClientConnPool) readyPool(host string) (pool.Pool, error) {
   528  	self.m.Lock()
   529  	defer self.m.Unlock()
   530  	p, b := self.pools[host]
   531  	if !b || p == nil {
   532  		pool, err := pool.NewPool(pool.DefaultOptions, pool.ConnConfig{Address: host, Timeout: 10, Opts: clientOptions})
   533  		if err != nil {
   534  			return nil, err
   535  		}
   536  		self.pools[host] = pool
   537  		zlog.Info("client connection pool create successful", 0, zlog.String("host", host))
   538  		return pool, nil
   539  	}
   540  	//_, err := self.once.Do(func() (interface{}, error) {
   541  	//	pool, err := pool.NewPool(pool.DefaultOptions, pool.ConnConfig{Address: host, Timeout: 10, Opts: clientOptions})
   542  	//	if err != nil {
   543  	//		return nil, err
   544  	//	}
   545  	//	self.pools[host] = pool
   546  	//	zlog.Info("client connection pool create successful", 0, zlog.String("host", host))
   547  	//	return pool, nil
   548  	//})
   549  	//if err != nil {
   550  	//	return nil, err
   551  	//}
   552  	p, b = self.pools[host]
   553  	if !b || p == nil {
   554  		return nil, utils.Error("pool connection create failed")
   555  		//panic("pool connection not found")
   556  	}
   557  	return p, nil
   558  }