github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mp/core/handler.go (about)

     1  package core
     2  
     3  import (
     4  	"github.com/chanxuehong/wechat/internal/util"
     5  )
     6  
     7  const maxHandlerChainSize = 64
     8  
     9  type HandlerChain []Handler
    10  
    11  type Handler interface {
    12  	ServeMsg(*Context)
    13  }
    14  
    15  // HandlerFunc =========================================================================================================
    16  
    17  var _ Handler = HandlerFunc(nil)
    18  
    19  type HandlerFunc func(*Context)
    20  
    21  func (fn HandlerFunc) ServeMsg(ctx *Context) { fn(ctx) }
    22  
    23  // ServeMux ============================================================================================================
    24  
    25  var _ Handler = (*ServeMux)(nil)
    26  
    27  // ServeMux 是一个消息(事件)路由器, 同时也是一个 Handler 的实现.
    28  //
    29  //	NOTE: ServeMux 非并发安全, 如果需要并发安全的 Handler, 可以参考 ServeMux 实现一个.
    30  type ServeMux struct {
    31  	startedChecker startedChecker
    32  
    33  	msgMiddlewares   HandlerChain
    34  	eventMiddlewares HandlerChain
    35  
    36  	defaultMsgHandlerChain   HandlerChain
    37  	defaultEventHandlerChain HandlerChain
    38  
    39  	msgHandlerChainMap   map[MsgType]HandlerChain
    40  	eventHandlerChainMap map[EventType]HandlerChain
    41  }
    42  
    43  func NewServeMux() *ServeMux {
    44  	return &ServeMux{
    45  		msgHandlerChainMap:   make(map[MsgType]HandlerChain),
    46  		eventHandlerChainMap: make(map[EventType]HandlerChain),
    47  	}
    48  }
    49  
    50  var successResponseBytes = []byte("success")
    51  
    52  // ServeMsg 实现 Handler 接口.
    53  func (mux *ServeMux) ServeMsg(ctx *Context) {
    54  	mux.startedChecker.start()
    55  	if MsgType := ctx.MixedMsg.MsgType; MsgType != "event" {
    56  		handlers := mux.getMsgHandlerChain(MsgType)
    57  		if len(handlers) == 0 {
    58  			ctx.ResponseWriter.Write(successResponseBytes)
    59  			return
    60  		}
    61  		ctx.handlers = handlers
    62  		ctx.Next()
    63  	} else {
    64  		handlers := mux.getEventHandlerChain(ctx.MixedMsg.EventType)
    65  		if len(handlers) == 0 {
    66  			ctx.ResponseWriter.Write(successResponseBytes)
    67  			return
    68  		}
    69  		ctx.handlers = handlers
    70  		ctx.Next()
    71  	}
    72  }
    73  
    74  // getMsgHandlerChain 获取 HandlerChain 以处理消息类型为 MsgType 的消息, 如果没有找到返回 nil.
    75  func (mux *ServeMux) getMsgHandlerChain(msgType MsgType) (handlers HandlerChain) {
    76  	if m := mux.msgHandlerChainMap; len(m) > 0 {
    77  		handlers = m[MsgType(util.ToLower(string(msgType)))]
    78  		if len(handlers) == 0 {
    79  			handlers = mux.defaultMsgHandlerChain
    80  		}
    81  	} else {
    82  		handlers = mux.defaultMsgHandlerChain
    83  	}
    84  	return
    85  }
    86  
    87  // getEventHandlerChain 获取 HandlerChain 以处理事件类型为 EventType 的事件, 如果没有找到返回 nil.
    88  func (mux *ServeMux) getEventHandlerChain(eventType EventType) (handlers HandlerChain) {
    89  	if m := mux.eventHandlerChainMap; len(m) > 0 {
    90  		handlers = m[EventType(util.ToLower(string(eventType)))]
    91  		if len(handlers) == 0 {
    92  			handlers = mux.defaultEventHandlerChain
    93  		}
    94  	} else {
    95  		handlers = mux.defaultEventHandlerChain
    96  	}
    97  	return
    98  }
    99  
   100  // ServeMux: registers HandlerChain ====================================================================================
   101  
   102  // Use 注册(新增) middlewares 使其在所有消息(事件)的 Handler 之前处理该处理消息(事件).
   103  func (mux *ServeMux) Use(middlewares ...Handler) {
   104  	mux.startedChecker.check()
   105  	if len(middlewares) == 0 {
   106  		return
   107  	}
   108  	for _, h := range middlewares {
   109  		if h == nil {
   110  			panic("handler can not be nil")
   111  		}
   112  	}
   113  	mux.useForMsg(middlewares)
   114  	mux.useForEvent(middlewares)
   115  }
   116  
   117  // UseFunc 注册(新增) middlewares 使其在所有消息(事件)的 Handler 之前处理该处理消息(事件).
   118  func (mux *ServeMux) UseFunc(middlewares ...func(*Context)) {
   119  	mux.startedChecker.check()
   120  	if len(middlewares) == 0 {
   121  		return
   122  	}
   123  	for _, h := range middlewares {
   124  		if h == nil {
   125  			panic("handler can not be nil")
   126  		}
   127  	}
   128  	middlewares2 := make(HandlerChain, len(middlewares))
   129  	for i := 0; i < len(middlewares); i++ {
   130  		middlewares2[i] = HandlerFunc(middlewares[i])
   131  	}
   132  	mux.useForMsg(middlewares2)
   133  	mux.useForEvent(middlewares2)
   134  }
   135  
   136  // UseForMsg 注册(新增) middlewares 使其在所有消息的 Handler 之前处理该处理消息.
   137  func (mux *ServeMux) UseForMsg(middlewares ...Handler) {
   138  	mux.startedChecker.check()
   139  	if len(middlewares) == 0 {
   140  		return
   141  	}
   142  	for _, h := range middlewares {
   143  		if h == nil {
   144  			panic("handler can not be nil")
   145  		}
   146  	}
   147  	mux.useForMsg(middlewares)
   148  }
   149  
   150  // UseFuncForMsg 注册(新增) middlewares 使其在所有消息的 Handler 之前处理该处理消息.
   151  func (mux *ServeMux) UseFuncForMsg(middlewares ...func(*Context)) {
   152  	mux.startedChecker.check()
   153  	if len(middlewares) == 0 {
   154  		return
   155  	}
   156  	for _, h := range middlewares {
   157  		if h == nil {
   158  			panic("handler can not be nil")
   159  		}
   160  	}
   161  	middlewares2 := make(HandlerChain, len(middlewares))
   162  	for i := 0; i < len(middlewares); i++ {
   163  		middlewares2[i] = HandlerFunc(middlewares[i])
   164  	}
   165  	mux.useForMsg(middlewares2)
   166  }
   167  
   168  func (mux *ServeMux) useForMsg(middlewares []Handler) {
   169  	if len(mux.defaultMsgHandlerChain) > 0 || len(mux.msgHandlerChainMap) > 0 {
   170  		panic("please call this method before any other methods those registered handlers for message")
   171  	}
   172  	mux.msgMiddlewares = combineHandlerChain(mux.msgMiddlewares, middlewares)
   173  }
   174  
   175  // UseForEvent 注册(新增) middlewares 使其在所有事件的 Handler 之前处理该处理事件.
   176  func (mux *ServeMux) UseForEvent(middlewares ...Handler) {
   177  	mux.startedChecker.check()
   178  	if len(middlewares) == 0 {
   179  		return
   180  	}
   181  	for _, h := range middlewares {
   182  		if h == nil {
   183  			panic("handler can not be nil")
   184  		}
   185  	}
   186  	mux.useForEvent(middlewares)
   187  }
   188  
   189  // UseFuncForEvent 注册(新增) middlewares 使其在所有事件的 Handler 之前处理该处理事件.
   190  func (mux *ServeMux) UseFuncForEvent(middlewares ...func(*Context)) {
   191  	mux.startedChecker.check()
   192  	if len(middlewares) == 0 {
   193  		return
   194  	}
   195  	for _, h := range middlewares {
   196  		if h == nil {
   197  			panic("handler can not be nil")
   198  		}
   199  	}
   200  	middlewares2 := make(HandlerChain, len(middlewares))
   201  	for i := 0; i < len(middlewares); i++ {
   202  		middlewares2[i] = HandlerFunc(middlewares[i])
   203  	}
   204  	mux.useForEvent(middlewares2)
   205  }
   206  
   207  func (mux *ServeMux) useForEvent(middlewares []Handler) {
   208  	if len(mux.defaultEventHandlerChain) > 0 || len(mux.eventHandlerChainMap) > 0 {
   209  		panic("please call this method before any other methods those registered handlers for event")
   210  	}
   211  	mux.eventMiddlewares = combineHandlerChain(mux.eventMiddlewares, middlewares)
   212  }
   213  
   214  // DefaultMsgHandle 设置 handlers 以处理没有匹配到具体类型的 HandlerChain 的消息.
   215  func (mux *ServeMux) DefaultMsgHandle(handlers ...Handler) {
   216  	mux.startedChecker.check()
   217  	if len(handlers) == 0 {
   218  		return
   219  	}
   220  	for _, h := range handlers {
   221  		if h == nil {
   222  			panic("handler can not be nil")
   223  		}
   224  	}
   225  	mux.defaultMsgHandlerChain = combineHandlerChain(mux.msgMiddlewares, handlers)
   226  }
   227  
   228  // DefaultMsgHandleFunc 设置 handlers 以处理没有匹配到具体类型的 HandlerChain 的消息.
   229  func (mux *ServeMux) DefaultMsgHandleFunc(handlers ...func(*Context)) {
   230  	mux.startedChecker.check()
   231  	if len(handlers) == 0 {
   232  		return
   233  	}
   234  	for _, h := range handlers {
   235  		if h == nil {
   236  			panic("handler can not be nil")
   237  		}
   238  	}
   239  	handlers2 := make(HandlerChain, len(handlers))
   240  	for i := 0; i < len(handlers); i++ {
   241  		handlers2[i] = HandlerFunc(handlers[i])
   242  	}
   243  	mux.defaultMsgHandlerChain = combineHandlerChain(mux.msgMiddlewares, handlers2)
   244  }
   245  
   246  // DefaultEventHandle 设置 handlers 以处理没有匹配到具体类型的 HandlerChain 的事件.
   247  func (mux *ServeMux) DefaultEventHandle(handlers ...Handler) {
   248  	mux.startedChecker.check()
   249  	if len(handlers) == 0 {
   250  		return
   251  	}
   252  	for _, h := range handlers {
   253  		if h == nil {
   254  			panic("handler can not be nil")
   255  		}
   256  	}
   257  	mux.defaultEventHandlerChain = combineHandlerChain(mux.eventMiddlewares, handlers)
   258  }
   259  
   260  // DefaultEventHandleFunc 设置 handlers 以处理没有匹配到具体类型的 HandlerChain 的事件.
   261  func (mux *ServeMux) DefaultEventHandleFunc(handlers ...func(*Context)) {
   262  	mux.startedChecker.check()
   263  	if len(handlers) == 0 {
   264  		return
   265  	}
   266  	for _, h := range handlers {
   267  		if h == nil {
   268  			panic("handler can not be nil")
   269  		}
   270  	}
   271  	handlers2 := make(HandlerChain, len(handlers))
   272  	for i := 0; i < len(handlers); i++ {
   273  		handlers2[i] = HandlerFunc(handlers[i])
   274  	}
   275  	mux.defaultEventHandlerChain = combineHandlerChain(mux.eventMiddlewares, handlers2)
   276  }
   277  
   278  // MsgHandle 设置 handlers 以处理特定类型的消息.
   279  func (mux *ServeMux) MsgHandle(msgType MsgType, handlers ...Handler) {
   280  	mux.startedChecker.check()
   281  	if len(handlers) == 0 {
   282  		return
   283  	}
   284  	for _, h := range handlers {
   285  		if h == nil {
   286  			panic("handler can not be nil")
   287  		}
   288  	}
   289  	mux.msgHandlerChainMap[MsgType(util.ToLower(string(msgType)))] = combineHandlerChain(mux.msgMiddlewares, handlers)
   290  }
   291  
   292  // MsgHandleFunc 设置 handlers 以处理特定类型的消息.
   293  func (mux *ServeMux) MsgHandleFunc(msgType MsgType, handlers ...func(*Context)) {
   294  	mux.startedChecker.check()
   295  	if len(handlers) == 0 {
   296  		return
   297  	}
   298  	for _, h := range handlers {
   299  		if h == nil {
   300  			panic("handler can not be nil")
   301  		}
   302  	}
   303  	handlers2 := make(HandlerChain, len(handlers))
   304  	for i := 0; i < len(handlers); i++ {
   305  		handlers2[i] = HandlerFunc(handlers[i])
   306  	}
   307  	mux.msgHandlerChainMap[MsgType(util.ToLower(string(msgType)))] = combineHandlerChain(mux.msgMiddlewares, handlers2)
   308  }
   309  
   310  // EventHandle 设置 handlers 以处理特定类型的事件.
   311  func (mux *ServeMux) EventHandle(eventType EventType, handlers ...Handler) {
   312  	mux.startedChecker.check()
   313  	if len(handlers) == 0 {
   314  		return
   315  	}
   316  	for _, h := range handlers {
   317  		if h == nil {
   318  			panic("handler can not be nil")
   319  		}
   320  	}
   321  	mux.eventHandlerChainMap[EventType(util.ToLower(string(eventType)))] = combineHandlerChain(mux.eventMiddlewares, handlers)
   322  }
   323  
   324  // EventHandleFunc 设置 handlers 以处理特定类型的事件.
   325  func (mux *ServeMux) EventHandleFunc(eventType EventType, handlers ...func(*Context)) {
   326  	mux.startedChecker.check()
   327  	if len(handlers) == 0 {
   328  		return
   329  	}
   330  	for _, h := range handlers {
   331  		if h == nil {
   332  			panic("handler can not be nil")
   333  		}
   334  	}
   335  	handlers2 := make(HandlerChain, len(handlers))
   336  	for i := 0; i < len(handlers); i++ {
   337  		handlers2[i] = HandlerFunc(handlers[i])
   338  	}
   339  	mux.eventHandlerChainMap[EventType(util.ToLower(string(eventType)))] = combineHandlerChain(mux.eventMiddlewares, handlers2)
   340  }
   341  
   342  func combineHandlerChain(middlewares, handlers HandlerChain) HandlerChain {
   343  	if len(middlewares)+len(handlers) > maxHandlerChainSize {
   344  		panic("too many handlers")
   345  	}
   346  	return append(middlewares, handlers...)
   347  }