github.com/godaddy-x/freego@v1.0.156/node/node_websocket.go (about)

     1  package node
     2  
     3  import (
     4  	rate "github.com/godaddy-x/freego/cache/limiter"
     5  	"github.com/godaddy-x/freego/ex"
     6  	"github.com/godaddy-x/freego/utils"
     7  	"github.com/godaddy-x/freego/utils/jwt"
     8  	"github.com/godaddy-x/freego/zlog"
     9  	"golang.org/x/net/websocket"
    10  	"net/http"
    11  	"sync"
    12  	"time"
    13  )
    14  
    15  type ConnPool map[string]map[string]*DevConn
    16  
    17  const (
    18  	pingCmd = "ws-health-check"
    19  )
    20  
    21  type Handle func(*Context, []byte) (interface{}, error) // 如响应数据为nil则不回复
    22  
    23  type WsServer struct {
    24  	Debug bool
    25  	HookNode
    26  	mu      sync.RWMutex
    27  	pool    ConnPool
    28  	ping    int           // 长连接心跳间隔
    29  	max     int           // 连接池总数量
    30  	limiter *rate.Limiter // 每秒限定连接数量
    31  }
    32  
    33  type DevConn struct {
    34  	Sub  string
    35  	Dev  string
    36  	Life int64
    37  	Last int64
    38  	Ctx  *Context
    39  	Conn *websocket.Conn
    40  }
    41  
    42  func (self *WsServer) readyContext() {
    43  	self.mu.Lock()
    44  	defer self.mu.Unlock()
    45  	if self.Context == nil {
    46  		self.Context = &Context{}
    47  		self.Context.configs = &Configs{}
    48  		self.Context.configs.routerConfigs = make(map[string]*RouterConfig)
    49  		self.Context.configs.langConfigs = make(map[string]map[string]string)
    50  		self.Context.configs.jwtConfig = jwt.JwtConfig{}
    51  		self.Context.System = &System{}
    52  	}
    53  }
    54  
    55  func (self *WsServer) checkContextReady(path string, routerConfig *RouterConfig) {
    56  	self.readyContext()
    57  	self.addRouterConfig(path, routerConfig)
    58  }
    59  
    60  func (self *WsServer) AddJwtConfig(config jwt.JwtConfig) {
    61  	self.readyContext()
    62  	if len(config.TokenKey) == 0 {
    63  		panic("jwt config key is nil")
    64  	}
    65  	if config.TokenExp < 0 {
    66  		panic("jwt config exp invalid")
    67  	}
    68  	self.Context.configs.jwtConfig.TokenAlg = config.TokenAlg
    69  	self.Context.configs.jwtConfig.TokenTyp = config.TokenTyp
    70  	self.Context.configs.jwtConfig.TokenKey = config.TokenKey
    71  	self.Context.configs.jwtConfig.TokenExp = config.TokenExp
    72  }
    73  
    74  func (self *WsServer) addRouterConfig(path string, routerConfig *RouterConfig) {
    75  	if routerConfig == nil {
    76  		routerConfig = &RouterConfig{}
    77  	}
    78  	if _, b := self.Context.configs.routerConfigs[path]; !b {
    79  		self.Context.configs.routerConfigs[path] = routerConfig
    80  	}
    81  }
    82  
    83  func (self *Context) readWsToken(auth string) error {
    84  	self.Subject.ResetTokenBytes(utils.Str2Bytes(auth))
    85  	return nil
    86  }
    87  
    88  func (self *Context) readWsBody(body []byte) error {
    89  	if body == nil || len(body) == 0 {
    90  		return ex.Throw{Code: http.StatusBadRequest, Msg: "body parameters is nil"}
    91  	}
    92  	if len(body) > (MAX_VALUE_LEN) {
    93  		return ex.Throw{Code: http.StatusLengthRequired, Msg: "body parameters length is too long"}
    94  	}
    95  	self.JsonBody.Data = utils.GetJsonString(body, "d")
    96  	self.JsonBody.Time = utils.GetJsonInt64(body, "t")
    97  	self.JsonBody.Nonce = utils.GetJsonString(body, "n")
    98  	self.JsonBody.Plan = utils.GetJsonInt64(body, "p")
    99  	self.JsonBody.Sign = utils.GetJsonString(body, "s")
   100  	if err := self.validJsonBody(); err != nil { // TODO important
   101  		return err
   102  	}
   103  	return nil
   104  }
   105  
   106  func (ctx *Context) writeError(ws *websocket.Conn, err error) error {
   107  	if err == nil {
   108  		return nil
   109  	}
   110  	out := ex.Catch(err)
   111  	if ctx.errorHandle != nil {
   112  		throw, ok := err.(ex.Throw)
   113  		if !ok {
   114  			throw = ex.Throw{Code: out.Code, Msg: out.Msg, Err: err, Arg: out.Arg}
   115  		}
   116  		if err = ctx.errorHandle(ctx, throw); err != nil {
   117  			zlog.Error("response error handle failed", 0, zlog.AddError(err))
   118  		}
   119  	}
   120  	resp := &JsonResp{
   121  		Code:    out.Code,
   122  		Message: out.Msg,
   123  		Time:    utils.UnixMilli(),
   124  	}
   125  	if !ctx.Authenticated() {
   126  		resp.Nonce = utils.RandNonce()
   127  	} else {
   128  		if ctx.JsonBody == nil || len(ctx.JsonBody.Nonce) == 0 {
   129  			resp.Nonce = utils.RandNonce()
   130  		} else {
   131  			resp.Nonce = ctx.JsonBody.Nonce
   132  		}
   133  	}
   134  	if ctx.RouterConfig.Guest {
   135  		if out.Code <= 600 {
   136  			ctx.Response.StatusCode = out.Code
   137  		}
   138  		return nil
   139  	}
   140  	if resp.Code == 0 {
   141  		resp.Code = ex.BIZ
   142  	}
   143  	result, _ := utils.JsonMarshal(resp)
   144  	if err := websocket.Message.Send(ws, result); err != nil {
   145  		zlog.Error("websocket send error", 0, zlog.AddError(err))
   146  	}
   147  	return nil
   148  }
   149  
   150  func closeConn(msg string, object *DevConn) {
   151  	defer func() {
   152  		if err := recover(); err != nil {
   153  			zlog.Error("ws close panic error", 0, zlog.String("sub", object.Sub), zlog.String("dev", object.Dev), zlog.Any("error", err))
   154  		}
   155  	}()
   156  	if object.Conn != nil {
   157  		if err := object.Conn.Close(); err != nil {
   158  			zlog.Error("ws close error", 0, zlog.String("msg", msg), zlog.String("sub", object.Sub), zlog.String("dev", object.Dev), zlog.AddError(err))
   159  		}
   160  	}
   161  }
   162  
   163  func createCtx(self *WsServer, path string) *Context {
   164  	ctx := self.Context
   165  	ctxNew := &Context{}
   166  	ctxNew.configs = self.Context.configs
   167  	ctxNew.filterChain = &filterChain{}
   168  	ctxNew.System = &System{}
   169  	ctxNew.JsonBody = &JsonBody{}
   170  	ctxNew.Subject = &jwt.Subject{Header: &jwt.Header{}, Payload: &jwt.Payload{}}
   171  	ctxNew.Response = &Response{Encoding: UTF8, ContentType: APPLICATION_JSON, ContentEntity: nil}
   172  	ctxNew.Storage = map[string]interface{}{}
   173  	if ctxNew.CacheAware == nil {
   174  		ctxNew.CacheAware = ctx.CacheAware
   175  	}
   176  	if ctxNew.RSA == nil {
   177  		ctxNew.RSA = ctx.RSA
   178  	}
   179  	if ctxNew.roleRealm == nil {
   180  		ctxNew.roleRealm = ctx.roleRealm
   181  	}
   182  	if ctxNew.errorHandle == nil {
   183  		ctxNew.errorHandle = ctx.errorHandle
   184  	}
   185  	ctxNew.System = ctx.System
   186  	//ctxNew.postHandle = handle
   187  	//ctxNew.RequestCtx = request
   188  	//ctxNew.Method = utils.Bytes2Str(self.RequestCtx.Method())
   189  	ctxNew.Path = path
   190  	ctxNew.RouterConfig = ctx.configs.routerConfigs[ctxNew.Path]
   191  	ctxNew.postCompleted = false
   192  	ctxNew.filterChain.pos = 0
   193  	return ctxNew
   194  }
   195  
   196  func wsRenderTo(ws *websocket.Conn, ctx *Context, data interface{}) error {
   197  	if data == nil {
   198  		return nil
   199  	}
   200  	routerConfig, _ := ctx.configs.routerConfigs[ctx.Path]
   201  	data, err := authReq(ctx.Path, data, ctx.GetTokenSecret(), routerConfig.AesResponse)
   202  	if err != nil {
   203  		return err
   204  	}
   205  	if err := websocket.Message.Send(ws, data); err != nil {
   206  		return ex.Throw{Code: ex.WS_SEND, Msg: "websocket send error", Err: err}
   207  	}
   208  	return nil
   209  }
   210  
   211  func validBody(ws *websocket.Conn, ctx *Context, body []byte) bool {
   212  	if body == nil || len(body) == 0 {
   213  		_ = ctx.writeError(ws, ex.Throw{Code: http.StatusBadRequest, Msg: "body parameters is nil"})
   214  		return false
   215  	}
   216  	if len(body) > (MAX_VALUE_LEN) {
   217  		_ = ctx.writeError(ws, ex.Throw{Code: http.StatusLengthRequired, Msg: "body parameters length is too long"})
   218  		return false
   219  	}
   220  	ctx.JsonBody.Data = utils.GetJsonString(body, "d")
   221  	ctx.JsonBody.Time = utils.GetJsonInt64(body, "t")
   222  	ctx.JsonBody.Nonce = utils.GetJsonString(body, "n")
   223  	ctx.JsonBody.Plan = utils.GetJsonInt64(body, "p")
   224  	ctx.JsonBody.Sign = utils.GetJsonString(body, "s")
   225  	if err := ctx.validJsonBody(); err != nil { // TODO important
   226  		_ = ctx.writeError(ws, err)
   227  		return false
   228  	}
   229  	return true
   230  }
   231  
   232  func (self *WsServer) SendMessage(data interface{}, subject string, dev ...string) error {
   233  	conn, b := self.pool[subject]
   234  	if !b || len(conn) == 0 {
   235  		return nil
   236  	}
   237  	for _, v := range conn {
   238  		if len(dev) > 0 && !utils.CheckStr(v.Dev, dev...) {
   239  			continue
   240  		}
   241  		if err := wsRenderTo(v.Conn, v.Ctx, data); err != nil {
   242  			return err
   243  		}
   244  	}
   245  	return nil
   246  }
   247  
   248  func (self *WsServer) addConn(conn *websocket.Conn, ctx *Context) error {
   249  	self.mu.Lock()
   250  	defer self.mu.Unlock()
   251  	sub := ctx.Subject.GetSub()
   252  	dev := ctx.Subject.GetDev()
   253  	exp := ctx.Subject.GetExp()
   254  	if len(dev) == 0 {
   255  		dev = "web"
   256  	}
   257  
   258  	zlog.Info("websocket client connect success", 0, zlog.String("subject", sub), zlog.String("path", ctx.Path), zlog.String("dev", dev))
   259  
   260  	key := utils.AddStr(dev, "_", ctx.Path)
   261  	if self.pool == nil {
   262  		self.pool = make(ConnPool, 50)
   263  	}
   264  
   265  	if len(self.pool) >= self.max {
   266  		closeConn("add conn max pool close", &DevConn{Conn: conn, Dev: ctx.Subject.GetDev(), Sub: ctx.Subject.GetSub()})
   267  		return utils.Error("conn pool full: ", len(self.pool))
   268  	}
   269  
   270  	check, b := self.pool[sub]
   271  	if !b {
   272  		value := make(map[string]*DevConn, 2)
   273  		value[key] = &DevConn{Life: exp, Last: utils.UnixSecond(), Sub: sub, Dev: dev, Ctx: ctx, Conn: conn}
   274  		self.pool[sub] = value
   275  		return nil
   276  	}
   277  	devConn, b := check[key]
   278  	if b {
   279  		closeConn("add conn replace close", devConn) // 如果存在连接对象则先关闭
   280  	}
   281  	if devConn == nil {
   282  		check[key] = &DevConn{Life: exp, Last: utils.UnixSecond(), Sub: sub, Dev: dev, Ctx: ctx, Conn: conn}
   283  		return nil
   284  	}
   285  	devConn.Sub = sub
   286  	devConn.Life = exp
   287  	devConn.Dev = dev
   288  	devConn.Last = utils.UnixSecond()
   289  	devConn.Ctx = ctx
   290  	devConn.Conn = conn
   291  	//check[key] = &DevConn{Life: exp, Last: utils.UnixSecond(), Dev: dev, Ctx: ctx, Conn: conn}
   292  	return nil
   293  }
   294  
   295  func (self *WsServer) refConn(ctx *Context) error {
   296  	self.mu.Lock()
   297  	defer self.mu.Unlock()
   298  	sub := ctx.Subject.Payload.Sub
   299  	dev := ctx.Subject.GetDev()
   300  	if len(dev) == 0 {
   301  		dev = "web"
   302  	}
   303  	dev = utils.AddStr(dev, "_", ctx.Path)
   304  	if self.pool == nil {
   305  		return nil
   306  	}
   307  
   308  	check, b := self.pool[sub]
   309  	if !b {
   310  		return nil
   311  	}
   312  	devConn, b := check[dev]
   313  	if !b {
   314  		return nil
   315  	}
   316  	devConn.Last = utils.UnixSecond()
   317  	return nil
   318  }
   319  
   320  func (self *WsServer) NewPool(maxConn, limit, bucket, ping int) {
   321  	if maxConn <= 0 {
   322  		panic("maxConn is nil")
   323  	}
   324  	if limit <= 0 {
   325  		panic("limit is nil")
   326  	}
   327  	if bucket <= 0 {
   328  		panic("bucket is nil")
   329  	}
   330  	if ping <= 0 {
   331  		panic("ping is nil")
   332  	}
   333  	self.mu.Lock()
   334  	defer self.mu.Unlock()
   335  	if self.pool == nil {
   336  		self.pool = make(ConnPool, maxConn)
   337  	}
   338  	self.max = maxConn
   339  	self.ping = ping
   340  
   341  	// 设置每秒放入100个令牌,并允许最大突发50个令牌
   342  	self.limiter = rate.NewLimiter(rate.Limit(limit), bucket)
   343  }
   344  
   345  func (self *WsServer) withConnectionLimit(handler websocket.Handler) http.Handler {
   346  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   347  		if !self.limiter.Allow() {
   348  			http.Error(w, "limited access", http.StatusServiceUnavailable)
   349  			return
   350  		}
   351  		handler.ServeHTTP(w, r)
   352  	})
   353  }
   354  
   355  func (self *WsServer) wsHandler(path string, handle Handle) websocket.Handler {
   356  	return func(ws *websocket.Conn) {
   357  
   358  		devConn := &DevConn{Conn: ws}
   359  
   360  		defer closeConn("handler close", devConn)
   361  
   362  		ctx := createCtx(self, path)
   363  		_ = ctx.readWsToken(ws.Request().Header.Get("Authorization"))
   364  
   365  		if len(ctx.Subject.GetRawBytes()) == 0 {
   366  			_ = ctx.writeError(ws, ex.Throw{Code: http.StatusUnauthorized, Msg: "token is nil"})
   367  			return
   368  		}
   369  		if err := ctx.Subject.Verify(utils.Bytes2Str(ctx.Subject.GetRawBytes()), ctx.GetJwtConfig().TokenKey, true); err != nil {
   370  			_ = ctx.writeError(ws, ex.Throw{Code: http.StatusUnauthorized, Msg: "token invalid or expired", Err: err})
   371  			return
   372  		}
   373  
   374  		devConn.Sub = ctx.Subject.GetSub()
   375  		devConn.Dev = ctx.Subject.GetDev()
   376  
   377  		if err := self.addConn(ws, ctx); err != nil {
   378  			zlog.Error("add conn error", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.AddError(err))
   379  			return
   380  		}
   381  
   382  		for {
   383  			// 读取消息
   384  			var body []byte
   385  			err := websocket.Message.Receive(ws, &body)
   386  			if err != nil {
   387  				zlog.Error("receive message error", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.AddError(err))
   388  				break
   389  			}
   390  
   391  			if self.Debug {
   392  				zlog.Info("websocket receive message", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.String("data", string(body)))
   393  			}
   394  
   395  			if !validBody(ws, ctx, body) {
   396  				if self.Debug {
   397  					zlog.Info("websocket receive message invalid", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.String("data", string(body)))
   398  				}
   399  				continue
   400  			}
   401  
   402  			dec, b := ctx.JsonBody.Data.([]byte)
   403  
   404  			if b && utils.GetJsonString(dec, "healthCheck") == pingCmd {
   405  				_ = self.refConn(ctx)
   406  				continue
   407  			}
   408  
   409  			reply, err := handle(ctx, dec)
   410  			if err != nil {
   411  				_ = ctx.writeError(ws, err)
   412  				continue
   413  			}
   414  
   415  			if self.Debug && reply != nil {
   416  				zlog.Info("websocket reply message", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.Any("data", reply))
   417  			}
   418  
   419  			// 回复消息
   420  			if err := wsRenderTo(ws, ctx, reply); err != nil {
   421  				zlog.Error("receive message reply error", 0, zlog.String("sub", devConn.Sub), zlog.String("dev", devConn.Dev), zlog.AddError(err))
   422  				break
   423  			}
   424  
   425  		}
   426  	}
   427  }
   428  
   429  func (self *WsServer) AddRouter(path string, handle Handle, routerConfig *RouterConfig) {
   430  	if handle == nil {
   431  		panic("handle function is nil")
   432  	}
   433  
   434  	self.checkContextReady(path, routerConfig)
   435  
   436  	http.Handle(path, self.withConnectionLimit(self.wsHandler(path, handle)))
   437  }
   438  
   439  func (self *WsServer) StartWebsocket(addr string) {
   440  	go func() {
   441  		for {
   442  			time.Sleep(time.Duration(self.ping) * time.Second)
   443  			s := utils.UnixMilli()
   444  			index := 0
   445  			current := utils.UnixSecond()
   446  			for _, v := range self.pool {
   447  				for k1, v1 := range v {
   448  					if current-v1.Last > int64(self.ping*2) || current > v1.Life {
   449  						self.mu.Lock()
   450  						closeConn("check life close", v1)
   451  						delete(v, k1)
   452  						self.mu.Unlock()
   453  					}
   454  					index++
   455  				}
   456  			}
   457  			if self.Debug {
   458  				zlog.Info("websocket check pool", 0, zlog.String("cost", utils.AddStr(utils.UnixMilli()-s, " ms")), zlog.Int("count", index))
   459  			}
   460  		}
   461  	}()
   462  	go func() {
   463  		zlog.Printf("websocket【%s】service has been started successful", addr)
   464  		if err := http.Serve(NewGracefulListener(addr, time.Second*10), nil); err != nil {
   465  			panic(err)
   466  		}
   467  	}()
   468  	select {}
   469  }