github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mch/core/client.go (about)

     1  package core
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/hmac"
     7  	"crypto/md5"
     8  	"crypto/sha256"
     9  	"fmt"
    10  	"net/http"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/chanxuehong/util"
    15  
    16  	"github.com/chanxuehong/wechat/internal/debug/mch/api"
    17  	wechatutil "github.com/chanxuehong/wechat/util"
    18  )
    19  
    20  type Client struct {
    21  	appId  string
    22  	mchId  string
    23  	apiKey string
    24  
    25  	subAppId string
    26  	subMchId string
    27  
    28  	httpClient *http.Client
    29  }
    30  
    31  func (clt *Client) AppId() string {
    32  	return clt.appId
    33  }
    34  func (clt *Client) MchId() string {
    35  	return clt.mchId
    36  }
    37  func (clt *Client) ApiKey() string {
    38  	return clt.apiKey
    39  }
    40  
    41  func (clt *Client) SubAppId() string {
    42  	return clt.subAppId
    43  }
    44  func (clt *Client) SubMchId() string {
    45  	return clt.subMchId
    46  }
    47  
    48  // NewClient 创建一个新的 Client.
    49  //
    50  //	appId:      必选; 公众号的 appid
    51  //	mchId:      必选; 商户号 mch_id
    52  //	apiKey:     必选; 商户的签名 key
    53  //	httpClient: 可选; 默认使用 util.DefaultHttpClient
    54  func NewClient(appId, mchId, apiKey string, httpClient *http.Client) *Client {
    55  	if httpClient == nil {
    56  		httpClient = wechatutil.DefaultHttpClient
    57  	}
    58  	return &Client{
    59  		appId:      appId,
    60  		mchId:      mchId,
    61  		apiKey:     apiKey,
    62  		httpClient: httpClient,
    63  	}
    64  }
    65  
    66  // NewSubMchClient 创建一个新的 Client.
    67  //
    68  //	appId:      必选; 公众号的 appid
    69  //	mchId:      必选; 商户号 mch_id
    70  //	apiKey:     必选; 商户的签名 key
    71  //	subAppId:   可选; 公众号的 sub_appid
    72  //	subMchId:   必选; 商户号 sub_mch_id
    73  //	httpClient: 可选; 默认使用 util.DefaultHttpClient
    74  func NewSubMchClient(appId, mchId, apiKey string, subAppId, subMchId string, httpClient *http.Client) *Client {
    75  	if httpClient == nil {
    76  		httpClient = wechatutil.DefaultHttpClient
    77  	}
    78  	return &Client{
    79  		appId:      appId,
    80  		mchId:      mchId,
    81  		apiKey:     apiKey,
    82  		subAppId:   subAppId,
    83  		subMchId:   subMchId,
    84  		httpClient: httpClient,
    85  	}
    86  }
    87  
    88  // PostXML 是微信支付通用请求方法.
    89  //
    90  //	err == nil 表示 (return_code == "SUCCESS" && result_code == "SUCCESS").
    91  func (clt *Client) PostXML(url string, req map[string]string) (resp map[string]string, err error) {
    92  	switch url {
    93  	case "https://api.mch.weixin.qq.com/mmpaymkttransfers/promotion/transfers", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/promotion/transfers", // 企业付款
    94  		"https://api.mch.weixin.qq.com/mmpaymkttransfers/sendredpack", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/sendredpack", // 发放普通红包
    95  		"https://api.mch.weixin.qq.com/mmpaymkttransfers/sendgroupredpack", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/sendgroupredpack": // 发放裂变红包
    96  		// TODO(chanxuehong): 这几个接口没有标准的 appid 和 mch_id 字段,需要用户在 req 里填写全部参数
    97  		// TODO(chanxuehong): 通读整个支付文档, 可以的话重新考虑逻辑
    98  	default:
    99  		if req["appid"] == "" {
   100  			req["appid"] = clt.appId
   101  		}
   102  		if req["mch_id"] == "" {
   103  			req["mch_id"] = clt.mchId
   104  		}
   105  		if clt.subAppId != "" && req["sub_appid"] == "" {
   106  			req["sub_appid"] = clt.subAppId
   107  		}
   108  		if clt.subMchId != "" && req["sub_mch_id"] == "" {
   109  			req["sub_mch_id"] = clt.subMchId
   110  		}
   111  	}
   112  
   113  	// 获取请求参数的 sign_type 并检查其有效性
   114  	var reqSignType string
   115  	switch signType := req["sign_type"]; signType {
   116  	case "", SignType_MD5:
   117  		reqSignType = SignType_MD5
   118  	case SignType_HMAC_SHA256:
   119  		reqSignType = SignType_HMAC_SHA256
   120  	default:
   121  		return nil, fmt.Errorf("unsupported request sign_type: %s", signType)
   122  	}
   123  
   124  	// 如果没有签名参数补全签名
   125  	if req["sign"] == "" {
   126  		switch reqSignType {
   127  		case SignType_MD5:
   128  			req["sign"] = Sign2(req, clt.ApiKey(), md5.New())
   129  		case SignType_HMAC_SHA256:
   130  			req["sign"] = Sign2(req, clt.ApiKey(), hmac.New(sha256.New, []byte(clt.ApiKey())))
   131  		}
   132  	}
   133  
   134  	buffer := textBufferPool.Get().(*bytes.Buffer)
   135  	buffer.Reset()
   136  	defer textBufferPool.Put(buffer)
   137  
   138  	if err = util.EncodeXMLFromMap(buffer, req, "xml"); err != nil {
   139  		return nil, err
   140  	}
   141  	body := buffer.Bytes()
   142  
   143  	hasRetried := false
   144  RETRY:
   145  	resp, needRetry, err := clt.postXML(url, body, reqSignType)
   146  	if err != nil {
   147  		if needRetry && !hasRetried {
   148  			// TODO(chanxuehong): 打印错误日志
   149  			hasRetried = true
   150  			url = switchRequestURL(url)
   151  			goto RETRY
   152  		}
   153  		return nil, err
   154  	}
   155  	return resp, nil
   156  }
   157  
   158  func (clt *Client) postXML(url string, body []byte, reqSignType string) (resp map[string]string, needRetry bool, err error) {
   159  	api.DebugPrintPostXMLRequest(url, body)
   160  
   161  	req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
   162  	if err != nil {
   163  		return nil, false, err
   164  	}
   165  	req.Header.Set("Content-Type", "text/xml; charset=utf-8")
   166  
   167  	ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
   168  	defer cancel()
   169  
   170  	req = req.WithContext(ctx)
   171  	httpResp, err := clt.httpClient.Do(req)
   172  	if err != nil {
   173  		return nil, true, err
   174  	}
   175  	defer httpResp.Body.Close()
   176  
   177  	if httpResp.StatusCode != http.StatusOK {
   178  		return nil, true, fmt.Errorf("http.Status: %s", httpResp.Status)
   179  	}
   180  
   181  	resp, err = api.DecodeXMLHttpResponse(httpResp.Body)
   182  	if err != nil {
   183  		return nil, false, err
   184  	}
   185  
   186  	// 判断协议状态
   187  	returnCode := resp["return_code"]
   188  	if returnCode == "" {
   189  		return nil, false, ErrNotFoundReturnCode
   190  	}
   191  	if returnCode != ReturnCodeSuccess {
   192  		return nil, false, &Error{
   193  			ReturnCode: returnCode,
   194  			ReturnMsg:  resp["return_msg"],
   195  		}
   196  	}
   197  
   198  	// 验证 appid 和 mch_id
   199  	appId := resp["appid"]
   200  	if appId != "" && appId != clt.appId {
   201  		return nil, false, fmt.Errorf("appid mismatch, have: %s, want: %s", appId, clt.appId)
   202  	}
   203  	mchId := resp["mch_id"]
   204  	if mchId != "" && mchId != clt.mchId {
   205  		return nil, false, fmt.Errorf("mch_id mismatch, have: %s, want: %s", mchId, clt.mchId)
   206  	}
   207  
   208  	// 验证 sub_appid 和 sub_mch_id
   209  	if clt.subAppId != "" {
   210  		subAppId := resp["sub_appid"]
   211  		if subAppId != "" && subAppId != clt.subAppId {
   212  			return nil, false, fmt.Errorf("sub_appid mismatch, have: %s, want: %s", subAppId, clt.subAppId)
   213  		}
   214  	}
   215  	if clt.subMchId != "" {
   216  		subMchId := resp["sub_mch_id"]
   217  		if subMchId != "" && subMchId != clt.subMchId {
   218  			return nil, false, fmt.Errorf("sub_mch_id mismatch, have: %s, want: %s", subMchId, clt.subMchId)
   219  		}
   220  	}
   221  
   222  	// 验证签名
   223  	signatureHave := resp["sign"]
   224  	if signatureHave == "" {
   225  		// TODO(chanxuehong): 在适当的时候更新下面的 case
   226  		switch url {
   227  		default:
   228  			return nil, false, ErrNotFoundSign
   229  		case "https://api.mch.weixin.qq.com/mmpaymkttransfers/promotion/transfers", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/promotion/transfers":
   230  			// do nothing
   231  		case "https://api.mch.weixin.qq.com/mmpaymkttransfers/gettransferinfo", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/gettransferinfo":
   232  		// do nothing
   233  		case "https://api.mch.weixin.qq.com/mmpaymkttransfers/sendredpack", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/sendredpack":
   234  			// do nothing
   235  		case "https://api.mch.weixin.qq.com/mmpaymkttransfers/sendgroupredpack", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/sendgroupredpack":
   236  			// do nothing
   237  		case "https://api.mch.weixin.qq.com/mmpaymkttransfers/gethbinfo", "https://api2.mch.weixin.qq.com/mmpaymkttransfers/gethbinfo":
   238  			// do nothing
   239  		}
   240  	} else {
   241  		// 获取返回参数的 sign_type 并检查其有效性
   242  		var respSignType string
   243  		switch signType := resp["sign_type"]; signType {
   244  		case "":
   245  			respSignType = reqSignType // 默认使用请求参数里的算法, 至少目前是这样
   246  		case SignType_MD5:
   247  			respSignType = SignType_MD5
   248  		case SignType_HMAC_SHA256:
   249  			respSignType = SignType_HMAC_SHA256
   250  		default:
   251  			err = fmt.Errorf("unsupported response sign_type: %s", signType)
   252  			return nil, false, err
   253  		}
   254  
   255  		// 校验签名
   256  		var signatureWant string
   257  		switch respSignType {
   258  		case SignType_MD5:
   259  			signatureWant = Sign2(resp, clt.apiKey, md5.New())
   260  		case SignType_HMAC_SHA256:
   261  			signatureWant = Sign2(resp, clt.apiKey, hmac.New(sha256.New, []byte(clt.apiKey)))
   262  		}
   263  		if signatureHave != signatureWant {
   264  			return nil, false, fmt.Errorf("sign mismatch,\nhave: %s,\nwant: %s", signatureHave, signatureWant)
   265  		}
   266  	}
   267  
   268  	resultCode := resp["result_code"]
   269  	if resultCode != "" && resultCode != ResultCodeSuccess {
   270  		errCode := resp["err_code"]
   271  		if errCode == "SYSTEMERROR" {
   272  			return nil, true, &BizError{
   273  				ResultCode:  resultCode,
   274  				ErrCode:     errCode,
   275  				ErrCodeDesc: resp["err_code_des"],
   276  			}
   277  		}
   278  		return nil, false, &BizError{
   279  			ResultCode:  resultCode,
   280  			ErrCode:     errCode,
   281  			ErrCodeDesc: resp["err_code_des"],
   282  		}
   283  	}
   284  	return resp, false, nil
   285  }
   286  
   287  func switchRequestURL(url string) string {
   288  	switch {
   289  	case strings.HasPrefix(url, "https://api.mch.weixin.qq.com/"):
   290  		return "https://api2.mch.weixin.qq.com/" + url[len("https://api.mch.weixin.qq.com/"):]
   291  	case strings.HasPrefix(url, "https://api2.mch.weixin.qq.com/"):
   292  		return "https://api.mch.weixin.qq.com/" + url[len("https://api2.mch.weixin.qq.com/"):]
   293  	default:
   294  		return url
   295  	}
   296  }