github.com/segmentio/parquet-go@v0.0.0-20230712180008-5d42db8f0d47/compress/compress.go (about)

     1  // Package compress provides the generic APIs implemented by parquet compression
     2  // codecs.
     3  //
     4  // https://github.com/apache/parquet-format/blob/master/Compression.md
     5  package compress
     6  
     7  import (
     8  	"bytes"
     9  	"io"
    10  	"sync"
    11  
    12  	"github.com/segmentio/parquet-go/format"
    13  )
    14  
    15  // The Codec interface represents parquet compression codecs implemented by the
    16  // compress sub-packages.
    17  //
    18  // Codec instances must be safe to use concurrently from multiple goroutines.
    19  type Codec interface {
    20  	// Returns a human-readable name for the codec.
    21  	String() string
    22  
    23  	// Returns the code of the compression codec in the parquet format.
    24  	CompressionCodec() format.CompressionCodec
    25  
    26  	// Writes the compressed version of src to dst and returns it.
    27  	//
    28  	// The method automatically reallocates the output buffer if its capacity
    29  	// was too small to hold the compressed data.
    30  	Encode(dst, src []byte) ([]byte, error)
    31  
    32  	// Writes the uncompressed version of src to dst and returns it.
    33  	//
    34  	// The method automatically reallocates the output buffer if its capacity
    35  	// was too small to hold the uncompressed data.
    36  	Decode(dst, src []byte) ([]byte, error)
    37  }
    38  
    39  type Reader interface {
    40  	io.ReadCloser
    41  	Reset(io.Reader) error
    42  }
    43  
    44  type Writer interface {
    45  	io.WriteCloser
    46  	Reset(io.Writer)
    47  }
    48  
    49  type Compressor struct {
    50  	writers sync.Pool // *writer
    51  }
    52  
    53  type writer struct {
    54  	output bytes.Buffer
    55  	writer Writer
    56  }
    57  
    58  func (c *Compressor) Encode(dst, src []byte, newWriter func(io.Writer) (Writer, error)) ([]byte, error) {
    59  	w, _ := c.writers.Get().(*writer)
    60  	if w != nil {
    61  		w.output = *bytes.NewBuffer(dst[:0])
    62  		w.writer.Reset(&w.output)
    63  	} else {
    64  		w = new(writer)
    65  		w.output = *bytes.NewBuffer(dst[:0])
    66  		var err error
    67  		if w.writer, err = newWriter(&w.output); err != nil {
    68  			return dst, err
    69  		}
    70  	}
    71  
    72  	defer func() {
    73  		w.output = *bytes.NewBuffer(nil)
    74  		w.writer.Reset(io.Discard)
    75  		c.writers.Put(w)
    76  	}()
    77  
    78  	if _, err := w.writer.Write(src); err != nil {
    79  		return w.output.Bytes(), err
    80  	}
    81  	if err := w.writer.Close(); err != nil {
    82  		return w.output.Bytes(), err
    83  	}
    84  	return w.output.Bytes(), nil
    85  }
    86  
    87  type Decompressor struct {
    88  	readers sync.Pool // *reader
    89  }
    90  
    91  type reader struct {
    92  	input  bytes.Reader
    93  	reader Reader
    94  }
    95  
    96  func (d *Decompressor) Decode(dst, src []byte, newReader func(io.Reader) (Reader, error)) ([]byte, error) {
    97  	r, _ := d.readers.Get().(*reader)
    98  	if r != nil {
    99  		r.input.Reset(src)
   100  		if err := r.reader.Reset(&r.input); err != nil {
   101  			return dst, err
   102  		}
   103  	} else {
   104  		r = new(reader)
   105  		r.input.Reset(src)
   106  		var err error
   107  		if r.reader, err = newReader(&r.input); err != nil {
   108  			return dst, err
   109  		}
   110  	}
   111  
   112  	defer func() {
   113  		r.input.Reset(nil)
   114  		if err := r.reader.Reset(nil); err == nil {
   115  			d.readers.Put(r)
   116  		}
   117  	}()
   118  
   119  	if cap(dst) == 0 {
   120  		dst = make([]byte, 0, 2*len(src))
   121  	} else {
   122  		dst = dst[:0]
   123  	}
   124  
   125  	for {
   126  		n, err := r.reader.Read(dst[len(dst):cap(dst)])
   127  		dst = dst[:len(dst)+n]
   128  
   129  		if err != nil {
   130  			if err == io.EOF {
   131  				err = nil
   132  			}
   133  			return dst, err
   134  		}
   135  
   136  		if len(dst) == cap(dst) {
   137  			tmp := make([]byte, len(dst), 2*len(dst))
   138  			copy(tmp, dst)
   139  			dst = tmp
   140  		}
   141  	}
   142  }