github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/middleware/session_test.go (about) 1 package middleware 2 3 import ( 4 "net/http" 5 "net/http/httptest" 6 "testing" 7 8 "github.com/cloudreve/Cloudreve/v3/pkg/util" 9 "github.com/gin-gonic/gin" 10 "github.com/stretchr/testify/assert" 11 ) 12 13 func TestSession(t *testing.T) { 14 asserts := assert.New(t) 15 16 { 17 handler := Session("2333") 18 asserts.NotNil(handler) 19 asserts.NotNil(Store) 20 asserts.IsType(emptyFunc(), handler) 21 } 22 } 23 24 func emptyFunc() gin.HandlerFunc { 25 return func(c *gin.Context) {} 26 } 27 28 func TestCSRFInit(t *testing.T) { 29 asserts := assert.New(t) 30 rec := httptest.NewRecorder() 31 sessionFunc := Session("233") 32 { 33 c, _ := gin.CreateTestContext(rec) 34 c.Request, _ = http.NewRequest("GET", "/test", nil) 35 sessionFunc(c) 36 CSRFInit()(c) 37 asserts.True(util.GetSession(c, "CSRF").(bool)) 38 } 39 } 40 41 func TestCSRFCheck(t *testing.T) { 42 asserts := assert.New(t) 43 rec := httptest.NewRecorder() 44 sessionFunc := Session("233") 45 46 // 通过检查 47 { 48 c, _ := gin.CreateTestContext(rec) 49 c.Request, _ = http.NewRequest("GET", "/test", nil) 50 sessionFunc(c) 51 CSRFInit()(c) 52 CSRFCheck()(c) 53 asserts.False(c.IsAborted()) 54 } 55 56 // 未通过检查 57 { 58 c, _ := gin.CreateTestContext(rec) 59 c.Request, _ = http.NewRequest("GET", "/test", nil) 60 sessionFunc(c) 61 CSRFCheck()(c) 62 asserts.True(c.IsAborted()) 63 } 64 }