github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/middleware/cluster_test.go (about)

     1  package middleware
     2  
     3  import (
     4  	"errors"
     5  	model "github.com/cloudreve/Cloudreve/v3/models"
     6  	"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
     7  	"github.com/cloudreve/Cloudreve/v3/pkg/auth"
     8  	"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
     9  	"github.com/cloudreve/Cloudreve/v3/pkg/mocks/controllermock"
    10  	"github.com/gin-gonic/gin"
    11  	"github.com/jinzhu/gorm"
    12  	"github.com/stretchr/testify/assert"
    13  	"net/http/httptest"
    14  	"testing"
    15  )
    16  
    17  func TestMasterMetadata(t *testing.T) {
    18  	a := assert.New(t)
    19  	masterMetaDataFunc := MasterMetadata()
    20  	rec := httptest.NewRecorder()
    21  	c, _ := gin.CreateTestContext(rec)
    22  	c.Request = httptest.NewRequest("GET", "/", nil)
    23  
    24  	c.Request.Header = map[string][]string{
    25  		"X-Cr-Site-Id":           {"expectedSiteID"},
    26  		"X-Cr-Site-Url":          {"expectedSiteURL"},
    27  		"X-Cr-Cloudreve-Version": {"expectedMasterVersion"},
    28  	}
    29  	masterMetaDataFunc(c)
    30  	siteID, _ := c.Get("MasterSiteID")
    31  	siteURL, _ := c.Get("MasterSiteURL")
    32  	siteVersion, _ := c.Get("MasterVersion")
    33  
    34  	a.Equal("expectedSiteID", siteID.(string))
    35  	a.Equal("expectedSiteURL", siteURL.(string))
    36  	a.Equal("expectedMasterVersion", siteVersion.(string))
    37  }
    38  
    39  func TestSlaveRPCSignRequired(t *testing.T) {
    40  	a := assert.New(t)
    41  	np := &cluster.NodePool{}
    42  	np.Init()
    43  	slaveRPCSignRequiredFunc := SlaveRPCSignRequired(np)
    44  	rec := httptest.NewRecorder()
    45  
    46  	// id parse failed
    47  	{
    48  		c, _ := gin.CreateTestContext(rec)
    49  		c.Request = httptest.NewRequest("GET", "/", nil)
    50  		c.Request.Header.Set("X-Cr-Node-Id", "unknown")
    51  		slaveRPCSignRequiredFunc(c)
    52  		a.True(c.IsAborted())
    53  	}
    54  
    55  	// node id not exist
    56  	{
    57  		c, _ := gin.CreateTestContext(rec)
    58  		c.Request = httptest.NewRequest("GET", "/", nil)
    59  		c.Request.Header.Set("X-Cr-Node-Id", "38")
    60  		slaveRPCSignRequiredFunc(c)
    61  		a.True(c.IsAborted())
    62  	}
    63  
    64  	// success
    65  	{
    66  		authInstance := auth.HMACAuth{SecretKey: []byte("")}
    67  		np.Add(&model.Node{Model: gorm.Model{
    68  			ID: 38,
    69  		}})
    70  
    71  		c, _ := gin.CreateTestContext(rec)
    72  		c.Request = httptest.NewRequest("POST", "/", nil)
    73  		c.Request.Header.Set("X-Cr-Node-Id", "38")
    74  		c.Request = auth.SignRequest(authInstance, c.Request, 0)
    75  		slaveRPCSignRequiredFunc(c)
    76  		a.False(c.IsAborted())
    77  	}
    78  }
    79  
    80  func TestUseSlaveAria2Instance(t *testing.T) {
    81  	a := assert.New(t)
    82  
    83  	// MasterSiteID not set
    84  	{
    85  		testController := &controllermock.SlaveControllerMock{}
    86  		useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
    87  		c, _ := gin.CreateTestContext(httptest.NewRecorder())
    88  		c.Request = httptest.NewRequest("GET", "/", nil)
    89  		useSlaveAria2InstanceFunc(c)
    90  		a.True(c.IsAborted())
    91  	}
    92  
    93  	// Cannot get aria2 instances
    94  	{
    95  		testController := &controllermock.SlaveControllerMock{}
    96  		useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
    97  		c, _ := gin.CreateTestContext(httptest.NewRecorder())
    98  		c.Request = httptest.NewRequest("GET", "/", nil)
    99  		c.Set("MasterSiteID", "expectedSiteID")
   100  		testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, errors.New("error"))
   101  		useSlaveAria2InstanceFunc(c)
   102  		a.True(c.IsAborted())
   103  		testController.AssertExpectations(t)
   104  	}
   105  
   106  	// Success
   107  	{
   108  		testController := &controllermock.SlaveControllerMock{}
   109  		useSlaveAria2InstanceFunc := UseSlaveAria2Instance(testController)
   110  		c, _ := gin.CreateTestContext(httptest.NewRecorder())
   111  		c.Request = httptest.NewRequest("GET", "/", nil)
   112  		c.Set("MasterSiteID", "expectedSiteID")
   113  		testController.On("GetAria2Instance", "expectedSiteID").Return(&common.DummyAria2{}, nil)
   114  		useSlaveAria2InstanceFunc(c)
   115  		a.False(c.IsAborted())
   116  		res, _ := c.Get("MasterAria2Instance")
   117  		a.NotNil(res)
   118  		testController.AssertExpectations(t)
   119  	}
   120  }