github.com/hoveychen/kafka-go@v0.4.42/compress/snappy/xerial.go (about)

     1  package snappy
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"io"
     8  
     9  	"github.com/klauspost/compress/snappy"
    10  )
    11  
    12  const defaultBufferSize = 32 * 1024
    13  
    14  // An implementation of io.Reader which consumes a stream of xerial-framed
    15  // snappy-encoeded data. The framing is optional, if no framing is detected
    16  // the reader will simply forward the bytes from its underlying stream.
    17  type xerialReader struct {
    18  	reader io.Reader
    19  	header [16]byte
    20  	input  []byte
    21  	output []byte
    22  	offset int64
    23  	nbytes int64
    24  	decode func([]byte, []byte) ([]byte, error)
    25  }
    26  
    27  func (x *xerialReader) Reset(r io.Reader) {
    28  	x.reader = r
    29  	x.input = x.input[:0]
    30  	x.output = x.output[:0]
    31  	x.header = [16]byte{}
    32  	x.offset = 0
    33  	x.nbytes = 0
    34  }
    35  
    36  func (x *xerialReader) Read(b []byte) (int, error) {
    37  	for {
    38  		if x.offset < int64(len(x.output)) {
    39  			n := copy(b, x.output[x.offset:])
    40  			x.offset += int64(n)
    41  			return n, nil
    42  		}
    43  
    44  		n, err := x.readChunk(b)
    45  		if err != nil {
    46  			return 0, err
    47  		}
    48  		if n > 0 {
    49  			return n, nil
    50  		}
    51  	}
    52  }
    53  
    54  func (x *xerialReader) WriteTo(w io.Writer) (int64, error) {
    55  	wn := int64(0)
    56  
    57  	for {
    58  		for x.offset < int64(len(x.output)) {
    59  			n, err := w.Write(x.output[x.offset:])
    60  			wn += int64(n)
    61  			x.offset += int64(n)
    62  			if err != nil {
    63  				return wn, err
    64  			}
    65  		}
    66  
    67  		if _, err := x.readChunk(nil); err != nil {
    68  			if errors.Is(err, io.EOF) {
    69  				err = nil
    70  			}
    71  			return wn, err
    72  		}
    73  	}
    74  }
    75  
    76  func (x *xerialReader) readChunk(dst []byte) (int, error) {
    77  	x.output = x.output[:0]
    78  	x.offset = 0
    79  	prefix := 0
    80  
    81  	if x.nbytes == 0 {
    82  		n, err := x.readFull(x.header[:])
    83  		if err != nil && n == 0 {
    84  			return 0, err
    85  		}
    86  		prefix = n
    87  	}
    88  
    89  	if isXerialHeader(x.header[:]) {
    90  		if cap(x.input) < 4 {
    91  			x.input = make([]byte, 4, defaultBufferSize)
    92  		} else {
    93  			x.input = x.input[:4]
    94  		}
    95  
    96  		_, err := x.readFull(x.input)
    97  		if err != nil {
    98  			return 0, err
    99  		}
   100  
   101  		frame := int(binary.BigEndian.Uint32(x.input))
   102  		if cap(x.input) < frame {
   103  			x.input = make([]byte, frame, align(frame, defaultBufferSize))
   104  		} else {
   105  			x.input = x.input[:frame]
   106  		}
   107  
   108  		if _, err := x.readFull(x.input); err != nil {
   109  			return 0, err
   110  		}
   111  	} else {
   112  		if cap(x.input) == 0 {
   113  			x.input = make([]byte, 0, defaultBufferSize)
   114  		} else {
   115  			x.input = x.input[:0]
   116  		}
   117  
   118  		if prefix > 0 {
   119  			x.input = append(x.input, x.header[:prefix]...)
   120  		}
   121  
   122  		for {
   123  			if len(x.input) == cap(x.input) {
   124  				b := make([]byte, len(x.input), 2*cap(x.input))
   125  				copy(b, x.input)
   126  				x.input = b
   127  			}
   128  
   129  			n, err := x.read(x.input[len(x.input):cap(x.input)])
   130  			x.input = x.input[:len(x.input)+n]
   131  			if err != nil {
   132  				if errors.Is(err, io.EOF) && len(x.input) > 0 {
   133  					break
   134  				}
   135  				return 0, err
   136  			}
   137  		}
   138  	}
   139  
   140  	var n int
   141  	var err error
   142  
   143  	if x.decode == nil {
   144  		x.output, x.input, err = x.input, x.output, nil
   145  	} else if n, err = snappy.DecodedLen(x.input); n <= len(dst) && err == nil {
   146  		// If the output buffer is large enough to hold the decode value,
   147  		// write it there directly instead of using the intermediary output
   148  		// buffer.
   149  		_, err = x.decode(dst, x.input)
   150  	} else {
   151  		var b []byte
   152  		n = 0
   153  		b, err = x.decode(x.output[:cap(x.output)], x.input)
   154  		if err == nil {
   155  			x.output = b
   156  		}
   157  	}
   158  
   159  	return n, err
   160  }
   161  
   162  func (x *xerialReader) read(b []byte) (int, error) {
   163  	n, err := x.reader.Read(b)
   164  	x.nbytes += int64(n)
   165  	return n, err
   166  }
   167  
   168  func (x *xerialReader) readFull(b []byte) (int, error) {
   169  	n, err := io.ReadFull(x.reader, b)
   170  	x.nbytes += int64(n)
   171  	return n, err
   172  }
   173  
   174  // An implementation of a xerial-framed snappy-encoded output stream.
   175  // Each Write made to the writer is framed with a xerial header.
   176  type xerialWriter struct {
   177  	writer io.Writer
   178  	header [16]byte
   179  	input  []byte
   180  	output []byte
   181  	nbytes int64
   182  	framed bool
   183  	encode func([]byte, []byte) []byte
   184  }
   185  
   186  func (x *xerialWriter) Reset(w io.Writer) {
   187  	x.writer = w
   188  	x.input = x.input[:0]
   189  	x.output = x.output[:0]
   190  	x.nbytes = 0
   191  }
   192  
   193  func (x *xerialWriter) ReadFrom(r io.Reader) (int64, error) {
   194  	wn := int64(0)
   195  
   196  	if cap(x.input) == 0 {
   197  		x.input = make([]byte, 0, defaultBufferSize)
   198  	}
   199  
   200  	for {
   201  		if x.full() {
   202  			x.grow()
   203  		}
   204  
   205  		n, err := r.Read(x.input[len(x.input):cap(x.input)])
   206  		wn += int64(n)
   207  		x.input = x.input[:len(x.input)+n]
   208  
   209  		if x.fullEnough() {
   210  			if err := x.Flush(); err != nil {
   211  				return wn, err
   212  			}
   213  		}
   214  
   215  		if err != nil {
   216  			if errors.Is(err, io.EOF) {
   217  				err = nil
   218  			}
   219  			return wn, err
   220  		}
   221  	}
   222  }
   223  
   224  func (x *xerialWriter) Write(b []byte) (int, error) {
   225  	wn := 0
   226  
   227  	if cap(x.input) == 0 {
   228  		x.input = make([]byte, 0, defaultBufferSize)
   229  	}
   230  
   231  	for len(b) > 0 {
   232  		if x.full() {
   233  			x.grow()
   234  		}
   235  
   236  		n := copy(x.input[len(x.input):cap(x.input)], b)
   237  		b = b[n:]
   238  		wn += n
   239  		x.input = x.input[:len(x.input)+n]
   240  
   241  		if x.fullEnough() {
   242  			if err := x.Flush(); err != nil {
   243  				return wn, err
   244  			}
   245  		}
   246  	}
   247  
   248  	return wn, nil
   249  }
   250  
   251  func (x *xerialWriter) Flush() error {
   252  	if len(x.input) == 0 {
   253  		return nil
   254  	}
   255  
   256  	var b []byte
   257  	if x.encode == nil {
   258  		b = x.input
   259  	} else {
   260  		x.output = x.encode(x.output[:cap(x.output)], x.input)
   261  		b = x.output
   262  	}
   263  
   264  	x.input = x.input[:0]
   265  	x.output = x.output[:0]
   266  
   267  	if x.framed && x.nbytes == 0 {
   268  		writeXerialHeader(x.header[:])
   269  		_, err := x.write(x.header[:])
   270  		if err != nil {
   271  			return err
   272  		}
   273  	}
   274  
   275  	if x.framed {
   276  		writeXerialFrame(x.header[:4], len(b))
   277  		_, err := x.write(x.header[:4])
   278  		if err != nil {
   279  			return err
   280  		}
   281  	}
   282  
   283  	_, err := x.write(b)
   284  	return err
   285  }
   286  
   287  func (x *xerialWriter) write(b []byte) (int, error) {
   288  	n, err := x.writer.Write(b)
   289  	x.nbytes += int64(n)
   290  	return n, err
   291  }
   292  
   293  func (x *xerialWriter) full() bool {
   294  	return len(x.input) == cap(x.input)
   295  }
   296  
   297  func (x *xerialWriter) fullEnough() bool {
   298  	return x.framed && (cap(x.input)-len(x.input)) < 1024
   299  }
   300  
   301  func (x *xerialWriter) grow() {
   302  	tmp := make([]byte, len(x.input), 2*cap(x.input))
   303  	copy(tmp, x.input)
   304  	x.input = tmp
   305  }
   306  
   307  func align(n, a int) int {
   308  	if (n % a) == 0 {
   309  		return n
   310  	}
   311  	return ((n / a) + 1) * a
   312  }
   313  
   314  var (
   315  	xerialHeader      = [...]byte{130, 83, 78, 65, 80, 80, 89, 0}
   316  	xerialVersionInfo = [...]byte{0, 0, 0, 1, 0, 0, 0, 1}
   317  )
   318  
   319  func isXerialHeader(src []byte) bool {
   320  	return len(src) >= 16 && bytes.Equal(src[:8], xerialHeader[:])
   321  }
   322  
   323  func writeXerialHeader(b []byte) {
   324  	copy(b[:8], xerialHeader[:])
   325  	copy(b[8:], xerialVersionInfo[:])
   326  }
   327  
   328  func writeXerialFrame(b []byte, n int) {
   329  	binary.BigEndian.PutUint32(b, uint32(n))
   330  }