github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/pkg/cluster/controller_test.go (about)

     1  package cluster
     2  
     3  import (
     4  	model "github.com/cloudreve/Cloudreve/v3/models"
     5  	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
     6  	"github.com/cloudreve/Cloudreve/v3/pkg/auth"
     7  	"github.com/cloudreve/Cloudreve/v3/pkg/mq"
     8  	"github.com/cloudreve/Cloudreve/v3/pkg/request"
     9  	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
    10  	"github.com/stretchr/testify/assert"
    11  	testMock "github.com/stretchr/testify/mock"
    12  	"io"
    13  	"io/ioutil"
    14  	"net/http"
    15  	"strings"
    16  	"testing"
    17  )
    18  
    19  func TestInitController(t *testing.T) {
    20  	assert.NotPanics(t, func() {
    21  		InitController()
    22  	})
    23  }
    24  
    25  func TestSlaveController_HandleHeartBeat(t *testing.T) {
    26  	a := assert.New(t)
    27  	c := &slaveController{
    28  		masters: make(map[string]MasterInfo),
    29  	}
    30  
    31  	// first heart beat
    32  	{
    33  		_, err := c.HandleHeartBeat(&serializer.NodePingReq{
    34  			SiteID: "1",
    35  			Node:   &model.Node{},
    36  		})
    37  		a.NoError(err)
    38  
    39  		_, err = c.HandleHeartBeat(&serializer.NodePingReq{
    40  			SiteID: "2",
    41  			Node:   &model.Node{},
    42  		})
    43  		a.NoError(err)
    44  
    45  		a.Len(c.masters, 2)
    46  	}
    47  
    48  	// second heart beat, no fresh
    49  	{
    50  		_, err := c.HandleHeartBeat(&serializer.NodePingReq{
    51  			SiteID:  "1",
    52  			SiteURL: "http://127.0.0.1",
    53  			Node:    &model.Node{},
    54  		})
    55  		a.NoError(err)
    56  		a.Len(c.masters, 2)
    57  		a.Empty(c.masters["1"].URL)
    58  	}
    59  
    60  	// second heart beat, fresh
    61  	{
    62  		_, err := c.HandleHeartBeat(&serializer.NodePingReq{
    63  			SiteID:   "1",
    64  			IsUpdate: true,
    65  			SiteURL:  "http://127.0.0.1",
    66  			Node:     &model.Node{},
    67  		})
    68  		a.NoError(err)
    69  		a.Len(c.masters, 2)
    70  		a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
    71  	}
    72  
    73  	// second heart beat, fresh, url illegal
    74  	{
    75  		_, err := c.HandleHeartBeat(&serializer.NodePingReq{
    76  			SiteID:   "1",
    77  			IsUpdate: true,
    78  			SiteURL:  string([]byte{0x7f}),
    79  			Node:     &model.Node{},
    80  		})
    81  		a.Error(err)
    82  		a.Len(c.masters, 2)
    83  		a.Equal("http://127.0.0.1", c.masters["1"].URL.String())
    84  	}
    85  }
    86  
    87  type nodeMock struct {
    88  	testMock.Mock
    89  }
    90  
    91  func (n nodeMock) Init(node *model.Node) {
    92  	n.Called(node)
    93  }
    94  
    95  func (n nodeMock) IsFeatureEnabled(feature string) bool {
    96  	args := n.Called(feature)
    97  	return args.Bool(0)
    98  }
    99  
   100  func (n nodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
   101  	n.Called(callback)
   102  }
   103  
   104  func (n nodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
   105  	args := n.Called(req)
   106  	return args.Get(0).(*serializer.NodePingResp), args.Error(1)
   107  }
   108  
   109  func (n nodeMock) IsActive() bool {
   110  	args := n.Called()
   111  	return args.Bool(0)
   112  }
   113  
   114  func (n nodeMock) GetAria2Instance() common.Aria2 {
   115  	args := n.Called()
   116  	return args.Get(0).(common.Aria2)
   117  }
   118  
   119  func (n nodeMock) ID() uint {
   120  	args := n.Called()
   121  	return args.Get(0).(uint)
   122  }
   123  
   124  func (n nodeMock) Kill() {
   125  	n.Called()
   126  }
   127  
   128  func (n nodeMock) IsMater() bool {
   129  	args := n.Called()
   130  	return args.Bool(0)
   131  }
   132  
   133  func (n nodeMock) MasterAuthInstance() auth.Auth {
   134  	args := n.Called()
   135  	return args.Get(0).(auth.Auth)
   136  }
   137  
   138  func (n nodeMock) SlaveAuthInstance() auth.Auth {
   139  	args := n.Called()
   140  	return args.Get(0).(auth.Auth)
   141  }
   142  
   143  func (n nodeMock) DBModel() *model.Node {
   144  	args := n.Called()
   145  	return args.Get(0).(*model.Node)
   146  }
   147  
   148  func TestSlaveController_GetAria2Instance(t *testing.T) {
   149  	a := assert.New(t)
   150  	mockNode := &nodeMock{}
   151  	mockNode.On("GetAria2Instance").Return(&common.DummyAria2{})
   152  	c := &slaveController{
   153  		masters: map[string]MasterInfo{
   154  			"1": {Instance: mockNode},
   155  		},
   156  	}
   157  
   158  	// node node found
   159  	{
   160  		res, err := c.GetAria2Instance("2")
   161  		a.Nil(res)
   162  		a.Equal(ErrMasterNotFound, err)
   163  	}
   164  
   165  	// node found
   166  	{
   167  		res, err := c.GetAria2Instance("1")
   168  		a.NotNil(res)
   169  		a.NoError(err)
   170  		mockNode.AssertExpectations(t)
   171  	}
   172  
   173  }
   174  
   175  type requestMock struct {
   176  	testMock.Mock
   177  }
   178  
   179  func (r requestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
   180  	return r.Called(method, target, body, opts).Get(0).(*request.Response)
   181  }
   182  
   183  func TestSlaveController_SendNotification(t *testing.T) {
   184  	a := assert.New(t)
   185  	c := &slaveController{
   186  		masters: map[string]MasterInfo{
   187  			"1": {},
   188  		},
   189  	}
   190  
   191  	// node not exit
   192  	{
   193  		a.Equal(ErrMasterNotFound, c.SendNotification("2", "", mq.Message{}))
   194  	}
   195  
   196  	// gob encode error
   197  	{
   198  		type randomType struct{}
   199  		a.Error(c.SendNotification("1", "", mq.Message{
   200  			Content: randomType{},
   201  		}))
   202  	}
   203  
   204  	// return none 200
   205  	{
   206  		mockRequest := &requestMock{}
   207  		mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s1", testMock.Anything, testMock.Anything).Return(&request.Response{
   208  			Response: &http.Response{StatusCode: http.StatusConflict},
   209  		})
   210  		c := &slaveController{
   211  			masters: map[string]MasterInfo{
   212  				"1": {Client: mockRequest},
   213  			},
   214  		}
   215  		a.Error(c.SendNotification("1", "s1", mq.Message{}))
   216  		mockRequest.AssertExpectations(t)
   217  	}
   218  
   219  	// master return error
   220  	{
   221  		mockRequest := &requestMock{}
   222  		mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s2", testMock.Anything, testMock.Anything).Return(&request.Response{
   223  			Response: &http.Response{
   224  				StatusCode: 200,
   225  				Body:       ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
   226  			},
   227  		})
   228  		c := &slaveController{
   229  			masters: map[string]MasterInfo{
   230  				"1": {Client: mockRequest},
   231  			},
   232  		}
   233  		a.Equal(1, c.SendNotification("1", "s2", mq.Message{}).(serializer.AppError).Code)
   234  		mockRequest.AssertExpectations(t)
   235  	}
   236  
   237  	// success
   238  	{
   239  		mockRequest := &requestMock{}
   240  		mockRequest.On("Request", "PUT", "/api/v3/slave/notification/s3", testMock.Anything, testMock.Anything).Return(&request.Response{
   241  			Response: &http.Response{
   242  				StatusCode: 200,
   243  				Body:       ioutil.NopCloser(strings.NewReader("{\"code\":0}")),
   244  			},
   245  		})
   246  		c := &slaveController{
   247  			masters: map[string]MasterInfo{
   248  				"1": {Client: mockRequest},
   249  			},
   250  		}
   251  		a.NoError(c.SendNotification("1", "s3", mq.Message{}))
   252  		mockRequest.AssertExpectations(t)
   253  	}
   254  }
   255  
   256  func TestSlaveController_SubmitTask(t *testing.T) {
   257  	a := assert.New(t)
   258  	c := &slaveController{
   259  		masters: map[string]MasterInfo{
   260  			"1": {
   261  				jobTracker: map[string]bool{},
   262  			},
   263  		},
   264  	}
   265  
   266  	// node not exit
   267  	{
   268  		a.Equal(ErrMasterNotFound, c.SubmitTask("2", "", "", nil))
   269  	}
   270  
   271  	// success
   272  	{
   273  		submitted := false
   274  		a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
   275  			submitted = true
   276  		}))
   277  		a.True(submitted)
   278  	}
   279  
   280  	// job already submitted
   281  	{
   282  		submitted := false
   283  		a.NoError(c.SubmitTask("1", "", "hash", func(i interface{}) {
   284  			submitted = true
   285  		}))
   286  		a.False(submitted)
   287  	}
   288  }
   289  
   290  func TestSlaveController_GetMasterInfo(t *testing.T) {
   291  	a := assert.New(t)
   292  	c := &slaveController{
   293  		masters: map[string]MasterInfo{
   294  			"1": {},
   295  		},
   296  	}
   297  
   298  	// node not exit
   299  	{
   300  		res, err := c.GetMasterInfo("2")
   301  		a.Equal(ErrMasterNotFound, err)
   302  		a.Nil(res)
   303  	}
   304  
   305  	// success
   306  	{
   307  		res, err := c.GetMasterInfo("1")
   308  		a.NoError(err)
   309  		a.NotNil(res)
   310  	}
   311  }
   312  
   313  func TestSlaveController_GetOneDriveToken(t *testing.T) {
   314  	a := assert.New(t)
   315  	c := &slaveController{
   316  		masters: map[string]MasterInfo{
   317  			"1": {},
   318  		},
   319  	}
   320  
   321  	// node not exit
   322  	{
   323  		res, err := c.GetPolicyOauthToken("2", 1)
   324  		a.Equal(ErrMasterNotFound, err)
   325  		a.Empty(res)
   326  	}
   327  
   328  	// return none 200
   329  	{
   330  		mockRequest := &requestMock{}
   331  		mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
   332  			Response: &http.Response{StatusCode: http.StatusConflict},
   333  		})
   334  		c := &slaveController{
   335  			masters: map[string]MasterInfo{
   336  				"1": {Client: mockRequest},
   337  			},
   338  		}
   339  		res, err := c.GetPolicyOauthToken("1", 1)
   340  		a.Error(err)
   341  		a.Empty(res)
   342  		mockRequest.AssertExpectations(t)
   343  	}
   344  
   345  	// master return error
   346  	{
   347  		mockRequest := &requestMock{}
   348  		mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
   349  			Response: &http.Response{
   350  				StatusCode: 200,
   351  				Body:       ioutil.NopCloser(strings.NewReader("{\"code\":1}")),
   352  			},
   353  		})
   354  		c := &slaveController{
   355  			masters: map[string]MasterInfo{
   356  				"1": {Client: mockRequest},
   357  			},
   358  		}
   359  		res, err := c.GetPolicyOauthToken("1", 1)
   360  		a.Equal(1, err.(serializer.AppError).Code)
   361  		a.Empty(res)
   362  		mockRequest.AssertExpectations(t)
   363  	}
   364  
   365  	// success
   366  	{
   367  		mockRequest := &requestMock{}
   368  		mockRequest.On("Request", "GET", "/api/v3/slave/credential/1", testMock.Anything, testMock.Anything).Return(&request.Response{
   369  			Response: &http.Response{
   370  				StatusCode: 200,
   371  				Body:       ioutil.NopCloser(strings.NewReader("{\"data\":\"expected\"}")),
   372  			},
   373  		})
   374  		c := &slaveController{
   375  			masters: map[string]MasterInfo{
   376  				"1": {Client: mockRequest},
   377  			},
   378  		}
   379  		res, err := c.GetPolicyOauthToken("1", 1)
   380  		a.NoError(err)
   381  		a.Equal("expected", res)
   382  		mockRequest.AssertExpectations(t)
   383  	}
   384  
   385  }