github.com/grailbio/base@v0.0.11/recordio/recordiozstd/recordiozstd.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package recordiozstd 6 7 import ( 8 "strconv" 9 "sync" 10 11 "github.com/grailbio/base/compress/zstd" 12 "github.com/grailbio/base/recordio" 13 "github.com/grailbio/base/recordio/recordioiov" 14 ) 15 16 // Name is the registered name of the zstd transformer. 17 const Name = "zstd" 18 19 func parseConfig(config string) (level int, err error) { 20 level = -1 21 if config != "" { 22 level, err = strconv.Atoi(config) 23 } 24 return 25 } 26 27 var tmpBufPool = sync.Pool{New: func() interface{} { return &[]byte{} }} 28 29 // As of 2018-03, zstd.{Compress,Decompress} is much faster than 30 // io.{Reader,Writer}-based implementations, even though the former incurs extra 31 // copying. 32 // 33 // Reader/Writer impl: 34 // BenchmarkWrite-56 20 116151712 ns/op 35 // BenchmarkRead-56 30 45302918 ns/op 36 // 37 // Compress/Decompress impl: 38 // BenchmarkWrite-56 50 30034396 ns/op 39 // BenchmarkRead-56 50 23871334 ns/op 40 func flattenIov(in [][]byte) []byte { 41 totalBytes := recordioiov.TotalBytes(in) 42 43 // storing only pointers in sync.Pool per https://github.com/golang/go/issues/16323 44 slicePtr := tmpBufPool.Get().(*[]byte) 45 tmp := recordioiov.Slice(*slicePtr, totalBytes) 46 n := 0 47 for _, inbuf := range in { 48 copy(tmp[n:], inbuf) 49 n += len(inbuf) 50 } 51 return tmp 52 } 53 54 func zstdCompress(level int, scratch []byte, in [][]byte) ([]byte, error) { 55 if len(in) == 0 { 56 return zstd.CompressLevel(scratch, nil, level) 57 } 58 if len(in) == 1 { 59 return zstd.CompressLevel(scratch, in[0], level) 60 } 61 tmp := flattenIov(in) 62 d, err := zstd.CompressLevel(scratch, tmp, level) 63 tmpBufPool.Put(&tmp) 64 return d, err 65 } 66 67 func zstdUncompress(scratch []byte, in [][]byte) ([]byte, error) { 68 if len(in) == 0 { 69 return zstd.Decompress(scratch, nil) 70 } 71 if len(in) == 1 { 72 return zstd.Decompress(scratch, in[0]) 73 } 74 tmp := flattenIov(in) 75 d, err := zstd.Decompress(scratch, tmp) 76 tmpBufPool.Put(&tmp) 77 return d, err 78 } 79 80 var once = sync.Once{} 81 82 // Init installs the zstd transformer in recordio. It can be called multiple 83 // times, but 2nd and later calls have no effect. 84 func Init() { 85 once.Do(func() { 86 recordio.RegisterTransformer( 87 Name, 88 func(config string) (recordio.TransformFunc, error) { 89 level, err := parseConfig(config) 90 if err != nil { 91 return nil, err 92 } 93 return func(scratch []byte, in [][]byte) ([]byte, error) { 94 return zstdCompress(level, scratch, in) 95 }, nil 96 }, 97 func(string) (recordio.TransformFunc, error) { 98 return zstdUncompress, nil 99 }) 100 }) 101 }