github.com/apache/arrow/go/v7@v7.0.1/parquet/compress/zstd.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one
     2  // or more contributor license agreements.  See the NOTICE file
     3  // distributed with this work for additional information
     4  // regarding copyright ownership.  The ASF licenses this file
     5  // to you under the Apache License, Version 2.0 (the
     6  // "License"); you may not use this file except in compliance
     7  // with the License.  You may obtain a copy of the License at
     8  //
     9  // http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package compress
    18  
    19  import (
    20  	"io"
    21  	"sync"
    22  
    23  	"github.com/apache/arrow/go/v7/parquet/internal/debug"
    24  	"github.com/klauspost/compress/zstd"
    25  )
    26  
    27  type zstdCodec struct{}
    28  
    29  type zstdcloser struct {
    30  	*zstd.Decoder
    31  }
    32  
    33  var (
    34  	enc         *zstd.Encoder
    35  	dec         *zstd.Decoder
    36  	initEncoder sync.Once
    37  	initDecoder sync.Once
    38  )
    39  
    40  func getencoder() *zstd.Encoder {
    41  	initEncoder.Do(func() {
    42  		enc, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true))
    43  	})
    44  	return enc
    45  }
    46  
    47  func getdecoder() *zstd.Decoder {
    48  	initDecoder.Do(func() {
    49  		dec, _ = zstd.NewReader(nil)
    50  	})
    51  	return dec
    52  }
    53  
    54  func (zstdCodec) Decode(dst, src []byte) []byte {
    55  	dst, err := getdecoder().DecodeAll(src, dst[:0])
    56  	if err != nil {
    57  		panic(err)
    58  	}
    59  	return dst
    60  }
    61  
    62  func (z *zstdcloser) Close() error {
    63  	z.Decoder.Close()
    64  	return nil
    65  }
    66  
    67  func (zstdCodec) NewReader(r io.Reader) io.ReadCloser {
    68  	ret, _ := zstd.NewReader(r)
    69  	return &zstdcloser{ret}
    70  }
    71  
    72  func (zstdCodec) NewWriter(w io.Writer) io.WriteCloser {
    73  	ret, _ := zstd.NewWriter(w)
    74  	return ret
    75  }
    76  
    77  func (zstdCodec) NewWriterLevel(w io.Writer, level int) (io.WriteCloser, error) {
    78  	var compressLevel zstd.EncoderLevel
    79  	if level == DefaultCompressionLevel {
    80  		compressLevel = zstd.SpeedDefault
    81  	} else {
    82  		compressLevel = zstd.EncoderLevelFromZstd(level)
    83  	}
    84  	return zstd.NewWriter(w, zstd.WithEncoderLevel(compressLevel))
    85  }
    86  
    87  func (z zstdCodec) Encode(dst, src []byte) []byte {
    88  	return getencoder().EncodeAll(src, dst[:0])
    89  }
    90  
    91  func (z zstdCodec) EncodeLevel(dst, src []byte, level int) []byte {
    92  	compressLevel := zstd.EncoderLevelFromZstd(level)
    93  	if level == DefaultCompressionLevel {
    94  		compressLevel = zstd.SpeedDefault
    95  	}
    96  	enc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true), zstd.WithEncoderLevel(compressLevel))
    97  	return enc.EncodeAll(src, dst[:0])
    98  }
    99  
   100  // from zstd.h, ZSTD_COMPRESSBOUND
   101  func (zstdCodec) CompressBound(len int64) int64 {
   102  	debug.Assert(len > 0, "len for zstd CompressBound should be > 0")
   103  	extra := ((128 << 10) - len) >> 11
   104  	if len >= (128 << 10) {
   105  		extra = 0
   106  	}
   107  	return len + (len >> 8) + extra
   108  }
   109  
   110  func init() {
   111  	codecs[Codecs.Zstd] = zstdCodec{}
   112  }