github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/filesystem/driver/googledrive/oauth.go (about)

     1  package googledrive
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"github.com/cloudreve/Cloudreve/v3/pkg/cache"
     7  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth"
     8  	"github.com/cloudreve/Cloudreve/v3/pkg/request"
     9  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    10  	"io"
    11  	"net/http"
    12  	"net/url"
    13  	"strings"
    14  	"time"
    15  )
    16  
    17  // OAuthURL 获取OAuth认证页面URL
    18  func (client *Client) OAuthURL(ctx context.Context, scope []string) string {
    19  	query := url.Values{
    20  		"client_id":     {client.ClientID},
    21  		"scope":         {strings.Join(scope, " ")},
    22  		"response_type": {"code"},
    23  		"redirect_uri":  {client.Redirect},
    24  		"access_type":   {"offline"},
    25  		"prompt":        {"consent"},
    26  	}
    27  
    28  	u, _ := url.Parse(client.Endpoints.UserConsentEndpoint)
    29  	u.RawQuery = query.Encode()
    30  	return u.String()
    31  }
    32  
    33  // ObtainToken 通过code或refresh_token兑换token
    34  func (client *Client) ObtainToken(ctx context.Context, code, refreshToken string) (*Credential, error) {
    35  	body := url.Values{
    36  		"client_id":     {client.ClientID},
    37  		"redirect_uri":  {client.Redirect},
    38  		"client_secret": {client.ClientSecret},
    39  	}
    40  	if code != "" {
    41  		body.Add("grant_type", "authorization_code")
    42  		body.Add("code", code)
    43  	} else {
    44  		body.Add("grant_type", "refresh_token")
    45  		body.Add("refresh_token", refreshToken)
    46  	}
    47  	strBody := body.Encode()
    48  
    49  	res := client.Request.Request(
    50  		"POST",
    51  		client.Endpoints.TokenEndpoint,
    52  		io.NopCloser(strings.NewReader(strBody)),
    53  		request.WithHeader(http.Header{
    54  			"Content-Type": {"application/x-www-form-urlencoded"}},
    55  		),
    56  		request.WithContentLength(int64(len(strBody))),
    57  	)
    58  	if res.Err != nil {
    59  		return nil, res.Err
    60  	}
    61  
    62  	respBody, err := res.GetResponse()
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	var (
    68  		errResp    OAuthError
    69  		credential Credential
    70  		decodeErr  error
    71  	)
    72  
    73  	if res.Response.StatusCode != 200 {
    74  		decodeErr = json.Unmarshal([]byte(respBody), &errResp)
    75  	} else {
    76  		decodeErr = json.Unmarshal([]byte(respBody), &credential)
    77  	}
    78  	if decodeErr != nil {
    79  		return nil, decodeErr
    80  	}
    81  
    82  	if errResp.ErrorType != "" {
    83  		return nil, errResp
    84  	}
    85  
    86  	return &credential, nil
    87  }
    88  
    89  // UpdateCredential 更新凭证,并检查有效期
    90  func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error {
    91  	if isSlave {
    92  		return client.fetchCredentialFromMaster(ctx)
    93  	}
    94  
    95  	oauth.GlobalMutex.Lock(client.Policy.ID)
    96  	defer oauth.GlobalMutex.Unlock(client.Policy.ID)
    97  
    98  	// 如果已存在凭证
    99  	if client.Credential != nil && client.Credential.AccessToken != "" {
   100  		// 检查已有凭证是否过期
   101  		if client.Credential.ExpiresIn > time.Now().Unix() {
   102  			// 未过期,不要更新
   103  			return nil
   104  		}
   105  	}
   106  
   107  	// 尝试从缓存中获取凭证
   108  	if cacheCredential, ok := cache.Get(TokenCachePrefix + client.ClientID); ok {
   109  		credential := cacheCredential.(Credential)
   110  		if credential.ExpiresIn > time.Now().Unix() {
   111  			client.Credential = &credential
   112  			return nil
   113  		}
   114  	}
   115  
   116  	// 获取新的凭证
   117  	if client.Credential == nil || client.Credential.RefreshToken == "" {
   118  		// 无有效的RefreshToken
   119  		util.Log().Error("Failed to refresh credential for policy %q, please login your Google account again.", client.Policy.Name)
   120  		return ErrInvalidRefreshToken
   121  	}
   122  
   123  	credential, err := client.ObtainToken(ctx, "", client.Credential.RefreshToken)
   124  	if err != nil {
   125  		return err
   126  	}
   127  
   128  	// 更新有效期为绝对时间戳
   129  	expires := credential.ExpiresIn - 60
   130  	credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix()
   131  	// refresh token for Google Drive does not expire in production
   132  	credential.RefreshToken = client.Credential.RefreshToken
   133  	client.Credential = credential
   134  
   135  	// 更新缓存
   136  	cache.Set(TokenCachePrefix+client.ClientID, *credential, int(expires))
   137  
   138  	return nil
   139  }
   140  
   141  func (client *Client) AccessToken() string {
   142  	return client.Credential.AccessToken
   143  }
   144  
   145  // UpdateCredential 更新凭证,并检查有效期
   146  func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
   147  	res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID)
   148  	if err != nil {
   149  		return err
   150  	}
   151  
   152  	client.Credential = &Credential{AccessToken: res}
   153  	return nil
   154  }