github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mp/material/download.go (about)

     1  package material
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/url"
    10  	"os"
    11  	"unicode"
    12  
    13  	"github.com/chanxuehong/wechat/internal/debug/api"
    14  	"github.com/chanxuehong/wechat/internal/debug/api/retry"
    15  	"github.com/chanxuehong/wechat/mp/core"
    16  	"github.com/chanxuehong/wechat/util"
    17  )
    18  
    19  // Download 下载多媒体到文件.
    20  //
    21  //	对于视频素材, 先通过 GetVideo 得到 Video 信息, 然后通过 Video.DownloadURL 来下载
    22  func Download(clt *core.Client, mediaId, filepath string) (written int64, err error) {
    23  	file, err := os.Create(filepath)
    24  	if err != nil {
    25  		return
    26  	}
    27  	defer func() {
    28  		file.Close()
    29  		if err != nil {
    30  			os.Remove(filepath)
    31  		}
    32  	}()
    33  
    34  	return DownloadToWriter(clt, mediaId, file)
    35  }
    36  
    37  // DownloadToWriter 下载多媒体到 io.Writer.
    38  //
    39  //	对于视频素材, 先通过 GetVideo 得到 Video 信息, 然后通过 Video.DownloadURL 来下载
    40  func DownloadToWriter(clt *core.Client, mediaId string, writer io.Writer) (written int64, err error) {
    41  	httpClient := clt.HttpClient
    42  	if httpClient == nil {
    43  		httpClient = util.DefaultMediaHttpClient
    44  	}
    45  
    46  	buffer := bytes.NewBuffer(make([]byte, 0, 256))
    47  	encoder := json.NewEncoder(buffer)
    48  	encoder.SetEscapeHTML(false)
    49  
    50  	var request = struct {
    51  		MediaId string `json:"media_id"`
    52  	}{
    53  		MediaId: mediaId,
    54  	}
    55  	if err = encoder.Encode(&request); err != nil {
    56  		return
    57  	}
    58  	requestBodyBytes := buffer.Bytes()
    59  
    60  	var errorResult core.Error
    61  
    62  	// 先读取 64bytes 内容来判断返回的是不是错误信息
    63  	// {"errcode":40007,"errmsg":"invalid media_id"}
    64  	var buf = make([]byte, 64)
    65  
    66  	token, err := clt.Token()
    67  	if err != nil {
    68  		return
    69  	}
    70  
    71  	hasRetried := false
    72  RETRY:
    73  	finalURL := "https://api.weixin.qq.com/cgi-bin/material/get_material?access_token=" + url.QueryEscape(token)
    74  	written, err = httpDownloadToWriter(httpClient, finalURL, requestBodyBytes, buf, writer, &errorResult)
    75  	if err != nil {
    76  		return
    77  	}
    78  	if written > 0 {
    79  		return
    80  	}
    81  
    82  	switch errorResult.ErrCode {
    83  	case core.ErrCodeOK:
    84  		return // 基本不会出现
    85  	case core.ErrCodeInvalidCredential, core.ErrCodeAccessTokenExpired:
    86  		retry.DebugPrintError(errorResult.ErrCode, errorResult.ErrMsg, token)
    87  		if !hasRetried {
    88  			hasRetried = true
    89  			errorResult = core.Error{}
    90  			if token, err = clt.RefreshToken(token); err != nil {
    91  				return
    92  			}
    93  			retry.DebugPrintNewToken(token)
    94  			goto RETRY
    95  		}
    96  		retry.DebugPrintFallthrough(token)
    97  		fallthrough
    98  	default:
    99  		err = &errorResult
   100  		return
   101  	}
   102  }
   103  
   104  var (
   105  	// {"errcode":40007,"errmsg":"invalid media_id"}
   106  	errRespBeginWithCode = []byte(`{"errcode":`)
   107  	errRespBeginWithMsg  = []byte(`{"errmsg":"`)
   108  )
   109  
   110  func httpDownloadToWriter(clt *http.Client, url string, body []byte, buf []byte, writer io.Writer, errorResult *core.Error) (written int64, err error) {
   111  	api.DebugPrintPostJSONRequest(url, body)
   112  	httpResp, err := clt.Post(url, "application/json; charset=utf-8", bytes.NewReader(body))
   113  	if err != nil {
   114  		return 0, err
   115  	}
   116  	defer httpResp.Body.Close()
   117  
   118  	if httpResp.StatusCode != http.StatusOK {
   119  		return 0, fmt.Errorf("http.Status: %s", httpResp.Status)
   120  	}
   121  
   122  	buf2 := buf // 保存预先读取的少量头部信息
   123  	switch n, err := io.ReadFull(httpResp.Body, buf2); err {
   124  	case nil:
   125  		break
   126  	case io.ErrUnexpectedEOF:
   127  		buf2 = buf2[:n]
   128  		break
   129  	case io.EOF: // 基本不会出现
   130  		return 0, nil
   131  	default:
   132  		return 0, err
   133  	}
   134  
   135  	var httpRespBody io.Reader
   136  	if len(buf2) < len(buf) {
   137  		httpRespBody = bytes.NewReader(buf2)
   138  	} else {
   139  		httpRespBody = io.MultiReader(bytes.NewReader(buf2), httpResp.Body)
   140  	}
   141  
   142  	buf3 := trimLeft(buf2)
   143  	if bytes.HasPrefix(buf3, errRespBeginWithCode) || bytes.HasPrefix(buf3, errRespBeginWithMsg) {
   144  		// 返回的是错误信息
   145  		return 0, api.DecodeJSONHttpResponse(httpRespBody, errorResult)
   146  	} else {
   147  		// 返回的是媒体流
   148  		return io.Copy(writer, httpRespBody)
   149  	}
   150  }
   151  
   152  func trimLeft(s []byte) []byte {
   153  	for i := 0; i < len(s); i++ {
   154  		if isSpace(s[i]) {
   155  			continue
   156  		}
   157  		return s[i:]
   158  	}
   159  	return s
   160  }
   161  
   162  func isSpace(b byte) bool {
   163  	if b > unicode.MaxASCII {
   164  		return false
   165  	}
   166  	return unicode.IsSpace(rune(b))
   167  }