github.com/gofiber/fiber/v2@v2.47.0/middleware/session/session_test.go (about)

     1  package session
     2  
     3  import (
     4  	"testing"
     5  	"time"
     6  
     7  	"github.com/gofiber/fiber/v2"
     8  	"github.com/gofiber/fiber/v2/internal/storage/memory"
     9  	"github.com/gofiber/fiber/v2/utils"
    10  
    11  	"github.com/valyala/fasthttp"
    12  )
    13  
    14  // go test -run Test_Session
    15  func Test_Session(t *testing.T) {
    16  	t.Parallel()
    17  
    18  	// session store
    19  	store := New()
    20  
    21  	// fiber instance
    22  	app := fiber.New()
    23  
    24  	// fiber context
    25  	ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
    26  	defer app.ReleaseCtx(ctx)
    27  
    28  	// set session
    29  	ctx.Request().Header.SetCookie(store.sessionName, "123")
    30  
    31  	// get session
    32  	sess, err := store.Get(ctx)
    33  	utils.AssertEqual(t, nil, err)
    34  	utils.AssertEqual(t, true, sess.Fresh())
    35  
    36  	// get keys
    37  	keys := sess.Keys()
    38  	utils.AssertEqual(t, []string{}, keys)
    39  
    40  	// get value
    41  	name := sess.Get("name")
    42  	utils.AssertEqual(t, nil, name)
    43  
    44  	// set value
    45  	sess.Set("name", "john")
    46  
    47  	// get value
    48  	name = sess.Get("name")
    49  	utils.AssertEqual(t, "john", name)
    50  
    51  	keys = sess.Keys()
    52  	utils.AssertEqual(t, []string{"name"}, keys)
    53  
    54  	// delete key
    55  	sess.Delete("name")
    56  
    57  	// get value
    58  	name = sess.Get("name")
    59  	utils.AssertEqual(t, nil, name)
    60  
    61  	// get keys
    62  	keys = sess.Keys()
    63  	utils.AssertEqual(t, []string{}, keys)
    64  
    65  	// get id
    66  	id := sess.ID()
    67  	utils.AssertEqual(t, "123", id)
    68  
    69  	// save the old session first
    70  	err = sess.Save()
    71  	utils.AssertEqual(t, nil, err)
    72  
    73  	// requesting entirely new context to prevent falsy tests
    74  	ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    75  	defer app.ReleaseCtx(ctx)
    76  
    77  	sess, err = store.Get(ctx)
    78  	utils.AssertEqual(t, nil, err)
    79  	utils.AssertEqual(t, true, sess.Fresh())
    80  
    81  	// this id should be randomly generated as session key was deleted
    82  	utils.AssertEqual(t, 36, len(sess.ID()))
    83  
    84  	// when we use the original session for the second time
    85  	// the session be should be same if the session is not expired
    86  	ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    87  	defer app.ReleaseCtx(ctx)
    88  
    89  	// request the server with the old session
    90  	ctx.Request().Header.SetCookie(store.sessionName, id)
    91  	sess, err = store.Get(ctx)
    92  	utils.AssertEqual(t, nil, err)
    93  	utils.AssertEqual(t, false, sess.Fresh())
    94  	utils.AssertEqual(t, sess.id, id)
    95  }
    96  
    97  // go test -run Test_Session_Types
    98  //
    99  //nolint:forcetypeassert // TODO: Do not force-type assert
   100  func Test_Session_Types(t *testing.T) {
   101  	t.Parallel()
   102  
   103  	// session store
   104  	store := New()
   105  
   106  	// fiber instance
   107  	app := fiber.New()
   108  
   109  	// fiber context
   110  	ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
   111  	defer app.ReleaseCtx(ctx)
   112  
   113  	// set cookie
   114  	ctx.Request().Header.SetCookie(store.sessionName, "123")
   115  
   116  	// get session
   117  	sess, err := store.Get(ctx)
   118  	utils.AssertEqual(t, nil, err)
   119  	utils.AssertEqual(t, true, sess.Fresh())
   120  
   121  	// the session string is no longer be 123
   122  	newSessionIDString := sess.ID()
   123  	ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString)
   124  
   125  	type User struct {
   126  		Name string
   127  	}
   128  	store.RegisterType(User{})
   129  	vuser := User{
   130  		Name: "John",
   131  	}
   132  	// set value
   133  	var (
   134  		vbool                  = true
   135  		vstring                = "str"
   136  		vint                   = 13
   137  		vint8       int8       = 13
   138  		vint16      int16      = 13
   139  		vint32      int32      = 13
   140  		vint64      int64      = 13
   141  		vuint       uint       = 13
   142  		vuint8      uint8      = 13
   143  		vuint16     uint16     = 13
   144  		vuint32     uint32     = 13
   145  		vuint64     uint64     = 13
   146  		vuintptr    uintptr    = 13
   147  		vbyte       byte       = 'k'
   148  		vrune                  = 'k'
   149  		vfloat32    float32    = 13
   150  		vfloat64    float64    = 13
   151  		vcomplex64  complex64  = 13
   152  		vcomplex128 complex128 = 13
   153  	)
   154  	sess.Set("vuser", vuser)
   155  	sess.Set("vbool", vbool)
   156  	sess.Set("vstring", vstring)
   157  	sess.Set("vint", vint)
   158  	sess.Set("vint8", vint8)
   159  	sess.Set("vint16", vint16)
   160  	sess.Set("vint32", vint32)
   161  	sess.Set("vint64", vint64)
   162  	sess.Set("vuint", vuint)
   163  	sess.Set("vuint8", vuint8)
   164  	sess.Set("vuint16", vuint16)
   165  	sess.Set("vuint32", vuint32)
   166  	sess.Set("vuint32", vuint32)
   167  	sess.Set("vuint64", vuint64)
   168  	sess.Set("vuintptr", vuintptr)
   169  	sess.Set("vbyte", vbyte)
   170  	sess.Set("vrune", vrune)
   171  	sess.Set("vfloat32", vfloat32)
   172  	sess.Set("vfloat64", vfloat64)
   173  	sess.Set("vcomplex64", vcomplex64)
   174  	sess.Set("vcomplex128", vcomplex128)
   175  
   176  	// save session
   177  	err = sess.Save()
   178  	utils.AssertEqual(t, nil, err)
   179  
   180  	// get session
   181  	sess, err = store.Get(ctx)
   182  	utils.AssertEqual(t, nil, err)
   183  	utils.AssertEqual(t, false, sess.Fresh())
   184  
   185  	// get value
   186  	utils.AssertEqual(t, vuser, sess.Get("vuser").(User))
   187  	utils.AssertEqual(t, vbool, sess.Get("vbool").(bool))
   188  	utils.AssertEqual(t, vstring, sess.Get("vstring").(string))
   189  	utils.AssertEqual(t, vint, sess.Get("vint").(int))
   190  	utils.AssertEqual(t, vint8, sess.Get("vint8").(int8))
   191  	utils.AssertEqual(t, vint16, sess.Get("vint16").(int16))
   192  	utils.AssertEqual(t, vint32, sess.Get("vint32").(int32))
   193  	utils.AssertEqual(t, vint64, sess.Get("vint64").(int64))
   194  	utils.AssertEqual(t, vuint, sess.Get("vuint").(uint))
   195  	utils.AssertEqual(t, vuint8, sess.Get("vuint8").(uint8))
   196  	utils.AssertEqual(t, vuint16, sess.Get("vuint16").(uint16))
   197  	utils.AssertEqual(t, vuint32, sess.Get("vuint32").(uint32))
   198  	utils.AssertEqual(t, vuint64, sess.Get("vuint64").(uint64))
   199  	utils.AssertEqual(t, vuintptr, sess.Get("vuintptr").(uintptr))
   200  	utils.AssertEqual(t, vbyte, sess.Get("vbyte").(byte))
   201  	utils.AssertEqual(t, vrune, sess.Get("vrune").(rune))
   202  	utils.AssertEqual(t, vfloat32, sess.Get("vfloat32").(float32))
   203  	utils.AssertEqual(t, vfloat64, sess.Get("vfloat64").(float64))
   204  	utils.AssertEqual(t, vcomplex64, sess.Get("vcomplex64").(complex64))
   205  	utils.AssertEqual(t, vcomplex128, sess.Get("vcomplex128").(complex128))
   206  }
   207  
   208  // go test -run Test_Session_Store_Reset
   209  func Test_Session_Store_Reset(t *testing.T) {
   210  	t.Parallel()
   211  	// session store
   212  	store := New()
   213  	// fiber instance
   214  	app := fiber.New()
   215  	// fiber context
   216  	ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
   217  	defer app.ReleaseCtx(ctx)
   218  
   219  	// get session
   220  	sess, err := store.Get(ctx)
   221  	utils.AssertEqual(t, nil, err)
   222  	// make sure its new
   223  	utils.AssertEqual(t, true, sess.Fresh())
   224  	// set value & save
   225  	sess.Set("hello", "world")
   226  	ctx.Request().Header.SetCookie(store.sessionName, sess.ID())
   227  	utils.AssertEqual(t, nil, sess.Save())
   228  
   229  	// reset store
   230  	utils.AssertEqual(t, nil, store.Reset())
   231  
   232  	// make sure the session is recreated
   233  	sess, err = store.Get(ctx)
   234  	utils.AssertEqual(t, nil, err)
   235  	utils.AssertEqual(t, true, sess.Fresh())
   236  	utils.AssertEqual(t, nil, sess.Get("hello"))
   237  }
   238  
   239  // go test -run Test_Session_Save
   240  func Test_Session_Save(t *testing.T) {
   241  	t.Parallel()
   242  
   243  	t.Run("save to cookie", func(t *testing.T) {
   244  		// session store
   245  		store := New()
   246  		// fiber instance
   247  		app := fiber.New()
   248  		// fiber context
   249  		ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
   250  		defer app.ReleaseCtx(ctx)
   251  		// get session
   252  		sess, err := store.Get(ctx)
   253  		utils.AssertEqual(t, nil, err)
   254  		// set value
   255  		sess.Set("name", "john")
   256  
   257  		// save session
   258  		err = sess.Save()
   259  		utils.AssertEqual(t, nil, err)
   260  	})
   261  
   262  	t.Run("save to header", func(t *testing.T) {
   263  		// session store
   264  		store := New(Config{
   265  			KeyLookup: "header:session_id",
   266  		})
   267  		// fiber instance
   268  		app := fiber.New()
   269  		// fiber context
   270  		ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
   271  		defer app.ReleaseCtx(ctx)
   272  		// get session
   273  		sess, err := store.Get(ctx)
   274  		utils.AssertEqual(t, nil, err)
   275  		// set value
   276  		sess.Set("name", "john")
   277  
   278  		// save session
   279  		err = sess.Save()
   280  		utils.AssertEqual(t, nil, err)
   281  		utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName)))
   282  		utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName)))
   283  	})
   284  }
   285  
   286  func Test_Session_Save_Expiration(t *testing.T) {
   287  	t.Parallel()
   288  
   289  	t.Run("save to cookie", func(t *testing.T) {
   290  		t.Parallel()
   291  		// session store
   292  		store := New()
   293  		// fiber instance
   294  		app := fiber.New()
   295  		// fiber context
   296  		ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
   297  		defer app.ReleaseCtx(ctx)
   298  		// get session
   299  		sess, err := store.Get(ctx)
   300  		utils.AssertEqual(t, nil, err)
   301  		// set value
   302  		sess.Set("name", "john")
   303  
   304  		// expire this session in 5 seconds
   305  		sess.SetExpiry(time.Second * 5)
   306  
   307  		// save session
   308  		err = sess.Save()
   309  		utils.AssertEqual(t, nil, err)
   310  
   311  		// here you need to get the old session yet
   312  		sess, err = store.Get(ctx)
   313  		utils.AssertEqual(t, nil, err)
   314  		utils.AssertEqual(t, "john", sess.Get("name"))
   315  
   316  		// just to make sure the session has been expired
   317  		time.Sleep(time.Second * 5)
   318  
   319  		// here you should get a new session
   320  		sess, err = store.Get(ctx)
   321  		utils.AssertEqual(t, nil, err)
   322  		utils.AssertEqual(t, nil, sess.Get("name"))
   323  	})
   324  }
   325  
   326  // go test -run Test_Session_Reset
   327  func Test_Session_Reset(t *testing.T) {
   328  	t.Parallel()
   329  
   330  	t.Run("reset from cookie", func(t *testing.T) {
   331  		t.Parallel()
   332  		// session store
   333  		store := New()
   334  		// fiber instance
   335  		app := fiber.New()
   336  		// fiber context
   337  		ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
   338  		defer app.ReleaseCtx(ctx)
   339  		// get session
   340  		sess, err := store.Get(ctx)
   341  		utils.AssertEqual(t, nil, err)
   342  
   343  		sess.Set("name", "fenny")
   344  		utils.AssertEqual(t, nil, sess.Destroy())
   345  		name := sess.Get("name")
   346  		utils.AssertEqual(t, nil, name)
   347  	})
   348  
   349  	t.Run("reset from header", func(t *testing.T) {
   350  		t.Parallel()
   351  		// session store
   352  		store := New(Config{
   353  			KeyLookup: "header:session_id",
   354  		})
   355  		// fiber instance
   356  		app := fiber.New()
   357  		// fiber context
   358  		ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
   359  		defer app.ReleaseCtx(ctx)
   360  		// get session
   361  		sess, err := store.Get(ctx)
   362  		utils.AssertEqual(t, nil, err)
   363  
   364  		// set value & save
   365  		sess.Set("name", "fenny")
   366  		utils.AssertEqual(t, nil, sess.Save())
   367  		sess, err = store.Get(ctx)
   368  		utils.AssertEqual(t, nil, err)
   369  
   370  		err = sess.Destroy()
   371  		utils.AssertEqual(t, nil, err)
   372  		utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
   373  		utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
   374  	})
   375  }
   376  
   377  // go test -run Test_Session_Custom_Config
   378  func Test_Session_Custom_Config(t *testing.T) {
   379  	t.Parallel()
   380  
   381  	store := New(Config{Expiration: time.Hour, KeyGenerator: func() string { return "very random" }})
   382  	utils.AssertEqual(t, time.Hour, store.Expiration)
   383  	utils.AssertEqual(t, "very random", store.KeyGenerator())
   384  
   385  	store = New(Config{Expiration: 0})
   386  	utils.AssertEqual(t, ConfigDefault.Expiration, store.Expiration)
   387  }
   388  
   389  // go test -run Test_Session_Cookie
   390  func Test_Session_Cookie(t *testing.T) {
   391  	t.Parallel()
   392  	// session store
   393  	store := New()
   394  	// fiber instance
   395  	app := fiber.New()
   396  	// fiber context
   397  	ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
   398  	defer app.ReleaseCtx(ctx)
   399  
   400  	// get session
   401  	sess, err := store.Get(ctx)
   402  	utils.AssertEqual(t, nil, err)
   403  	utils.AssertEqual(t, nil, sess.Save())
   404  
   405  	// cookie should be set on Save ( even if empty data )
   406  	utils.AssertEqual(t, 84, len(ctx.Response().Header.PeekCookie(store.sessionName)))
   407  }
   408  
   409  // go test -run Test_Session_Cookie_In_Response
   410  func Test_Session_Cookie_In_Response(t *testing.T) {
   411  	t.Parallel()
   412  	store := New()
   413  	app := fiber.New()
   414  
   415  	// fiber context
   416  	ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
   417  	defer app.ReleaseCtx(ctx)
   418  
   419  	// get session
   420  	sess, err := store.Get(ctx)
   421  	utils.AssertEqual(t, nil, err)
   422  	sess.Set("id", "1")
   423  	utils.AssertEqual(t, true, sess.Fresh())
   424  	utils.AssertEqual(t, nil, sess.Save())
   425  
   426  	sess, err = store.Get(ctx)
   427  	utils.AssertEqual(t, nil, err)
   428  	sess.Set("name", "john")
   429  	utils.AssertEqual(t, true, sess.Fresh())
   430  
   431  	utils.AssertEqual(t, "1", sess.Get("id"))
   432  	utils.AssertEqual(t, "john", sess.Get("name"))
   433  }
   434  
   435  // go test -run Test_Session_Deletes_Single_Key
   436  // Regression: https://github.com/gofiber/fiber/issues/1365
   437  func Test_Session_Deletes_Single_Key(t *testing.T) {
   438  	t.Parallel()
   439  	store := New()
   440  	app := fiber.New()
   441  
   442  	ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
   443  	defer app.ReleaseCtx(ctx)
   444  
   445  	sess, err := store.Get(ctx)
   446  	utils.AssertEqual(t, nil, err)
   447  	ctx.Request().Header.SetCookie(store.sessionName, sess.ID())
   448  
   449  	sess.Set("id", "1")
   450  	utils.AssertEqual(t, nil, sess.Save())
   451  
   452  	sess, err = store.Get(ctx)
   453  	utils.AssertEqual(t, nil, err)
   454  	sess.Delete("id")
   455  	utils.AssertEqual(t, nil, sess.Save())
   456  
   457  	sess, err = store.Get(ctx)
   458  	utils.AssertEqual(t, nil, err)
   459  	utils.AssertEqual(t, false, sess.Fresh())
   460  	utils.AssertEqual(t, nil, sess.Get("id"))
   461  }
   462  
   463  // go test -run Test_Session_Regenerate
   464  // Regression: https://github.com/gofiber/fiber/issues/1395
   465  func Test_Session_Regenerate(t *testing.T) {
   466  	t.Parallel()
   467  	// fiber instance
   468  	app := fiber.New()
   469  	t.Run("set fresh to be true when regenerating a session", func(t *testing.T) {
   470  		// session store
   471  		store := New()
   472  		// a random session uuid
   473  		originalSessionUUIDString := ""
   474  		// fiber context
   475  		ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
   476  		defer app.ReleaseCtx(ctx)
   477  
   478  		// now the session is in the storage
   479  		freshSession, err := store.Get(ctx)
   480  		utils.AssertEqual(t, nil, err)
   481  
   482  		originalSessionUUIDString = freshSession.ID()
   483  
   484  		err = freshSession.Save()
   485  		utils.AssertEqual(t, nil, err)
   486  
   487  		// set cookie
   488  		ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
   489  
   490  		// as the session is in the storage, session.fresh should be false
   491  		acquiredSession, err := store.Get(ctx)
   492  		utils.AssertEqual(t, nil, err)
   493  		utils.AssertEqual(t, false, acquiredSession.Fresh())
   494  
   495  		err = acquiredSession.Regenerate()
   496  		utils.AssertEqual(t, nil, err)
   497  
   498  		if acquiredSession.ID() == originalSessionUUIDString {
   499  			t.Fatal("regenerate should generate another different id")
   500  		}
   501  		// acquiredSession.fresh should be true after regenerating
   502  		utils.AssertEqual(t, true, acquiredSession.Fresh())
   503  	})
   504  }
   505  
   506  // go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4
   507  func Benchmark_Session(b *testing.B) {
   508  	app, store := fiber.New(), New()
   509  	c := app.AcquireCtx(&fasthttp.RequestCtx{})
   510  	defer app.ReleaseCtx(c)
   511  	c.Request().Header.SetCookie(store.sessionName, "12356789")
   512  
   513  	var err error
   514  	b.Run("default", func(b *testing.B) {
   515  		b.ReportAllocs()
   516  		b.ResetTimer()
   517  		for n := 0; n < b.N; n++ {
   518  			sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
   519  			sess.Set("john", "doe")
   520  			err = sess.Save()
   521  		}
   522  
   523  		utils.AssertEqual(b, nil, err)
   524  	})
   525  
   526  	b.Run("storage", func(b *testing.B) {
   527  		store = New(Config{
   528  			Storage: memory.New(),
   529  		})
   530  		b.ReportAllocs()
   531  		b.ResetTimer()
   532  		for n := 0; n < b.N; n++ {
   533  			sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
   534  			sess.Set("john", "doe")
   535  			err = sess.Save()
   536  		}
   537  
   538  		utils.AssertEqual(b, nil, err)
   539  	})
   540  }