github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/pool/pool.go (about)

     1  package pool
     2  
     3  import (
     4  	"container/list"
     5  	"context"
     6  	"fmt"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/jonboulle/clockwork"
    11  
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/retry"
    15  )
    16  
    17  type (
    18  	itemInfo struct {
    19  		idle    *list.Element
    20  		touched time.Time
    21  	}
    22  	Pool[T any] struct {
    23  		clock clockwork.Clock
    24  
    25  		createItem func(ctx context.Context, onClose func(item *T)) (*T, error)
    26  		deleteItem func(ctx context.Context, item *T) error
    27  		checkErr   func(err error) bool
    28  
    29  		mu                xsync.Mutex
    30  		index             map[*T]itemInfo
    31  		createInProgress  int
    32  		limit             int        // Upper bound for Pool size.
    33  		idle              *list.List // list<*T>
    34  		waitQ             *list.List // list<*chan *T>
    35  		waitChPool        sync.Pool
    36  		testHookGetWaitCh func() // nil except some tests.
    37  		wg                sync.WaitGroup
    38  		done              chan struct{}
    39  	}
    40  	option[T any] func(p *Pool[T])
    41  )
    42  
    43  func New[T any](
    44  	limit int,
    45  	createItem func(ctx context.Context, onClose func(item *T)) (*T, error),
    46  	deleteItem func(ctx context.Context, item *T) error,
    47  	checkErr func(err error) bool,
    48  	opts ...option[T],
    49  ) *Pool[T] {
    50  	p := &Pool[T]{
    51  		clock:      clockwork.NewRealClock(),
    52  		createItem: createItem,
    53  		deleteItem: deleteItem,
    54  		checkErr:   checkErr,
    55  		index:      make(map[*T]itemInfo),
    56  		idle:       list.New(),
    57  		waitQ:      list.New(),
    58  		limit:      limit,
    59  		waitChPool: sync.Pool{
    60  			New: func() interface{} {
    61  				ch := make(chan *T)
    62  
    63  				return &ch
    64  			},
    65  		},
    66  		done: make(chan struct{}),
    67  	}
    68  	for _, opt := range opts {
    69  		opt(p)
    70  	}
    71  
    72  	return p
    73  }
    74  
    75  func (p *Pool[T]) try(ctx context.Context, f func(ctx context.Context, item *T) error) error {
    76  	item, err := p.get(ctx)
    77  	if err != nil {
    78  		return xerrors.WithStackTrace(err)
    79  	}
    80  
    81  	defer func() {
    82  		select {
    83  		case <-p.done:
    84  			_ = p.deleteItem(ctx, item)
    85  		default:
    86  			p.mu.Lock()
    87  			defer p.mu.Unlock()
    88  
    89  			if p.idle.Len() >= p.limit {
    90  				_ = p.deleteItem(ctx, item)
    91  			}
    92  
    93  			if !p.notify(item) {
    94  				p.pushIdle(item, p.clock.Now())
    95  			}
    96  		}
    97  	}()
    98  
    99  	if err = f(ctx, item); err != nil {
   100  		if p.checkErr(err) {
   101  			_ = p.deleteItem(ctx, item)
   102  		}
   103  
   104  		return xerrors.WithStackTrace(err)
   105  	}
   106  
   107  	return nil
   108  }
   109  
   110  func (p *Pool[T]) With(ctx context.Context, f func(ctx context.Context, item *T) error) error {
   111  	err := retry.Retry(ctx, func(ctx context.Context) error {
   112  		err := p.try(ctx, f)
   113  		if err != nil {
   114  			return xerrors.WithStackTrace(err)
   115  		}
   116  
   117  		return nil
   118  	})
   119  	if err != nil {
   120  		return xerrors.WithStackTrace(err)
   121  	}
   122  
   123  	return nil
   124  }
   125  
   126  func (p *Pool[T]) newItem(ctx context.Context) (item *T, err error) {
   127  	select {
   128  	case <-p.done:
   129  		return nil, xerrors.WithStackTrace(errClosedPool)
   130  	default:
   131  		// pre-check the Client size
   132  		var enoughSpace bool
   133  		p.mu.WithLock(func() {
   134  			enoughSpace = p.createInProgress+len(p.index) < p.limit
   135  			if enoughSpace {
   136  				p.createInProgress++
   137  			}
   138  		})
   139  
   140  		if !enoughSpace {
   141  			return nil, xerrors.WithStackTrace(errPoolOverflow)
   142  		}
   143  
   144  		defer func() {
   145  			p.mu.WithLock(func() {
   146  				p.createInProgress--
   147  			})
   148  		}()
   149  
   150  		item, err = p.createItem(ctx, p.removeItem)
   151  		if err != nil {
   152  			return nil, xerrors.WithStackTrace(err)
   153  		}
   154  
   155  		return item, nil
   156  	}
   157  }
   158  
   159  func (p *Pool[T]) removeItem(item *T) {
   160  	p.mu.WithLock(func() {
   161  		info, has := p.index[item]
   162  		if !has {
   163  			return
   164  		}
   165  
   166  		delete(p.index, item)
   167  
   168  		select {
   169  		case <-p.done:
   170  		default:
   171  			p.notify(nil)
   172  		}
   173  
   174  		if info.idle != nil {
   175  			p.idle.Remove(info.idle)
   176  		}
   177  	})
   178  }
   179  
   180  func (p *Pool[T]) get(ctx context.Context) (item *T, err error) {
   181  	for {
   182  		select {
   183  		case <-ctx.Done():
   184  			return nil, xerrors.WithStackTrace(ctx.Err())
   185  		case <-p.done:
   186  			return nil, xerrors.WithStackTrace(errClosedPool)
   187  		default:
   188  			// First, we try to get item from idle
   189  			p.mu.WithLock(func() {
   190  				item = p.removeFirstIdle()
   191  			})
   192  			if item != nil {
   193  				return item, nil
   194  			}
   195  
   196  			// Second, we try to create item.
   197  			item, _ = p.newItem(ctx)
   198  			if item != nil {
   199  				return item, nil
   200  			}
   201  
   202  			// Third, we try to wait for a touched item - Pool is full.
   203  			//
   204  			// This should be done only if number of currently waiting goroutines
   205  			// are less than maximum amount of touched item. That is, we want to
   206  			// be fair here and not to lock more goroutines than we could ship
   207  			// item to.
   208  			item, _ = p.waitFromCh(ctx)
   209  			if item != nil {
   210  				return item, nil
   211  			}
   212  		}
   213  	}
   214  }
   215  
   216  func (p *Pool[T]) waitFromCh(ctx context.Context) (item *T, err error) {
   217  	var (
   218  		ch      *chan *T
   219  		element *list.Element // Element in the wait queue.
   220  		ok      bool
   221  	)
   222  
   223  	p.mu.WithLock(func() {
   224  		ch = p.getWaitCh()
   225  		element = p.waitQ.PushBack(ch)
   226  	})
   227  
   228  	select {
   229  	case <-ctx.Done():
   230  		p.mu.WithLock(func() {
   231  			p.waitQ.Remove(element)
   232  		})
   233  
   234  		return nil, xerrors.WithStackTrace(ctx.Err())
   235  
   236  	case <-p.done:
   237  		p.mu.WithLock(func() {
   238  			p.waitQ.Remove(element)
   239  		})
   240  
   241  		return nil, xerrors.WithStackTrace(errClosedPool)
   242  
   243  	case item, ok = <-*ch:
   244  		// Note that race may occur and some goroutine may try to write
   245  		// item into channel after it was enqueued but before it being
   246  		// read here. In that case we will receive nil here and will retry.
   247  		//
   248  		// The same way will work when some item become deleted - the
   249  		// nil value will be sent into the channel.
   250  		if ok {
   251  			// Put only filled and not closed channel back to the Pool.
   252  			// That is, we need to avoid races on filling reused channel
   253  			// for the next waiter – item could be lost for a long time.
   254  			p.putWaitCh(ch)
   255  		}
   256  
   257  		return item, nil
   258  	}
   259  }
   260  
   261  // Close deletes all stored items inside Pool.
   262  // It also stops all underlying timers and goroutines.
   263  // It returns first error occurred during stale items' deletion.
   264  // Note that even on error it calls Close() on each item.
   265  func (p *Pool[T]) Close(ctx context.Context) (err error) {
   266  	p.mu.WithLock(func() {
   267  		select {
   268  		case <-p.done:
   269  			return
   270  
   271  		default:
   272  			close(p.done)
   273  
   274  			p.limit = 0
   275  
   276  			for element := p.waitQ.Front(); element != nil; element = element.Next() {
   277  				ch := element.Value.(*chan *T)
   278  				close(*ch)
   279  			}
   280  
   281  			for element := p.idle.Front(); element != nil; element = element.Next() {
   282  				item := element.Value.(*T)
   283  				p.wg.Add(1)
   284  				go func() {
   285  					defer p.wg.Done()
   286  					_ = p.deleteItem(ctx, item)
   287  				}()
   288  			}
   289  		}
   290  	})
   291  
   292  	p.wg.Wait()
   293  
   294  	return nil
   295  }
   296  
   297  // getWaitCh returns pointer to a channel of items.
   298  //
   299  // Note that returning a pointer reduces allocations on sync.Pool usage –
   300  // sync.Pool.Get() returns empty interface, which leads to allocation for
   301  // non-pointer values.
   302  func (p *Pool[T]) getWaitCh() *chan *T { //nolint:gocritic
   303  	if p.testHookGetWaitCh != nil {
   304  		p.testHookGetWaitCh()
   305  	}
   306  	ch := p.waitChPool.Get()
   307  	s, ok := ch.(*chan *T)
   308  	if !ok {
   309  		panic(fmt.Sprintf("%T is not a chan of items", ch))
   310  	}
   311  
   312  	return s
   313  }
   314  
   315  // putWaitCh receives pointer to a channel and makes it available for further
   316  // use.
   317  // Note that ch MUST NOT be owned by any goroutine at the call moment and ch
   318  // MUST NOT contain any value.
   319  func (p *Pool[T]) putWaitCh(ch *chan *T) { //nolint:gocritic
   320  	p.waitChPool.Put(ch)
   321  }
   322  
   323  // c.mu must be held.
   324  func (p *Pool[T]) peekFirstIdle() (item *T, touched time.Time) {
   325  	element := p.idle.Front()
   326  	if element == nil {
   327  		return
   328  	}
   329  	item = element.Value.(*T)
   330  	info, has := p.index[item]
   331  	if !has || element != info.idle {
   332  		panic("inconsistent item in pool index")
   333  	}
   334  
   335  	return item, info.touched
   336  }
   337  
   338  // removes first item from idle and resets the keepAliveCount
   339  // to prevent item from dying in the internalPoolGC after it was returned
   340  // to be used only in outgoing functions that make item busy.
   341  // c.mu must be held.
   342  func (p *Pool[T]) removeFirstIdle() *T {
   343  	item, _ := p.peekFirstIdle()
   344  	if item != nil {
   345  		p.index[item] = p.removeIdle(item)
   346  	}
   347  
   348  	return item
   349  }
   350  
   351  // c.mu must be held.
   352  func (p *Pool[T]) notify(item *T) (notified bool) {
   353  	for element := p.waitQ.Front(); element != nil; element = p.waitQ.Front() {
   354  		// Some goroutine is waiting for a item.
   355  		//
   356  		// It could be in this states:
   357  		//   1) Reached the select code and awaiting for a value in channel.
   358  		//   2) Reached the select code but already in branch of deadline
   359  		//   cancellation. In this case it is locked on p.mu.Lock().
   360  		//   3) Not reached the select code and thus not reading yet from the
   361  		//   channel.
   362  		//
   363  		// For cases (2) and (3) we close the channel to signal that goroutine
   364  		// missed something and may want to retry (especially for case (3)).
   365  		//
   366  		// After that we taking a next waiter and repeat the same.
   367  		ch := p.waitQ.Remove(element).(*chan *T)
   368  		select {
   369  		case *ch <- item:
   370  			// Case (1).
   371  			return true
   372  
   373  		case <-p.done:
   374  			// Case (2) or (3).
   375  			close(*ch)
   376  
   377  		default:
   378  			// Case (2) or (3).
   379  			close(*ch)
   380  		}
   381  	}
   382  
   383  	return false
   384  }
   385  
   386  // c.mu must be held.
   387  func (p *Pool[T]) removeIdle(item *T) itemInfo {
   388  	info := p.index[item]
   389  	p.idle.Remove(info.idle)
   390  	info.idle = nil
   391  	p.index[item] = info
   392  
   393  	return info
   394  }
   395  
   396  // c.mu must be held.
   397  func (p *Pool[T]) pushIdle(item *T, now time.Time) {
   398  	p.handlePushIdle(item, now, p.idle.PushBack(item))
   399  }
   400  
   401  // c.mu must be held.
   402  func (p *Pool[T]) handlePushIdle(item *T, now time.Time, element *list.Element) {
   403  	info := p.index[item]
   404  	info.touched = now
   405  	info.idle = element
   406  	p.index[item] = info
   407  }