github.com/streamdal/segmentio-kafka-go@v0.4.47-streamdal/message_reader.go (about)

     1  package kafka
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"log"
     9  )
    10  
    11  type readBytesFunc func(*bufio.Reader, int, int) (int, error)
    12  
    13  // messageSetReader processes the messages encoded into a fetch response.
    14  // The response may contain a mix of Record Batches (newer format) and Messages
    15  // (older format).
    16  type messageSetReader struct {
    17  	*readerStack      // used for decompressing compressed messages and record batches
    18  	empty        bool // if true, short circuits messageSetReader methods
    19  	debug        bool // enable debug log messages
    20  	// How many bytes are expected to remain in the response.
    21  	//
    22  	// This is used to detect truncation of the response.
    23  	lengthRemain int
    24  
    25  	decompressed *bytes.Buffer
    26  }
    27  
    28  type readerStack struct {
    29  	reader *bufio.Reader
    30  	remain int
    31  	base   int64
    32  	parent *readerStack
    33  	count  int            // how many messages left in the current message set
    34  	header messagesHeader // the current header for a subset of messages within the set.
    35  }
    36  
    37  // messagesHeader describes a set of records. there may be many messagesHeader's in a message set.
    38  type messagesHeader struct {
    39  	firstOffset int64
    40  	length      int32
    41  	crc         int32
    42  	magic       int8
    43  	// v1 composes attributes specific to v0 and v1 message headers
    44  	v1 struct {
    45  		attributes int8
    46  		timestamp  int64
    47  	}
    48  	// v2 composes attributes specific to v2 message headers
    49  	v2 struct {
    50  		leaderEpoch     int32
    51  		attributes      int16
    52  		lastOffsetDelta int32
    53  		firstTimestamp  int64
    54  		lastTimestamp   int64
    55  		producerID      int64
    56  		producerEpoch   int16
    57  		baseSequence    int32
    58  		count           int32
    59  	}
    60  }
    61  
    62  func (h messagesHeader) compression() (codec CompressionCodec, err error) {
    63  	const compressionCodecMask = 0x07
    64  	var code int8
    65  	switch h.magic {
    66  	case 0, 1:
    67  		code = h.v1.attributes & compressionCodecMask
    68  	case 2:
    69  		code = int8(h.v2.attributes & compressionCodecMask)
    70  	default:
    71  		err = h.badMagic()
    72  		return
    73  	}
    74  	if code != 0 {
    75  		codec, err = resolveCodec(code)
    76  	}
    77  	return
    78  }
    79  
    80  func (h messagesHeader) badMagic() error {
    81  	return fmt.Errorf("unsupported magic byte %d in header", h.magic)
    82  }
    83  
    84  func newMessageSetReader(reader *bufio.Reader, remain int) (*messageSetReader, error) {
    85  	res := &messageSetReader{
    86  		readerStack: &readerStack{
    87  			reader: reader,
    88  			remain: remain,
    89  		},
    90  		decompressed: acquireBuffer(),
    91  	}
    92  	err := res.readHeader()
    93  	return res, err
    94  }
    95  
    96  func (r *messageSetReader) remaining() (remain int) {
    97  	if r.empty {
    98  		return 0
    99  	}
   100  	for s := r.readerStack; s != nil; s = s.parent {
   101  		remain += s.remain
   102  	}
   103  	return
   104  }
   105  
   106  func (r *messageSetReader) discard() (err error) {
   107  	switch {
   108  	case r.empty:
   109  	case r.readerStack == nil:
   110  	default:
   111  		// rewind up to the top-most reader b/c it's the only one that's doing
   112  		// actual i/o.  the rest are byte buffers that have been pushed on the stack
   113  		// while reading compressed message sets.
   114  		for r.parent != nil {
   115  			r.readerStack = r.parent
   116  		}
   117  		err = r.discardN(r.remain)
   118  	}
   119  	return
   120  }
   121  
   122  func (r *messageSetReader) readMessage(min int64, key readBytesFunc, val readBytesFunc) (
   123  	offset int64, lastOffset int64, timestamp int64, headers []Header, err error) {
   124  
   125  	if r.empty {
   126  		err = RequestTimedOut
   127  		return
   128  	}
   129  	if err = r.readHeader(); err != nil {
   130  		return
   131  	}
   132  	switch r.header.magic {
   133  	case 0, 1:
   134  		offset, timestamp, headers, err = r.readMessageV1(min, key, val)
   135  		// Set an invalid value so that it can be ignored
   136  		lastOffset = -1
   137  	case 2:
   138  		offset, lastOffset, timestamp, headers, err = r.readMessageV2(min, key, val)
   139  	default:
   140  		err = r.header.badMagic()
   141  	}
   142  	return
   143  }
   144  
   145  func (r *messageSetReader) readMessageV1(min int64, key readBytesFunc, val readBytesFunc) (
   146  	offset int64, timestamp int64, headers []Header, err error) {
   147  
   148  	for r.readerStack != nil {
   149  		if r.remain == 0 {
   150  			r.readerStack = r.parent
   151  			continue
   152  		}
   153  		if err = r.readHeader(); err != nil {
   154  			return
   155  		}
   156  		offset = r.header.firstOffset
   157  		timestamp = r.header.v1.timestamp
   158  		var codec CompressionCodec
   159  		if codec, err = r.header.compression(); err != nil {
   160  			return
   161  		}
   162  		if r.debug {
   163  			r.log("Reading with codec=%T", codec)
   164  		}
   165  		if codec != nil {
   166  			// discard next four bytes...will be -1 to indicate null key
   167  			if err = r.discardN(4); err != nil {
   168  				return
   169  			}
   170  
   171  			// read and decompress the contained message set.
   172  			r.decompressed.Reset()
   173  			if err = r.readBytesWith(func(br *bufio.Reader, sz int, n int) (remain int, err error) {
   174  				// x4 as a guess that the average compression ratio is near 75%
   175  				r.decompressed.Grow(4 * n)
   176  				limitReader := io.LimitedReader{R: br, N: int64(n)}
   177  				codecReader := codec.NewReader(&limitReader)
   178  				_, err = r.decompressed.ReadFrom(codecReader)
   179  				remain = sz - (n - int(limitReader.N))
   180  				codecReader.Close()
   181  				return
   182  			}); err != nil {
   183  				return
   184  			}
   185  
   186  			// the compressed message's offset will be equal to the offset of
   187  			// the last message in the set.  within the compressed set, the
   188  			// offsets will be relative, so we have to scan through them to
   189  			// get the base offset.  for example, if there are four compressed
   190  			// messages at offsets 10-13, then the container message will have
   191  			// offset 13 and the contained messages will be 0,1,2,3.  the base
   192  			// offset for the container, then is 13-3=10.
   193  			if offset, err = extractOffset(offset, r.decompressed.Bytes()); err != nil {
   194  				return
   195  			}
   196  
   197  			// mark the outer message as being read
   198  			r.markRead()
   199  
   200  			// then push the decompressed bytes onto the stack.
   201  			r.readerStack = &readerStack{
   202  				// Allocate a buffer of size 0, which gets capped at 16 bytes
   203  				// by the bufio package. We are already reading buffered data
   204  				// here, no need to reserve another 4KB buffer.
   205  				reader: bufio.NewReaderSize(r.decompressed, 0),
   206  				remain: r.decompressed.Len(),
   207  				base:   offset,
   208  				parent: r.readerStack,
   209  			}
   210  			continue
   211  		}
   212  
   213  		// adjust the offset in case we're reading compressed messages.  the
   214  		// base will be zero otherwise.
   215  		offset += r.base
   216  
   217  		// When the messages are compressed kafka may return messages at an
   218  		// earlier offset than the one that was requested, it's the client's
   219  		// responsibility to ignore those.
   220  		//
   221  		// At this point, the message header has been read, so discarding
   222  		// the rest of the message means we have to discard the key, and then
   223  		// the value. Each of those are preceded by a 4-byte length. Discarding
   224  		// them is then reading that length variable and then discarding that
   225  		// amount.
   226  		if offset < min {
   227  			// discard the key
   228  			if err = r.discardBytes(); err != nil {
   229  				return
   230  			}
   231  			// discard the value
   232  			if err = r.discardBytes(); err != nil {
   233  				return
   234  			}
   235  			// since we have fully consumed the message, mark as read
   236  			r.markRead()
   237  			continue
   238  		}
   239  		if err = r.readBytesWith(key); err != nil {
   240  			return
   241  		}
   242  		if err = r.readBytesWith(val); err != nil {
   243  			return
   244  		}
   245  		r.markRead()
   246  		return
   247  	}
   248  	err = errShortRead
   249  	return
   250  }
   251  
   252  func (r *messageSetReader) readMessageV2(_ int64, key readBytesFunc, val readBytesFunc) (
   253  	offset int64, lastOffset int64, timestamp int64, headers []Header, err error) {
   254  	if err = r.readHeader(); err != nil {
   255  		return
   256  	}
   257  	if r.count == int(r.header.v2.count) { // first time reading this set, so check for compression headers.
   258  		var codec CompressionCodec
   259  		if codec, err = r.header.compression(); err != nil {
   260  			return
   261  		}
   262  		if codec != nil {
   263  			batchRemain := int(r.header.length - 49) // TODO: document this magic number
   264  			if batchRemain > r.remain {
   265  				err = errShortRead
   266  				return
   267  			}
   268  			if batchRemain < 0 {
   269  				err = fmt.Errorf("batch remain < 0 (%d)", batchRemain)
   270  				return
   271  			}
   272  			r.decompressed.Reset()
   273  			// x4 as a guess that the average compression ratio is near 75%
   274  			r.decompressed.Grow(4 * batchRemain)
   275  			limitReader := io.LimitedReader{R: r.reader, N: int64(batchRemain)}
   276  			codecReader := codec.NewReader(&limitReader)
   277  			_, err = r.decompressed.ReadFrom(codecReader)
   278  			codecReader.Close()
   279  			if err != nil {
   280  				return
   281  			}
   282  			r.remain -= batchRemain - int(limitReader.N)
   283  			r.readerStack = &readerStack{
   284  				reader: bufio.NewReaderSize(r.decompressed, 0), // the new stack reads from the decompressed buffer
   285  				remain: r.decompressed.Len(),
   286  				base:   -1, // base is unused here
   287  				parent: r.readerStack,
   288  				header: r.header,
   289  				count:  r.count,
   290  			}
   291  			// all of the messages in this set are in the decompressed set just pushed onto the reader
   292  			// stack. here we set the parent count to 0 so that when the child set is exhausted, the
   293  			// reader will then try to read the header of the next message set
   294  			r.readerStack.parent.count = 0
   295  		}
   296  	}
   297  	remainBefore := r.remain
   298  	var length int64
   299  	if err = r.readVarInt(&length); err != nil {
   300  		return
   301  	}
   302  	lengthOfLength := remainBefore - r.remain
   303  	var attrs int8
   304  	if err = r.readInt8(&attrs); err != nil {
   305  		return
   306  	}
   307  	var timestampDelta int64
   308  	if err = r.readVarInt(&timestampDelta); err != nil {
   309  		return
   310  	}
   311  	timestamp = r.header.v2.firstTimestamp + timestampDelta
   312  	var offsetDelta int64
   313  	if err = r.readVarInt(&offsetDelta); err != nil {
   314  		return
   315  	}
   316  	offset = r.header.firstOffset + offsetDelta
   317  	if err = r.runFunc(key); err != nil {
   318  		return
   319  	}
   320  	if err = r.runFunc(val); err != nil {
   321  		return
   322  	}
   323  	var headerCount int64
   324  	if err = r.readVarInt(&headerCount); err != nil {
   325  		return
   326  	}
   327  	if headerCount > 0 {
   328  		headers = make([]Header, headerCount)
   329  		for i := range headers {
   330  			if err = r.readMessageHeader(&headers[i]); err != nil {
   331  				return
   332  			}
   333  		}
   334  	}
   335  	lastOffset = r.header.firstOffset + int64(r.header.v2.lastOffsetDelta)
   336  	r.lengthRemain -= int(length) + lengthOfLength
   337  	r.markRead()
   338  	return
   339  }
   340  
   341  func (r *messageSetReader) discardBytes() (err error) {
   342  	r.remain, err = discardBytes(r.reader, r.remain)
   343  	return
   344  }
   345  
   346  func (r *messageSetReader) discardN(sz int) (err error) {
   347  	r.remain, err = discardN(r.reader, r.remain, sz)
   348  	return
   349  }
   350  
   351  func (r *messageSetReader) markRead() {
   352  	if r.count == 0 {
   353  		panic("markRead: negative count")
   354  	}
   355  	r.count--
   356  	r.unwindStack()
   357  	if r.debug {
   358  		r.log("Mark read remain=%d", r.remain)
   359  	}
   360  }
   361  
   362  func (r *messageSetReader) unwindStack() {
   363  	for r.count == 0 {
   364  		if r.remain == 0 {
   365  			if r.parent != nil {
   366  				if r.debug {
   367  					r.log("Popped reader stack")
   368  				}
   369  				r.readerStack = r.parent
   370  				continue
   371  			}
   372  		}
   373  		break
   374  	}
   375  }
   376  
   377  func (r *messageSetReader) readMessageHeader(header *Header) (err error) {
   378  	var keyLen int64
   379  	if err = r.readVarInt(&keyLen); err != nil {
   380  		return
   381  	}
   382  	if header.Key, err = r.readNewString(int(keyLen)); err != nil {
   383  		return
   384  	}
   385  	var valLen int64
   386  	if err = r.readVarInt(&valLen); err != nil {
   387  		return
   388  	}
   389  	if header.Value, err = r.readNewBytes(int(valLen)); err != nil {
   390  		return
   391  	}
   392  	return nil
   393  }
   394  
   395  func (r *messageSetReader) runFunc(rbFunc readBytesFunc) (err error) {
   396  	var length int64
   397  	if err = r.readVarInt(&length); err != nil {
   398  		return
   399  	}
   400  	if r.remain, err = rbFunc(r.reader, r.remain, int(length)); err != nil {
   401  		return
   402  	}
   403  	return
   404  }
   405  
   406  func (r *messageSetReader) readHeader() (err error) {
   407  	if r.count > 0 {
   408  		// currently reading a set of messages, no need to read a header until they are exhausted.
   409  		return
   410  	}
   411  	r.header = messagesHeader{}
   412  	if err = r.readInt64(&r.header.firstOffset); err != nil {
   413  		return
   414  	}
   415  	if err = r.readInt32(&r.header.length); err != nil {
   416  		return
   417  	}
   418  	var crcOrLeaderEpoch int32
   419  	if err = r.readInt32(&crcOrLeaderEpoch); err != nil {
   420  		return
   421  	}
   422  	if err = r.readInt8(&r.header.magic); err != nil {
   423  		return
   424  	}
   425  	switch r.header.magic {
   426  	case 0:
   427  		r.header.crc = crcOrLeaderEpoch
   428  		if err = r.readInt8(&r.header.v1.attributes); err != nil {
   429  			return
   430  		}
   431  		r.count = 1
   432  		// Set arbitrary non-zero length so that we always assume the
   433  		// message is truncated since bytes remain.
   434  		r.lengthRemain = 1
   435  		if r.debug {
   436  			r.log("Read v0 header with offset=%d len=%d magic=%d attributes=%d", r.header.firstOffset, r.header.length, r.header.magic, r.header.v1.attributes)
   437  		}
   438  	case 1:
   439  		r.header.crc = crcOrLeaderEpoch
   440  		if err = r.readInt8(&r.header.v1.attributes); err != nil {
   441  			return
   442  		}
   443  		if err = r.readInt64(&r.header.v1.timestamp); err != nil {
   444  			return
   445  		}
   446  		r.count = 1
   447  		// Set arbitrary non-zero length so that we always assume the
   448  		// message is truncated since bytes remain.
   449  		r.lengthRemain = 1
   450  		if r.debug {
   451  			r.log("Read v1 header with remain=%d offset=%d magic=%d and attributes=%d", r.remain, r.header.firstOffset, r.header.magic, r.header.v1.attributes)
   452  		}
   453  	case 2:
   454  		r.header.v2.leaderEpoch = crcOrLeaderEpoch
   455  		if err = r.readInt32(&r.header.crc); err != nil {
   456  			return
   457  		}
   458  		if err = r.readInt16(&r.header.v2.attributes); err != nil {
   459  			return
   460  		}
   461  		if err = r.readInt32(&r.header.v2.lastOffsetDelta); err != nil {
   462  			return
   463  		}
   464  		if err = r.readInt64(&r.header.v2.firstTimestamp); err != nil {
   465  			return
   466  		}
   467  		if err = r.readInt64(&r.header.v2.lastTimestamp); err != nil {
   468  			return
   469  		}
   470  		if err = r.readInt64(&r.header.v2.producerID); err != nil {
   471  			return
   472  		}
   473  		if err = r.readInt16(&r.header.v2.producerEpoch); err != nil {
   474  			return
   475  		}
   476  		if err = r.readInt32(&r.header.v2.baseSequence); err != nil {
   477  			return
   478  		}
   479  		if err = r.readInt32(&r.header.v2.count); err != nil {
   480  			return
   481  		}
   482  		r.count = int(r.header.v2.count)
   483  		// Subtracts the header bytes from the length
   484  		r.lengthRemain = int(r.header.length) - 49
   485  		if r.debug {
   486  			r.log("Read v2 header with count=%d offset=%d len=%d magic=%d attributes=%d", r.count, r.header.firstOffset, r.header.length, r.header.magic, r.header.v2.attributes)
   487  		}
   488  	default:
   489  		err = r.header.badMagic()
   490  		return
   491  	}
   492  	return
   493  }
   494  
   495  func (r *messageSetReader) readNewBytes(len int) (res []byte, err error) {
   496  	res, r.remain, err = readNewBytes(r.reader, r.remain, len)
   497  	return
   498  }
   499  
   500  func (r *messageSetReader) readNewString(len int) (res string, err error) {
   501  	res, r.remain, err = readNewString(r.reader, r.remain, len)
   502  	return
   503  }
   504  
   505  func (r *messageSetReader) readInt8(val *int8) (err error) {
   506  	r.remain, err = readInt8(r.reader, r.remain, val)
   507  	return
   508  }
   509  
   510  func (r *messageSetReader) readInt16(val *int16) (err error) {
   511  	r.remain, err = readInt16(r.reader, r.remain, val)
   512  	return
   513  }
   514  
   515  func (r *messageSetReader) readInt32(val *int32) (err error) {
   516  	r.remain, err = readInt32(r.reader, r.remain, val)
   517  	return
   518  }
   519  
   520  func (r *messageSetReader) readInt64(val *int64) (err error) {
   521  	r.remain, err = readInt64(r.reader, r.remain, val)
   522  	return
   523  }
   524  
   525  func (r *messageSetReader) readVarInt(val *int64) (err error) {
   526  	r.remain, err = readVarInt(r.reader, r.remain, val)
   527  	return
   528  }
   529  
   530  func (r *messageSetReader) readBytesWith(fn readBytesFunc) (err error) {
   531  	r.remain, err = readBytesWith(r.reader, r.remain, fn)
   532  	return
   533  }
   534  
   535  func (r *messageSetReader) log(msg string, args ...interface{}) {
   536  	log.Printf("[DEBUG] "+msg, args...)
   537  }
   538  
   539  func extractOffset(base int64, msgSet []byte) (offset int64, err error) {
   540  	r, remain := bufio.NewReader(bytes.NewReader(msgSet)), len(msgSet)
   541  	for remain > 0 {
   542  		if remain, err = readInt64(r, remain, &offset); err != nil {
   543  			return
   544  		}
   545  		var sz int32
   546  		if remain, err = readInt32(r, remain, &sz); err != nil {
   547  			return
   548  		}
   549  		if remain, err = discardN(r, remain, int(sz)); err != nil {
   550  			return
   551  		}
   552  	}
   553  	offset = base - offset
   554  	return
   555  }