github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/utils/pool/pool.go (about)

     1  package pool
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"io"
     7  	"math"
     8  	"math/bits"
     9  	"net"
    10  	"net/http/httputil"
    11  	"sync"
    12  	"unsafe"
    13  
    14  	"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
    15  	"golang.org/x/exp/constraints"
    16  )
    17  
    18  var MaxSegmentSize = math.MaxUint16
    19  
    20  type Pool interface {
    21  	GetBytes(size int) []byte
    22  	PutBytes(b []byte)
    23  
    24  	GetBuffer() *bytes.Buffer
    25  	PutBuffer(b *bytes.Buffer)
    26  }
    27  
    28  const DefaultSize = 16 * 0x400
    29  
    30  var DefaultPool Pool = &pool{}
    31  
    32  func GetBytes[T constraints.Integer](size T) []byte { return DefaultPool.GetBytes(int(size)) }
    33  func PutBytes(b []byte)                             { DefaultPool.PutBytes(b) }
    34  func GetBuffer() *bytes.Buffer                      { return DefaultPool.GetBuffer() }
    35  func PutBuffer(b *bytes.Buffer)                     { DefaultPool.PutBuffer(b) }
    36  
    37  var _ httputil.BufferPool = (*ReverseProxyBuffer)(nil)
    38  
    39  type ReverseProxyBuffer struct{}
    40  
    41  func (ReverseProxyBuffer) Get() []byte  { return GetBytes(DefaultSize) }
    42  func (ReverseProxyBuffer) Put(b []byte) { PutBytes(b) }
    43  
    44  var poolMap syncmap.SyncMap[int, *sync.Pool]
    45  
    46  type pool struct{}
    47  
    48  func buffPool(size int) *sync.Pool {
    49  	if v, ok := poolMap.Load(size); ok {
    50  		return v
    51  	}
    52  
    53  	p := &sync.Pool{New: func() any { return make([]byte, size) }}
    54  	poolMap.Store(size, p)
    55  	return p
    56  }
    57  
    58  func (pool) GetBytes(size int) []byte {
    59  	if size == 0 {
    60  		return nil
    61  	}
    62  
    63  	l := bits.Len(uint(size)) - 1
    64  	if size != 1<<l {
    65  		size = 1 << (l + 1)
    66  	}
    67  	return buffPool(size).Get().([]byte)
    68  }
    69  
    70  func (pool) PutBytes(b []byte) {
    71  	if len(b) == 0 {
    72  		return
    73  	}
    74  
    75  	l := bits.Len(uint(len(b))) - 1
    76  	buffPool(1 << l).Put(b) //lint:ignore SA6002 ignore temporarily
    77  }
    78  
    79  var bufpool = sync.Pool{New: func() any {
    80  	buffer := bytes.NewBuffer(make([]byte, DefaultSize))
    81  	buffer.Reset()
    82  	return buffer
    83  }}
    84  
    85  func (pool) GetBuffer() *bytes.Buffer { return bufpool.Get().(*bytes.Buffer) }
    86  func (pool) PutBuffer(b *bytes.Buffer) {
    87  	if b != nil {
    88  		b.Reset()
    89  		bufpool.Put(b)
    90  	}
    91  }
    92  
    93  type MultipleBuffer []*Buffer
    94  
    95  func (m MultipleBuffer) Free() {
    96  	for _, v := range m {
    97  		v.Free()
    98  	}
    99  }
   100  
   101  type MultipleBytes []*Bytes
   102  
   103  func (m MultipleBytes) Free() {
   104  	for _, v := range m {
   105  		v.Free()
   106  	}
   107  }
   108  
   109  type Bytes struct {
   110  	once  sync.Once
   111  	buf   []byte
   112  	start int
   113  	end   int
   114  }
   115  
   116  func (b *Bytes) Bytes() []byte          { return b.buf[b.start:b.end] }
   117  func (b *Bytes) String() string         { return string(b.Bytes()) }
   118  func (b *Bytes) After(index int) []byte { return b.buf[b.start+index : b.end] }
   119  func (b *Bytes) Refactor(start, end int) *Bytes {
   120  	if end <= len(b.buf) {
   121  		b.end = end
   122  	}
   123  
   124  	if start >= 0 && start <= end {
   125  		b.start = start
   126  	}
   127  
   128  	return b
   129  }
   130  
   131  func (b *Bytes) Copy(byte []byte) *Bytes {
   132  	b.end = b.start + copy(b.Bytes(), byte)
   133  	return b
   134  }
   135  
   136  func (b *Bytes) Len() int { return b.end - b.start }
   137  
   138  func (b *Bytes) ReadFrom(c io.Reader) (int64, error) {
   139  	n, err := c.Read(b.Bytes())
   140  	if err != nil {
   141  		return int64(n), err
   142  	}
   143  
   144  	b.end = n
   145  
   146  	return int64(n), err
   147  }
   148  
   149  func (b *Bytes) ReadFull(c io.Reader) (int64, error) {
   150  	n, err := io.ReadFull(c, b.Bytes())
   151  	if err == io.EOF || err == io.ErrUnexpectedEOF {
   152  		err = nil
   153  	}
   154  
   155  	return int64(n), err
   156  }
   157  
   158  func (b *Bytes) AsWriter() *Buffer {
   159  	b.start = 0
   160  	b.end = 0
   161  
   162  	return &Buffer{b}
   163  }
   164  
   165  func (b *Bytes) ReadFromPacket(pc net.PacketConn) (int, net.Addr, error) {
   166  	n, addr, err := pc.ReadFrom(b.Bytes())
   167  	if err != nil {
   168  		return n, addr, err
   169  	}
   170  
   171  	b.end = n
   172  
   173  	return n, addr, err
   174  }
   175  
   176  func (b *Bytes) Free() {
   177  	putBytesBuffer(b)
   178  }
   179  
   180  func NewBytesBuffer(b []byte) *Bytes { return &Bytes{sync.Once{}, b, 0, len(b)} }
   181  
   182  func GetBytesBuffer[T constraints.Integer](size T) *Bytes {
   183  	return &Bytes{sync.Once{},
   184  		GetBytes(size), 0, int(size)}
   185  }
   186  
   187  func putBytesBuffer(b *Bytes) { b.once.Do(func() { PutBytes(b.buf) }) }
   188  
   189  func GetBytesWriter[T constraints.Integer](size T) *Buffer {
   190  	b := &Bytes{sync.Once{},
   191  		GetBytes(size), 0, 0}
   192  	return &Buffer{b}
   193  }
   194  
   195  type Buffer struct {
   196  	b *Bytes
   197  }
   198  
   199  func NewBuffer(b []byte) *Buffer { return &Buffer{NewBytesBuffer(b)} }
   200  
   201  func (b *Buffer) freeSlice() []byte {
   202  	return b.b.buf[b.b.end:]
   203  }
   204  
   205  func (b *Buffer) WriteUint16(v uint16) {
   206  	if len(b.freeSlice()) < 2 {
   207  		return
   208  	}
   209  
   210  	binary.BigEndian.PutUint16(b.freeSlice(), v)
   211  	b.b.end += 2
   212  }
   213  
   214  func (b *Buffer) WriteLittleEndianUint16(v uint16) {
   215  	if len(b.freeSlice()) < 2 {
   216  		return
   217  	}
   218  
   219  	binary.LittleEndian.PutUint16(b.freeSlice(), v)
   220  	b.b.end += 2
   221  }
   222  
   223  func (b *Buffer) WriteUint32(v uint32) {
   224  	if len(b.freeSlice()) < 4 {
   225  		return
   226  	}
   227  
   228  	binary.BigEndian.PutUint32(b.freeSlice(), v)
   229  	b.b.end += 4
   230  }
   231  func (b *Buffer) WriteLittleEndianUint32(v uint32) {
   232  	if len(b.freeSlice()) < 4 {
   233  		return
   234  	}
   235  
   236  	binary.LittleEndian.PutUint32(b.freeSlice(), v)
   237  	b.b.end += 4
   238  }
   239  
   240  func (b *Buffer) WriteUint64(v uint64) {
   241  	if len(b.freeSlice()) < 8 {
   242  		return
   243  	}
   244  
   245  	binary.BigEndian.PutUint64(b.freeSlice(), v)
   246  	b.b.end += 8
   247  }
   248  
   249  func (b *Buffer) WriteLittleEndianUint64(v uint64) {
   250  	if len(b.freeSlice()) < 8 {
   251  		return
   252  	}
   253  
   254  	binary.LittleEndian.PutUint64(b.freeSlice(), v)
   255  	b.b.end += 8
   256  }
   257  
   258  func (b *Buffer) Write(bb []byte) (int, error) {
   259  	n := copy(b.freeSlice(), bb)
   260  	b.b.end += n
   261  	return n, nil
   262  }
   263  
   264  func (b *Buffer) Advance(i int) {
   265  	free := len(b.freeSlice())
   266  	if free < i {
   267  		b.b.end += free
   268  	} else {
   269  		b.b.end += i
   270  	}
   271  }
   272  
   273  func (b *Buffer) WriteString(s string) {
   274  	_, _ = b.Write(unsafe.Slice(unsafe.StringData(s), len(s)))
   275  }
   276  
   277  func (b *Buffer) WriteByte(v byte) error {
   278  	_, err := b.Write([]byte{v})
   279  	return err
   280  }
   281  
   282  func (b *Buffer) ReadFrom(c io.Reader) (int64, error) {
   283  	return b.b.ReadFrom(c)
   284  }
   285  
   286  func (b *Buffer) ReadFromPacket(pc net.PacketConn) (int, net.Addr, error) {
   287  	return b.b.ReadFromPacket(pc)
   288  }
   289  
   290  func (b *Buffer) Len() int      { return b.b.Len() }
   291  func (b *Buffer) Bytes() []byte { return b.b.Bytes() }
   292  
   293  func (b *Buffer) String() string { return b.b.String() }
   294  
   295  func (b *Buffer) Truncate(n int) {
   296  	if n <= 0 {
   297  		b.b.start = 0
   298  		b.b.end = 0
   299  		return
   300  	}
   301  
   302  	if n >= b.b.end {
   303  		return
   304  	}
   305  
   306  	b.b.end = n
   307  }
   308  
   309  func (b *Buffer) Discard(n int) []byte {
   310  	if n > b.b.end-b.b.start {
   311  		x := b.Bytes()
   312  		b.b.start = b.b.end
   313  		return x
   314  	}
   315  
   316  	x := b.b.buf[b.b.start : b.b.start+n]
   317  	b.b.start += n
   318  	return x
   319  }
   320  
   321  func (b *Buffer) Unwrap() *Bytes { return b.b }
   322  
   323  func (b *Buffer) Free() {
   324  	putBytesBuffer(b.b)
   325  }