github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/device/pools.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"sync"
    10  
    11  	"github.com/sagernet/sing/common/atomic"
    12  )
    13  
    14  type WaitPool struct {
    15  	pool  sync.Pool
    16  	cond  sync.Cond
    17  	lock  sync.Mutex
    18  	count atomic.Uint32
    19  	max   uint32
    20  }
    21  
    22  func NewWaitPool(max uint32, new func() any) *WaitPool {
    23  	p := &WaitPool{pool: sync.Pool{New: new}, max: max}
    24  	p.cond = sync.Cond{L: &p.lock}
    25  	return p
    26  }
    27  
    28  func (p *WaitPool) Get() any {
    29  	if p.max != 0 {
    30  		p.lock.Lock()
    31  		for p.count.Load() >= p.max {
    32  			p.cond.Wait()
    33  		}
    34  		p.count.Add(1)
    35  		p.lock.Unlock()
    36  	}
    37  	return p.pool.Get()
    38  }
    39  
    40  func (p *WaitPool) Put(x any) {
    41  	p.pool.Put(x)
    42  	if p.max == 0 {
    43  		return
    44  	}
    45  	p.count.Add(^uint32(0))
    46  	p.cond.Signal()
    47  }
    48  
    49  func (device *Device) PopulatePools() {
    50  	device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    51  		s := make([]*QueueInboundElement, 0, device.BatchSize())
    52  		return &QueueInboundElementsContainer{elems: s}
    53  	})
    54  	device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    55  		s := make([]*QueueOutboundElement, 0, device.BatchSize())
    56  		return &QueueOutboundElementsContainer{elems: s}
    57  	})
    58  	device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    59  		return new([MaxMessageSize]byte)
    60  	})
    61  	device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    62  		return new(QueueInboundElement)
    63  	})
    64  	device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    65  		return new(QueueOutboundElement)
    66  	})
    67  }
    68  
    69  func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
    70  	c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
    71  	c.Mutex = sync.Mutex{}
    72  	return c
    73  }
    74  
    75  func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
    76  	for i := range c.elems {
    77  		c.elems[i] = nil
    78  	}
    79  	c.elems = c.elems[:0]
    80  	device.pool.inboundElementsContainer.Put(c)
    81  }
    82  
    83  func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
    84  	c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
    85  	c.Mutex = sync.Mutex{}
    86  	return c
    87  }
    88  
    89  func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
    90  	for i := range c.elems {
    91  		c.elems[i] = nil
    92  	}
    93  	c.elems = c.elems[:0]
    94  	device.pool.outboundElementsContainer.Put(c)
    95  }
    96  
    97  func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
    98  	return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
    99  }
   100  
   101  func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
   102  	device.pool.messageBuffers.Put(msg)
   103  }
   104  
   105  func (device *Device) GetInboundElement() *QueueInboundElement {
   106  	return device.pool.inboundElements.Get().(*QueueInboundElement)
   107  }
   108  
   109  func (device *Device) PutInboundElement(elem *QueueInboundElement) {
   110  	elem.clearPointers()
   111  	device.pool.inboundElements.Put(elem)
   112  }
   113  
   114  func (device *Device) GetOutboundElement() *QueueOutboundElement {
   115  	return device.pool.outboundElements.Get().(*QueueOutboundElement)
   116  }
   117  
   118  func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
   119  	elem.clearPointers()
   120  	device.pool.outboundElements.Put(elem)
   121  }