github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/filesystem/driver/onedrive/oauth_test.go (about)

     1  package onedrive
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock"
     8  	"io"
     9  	"io/ioutil"
    10  	"net/http"
    11  	"net/url"
    12  	"strings"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/DATA-DOG/go-sqlmock"
    17  	model "github.com/cloudreve/Cloudreve/v3/models"
    18  	"github.com/cloudreve/Cloudreve/v3/pkg/cache"
    19  	"github.com/cloudreve/Cloudreve/v3/pkg/request"
    20  	"github.com/jinzhu/gorm"
    21  	"github.com/stretchr/testify/assert"
    22  	testMock "github.com/stretchr/testify/mock"
    23  )
    24  
    25  var mock sqlmock.Sqlmock
    26  
    27  // TestMain 初始化数据库Mock
    28  func TestMain(m *testing.M) {
    29  	var db *sql.DB
    30  	var err error
    31  	db, mock, err = sqlmock.New()
    32  	if err != nil {
    33  		panic("An error was not expected when opening a stub database connection")
    34  	}
    35  	model.DB, _ = gorm.Open("mysql", db)
    36  	defer db.Close()
    37  	m.Run()
    38  }
    39  
    40  func TestGetOAuthEndpoint(t *testing.T) {
    41  	asserts := assert.New(t)
    42  
    43  	// URL解析失败
    44  	{
    45  		client := Client{
    46  			Endpoints: &Endpoints{
    47  				OAuthURL: string([]byte{0x7f}),
    48  			},
    49  		}
    50  		res := client.getOAuthEndpoint()
    51  		asserts.Nil(res)
    52  	}
    53  
    54  	{
    55  		testCase := []struct {
    56  			OAuthURL string
    57  			token    string
    58  			auth     string
    59  			isChina  bool
    60  		}{
    61  			{
    62  				OAuthURL: "http://login.live.com",
    63  				token:    "https://login.live.com/oauth20_token.srf",
    64  				auth:     "https://login.live.com/oauth20_authorize.srf",
    65  				isChina:  false,
    66  			},
    67  			{
    68  				OAuthURL: "http://login.chinacloudapi.cn",
    69  				token:    "https://login.chinacloudapi.cn/common/oauth2/v2.0/token",
    70  				auth:     "https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize",
    71  				isChina:  true,
    72  			},
    73  			{
    74  				OAuthURL: "other",
    75  				token:    "https://login.microsoftonline.com/common/oauth2/v2.0/token",
    76  				auth:     "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
    77  				isChina:  false,
    78  			},
    79  		}
    80  
    81  		for i, testCase := range testCase {
    82  			client := Client{
    83  				Endpoints: &Endpoints{
    84  					OAuthURL: testCase.OAuthURL,
    85  				},
    86  			}
    87  			res := client.getOAuthEndpoint()
    88  			asserts.Equal(testCase.token, res.token.String(), "Test Case #%d", i)
    89  			asserts.Equal(testCase.auth, res.authorize.String(), "Test Case #%d", i)
    90  			asserts.Equal(testCase.isChina, client.Endpoints.isInChina, "Test Case #%d", i)
    91  		}
    92  	}
    93  }
    94  
    95  func TestClient_OAuthURL(t *testing.T) {
    96  	asserts := assert.New(t)
    97  
    98  	client := Client{
    99  		ClientID:  "client_id",
   100  		Redirect:  "http://cloudreve.org/callback",
   101  		Endpoints: &Endpoints{},
   102  	}
   103  	client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint()
   104  	res, err := url.Parse(client.OAuthURL(context.Background(), []string{"scope1", "scope2"}))
   105  	asserts.NoError(err)
   106  	query := res.Query()
   107  	asserts.Equal("client_id", query.Get("client_id"))
   108  	asserts.Equal("scope1 scope2", query.Get("scope"))
   109  	asserts.Equal(client.Redirect, query.Get("redirect_uri"))
   110  
   111  }
   112  
   113  type ClientMock struct {
   114  	testMock.Mock
   115  }
   116  
   117  func (m ClientMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
   118  	args := m.Called(method, target, body, opts)
   119  	return args.Get(0).(*request.Response)
   120  }
   121  
   122  type mockReader string
   123  
   124  func (r mockReader) Read(b []byte) (int, error) {
   125  	return 0, errors.New("read error")
   126  }
   127  
   128  func TestClient_ObtainToken(t *testing.T) {
   129  	asserts := assert.New(t)
   130  
   131  	client := Client{
   132  		Endpoints:    &Endpoints{},
   133  		ClientID:     "ClientID",
   134  		ClientSecret: "ClientSecret",
   135  		Redirect:     "Redirect",
   136  	}
   137  	client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint()
   138  
   139  	// 刷新Token 成功
   140  	{
   141  		clientMock := ClientMock{}
   142  		clientMock.On(
   143  			"Request",
   144  			"POST",
   145  			client.Endpoints.OAuthEndpoints.token.String(),
   146  			testMock.Anything,
   147  			testMock.Anything,
   148  		).Return(&request.Response{
   149  			Err: nil,
   150  			Response: &http.Response{
   151  				StatusCode: 200,
   152  				Body:       ioutil.NopCloser(strings.NewReader(`{"access_token":"i am token"}`)),
   153  			},
   154  		})
   155  		client.Request = clientMock
   156  
   157  		res, err := client.ObtainToken(context.Background())
   158  		clientMock.AssertExpectations(t)
   159  		asserts.NoError(err)
   160  		asserts.NotNil(res)
   161  		asserts.Equal("i am token", res.AccessToken)
   162  	}
   163  
   164  	// 重新获取 无法发送请求
   165  	{
   166  		clientMock := ClientMock{}
   167  		clientMock.On(
   168  			"Request",
   169  			"POST",
   170  			client.Endpoints.OAuthEndpoints.token.String(),
   171  			testMock.Anything,
   172  			testMock.Anything,
   173  		).Return(&request.Response{
   174  			Err: errors.New("error"),
   175  		})
   176  		client.Request = clientMock
   177  
   178  		res, err := client.ObtainToken(context.Background(), WithCode("code"))
   179  		clientMock.AssertExpectations(t)
   180  		asserts.Error(err)
   181  		asserts.Nil(res)
   182  	}
   183  
   184  	// 刷新Token 无法获取响应正文
   185  	{
   186  		clientMock := ClientMock{}
   187  		clientMock.On(
   188  			"Request",
   189  			"POST",
   190  			client.Endpoints.OAuthEndpoints.token.String(),
   191  			testMock.Anything,
   192  			testMock.Anything,
   193  		).Return(&request.Response{
   194  			Err: nil,
   195  			Response: &http.Response{
   196  				StatusCode: 200,
   197  				Body:       ioutil.NopCloser(mockReader("")),
   198  			},
   199  		})
   200  		client.Request = clientMock
   201  
   202  		res, err := client.ObtainToken(context.Background())
   203  		clientMock.AssertExpectations(t)
   204  		asserts.Error(err)
   205  		asserts.Nil(res)
   206  		asserts.Equal("read error", err.Error())
   207  	}
   208  
   209  	// 刷新Token OneDrive返回错误
   210  	{
   211  		clientMock := ClientMock{}
   212  		clientMock.On(
   213  			"Request",
   214  			"POST",
   215  			client.Endpoints.OAuthEndpoints.token.String(),
   216  			testMock.Anything,
   217  			testMock.Anything,
   218  		).Return(&request.Response{
   219  			Err: nil,
   220  			Response: &http.Response{
   221  				StatusCode: 400,
   222  				Body:       ioutil.NopCloser(strings.NewReader(`{"error":"i am error"}`)),
   223  			},
   224  		})
   225  		client.Request = clientMock
   226  
   227  		res, err := client.ObtainToken(context.Background())
   228  		clientMock.AssertExpectations(t)
   229  		asserts.Error(err)
   230  		asserts.Nil(res)
   231  		asserts.Equal("", err.Error())
   232  	}
   233  
   234  	// 刷新Token OneDrive未知响应
   235  	{
   236  		clientMock := ClientMock{}
   237  		clientMock.On(
   238  			"Request",
   239  			"POST",
   240  			client.Endpoints.OAuthEndpoints.token.String(),
   241  			testMock.Anything,
   242  			testMock.Anything,
   243  		).Return(&request.Response{
   244  			Err: nil,
   245  			Response: &http.Response{
   246  				StatusCode: 400,
   247  				Body:       ioutil.NopCloser(strings.NewReader(`???`)),
   248  			},
   249  		})
   250  		client.Request = clientMock
   251  
   252  		res, err := client.ObtainToken(context.Background())
   253  		clientMock.AssertExpectations(t)
   254  		asserts.Error(err)
   255  		asserts.Nil(res)
   256  	}
   257  }
   258  
   259  func TestClient_UpdateCredential(t *testing.T) {
   260  	asserts := assert.New(t)
   261  	client := Client{
   262  		Policy:       &model.Policy{Model: gorm.Model{ID: 257}},
   263  		Endpoints:    &Endpoints{},
   264  		ClientID:     "TestClient_UpdateCredential",
   265  		ClientSecret: "ClientSecret",
   266  		Redirect:     "Redirect",
   267  		Credential:   &Credential{},
   268  	}
   269  	client.Endpoints.OAuthEndpoints = client.getOAuthEndpoint()
   270  
   271  	// 无有效的RefreshToken
   272  	{
   273  		err := client.UpdateCredential(context.Background(), false)
   274  		asserts.Equal(ErrInvalidRefreshToken, err)
   275  		client.Credential = nil
   276  		err = client.UpdateCredential(context.Background(), false)
   277  		asserts.Equal(ErrInvalidRefreshToken, err)
   278  	}
   279  
   280  	// 成功
   281  	{
   282  		clientMock := ClientMock{}
   283  		clientMock.On(
   284  			"Request",
   285  			"POST",
   286  			client.Endpoints.OAuthEndpoints.token.String(),
   287  			testMock.Anything,
   288  			testMock.Anything,
   289  		).Return(&request.Response{
   290  			Err: nil,
   291  			Response: &http.Response{
   292  				StatusCode: 200,
   293  				Body:       ioutil.NopCloser(strings.NewReader(`{"expires_in":3600,"refresh_token":"new_refresh_token","access_token":"i am token"}`)),
   294  			},
   295  		})
   296  		client.Request = clientMock
   297  		client.Credential = &Credential{
   298  			RefreshToken: "old_refresh_token",
   299  		}
   300  		mock.ExpectBegin()
   301  		mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1))
   302  		mock.ExpectCommit()
   303  		err := client.UpdateCredential(context.Background(), false)
   304  		clientMock.AssertExpectations(t)
   305  		asserts.NoError(mock.ExpectationsWereMet())
   306  		asserts.NoError(err)
   307  		cacheRes, ok := cache.Get("onedrive_TestClient_UpdateCredential")
   308  		asserts.True(ok)
   309  		cacheCredential := cacheRes.(Credential)
   310  		asserts.Equal("new_refresh_token", cacheCredential.RefreshToken)
   311  		asserts.Equal("i am token", cacheCredential.AccessToken)
   312  	}
   313  
   314  	// OneDrive返回错误
   315  	{
   316  		cache.Deletes([]string{"TestClient_UpdateCredential"}, "onedrive_")
   317  		clientMock := ClientMock{}
   318  		clientMock.On(
   319  			"Request",
   320  			"POST",
   321  			client.Endpoints.OAuthEndpoints.token.String(),
   322  			testMock.Anything,
   323  			testMock.Anything,
   324  		).Return(&request.Response{
   325  			Err: nil,
   326  			Response: &http.Response{
   327  				StatusCode: 400,
   328  				Body:       ioutil.NopCloser(strings.NewReader(`{"error":"error"}`)),
   329  			},
   330  		})
   331  		client.Request = clientMock
   332  		client.Credential = &Credential{
   333  			RefreshToken: "old_refresh_token",
   334  		}
   335  		err := client.UpdateCredential(context.Background(), false)
   336  		clientMock.AssertExpectations(t)
   337  		asserts.Error(err)
   338  	}
   339  
   340  	// 从缓存中获取
   341  	{
   342  		cache.Set("onedrive_TestClient_UpdateCredential", Credential{
   343  			ExpiresIn:    time.Now().Add(time.Duration(10) * time.Second).Unix(),
   344  			AccessToken:  "AccessToken",
   345  			RefreshToken: "RefreshToken",
   346  		}, 0)
   347  		client.Credential = &Credential{
   348  			RefreshToken: "old_refresh_token",
   349  		}
   350  		err := client.UpdateCredential(context.Background(), false)
   351  		asserts.NoError(err)
   352  		asserts.Equal("AccessToken", client.Credential.AccessToken)
   353  		asserts.Equal("RefreshToken", client.Credential.RefreshToken)
   354  	}
   355  
   356  	// 无需重新获取
   357  	{
   358  		client.Credential = &Credential{
   359  			RefreshToken: "old_refresh_token",
   360  			AccessToken:  "AccessToken2",
   361  			ExpiresIn:    time.Now().Add(time.Duration(10) * time.Second).Unix(),
   362  		}
   363  		err := client.UpdateCredential(context.Background(), false)
   364  		asserts.NoError(err)
   365  		asserts.Equal("AccessToken2", client.Credential.AccessToken)
   366  	}
   367  
   368  	// slave failed
   369  	{
   370  		mockController := &controllermock.SlaveControllerMock{}
   371  		mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("", errors.New("error"))
   372  		client.ClusterController = mockController
   373  		err := client.UpdateCredential(context.Background(), true)
   374  		asserts.Error(err)
   375  	}
   376  
   377  	// slave success
   378  	{
   379  		mockController := &controllermock.SlaveControllerMock{}
   380  		mockController.On("GetPolicyOauthToken", testMock.Anything, testMock.Anything).Return("AccessToken3", nil)
   381  		client.ClusterController = mockController
   382  		err := client.UpdateCredential(context.Background(), true)
   383  		asserts.NoError(err)
   384  		asserts.Equal("AccessToken3", client.Credential.AccessToken)
   385  	}
   386  }