github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mp/core/client.go (about) 1 package core 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "net/http" 8 "net/url" 9 "reflect" 10 11 "github.com/chanxuehong/wechat/internal/debug/api" 12 "github.com/chanxuehong/wechat/internal/debug/api/retry" 13 "github.com/chanxuehong/wechat/util" 14 ) 15 16 type Client struct { 17 AccessTokenServer 18 HttpClient *http.Client 19 } 20 21 // NewClient 创建一个新的 Client. 22 // 23 // 如果 clt == nil 则默认用 util.DefaultHttpClient 24 func NewClient(srv AccessTokenServer, clt *http.Client) *Client { 25 if srv == nil { 26 panic("nil AccessTokenServer") 27 } 28 if clt == nil { 29 clt = util.DefaultHttpClient 30 } 31 return &Client{ 32 AccessTokenServer: srv, 33 HttpClient: clt, 34 } 35 } 36 37 // GetJSON HTTP GET 微信资源, 然后将微信服务器返回的 JSON 用 encoding/json 解析到 response. 38 // 39 // NOTE: 40 // 1. 一般不需要调用这个方法, 请直接调用高层次的封装函数; 41 // 2. 最终的 URL == incompleteURL + access_token; 42 // 3. response 格式有要求, 要么是 *Error, 要么是下面结构体的指针(注意 Error 必须是第一个 Field): 43 // struct { 44 // Error 45 // ... 46 // } 47 func (clt *Client) GetJSON(incompleteURL string, response interface{}) (err error) { 48 ErrorStructValue, ErrorErrCodeValue := checkResponse(response) 49 50 httpClient := clt.HttpClient 51 if httpClient == nil { 52 httpClient = util.DefaultHttpClient 53 } 54 55 token, err := clt.Token() 56 if err != nil { 57 return 58 } 59 60 hasRetried := false 61 RETRY: 62 finalURL := incompleteURL + url.QueryEscape(token) 63 if err = httpGetJSON(httpClient, finalURL, response); err != nil { 64 return 65 } 66 67 switch errCode := ErrorErrCodeValue.Int(); errCode { 68 case ErrCodeOK: 69 return 70 case ErrCodeInvalidCredential, ErrCodeAccessTokenExpired: 71 errMsg := ErrorStructValue.Field(errorErrMsgIndex).String() 72 retry.DebugPrintError(errCode, errMsg, token) 73 if !hasRetried { 74 hasRetried = true 75 ErrorStructValue.Set(errorZeroValue) 76 if token, err = clt.RefreshToken(token); err != nil { 77 return 78 } 79 retry.DebugPrintNewToken(token) 80 goto RETRY 81 } 82 retry.DebugPrintFallthrough(token) 83 fallthrough 84 default: 85 return 86 } 87 } 88 89 func httpGetJSON(clt *http.Client, url string, response interface{}) error { 90 api.DebugPrintGetRequest(url) 91 httpResp, err := clt.Get(url) 92 if err != nil { 93 return err 94 } 95 defer httpResp.Body.Close() 96 97 if httpResp.StatusCode != http.StatusOK { 98 return fmt.Errorf("http.Status: %s", httpResp.Status) 99 } 100 return api.DecodeJSONHttpResponse(httpResp.Body, response) 101 } 102 103 // PostJSON 用 encoding/json 把 request marshal 为 JSON, HTTP POST 到微信服务器, 104 // 然后将微信服务器返回的 JSON 用 encoding/json 解析到 response. 105 // 106 // NOTE: 107 // 1. 一般不需要调用这个方法, 请直接调用高层次的封装函数; 108 // 2. 最终的 URL == incompleteURL + access_token; 109 // 3. response 格式有要求, 要么是 *Error, 要么是下面结构体的指针(注意 Error 必须是第一个 Field): 110 // struct { 111 // Error 112 // ... 113 // } 114 func (clt *Client) PostJSON(incompleteURL string, request interface{}, response interface{}) (err error) { 115 ErrorStructValue, ErrorErrCodeValue := checkResponse(response) 116 117 buffer := textBufferPool.Get().(*bytes.Buffer) 118 buffer.Reset() 119 defer textBufferPool.Put(buffer) 120 121 encoder := json.NewEncoder(buffer) 122 encoder.SetEscapeHTML(false) 123 if err = encoder.Encode(request); err != nil { 124 return 125 } 126 requestBodyBytes := buffer.Bytes() 127 if i := len(requestBodyBytes) - 1; i >= 0 && requestBodyBytes[i] == '\n' { 128 requestBodyBytes = requestBodyBytes[:i] // 去掉最后的 '\n', 这样能统一log格式, 不然可能多一个空白行 129 } 130 131 httpClient := clt.HttpClient 132 if httpClient == nil { 133 httpClient = util.DefaultHttpClient 134 } 135 136 token, err := clt.Token() 137 if err != nil { 138 return 139 } 140 141 hasRetried := false 142 RETRY: 143 finalURL := incompleteURL + url.QueryEscape(token) 144 if err = httpPostJSON(httpClient, finalURL, requestBodyBytes, response); err != nil { 145 return 146 } 147 148 switch errCode := ErrorErrCodeValue.Int(); errCode { 149 case ErrCodeOK: 150 return 151 case ErrCodeInvalidCredential, ErrCodeAccessTokenExpired: 152 errMsg := ErrorStructValue.Field(errorErrMsgIndex).String() 153 retry.DebugPrintError(errCode, errMsg, token) 154 if !hasRetried { 155 hasRetried = true 156 ErrorStructValue.Set(errorZeroValue) 157 if token, err = clt.RefreshToken(token); err != nil { 158 return 159 } 160 retry.DebugPrintNewToken(token) 161 goto RETRY 162 } 163 retry.DebugPrintFallthrough(token) 164 fallthrough 165 default: 166 return 167 } 168 } 169 170 func httpPostJSON(clt *http.Client, url string, body []byte, response interface{}) error { 171 api.DebugPrintPostJSONRequest(url, body) 172 httpResp, err := clt.Post(url, "application/json; charset=utf-8", bytes.NewReader(body)) 173 if err != nil { 174 return err 175 } 176 defer httpResp.Body.Close() 177 178 if httpResp.StatusCode != http.StatusOK { 179 return fmt.Errorf("http.Status: %s", httpResp.Status) 180 } 181 return api.DecodeJSONHttpResponse(httpResp.Body, response) 182 } 183 184 // checkResponse 检查 response 参数是否满足特定的结构要求, 如果不满足要求则会 panic, 否则返回相应的 reflect.Value. 185 func checkResponse(response interface{}) (ErrorStructValue, ErrorErrCodeValue reflect.Value) { 186 responseValue := reflect.ValueOf(response) 187 if responseValue.Kind() != reflect.Ptr { 188 panic("the type of response is incorrect") 189 } 190 responseStructValue := responseValue.Elem() 191 if responseStructValue.Kind() != reflect.Struct { 192 panic("the type of response is incorrect") 193 } 194 195 if t := responseStructValue.Type(); t == errorType { 196 ErrorStructValue = responseStructValue 197 } else { 198 if t.NumField() == 0 { 199 panic("the type of response is incorrect") 200 } 201 v := responseStructValue.Field(0) 202 if v.Type() != errorType { 203 panic("the type of response is incorrect") 204 } 205 ErrorStructValue = v 206 } 207 ErrorErrCodeValue = ErrorStructValue.Field(errorErrCodeIndex) 208 return 209 }