github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mch/core/server.go (about) 1 package core 2 3 import ( 4 "bytes" 5 "crypto/hmac" 6 "crypto/md5" 7 "crypto/sha256" 8 "errors" 9 "fmt" 10 "io/ioutil" 11 "net/http" 12 "net/url" 13 14 "github.com/chanxuehong/util" 15 "github.com/chanxuehong/util/security" 16 17 "github.com/chanxuehong/wechat/internal/debug/mch/callback" 18 ) 19 20 type Server struct { 21 appId string 22 mchId string 23 apiKey string 24 25 subAppId string 26 subMchId string 27 28 handler Handler 29 errorHandler ErrorHandler 30 } 31 32 // NewServer 创建一个新的 Server. 33 // 34 // appId: 可选; 公众号的 appid, 如果设置了值则该 Server 只能处理 appid 为该值的消息(事件) 35 // mchId: 可选; 商户号 mch_id, 如果设置了值则该 Server 只能处理 mch_id 为该值的消息(事件) 36 // apiKey: 必选; 商户的签名 key 37 // handler: 必选; 处理微信服务器推送过来的消息(事件)的 Handler 38 // errorHandler: 可选; 用于处理 Server 在处理消息(事件)过程中产生的错误, 如果没有设置则默认使用 DefaultErrorHandler 39 func NewServer(appId, mchId, apiKey string, handler Handler, errorHandler ErrorHandler) *Server { 40 if apiKey == "" { 41 panic("empty apiKey") 42 } 43 if handler == nil { 44 panic("nil Handler") 45 } 46 if errorHandler == nil { 47 errorHandler = DefaultErrorHandler 48 } 49 50 return &Server{ 51 appId: appId, 52 mchId: mchId, 53 apiKey: apiKey, 54 handler: handler, 55 errorHandler: errorHandler, 56 } 57 } 58 59 // NewSubMchServer 创建一个新的 Server. 60 // 61 // appId: 可选; 公众号的 appid, 如果设置了值则该 Server 只能处理 appid 为该值的消息(事件) 62 // mchId: 可选; 商户号 mch_id, 如果设置了值则该 Server 只能处理 mch_id 为该值的消息(事件) 63 // apiKey: 必选; 商户的签名 key 64 // subAppId: 可选; 公众号的 sub_appid, 如果设置了值则该 Server 只能处理 sub_appid 为该值的消息(事件) 65 // subMchId: 可选; 商户号 sub_mch_id, 如果设置了值则该 Server 只能处理 sub_mch_id 为该值的消息(事件) 66 // handler: 必选; 处理微信服务器推送过来的消息(事件)的 Handler 67 // errorHandler: 可选; 用于处理 Server 在处理消息(事件)过程中产生的错误, 如果没有设置则默认使用 DefaultErrorHandler 68 func NewSubMchServer(appId, mchId, apiKey string, subAppId, subMchId string, handler Handler, errorHandler ErrorHandler) *Server { 69 if apiKey == "" { 70 panic("empty apiKey") 71 } 72 if handler == nil { 73 panic("nil Handler") 74 } 75 if errorHandler == nil { 76 errorHandler = DefaultErrorHandler 77 } 78 79 return &Server{ 80 appId: appId, 81 mchId: mchId, 82 apiKey: apiKey, 83 subAppId: subAppId, 84 subMchId: subMchId, 85 handler: handler, 86 errorHandler: errorHandler, 87 } 88 } 89 90 func (srv *Server) AppId() string { 91 return srv.appId 92 } 93 func (srv *Server) MchId() string { 94 return srv.mchId 95 } 96 func (srv *Server) ApiKey() string { 97 return srv.apiKey 98 } 99 100 func (srv *Server) SubAppId() string { 101 return srv.subAppId 102 } 103 func (srv *Server) SubMchId() string { 104 return srv.subMchId 105 } 106 107 // ServeHTTP 处理微信服务器的回调请求, query 参数可以为 nil. 108 func (srv *Server) ServeHTTP(w http.ResponseWriter, r *http.Request, query url.Values) { 109 callback.DebugPrintRequest(r) 110 errorHandler := srv.errorHandler 111 112 switch r.Method { 113 case "POST": 114 requestBody, err := ioutil.ReadAll(r.Body) 115 if err != nil { 116 errorHandler.ServeError(w, r, err) 117 return 118 } 119 callback.DebugPrintRequestMessage(requestBody) 120 121 msg, err := util.DecodeXMLToMap(bytes.NewReader(requestBody)) 122 if err != nil { 123 errorHandler.ServeError(w, r, err) 124 return 125 } 126 127 returnCode := msg["return_code"] 128 if returnCode != "" && returnCode != ReturnCodeSuccess { 129 err = &Error{ 130 ReturnCode: returnCode, 131 ReturnMsg: msg["return_msg"], 132 } 133 errorHandler.ServeError(w, r, err) 134 return 135 } 136 137 resultCode := msg["result_code"] 138 if resultCode != "" && resultCode != ResultCodeSuccess { 139 err = &BizError{ 140 ResultCode: resultCode, 141 ErrCode: msg["err_code"], 142 ErrCodeDesc: msg["err_code_des"], 143 } 144 errorHandler.ServeError(w, r, err) 145 return 146 } 147 148 if srv.appId != "" { 149 wantAppId := srv.appId 150 haveAppId := msg["appid"] 151 if haveAppId != "" && !security.SecureCompareString(haveAppId, wantAppId) { 152 err = fmt.Errorf("appid mismatch, have: %s, want: %s", haveAppId, wantAppId) 153 errorHandler.ServeError(w, r, err) 154 return 155 } 156 } 157 if srv.mchId != "" { 158 wantMchId := srv.mchId 159 haveMchId := msg["mch_id"] 160 if haveMchId != "" && !security.SecureCompareString(haveMchId, wantMchId) { 161 err = fmt.Errorf("mch_id mismatch, have: %s, want: %s", haveMchId, wantMchId) 162 errorHandler.ServeError(w, r, err) 163 return 164 } 165 } 166 167 if srv.subAppId != "" { 168 wantSubAppId := srv.subAppId 169 haveSubAppId := msg["sub_appid"] 170 if haveSubAppId != "" && !security.SecureCompareString(haveSubAppId, wantSubAppId) { 171 err = fmt.Errorf("sub_appid mismatch, have: %s, want: %s", haveSubAppId, wantSubAppId) 172 errorHandler.ServeError(w, r, err) 173 return 174 } 175 } 176 if srv.subMchId != "" { 177 wantSubMchId := srv.subMchId 178 haveSubMchId := msg["sub_mch_id"] 179 if haveSubMchId != "" && !security.SecureCompareString(haveSubMchId, wantSubMchId) { 180 err = fmt.Errorf("sub_mch_id mismatch, have: %s, want: %s", haveSubMchId, wantSubMchId) 181 errorHandler.ServeError(w, r, err) 182 return 183 } 184 } 185 186 // 认证签名 187 if haveSignature := msg["sign"]; haveSignature != "" { 188 var wantSignature string 189 switch signType := msg["sign_type"]; signType { 190 case "", SignType_MD5: 191 wantSignature = Sign2(msg, srv.apiKey, md5.New()) 192 case SignType_HMAC_SHA256: 193 wantSignature = Sign2(msg, srv.apiKey, hmac.New(sha256.New, []byte(srv.apiKey))) 194 default: 195 err = fmt.Errorf("unsupported notification sign_type: %s", signType) 196 errorHandler.ServeError(w, r, err) 197 return 198 } 199 if !security.SecureCompareString(haveSignature, wantSignature) { 200 err = fmt.Errorf("sign mismatch,\nhave: %s,\nwant: %s", haveSignature, wantSignature) 201 errorHandler.ServeError(w, r, err) 202 return 203 } 204 } else { 205 if _, ok := msg["req_info"]; !ok { // 退款结果通知没有 sign 字段 206 err = ErrNotFoundSign 207 errorHandler.ServeError(w, r, err) 208 return 209 } 210 } 211 212 ctx := &Context{ 213 Server: srv, 214 215 ResponseWriter: w, 216 Request: r, 217 218 RequestBody: requestBody, 219 Msg: msg, 220 221 handlerIndex: initHandlerIndex, 222 } 223 srv.handler.ServeMsg(ctx) 224 default: 225 errorHandler.ServeError(w, r, errors.New("Unexpected HTTP Method: "+r.Method)) 226 } 227 }