github.com/cloudwego/kitex@v0.9.0/client/middlewares.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package client
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"time"
    24  
    25  	"github.com/apache/thrift/lib/go/thrift"
    26  
    27  	"github.com/cloudwego/kitex/internal"
    28  	"github.com/cloudwego/kitex/pkg/discovery"
    29  	"github.com/cloudwego/kitex/pkg/endpoint"
    30  	"github.com/cloudwego/kitex/pkg/event"
    31  	"github.com/cloudwego/kitex/pkg/kerrors"
    32  	"github.com/cloudwego/kitex/pkg/klog"
    33  	"github.com/cloudwego/kitex/pkg/loadbalance/lbcache"
    34  	"github.com/cloudwego/kitex/pkg/proxy"
    35  	"github.com/cloudwego/kitex/pkg/remote"
    36  	"github.com/cloudwego/kitex/pkg/remote/codec/protobuf"
    37  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    38  	"github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
    39  )
    40  
    41  const maxRetry = 6
    42  
    43  func newProxyMW(prx proxy.ForwardProxy) endpoint.Middleware {
    44  	// If you want to customize the processing logic of proxy middleware,
    45  	// you can implement this interface to replace the default implementation.
    46  	if p, ok := prx.(proxy.WithMiddleware); ok {
    47  		return p.ProxyMiddleware()
    48  	}
    49  	return func(next endpoint.Endpoint) endpoint.Endpoint {
    50  		return func(ctx context.Context, request, response interface{}) error {
    51  			err := prx.ResolveProxyInstance(ctx)
    52  			if err != nil {
    53  				return err
    54  			}
    55  			err = next(ctx, request, response)
    56  			return err
    57  		}
    58  	}
    59  }
    60  
    61  func discoveryEventHandler(name string, bus event.Bus, queue event.Queue) func(d *discovery.Change) {
    62  	return func(d *discovery.Change) {
    63  		now := time.Now()
    64  		bus.Dispatch(&event.Event{
    65  			Name:  name,
    66  			Time:  now,
    67  			Extra: d,
    68  		})
    69  		queue.Push(&event.Event{
    70  			Name: name,
    71  			Time: now,
    72  			Extra: map[string]interface{}{
    73  				"Added":   wrapInstances(d.Added),
    74  				"Updated": wrapInstances(d.Updated),
    75  				"Removed": wrapInstances(d.Removed),
    76  			},
    77  		})
    78  	}
    79  }
    80  
    81  // newResolveMWBuilder creates a middleware for service discovery.
    82  // This middleware selects an appropriate instance based on the resolver and loadbalancer given.
    83  // If retryable error is encountered, it will retry until timeout or an unretryable error is returned.
    84  func newResolveMWBuilder(lbf *lbcache.BalancerFactory) endpoint.MiddlewareBuilder {
    85  	return func(ctx context.Context) endpoint.Middleware {
    86  		return func(next endpoint.Endpoint) endpoint.Endpoint {
    87  			return func(ctx context.Context, request, response interface{}) error {
    88  				rpcInfo := rpcinfo.GetRPCInfo(ctx)
    89  
    90  				dest := rpcInfo.To()
    91  				if dest == nil {
    92  					return kerrors.ErrNoDestService
    93  				}
    94  
    95  				remote := remoteinfo.AsRemoteInfo(dest)
    96  				if remote == nil {
    97  					err := fmt.Errorf("unsupported target EndpointInfo type: %T", dest)
    98  					return kerrors.ErrInternalException.WithCause(err)
    99  				}
   100  				if remote.GetInstance() != nil {
   101  					return next(ctx, request, response)
   102  				}
   103  				lb, err := lbf.Get(ctx, dest)
   104  				if err != nil {
   105  					return kerrors.ErrServiceDiscovery.WithCause(err)
   106  				}
   107  
   108  				var lastErr error
   109  				for i := 0; i < maxRetry; i++ {
   110  					select {
   111  					case <-ctx.Done():
   112  						return kerrors.ErrRPCTimeout
   113  					default:
   114  					}
   115  
   116  					// we always need to get a new picker every time, because when downstream update deployment,
   117  					// we may get an old picker that include all outdated instances which will cause connect always failed.
   118  					picker := lb.GetPicker()
   119  					ins := picker.Next(ctx, request)
   120  					if ins == nil {
   121  						err = kerrors.ErrNoMoreInstance.WithCause(fmt.Errorf("last error: %w", lastErr))
   122  					} else {
   123  						remote.SetInstance(ins)
   124  						// TODO: generalize retry strategy
   125  						err = next(ctx, request, response)
   126  					}
   127  					if r, ok := picker.(internal.Reusable); ok {
   128  						r.Recycle()
   129  					}
   130  					if err == nil {
   131  						return nil
   132  					}
   133  					if retryable(err) {
   134  						lastErr = err
   135  						klog.CtxWarnf(ctx, "KITEX: auto retry retryable error, retry=%d error=%s", i+1, err.Error())
   136  						continue
   137  					}
   138  					return err
   139  				}
   140  				return lastErr
   141  			}
   142  		}
   143  	}
   144  }
   145  
   146  // newIOErrorHandleMW provides a hook point for io error handling.
   147  func newIOErrorHandleMW(errHandle func(context.Context, error) error) endpoint.Middleware {
   148  	if errHandle == nil {
   149  		errHandle = DefaultClientErrorHandler
   150  	}
   151  	return func(next endpoint.Endpoint) endpoint.Endpoint {
   152  		return func(ctx context.Context, request, response interface{}) (err error) {
   153  			err = next(ctx, request, response)
   154  			if err == nil {
   155  				return
   156  			}
   157  			return errHandle(ctx, err)
   158  		}
   159  	}
   160  }
   161  
   162  // DefaultClientErrorHandler is Default ErrorHandler for client
   163  // when no ErrorHandler is specified with Option `client.WithErrorHandler`, this ErrorHandler will be injected.
   164  // for thrift、KitexProtobuf, >= v0.4.0 wrap protocol error to TransError, which will be more friendly.
   165  func DefaultClientErrorHandler(ctx context.Context, err error) error {
   166  	switch err.(type) {
   167  	// for thrift、KitexProtobuf, actually check *remote.TransError is enough
   168  	case *remote.TransError, thrift.TApplicationException, protobuf.PBError:
   169  		// Add 'remote' prefix to distinguish with local err.
   170  		// Because it cannot make sure which side err when decode err happen
   171  		return kerrors.ErrRemoteOrNetwork.WithCauseAndExtraMsg(err, "remote")
   172  	}
   173  	return kerrors.ErrRemoteOrNetwork.WithCause(err)
   174  }
   175  
   176  // ClientErrorHandlerWithAddr is ErrorHandler for client, which will add remote addr info into error
   177  func ClientErrorHandlerWithAddr(ctx context.Context, err error) error {
   178  	addrStr := getRemoteAddr(ctx)
   179  	switch err.(type) {
   180  	// for thrift、KitexProtobuf, actually check *remote.TransError is enough
   181  	case *remote.TransError, thrift.TApplicationException, protobuf.PBError:
   182  		// Add 'remote' prefix to distinguish with local err.
   183  		// Because it cannot make sure which side err when decode err happen
   184  		extraMsg := "remote"
   185  		if addrStr != "" {
   186  			extraMsg = "remote-" + addrStr
   187  		}
   188  		return kerrors.ErrRemoteOrNetwork.WithCauseAndExtraMsg(err, extraMsg)
   189  	}
   190  	return kerrors.ErrRemoteOrNetwork.WithCauseAndExtraMsg(err, addrStr)
   191  }
   192  
   193  type instInfo struct {
   194  	Address string
   195  	Weight  int
   196  }
   197  
   198  func wrapInstances(insts []discovery.Instance) []*instInfo {
   199  	if len(insts) == 0 {
   200  		return nil
   201  	}
   202  	instInfos := make([]*instInfo, 0, len(insts))
   203  	for i := range insts {
   204  		inst := insts[i]
   205  		addr := fmt.Sprintf("%s://%s", inst.Address().Network(), inst.Address().String())
   206  		instInfos = append(instInfos, &instInfo{Address: addr, Weight: inst.Weight()})
   207  	}
   208  	return instInfos
   209  }
   210  
   211  func retryable(err error) bool {
   212  	return errors.Is(err, kerrors.ErrGetConnection) || errors.Is(err, kerrors.ErrCircuitBreak)
   213  }
   214  
   215  func getRemoteAddr(ctx context.Context) string {
   216  	if ri := rpcinfo.GetRPCInfo(ctx); ri != nil && ri.To() != nil && ri.To().Address() != nil {
   217  		return ri.To().Address().String()
   218  	}
   219  	return ""
   220  }