github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/router/router.go (about)

     1  package router
     2  
     3  import (
     4  	"io"
     5  	"net"
     6  	"net/http"
     7  	"reflect"
     8  	"sort"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/volts-dev/template"
    13  	"github.com/volts-dev/utils"
    14  	"github.com/volts-dev/volts/codec"
    15  	"github.com/volts-dev/volts/logger"
    16  	"github.com/volts-dev/volts/registry"
    17  	"github.com/volts-dev/volts/transport"
    18  )
    19  
    20  var log = logger.New("router")
    21  
    22  type (
    23  	// Router handle serving messages
    24  	IRouter interface {
    25  		Config() *Config // Retrieve the options
    26  		String() string
    27  		Handler() interface{} // 连接入口 serveHTTP 等接口实现
    28  		// Register endpoint in router
    29  		Register(ep *registry.Endpoint) error
    30  		// Deregister endpoint from router
    31  		Deregister(ep *registry.Endpoint) error
    32  		// list all endpiont from router
    33  		Endpoints() (services map[*TGroup][]*registry.Endpoint)
    34  		RegisterMiddleware(middlewares ...func(IRouter) IMiddleware)
    35  		RegisterGroup(groups ...IGroup)
    36  		PrintRoutes()
    37  	}
    38  
    39  	// router represents an RPC router.
    40  	TRouter struct {
    41  		sync.RWMutex
    42  		// router is a group set
    43  		TGroup
    44  		//
    45  		config *Config
    46  		// 中间件
    47  		middleware  *TMiddlewareManager
    48  		template    *template.TTemplateSet
    49  		respPool    sync.Pool
    50  		httpCtxPool map[int]*sync.Pool //根据Route缓存
    51  		rpcCtxPool  map[int]*sync.Pool
    52  		// compiled regexp for host and path
    53  		exit chan bool
    54  	}
    55  )
    56  
    57  // TODO Validate validates an endpoint to guarantee it won't blow up when being served
    58  func Validate(e *registry.Endpoint) error {
    59  	/*	if e == nil {
    60  			return errors.New("endpoint is nil")
    61  		}
    62  
    63  		if len(e.Name) == 0 {
    64  			return errors.New("name required")
    65  		}
    66  
    67  		for _, p := range e.Path {
    68  			ps := p[0]
    69  			pe := p[len(p)-1]
    70  
    71  			if ps == '^' && pe == '$' {
    72  				_, err := regexp.CompilePOSIX(p)
    73  				if err != nil {
    74  					return err
    75  				}
    76  			} else if ps == '^' && pe != '$' {
    77  				return errors.New("invalid path")
    78  			} else if ps != '^' && pe == '$' {
    79  				return errors.New("invalid path")
    80  			}
    81  		}
    82  
    83  		if len(e.Handler) == 0 {
    84  			return errors.New("invalid handler")
    85  		}
    86  	*/
    87  	return nil
    88  }
    89  
    90  // 新建路由
    91  // NOTE 路由不支持线程安全 不推荐多服务器调用!
    92  func New(opts ...Option) *TRouter {
    93  	cfg := newConfig(opts...)
    94  	router := &TRouter{
    95  		TGroup:      *NewGroup(),
    96  		config:      cfg,
    97  		middleware:  newMiddlewareManager(),
    98  		httpCtxPool: make(map[int]*sync.Pool),
    99  		rpcCtxPool:  make(map[int]*sync.Pool),
   100  		exit:        make(chan bool),
   101  	}
   102  	cfg.Router = router
   103  	//~router.respPool.New = func() interface{} {
   104  	//	return &transport.THttpResponse{}
   105  	//}
   106  
   107  	go router.watch()   // 实时订阅
   108  	go router.refresh() // 定时刷新
   109  
   110  	return router
   111  }
   112  
   113  func (self *TRouter) PrintRoutes() {
   114  	if self.Config().RouterTreePrinter {
   115  		self.tree.PrintTrees()
   116  	}
   117  }
   118  
   119  func (self *TRouter) Config() *Config {
   120  	return self.config
   121  }
   122  
   123  func (self *TRouter) String() string {
   124  	return "volts-router"
   125  }
   126  
   127  func (self *TRouter) Handler() interface{} {
   128  	return self
   129  }
   130  
   131  // registry a endpoion to router
   132  func (self *TRouter) Register(ep *registry.Endpoint) error {
   133  	if err := Validate(ep); err != nil {
   134  		return err
   135  	}
   136  	return self.tree.AddRoute(EndpiontToRoute(ep))
   137  }
   138  
   139  func (self *TRouter) Deregister(ep *registry.Endpoint) error {
   140  	path := ep.Metadata["path"]
   141  	return self.tree.DelRoute(path, EndpiontToRoute(ep))
   142  }
   143  
   144  func (self *TRouter) Endpoints() (services map[*TGroup][]*registry.Endpoint) {
   145  	// 注册订阅列表
   146  	var subscriberList []ISubscriber
   147  
   148  	for e := range self.subscribers {
   149  		// Only advertise non internal subscribers
   150  		if !e.Config().Internal {
   151  			subscriberList = append(subscriberList, e)
   152  		}
   153  	}
   154  	sort.Slice(subscriberList, func(i, j int) bool {
   155  		return subscriberList[i].Topic() > subscriberList[j].Topic()
   156  	})
   157  
   158  	eps := self.tree.Endpoints()
   159  	for grp, endpoints := range eps {
   160  		for _, e := range subscriberList {
   161  			endpoints = append(endpoints, e.Endpoints()...)
   162  		}
   163  		eps[grp] = endpoints
   164  	}
   165  
   166  	return eps
   167  }
   168  
   169  // 注册中间件
   170  func (self *TRouter) RegisterMiddleware(middlewares ...func(IRouter) IMiddleware) {
   171  	for _, creator := range middlewares {
   172  		// 新建中间件
   173  		middleware := creator(self)
   174  		if mm, ok := middleware.(IMiddlewareName); ok {
   175  			self.middleware.Add(mm.Name(), creator)
   176  		} else {
   177  			typ := reflect.TypeOf(middleware)
   178  			if typ.Kind() == reflect.Ptr {
   179  				typ = typ.Elem()
   180  			}
   181  			name := typ.String()
   182  			self.middleware.Add(name, creator)
   183  		}
   184  	}
   185  }
   186  
   187  // register module
   188  func (self *TRouter) RegisterGroup(groups ...IGroup) {
   189  	for _, group := range groups {
   190  		group := group
   191  		self.tree.Conbine(group.GetRoutes())
   192  
   193  		for sub, lst := range group.GetSubscribers() {
   194  			if rawLst, has := self.TGroup.subscribers[sub]; !has {
   195  				self.TGroup.subscribers[sub] = lst
   196  			} else {
   197  				/*
   198  					var news []broker.ISubscriber
   199  					// 排除重复的
   200  					for _, sb := range lst {
   201  						for _, s := range rawLst {
   202  							if s.Topic() == sb.Topic() {
   203  								goto next
   204  							}
   205  						}
   206  						news = append(news, sb)
   207  					next:
   208  					}*/
   209  
   210  				self.TGroup.subscribers[sub] = append(rawLst, lst...)
   211  			}
   212  		}
   213  	}
   214  }
   215  
   216  func (self *TRouter) ServeHTTP(w http.ResponseWriter, r *transport.THttpRequest) {
   217  	// 使用defer保证错误也打印
   218  	if self.config.RequestPrinter {
   219  		defer func() {
   220  			log.Infof("[Path]%v", r.URL.Path)
   221  		}()
   222  	}
   223  
   224  	if r.Method == "CONNECT" { // serve as a raw network server
   225  		conn, _, err := w.(http.Hijacker).Hijack()
   226  		if err != nil {
   227  			log.Errf("rpc hijacking %v:%v", r.RemoteAddr, ": ", err.Error())
   228  		}
   229  		io.WriteString(conn, "HTTP/1.0 200 Connected to RPC\n\n")
   230  
   231  		msg, err := transport.ReadMessage(conn)
   232  		if err != nil {
   233  			log.Errf("rpc Read %s", err.Error())
   234  		}
   235  		sock := transport.NewTcpTransportSocket(conn, 0, 0)
   236  		req := transport.NewRpcRequest(r.Context(), msg, sock)
   237  		rsp := transport.NewRpcResponse(r.Context(), req, sock)
   238  		self.ServeRPC(rsp, req)
   239  	} else { // serve as a web server
   240  		// Pool 提供TResponseWriter
   241  		var rsp *transport.THttpResponse
   242  		if v := self.respPool.Get(); v == nil {
   243  			rsp = transport.NewHttpResponse(r.Context(), r)
   244  		} else {
   245  			rsp = v.(*transport.THttpResponse)
   246  		}
   247  		rsp.Connect(w)
   248  
   249  		//获得的地址
   250  		// # match route from tree
   251  		route, params := self.tree.Match(r.Method, r.URL.Path)
   252  		if route == nil {
   253  			rsp.WriteHeader(http.StatusNotFound)
   254  			return
   255  		}
   256  
   257  		// # get the new context from pool
   258  		p, has := self.httpCtxPool[route.Id]
   259  		if !has {
   260  			p = &sync.Pool{New: func() interface{} {
   261  				return NewHttpContext(self)
   262  			}}
   263  
   264  			self.httpCtxPool[route.Id] = p
   265  		}
   266  
   267  		ctx := p.Get().(*THttpContext)
   268  		if !ctx.inited {
   269  			ctx.router = self
   270  			ctx.route = *route // fixme 复制
   271  			ctx.inited = true
   272  			ctx.Template = template.Default()
   273  		}
   274  
   275  		ctx.reset(rsp, r)
   276  		ctx.setPathParams(params)
   277  		ctx.route = *route // TODO 优化重复使用
   278  
   279  		self.route(route, ctx)
   280  
   281  		// 结束Route并返回内容
   282  		ctx.Apply()
   283  
   284  		// 回收资源
   285  		p.Put(ctx) // Pool Handler
   286  		rsp.ResponseWriter = nil
   287  		self.respPool.Put(rsp) // Pool 回收TResponseWriter
   288  	}
   289  }
   290  
   291  func (self *TRouter) ServeRPC(w *transport.RpcResponse, r *transport.RpcRequest) {
   292  	reqMessage := r.Message // return the packet struct
   293  	// 心跳包 直接返回
   294  	if reqMessage.IsHeartbeat() {
   295  		reqMessage.SetMessageType(transport.MT_RESPONSE)
   296  		data := reqMessage.Encode()
   297  		w.Write(data)
   298  		return
   299  	}
   300  
   301  	//resMetadata := make(map[string]string)
   302  	//newCtx := context.WithValue(context.WithValue(ctx, share.ReqMetaDataKey, req.Metadata),
   303  	//	share.ResMetaDataKey, resMetadata)
   304  
   305  	//		ctx := context.WithValue(context.Background(), RemoteConnContextKey, conn)
   306  	//serviceName := reqMessage.Header["ServicePath"]
   307  	//methodName := msg.ServiceMethod
   308  	st := reqMessage.SerializeType()
   309  	//res := transport.GetMessageFromPool()
   310  	//res.SetMessageType(transport.MT_RESPONSE)
   311  	//res.SetSerializeType(st)
   312  	// 获取支持的序列模式
   313  	coder := codec.IdentifyCodec(st)
   314  	if coder == nil {
   315  		w.WriteHeader(transport.StatusForbidden)
   316  		//w.Write([]byte("can not find codec for " + st.String()))
   317  		w.Write([]byte{})
   318  		return
   319  	}
   320  
   321  	route, params := self.tree.Match("CONNECT", reqMessage.Path) // 匹配路由树
   322  	if route == nil {
   323  		w.WriteHeader(transport.StatusNotFound)
   324  		w.Write([]byte{})
   325  		//w.Write([]byte("rpc: can't match route " + serviceName))
   326  		return
   327  	} else {
   328  		// 初始化Context
   329  		p, has := self.rpcCtxPool[route.Id]
   330  		if !has { // TODO 优化
   331  			p = &sync.Pool{New: func() interface{} {
   332  				return NewRpcHandler(self)
   333  			}}
   334  			self.rpcCtxPool[route.Id] = p
   335  		}
   336  
   337  		ctx := p.Get().(*TRpcContext)
   338  		if !ctx.inited {
   339  			ctx.router = self
   340  			ctx.route = *route
   341  			ctx.inited = true
   342  		}
   343  		ctx.reset(w, r, self, route)
   344  		ctx.setPathParams(params)
   345  
   346  		// 执行控制器
   347  		self.route(route, ctx)
   348  	}
   349  
   350  	// 返回数据
   351  	if !reqMessage.IsOneway() {
   352  		// 序列化数据
   353  		/* remove 已经交由Body response处理
   354  		data, err := coder.Encode(ctx.replyv.Interface())
   355  		//argsReplyPools.Put(mtype.ReplyType, replyv)
   356  		if err != nil {
   357  			handleError(res, err)
   358  			return
   359  		}
   360  		res.Payload = data
   361  
   362  		if len(resMetadata) > 0 { //copy meta in context to request
   363  			meta := res.Header
   364  			if meta == nil {
   365  				res.Header = resMetadata
   366  			} else {
   367  				for k, v := range resMetadata {
   368  					meta[k] = v
   369  				}
   370  			}
   371  		}
   372  
   373  		err = w.Write(res.Payload)
   374  		if err != nil {
   375  			log.Dbg(err.Error())
   376  		}
   377  		log.Dbg("aa", string(res.Payload))*/
   378  	}
   379  
   380  	return
   381  }
   382  
   383  func (self *TRouter) route(route *route, ctx IContext) {
   384  	if self.config.Recover {
   385  		defer func() {
   386  			if err := recover(); err != nil {
   387  				log.Err(err)
   388  
   389  				if self.config.RecoverHandler != nil {
   390  					self.config.RecoverHandler(ctx)
   391  				}
   392  			}
   393  		}()
   394  	}
   395  
   396  	// TODO:将所有需要执行的Handler 存疑列表或者树-Node保存函数和参数
   397  	for _, handler := range route.handlers {
   398  		// TODO 回收需要特殊通道 直接调用占用了处理时间
   399  		handler.init(self).Invoke(ctx).recycle()
   400  	}
   401  }
   402  
   403  func (self *TRouter) isClosed() bool {
   404  	select {
   405  	case <-self.exit:
   406  		return true
   407  	default:
   408  		return false
   409  	}
   410  }
   411  
   412  // 过滤自己
   413  func (self *TRouter) filteSelf(service *registry.Service) *registry.Service {
   414  	localServices := self.config.Registry.LocalServices()
   415  
   416  	nodes := make([]*registry.Node, 0)
   417  	for _, n := range service.Nodes {
   418  		//  TODO 解决监控拉取registry服务器服务列表比LocalServices获取注册的本地早
   419  		if len(localServices) == 0 && self.tree.Count.Load() > 0 {
   420  			break
   421  		}
   422  
   423  		for _, curSrv := range localServices {
   424  			node := curSrv.Nodes[0]
   425  			host, port, err := net.SplitHostPort(node.Address)
   426  			if err != nil {
   427  				log.Err(err)
   428  			}
   429  
   430  			if n.Id == node.Id {
   431  				goto out
   432  			}
   433  
   434  			h, p, err := net.SplitHostPort(n.Address)
   435  			if err != nil {
   436  				log.Err(err)
   437  			}
   438  
   439  			// 同个服务器
   440  			if host == h && port == p {
   441  				goto out
   442  			}
   443  
   444  		}
   445  		nodes = append(nodes, n)
   446  	out:
   447  	}
   448  
   449  	service.Nodes = nodes
   450  	/*
   451  		for _, curSrv := range localServices {
   452  			if curSrv != nil && service.Name == curSrv.Name {
   453  				node := curSrv.Nodes[0]
   454  				host, port, err := net.SplitHostPort(node.Address)
   455  				if err != nil {
   456  					log.Err(err)
   457  				}
   458  
   459  				nodes := make([]*registry.Node, 0)
   460  				var node *registry.Node
   461  				for _, n := range service.Nodes {
   462  					if n.Id == node.Id {
   463  						continue
   464  					}
   465  
   466  					h, p, err := net.SplitHostPort(n.Address)
   467  					if err != nil {
   468  						log.Err(err)
   469  					}
   470  
   471  					// 同个服务器
   472  					if host == h && port == p {
   473  						continue
   474  					}
   475  
   476  					nodes = append(nodes, n)
   477  				}
   478  				service.Nodes = nodes
   479  			}
   480  		}
   481  	*/
   482  	return service
   483  }
   484  
   485  // store local endpoint
   486  func (self *TRouter) store(services []*registry.Service) {
   487  	// create a new endpoint mapping
   488  	for _, service := range services {
   489  		service = self.filteSelf(service)
   490  		if len(service.Nodes) > 0 {
   491  			// map per endpoint
   492  			for _, sep := range service.Endpoints {
   493  				url := &TUrl{
   494  					Path: sep.Path,
   495  				}
   496  				r := EndpiontToRoute(sep)
   497  				if utils.InStrings("CONNECT", sep.Method...) != -1 {
   498  					r.handlers = append(r.handlers, generateHandler(ProxyHandler, RpcHandler, []interface{}{RpcReverseProxy}, nil, url, []*registry.Service{service}))
   499  				} else {
   500  					r.handlers = append(r.handlers, generateHandler(ProxyHandler, HttpHandler, []interface{}{HttpReverseProxy}, nil, url, []*registry.Service{service}))
   501  				}
   502  
   503  				err := self.tree.AddRoute(r)
   504  				if err != nil {
   505  					log.Err(err)
   506  				}
   507  			}
   508  		}
   509  	}
   510  }
   511  
   512  // watch for endpoint changes
   513  func (self *TRouter) watch() {
   514  	var attempts int
   515  
   516  	// 5秒后才启动监测
   517  	time.Sleep(5 * time.Second)
   518  
   519  	for {
   520  		if self.isClosed() {
   521  			break
   522  		}
   523  
   524  		// watch for changes
   525  		w, err := self.config.Registry.Watcher()
   526  		if err != nil {
   527  			attempts++
   528  			log.Errf("error watching endpoints: %v", err)
   529  			//time.Sleep(time.Duration(attempts) * time.Second)
   530  			continue
   531  		}
   532  
   533  		// 无监视者等待
   534  		if w == nil {
   535  			time.Sleep(60 * time.Second)
   536  			continue
   537  		}
   538  
   539  		ch := make(chan bool)
   540  
   541  		go func() {
   542  			select {
   543  			case <-ch:
   544  				w.Stop()
   545  			case <-self.exit:
   546  				w.Stop()
   547  			}
   548  		}()
   549  
   550  		// reset if we get here
   551  		attempts = 0
   552  
   553  		for {
   554  			// process next event
   555  			res, err := w.Next()
   556  			if err != nil {
   557  				log.Errf("error getting next endoint: %v", err)
   558  				close(ch)
   559  				break
   560  			}
   561  
   562  			// skip these things
   563  			if res == nil || res.Service == nil {
   564  				break
   565  			}
   566  
   567  			// get entry from cache
   568  			services, err := self.config.RegistryCacher.GetService(res.Service.Name)
   569  			if err != nil {
   570  				log.Errf("unable to get service: %v", err)
   571  				break
   572  			}
   573  
   574  			// update our local endpoints
   575  			self.store(services)
   576  		}
   577  	}
   578  }
   579  
   580  // refresh list of api services
   581  func (self *TRouter) refresh() {
   582  	// 5秒后才启动监测
   583  	time.Sleep(5 * time.Second)
   584  
   585  	var (
   586  		err      error
   587  		services []*registry.Service
   588  		list     []*registry.Service
   589  		attempts int
   590  	)
   591  
   592  	for {
   593  		list, err = self.config.Registry.ListServices()
   594  		if err != nil {
   595  			attempts++
   596  			log.Warnf("registry unable to list services: %v", err)
   597  			time.Sleep(time.Duration(attempts) * time.Second)
   598  			continue
   599  		}
   600  		// 无监视者等待
   601  		if len(list) == 0 {
   602  			time.Sleep(60 * time.Second)
   603  			continue
   604  		}
   605  
   606  		attempts = 0
   607  
   608  		// for each service, get service and store endpoints
   609  		for _, s := range list {
   610  			for _, local := range self.config.Registry.LocalServices() {
   611  				if local.Equal(s) {
   612  					// 不添加自己
   613  					goto out
   614  				}
   615  
   616  			}
   617  
   618  			services, err = self.config.RegistryCacher.GetService(s.Name)
   619  			if err != nil {
   620  				log.Errf("unable to get service: %v", err)
   621  				continue
   622  			}
   623  			self.store(services)
   624  
   625  		out:
   626  		}
   627  
   628  		// refresh list in 10 minutes... cruft
   629  		// use registry watching
   630  		select {
   631  		case <-time.After(time.Minute * 10):
   632  		case <-self.exit:
   633  			return
   634  		}
   635  	}
   636  }