github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mch/pay/downloadbill.go (about)

     1  package pay
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/hmac"
     6  	"crypto/md5"
     7  	"crypto/sha256"
     8  	"encoding/xml"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net/http"
    13  	"os"
    14  	"unicode"
    15  
    16  	"github.com/chanxuehong/util"
    17  
    18  	"github.com/chanxuehong/wechat/mch/core"
    19  	wechatutil "github.com/chanxuehong/wechat/util"
    20  )
    21  
    22  type DownloadBillRequest struct {
    23  	XMLName struct{} `xml:"xml" json:"-"`
    24  
    25  	// 必选参数
    26  	BillDate string `xml:"bill_date"` // 下载对账单的日期,格式:20140603
    27  	BillType string `xml:"bill_type"` // 账单类型
    28  
    29  	// 可选参数
    30  	DeviceInfo string `xml:"device_info"` // 微信支付分配的终端设备号
    31  	NonceStr   string `xml:"nonce_str"`   // 随机字符串,不长于32位。推荐随机数生成算法
    32  	SignType   string `xml:"sign_type"`   // 签名类型,目前支持HMAC-SHA256和MD5,默认为MD5
    33  	TarType    string `xml:"tar_type"`    // 压缩账单
    34  }
    35  
    36  // 下载对账单到到文件.
    37  func DownloadBill(clt *core.Client, filepath string, req *DownloadBillRequest, httpClient *http.Client) (written int64, err error) {
    38  	if req == nil {
    39  		return 0, errors.New("nil request req")
    40  	}
    41  
    42  	file, err := os.Create(filepath)
    43  	if err != nil {
    44  		return 0, err
    45  	}
    46  	defer func() {
    47  		file.Close()
    48  		if err != nil {
    49  			os.Remove(filepath)
    50  		}
    51  	}()
    52  	return downloadBillToWriter(clt, file, req, httpClient)
    53  }
    54  
    55  // 下载对账单到 io.Writer.
    56  func DownloadBillToWriter(clt *core.Client, writer io.Writer, req *DownloadBillRequest, httpClient *http.Client) (written int64, err error) {
    57  	if writer == nil {
    58  		return 0, errors.New("nil writer")
    59  	}
    60  	if req == nil {
    61  		return 0, errors.New("nil request req")
    62  	}
    63  	return downloadBillToWriter(clt, writer, req, httpClient)
    64  }
    65  
    66  var (
    67  	// <xml><return_code><![CDATA[FAIL]]></return_code>
    68  	// <return_msg><![CDATA[require POST method]]></return_msg>
    69  	// </xml>
    70  	downloadBillErrorRootNodeStartElement       = []byte("<xml>")
    71  	downloadBillErrorReturnCodeNodeStartElement = []byte("<return_code>")
    72  	downloadBillErrorReturnMsgNodeStartElement  = []byte("<return_msg>")
    73  )
    74  
    75  // 下载对账单到 io.Writer.
    76  func downloadBillToWriter(clt *core.Client, writer io.Writer, req *DownloadBillRequest, httpClient *http.Client) (written int64, err error) {
    77  	if httpClient == nil {
    78  		httpClient = wechatutil.DefaultMediaHttpClient
    79  	}
    80  
    81  	m1 := make(map[string]string, 8)
    82  	m1["appid"] = clt.AppId()
    83  	m1["mch_id"] = clt.MchId()
    84  	if subAppId := clt.SubAppId(); subAppId != "" {
    85  		m1["sub_appid"] = subAppId
    86  	}
    87  	if subMchId := clt.SubMchId(); subMchId != "" {
    88  		m1["sub_mch_id"] = subMchId
    89  	}
    90  	m1["bill_date"] = req.BillDate
    91  	if req.BillType != "" {
    92  		m1["bill_type"] = req.BillType
    93  	}
    94  	if req.DeviceInfo != "" {
    95  		m1["device_info"] = req.DeviceInfo
    96  	}
    97  	if req.NonceStr != "" {
    98  		m1["nonce_str"] = req.NonceStr
    99  	} else {
   100  		m1["nonce_str"] = wechatutil.NonceStr()
   101  	}
   102  	if req.TarType != "" {
   103  		m1["tar_type"] = req.TarType
   104  	}
   105  
   106  	// 签名
   107  	switch req.SignType {
   108  	case "":
   109  		m1["sign"] = core.Sign2(m1, clt.ApiKey(), md5.New())
   110  	case core.SignType_MD5:
   111  		m1["sign_type"] = core.SignType_MD5
   112  		m1["sign"] = core.Sign2(m1, clt.ApiKey(), md5.New())
   113  	case core.SignType_HMAC_SHA256:
   114  		m1["sign_type"] = core.SignType_HMAC_SHA256
   115  		m1["sign"] = core.Sign2(m1, clt.ApiKey(), hmac.New(sha256.New, []byte(clt.ApiKey())))
   116  	default:
   117  		err = fmt.Errorf("unsupported request sign_type: %s", req.SignType)
   118  		return 0, err
   119  	}
   120  
   121  	buffer := make([]byte, 32<<10) // 与 io.copyBuffer 里的默认大小一致
   122  
   123  	requestBuffer := bytes.NewBuffer(buffer[:0])
   124  	if err = util.EncodeXMLFromMap(requestBuffer, m1, "xml"); err != nil {
   125  		return 0, err
   126  	}
   127  
   128  	httpResp, err := httpClient.Post(core.APIBaseURL()+"/pay/downloadbill", "text/xml; charset=utf-8", requestBuffer)
   129  	if err != nil {
   130  		return 0, err
   131  	}
   132  	defer httpResp.Body.Close()
   133  
   134  	if httpResp.StatusCode != http.StatusOK {
   135  		err = fmt.Errorf("http.Status: %s", httpResp.Status)
   136  		return 0, err
   137  	}
   138  
   139  	switch n, err := io.ReadFull(httpResp.Body, buffer); err {
   140  	case nil:
   141  		// n == len(buffer) == 32KB, 可以认为返回的是对账单而不是xml格式的错误信息
   142  		written, err = bytes.NewReader(buffer).WriteTo(writer)
   143  		if err != nil {
   144  			return written, err
   145  		}
   146  		var n2 int64
   147  		n2, err = io.CopyBuffer(writer, httpResp.Body, buffer)
   148  		written += n2
   149  		return written, err
   150  	case io.ErrUnexpectedEOF:
   151  		content := buffer[:n]
   152  		if bs := trimLeft(content); bytes.HasPrefix(bs, downloadBillErrorRootNodeStartElement) {
   153  			bs = trimLeft(bs[len(downloadBillErrorRootNodeStartElement):])
   154  			if bytes.HasPrefix(bs, downloadBillErrorReturnCodeNodeStartElement) || bytes.HasPrefix(bs, downloadBillErrorReturnMsgNodeStartElement) {
   155  				// 可以认为是错误信息了, 尝试解析xml
   156  				var result core.Error
   157  				if err = xml.Unmarshal(content, &result); err == nil {
   158  					return 0, &result
   159  				}
   160  			}
   161  		}
   162  		return bytes.NewReader(content).WriteTo(writer)
   163  	case io.EOF: // 返回空的body
   164  		return 0, nil
   165  	default: // 其他的错误
   166  		return 0, err
   167  	}
   168  }
   169  
   170  func trimLeft(s []byte) []byte {
   171  	for i := 0; i < len(s); i++ {
   172  		if isSpace(s[i]) {
   173  			continue
   174  		}
   175  		return s[i:]
   176  	}
   177  	return s
   178  }
   179  
   180  func isSpace(b byte) bool {
   181  	if b > unicode.MaxASCII {
   182  		return false
   183  	}
   184  	return unicode.IsSpace(rune(b))
   185  }