github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/mp/core/access_token_server.go (about) 1 package core 2 3 import ( 4 "errors" 5 "fmt" 6 "math/rand" 7 "net/http" 8 "net/url" 9 "strconv" 10 "sync/atomic" 11 "time" 12 "unsafe" 13 14 "github.com/chanxuehong/wechat/internal/debug/api" 15 "github.com/chanxuehong/wechat/util" 16 ) 17 18 // access_token 中控服务器接口. 19 type AccessTokenServer interface { 20 Token() (token string, err error) // 请求中控服务器返回缓存的 access_token 21 RefreshToken(currentToken string) (token string, err error) // 请求中控服务器刷新 access_token 22 IID01332E16DF5011E5A9D5A4DB30FED8E1() // 接口标识, 没有实际意义 23 } 24 25 var _ AccessTokenServer = (*DefaultAccessTokenServer)(nil) 26 27 // DefaultAccessTokenServer 实现了 AccessTokenServer 接口. 28 // 29 // NOTE: 30 // 1. 用于单进程环境. 31 // 2. 因为 DefaultAccessTokenServer 同时也是一个简单的中控服务器, 而不是仅仅实现 AccessTokenServer 接口, 32 // 所以整个系统只能存在一个 DefaultAccessTokenServer 实例! 33 type DefaultAccessTokenServer struct { 34 appId string 35 appSecret string 36 httpClient *http.Client 37 38 refreshTokenRequestChan chan string // chan currentToken 39 refreshTokenResponseChan chan refreshTokenResult // chan {token, err} 40 41 tokenCache unsafe.Pointer // *accessToken 42 } 43 44 // NewDefaultAccessTokenServer 创建一个新的 DefaultAccessTokenServer, 如果 httpClient == nil 则默认使用 util.DefaultHttpClient. 45 func NewDefaultAccessTokenServer(appId, appSecret string, httpClient *http.Client) (srv *DefaultAccessTokenServer) { 46 if httpClient == nil { 47 httpClient = util.DefaultHttpClient 48 } 49 50 srv = &DefaultAccessTokenServer{ 51 appId: url.QueryEscape(appId), 52 appSecret: url.QueryEscape(appSecret), 53 httpClient: httpClient, 54 refreshTokenRequestChan: make(chan string), 55 refreshTokenResponseChan: make(chan refreshTokenResult), 56 } 57 58 go srv.tokenUpdateDaemon(time.Hour * 24 * time.Duration(100+rand.Int63n(200))) 59 return 60 } 61 62 func (srv *DefaultAccessTokenServer) IID01332E16DF5011E5A9D5A4DB30FED8E1() {} 63 64 func (srv *DefaultAccessTokenServer) Token() (token string, err error) { 65 if p := (*accessToken)(atomic.LoadPointer(&srv.tokenCache)); p != nil { 66 return p.Token, nil 67 } 68 return srv.RefreshToken("") 69 } 70 71 type refreshTokenResult struct { 72 token string 73 err error 74 } 75 76 func (srv *DefaultAccessTokenServer) RefreshToken(currentToken string) (token string, err error) { 77 srv.refreshTokenRequestChan <- currentToken 78 rslt := <-srv.refreshTokenResponseChan 79 return rslt.token, rslt.err 80 } 81 82 func (srv *DefaultAccessTokenServer) tokenUpdateDaemon(initTickDuration time.Duration) { 83 tickDuration := initTickDuration 84 85 NEW_TICK_DURATION: 86 ticker := time.NewTicker(tickDuration) 87 for { 88 select { 89 case currentToken := <-srv.refreshTokenRequestChan: 90 accessToken, cached, err := srv.updateToken(currentToken) 91 if err != nil { 92 srv.refreshTokenResponseChan <- refreshTokenResult{err: err} 93 break 94 } 95 srv.refreshTokenResponseChan <- refreshTokenResult{token: accessToken.Token} 96 if !cached { 97 tickDuration = time.Duration(accessToken.ExpiresIn) * time.Second 98 ticker.Stop() 99 goto NEW_TICK_DURATION 100 } 101 102 case <-ticker.C: 103 accessToken, _, err := srv.updateToken("") 104 if err != nil { 105 break 106 } 107 newTickDuration := time.Duration(accessToken.ExpiresIn) * time.Second 108 if abs(tickDuration-newTickDuration) > time.Second*5 { 109 tickDuration = newTickDuration 110 ticker.Stop() 111 goto NEW_TICK_DURATION 112 } 113 } 114 } 115 } 116 117 func abs(x time.Duration) time.Duration { 118 if x >= 0 { 119 return x 120 } 121 return -x 122 } 123 124 type accessToken struct { 125 Token string `json:"access_token"` 126 ExpiresIn int64 `json:"expires_in"` 127 } 128 129 // updateToken 从微信服务器获取新的 access_token 并存入缓存, 同时返回该 access_token. 130 func (srv *DefaultAccessTokenServer) updateToken(currentToken string) (token *accessToken, cached bool, err error) { 131 if currentToken != "" { 132 if p := (*accessToken)(atomic.LoadPointer(&srv.tokenCache)); p != nil && currentToken != p.Token { 133 return p, true, nil // 无需更改 p.ExpiresIn 参数值, cached == true 时用不到 134 } 135 } 136 137 url := "https://api.weixin.qq.com/cgi-bin/token?grant_type=client_credential&appid=" + srv.appId + 138 "&secret=" + srv.appSecret 139 api.DebugPrintGetRequest(url) 140 httpResp, err := srv.httpClient.Get(url) 141 if err != nil { 142 atomic.StorePointer(&srv.tokenCache, nil) 143 return 144 } 145 defer httpResp.Body.Close() 146 147 if httpResp.StatusCode != http.StatusOK { 148 atomic.StorePointer(&srv.tokenCache, nil) 149 err = fmt.Errorf("http.Status: %s", httpResp.Status) 150 return 151 } 152 153 var result struct { 154 Error 155 accessToken 156 } 157 if err = api.DecodeJSONHttpResponse(httpResp.Body, &result); err != nil { 158 atomic.StorePointer(&srv.tokenCache, nil) 159 return 160 } 161 if result.ErrCode != ErrCodeOK { 162 atomic.StorePointer(&srv.tokenCache, nil) 163 err = &result.Error 164 return 165 } 166 167 // 由于网络的延时, access_token 过期时间留有一个缓冲区 168 switch { 169 case result.ExpiresIn > 31556952: // 60*60*24*365.2425 170 atomic.StorePointer(&srv.tokenCache, nil) 171 err = errors.New("expires_in too large: " + strconv.FormatInt(result.ExpiresIn, 10)) 172 return 173 case result.ExpiresIn > 60*60: 174 result.ExpiresIn -= 60 * 10 175 case result.ExpiresIn > 60*30: 176 result.ExpiresIn -= 60 * 5 177 case result.ExpiresIn > 60*5: 178 result.ExpiresIn -= 60 179 case result.ExpiresIn > 60: 180 result.ExpiresIn -= 10 181 default: 182 atomic.StorePointer(&srv.tokenCache, nil) 183 err = errors.New("expires_in too small: " + strconv.FormatInt(result.ExpiresIn, 10)) 184 return 185 } 186 187 tokenCopy := result.accessToken 188 atomic.StorePointer(&srv.tokenCache, unsafe.Pointer(&tokenCopy)) 189 token = &tokenCopy 190 return 191 }