github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/device/pools.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package device 7 8 import ( 9 "sync" 10 11 "github.com/sagernet/sing/common/atomic" 12 ) 13 14 type WaitPool struct { 15 pool sync.Pool 16 cond sync.Cond 17 lock sync.Mutex 18 count atomic.Uint32 19 max uint32 20 } 21 22 func NewWaitPool(max uint32, new func() any) *WaitPool { 23 p := &WaitPool{pool: sync.Pool{New: new}, max: max} 24 p.cond = sync.Cond{L: &p.lock} 25 return p 26 } 27 28 func (p *WaitPool) Get() any { 29 if p.max != 0 { 30 p.lock.Lock() 31 for p.count.Load() >= p.max { 32 p.cond.Wait() 33 } 34 p.count.Add(1) 35 p.lock.Unlock() 36 } 37 return p.pool.Get() 38 } 39 40 func (p *WaitPool) Put(x any) { 41 p.pool.Put(x) 42 if p.max == 0 { 43 return 44 } 45 p.count.Add(^uint32(0)) 46 p.cond.Signal() 47 } 48 49 func (device *Device) PopulatePools() { 50 device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { 51 s := make([]*QueueInboundElement, 0, device.BatchSize()) 52 return &QueueInboundElementsContainer{elems: s} 53 }) 54 device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { 55 s := make([]*QueueOutboundElement, 0, device.BatchSize()) 56 return &QueueOutboundElementsContainer{elems: s} 57 }) 58 device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { 59 return new([MaxMessageSize]byte) 60 }) 61 device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 62 return new(QueueInboundElement) 63 }) 64 device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { 65 return new(QueueOutboundElement) 66 }) 67 } 68 69 func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { 70 c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) 71 c.Mutex = sync.Mutex{} 72 return c 73 } 74 75 func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { 76 for i := range c.elems { 77 c.elems[i] = nil 78 } 79 c.elems = c.elems[:0] 80 device.pool.inboundElementsContainer.Put(c) 81 } 82 83 func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { 84 c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) 85 c.Mutex = sync.Mutex{} 86 return c 87 } 88 89 func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { 90 for i := range c.elems { 91 c.elems[i] = nil 92 } 93 c.elems = c.elems[:0] 94 device.pool.outboundElementsContainer.Put(c) 95 } 96 97 func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { 98 return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) 99 } 100 101 func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { 102 device.pool.messageBuffers.Put(msg) 103 } 104 105 func (device *Device) GetInboundElement() *QueueInboundElement { 106 return device.pool.inboundElements.Get().(*QueueInboundElement) 107 } 108 109 func (device *Device) PutInboundElement(elem *QueueInboundElement) { 110 elem.clearPointers() 111 device.pool.inboundElements.Put(elem) 112 } 113 114 func (device *Device) GetOutboundElement() *QueueOutboundElement { 115 return device.pool.outboundElements.Get().(*QueueOutboundElement) 116 } 117 118 func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { 119 elem.clearPointers() 120 device.pool.outboundElements.Put(elem) 121 }