github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/service/callback/oauth.go (about)

     1  package callback
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	model "github.com/cloudreve/Cloudreve/v3/models"
     8  	"github.com/cloudreve/Cloudreve/v3/pkg/cache"
     9  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive"
    10  	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive"
    11  	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
    12  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    13  	"github.com/gin-gonic/gin"
    14  	"github.com/samber/lo"
    15  	"strings"
    16  )
    17  
    18  // OauthService OAuth 存储策略授权回调服务
    19  type OauthService struct {
    20  	Code     string `form:"code"`
    21  	Error    string `form:"error"`
    22  	ErrorMsg string `form:"error_description"`
    23  	Scope    string `form:"scope"`
    24  }
    25  
    26  // GDriveAuth Google Drive 更新认证信息
    27  func (service *OauthService) GDriveAuth(c *gin.Context) serializer.Response {
    28  	if service.Error != "" {
    29  		return serializer.ParamErr(service.Error, nil)
    30  	}
    31  
    32  	// validate required scope
    33  	if missing, found := lo.Find[string](googledrive.RequiredScope, func(item string) bool {
    34  		return !strings.Contains(service.Scope, item)
    35  	}); found {
    36  		return serializer.ParamErr(fmt.Sprintf("Missing required scope: %s", missing), nil)
    37  	}
    38  
    39  	policyID, ok := util.GetSession(c, "googledrive_oauth_policy").(uint)
    40  	if !ok {
    41  		return serializer.Err(serializer.CodeNotFound, "", nil)
    42  	}
    43  
    44  	util.DeleteSession(c, "googledrive_oauth_policy")
    45  
    46  	policy, err := model.GetPolicyByID(policyID)
    47  	if err != nil {
    48  		return serializer.Err(serializer.CodePolicyNotExist, "", nil)
    49  	}
    50  
    51  	client, err := googledrive.NewClient(&policy)
    52  	if err != nil {
    53  		return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize Google Drive client", err)
    54  	}
    55  
    56  	credential, err := client.ObtainToken(c, service.Code, "")
    57  	if err != nil {
    58  		return serializer.Err(serializer.CodeInternalSetting, "Failed to fetch AccessToken", err)
    59  	}
    60  
    61  	// 更新存储策略的 RefreshToken
    62  	client.Policy.AccessKey = credential.RefreshToken
    63  	if err := client.Policy.SaveAndClearCache(); err != nil {
    64  		return serializer.DBErr("Failed to update RefreshToken", err)
    65  	}
    66  
    67  	cache.Deletes([]string{client.Policy.AccessKey}, googledrive.TokenCachePrefix)
    68  	return serializer.Response{}
    69  }
    70  
    71  // OdAuth OneDrive 更新认证信息
    72  func (service *OauthService) OdAuth(c *gin.Context) serializer.Response {
    73  	if service.Error != "" {
    74  		return serializer.ParamErr(service.ErrorMsg, nil)
    75  	}
    76  
    77  	policyID, ok := util.GetSession(c, "onedrive_oauth_policy").(uint)
    78  	if !ok {
    79  		return serializer.Err(serializer.CodeNotFound, "", nil)
    80  	}
    81  
    82  	util.DeleteSession(c, "onedrive_oauth_policy")
    83  
    84  	policy, err := model.GetPolicyByID(policyID)
    85  	if err != nil {
    86  		return serializer.Err(serializer.CodePolicyNotExist, "", nil)
    87  	}
    88  
    89  	client, err := onedrive.NewClient(&policy)
    90  	if err != nil {
    91  		return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize OneDrive client", err)
    92  	}
    93  
    94  	credential, err := client.ObtainToken(c, onedrive.WithCode(service.Code))
    95  	if err != nil {
    96  		return serializer.Err(serializer.CodeInternalSetting, "Failed to fetch AccessToken", err)
    97  	}
    98  
    99  	// 更新存储策略的 RefreshToken
   100  	client.Policy.AccessKey = credential.RefreshToken
   101  	if err := client.Policy.SaveAndClearCache(); err != nil {
   102  		return serializer.DBErr("Failed to update RefreshToken", err)
   103  	}
   104  
   105  	cache.Deletes([]string{client.Policy.AccessKey}, "onedrive_")
   106  	if client.Policy.OptionsSerialized.OdDriver != "" && strings.Contains(client.Policy.OptionsSerialized.OdDriver, "http") {
   107  		if err := querySharePointSiteID(c, client.Policy); err != nil {
   108  			return serializer.Err(serializer.CodeInternalSetting, "Failed to query SharePoint site ID", err)
   109  		}
   110  	}
   111  
   112  	return serializer.Response{}
   113  }
   114  
   115  func querySharePointSiteID(ctx context.Context, policy *model.Policy) error {
   116  	client, err := onedrive.NewClient(policy)
   117  	if err != nil {
   118  		return err
   119  	}
   120  
   121  	id, err := client.GetSiteIDByURL(ctx, client.Policy.OptionsSerialized.OdDriver)
   122  	if err != nil {
   123  		return err
   124  	}
   125  
   126  	client.Policy.OptionsSerialized.OdDriver = fmt.Sprintf("sites/%s/drive", id)
   127  	if err := client.Policy.SaveAndClearCache(); err != nil {
   128  		return err
   129  	}
   130  
   131  	return nil
   132  }