github.com/segmentio/kafka-go@v0.4.48-0.20240318174348-3f6244eb34fd/compress/zstd/zstd.go (about)

     1  // Package zstd implements Zstandard compression.
     2  package zstd
     3  
     4  import (
     5  	"io"
     6  	"sync"
     7  
     8  	"github.com/klauspost/compress/zstd"
     9  )
    10  
    11  // Codec is the implementation of a compress.Codec which supports creating
    12  // readers and writers for kafka messages compressed with zstd.
    13  type Codec struct {
    14  	// The compression level configured on writers created by the codec.
    15  	//
    16  	// Default to 3.
    17  	Level int
    18  
    19  	encoderPool sync.Pool // *encoder
    20  }
    21  
    22  // Code implements the compress.Codec interface.
    23  func (c *Codec) Code() int8 { return 4 }
    24  
    25  // Name implements the compress.Codec interface.
    26  func (c *Codec) Name() string { return "zstd" }
    27  
    28  // NewReader implements the compress.Codec interface.
    29  func (c *Codec) NewReader(r io.Reader) io.ReadCloser {
    30  	p := new(reader)
    31  	if p.dec, _ = decoderPool.Get().(*zstd.Decoder); p.dec != nil {
    32  		p.dec.Reset(r)
    33  	} else {
    34  		z, err := zstd.NewReader(r,
    35  			zstd.WithDecoderConcurrency(1),
    36  		)
    37  		if err != nil {
    38  			p.err = err
    39  		} else {
    40  			p.dec = z
    41  		}
    42  	}
    43  	return p
    44  }
    45  
    46  func (c *Codec) level() int {
    47  	if c.Level != 0 {
    48  		return c.Level
    49  	}
    50  	return 3
    51  }
    52  
    53  func (c *Codec) zstdLevel() zstd.EncoderLevel {
    54  	return zstd.EncoderLevelFromZstd(c.level())
    55  }
    56  
    57  var decoderPool sync.Pool // *zstd.Decoder
    58  
    59  type reader struct {
    60  	dec *zstd.Decoder
    61  	err error
    62  }
    63  
    64  // Close implements the io.Closer interface.
    65  func (r *reader) Close() error {
    66  	if r.dec != nil {
    67  		r.dec.Reset(devNull{}) // don't retain the underlying reader
    68  		decoderPool.Put(r.dec)
    69  		r.dec = nil
    70  		r.err = io.ErrClosedPipe
    71  	}
    72  	return nil
    73  }
    74  
    75  // Read implements the io.Reader interface.
    76  func (r *reader) Read(p []byte) (int, error) {
    77  	if r.err != nil {
    78  		return 0, r.err
    79  	}
    80  	if r.dec == nil {
    81  		return 0, io.EOF
    82  	}
    83  	return r.dec.Read(p)
    84  }
    85  
    86  // WriteTo implements the io.WriterTo interface.
    87  func (r *reader) WriteTo(w io.Writer) (int64, error) {
    88  	if r.err != nil {
    89  		return 0, r.err
    90  	}
    91  	if r.dec == nil {
    92  		return 0, io.ErrClosedPipe
    93  	}
    94  	return r.dec.WriteTo(w)
    95  }
    96  
    97  // NewWriter implements the compress.Codec interface.
    98  func (c *Codec) NewWriter(w io.Writer) io.WriteCloser {
    99  	p := new(writer)
   100  	if enc, _ := c.encoderPool.Get().(*zstd.Encoder); enc == nil {
   101  		z, err := zstd.NewWriter(w,
   102  			zstd.WithEncoderLevel(c.zstdLevel()),
   103  			zstd.WithEncoderConcurrency(1),
   104  			zstd.WithZeroFrames(true),
   105  		)
   106  		if err != nil {
   107  			p.err = err
   108  		} else {
   109  			p.enc = z
   110  		}
   111  	} else {
   112  		p.enc = enc
   113  		p.enc.Reset(w)
   114  	}
   115  	p.c = c
   116  	return p
   117  }
   118  
   119  type writer struct {
   120  	c   *Codec
   121  	enc *zstd.Encoder
   122  	err error
   123  }
   124  
   125  // Close implements the io.Closer interface.
   126  func (w *writer) Close() error {
   127  	if w.enc != nil {
   128  		// Close needs to be called to write the end of stream marker and flush
   129  		// the buffers. The zstd package documents that the encoder is re-usable
   130  		// after being closed.
   131  		err := w.enc.Close()
   132  		if err != nil {
   133  			w.err = err
   134  		}
   135  		w.enc.Reset(devNull{}) // don't retain the underlying writer
   136  		w.c.encoderPool.Put(w.enc)
   137  		w.enc = nil
   138  		return err
   139  	}
   140  	return w.err
   141  }
   142  
   143  // WriteTo implements the io.WriterTo interface.
   144  func (w *writer) Write(p []byte) (int, error) {
   145  	if w.err != nil {
   146  		return 0, w.err
   147  	}
   148  	if w.enc == nil {
   149  		return 0, io.ErrClosedPipe
   150  	}
   151  	return w.enc.Write(p)
   152  }
   153  
   154  // ReadFrom implements the io.ReaderFrom interface.
   155  func (w *writer) ReadFrom(r io.Reader) (int64, error) {
   156  	if w.err != nil {
   157  		return 0, w.err
   158  	}
   159  	if w.enc == nil {
   160  		return 0, io.ErrClosedPipe
   161  	}
   162  	return w.enc.ReadFrom(r)
   163  }
   164  
   165  type devNull struct{}
   166  
   167  func (devNull) Read([]byte) (int, error)  { return 0, io.EOF }
   168  func (devNull) Write([]byte) (int, error) { return 0, nil }