github.com/gofiber/fiber/v2@v2.47.0/middleware/session/store.go (about)

     1  package session
     2  
     3  import (
     4  	"encoding/gob"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/gofiber/fiber/v2"
     9  	"github.com/gofiber/fiber/v2/internal/storage/memory"
    10  	"github.com/gofiber/fiber/v2/utils"
    11  
    12  	"github.com/valyala/fasthttp"
    13  )
    14  
    15  type Store struct {
    16  	Config
    17  }
    18  
    19  var mux sync.Mutex
    20  
    21  func New(config ...Config) *Store {
    22  	// Set default config
    23  	cfg := configDefault(config...)
    24  
    25  	if cfg.Storage == nil {
    26  		cfg.Storage = memory.New()
    27  	}
    28  
    29  	return &Store{
    30  		cfg,
    31  	}
    32  }
    33  
    34  // RegisterType will allow you to encode/decode custom types
    35  // into any Storage provider
    36  func (*Store) RegisterType(i interface{}) {
    37  	gob.Register(i)
    38  }
    39  
    40  // Get will get/create a session
    41  func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
    42  	var fresh bool
    43  	loadData := true
    44  
    45  	id := s.getSessionID(c)
    46  
    47  	if len(id) == 0 {
    48  		fresh = true
    49  		var err error
    50  		if id, err = s.responseCookies(c); err != nil {
    51  			return nil, err
    52  		}
    53  	}
    54  
    55  	// If no key exist, create new one
    56  	if len(id) == 0 {
    57  		loadData = false
    58  		id = s.KeyGenerator()
    59  	}
    60  
    61  	// Create session object
    62  	sess := acquireSession()
    63  	sess.ctx = c
    64  	sess.config = s
    65  	sess.id = id
    66  	sess.fresh = fresh
    67  
    68  	// Fetch existing data
    69  	if loadData {
    70  		raw, err := s.Storage.Get(id)
    71  		// Unmarshal if we found data
    72  		if raw != nil && err == nil {
    73  			mux.Lock()
    74  			defer mux.Unlock()
    75  			_, _ = sess.byteBuffer.Write(raw) //nolint:errcheck // This will never fail
    76  			encCache := gob.NewDecoder(sess.byteBuffer)
    77  			err := encCache.Decode(&sess.data.Data)
    78  			if err != nil {
    79  				return nil, fmt.Errorf("failed to decode session data: %w", err)
    80  			}
    81  		} else if err != nil {
    82  			return nil, err
    83  		} else {
    84  			// both raw and err is nil, which means id is not in the storage
    85  			sess.fresh = true
    86  		}
    87  	}
    88  
    89  	return sess, nil
    90  }
    91  
    92  // getSessionID will return the session id from:
    93  // 1. cookie
    94  // 2. http headers
    95  // 3. query string
    96  func (s *Store) getSessionID(c *fiber.Ctx) string {
    97  	id := c.Cookies(s.sessionName)
    98  	if len(id) > 0 {
    99  		return utils.CopyString(id)
   100  	}
   101  
   102  	if s.source == SourceHeader {
   103  		id = string(c.Request().Header.Peek(s.sessionName))
   104  		if len(id) > 0 {
   105  			return id
   106  		}
   107  	}
   108  
   109  	if s.source == SourceURLQuery {
   110  		id = c.Query(s.sessionName)
   111  		if len(id) > 0 {
   112  			return utils.CopyString(id)
   113  		}
   114  	}
   115  
   116  	return ""
   117  }
   118  
   119  func (s *Store) responseCookies(c *fiber.Ctx) (string, error) {
   120  	// Get key from response cookie
   121  	cookieValue := c.Response().Header.PeekCookie(s.sessionName)
   122  	if len(cookieValue) == 0 {
   123  		return "", nil
   124  	}
   125  
   126  	cookie := fasthttp.AcquireCookie()
   127  	defer fasthttp.ReleaseCookie(cookie)
   128  	err := cookie.ParseBytes(cookieValue)
   129  	if err != nil {
   130  		return "", err
   131  	}
   132  
   133  	value := make([]byte, len(cookie.Value()))
   134  	copy(value, cookie.Value())
   135  	id := string(value)
   136  	return id, nil
   137  }
   138  
   139  // Reset will delete all session from the storage
   140  func (s *Store) Reset() error {
   141  	return s.Storage.Reset()
   142  }