github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mp/core/client_upload.go (about)

     1  package core
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"mime/multipart"
     8  	"net/http"
     9  	"net/url"
    10  
    11  	"github.com/chanxuehong/wechat/internal/debug/api"
    12  	"github.com/chanxuehong/wechat/internal/debug/api/retry"
    13  	"github.com/chanxuehong/wechat/util"
    14  )
    15  
    16  type MultipartFormField struct {
    17  	IsFile   bool
    18  	Name     string
    19  	FileName string
    20  	Value    io.Reader
    21  }
    22  
    23  // PostMultipartForm 通用上传接口.
    24  //
    25  //	--BOUNDARY
    26  //	Content-Disposition: form-data; name="FIELDNAME"; filename="FILENAME"
    27  //	Content-Type: application/octet-stream
    28  //
    29  //	FILE-CONTENT
    30  //	--BOUNDARY
    31  //	Content-Disposition: form-data; name="FIELDNAME"
    32  //
    33  //	JSON-DESCRIPTION
    34  //	--BOUNDARY--
    35  //
    36  //
    37  //	NOTE:
    38  //	1. 一般不需要调用这个方法, 请直接调用高层次的封装函数;
    39  //	2. 最终的 URL == incompleteURL + access_token;
    40  //	3. response 格式有要求, 要么是 *Error, 要么是下面结构体的指针(注意 Error 必须是第一个 Field):
    41  //	    struct {
    42  //	        Error
    43  //	        ...
    44  //	    }
    45  func (clt *Client) PostMultipartForm(incompleteURL string, fields []MultipartFormField, response interface{}) (err error) {
    46  	ErrorStructValue, ErrorErrCodeValue := checkResponse(response)
    47  
    48  	buffer := mediaBufferPool.Get().(*bytes.Buffer)
    49  	buffer.Reset()
    50  	defer mediaBufferPool.Put(buffer)
    51  
    52  	multipartWriter := multipart.NewWriter(buffer)
    53  	for i := 0; i < len(fields); i++ {
    54  		if field := &fields[i]; field.IsFile {
    55  			partWriter, err3 := multipartWriter.CreateFormFile(field.Name, field.FileName)
    56  			if err3 != nil {
    57  				return err3
    58  			}
    59  			if _, err3 = io.Copy(partWriter, field.Value); err3 != nil {
    60  				return err3
    61  			}
    62  		} else {
    63  			partWriter, err3 := multipartWriter.CreateFormField(field.Name)
    64  			if err3 != nil {
    65  				return err3
    66  			}
    67  			if _, err3 = io.Copy(partWriter, field.Value); err3 != nil {
    68  				return err3
    69  			}
    70  		}
    71  	}
    72  	if err = multipartWriter.Close(); err != nil {
    73  		return
    74  	}
    75  	requestBodyBytes := buffer.Bytes()
    76  	requestBodyType := multipartWriter.FormDataContentType()
    77  
    78  	httpClient := clt.HttpClient
    79  	if httpClient == nil {
    80  		httpClient = util.DefaultMediaHttpClient
    81  	}
    82  
    83  	token, err := clt.Token()
    84  	if err != nil {
    85  		return
    86  	}
    87  
    88  	hasRetried := false
    89  RETRY:
    90  	finalURL := incompleteURL + url.QueryEscape(token)
    91  	if err = httpPostMultipartForm(httpClient, finalURL, requestBodyType, requestBodyBytes, response); err != nil {
    92  		return
    93  	}
    94  
    95  	switch errCode := ErrorErrCodeValue.Int(); errCode {
    96  	case ErrCodeOK:
    97  		return
    98  	case ErrCodeInvalidCredential, ErrCodeAccessTokenExpired:
    99  		errMsg := ErrorStructValue.Field(errorErrMsgIndex).String()
   100  		retry.DebugPrintError(errCode, errMsg, token)
   101  		if !hasRetried {
   102  			hasRetried = true
   103  			ErrorStructValue.Set(errorZeroValue)
   104  			if token, err = clt.RefreshToken(token); err != nil {
   105  				return
   106  			}
   107  			retry.DebugPrintNewToken(token)
   108  			goto RETRY
   109  		}
   110  		retry.DebugPrintFallthrough(token)
   111  		fallthrough
   112  	default:
   113  		return
   114  	}
   115  }
   116  
   117  func httpPostMultipartForm(clt *http.Client, url, bodyType string, body []byte, response interface{}) error {
   118  	api.DebugPrintPostMultipartRequest(url, body)
   119  	httpResp, err := clt.Post(url, bodyType, bytes.NewReader(body))
   120  	if err != nil {
   121  		return err
   122  	}
   123  	defer httpResp.Body.Close()
   124  
   125  	if httpResp.StatusCode != http.StatusOK {
   126  		return fmt.Errorf("http.Status: %s", httpResp.Status)
   127  	}
   128  	return api.DecodeJSONHttpResponse(httpResp.Body, response)
   129  }