github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/filesystem/driver/onedrive/oauth.go (about) 1 package onedrive 2 3 import ( 4 "context" 5 "encoding/json" 6 "io/ioutil" 7 "net/http" 8 "net/url" 9 "strings" 10 "time" 11 12 "github.com/cloudreve/Cloudreve/v3/pkg/cache" 13 "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth" 14 "github.com/cloudreve/Cloudreve/v3/pkg/request" 15 "github.com/cloudreve/Cloudreve/v3/pkg/util" 16 ) 17 18 // Error 实现error接口 19 func (err OAuthError) Error() string { 20 return err.ErrorDescription 21 } 22 23 // OAuthURL 获取OAuth认证页面URL 24 func (client *Client) OAuthURL(ctx context.Context, scope []string) string { 25 query := url.Values{ 26 "client_id": {client.ClientID}, 27 "scope": {strings.Join(scope, " ")}, 28 "response_type": {"code"}, 29 "redirect_uri": {client.Redirect}, 30 } 31 client.Endpoints.OAuthEndpoints.authorize.RawQuery = query.Encode() 32 return client.Endpoints.OAuthEndpoints.authorize.String() 33 } 34 35 // getOAuthEndpoint 根据指定的AuthURL获取详细的认证接口地址 36 func (client *Client) getOAuthEndpoint() *oauthEndpoint { 37 base, err := url.Parse(client.Endpoints.OAuthURL) 38 if err != nil { 39 return nil 40 } 41 var ( 42 token *url.URL 43 authorize *url.URL 44 ) 45 switch base.Host { 46 case "login.live.com": 47 token, _ = url.Parse("https://login.live.com/oauth20_token.srf") 48 authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf") 49 case "login.chinacloudapi.cn": 50 client.Endpoints.isInChina = true 51 token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token") 52 authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize") 53 default: 54 token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token") 55 authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize") 56 } 57 58 return &oauthEndpoint{ 59 token: *token, 60 authorize: *authorize, 61 } 62 } 63 64 // ObtainToken 通过code或refresh_token兑换token 65 func (client *Client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) { 66 options := newDefaultOption() 67 for _, o := range opts { 68 o.apply(options) 69 } 70 71 body := url.Values{ 72 "client_id": {client.ClientID}, 73 "redirect_uri": {client.Redirect}, 74 "client_secret": {client.ClientSecret}, 75 } 76 if options.code != "" { 77 body.Add("grant_type", "authorization_code") 78 body.Add("code", options.code) 79 } else { 80 body.Add("grant_type", "refresh_token") 81 body.Add("refresh_token", options.refreshToken) 82 } 83 strBody := body.Encode() 84 85 res := client.Request.Request( 86 "POST", 87 client.Endpoints.OAuthEndpoints.token.String(), 88 ioutil.NopCloser(strings.NewReader(strBody)), 89 request.WithHeader(http.Header{ 90 "Content-Type": {"application/x-www-form-urlencoded"}}, 91 ), 92 request.WithContentLength(int64(len(strBody))), 93 ) 94 if res.Err != nil { 95 return nil, res.Err 96 } 97 98 respBody, err := res.GetResponse() 99 if err != nil { 100 return nil, err 101 } 102 103 var ( 104 errResp OAuthError 105 credential Credential 106 decodeErr error 107 ) 108 109 if res.Response.StatusCode != 200 { 110 decodeErr = json.Unmarshal([]byte(respBody), &errResp) 111 } else { 112 decodeErr = json.Unmarshal([]byte(respBody), &credential) 113 } 114 if decodeErr != nil { 115 return nil, decodeErr 116 } 117 118 if errResp.ErrorType != "" { 119 return nil, errResp 120 } 121 122 return &credential, nil 123 124 } 125 126 // UpdateCredential 更新凭证,并检查有效期 127 func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error { 128 if isSlave { 129 return client.fetchCredentialFromMaster(ctx) 130 } 131 132 oauth.GlobalMutex.Lock(client.Policy.ID) 133 defer oauth.GlobalMutex.Unlock(client.Policy.ID) 134 135 // 如果已存在凭证 136 if client.Credential != nil && client.Credential.AccessToken != "" { 137 // 检查已有凭证是否过期 138 if client.Credential.ExpiresIn > time.Now().Unix() { 139 // 未过期,不要更新 140 return nil 141 } 142 } 143 144 // 尝试从缓存中获取凭证 145 if cacheCredential, ok := cache.Get("onedrive_" + client.ClientID); ok { 146 credential := cacheCredential.(Credential) 147 if credential.ExpiresIn > time.Now().Unix() { 148 client.Credential = &credential 149 return nil 150 } 151 } 152 153 // 获取新的凭证 154 if client.Credential == nil || client.Credential.RefreshToken == "" { 155 // 无有效的RefreshToken 156 util.Log().Error("Failed to refresh credential for policy %q, please login your Microsoft account again.", client.Policy.Name) 157 return ErrInvalidRefreshToken 158 } 159 160 credential, err := client.ObtainToken(ctx, WithRefreshToken(client.Credential.RefreshToken)) 161 if err != nil { 162 return err 163 } 164 165 // 更新有效期为绝对时间戳 166 expires := credential.ExpiresIn - 60 167 credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix() 168 client.Credential = credential 169 170 // 更新存储策略的 RefreshToken 171 client.Policy.UpdateAccessKeyAndClearCache(credential.RefreshToken) 172 173 // 更新缓存 174 cache.Set("onedrive_"+client.ClientID, *credential, int(expires)) 175 176 return nil 177 } 178 179 func (client *Client) AccessToken() string { 180 return client.Credential.AccessToken 181 } 182 183 // UpdateCredential 更新凭证,并检查有效期 184 func (client *Client) fetchCredentialFromMaster(ctx context.Context) error { 185 res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID) 186 if err != nil { 187 return err 188 } 189 190 client.Credential = &Credential{AccessToken: res} 191 return nil 192 }