github.com/cloudreve/Cloudreve/v3@v3.0.0-20240224133659-3edb00a6484c/middleware/session.go (about)

     1  package middleware
     2  
     3  import (
     4  	"github.com/cloudreve/Cloudreve/v3/pkg/cache"
     5  	"github.com/cloudreve/Cloudreve/v3/pkg/sessionstore"
     6  	"net/http"
     7  	"strings"
     8  
     9  	"github.com/cloudreve/Cloudreve/v3/pkg/conf"
    10  	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
    11  	"github.com/cloudreve/Cloudreve/v3/pkg/util"
    12  	"github.com/gin-contrib/sessions"
    13  	"github.com/gin-gonic/gin"
    14  )
    15  
    16  // Store session存储
    17  var Store sessions.Store
    18  
    19  // Session 初始化session
    20  func Session(secret string) gin.HandlerFunc {
    21  	// Redis设置不为空,且非测试模式时使用Redis
    22  	Store = sessionstore.NewStore(cache.Store, []byte(secret))
    23  
    24  	sameSiteMode := http.SameSiteDefaultMode
    25  	switch strings.ToLower(conf.CORSConfig.SameSite) {
    26  	case "default":
    27  		sameSiteMode = http.SameSiteDefaultMode
    28  	case "none":
    29  		sameSiteMode = http.SameSiteNoneMode
    30  	case "strict":
    31  		sameSiteMode = http.SameSiteStrictMode
    32  	case "lax":
    33  		sameSiteMode = http.SameSiteLaxMode
    34  	}
    35  
    36  	// Also set Secure: true if using SSL, you should though
    37  	Store.Options(sessions.Options{
    38  		HttpOnly: true,
    39  		MaxAge:   60 * 86400,
    40  		Path:     "/",
    41  		SameSite: sameSiteMode,
    42  		Secure:   conf.CORSConfig.Secure,
    43  	})
    44  
    45  	return sessions.Sessions("cloudreve-session", Store)
    46  }
    47  
    48  // CSRFInit 初始化CSRF标记
    49  func CSRFInit() gin.HandlerFunc {
    50  	return func(c *gin.Context) {
    51  		util.SetSession(c, map[string]interface{}{"CSRF": true})
    52  		c.Next()
    53  	}
    54  }
    55  
    56  // CSRFCheck 检查CSRF标记
    57  func CSRFCheck() gin.HandlerFunc {
    58  	return func(c *gin.Context) {
    59  		if check, ok := util.GetSession(c, "CSRF").(bool); ok && check {
    60  			c.Next()
    61  			return
    62  		}
    63  
    64  		c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "Invalid origin", nil))
    65  		c.Abort()
    66  	}
    67  }