github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/oauth2/api.go (about)

     1  package oauth2
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/http"
     7  	"strconv"
     8  	"time"
     9  
    10  	"github.com/chanxuehong/wechat/internal/debug/api"
    11  )
    12  
    13  // ExchangeToken 通过 code 换取网页授权 access_token.
    14  //
    15  //	NOTE: 返回的 token == clt.Token
    16  func (clt *Client) ExchangeToken(code string) (token *Token, err error) {
    17  	if clt.Endpoint == nil {
    18  		err = errors.New("nil Client.Endpoint")
    19  		return
    20  	}
    21  
    22  	var tk *Token
    23  	if clt.TokenStorage != nil {
    24  		if tk, _ = clt.TokenStorage.Token(); tk == nil {
    25  			tk = clt.Token
    26  		} else {
    27  			clt.Token = tk // update local
    28  		}
    29  	} else {
    30  		tk = clt.Token
    31  	}
    32  	if tk == nil {
    33  		tk = new(Token)
    34  	}
    35  
    36  	if err = clt.updateToken(tk, clt.Endpoint.ExchangeTokenURL(code)); err != nil {
    37  		return
    38  	}
    39  	if err = clt.putToken(tk); err != nil {
    40  		return
    41  	}
    42  	token = tk
    43  	return
    44  }
    45  
    46  // RefreshToken 刷新 access_token.
    47  //
    48  //	NOTE:
    49  //	1. refreshToken 可以为空.
    50  //	2. 返回的 token == clt.Token
    51  func (clt *Client) RefreshToken(refreshToken string) (token *Token, err error) {
    52  	if clt.Endpoint == nil {
    53  		err = errors.New("nil Client.Endpoint")
    54  		return
    55  	}
    56  
    57  	var tk *Token
    58  	if refreshToken == "" {
    59  		if tk, err = clt.GetToken(false); err != nil {
    60  			return
    61  		}
    62  		refreshToken = tk.RefreshToken
    63  	} else {
    64  		tk = new(Token)
    65  	}
    66  
    67  	if err = clt.updateToken(tk, clt.Endpoint.RefreshTokenURL(refreshToken)); err != nil {
    68  		return
    69  	}
    70  	if err = clt.putToken(tk); err != nil {
    71  		return
    72  	}
    73  	token = tk
    74  	return
    75  }
    76  
    77  func (clt *Client) updateToken(tk *Token, url string) (err error) {
    78  	api.DebugPrintGetRequest(url)
    79  	httpResp, err := clt.httpClient().Get(url)
    80  	if err != nil {
    81  		return
    82  	}
    83  	defer httpResp.Body.Close()
    84  
    85  	if httpResp.StatusCode != http.StatusOK {
    86  		return fmt.Errorf("http.Status: %s", httpResp.Status)
    87  	}
    88  
    89  	var result struct {
    90  		Error
    91  		Token
    92  	}
    93  	if err = api.DecodeJSONHttpResponse(httpResp.Body, &result); err != nil {
    94  		return
    95  	}
    96  	if result.ErrCode != ErrCodeOK {
    97  		return &result.Error
    98  	}
    99  
   100  	// 由于网络的延时 和 分布式服务器之间的时间可能不是绝对同步, access_token 过期时间留了一个缓冲区
   101  	switch {
   102  	case result.ExpiresIn > 31556952: // 60*60*24*365.2425
   103  		return errors.New("expires_in too large: " + strconv.FormatInt(result.ExpiresIn, 10))
   104  	case result.ExpiresIn > 60*60:
   105  		result.ExpiresIn -= 60 * 20
   106  	case result.ExpiresIn > 60*30:
   107  		result.ExpiresIn -= 60 * 10
   108  	case result.ExpiresIn > 60*15:
   109  		result.ExpiresIn -= 60 * 5
   110  	case result.ExpiresIn > 60*5:
   111  		result.ExpiresIn -= 60
   112  	case result.ExpiresIn > 60:
   113  		result.ExpiresIn -= 20
   114  	default:
   115  		return errors.New("expires_in too small: " + strconv.FormatInt(result.ExpiresIn, 10))
   116  	}
   117  
   118  	tk.AccessToken = result.AccessToken
   119  	tk.CreatedAt = time.Now().Unix()
   120  	tk.ExpiresIn = result.ExpiresIn
   121  	if result.RefreshToken != "" {
   122  		tk.RefreshToken = result.RefreshToken
   123  	}
   124  	if result.OpenId != "" {
   125  		tk.OpenId = result.OpenId
   126  	}
   127  	if result.UnionId != "" {
   128  		tk.UnionId = result.UnionId
   129  	}
   130  	if result.Scope != "" {
   131  		tk.Scope = result.Scope
   132  	}
   133  	return
   134  }