github.com/timandy/routine@v1.1.4-0.20240507073150-e4a3e1fe2ba5/thread_local_inheritable.go (about)

     1  package routine
     2  
     3  import "sync/atomic"
     4  
     5  var inheritableThreadLocalIndex int32 = -1
     6  
     7  func nextInheritableThreadLocalIndex() int {
     8  	index := atomic.AddInt32(&inheritableThreadLocalIndex, 1)
     9  	if index < 0 {
    10  		atomic.AddInt32(&inheritableThreadLocalIndex, -1)
    11  		panic("too many inheritable-thread-local indexed variables")
    12  	}
    13  	return int(index)
    14  }
    15  
    16  type inheritableThreadLocal[T any] struct {
    17  	index    int
    18  	supplier Supplier[T]
    19  }
    20  
    21  func (tls *inheritableThreadLocal[T]) Get() T {
    22  	t := currentThread(true)
    23  	mp := tls.getMap(t)
    24  	if mp != nil {
    25  		v := mp.get(tls.index)
    26  		if v != unset {
    27  			return entryValue[T](v)
    28  		}
    29  	}
    30  	return tls.setInitialValue(t)
    31  }
    32  
    33  func (tls *inheritableThreadLocal[T]) Set(value T) {
    34  	t := currentThread(true)
    35  	mp := tls.getMap(t)
    36  	if mp != nil {
    37  		mp.set(tls.index, entry(value))
    38  	} else {
    39  		tls.createMap(t, value)
    40  	}
    41  }
    42  
    43  func (tls *inheritableThreadLocal[T]) Remove() {
    44  	t := currentThread(false)
    45  	if t == nil {
    46  		return
    47  	}
    48  	mp := tls.getMap(t)
    49  	if mp != nil {
    50  		mp.remove(tls.index)
    51  	}
    52  }
    53  
    54  func (tls *inheritableThreadLocal[T]) getMap(t *thread) *threadLocalMap {
    55  	return t.inheritableThreadLocals
    56  }
    57  
    58  func (tls *inheritableThreadLocal[T]) createMap(t *thread, firstValue T) {
    59  	mp := &threadLocalMap{}
    60  	mp.set(tls.index, entry(firstValue))
    61  	t.inheritableThreadLocals = mp
    62  }
    63  
    64  func (tls *inheritableThreadLocal[T]) setInitialValue(t *thread) T {
    65  	value := tls.initialValue()
    66  	mp := tls.getMap(t)
    67  	if mp != nil {
    68  		mp.set(tls.index, entry(value))
    69  	} else {
    70  		tls.createMap(t, value)
    71  	}
    72  	return value
    73  }
    74  
    75  func (tls *inheritableThreadLocal[T]) initialValue() T {
    76  	if tls.supplier == nil {
    77  		var defaultValue T
    78  		return defaultValue
    79  	}
    80  	return tls.supplier()
    81  }