github.com/amnezia-vpn/amneziawg-go@v0.2.8/device/pools.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "sync" 10 "sync/atomic" 11 ) 12 13 type WaitPool struct { 14 pool sync.Pool 15 cond sync.Cond 16 lock sync.Mutex 17 count atomic.Uint32 18 max uint32 19 } 20 21 func NewWaitPool(max uint32, new func() any) *WaitPool { 22 p := &WaitPool{pool: sync.Pool{New: new}, max: max} 23 p.cond = sync.Cond{L: &p.lock} 24 return p 25 } 26 27 func (p *WaitPool) Get() any { 28 if p.max != 0 { 29 p.lock.Lock() 30 for p.count.Load() >= p.max { 31 p.cond.Wait() 32 } 33 p.count.Add(1) 34 p.lock.Unlock() 35 } 36 return p.pool.Get() 37 } 38 39 func (p *WaitPool) Put(x any) { 40 p.pool.Put(x) 41 if p.max == 0 { 42 return 43 } 44 p.count.Add(^uint32(0)) 45 p.cond.Signal() 46 } 47 48 func (device *Device) PopulatePools() { 49 device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { 50 s := make([]*QueueInboundElement, 0, device.BatchSize()) 51 return &QueueInboundElementsContainer{elems: s} 52 }) 53 device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { 54 s := make([]*QueueOutboundElement, 0, device.BatchSize()) 55 return &QueueOutboundElementsContainer{elems: s} 56 }) 57 device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { 58 return new([MaxMessageSize]byte) 59 }) 60 device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 61 return new(QueueInboundElement) 62 }) 63 device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 64 return new(QueueOutboundElement) 65 }) 66 } 67 68 func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { 69 c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) 70 c.Mutex = sync.Mutex{} 71 return c 72 } 73 74 func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { 75 for i := range c.elems { 76 c.elems[i] = nil 77 } 78 c.elems = c.elems[:0] 79 device.pool.inboundElementsContainer.Put(c) 80 } 81 82 func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { 83 c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) 84 c.Mutex = sync.Mutex{} 85 return c 86 } 87 88 func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { 89 for i := range c.elems { 90 c.elems[i] = nil 91 } 92 c.elems = c.elems[:0] 93 device.pool.outboundElementsContainer.Put(c) 94 } 95 96 func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { 97 return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) 98 } 99 100 func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { 101 device.pool.messageBuffers.Put(msg) 102 } 103 104 func (device *Device) GetInboundElement() *QueueInboundElement { 105 return device.pool.inboundElements.Get().(*QueueInboundElement) 106 } 107 108 func (device *Device) PutInboundElement(elem *QueueInboundElement) { 109 elem.clearPointers() 110 device.pool.inboundElements.Put(elem) 111 } 112 113 func (device *Device) GetOutboundElement() *QueueOutboundElement { 114 return device.pool.outboundElements.Get().(*QueueOutboundElement) 115 } 116 117 func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { 118 elem.clearPointers() 119 device.pool.outboundElements.Put(elem) 120 }