github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/middleware/cluster_test.go (about) 1 package middleware 2 3 import ( 4 "errors" 5 model "github.com/cloudreve/Cloudreve/v3/models" 6 "github.com/cloudreve/Cloudreve/v3/pkg/aria2/common" 7 "github.com/cloudreve/Cloudreve/v3/pkg/auth" 8 "github.com/cloudreve/Cloudreve/v3/pkg/cluster" 9 "github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock" 10 "github.com/gin-gonic/gin" 11 "github.com/jinzhu/gorm" 12 "github.com/stretchr/testify/assert" 13 "net/http/httptest" 14 "testing" 15 ) 16 17 func TestMasterMetadata(t *testing.T) { 18 a := assert.New(t) 19 masterMetaDataFunc := MasterMetadata() 20 rec := httptest.NewRecorder() 21 c, _ := gin.CreateTestContext(rec) 22 c.Request = httptest.NewRequest("GET", "/", nil) 23 24 c.Request.Header = map[string][]string{ 25 "X-Cr-Site-Id": {"expectedSiteID"}, 26 "X-Cr-Site-Url": {"expectedSiteURL"}, 27 "X-Cr-Cloudreve-Version": {"expectedMasterVersion"}, 28 } 29 masterMetaDataFunc(c) 30 siteID, _ := c.Get("MasterSiteID") 31 siteURL, _ := c.Get("MasterSiteURL") 32 siteVersion, _ := c.Get("MasterVersion") 33 34 a.Equal("expectedSiteID", siteID.(string)) 35 a.Equal("expectedSiteURL", siteURL.(string)) 36 a.Equal("expectedMasterVersion", siteVersion.(string)) 37 } 38 39 func TestSlaveRPCSignRequired(t *testing.T) { 40 a := assert.New(t) 41 np := &cluster.NodePool{} 42 np.Init() 43 slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np) 44 rec := httptest.NewRecorder() 45 46 // id parse failed 47 { 48 c, _ := gin.CreateTestContext(rec) 49 c.Request = httptest.NewRequest("GET", "/", nil) 50 c.Request.Header.Set("X-Cr-Node-Id", "unknown") 51 slaveRPCSignRequiredFunc(c) 52 a.True(c.IsAborted()) 53 } 54 55 // node id not exist 56 { 57 c, _ := gin.CreateTestContext(rec) 58 c.Request = httptest.NewRequest("GET", "/", nil) 59 c.Request.Header.Set("X-Cr-Node-Id", "38") 60 slaveRPCSignRequiredFunc(c) 61 a.True(c.IsAborted()) 62 } 63 64 // success 65 { 66 authInstance := auth.HMACAuth{SecretKey: []byte("")} 67 np.Add(&model.Node{Model: gorm.Model{ 68 ID: 38, 69 }}) 70 71 c, _ := gin.CreateTestContext(rec) 72 c.Request = httptest.NewRequest("POST", "/", nil) 73 c.Request.Header.Set("X-Cr-Node-Id", "38") 74 c.Request = auth.SignRequest(authInstance, c.Request, 0) 75 slaveRPCSignRequiredFunc(c) 76 a.False(c.IsAborted()) 77 } 78 } 79 80 func TestUseSlaveAria2Instance(t *testing.T) { 81 a := assert.New(t) 82 83 // MasterSiteID not set 84 { 85 testController := &controllermock.SlaveControllerMock{} 86 useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) 87 c, _ := gin.CreateTestContext(httptest.NewRecorder()) 88 c.Request = httptest.NewRequest("GET", "/", nil) 89 useSlaveAria2InstanceFunc(c) 90 a.True(c.IsAborted()) 91 } 92 93 // Cannot get aria2 instances 94 { 95 testController := &controllermock.SlaveControllerMock{} 96 useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) 97 c, _ := gin.CreateTestContext(httptest.NewRecorder()) 98 c.Request = httptest.NewRequest("GET", "/", nil) 99 c.Set("MasterSiteID", "expectedSiteID") 100 testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error")) 101 useSlaveAria2InstanceFunc(c) 102 a.True(c.IsAborted()) 103 testController.AssertExpectations(t) 104 } 105 106 // Success 107 { 108 testController := &controllermock.SlaveControllerMock{} 109 useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController) 110 c, _ := gin.CreateTestContext(httptest.NewRecorder()) 111 c.Request = httptest.NewRequest("GET", "/", nil) 112 c.Set("MasterSiteID", "expectedSiteID") 113 testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil) 114 useSlaveAria2InstanceFunc(c) 115 a.False(c.IsAborted()) 116 res, _ := c.Get("MasterAria2Instance") 117 a.NotNil(res) 118 testController.AssertExpectations(t) 119 } 120 }