github.com/cloudwego/kitex@v0.9.0/client/service_inline.go (about)

     1  /*
     2   * Copyright 2023 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package client
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"net"
    23  	"runtime/debug"
    24  
    25  	"github.com/bytedance/gopkg/cloud/metainfo"
    26  
    27  	"github.com/cloudwego/kitex/client/callopt"
    28  	"github.com/cloudwego/kitex/internal/client"
    29  	internal_server "github.com/cloudwego/kitex/internal/server"
    30  	"github.com/cloudwego/kitex/pkg/consts"
    31  	"github.com/cloudwego/kitex/pkg/endpoint"
    32  	"github.com/cloudwego/kitex/pkg/klog"
    33  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    34  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    35  	"github.com/cloudwego/kitex/pkg/utils"
    36  )
    37  
    38  var localAddr net.Addr
    39  
    40  func init() {
    41  	localAddr = utils.NewNetAddr("tcp", "127.0.0.1")
    42  }
    43  
    44  type ContextServiceInlineHandler interface {
    45  	WriteMeta(cliCtx, svrCtx context.Context, req interface{}) (newSvrCtx context.Context, err error)
    46  	ReadMeta(cliCtx, svrCtx context.Context, resp interface{}) (newCliCtx context.Context, err error)
    47  }
    48  
    49  type serviceInlineClient struct {
    50  	svcInfo *serviceinfo.ServiceInfo
    51  	mws     []endpoint.Middleware
    52  	eps     endpoint.Endpoint
    53  	opt     *client.Options
    54  
    55  	inited bool
    56  	closed bool
    57  
    58  	// server info
    59  	serverEps endpoint.Endpoint
    60  	serverOpt *internal_server.Options
    61  
    62  	contextServiceInlineHandler ContextServiceInlineHandler
    63  }
    64  
    65  type ServerInitialInfo interface {
    66  	Endpoints() endpoint.Endpoint
    67  	Option() *internal_server.Options
    68  	GetServiceInfos() map[string]*serviceinfo.ServiceInfo
    69  }
    70  
    71  // NewServiceInlineClient creates a kitex.Client with the given ServiceInfo, it is from generated code.
    72  func NewServiceInlineClient(svcInfo *serviceinfo.ServiceInfo, s ServerInitialInfo, opts ...Option) (Client, error) {
    73  	if svcInfo == nil {
    74  		return nil, errors.New("NewClient: no service info")
    75  	}
    76  	kc := &serviceInlineClient{}
    77  	kc.svcInfo = svcInfo
    78  	kc.opt = client.NewOptions(opts)
    79  	kc.serverEps = s.Endpoints()
    80  	kc.serverOpt = s.Option()
    81  	kc.serverOpt.RemoteOpt.TargetSvcInfo = svcInfo
    82  	kc.serverOpt.RemoteOpt.SvcSearchMap = s.GetServiceInfos()
    83  	if err := kc.init(); err != nil {
    84  		_ = kc.Close()
    85  		return nil, err
    86  	}
    87  	return kc, nil
    88  }
    89  
    90  func (kc *serviceInlineClient) SetContextServiceInlineHandler(simh ContextServiceInlineHandler) {
    91  	kc.contextServiceInlineHandler = simh
    92  }
    93  
    94  func (kc *serviceInlineClient) init() (err error) {
    95  	if err = kc.checkOptions(); err != nil {
    96  		return err
    97  	}
    98  	ctx := kc.initContext()
    99  	kc.initMiddlewares(ctx)
   100  	kc.richRemoteOption()
   101  	if err = kc.buildInvokeChain(); err != nil {
   102  		return err
   103  	}
   104  	kc.inited = true
   105  	return nil
   106  }
   107  
   108  func (kc *serviceInlineClient) checkOptions() (err error) {
   109  	if kc.opt.Svr.ServiceName == "" {
   110  		return errors.New("service name is required")
   111  	}
   112  	return nil
   113  }
   114  
   115  func (kc *serviceInlineClient) initContext() context.Context {
   116  	ctx := context.Background()
   117  	ctx = context.WithValue(ctx, endpoint.CtxEventBusKey, kc.opt.Bus)
   118  	ctx = context.WithValue(ctx, endpoint.CtxEventQueueKey, kc.opt.Events)
   119  	return ctx
   120  }
   121  
   122  func (kc *serviceInlineClient) initMiddlewares(ctx context.Context) {
   123  	builderMWs := richMWsWithBuilder(ctx, kc.opt.MWBs)
   124  	kc.mws = append(kc.mws, contextMW)
   125  	kc.mws = append(kc.mws, builderMWs...)
   126  }
   127  
   128  // initRPCInfo initializes the RPCInfo structure and attaches it to context.
   129  func (kc *serviceInlineClient) initRPCInfo(ctx context.Context, method string) (context.Context, rpcinfo.RPCInfo, *callopt.CallOptions) {
   130  	return initRPCInfo(ctx, method, kc.opt, kc.svcInfo, 0, nil)
   131  }
   132  
   133  // Call implements the Client interface .
   134  func (kc *serviceInlineClient) Call(ctx context.Context, method string, request, response interface{}) (err error) {
   135  	validateForCall(ctx, kc.inited, kc.closed)
   136  	var ri rpcinfo.RPCInfo
   137  	var callOpts *callopt.CallOptions
   138  	ctx, ri, callOpts = kc.initRPCInfo(ctx, method)
   139  
   140  	ctx = kc.opt.TracerCtl.DoStart(ctx, ri)
   141  	var reportErr error
   142  	defer func() {
   143  		if panicInfo := recover(); panicInfo != nil {
   144  			reportErr = rpcinfo.ClientPanicToErr(ctx, panicInfo, ri, true)
   145  		}
   146  		kc.opt.TracerCtl.DoFinish(ctx, ri, reportErr)
   147  		// If the user start a new goroutine and return before endpoint finished, it may cause panic.
   148  		// For example,, if the user writes a timeout Middleware and times out, rpcinfo will be recycled,
   149  		// but in fact, rpcinfo is still being used when it is executed inside
   150  		// So if endpoint returns err, client won't recycle rpcinfo.
   151  		if reportErr == nil {
   152  			rpcinfo.PutRPCInfo(ri)
   153  		}
   154  		callOpts.Recycle()
   155  	}()
   156  	reportErr = kc.eps(ctx, request, response)
   157  
   158  	if reportErr == nil {
   159  		err = ri.Invocation().BizStatusErr()
   160  	} else {
   161  		err = reportErr
   162  	}
   163  	return err
   164  }
   165  
   166  func (kc *serviceInlineClient) richRemoteOption() {
   167  	kc.opt.RemoteOpt.SvcInfo = kc.svcInfo
   168  }
   169  
   170  func (kc *serviceInlineClient) buildInvokeChain() error {
   171  	innerHandlerEp, err := kc.invokeHandleEndpoint()
   172  	if err != nil {
   173  		return err
   174  	}
   175  	kc.eps = endpoint.Chain(kc.mws...)(innerHandlerEp)
   176  	return nil
   177  }
   178  
   179  func (kc *serviceInlineClient) constructServerCtxWithMetadata(cliCtx context.Context) (serverCtx context.Context) {
   180  	serverCtx = context.Background()
   181  	// metainfo
   182  	// forward transmission
   183  	kvs := make(map[string]string, 16)
   184  	metainfo.SaveMetaInfoToMap(cliCtx, kvs)
   185  	if len(kvs) > 0 {
   186  		serverCtx = metainfo.SetMetaInfoFromMap(serverCtx, kvs)
   187  	}
   188  	serverCtx = metainfo.TransferForward(serverCtx)
   189  	// reverse transmission, backward mark
   190  	serverCtx = metainfo.WithBackwardValuesToSend(serverCtx)
   191  	return serverCtx
   192  }
   193  
   194  func (kc *serviceInlineClient) constructServerRPCInfo(svrCtx, cliCtx context.Context) (newServerCtx context.Context, svrRPCInfo rpcinfo.RPCInfo) {
   195  	rpcStats := rpcinfo.AsMutableRPCStats(rpcinfo.NewRPCStats())
   196  	if kc.serverOpt.StatsLevel != nil {
   197  		rpcStats.SetLevel(*kc.serverOpt.StatsLevel)
   198  	}
   199  	// Export read-only views to external users and keep a mapping for internal users.
   200  	ri := rpcinfo.NewRPCInfo(
   201  		rpcinfo.EmptyEndpointInfo(),
   202  		rpcinfo.FromBasicInfo(kc.serverOpt.Svr),
   203  		rpcinfo.NewServerInvocation(),
   204  		rpcinfo.AsMutableRPCConfig(kc.serverOpt.Configs).Clone().ImmutableView(),
   205  		rpcStats.ImmutableView(),
   206  	)
   207  	rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(localAddr)
   208  	svrCtx = rpcinfo.NewCtxWithRPCInfo(svrCtx, ri)
   209  
   210  	cliRpcInfo := rpcinfo.GetRPCInfo(cliCtx)
   211  	// handle common rpcinfo
   212  	method := cliRpcInfo.To().Method()
   213  	if ink, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok {
   214  		ink.SetMethodName(method)
   215  		ink.SetServiceName(kc.svcInfo.ServiceName)
   216  	}
   217  	rpcinfo.AsMutableEndpointInfo(ri.To()).SetMethod(method)
   218  	svrCtx = context.WithValue(svrCtx, consts.CtxKeyMethod, method)
   219  	return svrCtx, ri
   220  }
   221  
   222  func (kc *serviceInlineClient) invokeHandleEndpoint() (endpoint.Endpoint, error) {
   223  	svrTraceCtl := kc.serverOpt.TracerCtl
   224  	if svrTraceCtl == nil {
   225  		svrTraceCtl = &rpcinfo.TraceController{}
   226  	}
   227  
   228  	return func(ctx context.Context, req, resp interface{}) (err error) {
   229  		serverCtx := kc.constructServerCtxWithMetadata(ctx)
   230  		defer func() {
   231  			// backward key
   232  			kvs := metainfo.AllBackwardValuesToSend(serverCtx)
   233  			if len(kvs) > 0 {
   234  				metainfo.SetBackwardValuesFromMap(ctx, kvs)
   235  			}
   236  		}()
   237  		serverCtx, svrRPCInfo := kc.constructServerRPCInfo(serverCtx, ctx)
   238  		defer func() {
   239  			rpcinfo.PutRPCInfo(svrRPCInfo)
   240  		}()
   241  
   242  		// server trace
   243  		serverCtx = svrTraceCtl.DoStart(serverCtx, svrRPCInfo)
   244  
   245  		if kc.contextServiceInlineHandler != nil {
   246  			serverCtx, err = kc.contextServiceInlineHandler.WriteMeta(ctx, serverCtx, req)
   247  			if err != nil {
   248  				return err
   249  			}
   250  		}
   251  
   252  		// server logic
   253  		err = kc.serverEps(serverCtx, req, resp)
   254  		// finish server trace
   255  		// contextServiceInlineHandler may convert nil err to non nil err, so handle trace here
   256  		svrTraceCtl.DoFinish(serverCtx, svrRPCInfo, err)
   257  
   258  		if kc.contextServiceInlineHandler != nil {
   259  			var err1 error
   260  			ctx, err1 = kc.contextServiceInlineHandler.ReadMeta(ctx, serverCtx, resp)
   261  			if err1 != nil {
   262  				return err1
   263  			}
   264  		}
   265  		return err
   266  	}, nil
   267  }
   268  
   269  // Close is not concurrency safe.
   270  func (kc *serviceInlineClient) Close() error {
   271  	defer func() {
   272  		if err := recover(); err != nil {
   273  			klog.Warnf("KITEX: panic when close client, error=%s, stack=%s", err, string(debug.Stack()))
   274  		}
   275  	}()
   276  	if kc.closed {
   277  		return nil
   278  	}
   279  	kc.closed = true
   280  	var errs utils.ErrChain
   281  	for _, cb := range kc.opt.CloseCallbacks {
   282  		if err := cb(); err != nil {
   283  			errs.Append(err)
   284  		}
   285  	}
   286  	if errs.HasError() {
   287  		return errs
   288  	}
   289  	return nil
   290  }