github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/go.mongodb.org/mongo-driver/x/mongo/driver/compression.go (about) 1 // Copyright (C) MongoDB, Inc. 2017-present. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); you may 4 // not use this file except in compliance with the License. You may obtain 5 // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7 package driver 8 9 import ( 10 "bytes" 11 "compress/zlib" 12 "fmt" 13 "io" 14 "sync" 15 16 "github.com/golang/snappy" 17 "github.com/klauspost/compress/zstd" 18 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" 19 ) 20 21 // CompressionOpts holds settings for how to compress a payload 22 type CompressionOpts struct { 23 Compressor wiremessage.CompressorID 24 ZlibLevel int 25 ZstdLevel int 26 UncompressedSize int32 27 } 28 29 var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder 30 31 func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) { 32 if v, ok := zstdEncoders.Load(level); ok { 33 return v.(*zstd.Encoder), nil 34 } 35 encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) 36 if err != nil { 37 return nil, err 38 } 39 zstdEncoders.Store(level, encoder) 40 return encoder, nil 41 } 42 43 var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder 44 45 func getZlibEncoder(level int) (*zlibEncoder, error) { 46 if v, ok := zlibEncoders.Load(level); ok { 47 return v.(*zlibEncoder), nil 48 } 49 writer, err := zlib.NewWriterLevel(nil, level) 50 if err != nil { 51 return nil, err 52 } 53 encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)} 54 zlibEncoders.Store(level, encoder) 55 56 return encoder, nil 57 } 58 59 type zlibEncoder struct { 60 mu sync.Mutex 61 writer *zlib.Writer 62 buf *bytes.Buffer 63 } 64 65 func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) { 66 e.mu.Lock() 67 defer e.mu.Unlock() 68 69 e.buf.Reset() 70 e.writer.Reset(e.buf) 71 72 _, err := e.writer.Write(src) 73 if err != nil { 74 return nil, err 75 } 76 err = e.writer.Close() 77 if err != nil { 78 return nil, err 79 } 80 dst = append(dst[:0], e.buf.Bytes()...) 81 return dst, nil 82 } 83 84 // CompressPayload takes a byte slice and compresses it according to the options passed 85 func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) { 86 switch opts.Compressor { 87 case wiremessage.CompressorNoOp: 88 return in, nil 89 case wiremessage.CompressorSnappy: 90 return snappy.Encode(nil, in), nil 91 case wiremessage.CompressorZLib: 92 encoder, err := getZlibEncoder(opts.ZlibLevel) 93 if err != nil { 94 return nil, err 95 } 96 return encoder.Encode(nil, in) 97 case wiremessage.CompressorZstd: 98 encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel)) 99 if err != nil { 100 return nil, err 101 } 102 return encoder.EncodeAll(in, nil), nil 103 default: 104 return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) 105 } 106 } 107 108 // DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed 109 func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) { 110 switch opts.Compressor { 111 case wiremessage.CompressorNoOp: 112 return in, nil 113 case wiremessage.CompressorSnappy: 114 l, err := snappy.DecodedLen(in) 115 if err != nil { 116 return nil, fmt.Errorf("decoding compressed length %w", err) 117 } else if int32(l) != opts.UncompressedSize { 118 return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) 119 } 120 uncompressed = make([]byte, opts.UncompressedSize) 121 return snappy.Decode(uncompressed, in) 122 case wiremessage.CompressorZLib: 123 r, err := zlib.NewReader(bytes.NewReader(in)) 124 if err != nil { 125 return nil, err 126 } 127 defer func() { 128 err = r.Close() 129 }() 130 uncompressed = make([]byte, opts.UncompressedSize) 131 _, err = io.ReadFull(r, uncompressed) 132 if err != nil { 133 return nil, err 134 } 135 return uncompressed, nil 136 case wiremessage.CompressorZstd: 137 r, err := zstd.NewReader(bytes.NewBuffer(in)) 138 if err != nil { 139 return nil, err 140 } 141 defer r.Close() 142 uncompressed = make([]byte, opts.UncompressedSize) 143 _, err = io.ReadFull(r, uncompressed) 144 if err != nil { 145 return nil, err 146 } 147 return uncompressed, nil 148 default: 149 return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor) 150 } 151 }