github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/device/pools.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"sync"
    10  	"sync/atomic"
    11  )
    12  
    13  type WaitPool struct {
    14  	pool  sync.Pool
    15  	cond  sync.Cond
    16  	lock  sync.Mutex
    17  	count atomic.Uint32
    18  	max   uint32
    19  }
    20  
    21  func NewWaitPool(max uint32, new func() any) *WaitPool {
    22  	p := &WaitPool{pool: sync.Pool{New: new}, max: max}
    23  	p.cond = sync.Cond{L: &p.lock}
    24  	return p
    25  }
    26  
    27  func (p *WaitPool) Get() any {
    28  	if p.max != 0 {
    29  		p.lock.Lock()
    30  		for p.count.Load() >= p.max {
    31  			p.cond.Wait()
    32  		}
    33  		p.count.Add(1)
    34  		p.lock.Unlock()
    35  	}
    36  	return p.pool.Get()
    37  }
    38  
    39  func (p *WaitPool) Put(x any) {
    40  	p.pool.Put(x)
    41  	if p.max == 0 {
    42  		return
    43  	}
    44  	p.count.Add(^uint32(0))
    45  	p.cond.Signal()
    46  }
    47  
    48  func (device *Device) PopulatePools() {
    49  	device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    50  		s := make([]*QueueOutboundElement, 0, device.BatchSize())
    51  		return &s
    52  	})
    53  	device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    54  		s := make([]*QueueInboundElement, 0, device.BatchSize())
    55  		return &s
    56  	})
    57  	device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    58  		return new([MaxMessageSize]byte)
    59  	})
    60  	device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    61  		return new(QueueInboundElement)
    62  	})
    63  	device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    64  		return new(QueueOutboundElement)
    65  	})
    66  }
    67  
    68  func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement {
    69  	return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement)
    70  }
    71  
    72  func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) {
    73  	for i := range *s {
    74  		(*s)[i] = nil
    75  	}
    76  	*s = (*s)[:0]
    77  	device.pool.outboundElementsSlice.Put(s)
    78  }
    79  
    80  func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement {
    81  	return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement)
    82  }
    83  
    84  func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) {
    85  	for i := range *s {
    86  		(*s)[i] = nil
    87  	}
    88  	*s = (*s)[:0]
    89  	device.pool.inboundElementsSlice.Put(s)
    90  }
    91  
    92  func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
    93  	return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
    94  }
    95  
    96  func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
    97  	device.pool.messageBuffers.Put(msg)
    98  }
    99  
   100  func (device *Device) GetInboundElement() *QueueInboundElement {
   101  	return device.pool.inboundElements.Get().(*QueueInboundElement)
   102  }
   103  
   104  func (device *Device) PutInboundElement(elem *QueueInboundElement) {
   105  	elem.clearPointers()
   106  	device.pool.inboundElements.Put(elem)
   107  }
   108  
   109  func (device *Device) GetOutboundElement() *QueueOutboundElement {
   110  	return device.pool.outboundElements.Get().(*QueueOutboundElement)
   111  }
   112  
   113  func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
   114  	elem.clearPointers()
   115  	device.pool.outboundElements.Put(elem)
   116  }