github.com/zooyer/miskit@v1.0.71/micro/session.go (about)

     1  package micro
     2  
     3  import (
     4  	"context"
     5  	"encoding/base32"
     6  	"encoding/json"
     7  	"errors"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	ginsession "github.com/gin-contrib/sessions"
    13  	"github.com/gorilla/securecookie"
    14  	"github.com/gorilla/sessions"
    15  	"github.com/zooyer/miskit/imdb"
    16  )
    17  
    18  type store struct {
    19  	db      imdb.Conn
    20  	prefix  string
    21  	Codecs  []securecookie.Codec
    22  	Options *sessions.Options
    23  }
    24  
    25  type imdbStore struct {
    26  	*store
    27  }
    28  
    29  const defaultExpires = int64(20 * time.Minute / time.Second)
    30  
    31  func newStore(db imdb.Conn, prefix string, keyPairs ...[]byte) *store {
    32  	return &store{
    33  		db:     db,
    34  		prefix: prefix,
    35  		Codecs: securecookie.CodecsFromPairs(keyPairs...),
    36  		Options: &sessions.Options{
    37  			Path:   "/",
    38  			MaxAge: int(defaultExpires),
    39  		},
    40  	}
    41  }
    42  
    43  func (s *store) Get(r *http.Request, name string) (*sessions.Session, error) {
    44  	return sessions.GetRegistry(r).Get(s, name)
    45  }
    46  
    47  func (s *store) New(r *http.Request, name string) (*sessions.Session, error) {
    48  	var (
    49  		err error
    50  		ok  bool
    51  	)
    52  
    53  	session := sessions.NewSession(s, name)
    54  	options := *s.Options
    55  	session.Options = &options
    56  	session.IsNew = true
    57  
    58  	if c, errCookie := r.Cookie(name); errCookie == nil {
    59  		err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...)
    60  		if err == nil {
    61  			ok, err = s.load(r.Context(), session)
    62  			session.IsNew = err != nil || !ok
    63  		}
    64  	}
    65  
    66  	return session, err
    67  }
    68  
    69  func (s *store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
    70  	if session.Options.MaxAge <= 0 {
    71  		if err := s.delete(r.Context(), session); err != nil {
    72  			return err
    73  		}
    74  		http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
    75  	} else {
    76  		if session.ID == "" {
    77  			session.ID = strings.TrimRight(base32.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(32)), "=")
    78  		}
    79  		if err := s.save(r.Context(), session); err != nil {
    80  			return err
    81  		}
    82  		encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...)
    83  		if err != nil {
    84  			return err
    85  		}
    86  		http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
    87  	}
    88  	return nil
    89  }
    90  
    91  func (s *store) load(ctx context.Context, session *sessions.Session) (ok bool, err error) {
    92  	value, err := s.db.Get(ctx, s.prefix+session.ID)
    93  	if err != nil {
    94  		return
    95  	}
    96  
    97  	var values = make(map[string]interface{})
    98  	if err = json.Unmarshal([]byte(value), &values); err != nil {
    99  		return
   100  	}
   101  
   102  	for key, val := range values {
   103  		session.Values[key] = val
   104  	}
   105  
   106  	return true, nil
   107  }
   108  
   109  func (s *store) save(ctx context.Context, session *sessions.Session) (err error) {
   110  	var values = make(map[string]interface{})
   111  	for key, value := range session.Values {
   112  		key, ok := key.(string)
   113  		if !ok {
   114  			return errors.New("session key non-string")
   115  		}
   116  		values[key] = value
   117  	}
   118  
   119  	data, err := json.Marshal(values)
   120  	if err != nil {
   121  		return
   122  	}
   123  	var expires = defaultExpires
   124  	if session.Options.MaxAge > 0 {
   125  		expires = int64(session.Options.MaxAge)
   126  	}
   127  
   128  	if err = s.db.SetEx(ctx, s.prefix+session.ID, string(data), expires); err != nil {
   129  		return
   130  	}
   131  
   132  	return
   133  }
   134  
   135  func (s *store) delete(ctx context.Context, session *sessions.Session) (err error) {
   136  	if err = s.db.Del(ctx, s.prefix+session.ID); err != nil {
   137  		return
   138  	}
   139  
   140  	return
   141  }
   142  
   143  func (s *imdbStore) Options(options ginsession.Options) {
   144  	s.store.Options = options.ToGorillaOptions()
   145  }
   146  
   147  func NewStore(name, dsn, prefix string, keyPairs ...[]byte) (store ginsession.Store, err error) {
   148  	db, err := imdb.Open(name, dsn)
   149  	if err != nil {
   150  		return
   151  	}
   152  
   153  	return NewStoreWithIMDB(db, prefix, keyPairs...), nil
   154  }
   155  
   156  func NewStoreWithIMDB(db imdb.Conn, prefix string, keyPairs ...[]byte) ginsession.Store {
   157  	return &imdbStore{
   158  		store: newStore(db, prefix, keyPairs...),
   159  	}
   160  }