github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/filesystem/driver/onedrive/oauth.go (about)

     1  package onedrive
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/url"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/cloudreve/Cloudreve/v3/pkg/cache"
    13  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth"
    14  	"github.com/cloudreve/Cloudreve/v3/pkg/request"
    15  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    16  )
    17  
    18  // Error 实现error接口
    19  func (err OAuthError) Error() string {
    20  	return err.ErrorDescription
    21  }
    22  
    23  // OAuthURL 获取OAuth认证页面URL
    24  func (client *Client) OAuthURL(ctx context.Context, scope []string) string {
    25  	query := url.Values{
    26  		"client_id":     {client.ClientID},
    27  		"scope":         {strings.Join(scope, " ")},
    28  		"response_type": {"code"},
    29  		"redirect_uri":  {client.Redirect},
    30  	}
    31  	client.Endpoints.OAuthEndpoints.authorize.RawQuery = query.Encode()
    32  	return client.Endpoints.OAuthEndpoints.authorize.String()
    33  }
    34  
    35  // getOAuthEndpoint 根据指定的AuthURL获取详细的认证接口地址
    36  func (client *Client) getOAuthEndpoint() *oauthEndpoint {
    37  	base, err := url.Parse(client.Endpoints.OAuthURL)
    38  	if err != nil {
    39  		return nil
    40  	}
    41  	var (
    42  		token     *url.URL
    43  		authorize *url.URL
    44  	)
    45  	switch base.Host {
    46  	case "login.live.com":
    47  		token, _ = url.Parse("https://login.live.com/oauth20_token.srf")
    48  		authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf")
    49  	case "login.chinacloudapi.cn":
    50  		client.Endpoints.isInChina = true
    51  		token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token")
    52  		authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize")
    53  	default:
    54  		token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token")
    55  		authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
    56  	}
    57  
    58  	return &oauthEndpoint{
    59  		token:     *token,
    60  		authorize: *authorize,
    61  	}
    62  }
    63  
    64  // ObtainToken 通过code或refresh_token兑换token
    65  func (client *Client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) {
    66  	options := newDefaultOption()
    67  	for _, o := range opts {
    68  		o.apply(options)
    69  	}
    70  
    71  	body := url.Values{
    72  		"client_id":     {client.ClientID},
    73  		"redirect_uri":  {client.Redirect},
    74  		"client_secret": {client.ClientSecret},
    75  	}
    76  	if options.code != "" {
    77  		body.Add("grant_type", "authorization_code")
    78  		body.Add("code", options.code)
    79  	} else {
    80  		body.Add("grant_type", "refresh_token")
    81  		body.Add("refresh_token", options.refreshToken)
    82  	}
    83  	strBody := body.Encode()
    84  
    85  	res := client.Request.Request(
    86  		"POST",
    87  		client.Endpoints.OAuthEndpoints.token.String(),
    88  		ioutil.NopCloser(strings.NewReader(strBody)),
    89  		request.WithHeader(http.Header{
    90  			"Content-Type": {"application/x-www-form-urlencoded"}},
    91  		),
    92  		request.WithContentLength(int64(len(strBody))),
    93  	)
    94  	if res.Err != nil {
    95  		return nil, res.Err
    96  	}
    97  
    98  	respBody, err := res.GetResponse()
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	var (
   104  		errResp    OAuthError
   105  		credential Credential
   106  		decodeErr  error
   107  	)
   108  
   109  	if res.Response.StatusCode != 200 {
   110  		decodeErr = json.Unmarshal([]byte(respBody), &errResp)
   111  	} else {
   112  		decodeErr = json.Unmarshal([]byte(respBody), &credential)
   113  	}
   114  	if decodeErr != nil {
   115  		return nil, decodeErr
   116  	}
   117  
   118  	if errResp.ErrorType != "" {
   119  		return nil, errResp
   120  	}
   121  
   122  	return &credential, nil
   123  
   124  }
   125  
   126  // UpdateCredential 更新凭证,并检查有效期
   127  func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error {
   128  	if isSlave {
   129  		return client.fetchCredentialFromMaster(ctx)
   130  	}
   131  
   132  	oauth.GlobalMutex.Lock(client.Policy.ID)
   133  	defer oauth.GlobalMutex.Unlock(client.Policy.ID)
   134  
   135  	// 如果已存在凭证
   136  	if client.Credential != nil && client.Credential.AccessToken != "" {
   137  		// 检查已有凭证是否过期
   138  		if client.Credential.ExpiresIn > time.Now().Unix() {
   139  			// 未过期,不要更新
   140  			return nil
   141  		}
   142  	}
   143  
   144  	// 尝试从缓存中获取凭证
   145  	if cacheCredential, ok := cache.Get("onedrive_" + client.ClientID); ok {
   146  		credential := cacheCredential.(Credential)
   147  		if credential.ExpiresIn > time.Now().Unix() {
   148  			client.Credential = &credential
   149  			return nil
   150  		}
   151  	}
   152  
   153  	// 获取新的凭证
   154  	if client.Credential == nil || client.Credential.RefreshToken == "" {
   155  		// 无有效的RefreshToken
   156  		util.Log().Error("Failed to refresh credential for policy %q, please login your Microsoft account again.", client.Policy.Name)
   157  		return ErrInvalidRefreshToken
   158  	}
   159  
   160  	credential, err := client.ObtainToken(ctx, WithRefreshToken(client.Credential.RefreshToken))
   161  	if err != nil {
   162  		return err
   163  	}
   164  
   165  	// 更新有效期为绝对时间戳
   166  	expires := credential.ExpiresIn - 60
   167  	credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix()
   168  	client.Credential = credential
   169  
   170  	// 更新存储策略的 RefreshToken
   171  	client.Policy.UpdateAccessKeyAndClearCache(credential.RefreshToken)
   172  
   173  	// 更新缓存
   174  	cache.Set("onedrive_"+client.ClientID, *credential, int(expires))
   175  
   176  	return nil
   177  }
   178  
   179  func (client *Client) AccessToken() string {
   180  	return client.Credential.AccessToken
   181  }
   182  
   183  // UpdateCredential 更新凭证,并检查有效期
   184  func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
   185  	res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID)
   186  	if err != nil {
   187  		return err
   188  	}
   189  
   190  	client.Credential = &Credential{AccessToken: res}
   191  	return nil
   192  }