github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/go.mongodb.org/mongo-driver/x/mongo/driver/compression.go (about)

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package driver
     8  
     9  import (
    10  	"bytes"
    11  	"compress/zlib"
    12  	"fmt"
    13  	"io"
    14  	"sync"
    15  
    16  	"github.com/golang/snappy"
    17  	"github.com/klauspost/compress/zstd"
    18  	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
    19  )
    20  
    21  // CompressionOpts holds settings for how to compress a payload
    22  type CompressionOpts struct {
    23  	Compressor       wiremessage.CompressorID
    24  	ZlibLevel        int
    25  	ZstdLevel        int
    26  	UncompressedSize int32
    27  }
    28  
    29  var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder
    30  
    31  func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
    32  	if v, ok := zstdEncoders.Load(level); ok {
    33  		return v.(*zstd.Encoder), nil
    34  	}
    35  	encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  	zstdEncoders.Store(level, encoder)
    40  	return encoder, nil
    41  }
    42  
    43  var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder
    44  
    45  func getZlibEncoder(level int) (*zlibEncoder, error) {
    46  	if v, ok := zlibEncoders.Load(level); ok {
    47  		return v.(*zlibEncoder), nil
    48  	}
    49  	writer, err := zlib.NewWriterLevel(nil, level)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
    54  	zlibEncoders.Store(level, encoder)
    55  
    56  	return encoder, nil
    57  }
    58  
    59  type zlibEncoder struct {
    60  	mu     sync.Mutex
    61  	writer *zlib.Writer
    62  	buf    *bytes.Buffer
    63  }
    64  
    65  func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
    66  	e.mu.Lock()
    67  	defer e.mu.Unlock()
    68  
    69  	e.buf.Reset()
    70  	e.writer.Reset(e.buf)
    71  
    72  	_, err := e.writer.Write(src)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	err = e.writer.Close()
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	dst = append(dst[:0], e.buf.Bytes()...)
    81  	return dst, nil
    82  }
    83  
    84  // CompressPayload takes a byte slice and compresses it according to the options passed
    85  func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
    86  	switch opts.Compressor {
    87  	case wiremessage.CompressorNoOp:
    88  		return in, nil
    89  	case wiremessage.CompressorSnappy:
    90  		return snappy.Encode(nil, in), nil
    91  	case wiremessage.CompressorZLib:
    92  		encoder, err := getZlibEncoder(opts.ZlibLevel)
    93  		if err != nil {
    94  			return nil, err
    95  		}
    96  		return encoder.Encode(nil, in)
    97  	case wiremessage.CompressorZstd:
    98  		encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel))
    99  		if err != nil {
   100  			return nil, err
   101  		}
   102  		return encoder.EncodeAll(in, nil), nil
   103  	default:
   104  		return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
   105  	}
   106  }
   107  
   108  // DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
   109  func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
   110  	switch opts.Compressor {
   111  	case wiremessage.CompressorNoOp:
   112  		return in, nil
   113  	case wiremessage.CompressorSnappy:
   114  		l, err := snappy.DecodedLen(in)
   115  		if err != nil {
   116  			return nil, fmt.Errorf("decoding compressed length %w", err)
   117  		} else if int32(l) != opts.UncompressedSize {
   118  			return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l)
   119  		}
   120  		uncompressed = make([]byte, opts.UncompressedSize)
   121  		return snappy.Decode(uncompressed, in)
   122  	case wiremessage.CompressorZLib:
   123  		r, err := zlib.NewReader(bytes.NewReader(in))
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  		defer func() {
   128  			err = r.Close()
   129  		}()
   130  		uncompressed = make([]byte, opts.UncompressedSize)
   131  		_, err = io.ReadFull(r, uncompressed)
   132  		if err != nil {
   133  			return nil, err
   134  		}
   135  		return uncompressed, nil
   136  	case wiremessage.CompressorZstd:
   137  		r, err := zstd.NewReader(bytes.NewBuffer(in))
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  		defer r.Close()
   142  		uncompressed = make([]byte, opts.UncompressedSize)
   143  		_, err = io.ReadFull(r, uncompressed)
   144  		if err != nil {
   145  			return nil, err
   146  		}
   147  		return uncompressed, nil
   148  	default:
   149  		return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
   150  	}
   151  }