github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/client/client.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/nyan233/littlerpc/core/client/loadbalance"
     8  	"github.com/nyan233/littlerpc/core/common/errorhandler"
     9  	"github.com/nyan233/littlerpc/core/common/logger"
    10  	"github.com/nyan233/littlerpc/core/common/metadata"
    11  	transport2 "github.com/nyan233/littlerpc/core/common/transport"
    12  	metaDataUtil "github.com/nyan233/littlerpc/core/common/utils/metadata"
    13  	container2 "github.com/nyan233/littlerpc/core/container"
    14  	error2 "github.com/nyan233/littlerpc/core/protocol/error"
    15  	"github.com/nyan233/littlerpc/core/protocol/message"
    16  	"github.com/nyan233/littlerpc/core/utils/random"
    17  	"github.com/nyan233/littlerpc/internal/pool"
    18  	"reflect"
    19  )
    20  
    21  type Complete struct {
    22  	Message *message.Message
    23  	Error   error2.LErrorDesc
    24  }
    25  
    26  // Client 在Client中同时使用同步调用和异步调用将导致同步调用阻塞某一连接上的所有异步调用
    27  // 请求的发送
    28  type Client struct {
    29  	cfg *Config
    30  	// 用于连接管理
    31  	balancer loadbalance.Balancer
    32  	// 客户端的事件驱动引擎
    33  	engine transport2.ClientBuilder
    34  	// 为每个连接分配的资源
    35  	connSourceSet *container2.RWMutexMap[transport2.ConnAdapter, *connSource]
    36  	contextM      *contextManager
    37  	// context id的起始, 开始时随机分配
    38  	contextInitId uint64
    39  	// services 可以支持不同实例的调用
    40  	// 所有的操作都是线程安全的
    41  	services *container2.RCUMap[string, *metadata.Process]
    42  	// 用于keepalive
    43  	logger logger.LLogger
    44  	// 用于超时管理和异步调用模拟的goroutine池
    45  	gp pool.TaskPool[string]
    46  	// 用于客户端的插件
    47  	pluginManager *pluginManager
    48  	// 错误处理接口
    49  	eHandle error2.LErrors
    50  }
    51  
    52  func New(opts ...Option) (*Client, error) {
    53  	config := &Config{}
    54  	WithDefault()(config)
    55  	for _, v := range opts {
    56  		v(config)
    57  	}
    58  	client := &Client{
    59  		cfg: config,
    60  	}
    61  	client.logger = config.Logger
    62  	client.eHandle = config.ErrHandler
    63  	// init engine
    64  	client.engine = transport2.Manager.GetClientEngine(config.NetWork)()
    65  	eventD := client.engine.EventDriveInter()
    66  	eventD.OnOpen(client.onOpen)
    67  	eventD.OnMessage(client.onMessage)
    68  	eventD.OnClose(client.onClose)
    69  	if config.RegisterMPOnRead {
    70  		eventD.OnRead(client.onRead)
    71  	}
    72  	err := client.engine.Client().Start()
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	// 初始化负载均衡功能
    77  	client.connSourceSet = new(container2.RWMutexMap[transport2.ConnAdapter, *connSource])
    78  	defer client.initBalancer(config)()
    79  	// init goroutine pool
    80  	if config.PoolSize <= 0 {
    81  		// 关闭Async模式
    82  		client.gp = nil
    83  	} else if config.ExecPoolBuilder != nil {
    84  		client.gp = config.ExecPoolBuilder.Builder(
    85  			pool.MaxTaskPoolSize/4, config.PoolSize, config.PoolSize*2, func(poolId int, err interface{}) {
    86  				client.logger.Error(fmt.Sprintf("poolId : %d -> Panic : %v", poolId, err))
    87  			})
    88  	} else {
    89  		client.gp = pool.NewTaskPool[string](
    90  			pool.MaxTaskPoolSize/4, config.PoolSize, config.PoolSize*2, func(poolId int, err interface{}) {
    91  				client.logger.Error(fmt.Sprintf("poolId : %d -> Panic : %v", poolId, err))
    92  			})
    93  	}
    94  	// plugins
    95  	client.pluginManager = newPluginManager(config.Plugins)
    96  	// init ErrHandler
    97  	client.eHandle = config.ErrHandler
    98  	// init service map
    99  	client.services = container2.NewRCUMap[string, *metadata.Process](64)
   100  	// init context manager
   101  	client.contextM = newContextManager()
   102  	client.contextInitId = uint64(random.FastRand())
   103  	return client, nil
   104  }
   105  
   106  func (c *Client) initBalancer(config *Config) (afterStart func()) {
   107  	bConfig := new(loadbalance.Config)
   108  	bConfig.Logger = c.logger
   109  	bConfig.MuxConnSize = config.MuxConnection
   110  	bConfig.ConnectionFactory = func(node loadbalance.RpcNode) (transport2.ConnAdapter, error) {
   111  		conn, err := c.engine.Client().NewConn(transport2.NetworkClientConfig{
   112  			ServerAddr: node.Address,
   113  			KeepAlive:  config.KeepAlive,
   114  		})
   115  		if err != nil {
   116  			return nil, err
   117  		}
   118  		connSrc := newConnSource(c.cfg.ParserFactory, conn, node)
   119  		c.connSourceSet.Store(conn, connSrc)
   120  		return connSrc, nil
   121  	}
   122  	bConfig.CloseFunc = func(conn transport2.ConnAdapter) {
   123  		connSrc, ok := conn.(*connSource)
   124  		if !ok {
   125  			panic("closeFunc the conn type is not *connSource")
   126  		}
   127  		ableRelease, err := connSrc.halfClose()
   128  		if err != nil {
   129  			c.logger.Warn(err.Error())
   130  			return
   131  		}
   132  		if ableRelease {
   133  			c.logger.Debug("LRPC: balancer click CloseFunc : %v", connSrc.ConnAdapter.Close())
   134  		}
   135  	}
   136  	if !config.OpenLoadBalance {
   137  		bConfig.Resolver = func() ([]loadbalance.RpcNode, error) {
   138  			return []loadbalance.RpcNode{
   139  				{Address: config.ServerAddr},
   140  			}, nil
   141  		}
   142  		bConfig.ResolverUpdateInterval = -1
   143  		bConfig.Scheme = "scheme"
   144  	} else {
   145  		bConfig.Resolver = config.BalancerResolverFunc
   146  		bConfig.ResolverUpdateInterval = config.ResolverUpdateInterval
   147  	}
   148  	return func() {
   149  		c.balancer = config.BalancerFactory(*bConfig)
   150  	}
   151  }
   152  
   153  func (c *Client) BindFunc(sourceName string, i interface{}) error {
   154  	if i == nil {
   155  		return errors.New("register elem is nil")
   156  	}
   157  	if sourceName == "" {
   158  		return errors.New("the typ name is not defined")
   159  	}
   160  	source := new(metadata.Source)
   161  	source.InstanceType = reflect.TypeOf(i)
   162  	value := reflect.ValueOf(i)
   163  	// init map
   164  	source.ProcessSet = make(map[string]*metadata.Process, value.NumMethod())
   165  	// NOTE: 这里的判断不能依靠map的len/cap来确定实例用于多少的绑定方法
   166  	// 因为len/cap都不能提供准确的信息,调用make()时指定的cap只是给真正创建map的函数一个提示
   167  	// 并不代表真实大小,对没有插入过数据的map调用len()永远为0
   168  	if value.NumMethod() == 0 {
   169  		return errors.New("instance no method")
   170  	}
   171  	for i := 0; i < value.NumMethod(); i++ {
   172  		method := source.InstanceType.Method(i)
   173  		if !method.IsExported() {
   174  			continue
   175  		}
   176  		// 2022/02/22 : 生成器可能直接使用/间接使用*Client作为内嵌对象
   177  		// 这个时候需要防止Client自己的方法被添加到列表中
   178  		switch method.Name {
   179  		case "Call", "RawCall", "Request", "Requests", "AsyncCall", "BindFunc", "Close":
   180  			continue
   181  		}
   182  		opt := &metadata.Process{
   183  			Value: value.Method(i),
   184  		}
   185  		for j := 0; j < method.Type.NumIn(); j++ {
   186  			// 检查输入参数的最后一项是否为(...CallOption)
   187  			if j == (method.Type.NumIn()-1) && method.Type.In(j) == reflect.TypeOf([]CallOption{}) {
   188  				break
   189  			}
   190  			opt.ArgsType = append(opt.ArgsType, method.Type.In(j))
   191  		}
   192  		for j := 0; j < method.Type.NumOut(); j++ {
   193  			// 检查输入参数的最后一项是否为(error)
   194  			// NOTE: 2022/11/22 目前没有优雅的方法比较参数列表的接口类型为什么接口
   195  			// 值为nil的非空接口在转换成空接口时不会将数据类型assign给空接口, 只能通过类型的指针来比较
   196  			if j == (method.Type.NumOut()-1) && reflect.PtrTo(method.Type.Out(j)) == reflect.TypeOf(new(error)) {
   197  				break
   198  			}
   199  			opt.ResultsType = append(opt.ResultsType, method.Type.Out(j))
   200  		}
   201  		metaDataUtil.IFContextOrStream(opt, method.Type)
   202  		source.ProcessSet[method.Name] = opt
   203  	}
   204  	kvs := make([]container2.RCUMapElement[string, *metadata.Process], 0, len(source.ProcessSet))
   205  	for k, v := range source.ProcessSet {
   206  		serviceName := fmt.Sprintf("%s.%s", sourceName, k)
   207  		_, ok := c.services.LoadOk(serviceName)
   208  		if ok {
   209  			return errors.New("service name already usage")
   210  		}
   211  		kvs = append(kvs, container2.RCUMapElement[string, *metadata.Process]{
   212  			Key:   serviceName,
   213  			Value: v,
   214  		})
   215  	}
   216  	c.services.StoreMulti(kvs)
   217  	return nil
   218  }
   219  
   220  // RawCall 该调用和Client.Call不同, 这个调用不会识别Method和对应的in/out list
   221  // 只会对除context.Context/stream.LStream外的args/reps直接序列化
   222  func (c *Client) RawCall(service string, opts []CallOption, args ...interface{}) ([]interface{}, error) {
   223  	return c.call(service, opts, args, nil, false, false)
   224  }
   225  
   226  // Request req/rep风格的RPC调用, 这要求rep必须是指针类型, 否则会返回ErrCallArgsType
   227  func (c *Client) Request(service string, ctx context.Context, request interface{}, response interface{}, opts ...CallOption) error {
   228  	if response == nil {
   229  		return c.eHandle.LWarpErrorDesc(errorhandler.ErrCallArgsType, "response pointer equal nil")
   230  	}
   231  	_, err := c.call(service, opts, []interface{}{ctx, request}, []interface{}{response}, false, true)
   232  	return err
   233  }
   234  
   235  // Requests multi request and response
   236  func (c *Client) Requests(service string, requests []interface{}, responses []interface{}, opts ...CallOption) error {
   237  	// TODO: 修改检查的逻辑
   238  	if responses == nil || len(responses) > 0 {
   239  		return c.eHandle.LWarpErrorDesc(errorhandler.ErrCallArgsType, "responses length equal zero")
   240  	}
   241  	for _, response := range responses {
   242  		if response == nil {
   243  			return c.eHandle.LWarpErrorDesc(errorhandler.ErrCallArgsType, "response pointer equal nil")
   244  		}
   245  	}
   246  	_, err := c.call(service, opts, requests, responses, false, true)
   247  	return err
   248  }
   249  
   250  // Call 返回的error可能是由Server/Client本身产生的, 也有可能是调用用户过程返回的, 这些都会被Call
   251  // 视为错误, args为用户参数, 即context.Context & stream.LStream都会被放置在此, 如果存在的话.
   252  // Call实现context.Context传播的语义, 即传递的Context cancel时, client会同时将server端的
   253  // Context cancel, 但不会影响到自身的调用过程, 如果cancel之后, remote process不返回, 那么这次调用将会阻塞
   254  // 注册了元信息的过程返回的result数量始终等于自身结果数量-1, 因为error不包括在reps中, 不管发生了什么错误, 除非
   255  // 找不到注册的元信息
   256  func (c *Client) Call(service string, opts []CallOption, args ...interface{}) ([]interface{}, error) {
   257  	return c.call(service, opts, args, nil, true, false)
   258  }
   259  
   260  func (c *Client) call(
   261  	service string,
   262  	opts []CallOption,
   263  	args []interface{},
   264  	reps []interface{},
   265  	check bool,
   266  	bind bool,
   267  ) (completeReps []interface{}, completeErr error2.LErrorDesc) {
   268  
   269  	defer func() {
   270  		if completeErr != nil && check && (completeReps == nil || len(completeReps) == 0) {
   271  			if serviceInstance, ok := c.services.LoadOk(service); ok {
   272  				completeReps = make([]interface{}, serviceInstance.Value.Type().NumOut()-1)
   273  			}
   274  		}
   275  	}()
   276  	cs, err := c.takeConnSource(service)
   277  	if err != nil && check {
   278  		return nil, c.eHandle.LWarpErrorDesc(errorhandler.ErrClient, err)
   279  	}
   280  	mp := sharedPool.TakeMessagePool()
   281  	writeMsg := mp.Get().(*message.Message)
   282  	defer mp.Put(writeMsg)
   283  	writeMsg.Reset()
   284  	pCtx := c.pluginManager.GetContext()
   285  	defer c.pluginManager.FreeContext(pCtx)
   286  	if err := c.pluginManager.Request4C(pCtx, args, writeMsg); err != nil {
   287  		return nil, err
   288  	}
   289  	cc := &callConfig{
   290  		Writer: c.cfg.Writer,
   291  		Codec:  c.cfg.Codec,
   292  		Packer: c.cfg.Packer,
   293  	}
   294  	if opts != nil && len(opts) > 0 {
   295  		for _, opt := range opts {
   296  			opt(cc)
   297  		}
   298  	}
   299  	process, ctx, ctxId, err := c.identArgAndEncode(service, cc, writeMsg, args, bind)
   300  	if err != nil {
   301  		_ = c.pluginManager.Send4C(pCtx, writeMsg, err)
   302  		return nil, err
   303  	}
   304  	var notifyChannel chan Complete
   305  	notifyChannel, err = c.sendCallMsg(pCtx, cc, ctxId, writeMsg, cs, false)
   306  	if err != nil {
   307  		switch err.Code() {
   308  		case error2.ConnectionErr:
   309  			// TODO 连接错误启动重试
   310  			return nil, err
   311  		default:
   312  			return nil, err
   313  		}
   314  	}
   315  	if len(reps) == 0 {
   316  		if check {
   317  			reps = make([]interface{}, len(process.ResultsType))
   318  		}
   319  	}
   320  	reps, err = c.readMsgAndDecodeReply(ctx, notifyChannel, pCtx, cc, writeMsg.GetMsgId(), cs, process, reps, !check)
   321  	// 插件错误中断后续的处理
   322  	if err != nil && (err.Code() == errorhandler.ErrPlugin.Code()) {
   323  		return reps, err
   324  	}
   325  	if err := c.pluginManager.AfterReceive4C(pCtx, reps, err); err != nil {
   326  		return reps, err
   327  	}
   328  	if err == nil {
   329  		return reps, nil
   330  	}
   331  	switch err.Code() {
   332  	case error2.ConnectionErr:
   333  		// TODO 连接错误启动重试
   334  		return reps, err
   335  	default:
   336  		return reps, err
   337  	}
   338  }
   339  
   340  // AsyncCall TODO 改进这个不合时宜的API
   341  // AsyncCall 该函数返回时至少数据已经经过Codec的序列化,调用者有责任检查error
   342  // 该函数可能会传递来自Codec和内部组件的错误,因为它在发送消息之前完成
   343  func (c *Client) AsyncCall(service string, opts []CallOption, args []interface{}, callBack func(results []interface{}, err error)) error {
   344  	if callBack == nil {
   345  		return c.eHandle.LWarpErrorDesc(errorhandler.ErrCallArgsType, "callBack is empty")
   346  	}
   347  	msg := message.New()
   348  	cc := &callConfig{
   349  		Writer: c.cfg.Writer,
   350  		Codec:  c.cfg.Codec,
   351  		Packer: c.cfg.Packer,
   352  	}
   353  	if opts != nil && len(opts) > 0 {
   354  		cc = new(callConfig)
   355  		for _, opt := range opts {
   356  			opt(cc)
   357  		}
   358  	}
   359  	process, ctx, ctxId, err := c.identArgAndEncode(service, cc, msg, args, false)
   360  	if err != nil {
   361  		return err
   362  	}
   363  	return c.gp.Push(service, func() {
   364  		// 在池中获取一个底层传输的连接
   365  		conn, err := c.takeConnSource(service)
   366  		if err != nil {
   367  			callBack(nil, err)
   368  			return
   369  		}
   370  		var notifyChannel chan Complete
   371  		notifyChannel, err = c.sendCallMsg(nil, cc, ctxId, msg, conn, false)
   372  		if err != nil {
   373  			callBack(nil, err)
   374  			return
   375  		}
   376  		reps := make([]interface{}, len(process.ResultsType))
   377  		reps, err = c.readMsgAndDecodeReply(ctx, notifyChannel, nil, cc, msg.GetMsgId(), conn, process, reps, false)
   378  		callBack(reps, err)
   379  	})
   380  }
   381  
   382  func (c *Client) takeConnSource(service string) (*connSource, error2.LErrorDesc) {
   383  	conn := c.balancer.Target(service)
   384  	cs, ok := conn.(*connSource)
   385  	if !ok {
   386  		return nil, c.eHandle.LWarpErrorDesc(errorhandler.ErrClient, "target result is not connSource type")
   387  	}
   388  	return cs, nil
   389  }
   390  
   391  func (c *Client) Close() error {
   392  	if c.gp != nil {
   393  		if err := c.gp.Stop(); err != nil {
   394  			return err
   395  		}
   396  	}
   397  	err := c.engine.Client().Stop()
   398  	if err != nil {
   399  		return err
   400  	}
   401  	return c.balancer.Exit()
   402  }