github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/device/pools.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"sync"
    10  	"sync/atomic"
    11  )
    12  
    13  type WaitPool struct {
    14  	pool  sync.Pool
    15  	cond  sync.Cond
    16  	lock  sync.Mutex
    17  	count uint32
    18  	max   uint32
    19  }
    20  
    21  func NewWaitPool(max uint32, new func() interface{}) *WaitPool {
    22  	p := &WaitPool{pool: sync.Pool{New: new}, max: max}
    23  	p.cond = sync.Cond{L: &p.lock}
    24  	return p
    25  }
    26  
    27  func (p *WaitPool) Get() interface{} {
    28  	if p.max != 0 {
    29  		p.lock.Lock()
    30  		for atomic.LoadUint32(&p.count) >= p.max {
    31  			p.cond.Wait()
    32  		}
    33  		atomic.AddUint32(&p.count, 1)
    34  		p.lock.Unlock()
    35  	}
    36  	return p.pool.Get()
    37  }
    38  
    39  func (p *WaitPool) Put(x interface{}) {
    40  	p.pool.Put(x)
    41  	if p.max == 0 {
    42  		return
    43  	}
    44  	atomic.AddUint32(&p.count, ^uint32(0))
    45  	p.cond.Signal()
    46  }
    47  
    48  func (device *Device) PopulatePools() {
    49  	device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
    50  		return new([MaxMessageSize]byte)
    51  	})
    52  	device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
    53  		return new(QueueInboundElement)
    54  	})
    55  	device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
    56  		return new(QueueOutboundElement)
    57  	})
    58  }
    59  
    60  func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
    61  	return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
    62  }
    63  
    64  func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
    65  	device.pool.messageBuffers.Put(msg)
    66  }
    67  
    68  func (device *Device) GetInboundElement() *QueueInboundElement {
    69  	return device.pool.inboundElements.Get().(*QueueInboundElement)
    70  }
    71  
    72  func (device *Device) PutInboundElement(elem *QueueInboundElement) {
    73  	elem.clearPointers()
    74  	device.pool.inboundElements.Put(elem)
    75  }
    76  
    77  func (device *Device) GetOutboundElement() *QueueOutboundElement {
    78  	return device.pool.outboundElements.Get().(*QueueOutboundElement)
    79  }
    80  
    81  func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
    82  	elem.clearPointers()
    83  	device.pool.outboundElements.Put(elem)
    84  }