github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mch/core/client.go (about) 1 package core 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/hmac" 7 "crypto/md5" 8 "crypto/sha256" 9 "fmt" 10 "net/http" 11 "strings" 12 "time" 13 14 "github.com/chanxuehong/util" 15 16 "github.com/chanxuehong/wechat/internal/debug/mch/api" 17 wechatutil "github.com/chanxuehong/wechat/util" 18 ) 19 20 type Client struct { 21 appId string 22 mchId string 23 apiKey string 24 25 subAppId string 26 subMchId string 27 28 httpClient *http.Client 29 } 30 31 func (clt *Client) AppId() string { 32 return clt.appId 33 } 34 func (clt *Client) MchId() string { 35 return clt.mchId 36 } 37 func (clt *Client) ApiKey() string { 38 return clt.apiKey 39 } 40 41 func (clt *Client) SubAppId() string { 42 return clt.subAppId 43 } 44 func (clt *Client) SubMchId() string { 45 return clt.subMchId 46 } 47 48 // NewClient 创建一个新的 Client. 49 // 50 // appId: 必选; 公众号的 appid 51 // mchId: 必选; 商户号 mch_id 52 // apiKey: 必选; 商户的签名 key 53 // httpClient: 可选; 默认使用 util.DefaultHttpClient 54 func NewClient(appId, mchId, apiKey string, httpClient *http.Client) *Client { 55 if httpClient == nil { 56 httpClient = wechatutil.DefaultHttpClient 57 } 58 return &Client{ 59 appId: appId, 60 mchId: mchId, 61 apiKey: apiKey, 62 httpClient: httpClient, 63 } 64 } 65 66 // NewSubMchClient 创建一个新的 Client. 67 // 68 // appId: 必选; 公众号的 appid 69 // mchId: 必选; 商户号 mch_id 70 // apiKey: 必选; 商户的签名 key 71 // subAppId: 可选; 公众号的 sub_appid 72 // subMchId: 必选; 商户号 sub_mch_id 73 // httpClient: 可选; 默认使用 util.DefaultHttpClient 74 func NewSubMchClient(appId, mchId, apiKey string, subAppId, subMchId string, httpClient *http.Client) *Client { 75 if httpClient == nil { 76 httpClient = wechatutil.DefaultHttpClient 77 } 78 return &Client{ 79 appId: appId, 80 mchId: mchId, 81 apiKey: apiKey, 82 subAppId: subAppId, 83 subMchId: subMchId, 84 httpClient: httpClient, 85 } 86 } 87 88 // PostXML 是微信支付通用请求方法. 89 // 90 // err == nil 表示 (return_code == "SUCCESS" && result_code == "SUCCESS"). 91 func (clt *Client) PostXML(url string, req map[string]string) (resp map[string]string, err error) { 92 switch url { 93 case "https://api.mch.weixin.qq.com/mmpaymkttransfers/promotion/transfers", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/promotion/transfers", // 企业付款 94 "https://api.mch.weixin.qq.com/mmpaymkttransfers/sendredpack", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/sendredpack", // 发放普通红包 95 "https://api.mch.weixin.qq.com/mmpaymkttransfers/sendgroupredpack", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/sendgroupredpack": // 发放裂变红包 96 // TODO(chanxuehong): 这几个接口没有标准的 appid 和 mch_id 字段,需要用户在 req 里填写全部参数 97 // TODO(chanxuehong): 通读整个支付文档, 可以的话重新考虑逻辑 98 default: 99 if req["appid"] == "" { 100 req["appid"] = clt.appId 101 } 102 if req["mch_id"] == "" { 103 req["mch_id"] = clt.mchId 104 } 105 if clt.subAppId != "" && req["sub_appid"] == "" { 106 req["sub_appid"] = clt.subAppId 107 } 108 if clt.subMchId != "" && req["sub_mch_id"] == "" { 109 req["sub_mch_id"] = clt.subMchId 110 } 111 } 112 113 // 获取请求参数的 sign_type 并检查其有效性 114 var reqSignType string 115 switch signType := req["sign_type"]; signType { 116 case "", SignType_MD5: 117 reqSignType = SignType_MD5 118 case SignType_HMAC_SHA256: 119 reqSignType = SignType_HMAC_SHA256 120 default: 121 return nil, fmt.Errorf("unsupported request sign_type: %s", signType) 122 } 123 124 // 如果没有签名参数补全签名 125 if req["sign"] == "" { 126 switch reqSignType { 127 case SignType_MD5: 128 req["sign"] = Sign2(req, clt.ApiKey(), md5.New()) 129 case SignType_HMAC_SHA256: 130 req["sign"] = Sign2(req, clt.ApiKey(), hmac.New(sha256.New, []byte(clt.ApiKey()))) 131 } 132 } 133 134 buffer := textBufferPool.Get().(*bytes.Buffer) 135 buffer.Reset() 136 defer textBufferPool.Put(buffer) 137 138 if err = util.EncodeXMLFromMap(buffer, req, "xml"); err != nil { 139 return nil, err 140 } 141 body := buffer.Bytes() 142 143 hasRetried := false 144 RETRY: 145 resp, needRetry, err := clt.postXML(url, body, reqSignType) 146 if err != nil { 147 if needRetry && !hasRetried { 148 // TODO(chanxuehong): 打印错误日志 149 hasRetried = true 150 url = switchRequestURL(url) 151 goto RETRY 152 } 153 return nil, err 154 } 155 return resp, nil 156 } 157 158 func (clt *Client) postXML(url string, body []byte, reqSignType string) (resp map[string]string, needRetry bool, err error) { 159 api.DebugPrintPostXMLRequest(url, body) 160 161 req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) 162 if err != nil { 163 return nil, false, err 164 } 165 req.Header.Set("Content-Type", "text/xml; charset=utf-8") 166 167 ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) 168 defer cancel() 169 170 req = req.WithContext(ctx) 171 httpResp, err := clt.httpClient.Do(req) 172 if err != nil { 173 return nil, true, err 174 } 175 defer httpResp.Body.Close() 176 177 if httpResp.StatusCode != http.StatusOK { 178 return nil, true, fmt.Errorf("http.Status: %s", httpResp.Status) 179 } 180 181 resp, err = api.DecodeXMLHttpResponse(httpResp.Body) 182 if err != nil { 183 return nil, false, err 184 } 185 186 // 判断协议状态 187 returnCode := resp["return_code"] 188 if returnCode == "" { 189 return nil, false, ErrNotFoundReturnCode 190 } 191 if returnCode != ReturnCodeSuccess { 192 return nil, false, &Error{ 193 ReturnCode: returnCode, 194 ReturnMsg: resp["return_msg"], 195 } 196 } 197 198 // 验证 appid 和 mch_id 199 appId := resp["appid"] 200 if appId != "" && appId != clt.appId { 201 return nil, false, fmt.Errorf("appid mismatch, have: %s, want: %s", appId, clt.appId) 202 } 203 mchId := resp["mch_id"] 204 if mchId != "" && mchId != clt.mchId { 205 return nil, false, fmt.Errorf("mch_id mismatch, have: %s, want: %s", mchId, clt.mchId) 206 } 207 208 // 验证 sub_appid 和 sub_mch_id 209 if clt.subAppId != "" { 210 subAppId := resp["sub_appid"] 211 if subAppId != "" && subAppId != clt.subAppId { 212 return nil, false, fmt.Errorf("sub_appid mismatch, have: %s, want: %s", subAppId, clt.subAppId) 213 } 214 } 215 if clt.subMchId != "" { 216 subMchId := resp["sub_mch_id"] 217 if subMchId != "" && subMchId != clt.subMchId { 218 return nil, false, fmt.Errorf("sub_mch_id mismatch, have: %s, want: %s", subMchId, clt.subMchId) 219 } 220 } 221 222 // 验证签名 223 signatureHave := resp["sign"] 224 if signatureHave == "" { 225 // TODO(chanxuehong): 在适当的时候更新下面的 case 226 switch url { 227 default: 228 return nil, false, ErrNotFoundSign 229 case "https://api.mch.weixin.qq.com/mmpaymkttransfers/promotion/transfers", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/promotion/transfers": 230 // do nothing 231 case "https://api.mch.weixin.qq.com/mmpaymkttransfers/gettransferinfo", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/gettransferinfo": 232 // do nothing 233 case "https://api.mch.weixin.qq.com/mmpaymkttransfers/sendredpack", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/sendredpack": 234 // do nothing 235 case "https://api.mch.weixin.qq.com/mmpaymkttransfers/sendgroupredpack", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/sendgroupredpack": 236 // do nothing 237 case "https://api.mch.weixin.qq.com/mmpaymkttransfers/gethbinfo", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/gethbinfo": 238 // do nothing 239 } 240 } else { 241 // 获取返回参数的 sign_type 并检查其有效性 242 var respSignType string 243 switch signType := resp["sign_type"]; signType { 244 case "": 245 respSignType = reqSignType // 默认使用请求参数里的算法, 至少目前是这样 246 case SignType_MD5: 247 respSignType = SignType_MD5 248 case SignType_HMAC_SHA256: 249 respSignType = SignType_HMAC_SHA256 250 default: 251 err = fmt.Errorf("unsupported response sign_type: %s", signType) 252 return nil, false, err 253 } 254 255 // 校验签名 256 var signatureWant string 257 switch respSignType { 258 case SignType_MD5: 259 signatureWant = Sign2(resp, clt.apiKey, md5.New()) 260 case SignType_HMAC_SHA256: 261 signatureWant = Sign2(resp, clt.apiKey, hmac.New(sha256.New, []byte(clt.apiKey))) 262 } 263 if signatureHave != signatureWant { 264 return nil, false, fmt.Errorf("sign mismatch,\nhave: %s,\nwant: %s", signatureHave, signatureWant) 265 } 266 } 267 268 resultCode := resp["result_code"] 269 if resultCode != "" && resultCode != ResultCodeSuccess { 270 errCode := resp["err_code"] 271 if errCode == "SYSTEMERROR" { 272 return nil, true, &BizError{ 273 ResultCode: resultCode, 274 ErrCode: errCode, 275 ErrCodeDesc: resp["err_code_des"], 276 } 277 } 278 return nil, false, &BizError{ 279 ResultCode: resultCode, 280 ErrCode: errCode, 281 ErrCodeDesc: resp["err_code_des"], 282 } 283 } 284 return resp, false, nil 285 } 286 287 func switchRequestURL(url string) string { 288 switch { 289 case strings.HasPrefix(url, "https://api.mch.weixin.qq.com/"): 290 return "https://api2.mch.weixin.qq.com/" + url[len("https://api.mch.weixin.qq.com/"):] 291 case strings.HasPrefix(url, "https://api2.mch.weixin.qq.com/"): 292 return "https://api.mch.weixin.qq.com/" + url[len("https://api2.mch.weixin.qq.com/"):] 293 default: 294 return url 295 } 296 }