github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/pools.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   *
     9   * Portions of this file are based on code originally from wireguard-go,
    10   *
    11   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
    12   *
    13   * Permission is hereby granted, free of charge, to any person obtaining a copy of
    14   * this software and associated documentation files (the "Software"), to deal in
    15   * the Software without restriction, including without limitation the rights to
    16   * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
    17   * of the Software, and to permit persons to whom the Software is furnished to do
    18   * so, subject to the following conditions:
    19   *
    20   * The above copyright notice and this permission notice shall be included in all
    21   * copies or substantial portions of the Software.
    22   *
    23   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    24   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    25   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    26   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    27   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    28   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    29   * SOFTWARE.
    30   */
    31  
    32  package transport
    33  
    34  import (
    35  	"sync"
    36  	"sync/atomic"
    37  )
    38  
    39  type WaitPool struct {
    40  	pool  sync.Pool
    41  	cond  sync.Cond
    42  	lock  sync.Mutex
    43  	count atomic.Uint32
    44  	max   uint32
    45  }
    46  
    47  func NewWaitPool(max uint32, new func() any) *WaitPool {
    48  	p := &WaitPool{pool: sync.Pool{New: new}, max: max}
    49  	p.cond = sync.Cond{L: &p.lock}
    50  	return p
    51  }
    52  
    53  func (p *WaitPool) Get() any {
    54  	if p.max != 0 {
    55  		p.lock.Lock()
    56  		for p.count.Load() >= p.max {
    57  			p.cond.Wait()
    58  		}
    59  		p.count.Add(1)
    60  		p.lock.Unlock()
    61  	}
    62  	return p.pool.Get()
    63  }
    64  
    65  func (p *WaitPool) Put(x any) {
    66  	p.pool.Put(x)
    67  	if p.max == 0 {
    68  		return
    69  	}
    70  	p.count.Add(^uint32(0))
    71  	p.cond.Signal()
    72  }
    73  
    74  func (transport *Transport) PopulatePools() {
    75  	transport.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    76  		s := make([]*QueueInboundElement, 0, transport.BatchSize())
    77  		return &QueueInboundElementsContainer{elems: s}
    78  	})
    79  	transport.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    80  		s := make([]*QueueOutboundElement, 0, transport.BatchSize())
    81  		return &QueueOutboundElementsContainer{elems: s}
    82  	})
    83  	transport.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    84  		return new([MaxMessageSize]byte)
    85  	})
    86  	transport.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    87  		return new(QueueInboundElement)
    88  	})
    89  	transport.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
    90  		return new(QueueOutboundElement)
    91  	})
    92  }
    93  
    94  func (transport *Transport) GetInboundElementsContainer() *QueueInboundElementsContainer {
    95  	c := transport.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
    96  	c.Mutex = sync.Mutex{}
    97  	return c
    98  }
    99  
   100  func (transport *Transport) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
   101  	for i := range c.elems {
   102  		c.elems[i] = nil
   103  	}
   104  	c.elems = c.elems[:0]
   105  	transport.pool.inboundElementsContainer.Put(c)
   106  }
   107  
   108  func (transport *Transport) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
   109  	c := transport.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
   110  	c.Mutex = sync.Mutex{}
   111  	return c
   112  }
   113  
   114  func (transport *Transport) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
   115  	for i := range c.elems {
   116  		c.elems[i] = nil
   117  	}
   118  	c.elems = c.elems[:0]
   119  	transport.pool.outboundElementsContainer.Put(c)
   120  }
   121  
   122  func (transport *Transport) GetMessageBuffer() *[MaxMessageSize]byte {
   123  	return transport.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
   124  }
   125  
   126  func (transport *Transport) PutMessageBuffer(msg *[MaxMessageSize]byte) {
   127  	transport.pool.messageBuffers.Put(msg)
   128  }
   129  
   130  func (transport *Transport) GetInboundElement() *QueueInboundElement {
   131  	return transport.pool.inboundElements.Get().(*QueueInboundElement)
   132  }
   133  
   134  func (transport *Transport) PutInboundElement(elem *QueueInboundElement) {
   135  	elem.clearPointers()
   136  	transport.pool.inboundElements.Put(elem)
   137  }
   138  
   139  func (transport *Transport) GetOutboundElement() *QueueOutboundElement {
   140  	return transport.pool.outboundElements.Get().(*QueueOutboundElement)
   141  }
   142  
   143  func (transport *Transport) PutOutboundElement(elem *QueueOutboundElement) {
   144  	elem.clearPointers()
   145  	transport.pool.outboundElements.Put(elem)
   146  }