github.com/godaddy-x/freego@v1.0.156/node/node_websocket.go (about) 1 package node 2 3 import ( 4 rate "github.com/godaddy-x/freego/cache/limiter" 5 "github.com/godaddy-x/freego/ex" 6 "github.com/godaddy-x/freego/utils" 7 "github.com/godaddy-x/freego/utils/jwt" 8 "github.com/godaddy-x/freego/zlog" 9 "golang.org/x/net/websocket" 10 "net/http" 11 "sync" 12 "time" 13 ) 14 15 type ConnPool map[string]map[string]*DevConn 16 17 const ( 18 pingCmd = "ws-health-check" 19 ) 20 21 type Handle func(*Context, []byte) (interface{}, error) // 如响应数据为nil则不回复 22 23 type WsServer struct { 24 Debug bool 25 HookNode 26 mu sync.RWMutex 27 pool ConnPool 28 ping int // 长连接心跳间隔 29 max int // 连接池总数量 30 limiter *rate.Limiter // 每秒限定连接数量 31 } 32 33 type DevConn struct { 34 Sub string 35 Dev string 36 Life int64 37 Last int64 38 Ctx *Context 39 Conn *websocket.Conn 40 } 41 42 func (self *WsServer) readyContext() { 43 self.mu.Lock() 44 defer self.mu.Unlock() 45 if self.Context == nil { 46 self.Context = &Context{} 47 self.Context.configs = &Configs{} 48 self.Context.configs.routerConfigs = make(map[string]*RouterConfig) 49 self.Context.configs.langConfigs = make(map[string]map[string]string) 50 self.Context.configs.jwtConfig = jwt.JwtConfig{} 51 self.Context.System = &System{} 52 } 53 } 54 55 func (self *WsServer) checkContextReady(path string, routerConfig *RouterConfig) { 56 self.readyContext() 57 self.addRouterConfig(path, routerConfig) 58 } 59 60 func (self *WsServer) AddJwtConfig(config jwt.JwtConfig) { 61 self.readyContext() 62 if len(config.TokenKey) == 0 { 63 panic("jwt config key is nil") 64 } 65 if config.TokenExp < 0 { 66 panic("jwt config exp invalid") 67 } 68 self.Context.configs.jwtConfig.TokenAlg = config.TokenAlg 69 self.Context.configs.jwtConfig.TokenTyp = config.TokenTyp 70 self.Context.configs.jwtConfig.TokenKey = config.TokenKey 71 self.Context.configs.jwtConfig.TokenExp = config.TokenExp 72 } 73 74 func (self *WsServer) addRouterConfig(path string, routerConfig *RouterConfig) { 75 if routerConfig == nil { 76 routerConfig = &RouterConfig{} 77 } 78 if _, b := self.Context.configs.routerConfigs[path]; !b { 79 self.Context.configs.routerConfigs[path] = routerConfig 80 } 81 } 82 83 func (self *Context) readWsToken(auth string) error { 84 self.Subject.ResetTokenBytes(utils.Str2Bytes(auth)) 85 return nil 86 } 87 88 func (self *Context) readWsBody(body []byte) error { 89 if body == nil || len(body) == 0 { 90 return ex.Throw{Code: http.StatusBadRequest, Msg: "body parameters is nil"} 91 } 92 if len(body) > (MAX_VALUE_LEN) { 93 return ex.Throw{Code: http.StatusLengthRequired, Msg: "body parameters length is too long"} 94 } 95 self.JsonBody.Data = utils.GetJsonString(body, "d") 96 self.JsonBody.Time = utils.GetJsonInt64(body, "t") 97 self.JsonBody.Nonce = utils.GetJsonString(body, "n") 98 self.JsonBody.Plan = utils.GetJsonInt64(body, "p") 99 self.JsonBody.Sign = utils.GetJsonString(body, "s") 100 if err := self.validJsonBody(); err != nil { // TODO important 101 return err 102 } 103 return nil 104 } 105 106 func (ctx *Context) writeError(ws *websocket.Conn, err error) error { 107 if err == nil { 108 return nil 109 } 110 out := ex.Catch(err) 111 if ctx.errorHandle != nil { 112 throw, ok := err.(ex.Throw) 113 if !ok { 114 throw = ex.Throw{Code: out.Code, Msg: out.Msg, Err: err, Arg: out.Arg} 115 } 116 if err = ctx.errorHandle(ctx, throw); err != nil { 117 zlog.Error("response error handle failed", 0, zlog.AddError(err)) 118 } 119 } 120 resp := &JsonResp{ 121 Code: out.Code, 122 Message: out.Msg, 123 Time: utils.UnixMilli(), 124 } 125 if !ctx.Authenticated() { 126 resp.Nonce = utils.RandNonce() 127 } else { 128 if ctx.JsonBody == nil || len(ctx.JsonBody.Nonce) == 0 { 129 resp.Nonce = utils.RandNonce() 130 } else { 131 resp.Nonce = ctx.JsonBody.Nonce 132 } 133 } 134 if ctx.RouterConfig.Guest { 135 if out.Code <= 600 { 136 ctx.Response.StatusCode = out.Code 137 } 138 return nil 139 } 140 if resp.Code == 0 { 141 resp.Code = ex.BIZ 142 } 143 result, _ := utils.JsonMarshal(resp) 144 if err := websocket.Message.Send(ws, result); err != nil { 145 zlog.Error("websocket send error", 0, zlog.AddError(err)) 146 } 147 return nil 148 } 149 150 func closeConn(msg string, object *DevConn) { 151 defer func() { 152 if err := recover(); err != nil { 153 zlog.Error("ws close panic error", 0, zlog.String("sub", object.Sub), zlog.String("dev", object.Dev), zlog.Any("error", err)) 154 } 155 }() 156 if object.Conn != nil { 157 if err := object.Conn.Close(); err != nil { 158 zlog.Error("ws close error", 0, zlog.String("msg", msg), zlog.String("sub", object.Sub), zlog.String("dev", object.Dev), zlog.AddError(err)) 159 } 160 } 161 } 162 163 func createCtx(self *WsServer, path string) *Context { 164 ctx := self.Context 165 ctxNew := &Context{} 166 ctxNew.configs = self.Context.configs 167 ctxNew.filterChain = &filterChain{} 168 ctxNew.System = &System{} 169 ctxNew.JsonBody = &JsonBody{} 170 ctxNew.Subject = &jwt.Subject{Header: &jwt.Header{}, Payload: &jwt.Payload{}} 171 ctxNew.Response = &Response{Encoding: UTF8, ContentType: APPLICATION_JSON, ContentEntity: nil} 172 ctxNew.Storage = map[string]interface{}{} 173 if ctxNew.CacheAware == nil { 174 ctxNew.CacheAware = ctx.CacheAware 175 } 176 if ctxNew.RSA == nil { 177 ctxNew.RSA = ctx.RSA 178 } 179 if ctxNew.roleRealm == nil { 180 ctxNew.roleRealm = ctx.roleRealm 181 } 182 if ctxNew.errorHandle == nil { 183 ctxNew.errorHandle = ctx.errorHandle 184 } 185 ctxNew.System = ctx.System 186 //ctxNew.postHandle = handle 187 //ctxNew.RequestCtx = request 188 //ctxNew.Method = utils.Bytes2Str(self.RequestCtx.Method()) 189 ctxNew.Path = path 190 ctxNew.RouterConfig = ctx.configs.routerConfigs[ctxNew.Path] 191 ctxNew.postCompleted = false 192 ctxNew.filterChain.pos = 0 193 return ctxNew 194 } 195 196 func wsRenderTo(ws *websocket.Conn, ctx *Context, data interface{}) error { 197 if data == nil { 198 return nil 199 } 200 routerConfig, _ := ctx.configs.routerConfigs[ctx.Path] 201 data, err := authReq(ctx.Path, data, ctx.GetTokenSecret(), routerConfig.AesResponse) 202 if err != nil { 203 return err 204 } 205 if err := websocket.Message.Send(ws, data); err != nil { 206 return ex.Throw{Code: ex.WS_SEND, Msg: "websocket send error", Err: err} 207 } 208 return nil 209 } 210 211 func validBody(ws *websocket.Conn, ctx *Context, body []byte) bool { 212 if body == nil || len(body) == 0 { 213 _ = ctx.writeError(ws, ex.Throw{Code: http.StatusBadRequest, Msg: "body parameters is nil"}) 214 return false 215 } 216 if len(body) > (MAX_VALUE_LEN) { 217 _ = ctx.writeError(ws, ex.Throw{Code: http.StatusLengthRequired, Msg: "body parameters length is too long"}) 218 return false 219 } 220 ctx.JsonBody.Data = utils.GetJsonString(body, "d") 221 ctx.JsonBody.Time = utils.GetJsonInt64(body, "t") 222 ctx.JsonBody.Nonce = utils.GetJsonString(body, "n") 223 ctx.JsonBody.Plan = utils.GetJsonInt64(body, "p") 224 ctx.JsonBody.Sign = utils.GetJsonString(body, "s") 225 if err := ctx.validJsonBody(); err != nil { // TODO important 226 _ = ctx.writeError(ws, err) 227 return false 228 } 229 return true 230 } 231 232 func (self *WsServer) SendMessage(data interface{}, subject string, dev ...string) error { 233 conn, b := self.pool[subject] 234 if !b || len(conn) == 0 { 235 return nil 236 } 237 for _, v := range conn { 238 if len(dev) > 0 && !utils.CheckStr(v.Dev, dev...) { 239 continue 240 } 241 if err := wsRenderTo(v.Conn, v.Ctx, data); err != nil { 242 return err 243 } 244 } 245 return nil 246 } 247 248 func (self *WsServer) addConn(conn *websocket.Conn, ctx *Context) error { 249 self.mu.Lock() 250 defer self.mu.Unlock() 251 sub := ctx.Subject.GetSub() 252 dev := ctx.Subject.GetDev() 253 exp := ctx.Subject.GetExp() 254 if len(dev) == 0 { 255 dev = "web" 256 } 257 258 zlog.Info("websocket client connect success", 0, zlog.String("subject", sub), zlog.String("path", ctx.Path), zlog.String("dev", dev)) 259 260 key := utils.AddStr(dev, "_", ctx.Path) 261 if self.pool == nil { 262 self.pool = make(ConnPool, 50) 263 } 264 265 if len(self.pool) >= self.max { 266 closeConn("add conn max pool close", &DevConn{Conn: conn, Dev: ctx.Subject.GetDev(), Sub: ctx.Subject.GetSub()}) 267 return utils.Error("conn pool full: ", len(self.pool)) 268 } 269 270 check, b := self.pool[sub] 271 if !b { 272 value := make(map[string]*DevConn, 2) 273 value[key] = &DevConn{Life: exp, Last: utils.UnixSecond(), Sub: sub, Dev: dev, Ctx: ctx, Conn: conn} 274 self.pool[sub] = value 275 return nil 276 } 277 devConn, b := check[key] 278 if b { 279 closeConn("add conn replace close", devConn) // 如果存在连接对象则先关闭 280 } 281 if devConn == nil { 282 check[key] = &DevConn{Life: exp, Last: utils.UnixSecond(), Sub: sub, Dev: dev, Ctx: ctx, Conn: conn} 283 return nil 284 } 285 devConn.Sub = sub 286 devConn.Life = exp 287 devConn.Dev = dev 288 devConn.Last = utils.UnixSecond() 289 devConn.Ctx = ctx 290 devConn.Conn = conn 291 //check[key] = &DevConn{Life: exp, Last: utils.UnixSecond(), Dev: dev, Ctx: ctx, Conn: conn} 292 return nil 293 } 294 295 func (self *WsServer) refConn(ctx *Context) error { 296 self.mu.Lock() 297 defer self.mu.Unlock() 298 sub := ctx.Subject.Payload.Sub 299 dev := ctx.Subject.GetDev() 300 if len(dev) == 0 { 301 dev = "web" 302 } 303 dev = utils.AddStr(dev, "_", ctx.Path) 304 if self.pool == nil { 305 return nil 306 } 307 308 check, b := self.pool[sub] 309 if !b { 310 return nil 311 } 312 devConn, b := check[dev] 313 if !b { 314 return nil 315 } 316 devConn.Last = utils.UnixSecond() 317 return nil 318 } 319 320 func (self *WsServer) NewPool(maxConn, limit, bucket, ping int) { 321 if maxConn <= 0 { 322 panic("maxConn is nil") 323 } 324 if limit <= 0 { 325 panic("limit is nil") 326 } 327 if bucket <= 0 { 328 panic("bucket is nil") 329 } 330 if ping <= 0 { 331 panic("ping is nil") 332 } 333 self.mu.Lock() 334 defer self.mu.Unlock() 335 if self.pool == nil { 336 self.pool = make(ConnPool, maxConn) 337 } 338 self.max = maxConn 339 self.ping = ping 340 341 // 设置每秒放入100个令牌,并允许最大突发50个令牌 342 self.limiter = rate.NewLimiter(rate.Limit(limit), bucket) 343 } 344 345 func (self *WsServer) withConnectionLimit(handler websocket.Handler) http.Handler { 346 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 347 if !self.limiter.Allow() { 348 http.Error(w, "limited access", http.StatusServiceUnavailable) 349 return 350 } 351 handler.ServeHTTP(w, r) 352 }) 353 } 354 355 func (self *WsServer) wsHandler(path string, handle Handle) websocket.Handler { 356 return func(ws *websocket.Conn) { 357 358 devConn := &DevConn{Conn: ws} 359 360 defer closeConn("handler close", devConn) 361 362 ctx := createCtx(self, path) 363 _ = ctx.readWsToken(ws.Request().Header.Get("Authorization")) 364 365 if len(ctx.Subject.GetRawBytes()) == 0 { 366 _ = ctx.writeError(ws, ex.Throw{Code: http.StatusUnauthorized, Msg: "token is nil"}) 367 return 368 } 369 if err := ctx.Subject.Verify(utils.Bytes2Str(ctx.Subject.GetRawBytes()), ctx.GetJwtConfig().TokenKey, true); err != nil { 370 _ = ctx.writeError(ws, ex.Throw{Code: http.StatusUnauthorized, Msg: "token invalid or expired", Err: err}) 371 return 372 } 373 374 devConn.Sub = ctx.Subject.GetSub() 375 devConn.Dev = ctx.Subject.GetDev() 376 377 if err := self.addConn(ws, ctx); err != nil { 378 zlog.Error("add conn error", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.AddError(err)) 379 return 380 } 381 382 for { 383 // 读取消息 384 var body []byte 385 err := websocket.Message.Receive(ws, &body) 386 if err != nil { 387 zlog.Error("receive message error", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.AddError(err)) 388 break 389 } 390 391 if self.Debug { 392 zlog.Info("websocket receive message", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.String("data", string(body))) 393 } 394 395 if !validBody(ws, ctx, body) { 396 if self.Debug { 397 zlog.Info("websocket receive message invalid", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.String("data", string(body))) 398 } 399 continue 400 } 401 402 dec, b := ctx.JsonBody.Data.([]byte) 403 404 if b && utils.GetJsonString(dec, "healthCheck") == pingCmd { 405 _ = self.refConn(ctx) 406 continue 407 } 408 409 reply, err := handle(ctx, dec) 410 if err != nil { 411 _ = ctx.writeError(ws, err) 412 continue 413 } 414 415 if self.Debug && reply != nil { 416 zlog.Info("websocket reply message", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.Any("data", reply)) 417 } 418 419 // 回复消息 420 if err := wsRenderTo(ws, ctx, reply); err != nil { 421 zlog.Error("receive message reply error", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.AddError(err)) 422 break 423 } 424 425 } 426 } 427 } 428 429 func (self *WsServer) AddRouter(path string, handle Handle, routerConfig *RouterConfig) { 430 if handle == nil { 431 panic("handle function is nil") 432 } 433 434 self.checkContextReady(path, routerConfig) 435 436 http.Handle(path, self.withConnectionLimit(self.wsHandler(path, handle))) 437 } 438 439 func (self *WsServer) StartWebsocket(addr string) { 440 go func() { 441 for { 442 time.Sleep(time.Duration(self.ping) * time.Second) 443 s := utils.UnixMilli() 444 index := 0 445 current := utils.UnixSecond() 446 for _, v := range self.pool { 447 for k1, v1 := range v { 448 if current-v1.Last > int64(self.ping*2) || current > v1.Life { 449 self.mu.Lock() 450 closeConn("check life close", v1) 451 delete(v, k1) 452 self.mu.Unlock() 453 } 454 index++ 455 } 456 } 457 if self.Debug { 458 zlog.Info("websocket check pool", 0, zlog.String("cost", utils.AddStr(utils.UnixMilli()-s, " ms")), zlog.Int("count", index)) 459 } 460 } 461 }() 462 go func() { 463 zlog.Printf("websocket【%s】service has been started successful", addr) 464 if err := http.Serve(NewGracefulListener(addr, time.Second*10), nil); err != nil { 465 panic(err) 466 } 467 }() 468 select {} 469 }