github.com/reusee/pr2@v0.0.0-20230630035947-72a20ff5e864/pool.go (about)

     1  package pr2
     2  
     3  import (
     4  	"sync"
     5  	"sync/atomic"
     6  	_ "unsafe"
     7  )
     8  
     9  type Pool[T any] struct {
    10  	l        sync.Mutex
    11  	newFunc  func() T
    12  	elems    atomic.Pointer[[]_PoolElem[T]]
    13  	capacity uint32
    14  }
    15  
    16  type _PoolElem[T any] struct {
    17  	refs   atomic.Int32
    18  	put    func() bool
    19  	incRef func()
    20  	value  T
    21  }
    22  
    23  func NewPool[T any](
    24  	capacity uint32,
    25  	newFunc func() T,
    26  ) *Pool[T] {
    27  	pool := &Pool[T]{
    28  		capacity: capacity,
    29  		newFunc:  newFunc,
    30  	}
    31  	pool.allocElems(nil)
    32  	return pool
    33  }
    34  
    35  func (p *Pool[T]) allocElems(old *[]_PoolElem[T]) {
    36  	p.l.Lock()
    37  	defer p.l.Unlock()
    38  	if old != nil && p.elems.Load() != old {
    39  		// refreshed
    40  		return
    41  	}
    42  	elems := make([]_PoolElem[T], p.capacity)
    43  	for i := uint32(0); i < p.capacity; i++ {
    44  		i := i
    45  		ptr := p.newFunc()
    46  		elems[i] = _PoolElem[T]{
    47  			value: ptr,
    48  			put: func() bool {
    49  				if c := elems[i].refs.Add(-1); c == 0 {
    50  					return true
    51  				} else if c < 0 {
    52  					panic("bad put")
    53  				}
    54  				return false
    55  			},
    56  			incRef: func() {
    57  				elems[i].refs.Add(1)
    58  			},
    59  		}
    60  	}
    61  	p.elems.Store(&elems)
    62  }
    63  
    64  func (p *Pool[T]) Get(ptr *T) (put func() bool) {
    65  	put, _ = p.GetRC(ptr)
    66  	return
    67  }
    68  
    69  func (p *Pool[T]) GetRC(ptr *T) (
    70  	put func() bool,
    71  	incRef func(),
    72  ) {
    73  
    74  	for {
    75  		cur := p.elems.Load()
    76  		elems := *cur
    77  		for i := 0; i < 16; i++ {
    78  			idx := fastrand() % p.capacity
    79  			if elems[idx].refs.CompareAndSwap(0, 1) {
    80  				*ptr = elems[idx].value
    81  				put = elems[idx].put
    82  				incRef = elems[idx].incRef
    83  				return
    84  			}
    85  		}
    86  
    87  		p.allocElems(cur)
    88  	}
    89  
    90  }
    91  
    92  func (p *Pool[T]) Getter() (
    93  	get func(*T),
    94  	putAll func(),
    95  ) {
    96  
    97  	var l sync.Mutex
    98  	var curPut func()
    99  
   100  	get = func(ptr *T) {
   101  		put := p.Get(ptr)
   102  		l.Lock()
   103  		if curPut != nil {
   104  			cur := curPut
   105  			newPut := func() {
   106  				put()
   107  				cur()
   108  			}
   109  			curPut = newPut
   110  		} else {
   111  			curPut = func() {
   112  				put()
   113  			}
   114  		}
   115  		l.Unlock()
   116  	}
   117  
   118  	putAll = func() {
   119  		l.Lock()
   120  		put := curPut
   121  		curPut = nil
   122  		l.Unlock()
   123  		put()
   124  	}
   125  
   126  	return
   127  }
   128  
   129  //go:linkname fastrand runtime.fastrand
   130  func fastrand() uint32