github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/middleware/session_test.go (about)

     1  package middleware
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"testing"
     7  
     8  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
     9  	"github.com/gin-gonic/gin"
    10  	"github.com/stretchr/testify/assert"
    11  )
    12  
    13  func TestSession(t *testing.T) {
    14  	asserts := assert.New(t)
    15  
    16  	{
    17  		handler := Session("2333")
    18  		asserts.NotNil(handler)
    19  		asserts.NotNil(Store)
    20  		asserts.IsType(emptyFunc(), handler)
    21  	}
    22  }
    23  
    24  func emptyFunc() gin.HandlerFunc {
    25  	return func(c *gin.Context) {}
    26  }
    27  
    28  func TestCSRFInit(t *testing.T) {
    29  	asserts := assert.New(t)
    30  	rec := httptest.NewRecorder()
    31  	sessionFunc := Session("233")
    32  	{
    33  		c, _ := gin.CreateTestContext(rec)
    34  		c.Request, _ = http.NewRequest("GET", "/test", nil)
    35  		sessionFunc(c)
    36  		CSRFInit()(c)
    37  		asserts.True(util.GetSession(c, "CSRF").(bool))
    38  	}
    39  }
    40  
    41  func TestCSRFCheck(t *testing.T) {
    42  	asserts := assert.New(t)
    43  	rec := httptest.NewRecorder()
    44  	sessionFunc := Session("233")
    45  
    46  	// 通过检查
    47  	{
    48  		c, _ := gin.CreateTestContext(rec)
    49  		c.Request, _ = http.NewRequest("GET", "/test", nil)
    50  		sessionFunc(c)
    51  		CSRFInit()(c)
    52  		CSRFCheck()(c)
    53  		asserts.False(c.IsAborted())
    54  	}
    55  
    56  	// 未通过检查
    57  	{
    58  		c, _ := gin.CreateTestContext(rec)
    59  		c.Request, _ = http.NewRequest("GET", "/test", nil)
    60  		sessionFunc(c)
    61  		CSRFCheck()(c)
    62  		asserts.True(c.IsAborted())
    63  	}
    64  }