github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/router/router.go (about) 1 package router 2 3 import ( 4 "io" 5 "net" 6 "net/http" 7 "reflect" 8 "sort" 9 "sync" 10 "time" 11 12 "github.com/volts-dev/template" 13 "github.com/volts-dev/utils" 14 "github.com/volts-dev/volts/codec" 15 "github.com/volts-dev/volts/logger" 16 "github.com/volts-dev/volts/registry" 17 "github.com/volts-dev/volts/transport" 18 ) 19 20 var log = logger.New("router") 21 22 type ( 23 // Router handle serving messages 24 IRouter interface { 25 Config() *Config // Retrieve the options 26 String() string 27 Handler() interface{} // 连接入口 serveHTTP 等接口实现 28 // Register endpoint in router 29 Register(ep *registry.Endpoint) error 30 // Deregister endpoint from router 31 Deregister(ep *registry.Endpoint) error 32 // list all endpiont from router 33 Endpoints() (services map[*TGroup][]*registry.Endpoint) 34 RegisterMiddleware(middlewares ...func(IRouter) IMiddleware) 35 RegisterGroup(groups ...IGroup) 36 PrintRoutes() 37 } 38 39 // router represents an RPC router. 40 TRouter struct { 41 sync.RWMutex 42 // router is a group set 43 TGroup 44 // 45 config *Config 46 // 中间件 47 middleware *TMiddlewareManager 48 template *template.TTemplateSet 49 respPool sync.Pool 50 httpCtxPool map[int]*sync.Pool //根据Route缓存 51 rpcCtxPool map[int]*sync.Pool 52 // compiled regexp for host and path 53 exit chan bool 54 } 55 ) 56 57 // TODO Validate validates an endpoint to guarantee it won't blow up when being served 58 func Validate(e *registry.Endpoint) error { 59 /* if e == nil { 60 return errors.New("endpoint is nil") 61 } 62 63 if len(e.Name) == 0 { 64 return errors.New("name required") 65 } 66 67 for _, p := range e.Path { 68 ps := p[0] 69 pe := p[len(p)-1] 70 71 if ps == '^' && pe == '$' { 72 _, err := regexp.CompilePOSIX(p) 73 if err != nil { 74 return err 75 } 76 } else if ps == '^' && pe != '$' { 77 return errors.New("invalid path") 78 } else if ps != '^' && pe == '$' { 79 return errors.New("invalid path") 80 } 81 } 82 83 if len(e.Handler) == 0 { 84 return errors.New("invalid handler") 85 } 86 */ 87 return nil 88 } 89 90 // 新建路由 91 // NOTE 路由不支持线程安全 不推荐多服务器调用! 92 func New(opts ...Option) *TRouter { 93 cfg := newConfig(opts...) 94 router := &TRouter{ 95 TGroup: *NewGroup(), 96 config: cfg, 97 middleware: newMiddlewareManager(), 98 httpCtxPool: make(map[int]*sync.Pool), 99 rpcCtxPool: make(map[int]*sync.Pool), 100 exit: make(chan bool), 101 } 102 cfg.Router = router 103 //~router.respPool.New = func() interface{} { 104 // return &transport.THttpResponse{} 105 //} 106 107 go router.watch() // 实时订阅 108 go router.refresh() // 定时刷新 109 110 return router 111 } 112 113 func (self *TRouter) PrintRoutes() { 114 if self.Config().RouterTreePrinter { 115 self.tree.PrintTrees() 116 } 117 } 118 119 func (self *TRouter) Config() *Config { 120 return self.config 121 } 122 123 func (self *TRouter) String() string { 124 return "volts-router" 125 } 126 127 func (self *TRouter) Handler() interface{} { 128 return self 129 } 130 131 // registry a endpoion to router 132 func (self *TRouter) Register(ep *registry.Endpoint) error { 133 if err := Validate(ep); err != nil { 134 return err 135 } 136 return self.tree.AddRoute(EndpiontToRoute(ep)) 137 } 138 139 func (self *TRouter) Deregister(ep *registry.Endpoint) error { 140 path := ep.Metadata["path"] 141 return self.tree.DelRoute(path, EndpiontToRoute(ep)) 142 } 143 144 func (self *TRouter) Endpoints() (services map[*TGroup][]*registry.Endpoint) { 145 // 注册订阅列表 146 var subscriberList []ISubscriber 147 148 for e := range self.subscribers { 149 // Only advertise non internal subscribers 150 if !e.Config().Internal { 151 subscriberList = append(subscriberList, e) 152 } 153 } 154 sort.Slice(subscriberList, func(i, j int) bool { 155 return subscriberList[i].Topic() > subscriberList[j].Topic() 156 }) 157 158 eps := self.tree.Endpoints() 159 for grp, endpoints := range eps { 160 for _, e := range subscriberList { 161 endpoints = append(endpoints, e.Endpoints()...) 162 } 163 eps[grp] = endpoints 164 } 165 166 return eps 167 } 168 169 // 注册中间件 170 func (self *TRouter) RegisterMiddleware(middlewares ...func(IRouter) IMiddleware) { 171 for _, creator := range middlewares { 172 // 新建中间件 173 middleware := creator(self) 174 if mm, ok := middleware.(IMiddlewareName); ok { 175 self.middleware.Add(mm.Name(), creator) 176 } else { 177 typ := reflect.TypeOf(middleware) 178 if typ.Kind() == reflect.Ptr { 179 typ = typ.Elem() 180 } 181 name := typ.String() 182 self.middleware.Add(name, creator) 183 } 184 } 185 } 186 187 // register module 188 func (self *TRouter) RegisterGroup(groups ...IGroup) { 189 for _, group := range groups { 190 group := group 191 self.tree.Conbine(group.GetRoutes()) 192 193 for sub, lst := range group.GetSubscribers() { 194 if rawLst, has := self.TGroup.subscribers[sub]; !has { 195 self.TGroup.subscribers[sub] = lst 196 } else { 197 /* 198 var news []broker.ISubscriber 199 // 排除重复的 200 for _, sb := range lst { 201 for _, s := range rawLst { 202 if s.Topic() == sb.Topic() { 203 goto next 204 } 205 } 206 news = append(news, sb) 207 next: 208 }*/ 209 210 self.TGroup.subscribers[sub] = append(rawLst, lst...) 211 } 212 } 213 } 214 } 215 216 func (self *TRouter) ServeHTTP(w http.ResponseWriter, r *transport.THttpRequest) { 217 // 使用defer保证错误也打印 218 if self.config.RequestPrinter { 219 defer func() { 220 log.Infof("[Path]%v", r.URL.Path) 221 }() 222 } 223 224 if r.Method == "CONNECT" { // serve as a raw network server 225 conn, _, err := w.(http.Hijacker).Hijack() 226 if err != nil { 227 log.Errf("rpc hijacking %v:%v", r.RemoteAddr, ": ", err.Error()) 228 } 229 io.WriteString(conn, "HTTP/1.0 200 Connected to RPC\n\n") 230 231 msg, err := transport.ReadMessage(conn) 232 if err != nil { 233 log.Errf("rpc Read %s", err.Error()) 234 } 235 sock := transport.NewTcpTransportSocket(conn, 0, 0) 236 req := transport.NewRpcRequest(r.Context(), msg, sock) 237 rsp := transport.NewRpcResponse(r.Context(), req, sock) 238 self.ServeRPC(rsp, req) 239 } else { // serve as a web server 240 // Pool 提供TResponseWriter 241 var rsp *transport.THttpResponse 242 if v := self.respPool.Get(); v == nil { 243 rsp = transport.NewHttpResponse(r.Context(), r) 244 } else { 245 rsp = v.(*transport.THttpResponse) 246 } 247 rsp.Connect(w) 248 249 //获得的地址 250 // # match route from tree 251 route, params := self.tree.Match(r.Method, r.URL.Path) 252 if route == nil { 253 rsp.WriteHeader(http.StatusNotFound) 254 return 255 } 256 257 // # get the new context from pool 258 p, has := self.httpCtxPool[route.Id] 259 if !has { 260 p = &sync.Pool{New: func() interface{} { 261 return NewHttpContext(self) 262 }} 263 264 self.httpCtxPool[route.Id] = p 265 } 266 267 ctx := p.Get().(*THttpContext) 268 if !ctx.inited { 269 ctx.router = self 270 ctx.route = *route // fixme 复制 271 ctx.inited = true 272 ctx.Template = template.Default() 273 } 274 275 ctx.reset(rsp, r) 276 ctx.setPathParams(params) 277 ctx.route = *route // TODO 优化重复使用 278 279 self.route(route, ctx) 280 281 // 结束Route并返回内容 282 ctx.Apply() 283 284 // 回收资源 285 p.Put(ctx) // Pool Handler 286 rsp.ResponseWriter = nil 287 self.respPool.Put(rsp) // Pool 回收TResponseWriter 288 } 289 } 290 291 func (self *TRouter) ServeRPC(w *transport.RpcResponse, r *transport.RpcRequest) { 292 reqMessage := r.Message // return the packet struct 293 // 心跳包 直接返回 294 if reqMessage.IsHeartbeat() { 295 reqMessage.SetMessageType(transport.MT_RESPONSE) 296 data := reqMessage.Encode() 297 w.Write(data) 298 return 299 } 300 301 //resMetadata := make(map[string]string) 302 //newCtx := context.WithValue(context.WithValue(ctx, share.ReqMetaDataKey, req.Metadata), 303 // share.ResMetaDataKey, resMetadata) 304 305 // ctx := context.WithValue(context.Background(), RemoteConnContextKey, conn) 306 //serviceName := reqMessage.Header["ServicePath"] 307 //methodName := msg.ServiceMethod 308 st := reqMessage.SerializeType() 309 //res := transport.GetMessageFromPool() 310 //res.SetMessageType(transport.MT_RESPONSE) 311 //res.SetSerializeType(st) 312 // 获取支持的序列模式 313 coder := codec.IdentifyCodec(st) 314 if coder == nil { 315 w.WriteHeader(transport.StatusForbidden) 316 //w.Write([]byte("can not find codec for " + st.String())) 317 w.Write([]byte{}) 318 return 319 } 320 321 route, params := self.tree.Match("CONNECT", reqMessage.Path) // 匹配路由树 322 if route == nil { 323 w.WriteHeader(transport.StatusNotFound) 324 w.Write([]byte{}) 325 //w.Write([]byte("rpc: can't match route " + serviceName)) 326 return 327 } else { 328 // 初始化Context 329 p, has := self.rpcCtxPool[route.Id] 330 if !has { // TODO 优化 331 p = &sync.Pool{New: func() interface{} { 332 return NewRpcHandler(self) 333 }} 334 self.rpcCtxPool[route.Id] = p 335 } 336 337 ctx := p.Get().(*TRpcContext) 338 if !ctx.inited { 339 ctx.router = self 340 ctx.route = *route 341 ctx.inited = true 342 } 343 ctx.reset(w, r, self, route) 344 ctx.setPathParams(params) 345 346 // 执行控制器 347 self.route(route, ctx) 348 } 349 350 // 返回数据 351 if !reqMessage.IsOneway() { 352 // 序列化数据 353 /* remove 已经交由Body response处理 354 data, err := coder.Encode(ctx.replyv.Interface()) 355 //argsReplyPools.Put(mtype.ReplyType, replyv) 356 if err != nil { 357 handleError(res, err) 358 return 359 } 360 res.Payload = data 361 362 if len(resMetadata) > 0 { //copy meta in context to request 363 meta := res.Header 364 if meta == nil { 365 res.Header = resMetadata 366 } else { 367 for k, v := range resMetadata { 368 meta[k] = v 369 } 370 } 371 } 372 373 err = w.Write(res.Payload) 374 if err != nil { 375 log.Dbg(err.Error()) 376 } 377 log.Dbg("aa", string(res.Payload))*/ 378 } 379 380 return 381 } 382 383 func (self *TRouter) route(route *route, ctx IContext) { 384 if self.config.Recover { 385 defer func() { 386 if err := recover(); err != nil { 387 log.Err(err) 388 389 if self.config.RecoverHandler != nil { 390 self.config.RecoverHandler(ctx) 391 } 392 } 393 }() 394 } 395 396 // TODO:将所有需要执行的Handler 存疑列表或者树-Node保存函数和参数 397 for _, handler := range route.handlers { 398 // TODO 回收需要特殊通道 直接调用占用了处理时间 399 handler.init(self).Invoke(ctx).recycle() 400 } 401 } 402 403 func (self *TRouter) isClosed() bool { 404 select { 405 case <-self.exit: 406 return true 407 default: 408 return false 409 } 410 } 411 412 // 过滤自己 413 func (self *TRouter) filteSelf(service *registry.Service) *registry.Service { 414 localServices := self.config.Registry.LocalServices() 415 416 nodes := make([]*registry.Node, 0) 417 for _, n := range service.Nodes { 418 // TODO 解决监控拉取registry服务器服务列表比LocalServices获取注册的本地早 419 if len(localServices) == 0 && self.tree.Count.Load() > 0 { 420 break 421 } 422 423 for _, curSrv := range localServices { 424 node := curSrv.Nodes[0] 425 host, port, err := net.SplitHostPort(node.Address) 426 if err != nil { 427 log.Err(err) 428 } 429 430 if n.Id == node.Id { 431 goto out 432 } 433 434 h, p, err := net.SplitHostPort(n.Address) 435 if err != nil { 436 log.Err(err) 437 } 438 439 // 同个服务器 440 if host == h && port == p { 441 goto out 442 } 443 444 } 445 nodes = append(nodes, n) 446 out: 447 } 448 449 service.Nodes = nodes 450 /* 451 for _, curSrv := range localServices { 452 if curSrv != nil && service.Name == curSrv.Name { 453 node := curSrv.Nodes[0] 454 host, port, err := net.SplitHostPort(node.Address) 455 if err != nil { 456 log.Err(err) 457 } 458 459 nodes := make([]*registry.Node, 0) 460 var node *registry.Node 461 for _, n := range service.Nodes { 462 if n.Id == node.Id { 463 continue 464 } 465 466 h, p, err := net.SplitHostPort(n.Address) 467 if err != nil { 468 log.Err(err) 469 } 470 471 // 同个服务器 472 if host == h && port == p { 473 continue 474 } 475 476 nodes = append(nodes, n) 477 } 478 service.Nodes = nodes 479 } 480 } 481 */ 482 return service 483 } 484 485 // store local endpoint 486 func (self *TRouter) store(services []*registry.Service) { 487 // create a new endpoint mapping 488 for _, service := range services { 489 service = self.filteSelf(service) 490 if len(service.Nodes) > 0 { 491 // map per endpoint 492 for _, sep := range service.Endpoints { 493 url := &TUrl{ 494 Path: sep.Path, 495 } 496 r := EndpiontToRoute(sep) 497 if utils.InStrings("CONNECT", sep.Method...) != -1 { 498 r.handlers = append(r.handlers, generateHandler(ProxyHandler, RpcHandler, []interface{}{RpcReverseProxy}, nil, url, []*registry.Service{service})) 499 } else { 500 r.handlers = append(r.handlers, generateHandler(ProxyHandler, HttpHandler, []interface{}{HttpReverseProxy}, nil, url, []*registry.Service{service})) 501 } 502 503 err := self.tree.AddRoute(r) 504 if err != nil { 505 log.Err(err) 506 } 507 } 508 } 509 } 510 } 511 512 // watch for endpoint changes 513 func (self *TRouter) watch() { 514 var attempts int 515 516 // 5秒后才启动监测 517 time.Sleep(5 * time.Second) 518 519 for { 520 if self.isClosed() { 521 break 522 } 523 524 // watch for changes 525 w, err := self.config.Registry.Watcher() 526 if err != nil { 527 attempts++ 528 log.Errf("error watching endpoints: %v", err) 529 //time.Sleep(time.Duration(attempts) * time.Second) 530 continue 531 } 532 533 // 无监视者等待 534 if w == nil { 535 time.Sleep(60 * time.Second) 536 continue 537 } 538 539 ch := make(chan bool) 540 541 go func() { 542 select { 543 case <-ch: 544 w.Stop() 545 case <-self.exit: 546 w.Stop() 547 } 548 }() 549 550 // reset if we get here 551 attempts = 0 552 553 for { 554 // process next event 555 res, err := w.Next() 556 if err != nil { 557 log.Errf("error getting next endoint: %v", err) 558 close(ch) 559 break 560 } 561 562 // skip these things 563 if res == nil || res.Service == nil { 564 break 565 } 566 567 // get entry from cache 568 services, err := self.config.RegistryCacher.GetService(res.Service.Name) 569 if err != nil { 570 log.Errf("unable to get service: %v", err) 571 break 572 } 573 574 // update our local endpoints 575 self.store(services) 576 } 577 } 578 } 579 580 // refresh list of api services 581 func (self *TRouter) refresh() { 582 // 5秒后才启动监测 583 time.Sleep(5 * time.Second) 584 585 var ( 586 err error 587 services []*registry.Service 588 list []*registry.Service 589 attempts int 590 ) 591 592 for { 593 list, err = self.config.Registry.ListServices() 594 if err != nil { 595 attempts++ 596 log.Warnf("registry unable to list services: %v", err) 597 time.Sleep(time.Duration(attempts) * time.Second) 598 continue 599 } 600 // 无监视者等待 601 if len(list) == 0 { 602 time.Sleep(60 * time.Second) 603 continue 604 } 605 606 attempts = 0 607 608 // for each service, get service and store endpoints 609 for _, s := range list { 610 for _, local := range self.config.Registry.LocalServices() { 611 if local.Equal(s) { 612 // 不添加自己 613 goto out 614 } 615 616 } 617 618 services, err = self.config.RegistryCacher.GetService(s.Name) 619 if err != nil { 620 log.Errf("unable to get service: %v", err) 621 continue 622 } 623 self.store(services) 624 625 out: 626 } 627 628 // refresh list in 10 minutes... cruft 629 // use registry watching 630 select { 631 case <-time.After(time.Minute * 10): 632 case <-self.exit: 633 return 634 } 635 } 636 }