github.com/lastbackend/toolkit@v0.0.0-20241020043710-cafa37b95aad/pkg/client/grpc/client.go (about)

     1  /*
     2  Copyright [2014] - [2023] The Last.Backend authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package grpc
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/lastbackend/toolkit/pkg/client"
    26  	"github.com/lastbackend/toolkit/pkg/client/grpc/resolver"
    27  	"github.com/lastbackend/toolkit/pkg/client/grpc/resolver/file"
    28  	"github.com/lastbackend/toolkit/pkg/client/grpc/resolver/local"
    29  	"github.com/lastbackend/toolkit/pkg/context/metadata"
    30  	"github.com/lastbackend/toolkit/pkg/runtime"
    31  	"github.com/lastbackend/toolkit/pkg/util/backoff"
    32  	"google.golang.org/grpc"
    33  	"google.golang.org/grpc/codes"
    34  	"google.golang.org/grpc/credentials/insecure"
    35  	"google.golang.org/grpc/encoding"
    36  	grpc_md "google.golang.org/grpc/metadata"
    37  	"google.golang.org/grpc/status"
    38  )
    39  
    40  func init() {
    41  	encoding.RegisterCodec(protoCodec{})
    42  }
    43  
    44  const (
    45  	// default prefix
    46  	defaultPrefix = "GRPC_CLIENT"
    47  	// default pool name
    48  	defaultPoolName = ""
    49  	// default GRPC port
    50  	defaultPort = 9000
    51  	// The default number of times a request is tried
    52  	defaultRetries = 0 * time.Second
    53  	// The default request timeout
    54  	defaultRequestTimeout = 15 * time.Second
    55  	// The connection pool size
    56  	defaultPoolSize = 100
    57  	// The connection pool ttl
    58  	defaultPoolTTL = time.Minute
    59  	// DefaultMaxRecvMsgSize maximum message that client can receive (16 MB).
    60  	defaultMaxRecvMsgSize = 1024 * 1024 * 16
    61  	// DefaultMaxSendMsgSize maximum message that client can send (16 MB).
    62  	defaultMaxSendMsgSize = 1024 * 1024 * 16
    63  )
    64  
    65  type grpcClient struct {
    66  	ctx      context.Context
    67  	runtime  runtime.Runtime
    68  	resolver resolver.Resolver
    69  
    70  	opts Options
    71  	pool map[string]*pool
    72  }
    73  
    74  func NewClient(ctx context.Context, runtime runtime.Runtime) client.GRPCClient {
    75  
    76  	client := &grpcClient{
    77  		ctx:     ctx,
    78  		runtime: runtime,
    79  		opts:    defaultOptions(),
    80  		pool:    make(map[string]*pool, 0),
    81  	}
    82  
    83  	client.pool[defaultPoolName] = newPool()
    84  	runtime.Config().Parse(&client.opts, defaultPrefix)
    85  
    86  	if client.opts.Resolver == "local" {
    87  		client.resolver = local.NewResolver(runtime)
    88  	}
    89  
    90  	if client.opts.Resolver == "file" {
    91  		client.resolver = file.NewResolver(runtime)
    92  	}
    93  
    94  	return client
    95  }
    96  
    97  func (c *grpcClient) Conn(service string) (grpc.ClientConnInterface, error) {
    98  	var p *pool
    99  	p, ok := c.pool[service]
   100  	if !ok {
   101  		p = newPool()
   102  		c.pool[service] = p
   103  
   104  	}
   105  
   106  	routes, err := c.getResolver().Lookup(service)
   107  	if err != nil && !strings.HasSuffix(err.Error(), "route not found") {
   108  		return nil, status.Error(codes.Unavailable, err.Error())
   109  	}
   110  
   111  	addresses := routes.Addresses()
   112  	if len(addresses) == 0 {
   113  		addresses = []string{fmt.Sprintf(":%d", defaultPort)}
   114  	}
   115  
   116  	next, err := c.opts.Selector.Select(addresses)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  
   121  	return p.getConn(c.ctx, next(), c.makeGrpcDialOptions()...)
   122  
   123  }
   124  
   125  func (c *grpcClient) GetResolver() resolver.Resolver {
   126  	return c.resolver
   127  }
   128  
   129  func (c *grpcClient) SetResolver(resolver resolver.Resolver) {
   130  	c.resolver = resolver
   131  }
   132  
   133  func (c *grpcClient) Call(ctx context.Context, service, method string, body, resp interface{}, opts ...client.GRPCCallOption) error {
   134  	if body == nil {
   135  		return status.Error(codes.Internal, "request is nil")
   136  	}
   137  	if resp == nil {
   138  		return status.Error(codes.Internal, "response is nil")
   139  	}
   140  
   141  	callOpts := c.opts.CallOptions
   142  	for _, opt := range opts {
   143  		opt(&callOpts)
   144  	}
   145  
   146  	ctx, cancel := context.WithTimeout(ctx, callOpts.RequestTimeout)
   147  	defer cancel()
   148  
   149  	headers := c.makeHeaders(ctx, service, callOpts)
   150  	req := client.NewGRPCRequest(service, method, body, headers)
   151  
   152  	routes, err := c.getResolver().Lookup(req.Service())
   153  	if err != nil && !strings.HasSuffix(err.Error(), "route not found") {
   154  		return status.Error(codes.Unavailable, err.Error())
   155  	}
   156  
   157  	addresses := routes.Addresses()
   158  	if len(addresses) == 0 {
   159  		addresses = []string{fmt.Sprintf(":%d", defaultPort)}
   160  	}
   161  
   162  	next, err := c.opts.Selector.Select(addresses)
   163  	if err != nil {
   164  		return err
   165  	}
   166  
   167  	invokeFunc := c.invoke
   168  
   169  	b := &backoff.Backoff{
   170  		Max: callOpts.Retries,
   171  	}
   172  
   173  	for {
   174  		select {
   175  		case <-ctx.Done():
   176  			return status.Error(codes.Canceled, ctx.Err().Error())
   177  		default:
   178  			err := invokeFunc(ctx, next(), req, resp, callOpts)
   179  			if err != nil {
   180  				d := b.Duration()
   181  				if d.Seconds() >= callOpts.Retries.Seconds() {
   182  					return err
   183  				}
   184  				time.Sleep(d)
   185  				continue
   186  			}
   187  			b.Reset()
   188  			return nil
   189  		}
   190  	}
   191  
   192  }
   193  
   194  func (c *grpcClient) Stream(ctx context.Context, service, method string, body interface{}, opts ...client.GRPCCallOption) (grpc.ClientStream, error) {
   195  
   196  	callOpts := c.opts.CallOptions
   197  	for _, opt := range opts {
   198  		opt(&callOpts)
   199  	}
   200  
   201  	streamFunc := c.stream
   202  
   203  	headers := c.makeHeaders(ctx, service, callOpts)
   204  	req := client.NewGRPCRequest(service, method, body, headers)
   205  
   206  	routes, err := c.getResolver().Lookup(req.Service())
   207  	if err != nil && !strings.HasSuffix(err.Error(), "route not found") {
   208  		return nil, status.Error(codes.Unavailable, err.Error())
   209  	}
   210  
   211  	addresses := routes.Addresses()
   212  	if len(addresses) == 0 {
   213  		addresses = []string{fmt.Sprintf(":%d", defaultPort)}
   214  	}
   215  
   216  	next, err := c.opts.Selector.Select(addresses)
   217  	if err != nil {
   218  		return nil, err
   219  	}
   220  
   221  	b := &backoff.Backoff{
   222  		Max: callOpts.Retries,
   223  	}
   224  
   225  	for {
   226  		select {
   227  		case <-ctx.Done():
   228  			return nil, status.Error(codes.Canceled, ctx.Err().Error())
   229  		default:
   230  			s, err := streamFunc(ctx, next(), req, callOpts)
   231  			if err != nil {
   232  				d := b.Duration()
   233  				if d.Seconds() >= callOpts.Retries.Seconds() {
   234  					return nil, err
   235  				}
   236  				time.Sleep(d)
   237  				continue
   238  			}
   239  			b.Reset()
   240  			return s, nil
   241  		}
   242  	}
   243  }
   244  
   245  func (c *grpcClient) invoke(ctx context.Context, addr string, req *client.GRPCRequest, rsp interface{}, opts client.GRPCCallOptions) error {
   246  
   247  	md := grpc_md.New(req.Headers())
   248  	ctx = grpc_md.NewOutgoingContext(ctx, md)
   249  
   250  	var headers grpc_md.MD
   251  
   252  	var gErr error
   253  	conn, err := c.pool[defaultPoolName].getConn(ctx, addr, c.makeGrpcDialOptions()...)
   254  	if err != nil {
   255  		return status.Error(codes.Internal, fmt.Sprintf("Failed sending request: %v", err))
   256  	}
   257  	defer conn.pool.release(addr, conn, gErr)
   258  
   259  	grpcOpts := c.makeGrpcCallOptions(opts)
   260  	grpcOpts = append(grpcOpts, grpc.Header(&headers))
   261  
   262  	ch := make(chan error, 1)
   263  	go func() {
   264  		ch <- conn.Invoke(ctx, req.Method(), req.Body(), rsp, grpcOpts...)
   265  		for k, v := range headers {
   266  			if len(v) > 0 && opts.Headers != nil {
   267  				opts.Headers[k] = v[0]
   268  			}
   269  		}
   270  	}()
   271  
   272  	select {
   273  	case err := <-ch:
   274  		gErr = err
   275  	case <-ctx.Done():
   276  		gErr = status.Error(codes.Canceled, ctx.Err().Error())
   277  	}
   278  
   279  	return gErr
   280  }
   281  
   282  func (c *grpcClient) stream(ctx context.Context, addr string, req *client.GRPCRequest, opts client.GRPCCallOptions) (grpc.ClientStream, error) {
   283  
   284  	md := grpc_md.New(req.Headers())
   285  	ctx = grpc_md.NewOutgoingContext(ctx, md)
   286  	ctx, cancel := context.WithCancel(ctx)
   287  
   288  	cc, err := c.pool[defaultPoolName].getConn(ctx, addr, c.makeGrpcDialOptions()...)
   289  	if err != nil {
   290  		cancel()
   291  		return nil, status.Error(codes.Internal, err.Error())
   292  	}
   293  
   294  	desc := &grpc.StreamDesc{
   295  		StreamName:    req.Method(),
   296  		ClientStreams: true,
   297  		ServerStreams: true,
   298  	}
   299  
   300  	st, err := cc.NewStream(ctx, desc, req.Method(), c.makeGrpcCallOptions(opts)...)
   301  	if err != nil {
   302  		cancel()
   303  		c.pool[defaultPoolName].release(addr, cc, err)
   304  		return nil, status.Error(codes.Canceled, err.Error())
   305  	}
   306  
   307  	s := &stream{
   308  		ClientStream: st,
   309  		context:      ctx,
   310  		request:      req,
   311  		conn:         cc,
   312  		close: func(err error) {
   313  			if err != nil {
   314  				cancel()
   315  			}
   316  			c.pool[defaultPoolName].release(addr, cc, err)
   317  		},
   318  	}
   319  
   320  	// wait for error response
   321  	ch := make(chan error, 1)
   322  
   323  	go func() {
   324  		// send the first message
   325  		ch <- st.SendMsg(req.Body())
   326  	}()
   327  
   328  	var grr error
   329  
   330  	select {
   331  	case err := <-ch:
   332  		grr = err
   333  	case <-ctx.Done():
   334  		grr = ctx.Err()
   335  	}
   336  
   337  	if grr != nil {
   338  		_ = st.CloseSend()
   339  		return nil, grr
   340  	}
   341  
   342  	return s, nil
   343  }
   344  
   345  func (c *grpcClient) makeGrpcCallOptions(opts client.GRPCCallOptions) []grpc.CallOption {
   346  	grpcCallOptions := make([]grpc.CallOption, 0)
   347  
   348  	//if len(opts.Headers) > 0 {
   349  	//  var header = grpc_md.New(opts.Headers)
   350  	//  grpcCallOptions = append(grpcCallOptions, grpc.Header(&header))
   351  	//}
   352  
   353  	if opts.MaxCallRecvMsgSize > 0 {
   354  		grpcCallOptions = append(grpcCallOptions, grpc.MaxCallRecvMsgSize(opts.MaxCallRecvMsgSize))
   355  	}
   356  	if opts.MaxCallSendMsgSize > 0 {
   357  		grpcCallOptions = append(grpcCallOptions, grpc.MaxCallSendMsgSize(opts.MaxCallSendMsgSize))
   358  	}
   359  	if opts.MaxRetryRPCBufferSize > 0 {
   360  		grpcCallOptions = append(grpcCallOptions, grpc.MaxRetryRPCBufferSize(opts.MaxRetryRPCBufferSize))
   361  	}
   362  	if opts.CallContentSubtype != "" {
   363  		grpcCallOptions = append(grpcCallOptions, grpc.CallContentSubtype(opts.CallContentSubtype))
   364  	}
   365  
   366  	return grpcCallOptions
   367  }
   368  
   369  func (c *grpcClient) makeGrpcDialOptions() []grpc.DialOption {
   370  	grpcDialOptions := make([]grpc.DialOption, 0)
   371  
   372  	// TODO: implement auths
   373  	grpcDialOptions = append(grpcDialOptions, grpc.WithTransportCredentials(insecure.NewCredentials()))
   374  
   375  	if c.opts.MaxRecvMsgSize != nil || c.opts.MaxSendMsgSize != nil {
   376  		var defaultCallOpts = make([]grpc.CallOption, 0)
   377  		if c.opts.MaxRecvMsgSize != nil {
   378  			defaultCallOpts = append(defaultCallOpts, grpc.MaxCallRecvMsgSize(*c.opts.MaxRecvMsgSize))
   379  		}
   380  		if c.opts.MaxSendMsgSize != nil {
   381  			defaultCallOpts = append(defaultCallOpts, grpc.MaxCallSendMsgSize(*c.opts.MaxSendMsgSize))
   382  		}
   383  		grpcDialOptions = append(grpcDialOptions, grpc.WithDefaultCallOptions(defaultCallOpts...))
   384  	}
   385  	if c.opts.WriteBufferSize != nil {
   386  		grpcDialOptions = append(grpcDialOptions, grpc.WithWriteBufferSize(*c.opts.WriteBufferSize))
   387  	}
   388  	if c.opts.ReadBufferSize != nil {
   389  		grpcDialOptions = append(grpcDialOptions, grpc.WithReadBufferSize(*c.opts.ReadBufferSize))
   390  	}
   391  	if c.opts.InitialWindowSize != nil {
   392  		grpcDialOptions = append(grpcDialOptions, grpc.WithInitialWindowSize(*c.opts.InitialWindowSize))
   393  	}
   394  	if c.opts.InitialConnWindowSize != nil {
   395  		grpcDialOptions = append(grpcDialOptions, grpc.WithInitialConnWindowSize(*c.opts.InitialConnWindowSize))
   396  	}
   397  	if c.opts.UserAgent != nil {
   398  		grpcDialOptions = append(grpcDialOptions, grpc.WithUserAgent(*c.opts.UserAgent))
   399  	}
   400  	if c.opts.MaxHeaderListSize != nil {
   401  		grpcDialOptions = append(grpcDialOptions, grpc.WithMaxHeaderListSize(uint32(*c.opts.MaxHeaderListSize)))
   402  	}
   403  
   404  	return grpcDialOptions
   405  }
   406  
   407  func (c *grpcClient) makeHeaders(ctx context.Context, service string, opts client.GRPCCallOptions) map[string]string {
   408  	var headers = make(map[string]string, 0)
   409  
   410  	if md, ok := metadata.LoadFromContext(ctx); ok {
   411  		for k, v := range md {
   412  			headers[strings.ToLower(k)] = v
   413  		}
   414  	}
   415  	if opts.Headers != nil {
   416  		for k, v := range opts.Headers {
   417  			headers[strings.ToLower(k)] = v
   418  		}
   419  	}
   420  
   421  	if _, ok := headers["content-type"]; !ok {
   422  		headers["content-type"] = c.opts.ContentType
   423  	}
   424  
   425  	headers["x-service-name"] = service
   426  
   427  	return headers
   428  }
   429  
   430  func (c *grpcClient) getResolver() resolver.Resolver {
   431  	if c.resolver == nil {
   432  		c.resolver = local.NewResolver(c.runtime)
   433  	}
   434  
   435  	return c.resolver
   436  }