github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mch/core/server.go (about)

     1  package core
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/hmac"
     6  	"crypto/md5"
     7  	"crypto/sha256"
     8  	"errors"
     9  	"fmt"
    10  	"io/ioutil"
    11  	"net/http"
    12  	"net/url"
    13  
    14  	"github.com/chanxuehong/util"
    15  	"github.com/chanxuehong/util/security"
    16  
    17  	"github.com/chanxuehong/wechat/internal/debug/mch/callback"
    18  )
    19  
    20  type Server struct {
    21  	appId  string
    22  	mchId  string
    23  	apiKey string
    24  
    25  	subAppId string
    26  	subMchId string
    27  
    28  	handler      Handler
    29  	errorHandler ErrorHandler
    30  }
    31  
    32  // NewServer 创建一个新的 Server.
    33  //
    34  //	appId:        可选; 公众号的 appid, 如果设置了值则该 Server 只能处理 appid 为该值的消息(事件)
    35  //	mchId:        可选; 商户号 mch_id, 如果设置了值则该 Server 只能处理 mch_id 为该值的消息(事件)
    36  //	apiKey:       必选; 商户的签名 key
    37  //	handler:      必选; 处理微信服务器推送过来的消息(事件)的 Handler
    38  //	errorHandler: 可选; 用于处理 Server 在处理消息(事件)过程中产生的错误, 如果没有设置则默认使用 DefaultErrorHandler
    39  func NewServer(appId, mchId, apiKey string, handler Handler, errorHandler ErrorHandler) *Server {
    40  	if apiKey == "" {
    41  		panic("empty apiKey")
    42  	}
    43  	if handler == nil {
    44  		panic("nil Handler")
    45  	}
    46  	if errorHandler == nil {
    47  		errorHandler = DefaultErrorHandler
    48  	}
    49  
    50  	return &Server{
    51  		appId:        appId,
    52  		mchId:        mchId,
    53  		apiKey:       apiKey,
    54  		handler:      handler,
    55  		errorHandler: errorHandler,
    56  	}
    57  }
    58  
    59  // NewSubMchServer 创建一个新的 Server.
    60  //
    61  //	appId:        可选; 公众号的 appid, 如果设置了值则该 Server 只能处理 appid 为该值的消息(事件)
    62  //	mchId:        可选; 商户号 mch_id, 如果设置了值则该 Server 只能处理 mch_id 为该值的消息(事件)
    63  //	apiKey:       必选; 商户的签名 key
    64  //	subAppId:     可选; 公众号的 sub_appid, 如果设置了值则该 Server 只能处理 sub_appid 为该值的消息(事件)
    65  //	subMchId:     可选; 商户号 sub_mch_id, 如果设置了值则该 Server 只能处理 sub_mch_id 为该值的消息(事件)
    66  //	handler:      必选; 处理微信服务器推送过来的消息(事件)的 Handler
    67  //	errorHandler: 可选; 用于处理 Server 在处理消息(事件)过程中产生的错误, 如果没有设置则默认使用 DefaultErrorHandler
    68  func NewSubMchServer(appId, mchId, apiKey string, subAppId, subMchId string, handler Handler, errorHandler ErrorHandler) *Server {
    69  	if apiKey == "" {
    70  		panic("empty apiKey")
    71  	}
    72  	if handler == nil {
    73  		panic("nil Handler")
    74  	}
    75  	if errorHandler == nil {
    76  		errorHandler = DefaultErrorHandler
    77  	}
    78  
    79  	return &Server{
    80  		appId:        appId,
    81  		mchId:        mchId,
    82  		apiKey:       apiKey,
    83  		subAppId:     subAppId,
    84  		subMchId:     subMchId,
    85  		handler:      handler,
    86  		errorHandler: errorHandler,
    87  	}
    88  }
    89  
    90  func (srv *Server) AppId() string {
    91  	return srv.appId
    92  }
    93  func (srv *Server) MchId() string {
    94  	return srv.mchId
    95  }
    96  func (srv *Server) ApiKey() string {
    97  	return srv.apiKey
    98  }
    99  
   100  func (srv *Server) SubAppId() string {
   101  	return srv.subAppId
   102  }
   103  func (srv *Server) SubMchId() string {
   104  	return srv.subMchId
   105  }
   106  
   107  // ServeHTTP 处理微信服务器的回调请求, query 参数可以为 nil.
   108  func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request, query url.Values) {
   109  	callback.DebugPrintRequest(r)
   110  	errorHandler := srv.errorHandler
   111  
   112  	switch r.Method {
   113  	case "POST":
   114  		requestBody, err := ioutil.ReadAll(r.Body)
   115  		if err != nil {
   116  			errorHandler.ServeError(w, r, err)
   117  			return
   118  		}
   119  		callback.DebugPrintRequestMessage(requestBody)
   120  
   121  		msg, err := util.DecodeXMLToMap(bytes.NewReader(requestBody))
   122  		if err != nil {
   123  			errorHandler.ServeError(w, r, err)
   124  			return
   125  		}
   126  
   127  		returnCode := msg["return_code"]
   128  		if returnCode != "" && returnCode != ReturnCodeSuccess {
   129  			err = &Error{
   130  				ReturnCode: returnCode,
   131  				ReturnMsg:  msg["return_msg"],
   132  			}
   133  			errorHandler.ServeError(w, r, err)
   134  			return
   135  		}
   136  
   137  		resultCode := msg["result_code"]
   138  		if resultCode != "" && resultCode != ResultCodeSuccess {
   139  			err = &BizError{
   140  				ResultCode:  resultCode,
   141  				ErrCode:     msg["err_code"],
   142  				ErrCodeDesc: msg["err_code_des"],
   143  			}
   144  			errorHandler.ServeError(w, r, err)
   145  			return
   146  		}
   147  
   148  		if srv.appId != "" {
   149  			wantAppId := srv.appId
   150  			haveAppId := msg["appid"]
   151  			if haveAppId != "" && !security.SecureCompareString(haveAppId, wantAppId) {
   152  				err = fmt.Errorf("appid mismatch, have: %s, want: %s", haveAppId, wantAppId)
   153  				errorHandler.ServeError(w, r, err)
   154  				return
   155  			}
   156  		}
   157  		if srv.mchId != "" {
   158  			wantMchId := srv.mchId
   159  			haveMchId := msg["mch_id"]
   160  			if haveMchId != "" && !security.SecureCompareString(haveMchId, wantMchId) {
   161  				err = fmt.Errorf("mch_id mismatch, have: %s, want: %s", haveMchId, wantMchId)
   162  				errorHandler.ServeError(w, r, err)
   163  				return
   164  			}
   165  		}
   166  
   167  		if srv.subAppId != "" {
   168  			wantSubAppId := srv.subAppId
   169  			haveSubAppId := msg["sub_appid"]
   170  			if haveSubAppId != "" && !security.SecureCompareString(haveSubAppId, wantSubAppId) {
   171  				err = fmt.Errorf("sub_appid mismatch, have: %s, want: %s", haveSubAppId, wantSubAppId)
   172  				errorHandler.ServeError(w, r, err)
   173  				return
   174  			}
   175  		}
   176  		if srv.subMchId != "" {
   177  			wantSubMchId := srv.subMchId
   178  			haveSubMchId := msg["sub_mch_id"]
   179  			if haveSubMchId != "" && !security.SecureCompareString(haveSubMchId, wantSubMchId) {
   180  				err = fmt.Errorf("sub_mch_id mismatch, have: %s, want: %s", haveSubMchId, wantSubMchId)
   181  				errorHandler.ServeError(w, r, err)
   182  				return
   183  			}
   184  		}
   185  
   186  		// 认证签名
   187  		if haveSignature := msg["sign"]; haveSignature != "" {
   188  			var wantSignature string
   189  			switch signType := msg["sign_type"]; signType {
   190  			case "", SignType_MD5:
   191  				wantSignature = Sign2(msg, srv.apiKey, md5.New())
   192  			case SignType_HMAC_SHA256:
   193  				wantSignature = Sign2(msg, srv.apiKey, hmac.New(sha256.New, []byte(srv.apiKey)))
   194  			default:
   195  				err = fmt.Errorf("unsupported notification sign_type: %s", signType)
   196  				errorHandler.ServeError(w, r, err)
   197  				return
   198  			}
   199  			if !security.SecureCompareString(haveSignature, wantSignature) {
   200  				err = fmt.Errorf("sign mismatch,\nhave: %s,\nwant: %s", haveSignature, wantSignature)
   201  				errorHandler.ServeError(w, r, err)
   202  				return
   203  			}
   204  		} else {
   205  			if _, ok := msg["req_info"]; !ok { // 退款结果通知没有 sign 字段
   206  				err = ErrNotFoundSign
   207  				errorHandler.ServeError(w, r, err)
   208  				return
   209  			}
   210  		}
   211  
   212  		ctx := &Context{
   213  			Server: srv,
   214  
   215  			ResponseWriter: w,
   216  			Request:        r,
   217  
   218  			RequestBody: requestBody,
   219  			Msg:         msg,
   220  
   221  			handlerIndex: initHandlerIndex,
   222  		}
   223  		srv.handler.ServeMsg(ctx)
   224  	default:
   225  		errorHandler.ServeError(w, r, errors.New("Unexpected HTTP Method: "+r.Method))
   226  	}
   227  }