github.com/amnezia-vpn/amnezia-wg@v0.1.8/device/pools.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"sync"
    10  	"sync/atomic"
    11  )
    12  
    13  type WaitPool struct {
    14  	pool  sync.Pool
    15  	cond  sync.Cond
    16  	lock  sync.Mutex
    17  	count atomic.Uint32
    18  	max   uint32
    19  }
    20  
    21  func NewWaitPool(max uint32, new func() any) *WaitPool {
    22  	p := &WaitPool{pool: sync.Pool{New: new}, max: max}
    23  	p.cond = sync.Cond{L: &p.lock}
    24  	return p
    25  }
    26  
    27  func (p *WaitPool) Get() any {
    28  	if p.max != 0 {
    29  		p.lock.Lock()
    30  		for p.count.Load() >= p.max {
    31  			p.cond.Wait()
    32  		}
    33  		p.count.Add(1)
    34  		p.lock.Unlock()
    35  	}
    36  	return p.pool.Get()
    37  }
    38  
    39  func (p *WaitPool) Put(x any) {
    40  	p.pool.Put(x)
    41  	if p.max == 0 {
    42  		return
    43  	}
    44  	p.count.Add(^uint32(0))
    45  	p.cond.Signal()
    46  }
    47  
    48  func (device *Device) PopulatePools() {
    49  	device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    50  		s := make([]*QueueInboundElement, 0, device.BatchSize())
    51  		return &QueueInboundElementsContainer{elems: s}
    52  	})
    53  	device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    54  		s := make([]*QueueOutboundElement, 0, device.BatchSize())
    55  		return &QueueOutboundElementsContainer{elems: s}
    56  	})
    57  	device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    58  		return new([MaxMessageSize]byte)
    59  	})
    60  	device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    61  		return new(QueueInboundElement)
    62  	})
    63  	device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    64  		return new(QueueOutboundElement)
    65  	})
    66  }
    67  
    68  func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
    69  	c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
    70  	c.Mutex = sync.Mutex{}
    71  	return c
    72  }
    73  
    74  func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
    75  	for i := range c.elems {
    76  		c.elems[i] = nil
    77  	}
    78  	c.elems = c.elems[:0]
    79  	device.pool.inboundElementsContainer.Put(c)
    80  }
    81  
    82  func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
    83  	c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
    84  	c.Mutex = sync.Mutex{}
    85  	return c
    86  }
    87  
    88  func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
    89  	for i := range c.elems {
    90  		c.elems[i] = nil
    91  	}
    92  	c.elems = c.elems[:0]
    93  	device.pool.outboundElementsContainer.Put(c)
    94  }
    95  
    96  func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
    97  	return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
    98  }
    99  
   100  func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
   101  	device.pool.messageBuffers.Put(msg)
   102  }
   103  
   104  func (device *Device) GetInboundElement() *QueueInboundElement {
   105  	return device.pool.inboundElements.Get().(*QueueInboundElement)
   106  }
   107  
   108  func (device *Device) PutInboundElement(elem *QueueInboundElement) {
   109  	elem.clearPointers()
   110  	device.pool.inboundElements.Put(elem)
   111  }
   112  
   113  func (device *Device) GetOutboundElement() *QueueOutboundElement {
   114  	return device.pool.outboundElements.Get().(*QueueOutboundElement)
   115  }
   116  
   117  func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
   118  	elem.clearPointers()
   119  	device.pool.outboundElements.Put(elem)
   120  }