github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/filesystem/driver/onedrive/oauth_test.go (about) 1 package onedrive 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock" 8 "io" 9 "io/ioutil" 10 "net/http" 11 "net/url" 12 "strings" 13 "testing" 14 "time" 15 16 "github.com/DATA-DOG/go-sqlmock" 17 model "github.com/cloudreve/Cloudreve/v3/models" 18 "github.com/cloudreve/Cloudreve/v3/pkg/cache" 19 "github.com/cloudreve/Cloudreve/v3/pkg/request" 20 "github.com/jinzhu/gorm" 21 "github.com/stretchr/testify/assert" 22 testMock "github.com/stretchr/testify/mock" 23 ) 24 25 var mock sqlmock.Sqlmock 26 27 // TestMain 初始化数据库Mock 28 func TestMain(m *testing.M) { 29 var db *sql.DB 30 var err error 31 db, mock, err = sqlmock.New() 32 if err != nil { 33 panic("An error was not expected when opening a stub database connection") 34 } 35 model.DB, _ = gorm.Open("mysql", db) 36 defer db.Close() 37 m.Run() 38 } 39 40 func TestGetOAuthEndpoint(t *testing.T) { 41 asserts := assert.New(t) 42 43 // URL解析失败 44 { 45 client := Client{ 46 Endpoints: &Endpoints{ 47 OAuthURL: string([]byte{0x7f}), 48 }, 49 } 50 res := client.getOAuthEndpoint() 51 asserts.Nil(res) 52 } 53 54 { 55 testCase := []struct { 56 OAuthURL string 57 token string 58 auth string 59 isChina bool 60 }{ 61 { 62 OAuthURL: "http://login.live.com", 63 token: "https://login.live.com/oauth20_token.srf", 64 auth: "https://login.live.com/oauth20_authorize.srf", 65 isChina: false, 66 }, 67 { 68 OAuthURL: "http://login.chinacloudapi.cn", 69 token: "https://login.chinacloudapi.cn/common/oauth2/v2.0/token", 70 auth: "https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize", 71 isChina: true, 72 }, 73 { 74 OAuthURL: "other", 75 token: "https://login.microsoftonline.com/common/oauth2/v2.0/token", 76 auth: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", 77 isChina: false, 78 }, 79 } 80 81 for i, testCase := range testCase { 82 client := Client{ 83 Endpoints: &Endpoints{ 84 OAuthURL: testCase.OAuthURL, 85 }, 86 } 87 res := client.getOAuthEndpoint() 88 asserts.Equal(testCase.token, res.token.String(), "Test Case #%d", i) 89 asserts.Equal(testCase.auth, res.authorize.String(), "Test Case #%d", i) 90 asserts.Equal(testCase.isChina, client.Endpoints.isInChina, "Test Case #%d", i) 91 } 92 } 93 } 94 95 func TestClient_OAuthURL(t *testing.T) { 96 asserts := assert.New(t) 97 98 client := Client{ 99 ClientID: "client_id", 100 Redirect: "http://cloudreve.org/callback", 101 Endpoints: &Endpoints{}, 102 } 103 client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() 104 res, err := url.Parse(client.OAuthURL(context.Background(), []string{"scope1", "scope2"})) 105 asserts.NoError(err) 106 query := res.Query() 107 asserts.Equal("client_id", query.Get("client_id")) 108 asserts.Equal("scope1 scope2", query.Get("scope")) 109 asserts.Equal(client.Redirect, query.Get("redirect_uri")) 110 111 } 112 113 type ClientMock struct { 114 testMock.Mock 115 } 116 117 func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response { 118 args := m.Called(method, target, body, opts) 119 return args.Get(0).(*request.Response) 120 } 121 122 type mockReader string 123 124 func (r mockReader) Read(b []byte) (int, error) { 125 return 0, errors.New("read error") 126 } 127 128 func TestClient_ObtainToken(t *testing.T) { 129 asserts := assert.New(t) 130 131 client := Client{ 132 Endpoints: &Endpoints{}, 133 ClientID: "ClientID", 134 ClientSecret: "ClientSecret", 135 Redirect: "Redirect", 136 } 137 client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() 138 139 // 刷新Token 成功 140 { 141 clientMock := ClientMock{} 142 clientMock.On( 143 "Request", 144 "POST", 145 client.Endpoints.OAuthEndpoints.token.String(), 146 testMock.Anything, 147 testMock.Anything, 148 ).Return(&request.Response{ 149 Err: nil, 150 Response: &http.Response{ 151 StatusCode: 200, 152 Body: ioutil.NopCloser(strings.NewReader(`{"access_token":"i am token"}`)), 153 }, 154 }) 155 client.Request = clientMock 156 157 res, err := client.ObtainToken(context.Background()) 158 clientMock.AssertExpectations(t) 159 asserts.NoError(err) 160 asserts.NotNil(res) 161 asserts.Equal("i am token", res.AccessToken) 162 } 163 164 // 重新获取 无法发送请求 165 { 166 clientMock := ClientMock{} 167 clientMock.On( 168 "Request", 169 "POST", 170 client.Endpoints.OAuthEndpoints.token.String(), 171 testMock.Anything, 172 testMock.Anything, 173 ).Return(&request.Response{ 174 Err: errors.New("error"), 175 }) 176 client.Request = clientMock 177 178 res, err := client.ObtainToken(context.Background(), WithCode("code")) 179 clientMock.AssertExpectations(t) 180 asserts.Error(err) 181 asserts.Nil(res) 182 } 183 184 // 刷新Token 无法获取响应正文 185 { 186 clientMock := ClientMock{} 187 clientMock.On( 188 "Request", 189 "POST", 190 client.Endpoints.OAuthEndpoints.token.String(), 191 testMock.Anything, 192 testMock.Anything, 193 ).Return(&request.Response{ 194 Err: nil, 195 Response: &http.Response{ 196 StatusCode: 200, 197 Body: ioutil.NopCloser(mockReader("")), 198 }, 199 }) 200 client.Request = clientMock 201 202 res, err := client.ObtainToken(context.Background()) 203 clientMock.AssertExpectations(t) 204 asserts.Error(err) 205 asserts.Nil(res) 206 asserts.Equal("read error", err.Error()) 207 } 208 209 // 刷新Token OneDrive返回错误 210 { 211 clientMock := ClientMock{} 212 clientMock.On( 213 "Request", 214 "POST", 215 client.Endpoints.OAuthEndpoints.token.String(), 216 testMock.Anything, 217 testMock.Anything, 218 ).Return(&request.Response{ 219 Err: nil, 220 Response: &http.Response{ 221 StatusCode: 400, 222 Body: ioutil.NopCloser(strings.NewReader(`{"error":"i am error"}`)), 223 }, 224 }) 225 client.Request = clientMock 226 227 res, err := client.ObtainToken(context.Background()) 228 clientMock.AssertExpectations(t) 229 asserts.Error(err) 230 asserts.Nil(res) 231 asserts.Equal("", err.Error()) 232 } 233 234 // 刷新Token OneDrive未知响应 235 { 236 clientMock := ClientMock{} 237 clientMock.On( 238 "Request", 239 "POST", 240 client.Endpoints.OAuthEndpoints.token.String(), 241 testMock.Anything, 242 testMock.Anything, 243 ).Return(&request.Response{ 244 Err: nil, 245 Response: &http.Response{ 246 StatusCode: 400, 247 Body: ioutil.NopCloser(strings.NewReader(`???`)), 248 }, 249 }) 250 client.Request = clientMock 251 252 res, err := client.ObtainToken(context.Background()) 253 clientMock.AssertExpectations(t) 254 asserts.Error(err) 255 asserts.Nil(res) 256 } 257 } 258 259 func TestClient_UpdateCredential(t *testing.T) { 260 asserts := assert.New(t) 261 client := Client{ 262 Policy: &model.Policy{Model: gorm.Model{ID: 257}}, 263 Endpoints: &Endpoints{}, 264 ClientID: "TestClient_UpdateCredential", 265 ClientSecret: "ClientSecret", 266 Redirect: "Redirect", 267 Credential: &Credential{}, 268 } 269 client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint() 270 271 // 无有效的RefreshToken 272 { 273 err := client.UpdateCredential(context.Background(), false) 274 asserts.Equal(ErrInvalidRefreshToken, err) 275 client.Credential = nil 276 err = client.UpdateCredential(context.Background(), false) 277 asserts.Equal(ErrInvalidRefreshToken, err) 278 } 279 280 // 成功 281 { 282 clientMock := ClientMock{} 283 clientMock.On( 284 "Request", 285 "POST", 286 client.Endpoints.OAuthEndpoints.token.String(), 287 testMock.Anything, 288 testMock.Anything, 289 ).Return(&request.Response{ 290 Err: nil, 291 Response: &http.Response{ 292 StatusCode: 200, 293 Body: ioutil.NopCloser(strings.NewReader(`{"expires_in":3600,"refresh_token":"new_refresh_token","access_token":"i am token"}`)), 294 }, 295 }) 296 client.Request = clientMock 297 client.Credential = &Credential{ 298 RefreshToken: "old_refresh_token", 299 } 300 mock.ExpectBegin() 301 mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 302 mock.ExpectCommit() 303 err := client.UpdateCredential(context.Background(), false) 304 clientMock.AssertExpectations(t) 305 asserts.NoError(mock.ExpectationsWereMet()) 306 asserts.NoError(err) 307 cacheRes, ok := cache.Get("onedrive_TestClient_UpdateCredential") 308 asserts.True(ok) 309 cacheCredential := cacheRes.(Credential) 310 asserts.Equal("new_refresh_token", cacheCredential.RefreshToken) 311 asserts.Equal("i am token", cacheCredential.AccessToken) 312 } 313 314 // OneDrive返回错误 315 { 316 cache.Deletes([]string{"TestClient_UpdateCredential"}, "onedrive_") 317 clientMock := ClientMock{} 318 clientMock.On( 319 "Request", 320 "POST", 321 client.Endpoints.OAuthEndpoints.token.String(), 322 testMock.Anything, 323 testMock.Anything, 324 ).Return(&request.Response{ 325 Err: nil, 326 Response: &http.Response{ 327 StatusCode: 400, 328 Body: ioutil.NopCloser(strings.NewReader(`{"error":"error"}`)), 329 }, 330 }) 331 client.Request = clientMock 332 client.Credential = &Credential{ 333 RefreshToken: "old_refresh_token", 334 } 335 err := client.UpdateCredential(context.Background(), false) 336 clientMock.AssertExpectations(t) 337 asserts.Error(err) 338 } 339 340 // 从缓存中获取 341 { 342 cache.Set("onedrive_TestClient_UpdateCredential", Credential{ 343 ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(), 344 AccessToken: "AccessToken", 345 RefreshToken: "RefreshToken", 346 }, 0) 347 client.Credential = &Credential{ 348 RefreshToken: "old_refresh_token", 349 } 350 err := client.UpdateCredential(context.Background(), false) 351 asserts.NoError(err) 352 asserts.Equal("AccessToken", client.Credential.AccessToken) 353 asserts.Equal("RefreshToken", client.Credential.RefreshToken) 354 } 355 356 // 无需重新获取 357 { 358 client.Credential = &Credential{ 359 RefreshToken: "old_refresh_token", 360 AccessToken: "AccessToken2", 361 ExpiresIn: time.Now().Add(time.Duration(10) * time.Second).Unix(), 362 } 363 err := client.UpdateCredential(context.Background(), false) 364 asserts.NoError(err) 365 asserts.Equal("AccessToken2", client.Credential.AccessToken) 366 } 367 368 // slave failed 369 { 370 mockController := &controllermock.SlaveControllerMock{} 371 mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("", errors.New("error")) 372 client.ClusterController = mockController 373 err := client.UpdateCredential(context.Background(), true) 374 asserts.Error(err) 375 } 376 377 // slave success 378 { 379 mockController := &controllermock.SlaveControllerMock{} 380 mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("AccessToken3", nil) 381 client.ClusterController = mockController 382 err := client.UpdateCredential(context.Background(), true) 383 asserts.NoError(err) 384 asserts.Equal("AccessToken3", client.Credential.AccessToken) 385 } 386 }