github.com/chanxuehong/wechat@v0.0.0-20230222024006-36f0325263cd/oauth2/api.go (about) 1 package oauth2 2 3 import ( 4 "errors" 5 "fmt" 6 "net/http" 7 "strconv" 8 "time" 9 10 "github.com/chanxuehong/wechat/internal/debug/api" 11 ) 12 13 // ExchangeToken 通过 code 换取网页授权 access_token. 14 // 15 // NOTE: 返回的 token == clt.Token 16 func (clt *Client) ExchangeToken(code string) (token *Token, err error) { 17 if clt.Endpoint == nil { 18 err = errors.New("nil Client.Endpoint") 19 return 20 } 21 22 var tk *Token 23 if clt.TokenStorage != nil { 24 if tk, _ = clt.TokenStorage.Token(); tk == nil { 25 tk = clt.Token 26 } else { 27 clt.Token = tk // update local 28 } 29 } else { 30 tk = clt.Token 31 } 32 if tk == nil { 33 tk = new(Token) 34 } 35 36 if err = clt.updateToken(tk, clt.Endpoint.ExchangeTokenURL(code)); err != nil { 37 return 38 } 39 if err = clt.putToken(tk); err != nil { 40 return 41 } 42 token = tk 43 return 44 } 45 46 // RefreshToken 刷新 access_token. 47 // 48 // NOTE: 49 // 1. refreshToken 可以为空. 50 // 2. 返回的 token == clt.Token 51 func (clt *Client) RefreshToken(refreshToken string) (token *Token, err error) { 52 if clt.Endpoint == nil { 53 err = errors.New("nil Client.Endpoint") 54 return 55 } 56 57 var tk *Token 58 if refreshToken == "" { 59 if tk, err = clt.GetToken(false); err != nil { 60 return 61 } 62 refreshToken = tk.RefreshToken 63 } else { 64 tk = new(Token) 65 } 66 67 if err = clt.updateToken(tk, clt.Endpoint.RefreshTokenURL(refreshToken)); err != nil { 68 return 69 } 70 if err = clt.putToken(tk); err != nil { 71 return 72 } 73 token = tk 74 return 75 } 76 77 func (clt *Client) updateToken(tk *Token, url string) (err error) { 78 api.DebugPrintGetRequest(url) 79 httpResp, err := clt.httpClient().Get(url) 80 if err != nil { 81 return 82 } 83 defer httpResp.Body.Close() 84 85 if httpResp.StatusCode != http.StatusOK { 86 return fmt.Errorf("http.Status: %s", httpResp.Status) 87 } 88 89 var result struct { 90 Error 91 Token 92 } 93 if err = api.DecodeJSONHttpResponse(httpResp.Body, &result); err != nil { 94 return 95 } 96 if result.ErrCode != ErrCodeOK { 97 return &result.Error 98 } 99 100 // 由于网络的延时 和 分布式服务器之间的时间可能不是绝对同步, access_token 过期时间留了一个缓冲区 101 switch { 102 case result.ExpiresIn > 31556952: // 60*60*24*365.2425 103 return errors.New("expires_in too large: " + strconv.FormatInt(result.ExpiresIn, 10)) 104 case result.ExpiresIn > 60*60: 105 result.ExpiresIn -= 60 * 20 106 case result.ExpiresIn > 60*30: 107 result.ExpiresIn -= 60 * 10 108 case result.ExpiresIn > 60*15: 109 result.ExpiresIn -= 60 * 5 110 case result.ExpiresIn > 60*5: 111 result.ExpiresIn -= 60 112 case result.ExpiresIn > 60: 113 result.ExpiresIn -= 20 114 default: 115 return errors.New("expires_in too small: " + strconv.FormatInt(result.ExpiresIn, 10)) 116 } 117 118 tk.AccessToken = result.AccessToken 119 tk.CreatedAt = time.Now().Unix() 120 tk.ExpiresIn = result.ExpiresIn 121 if result.RefreshToken != "" { 122 tk.RefreshToken = result.RefreshToken 123 } 124 if result.OpenId != "" { 125 tk.OpenId = result.OpenId 126 } 127 if result.UnionId != "" { 128 tk.UnionId = result.UnionId 129 } 130 if result.Scope != "" { 131 tk.Scope = result.Scope 132 } 133 return 134 }