github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mp/core/client.go (about)

     1  package core
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"net/url"
     9  	"reflect"
    10  
    11  	"github.com/chanxuehong/wechat/internal/debug/api"
    12  	"github.com/chanxuehong/wechat/internal/debug/api/retry"
    13  	"github.com/chanxuehong/wechat/util"
    14  )
    15  
    16  type Client struct {
    17  	AccessTokenServer
    18  	HttpClient *http.Client
    19  }
    20  
    21  // NewClient 创建一个新的 Client.
    22  //
    23  //	如果 clt == nil 则默认用 util.DefaultHttpClient
    24  func NewClient(srv AccessTokenServer, clt *http.Client) *Client {
    25  	if srv == nil {
    26  		panic("nil AccessTokenServer")
    27  	}
    28  	if clt == nil {
    29  		clt = util.DefaultHttpClient
    30  	}
    31  	return &Client{
    32  		AccessTokenServer: srv,
    33  		HttpClient:        clt,
    34  	}
    35  }
    36  
    37  // GetJSON HTTP GET 微信资源, 然后将微信服务器返回的 JSON 用 encoding/json 解析到 response.
    38  //
    39  //	NOTE:
    40  //	1. 一般不需要调用这个方法, 请直接调用高层次的封装函数;
    41  //	2. 最终的 URL == incompleteURL + access_token;
    42  //	3. response 格式有要求, 要么是 *Error, 要么是下面结构体的指针(注意 Error 必须是第一个 Field):
    43  //	    struct {
    44  //	        Error
    45  //	        ...
    46  //	    }
    47  func (clt *Client) GetJSON(incompleteURL string, response interface{}) (err error) {
    48  	ErrorStructValue, ErrorErrCodeValue := checkResponse(response)
    49  
    50  	httpClient := clt.HttpClient
    51  	if httpClient == nil {
    52  		httpClient = util.DefaultHttpClient
    53  	}
    54  
    55  	token, err := clt.Token()
    56  	if err != nil {
    57  		return
    58  	}
    59  
    60  	hasRetried := false
    61  RETRY:
    62  	finalURL := incompleteURL + url.QueryEscape(token)
    63  	if err = httpGetJSON(httpClient, finalURL, response); err != nil {
    64  		return
    65  	}
    66  
    67  	switch errCode := ErrorErrCodeValue.Int(); errCode {
    68  	case ErrCodeOK:
    69  		return
    70  	case ErrCodeInvalidCredential, ErrCodeAccessTokenExpired:
    71  		errMsg := ErrorStructValue.Field(errorErrMsgIndex).String()
    72  		retry.DebugPrintError(errCode, errMsg, token)
    73  		if !hasRetried {
    74  			hasRetried = true
    75  			ErrorStructValue.Set(errorZeroValue)
    76  			if token, err = clt.RefreshToken(token); err != nil {
    77  				return
    78  			}
    79  			retry.DebugPrintNewToken(token)
    80  			goto RETRY
    81  		}
    82  		retry.DebugPrintFallthrough(token)
    83  		fallthrough
    84  	default:
    85  		return
    86  	}
    87  }
    88  
    89  func httpGetJSON(clt *http.Client, url string, response interface{}) error {
    90  	api.DebugPrintGetRequest(url)
    91  	httpResp, err := clt.Get(url)
    92  	if err != nil {
    93  		return err
    94  	}
    95  	defer httpResp.Body.Close()
    96  
    97  	if httpResp.StatusCode != http.StatusOK {
    98  		return fmt.Errorf("http.Status: %s", httpResp.Status)
    99  	}
   100  	return api.DecodeJSONHttpResponse(httpResp.Body, response)
   101  }
   102  
   103  // PostJSON 用 encoding/json 把 request marshal 为 JSON, HTTP POST 到微信服务器,
   104  // 然后将微信服务器返回的 JSON 用 encoding/json 解析到 response.
   105  //
   106  //	NOTE:
   107  //	1. 一般不需要调用这个方法, 请直接调用高层次的封装函数;
   108  //	2. 最终的 URL == incompleteURL + access_token;
   109  //	3. response 格式有要求, 要么是 *Error, 要么是下面结构体的指针(注意 Error 必须是第一个 Field):
   110  //	    struct {
   111  //	        Error
   112  //	        ...
   113  //	    }
   114  func (clt *Client) PostJSON(incompleteURL string, request interface{}, response interface{}) (err error) {
   115  	ErrorStructValue, ErrorErrCodeValue := checkResponse(response)
   116  
   117  	buffer := textBufferPool.Get().(*bytes.Buffer)
   118  	buffer.Reset()
   119  	defer textBufferPool.Put(buffer)
   120  
   121  	encoder := json.NewEncoder(buffer)
   122  	encoder.SetEscapeHTML(false)
   123  	if err = encoder.Encode(request); err != nil {
   124  		return
   125  	}
   126  	requestBodyBytes := buffer.Bytes()
   127  	if i := len(requestBodyBytes) - 1; i >= 0 && requestBodyBytes[i] == '\n' {
   128  		requestBodyBytes = requestBodyBytes[:i] // 去掉最后的 '\n', 这样能统一log格式, 不然可能多一个空白行
   129  	}
   130  
   131  	httpClient := clt.HttpClient
   132  	if httpClient == nil {
   133  		httpClient = util.DefaultHttpClient
   134  	}
   135  
   136  	token, err := clt.Token()
   137  	if err != nil {
   138  		return
   139  	}
   140  
   141  	hasRetried := false
   142  RETRY:
   143  	finalURL := incompleteURL + url.QueryEscape(token)
   144  	if err = httpPostJSON(httpClient, finalURL, requestBodyBytes, response); err != nil {
   145  		return
   146  	}
   147  
   148  	switch errCode := ErrorErrCodeValue.Int(); errCode {
   149  	case ErrCodeOK:
   150  		return
   151  	case ErrCodeInvalidCredential, ErrCodeAccessTokenExpired:
   152  		errMsg := ErrorStructValue.Field(errorErrMsgIndex).String()
   153  		retry.DebugPrintError(errCode, errMsg, token)
   154  		if !hasRetried {
   155  			hasRetried = true
   156  			ErrorStructValue.Set(errorZeroValue)
   157  			if token, err = clt.RefreshToken(token); err != nil {
   158  				return
   159  			}
   160  			retry.DebugPrintNewToken(token)
   161  			goto RETRY
   162  		}
   163  		retry.DebugPrintFallthrough(token)
   164  		fallthrough
   165  	default:
   166  		return
   167  	}
   168  }
   169  
   170  func httpPostJSON(clt *http.Client, url string, body []byte, response interface{}) error {
   171  	api.DebugPrintPostJSONRequest(url, body)
   172  	httpResp, err := clt.Post(url, "application/json; charset=utf-8", bytes.NewReader(body))
   173  	if err != nil {
   174  		return err
   175  	}
   176  	defer httpResp.Body.Close()
   177  
   178  	if httpResp.StatusCode != http.StatusOK {
   179  		return fmt.Errorf("http.Status: %s", httpResp.Status)
   180  	}
   181  	return api.DecodeJSONHttpResponse(httpResp.Body, response)
   182  }
   183  
   184  // checkResponse 检查 response 参数是否满足特定的结构要求, 如果不满足要求则会 panic, 否则返回相应的 reflect.Value.
   185  func checkResponse(response interface{}) (ErrorStructValue, ErrorErrCodeValue reflect.Value) {
   186  	responseValue := reflect.ValueOf(response)
   187  	if responseValue.Kind() != reflect.Ptr {
   188  		panic("the type of response is incorrect")
   189  	}
   190  	responseStructValue := responseValue.Elem()
   191  	if responseStructValue.Kind() != reflect.Struct {
   192  		panic("the type of response is incorrect")
   193  	}
   194  
   195  	if t := responseStructValue.Type(); t == errorType {
   196  		ErrorStructValue = responseStructValue
   197  	} else {
   198  		if t.NumField() == 0 {
   199  			panic("the type of response is incorrect")
   200  		}
   201  		v := responseStructValue.Field(0)
   202  		if v.Type() != errorType {
   203  			panic("the type of response is incorrect")
   204  		}
   205  		ErrorStructValue = v
   206  	}
   207  	ErrorErrCodeValue = ErrorStructValue.Field(errorErrCodeIndex)
   208  	return
   209  }