github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/service/callback/oauth.go (about) 1 package callback 2 3 import ( 4 "context" 5 "fmt" 6 7 model "github.com/cloudreve/Cloudreve/v3/models" 8 "github.com/cloudreve/Cloudreve/v3/pkg/cache" 9 "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive" 10 "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" 11 "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 12 "github.com/cloudreve/Cloudreve/v3/pkg/util" 13 "github.com/gin-gonic/gin" 14 "github.com/samber/lo" 15 "strings" 16 ) 17 18 // OauthService OAuth 存储策略授权回调服务 19 type OauthService struct { 20 Code string `form:"code"` 21 Error string `form:"error"` 22 ErrorMsg string `form:"error_description"` 23 Scope string `form:"scope"` 24 } 25 26 // GDriveAuth Google Drive 更新认证信息 27 func (service *OauthService) GDriveAuth(c *gin.Context) serializer.Response { 28 if service.Error != "" { 29 return serializer.ParamErr(service.Error, nil) 30 } 31 32 // validate required scope 33 if missing, found := lo.Find[string](googledrive.RequiredScope, func(item string) bool { 34 return !strings.Contains(service.Scope, item) 35 }); found { 36 return serializer.ParamErr(fmt.Sprintf("Missing required scope: %s", missing), nil) 37 } 38 39 policyID, ok := util.GetSession(c, "googledrive_oauth_policy").(uint) 40 if !ok { 41 return serializer.Err(serializer.CodeNotFound, "", nil) 42 } 43 44 util.DeleteSession(c, "googledrive_oauth_policy") 45 46 policy, err := model.GetPolicyByID(policyID) 47 if err != nil { 48 return serializer.Err(serializer.CodePolicyNotExist, "", nil) 49 } 50 51 client, err := googledrive.NewClient(&policy) 52 if err != nil { 53 return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize Google Drive client", err) 54 } 55 56 credential, err := client.ObtainToken(c, service.Code, "") 57 if err != nil { 58 return serializer.Err(serializer.CodeInternalSetting, "Failed to fetch AccessToken", err) 59 } 60 61 // 更新存储策略的 RefreshToken 62 client.Policy.AccessKey = credential.RefreshToken 63 if err := client.Policy.SaveAndClearCache(); err != nil { 64 return serializer.DBErr("Failed to update RefreshToken", err) 65 } 66 67 cache.Deletes([]string{client.Policy.AccessKey}, googledrive.TokenCachePrefix) 68 return serializer.Response{} 69 } 70 71 // OdAuth OneDrive 更新认证信息 72 func (service *OauthService) OdAuth(c *gin.Context) serializer.Response { 73 if service.Error != "" { 74 return serializer.ParamErr(service.ErrorMsg, nil) 75 } 76 77 policyID, ok := util.GetSession(c, "onedrive_oauth_policy").(uint) 78 if !ok { 79 return serializer.Err(serializer.CodeNotFound, "", nil) 80 } 81 82 util.DeleteSession(c, "onedrive_oauth_policy") 83 84 policy, err := model.GetPolicyByID(policyID) 85 if err != nil { 86 return serializer.Err(serializer.CodePolicyNotExist, "", nil) 87 } 88 89 client, err := onedrive.NewClient(&policy) 90 if err != nil { 91 return serializer.Err(serializer.CodeInternalSetting, "Failed to initialize OneDrive client", err) 92 } 93 94 credential, err := client.ObtainToken(c, onedrive.WithCode(service.Code)) 95 if err != nil { 96 return serializer.Err(serializer.CodeInternalSetting, "Failed to fetch AccessToken", err) 97 } 98 99 // 更新存储策略的 RefreshToken 100 client.Policy.AccessKey = credential.RefreshToken 101 if err := client.Policy.SaveAndClearCache(); err != nil { 102 return serializer.DBErr("Failed to update RefreshToken", err) 103 } 104 105 cache.Deletes([]string{client.Policy.AccessKey}, "onedrive_") 106 if client.Policy.OptionsSerialized.OdDriver != "" && strings.Contains(client.Policy.OptionsSerialized.OdDriver, "http") { 107 if err := querySharePointSiteID(c, client.Policy); err != nil { 108 return serializer.Err(serializer.CodeInternalSetting, "Failed to query SharePoint site ID", err) 109 } 110 } 111 112 return serializer.Response{} 113 } 114 115 func querySharePointSiteID(ctx context.Context, policy *model.Policy) error { 116 client, err := onedrive.NewClient(policy) 117 if err != nil { 118 return err 119 } 120 121 id, err := client.GetSiteIDByURL(ctx, client.Policy.OptionsSerialized.OdDriver) 122 if err != nil { 123 return err 124 } 125 126 client.Policy.OptionsSerialized.OdDriver = fmt.Sprintf("sites/%s/drive", id) 127 if err := client.Policy.SaveAndClearCache(); err != nil { 128 return err 129 } 130 131 return nil 132 }