code.gitea.io/gitea@v1.22.3/modules/session/db.go (about) 1 // Copyright 2020 The Gitea Authors. All rights reserved. 2 // SPDX-License-Identifier: MIT 3 4 package session 5 6 import ( 7 "log" 8 "sync" 9 10 "code.gitea.io/gitea/models/auth" 11 "code.gitea.io/gitea/models/db" 12 "code.gitea.io/gitea/modules/timeutil" 13 14 "gitea.com/go-chi/session" 15 ) 16 17 // DBStore represents a session store implementation based on the DB. 18 type DBStore struct { 19 sid string 20 lock sync.RWMutex 21 data map[any]any 22 } 23 24 // NewDBStore creates and returns a DB session store. 25 func NewDBStore(sid string, kv map[any]any) *DBStore { 26 return &DBStore{ 27 sid: sid, 28 data: kv, 29 } 30 } 31 32 // Set sets value to given key in session. 33 func (s *DBStore) Set(key, val any) error { 34 s.lock.Lock() 35 defer s.lock.Unlock() 36 37 s.data[key] = val 38 return nil 39 } 40 41 // Get gets value by given key in session. 42 func (s *DBStore) Get(key any) any { 43 s.lock.RLock() 44 defer s.lock.RUnlock() 45 46 return s.data[key] 47 } 48 49 // Delete delete a key from session. 50 func (s *DBStore) Delete(key any) error { 51 s.lock.Lock() 52 defer s.lock.Unlock() 53 54 delete(s.data, key) 55 return nil 56 } 57 58 // ID returns current session ID. 59 func (s *DBStore) ID() string { 60 return s.sid 61 } 62 63 // Release releases resource and save data to provider. 64 func (s *DBStore) Release() error { 65 // Skip encoding if the data is empty 66 if len(s.data) == 0 { 67 return nil 68 } 69 70 data, err := session.EncodeGob(s.data) 71 if err != nil { 72 return err 73 } 74 75 return auth.UpdateSession(db.DefaultContext, s.sid, data) 76 } 77 78 // Flush deletes all session data. 79 func (s *DBStore) Flush() error { 80 s.lock.Lock() 81 defer s.lock.Unlock() 82 83 s.data = make(map[any]any) 84 return nil 85 } 86 87 // DBProvider represents a DB session provider implementation. 88 type DBProvider struct { 89 maxLifetime int64 90 } 91 92 // Init initializes DB session provider. 93 // connStr: username:password@protocol(address)/dbname?param=value 94 func (p *DBProvider) Init(maxLifetime int64, connStr string) error { 95 p.maxLifetime = maxLifetime 96 return nil 97 } 98 99 // Read returns raw session store by session ID. 100 func (p *DBProvider) Read(sid string) (session.RawStore, error) { 101 s, err := auth.ReadSession(db.DefaultContext, sid) 102 if err != nil { 103 return nil, err 104 } 105 106 var kv map[any]any 107 if len(s.Data) == 0 || s.Expiry.Add(p.maxLifetime) <= timeutil.TimeStampNow() { 108 kv = make(map[any]any) 109 } else { 110 kv, err = session.DecodeGob(s.Data) 111 if err != nil { 112 return nil, err 113 } 114 } 115 116 return NewDBStore(sid, kv), nil 117 } 118 119 // Exist returns true if session with given ID exists. 120 func (p *DBProvider) Exist(sid string) bool { 121 has, err := auth.ExistSession(db.DefaultContext, sid) 122 if err != nil { 123 panic("session/DB: error checking existence: " + err.Error()) 124 } 125 return has 126 } 127 128 // Destroy deletes a session by session ID. 129 func (p *DBProvider) Destroy(sid string) error { 130 return auth.DestroySession(db.DefaultContext, sid) 131 } 132 133 // Regenerate regenerates a session store from old session ID to new one. 134 func (p *DBProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { 135 s, err := auth.RegenerateSession(db.DefaultContext, oldsid, sid) 136 if err != nil { 137 return nil, err 138 } 139 140 var kv map[any]any 141 if len(s.Data) == 0 || s.Expiry.Add(p.maxLifetime) <= timeutil.TimeStampNow() { 142 kv = make(map[any]any) 143 } else { 144 kv, err = session.DecodeGob(s.Data) 145 if err != nil { 146 return nil, err 147 } 148 } 149 150 return NewDBStore(sid, kv), nil 151 } 152 153 // Count counts and returns number of sessions. 154 func (p *DBProvider) Count() int { 155 total, err := auth.CountSessions(db.DefaultContext) 156 if err != nil { 157 panic("session/DB: error counting records: " + err.Error()) 158 } 159 return int(total) 160 } 161 162 // GC calls GC to clean expired sessions. 163 func (p *DBProvider) GC() { 164 if err := auth.CleanupSessions(db.DefaultContext, p.maxLifetime); err != nil { 165 log.Printf("session/DB: error garbage collecting: %v", err) 166 } 167 } 168 169 func init() { 170 session.Register("db", &DBProvider{}) 171 }