github.com/apache/arrow/go/v14@v14.0.1/parquet/compress/zstd.go (about) 1 // Licensed to the Apache Software Foundation (ASF) under one 2 // or more contributor license agreements. See the NOTICE file 3 // distributed with this work for additional information 4 // regarding copyright ownership. The ASF licenses this file 5 // to you under the Apache License, Version 2.0 (the 6 // "License"); you may not use this file except in compliance 7 // with the License. You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 package compress 18 19 import ( 20 "io" 21 "sync" 22 23 "github.com/apache/arrow/go/v14/parquet/internal/debug" 24 "github.com/klauspost/compress/zstd" 25 ) 26 27 type zstdCodec struct{} 28 29 type zstdcloser struct { 30 *zstd.Decoder 31 } 32 33 var ( 34 enc *zstd.Encoder 35 dec *zstd.Decoder 36 initEncoder sync.Once 37 initDecoder sync.Once 38 ) 39 40 func getencoder() *zstd.Encoder { 41 initEncoder.Do(func() { 42 enc, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true)) 43 }) 44 return enc 45 } 46 47 func getdecoder() *zstd.Decoder { 48 initDecoder.Do(func() { 49 dec, _ = zstd.NewReader(nil) 50 }) 51 return dec 52 } 53 54 func (zstdCodec) Decode(dst, src []byte) []byte { 55 dst, err := getdecoder().DecodeAll(src, dst[:0]) 56 if err != nil { 57 panic(err) 58 } 59 return dst 60 } 61 62 func (z *zstdcloser) Close() error { 63 z.Decoder.Close() 64 return nil 65 } 66 67 func (zstdCodec) NewReader(r io.Reader) io.ReadCloser { 68 ret, _ := zstd.NewReader(r) 69 return &zstdcloser{ret} 70 } 71 72 func (zstdCodec) NewWriter(w io.Writer) io.WriteCloser { 73 ret, _ := zstd.NewWriter(w) 74 return ret 75 } 76 77 func (zstdCodec) NewWriterLevel(w io.Writer, level int) (io.WriteCloser, error) { 78 var compressLevel zstd.EncoderLevel 79 if level == DefaultCompressionLevel { 80 compressLevel = zstd.SpeedDefault 81 } else { 82 compressLevel = zstd.EncoderLevelFromZstd(level) 83 } 84 return zstd.NewWriter(w, zstd.WithEncoderLevel(compressLevel)) 85 } 86 87 func (z zstdCodec) Encode(dst, src []byte) []byte { 88 return getencoder().EncodeAll(src, dst[:0]) 89 } 90 91 func (z zstdCodec) EncodeLevel(dst, src []byte, level int) []byte { 92 compressLevel := zstd.EncoderLevelFromZstd(level) 93 if level == DefaultCompressionLevel { 94 compressLevel = zstd.SpeedDefault 95 } 96 enc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true), zstd.WithEncoderLevel(compressLevel)) 97 return enc.EncodeAll(src, dst[:0]) 98 } 99 100 // from zstd.h, ZSTD_COMPRESSBOUND 101 func (zstdCodec) CompressBound(len int64) int64 { 102 debug.Assert(len > 0, "len for zstd CompressBound should be > 0") 103 extra := ((128 << 10) - len) >> 11 104 if len >= (128 << 10) { 105 extra = 0 106 } 107 return len + (len >> 8) + extra 108 } 109 110 func init() { 111 codecs[Codecs.Zstd] = zstdCodec{} 112 }