github.com/grailbio/base@v0.0.11/recordio/recordiozstd/recordiozstd.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package recordiozstd
     6  
     7  import (
     8  	"strconv"
     9  	"sync"
    10  
    11  	"github.com/grailbio/base/compress/zstd"
    12  	"github.com/grailbio/base/recordio"
    13  	"github.com/grailbio/base/recordio/recordioiov"
    14  )
    15  
    16  // Name is the registered name of the zstd transformer.
    17  const Name = "zstd"
    18  
    19  func parseConfig(config string) (level int, err error) {
    20  	level = -1
    21  	if config != "" {
    22  		level, err = strconv.Atoi(config)
    23  	}
    24  	return
    25  }
    26  
    27  var tmpBufPool = sync.Pool{New: func() interface{} { return &[]byte{} }}
    28  
    29  // As of 2018-03, zstd.{Compress,Decompress} is much faster than
    30  // io.{Reader,Writer}-based implementations, even though the former incurs extra
    31  // copying.
    32  //
    33  // Reader/Writer impl:
    34  // BenchmarkWrite-56             20         116151712 ns/op
    35  // BenchmarkRead-56              30          45302918 ns/op
    36  //
    37  // Compress/Decompress impl:
    38  // BenchmarkWrite-56    	      50	  30034396 ns/op
    39  // BenchmarkRead-56    	      50	  23871334 ns/op
    40  func flattenIov(in [][]byte) []byte {
    41  	totalBytes := recordioiov.TotalBytes(in)
    42  
    43  	// storing only pointers in sync.Pool per https://github.com/golang/go/issues/16323
    44  	slicePtr := tmpBufPool.Get().(*[]byte)
    45  	tmp := recordioiov.Slice(*slicePtr, totalBytes)
    46  	n := 0
    47  	for _, inbuf := range in {
    48  		copy(tmp[n:], inbuf)
    49  		n += len(inbuf)
    50  	}
    51  	return tmp
    52  }
    53  
    54  func zstdCompress(level int, scratch []byte, in [][]byte) ([]byte, error) {
    55  	if len(in) == 0 {
    56  		return zstd.CompressLevel(scratch, nil, level)
    57  	}
    58  	if len(in) == 1 {
    59  		return zstd.CompressLevel(scratch, in[0], level)
    60  	}
    61  	tmp := flattenIov(in)
    62  	d, err := zstd.CompressLevel(scratch, tmp, level)
    63  	tmpBufPool.Put(&tmp)
    64  	return d, err
    65  }
    66  
    67  func zstdUncompress(scratch []byte, in [][]byte) ([]byte, error) {
    68  	if len(in) == 0 {
    69  		return zstd.Decompress(scratch, nil)
    70  	}
    71  	if len(in) == 1 {
    72  		return zstd.Decompress(scratch, in[0])
    73  	}
    74  	tmp := flattenIov(in)
    75  	d, err := zstd.Decompress(scratch, tmp)
    76  	tmpBufPool.Put(&tmp)
    77  	return d, err
    78  }
    79  
    80  var once = sync.Once{}
    81  
    82  // Init installs the zstd transformer in recordio.  It can be called multiple
    83  // times, but 2nd and later calls have no effect.
    84  func Init() {
    85  	once.Do(func() {
    86  		recordio.RegisterTransformer(
    87  			Name,
    88  			func(config string) (recordio.TransformFunc, error) {
    89  				level, err := parseConfig(config)
    90  				if err != nil {
    91  					return nil, err
    92  				}
    93  				return func(scratch []byte, in [][]byte) ([]byte, error) {
    94  					return zstdCompress(level, scratch, in)
    95  				}, nil
    96  			},
    97  			func(string) (recordio.TransformFunc, error) {
    98  				return zstdUncompress, nil
    99  			})
   100  	})
   101  }