github.com/benz9527/xboot@v0.0.0-20240504061247-c23f15593274/lib/kv/thread_safe_map.go (about)

     1  package kv
     2  
     3  import (
     4  	"io"
     5  	"reflect"
     6  	"sync"
     7  
     8  	"go.uber.org/multierr"
     9  
    10  	"github.com/benz9527/xboot/lib/infra"
    11  )
    12  
    13  type threadSafeMap[K comparable, V any] struct {
    14  	lock           sync.RWMutex
    15  	m              *swissMap[K, V]
    16  	initCap        uint32
    17  	isClosableItem bool
    18  }
    19  
    20  func (t *threadSafeMap[K, V]) AddOrUpdate(key K, obj V) error {
    21  	t.lock.Lock()
    22  	defer t.lock.Unlock()
    23  	return t.m.Put(key, obj)
    24  }
    25  
    26  func (t *threadSafeMap[K, V]) Replace(items map[K]V) error {
    27  	t.lock.Lock()
    28  	defer t.lock.Unlock()
    29  	t.m.Clear()
    30  	return t.m.MigrateFrom(items)
    31  }
    32  
    33  func (t *threadSafeMap[K, V]) Delete(key K) (V, error) {
    34  	t.lock.Lock()
    35  	defer t.lock.Unlock()
    36  	return t.m.Delete(key)
    37  }
    38  
    39  func (t *threadSafeMap[K, V]) Get(key K) (item V, exists bool) {
    40  	t.lock.RLock()
    41  	defer t.lock.RUnlock()
    42  	item, exists = t.m.Get(key)
    43  	return
    44  }
    45  
    46  func (t *threadSafeMap[K, V]) ListKeys(filters ...SafeStoreKeyFilterFunc[K]) []K {
    47  	realFilters := make([]SafeStoreKeyFilterFunc[K], 0, len(filters))
    48  	for _, filter := range filters {
    49  		if filter != nil {
    50  			realFilters = append(realFilters, filter)
    51  		}
    52  	}
    53  	if len(realFilters) == 0 {
    54  		realFilters = append(realFilters, defaultAllKeysFilter[K])
    55  	}
    56  
    57  	t.lock.RLock()
    58  	defer t.lock.RUnlock()
    59  
    60  	keys := make([]K, 0, t.m.Len())
    61  	t.m.Foreach(func(i uint64, key K, val V) bool {
    62  		for _, filter := range realFilters {
    63  			if filter(key) {
    64  				keys = append(keys, key)
    65  				break
    66  			}
    67  		}
    68  		return true
    69  	})
    70  	return keys
    71  }
    72  
    73  func (t *threadSafeMap[K, V]) ListValues(keys ...K) (items []V) {
    74  	realKeys := make([]K, 0, len(keys))
    75  	for _, key := range keys {
    76  		realKeys = append(realKeys, key)
    77  	}
    78  
    79  	contains := func(keys []K, key K) bool {
    80  		for _, k := range keys {
    81  			if k == key {
    82  				return true
    83  			}
    84  		}
    85  		return false
    86  	}
    87  
    88  	t.lock.RLock()
    89  	defer t.lock.RUnlock()
    90  	values := make([]V, 0, t.m.Len())
    91  	t.m.Foreach(func(i uint64, key K, val V) bool {
    92  		if len(realKeys) > 0 && contains(realKeys, key) {
    93  			values = append(values, val)
    94  		} else if len(realKeys) == 0 {
    95  			values = append(values, val)
    96  		}
    97  		return true
    98  	})
    99  	return values
   100  }
   101  
   102  func (t *threadSafeMap[K, V]) Purge() error {
   103  	t.lock.Lock()
   104  	defer t.lock.Unlock()
   105  
   106  	var merr error
   107  	if t.isClosableItem {
   108  		t.m.Foreach(func(i uint64, key K, val V) bool {
   109  			if reflect.ValueOf(val).IsNil() {
   110  				return true
   111  			}
   112  			typ := reflect.TypeOf(val)
   113  			if typ.Implements(reflect.TypeOf((*io.Closer)(nil)).Elem()) {
   114  				vals := reflect.ValueOf(val).MethodByName("Close").Call([]reflect.Value{})
   115  				if len(vals) > 0 && !vals[0].IsNil() {
   116  					intf := vals[0].Elem().Interface()
   117  					switch intf.(type) {
   118  					case error:
   119  						if err := intf.(error); err != nil {
   120  							merr = multierr.Append(merr, err) // FIXME: memory leak?
   121  						}
   122  					}
   123  				}
   124  			}
   125  			return true
   126  		})
   127  	}
   128  	t.m = nil
   129  	return infra.WrapErrorStack(merr)
   130  }
   131  
   132  type ThreadSafeMapOption[K comparable, V any] func(*threadSafeMap[K, V]) error
   133  
   134  func NewThreadSafeMap[K comparable, V any](opts ...ThreadSafeMapOption[K, V]) ThreadSafeStorer[K, V] {
   135  	tsm := &threadSafeMap[K, V]{}
   136  	for _, opt := range opts {
   137  		if err := opt(tsm); err != nil {
   138  			panic(err)
   139  		}
   140  	}
   141  	if tsm.initCap == 0 {
   142  		tsm.initCap = 1024
   143  	}
   144  	tsm.m = newSwissMap[K, V](tsm.initCap)
   145  	return tsm
   146  }
   147  
   148  func WithThreadSafeMapInitCap[K comparable, V any](capacity uint32) ThreadSafeMapOption[K, V] {
   149  	return func(tsm *threadSafeMap[K, V]) error {
   150  		if capacity == 0 {
   151  			capacity = 1024
   152  		}
   153  		tsm.initCap = capacity
   154  		return nil
   155  	}
   156  }
   157  
   158  func WithThreadSafeMapCloseableItemCheck[K comparable, V any]() ThreadSafeMapOption[K, V] {
   159  	return func(tsm *threadSafeMap[K, V]) error {
   160  		nilT := new(V)
   161  		if !reflect.ValueOf(nilT).IsNil() {
   162  			if reflect.TypeOf(nilT).Implements(reflect.TypeOf((*io.Closer)(nil)).Elem()) {
   163  				tsm.isClosableItem = true
   164  			}
   165  		} else {
   166  			_nilT := *new(V)
   167  			if reflect.TypeOf(_nilT).Implements(reflect.TypeOf((*io.Closer)(nil)).Elem()) {
   168  				tsm.isClosableItem = true
   169  			}
   170  		}
   171  		return nil
   172  	}
   173  }