github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/device/pools.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "sync" 10 "sync/atomic" 11 ) 12 13 type WaitPool struct { 14 pool sync.Pool 15 cond sync.Cond 16 lock sync.Mutex 17 count atomic.Uint32 18 max uint32 19 } 20 21 func NewWaitPool(max uint32, new func() any) *WaitPool { 22 p := &WaitPool{pool: sync.Pool{New: new}, max: max} 23 p.cond = sync.Cond{L: &p.lock} 24 return p 25 } 26 27 func (p *WaitPool) Get() any { 28 if p.max != 0 { 29 p.lock.Lock() 30 for p.count.Load() >= p.max { 31 p.cond.Wait() 32 } 33 p.count.Add(1) 34 p.lock.Unlock() 35 } 36 return p.pool.Get() 37 } 38 39 func (p *WaitPool) Put(x any) { 40 p.pool.Put(x) 41 if p.max == 0 { 42 return 43 } 44 p.count.Add(^uint32(0)) 45 p.cond.Signal() 46 } 47 48 func (device *Device) PopulatePools() { 49 device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { 50 s := make([]*QueueOutboundElement, 0, device.BatchSize()) 51 return &s 52 }) 53 device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { 54 s := make([]*QueueInboundElement, 0, device.BatchSize()) 55 return &s 56 }) 57 device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { 58 return new([MaxMessageSize]byte) 59 }) 60 device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 61 return new(QueueInboundElement) 62 }) 63 device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 64 return new(QueueOutboundElement) 65 }) 66 } 67 68 func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement { 69 return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement) 70 } 71 72 func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) { 73 for i := range *s { 74 (*s)[i] = nil 75 } 76 *s = (*s)[:0] 77 device.pool.outboundElementsSlice.Put(s) 78 } 79 80 func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement { 81 return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement) 82 } 83 84 func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) { 85 for i := range *s { 86 (*s)[i] = nil 87 } 88 *s = (*s)[:0] 89 device.pool.inboundElementsSlice.Put(s) 90 } 91 92 func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { 93 return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) 94 } 95 96 func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { 97 device.pool.messageBuffers.Put(msg) 98 } 99 100 func (device *Device) GetInboundElement() *QueueInboundElement { 101 return device.pool.inboundElements.Get().(*QueueInboundElement) 102 } 103 104 func (device *Device) PutInboundElement(elem *QueueInboundElement) { 105 elem.clearPointers() 106 device.pool.inboundElements.Put(elem) 107 } 108 109 func (device *Device) GetOutboundElement() *QueueOutboundElement { 110 return device.pool.outboundElements.Get().(*QueueOutboundElement) 111 } 112 113 func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { 114 elem.clearPointers() 115 device.pool.outboundElements.Put(elem) 116 }