github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mp/core/access_token_server.go (about)

     1  package core
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"math/rand"
     7  	"net/http"
     8  	"net/url"
     9  	"strconv"
    10  	"sync/atomic"
    11  	"time"
    12  	"unsafe"
    13  
    14  	"github.com/chanxuehong/wechat/internal/debug/api"
    15  	"github.com/chanxuehong/wechat/util"
    16  )
    17  
    18  // access_token 中控服务器接口.
    19  type AccessTokenServer interface {
    20  	Token() (token string, err error)                           // 请求中控服务器返回缓存的 access_token
    21  	RefreshToken(currentToken string) (token string, err error) // 请求中控服务器刷新 access_token
    22  	IID01332E16DF5011E5A9D5A4DB30FED8E1()                       // 接口标识, 没有实际意义
    23  }
    24  
    25  var _ AccessTokenServer = (*DefaultAccessTokenServer)(nil)
    26  
    27  // DefaultAccessTokenServer 实现了 AccessTokenServer 接口.
    28  //
    29  //	NOTE:
    30  //	1. 用于单进程环境.
    31  //	2. 因为 DefaultAccessTokenServer 同时也是一个简单的中控服务器, 而不是仅仅实现 AccessTokenServer 接口,
    32  //	   所以整个系统只能存在一个 DefaultAccessTokenServer 实例!
    33  type DefaultAccessTokenServer struct {
    34  	appId      string
    35  	appSecret  string
    36  	httpClient *http.Client
    37  
    38  	refreshTokenRequestChan  chan string             // chan currentToken
    39  	refreshTokenResponseChan chan refreshTokenResult // chan {token, err}
    40  
    41  	tokenCache unsafe.Pointer // *accessToken
    42  }
    43  
    44  // NewDefaultAccessTokenServer 创建一个新的 DefaultAccessTokenServer, 如果 httpClient == nil 则默认使用 util.DefaultHttpClient.
    45  func NewDefaultAccessTokenServer(appId, appSecret string, httpClient *http.Client) (srv *DefaultAccessTokenServer) {
    46  	if httpClient == nil {
    47  		httpClient = util.DefaultHttpClient
    48  	}
    49  
    50  	srv = &DefaultAccessTokenServer{
    51  		appId:                    url.QueryEscape(appId),
    52  		appSecret:                url.QueryEscape(appSecret),
    53  		httpClient:               httpClient,
    54  		refreshTokenRequestChan:  make(chan string),
    55  		refreshTokenResponseChan: make(chan refreshTokenResult),
    56  	}
    57  
    58  	go srv.tokenUpdateDaemon(time.Hour * 24 * time.Duration(100+rand.Int63n(200)))
    59  	return
    60  }
    61  
    62  func (srv *DefaultAccessTokenServer) IID01332E16DF5011E5A9D5A4DB30FED8E1() {}
    63  
    64  func (srv *DefaultAccessTokenServer) Token() (token string, err error) {
    65  	if p := (*accessToken)(atomic.LoadPointer(&srv.tokenCache)); p != nil {
    66  		return p.Token, nil
    67  	}
    68  	return srv.RefreshToken("")
    69  }
    70  
    71  type refreshTokenResult struct {
    72  	token string
    73  	err   error
    74  }
    75  
    76  func (srv *DefaultAccessTokenServer) RefreshToken(currentToken string) (token string, err error) {
    77  	srv.refreshTokenRequestChan <- currentToken
    78  	rslt := <-srv.refreshTokenResponseChan
    79  	return rslt.token, rslt.err
    80  }
    81  
    82  func (srv *DefaultAccessTokenServer) tokenUpdateDaemon(initTickDuration time.Duration) {
    83  	tickDuration := initTickDuration
    84  
    85  NEW_TICK_DURATION:
    86  	ticker := time.NewTicker(tickDuration)
    87  	for {
    88  		select {
    89  		case currentToken := <-srv.refreshTokenRequestChan:
    90  			accessToken, cached, err := srv.updateToken(currentToken)
    91  			if err != nil {
    92  				srv.refreshTokenResponseChan <- refreshTokenResult{err: err}
    93  				break
    94  			}
    95  			srv.refreshTokenResponseChan <- refreshTokenResult{token: accessToken.Token}
    96  			if !cached {
    97  				tickDuration = time.Duration(accessToken.ExpiresIn) * time.Second
    98  				ticker.Stop()
    99  				goto NEW_TICK_DURATION
   100  			}
   101  
   102  		case <-ticker.C:
   103  			accessToken, _, err := srv.updateToken("")
   104  			if err != nil {
   105  				break
   106  			}
   107  			newTickDuration := time.Duration(accessToken.ExpiresIn) * time.Second
   108  			if abs(tickDuration-newTickDuration) > time.Second*5 {
   109  				tickDuration = newTickDuration
   110  				ticker.Stop()
   111  				goto NEW_TICK_DURATION
   112  			}
   113  		}
   114  	}
   115  }
   116  
   117  func abs(x time.Duration) time.Duration {
   118  	if x >= 0 {
   119  		return x
   120  	}
   121  	return -x
   122  }
   123  
   124  type accessToken struct {
   125  	Token     string `json:"access_token"`
   126  	ExpiresIn int64  `json:"expires_in"`
   127  }
   128  
   129  // updateToken 从微信服务器获取新的 access_token 并存入缓存, 同时返回该 access_token.
   130  func (srv *DefaultAccessTokenServer) updateToken(currentToken string) (token *accessToken, cached bool, err error) {
   131  	if currentToken != "" {
   132  		if p := (*accessToken)(atomic.LoadPointer(&srv.tokenCache)); p != nil && currentToken != p.Token {
   133  			return p, true, nil // 无需更改 p.ExpiresIn 参数值, cached == true 时用不到
   134  		}
   135  	}
   136  
   137  	url := "https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid=" + srv.appId +
   138  		"&secret=" + srv.appSecret
   139  	api.DebugPrintGetRequest(url)
   140  	httpResp, err := srv.httpClient.Get(url)
   141  	if err != nil {
   142  		atomic.StorePointer(&srv.tokenCache, nil)
   143  		return
   144  	}
   145  	defer httpResp.Body.Close()
   146  
   147  	if httpResp.StatusCode != http.StatusOK {
   148  		atomic.StorePointer(&srv.tokenCache, nil)
   149  		err = fmt.Errorf("http.Status: %s", httpResp.Status)
   150  		return
   151  	}
   152  
   153  	var result struct {
   154  		Error
   155  		accessToken
   156  	}
   157  	if err = api.DecodeJSONHttpResponse(httpResp.Body, &result); err != nil {
   158  		atomic.StorePointer(&srv.tokenCache, nil)
   159  		return
   160  	}
   161  	if result.ErrCode != ErrCodeOK {
   162  		atomic.StorePointer(&srv.tokenCache, nil)
   163  		err = &result.Error
   164  		return
   165  	}
   166  
   167  	// 由于网络的延时, access_token 过期时间留有一个缓冲区
   168  	switch {
   169  	case result.ExpiresIn > 31556952: // 60*60*24*365.2425
   170  		atomic.StorePointer(&srv.tokenCache, nil)
   171  		err = errors.New("expires_in too large: " + strconv.FormatInt(result.ExpiresIn, 10))
   172  		return
   173  	case result.ExpiresIn > 60*60:
   174  		result.ExpiresIn -= 60 * 10
   175  	case result.ExpiresIn > 60*30:
   176  		result.ExpiresIn -= 60 * 5
   177  	case result.ExpiresIn > 60*5:
   178  		result.ExpiresIn -= 60
   179  	case result.ExpiresIn > 60:
   180  		result.ExpiresIn -= 10
   181  	default:
   182  		atomic.StorePointer(&srv.tokenCache, nil)
   183  		err = errors.New("expires_in too small: " + strconv.FormatInt(result.ExpiresIn, 10))
   184  		return
   185  	}
   186  
   187  	tokenCopy := result.accessToken
   188  	atomic.StorePointer(&srv.tokenCache, unsafe.Pointer(&tokenCopy))
   189  	token = &tokenCopy
   190  	return
   191  }