github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/filesystem/driver/googledrive/oauth.go (about) 1 package googledrive 2 3 import ( 4 "context" 5 "encoding/json" 6 "github.com/cloudreve/Cloudreve/v3/pkg/cache" 7 "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth" 8 "github.com/cloudreve/Cloudreve/v3/pkg/request" 9 "github.com/cloudreve/Cloudreve/v3/pkg/util" 10 "io" 11 "net/http" 12 "net/url" 13 "strings" 14 "time" 15 ) 16 17 // OAuthURL 获取OAuth认证页面URL 18 func (client *Client) OAuthURL(ctx context.Context, scope []string) string { 19 query := url.Values{ 20 "client_id": {client.ClientID}, 21 "scope": {strings.Join(scope, " ")}, 22 "response_type": {"code"}, 23 "redirect_uri": {client.Redirect}, 24 "access_type": {"offline"}, 25 "prompt": {"consent"}, 26 } 27 28 u, _ := url.Parse(client.Endpoints.UserConsentEndpoint) 29 u.RawQuery = query.Encode() 30 return u.String() 31 } 32 33 // ObtainToken 通过code或refresh_token兑换token 34 func (client *Client) ObtainToken(ctx context.Context, code, refreshToken string) (*Credential, error) { 35 body := url.Values{ 36 "client_id": {client.ClientID}, 37 "redirect_uri": {client.Redirect}, 38 "client_secret": {client.ClientSecret}, 39 } 40 if code != "" { 41 body.Add("grant_type", "authorization_code") 42 body.Add("code", code) 43 } else { 44 body.Add("grant_type", "refresh_token") 45 body.Add("refresh_token", refreshToken) 46 } 47 strBody := body.Encode() 48 49 res := client.Request.Request( 50 "POST", 51 client.Endpoints.TokenEndpoint, 52 io.NopCloser(strings.NewReader(strBody)), 53 request.WithHeader(http.Header{ 54 "Content-Type": {"application/x-www-form-urlencoded"}}, 55 ), 56 request.WithContentLength(int64(len(strBody))), 57 ) 58 if res.Err != nil { 59 return nil, res.Err 60 } 61 62 respBody, err := res.GetResponse() 63 if err != nil { 64 return nil, err 65 } 66 67 var ( 68 errResp OAuthError 69 credential Credential 70 decodeErr error 71 ) 72 73 if res.Response.StatusCode != 200 { 74 decodeErr = json.Unmarshal([]byte(respBody), &errResp) 75 } else { 76 decodeErr = json.Unmarshal([]byte(respBody), &credential) 77 } 78 if decodeErr != nil { 79 return nil, decodeErr 80 } 81 82 if errResp.ErrorType != "" { 83 return nil, errResp 84 } 85 86 return &credential, nil 87 } 88 89 // UpdateCredential 更新凭证,并检查有效期 90 func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error { 91 if isSlave { 92 return client.fetchCredentialFromMaster(ctx) 93 } 94 95 oauth.GlobalMutex.Lock(client.Policy.ID) 96 defer oauth.GlobalMutex.Unlock(client.Policy.ID) 97 98 // 如果已存在凭证 99 if client.Credential != nil && client.Credential.AccessToken != "" { 100 // 检查已有凭证是否过期 101 if client.Credential.ExpiresIn > time.Now().Unix() { 102 // 未过期,不要更新 103 return nil 104 } 105 } 106 107 // 尝试从缓存中获取凭证 108 if cacheCredential, ok := cache.Get(TokenCachePrefix + client.ClientID); ok { 109 credential := cacheCredential.(Credential) 110 if credential.ExpiresIn > time.Now().Unix() { 111 client.Credential = &credential 112 return nil 113 } 114 } 115 116 // 获取新的凭证 117 if client.Credential == nil || client.Credential.RefreshToken == "" { 118 // 无有效的RefreshToken 119 util.Log().Error("Failed to refresh credential for policy %q, please login your Google account again.", client.Policy.Name) 120 return ErrInvalidRefreshToken 121 } 122 123 credential, err := client.ObtainToken(ctx, "", client.Credential.RefreshToken) 124 if err != nil { 125 return err 126 } 127 128 // 更新有效期为绝对时间戳 129 expires := credential.ExpiresIn - 60 130 credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix() 131 // refresh token for Google Drive does not expire in production 132 credential.RefreshToken = client.Credential.RefreshToken 133 client.Credential = credential 134 135 // 更新缓存 136 cache.Set(TokenCachePrefix+client.ClientID, *credential, int(expires)) 137 138 return nil 139 } 140 141 func (client *Client) AccessToken() string { 142 return client.Credential.AccessToken 143 } 144 145 // UpdateCredential 更新凭证,并检查有效期 146 func (client *Client) fetchCredentialFromMaster(ctx context.Context) error { 147 res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID) 148 if err != nil { 149 return err 150 } 151 152 client.Credential = &Credential{AccessToken: res} 153 return nil 154 }