github.com/cloudwego/localsession@v0.0.2/manager.go (about)

     1  // Copyright 2023 CloudWeGo Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package localsession
    16  
    17  import (
    18  	"sync"
    19  	"sync/atomic"
    20  	"time"
    21  )
    22  
    23  // ManagerOptions for SessionManager
    24  type ManagerOptions struct {
    25  	// EnableImplicitlyTransmitAsync enables transparently transmit
    26  	// current session to children goroutines
    27  	//
    28  	// WARNING: Once this option enables, if you want to use `pprof.Do()`, it must be called before `BindSession()`,
    29  	// otherwise transmitting will be dysfunctional
    30  	EnableImplicitlyTransmitAsync bool
    31  
    32  	// ShardNumber is used to shard session id, it must be larger than zero
    33  	ShardNumber int
    34  
    35  	// GCInterval decides the GC interval for SessionManager,
    36  	// it must be larger than 1s or zero means disable GC
    37  	GCInterval time.Duration
    38  }
    39  
    40  type shard struct {
    41  	lock sync.RWMutex
    42  	m    map[SessionID]Session
    43  }
    44  
    45  // SessionManager maintain and manage sessions
    46  type SessionManager struct {
    47  	shards []*shard
    48  	inGC   uint32
    49  	tik    *time.Ticker
    50  	opts   ManagerOptions
    51  }
    52  
    53  var defaultShardCap = 10
    54  
    55  func newShard() *shard {
    56  	ret := new(shard)
    57  	ret.m = make(map[SessionID]Session, defaultShardCap)
    58  	return ret
    59  }
    60  
    61  // NewSessionManager creates a SessionManager with default containers
    62  // If opts.GCInterval > 0, it will start scheduled GC() loop automatically
    63  func NewSessionManager(opts ManagerOptions) SessionManager {
    64  	if opts.ShardNumber <= 0 {
    65  		panic("ShardNumber must be larger than zero")
    66  	}
    67  	shards := make([]*shard, opts.ShardNumber)
    68  	for i := range shards {
    69  		shards[i] = newShard()
    70  	}
    71  	ret := SessionManager{
    72  		shards: shards,
    73  		opts:   opts,
    74  	}
    75  
    76  	if opts.GCInterval > 0 {
    77  		ret.startGC()
    78  	}
    79  	return ret
    80  }
    81  
    82  // Options shows the manager's Options
    83  func (self SessionManager) Options() ManagerOptions {
    84  	return self.opts
    85  }
    86  
    87  // SessionID is the identity of a session
    88  type SessionID uint64
    89  
    90  //go:nocheckptr
    91  func (s *shard) Load(id SessionID) (Session, bool) {
    92  	s.lock.RLock()
    93  	session, ok := s.m[id]
    94  	s.lock.RUnlock()
    95  	return session, ok
    96  }
    97  
    98  func (s *shard) Store(id SessionID, se Session) {
    99  	s.lock.Lock()
   100  	s.m[id] = se
   101  	s.lock.Unlock()
   102  }
   103  
   104  func (s *shard) Delete(id SessionID) {
   105  	s.lock.Lock()
   106  	delete(s.m, id)
   107  	s.lock.Unlock()
   108  }
   109  
   110  // Get gets specific session
   111  // or get inherited session if option EnableImplicitlyTransmitAsync is true
   112  func (self *SessionManager) GetSession(id SessionID) (Session, bool) {
   113  	shard := self.shards[uint64(id)%uint64(self.opts.ShardNumber)]
   114  	session, ok := shard.Load(id)
   115  	if ok {
   116  		return session, ok
   117  	}
   118  	if !self.opts.EnableImplicitlyTransmitAsync {
   119  		return nil, false
   120  	}
   121  
   122  	id, ok = getSessionID()
   123  	if !ok {
   124  		return nil, false
   125  	}
   126  	shard = self.shards[uint64(id)%uint64(self.opts.ShardNumber)]
   127  	return shard.Load(id)
   128  }
   129  
   130  // BindSession binds the session with current goroutine
   131  func (self *SessionManager) BindSession(id SessionID, s Session) {
   132  	shard := self.shards[uint64(id)%uint64(self.opts.ShardNumber)]
   133  
   134  	shard.Store(id, s)
   135  
   136  	if self.opts.EnableImplicitlyTransmitAsync {
   137  		transmitSessionID(id)
   138  	}
   139  }
   140  
   141  // UnbindSession clears current session
   142  //
   143  // Notice: If you want to end the session,
   144  // please call `Disable()` (or whatever make the session invalid)
   145  // on your session's implementation
   146  func (self *SessionManager) UnbindSession(id SessionID) {
   147  	shard := self.shards[uint64(id)%uint64(self.opts.ShardNumber)]
   148  
   149  	_, ok := shard.Load(id)
   150  	if ok {
   151  		shard.Delete(id)
   152  	}
   153  
   154  	if self.opts.EnableImplicitlyTransmitAsync {
   155  		clearSessionID()
   156  	}
   157  }
   158  
   159  // GC sweep invalid sessions and release unused memory
   160  //
   161  //go:nocheckptr
   162  func (self SessionManager) GC() {
   163  	if !atomic.CompareAndSwapUint32(&self.inGC, 0, 1) {
   164  		return
   165  	}
   166  
   167  	for _, shard := range self.shards {
   168  		shard.lock.RLock()
   169  		n := shard.m
   170  		m := make(map[SessionID]Session, len(n))
   171  		for id, s := range n {
   172  			// Warning: may panic here?
   173  			if s.IsValid() {
   174  				m[id] = s
   175  			}
   176  		}
   177  		shard.m = m
   178  		shard.lock.RUnlock()
   179  	}
   180  
   181  	atomic.StoreUint32(&self.inGC, 0)
   182  }
   183  
   184  // startGC start a scheduled goroutine to call GC() according to GCInterval
   185  func (self *SessionManager) startGC() {
   186  	if self.opts.GCInterval < time.Second {
   187  		panic("GCInterval must be larger than 1 second")
   188  	}
   189  	self.tik = time.NewTicker(self.opts.GCInterval)
   190  	go func() {
   191  		for range self.tik.C {
   192  			self.GC()
   193  		}
   194  	}()
   195  }
   196  
   197  // Close stop persistent work for the manager, like GC
   198  func (self SessionManager) Close() {
   199  	if self.tik != nil {
   200  		self.tik.Stop()
   201  	}
   202  }