github.com/deanMdreon/kafka-go@v0.4.32/compress/snappy/xerial.go (about)

     1  package snappy
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"io"
     7  
     8  	"github.com/klauspost/compress/snappy"
     9  )
    10  
    11  const defaultBufferSize = 32 * 1024
    12  
    13  // An implementation of io.Reader which consumes a stream of xerial-framed
    14  // snappy-encoeded data. The framing is optional, if no framing is detected
    15  // the reader will simply forward the bytes from its underlying stream.
    16  type xerialReader struct {
    17  	reader io.Reader
    18  	header [16]byte
    19  	input  []byte
    20  	output []byte
    21  	offset int64
    22  	nbytes int64
    23  	decode func([]byte, []byte) ([]byte, error)
    24  }
    25  
    26  func (x *xerialReader) Reset(r io.Reader) {
    27  	x.reader = r
    28  	x.input = x.input[:0]
    29  	x.output = x.output[:0]
    30  	x.header = [16]byte{}
    31  	x.offset = 0
    32  	x.nbytes = 0
    33  }
    34  
    35  func (x *xerialReader) Read(b []byte) (int, error) {
    36  	for {
    37  		if x.offset < int64(len(x.output)) {
    38  			n := copy(b, x.output[x.offset:])
    39  			x.offset += int64(n)
    40  			return n, nil
    41  		}
    42  
    43  		n, err := x.readChunk(b)
    44  		if err != nil {
    45  			return 0, err
    46  		}
    47  		if n > 0 {
    48  			return n, nil
    49  		}
    50  	}
    51  }
    52  
    53  func (x *xerialReader) WriteTo(w io.Writer) (int64, error) {
    54  	wn := int64(0)
    55  
    56  	for {
    57  		for x.offset < int64(len(x.output)) {
    58  			n, err := w.Write(x.output[x.offset:])
    59  			wn += int64(n)
    60  			x.offset += int64(n)
    61  			if err != nil {
    62  				return wn, err
    63  			}
    64  		}
    65  
    66  		if _, err := x.readChunk(nil); err != nil {
    67  			if err == io.EOF {
    68  				err = nil
    69  			}
    70  			return wn, err
    71  		}
    72  	}
    73  }
    74  
    75  func (x *xerialReader) readChunk(dst []byte) (int, error) {
    76  	x.output = x.output[:0]
    77  	x.offset = 0
    78  	prefix := 0
    79  
    80  	if x.nbytes == 0 {
    81  		n, err := x.readFull(x.header[:])
    82  		if err != nil && n == 0 {
    83  			return 0, err
    84  		}
    85  		prefix = n
    86  	}
    87  
    88  	if isXerialHeader(x.header[:]) {
    89  		if cap(x.input) < 4 {
    90  			x.input = make([]byte, 4, defaultBufferSize)
    91  		} else {
    92  			x.input = x.input[:4]
    93  		}
    94  
    95  		_, err := x.readFull(x.input)
    96  		if err != nil {
    97  			return 0, err
    98  		}
    99  
   100  		frame := int(binary.BigEndian.Uint32(x.input))
   101  		if cap(x.input) < frame {
   102  			x.input = make([]byte, frame, align(frame, defaultBufferSize))
   103  		} else {
   104  			x.input = x.input[:frame]
   105  		}
   106  
   107  		if _, err := x.readFull(x.input); err != nil {
   108  			return 0, err
   109  		}
   110  	} else {
   111  		if cap(x.input) == 0 {
   112  			x.input = make([]byte, 0, defaultBufferSize)
   113  		} else {
   114  			x.input = x.input[:0]
   115  		}
   116  
   117  		if prefix > 0 {
   118  			x.input = append(x.input, x.header[:prefix]...)
   119  		}
   120  
   121  		for {
   122  			if len(x.input) == cap(x.input) {
   123  				b := make([]byte, len(x.input), 2*cap(x.input))
   124  				copy(b, x.input)
   125  				x.input = b
   126  			}
   127  
   128  			n, err := x.read(x.input[len(x.input):cap(x.input)])
   129  			x.input = x.input[:len(x.input)+n]
   130  			if err != nil {
   131  				if err == io.EOF && len(x.input) > 0 {
   132  					break
   133  				}
   134  				return 0, err
   135  			}
   136  		}
   137  	}
   138  
   139  	var n int
   140  	var err error
   141  
   142  	if x.decode == nil {
   143  		x.output, x.input, err = x.input, x.output, nil
   144  	} else if n, err = snappy.DecodedLen(x.input); n <= len(dst) && err == nil {
   145  		// If the output buffer is large enough to hold the decode value,
   146  		// write it there directly instead of using the intermediary output
   147  		// buffer.
   148  		_, err = x.decode(dst, x.input)
   149  	} else {
   150  		var b []byte
   151  		n = 0
   152  		b, err = x.decode(x.output[:cap(x.output)], x.input)
   153  		if err == nil {
   154  			x.output = b
   155  		}
   156  	}
   157  
   158  	return n, err
   159  }
   160  
   161  func (x *xerialReader) read(b []byte) (int, error) {
   162  	n, err := x.reader.Read(b)
   163  	x.nbytes += int64(n)
   164  	return n, err
   165  }
   166  
   167  func (x *xerialReader) readFull(b []byte) (int, error) {
   168  	n, err := io.ReadFull(x.reader, b)
   169  	x.nbytes += int64(n)
   170  	return n, err
   171  }
   172  
   173  // An implementation of a xerial-framed snappy-encoded output stream.
   174  // Each Write made to the writer is framed with a xerial header.
   175  type xerialWriter struct {
   176  	writer io.Writer
   177  	header [16]byte
   178  	input  []byte
   179  	output []byte
   180  	nbytes int64
   181  	framed bool
   182  	encode func([]byte, []byte) []byte
   183  }
   184  
   185  func (x *xerialWriter) Reset(w io.Writer) {
   186  	x.writer = w
   187  	x.input = x.input[:0]
   188  	x.output = x.output[:0]
   189  	x.nbytes = 0
   190  }
   191  
   192  func (x *xerialWriter) ReadFrom(r io.Reader) (int64, error) {
   193  	wn := int64(0)
   194  
   195  	if cap(x.input) == 0 {
   196  		x.input = make([]byte, 0, defaultBufferSize)
   197  	}
   198  
   199  	for {
   200  		if x.full() {
   201  			x.grow()
   202  		}
   203  
   204  		n, err := r.Read(x.input[len(x.input):cap(x.input)])
   205  		wn += int64(n)
   206  		x.input = x.input[:len(x.input)+n]
   207  
   208  		if x.fullEnough() {
   209  			if err := x.Flush(); err != nil {
   210  				return wn, err
   211  			}
   212  		}
   213  
   214  		if err != nil {
   215  			if err == io.EOF {
   216  				err = nil
   217  			}
   218  			return wn, err
   219  		}
   220  	}
   221  }
   222  
   223  func (x *xerialWriter) Write(b []byte) (int, error) {
   224  	wn := 0
   225  
   226  	if cap(x.input) == 0 {
   227  		x.input = make([]byte, 0, defaultBufferSize)
   228  	}
   229  
   230  	for len(b) > 0 {
   231  		if x.full() {
   232  			x.grow()
   233  		}
   234  
   235  		n := copy(x.input[len(x.input):cap(x.input)], b)
   236  		b = b[n:]
   237  		wn += n
   238  		x.input = x.input[:len(x.input)+n]
   239  
   240  		if x.fullEnough() {
   241  			if err := x.Flush(); err != nil {
   242  				return wn, err
   243  			}
   244  		}
   245  	}
   246  
   247  	return wn, nil
   248  }
   249  
   250  func (x *xerialWriter) Flush() error {
   251  	if len(x.input) == 0 {
   252  		return nil
   253  	}
   254  
   255  	var b []byte
   256  	if x.encode == nil {
   257  		b = x.input
   258  	} else {
   259  		x.output = x.encode(x.output[:cap(x.output)], x.input)
   260  		b = x.output
   261  	}
   262  
   263  	x.input = x.input[:0]
   264  	x.output = x.output[:0]
   265  
   266  	if x.framed && x.nbytes == 0 {
   267  		writeXerialHeader(x.header[:])
   268  		_, err := x.write(x.header[:])
   269  		if err != nil {
   270  			return err
   271  		}
   272  	}
   273  
   274  	if x.framed {
   275  		writeXerialFrame(x.header[:4], len(b))
   276  		_, err := x.write(x.header[:4])
   277  		if err != nil {
   278  			return err
   279  		}
   280  	}
   281  
   282  	_, err := x.write(b)
   283  	return err
   284  }
   285  
   286  func (x *xerialWriter) write(b []byte) (int, error) {
   287  	n, err := x.writer.Write(b)
   288  	x.nbytes += int64(n)
   289  	return n, err
   290  }
   291  
   292  func (x *xerialWriter) full() bool {
   293  	return len(x.input) == cap(x.input)
   294  }
   295  
   296  func (x *xerialWriter) fullEnough() bool {
   297  	return x.framed && (cap(x.input)-len(x.input)) < 1024
   298  }
   299  
   300  func (x *xerialWriter) grow() {
   301  	tmp := make([]byte, len(x.input), 2*cap(x.input))
   302  	copy(tmp, x.input)
   303  	x.input = tmp
   304  }
   305  
   306  func align(n, a int) int {
   307  	if (n % a) == 0 {
   308  		return n
   309  	}
   310  	return ((n / a) + 1) * a
   311  }
   312  
   313  var (
   314  	xerialHeader      = [...]byte{130, 83, 78, 65, 80, 80, 89, 0}
   315  	xerialVersionInfo = [...]byte{0, 0, 0, 1, 0, 0, 0, 1}
   316  )
   317  
   318  func isXerialHeader(src []byte) bool {
   319  	return len(src) >= 16 && bytes.Equal(src[:8], xerialHeader[:])
   320  }
   321  
   322  func writeXerialHeader(b []byte) {
   323  	copy(b[:8], xerialHeader[:])
   324  	copy(b[8:], xerialVersionInfo[:])
   325  }
   326  
   327  func writeXerialFrame(b []byte, n int) {
   328  	binary.BigEndian.PutUint32(b, uint32(n))
   329  }