github.com/segmentio/kafka-go@v0.4.48-0.20240318174348-3f6244eb34fd/message_reader.go (about)

     1  package kafka
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"log"
     9  )
    10  
    11  type readBytesFunc func(*bufio.Reader, int, int) (int, error)
    12  
    13  // messageSetReader processes the messages encoded into a fetch response.
    14  // The response may contain a mix of Record Batches (newer format) and Messages
    15  // (older format).
    16  type messageSetReader struct {
    17  	*readerStack      // used for decompressing compressed messages and record batches
    18  	empty        bool // if true, short circuits messageSetReader methods
    19  	debug        bool // enable debug log messages
    20  	// How many bytes are expected to remain in the response.
    21  	//
    22  	// This is used to detect truncation of the response.
    23  	lengthRemain int
    24  
    25  	decompressed *bytes.Buffer
    26  }
    27  
    28  type readerStack struct {
    29  	reader *bufio.Reader
    30  	remain int
    31  	base   int64
    32  	parent *readerStack
    33  	count  int            // how many messages left in the current message set
    34  	header messagesHeader // the current header for a subset of messages within the set.
    35  }
    36  
    37  // messagesHeader describes a set of records. there may be many messagesHeader's in a message set.
    38  type messagesHeader struct {
    39  	firstOffset int64
    40  	length      int32
    41  	crc         int32
    42  	magic       int8
    43  	// v1 composes attributes specific to v0 and v1 message headers
    44  	v1 struct {
    45  		attributes int8
    46  		timestamp  int64
    47  	}
    48  	// v2 composes attributes specific to v2 message headers
    49  	v2 struct {
    50  		leaderEpoch     int32
    51  		attributes      int16
    52  		lastOffsetDelta int32
    53  		firstTimestamp  int64
    54  		lastTimestamp   int64
    55  		producerID      int64
    56  		producerEpoch   int16
    57  		baseSequence    int32
    58  		count           int32
    59  	}
    60  }
    61  
    62  func (h messagesHeader) compression() (codec CompressionCodec, err error) {
    63  	const compressionCodecMask = 0x07
    64  	var code int8
    65  	switch h.magic {
    66  	case 0, 1:
    67  		code = h.v1.attributes & compressionCodecMask
    68  	case 2:
    69  		code = int8(h.v2.attributes & compressionCodecMask)
    70  	default:
    71  		err = h.badMagic()
    72  		return
    73  	}
    74  	if code != 0 {
    75  		codec, err = resolveCodec(code)
    76  	}
    77  	return
    78  }
    79  
    80  func (h messagesHeader) badMagic() error {
    81  	return fmt.Errorf("unsupported magic byte %d in header", h.magic)
    82  }
    83  
    84  func newMessageSetReader(reader *bufio.Reader, remain int) (*messageSetReader, error) {
    85  	res := &messageSetReader{
    86  		readerStack: &readerStack{
    87  			reader: reader,
    88  			remain: remain,
    89  		},
    90  		decompressed: acquireBuffer(),
    91  	}
    92  	err := res.readHeader()
    93  	return res, err
    94  }
    95  
    96  func (r *messageSetReader) remaining() (remain int) {
    97  	if r.empty {
    98  		return 0
    99  	}
   100  	for s := r.readerStack; s != nil; s = s.parent {
   101  		remain += s.remain
   102  	}
   103  	return
   104  }
   105  
   106  func (r *messageSetReader) discard() (err error) {
   107  	switch {
   108  	case r.empty:
   109  	case r.readerStack == nil:
   110  	default:
   111  		// rewind up to the top-most reader b/c it's the only one that's doing
   112  		// actual i/o.  the rest are byte buffers that have been pushed on the stack
   113  		// while reading compressed message sets.
   114  		for r.parent != nil {
   115  			r.readerStack = r.parent
   116  		}
   117  		err = r.discardN(r.remain)
   118  	}
   119  	return
   120  }
   121  
   122  func (r *messageSetReader) readMessage(min int64, key readBytesFunc, val readBytesFunc) (
   123  	offset int64, lastOffset int64, timestamp int64, headers []Header, err error) {
   124  
   125  	if r.empty {
   126  		err = RequestTimedOut
   127  		return
   128  	}
   129  	if err = r.readHeader(); err != nil {
   130  		return
   131  	}
   132  	switch r.header.magic {
   133  	case 0, 1:
   134  		offset, timestamp, headers, err = r.readMessageV1(min, key, val)
   135  		// Set an invalid value so that it can be ignored
   136  		lastOffset = -1
   137  	case 2:
   138  		offset, lastOffset, timestamp, headers, err = r.readMessageV2(min, key, val)
   139  	default:
   140  		err = r.header.badMagic()
   141  	}
   142  	return
   143  }
   144  
   145  func (r *messageSetReader) readMessageV1(min int64, key readBytesFunc, val readBytesFunc) (
   146  	offset int64, timestamp int64, headers []Header, err error) {
   147  
   148  	for r.readerStack != nil {
   149  		if r.remain == 0 {
   150  			r.readerStack = r.parent
   151  			continue
   152  		}
   153  		if err = r.readHeader(); err != nil {
   154  			return
   155  		}
   156  		offset = r.header.firstOffset
   157  		timestamp = r.header.v1.timestamp
   158  		var codec CompressionCodec
   159  		if codec, err = r.header.compression(); err != nil {
   160  			return
   161  		}
   162  		if r.debug {
   163  			r.log("Reading with codec=%T", codec)
   164  		}
   165  		if codec != nil {
   166  			fmt.Printf("codec: %s\n", codec.Name())
   167  			// discard next four bytes...will be -1 to indicate null key
   168  			if err = r.discardN(4); err != nil {
   169  				return
   170  			}
   171  
   172  			// read and decompress the contained message set.
   173  			r.decompressed.Reset()
   174  			if err = r.readBytesWith(func(br *bufio.Reader, sz int, n int) (remain int, err error) {
   175  				// x4 as a guess that the average compression ratio is near 75%
   176  				r.decompressed.Grow(4 * n)
   177  				limitReader := io.LimitedReader{R: br, N: int64(n)}
   178  				codecReader := codec.NewReader(&limitReader)
   179  				_, err = r.decompressed.ReadFrom(codecReader)
   180  				remain = sz - (n - int(limitReader.N))
   181  				codecReader.Close()
   182  				return
   183  			}); err != nil {
   184  				return
   185  			}
   186  
   187  			// the compressed message's offset will be equal to the offset of
   188  			// the last message in the set.  within the compressed set, the
   189  			// offsets will be relative, so we have to scan through them to
   190  			// get the base offset.  for example, if there are four compressed
   191  			// messages at offsets 10-13, then the container message will have
   192  			// offset 13 and the contained messages will be 0,1,2,3.  the base
   193  			// offset for the container, then is 13-3=10.
   194  			if offset, err = extractOffset(offset, r.decompressed.Bytes()); err != nil {
   195  				return
   196  			}
   197  
   198  			// mark the outer message as being read
   199  			r.markRead()
   200  
   201  			// then push the decompressed bytes onto the stack.
   202  			r.readerStack = &readerStack{
   203  				// Allocate a buffer of size 0, which gets capped at 16 bytes
   204  				// by the bufio package. We are already reading buffered data
   205  				// here, no need to reserve another 4KB buffer.
   206  				reader: bufio.NewReaderSize(r.decompressed, 0),
   207  				remain: r.decompressed.Len(),
   208  				base:   offset,
   209  				parent: r.readerStack,
   210  			}
   211  			continue
   212  		}
   213  
   214  		// adjust the offset in case we're reading compressed messages.  the
   215  		// base will be zero otherwise.
   216  		offset += r.base
   217  
   218  		// When the messages are compressed kafka may return messages at an
   219  		// earlier offset than the one that was requested, it's the client's
   220  		// responsibility to ignore those.
   221  		//
   222  		// At this point, the message header has been read, so discarding
   223  		// the rest of the message means we have to discard the key, and then
   224  		// the value. Each of those are preceded by a 4-byte length. Discarding
   225  		// them is then reading that length variable and then discarding that
   226  		// amount.
   227  		if offset < min {
   228  			// discard the key
   229  			if err = r.discardBytes(); err != nil {
   230  				return
   231  			}
   232  			// discard the value
   233  			if err = r.discardBytes(); err != nil {
   234  				return
   235  			}
   236  			// since we have fully consumed the message, mark as read
   237  			r.markRead()
   238  			continue
   239  		}
   240  		if err = r.readBytesWith(key); err != nil {
   241  			return
   242  		}
   243  		if err = r.readBytesWith(val); err != nil {
   244  			return
   245  		}
   246  		r.markRead()
   247  		return
   248  	}
   249  	err = errShortRead
   250  	return
   251  }
   252  
   253  func (r *messageSetReader) readMessageV2(_ int64, key readBytesFunc, val readBytesFunc) (
   254  	offset int64, lastOffset int64, timestamp int64, headers []Header, err error) {
   255  	if err = r.readHeader(); err != nil {
   256  		return
   257  	}
   258  	if r.count == int(r.header.v2.count) { // first time reading this set, so check for compression headers.
   259  		var codec CompressionCodec
   260  		if codec, err = r.header.compression(); err != nil {
   261  			return
   262  		}
   263  		if codec != nil {
   264  			fmt.Printf("codec: %s\n", codec.Name())
   265  			batchRemain := int(r.header.length - 49) // TODO: document this magic number
   266  			if batchRemain > r.remain {
   267  				err = errShortRead
   268  				return
   269  			}
   270  			if batchRemain < 0 {
   271  				err = fmt.Errorf("batch remain < 0 (%d)", batchRemain)
   272  				return
   273  			}
   274  			r.decompressed.Reset()
   275  			// x4 as a guess that the average compression ratio is near 75%
   276  			r.decompressed.Grow(4 * batchRemain)
   277  			limitReader := io.LimitedReader{R: r.reader, N: int64(batchRemain)}
   278  			codecReader := codec.NewReader(&limitReader)
   279  			_, err = r.decompressed.ReadFrom(codecReader)
   280  			codecReader.Close()
   281  			if err != nil {
   282  				return
   283  			}
   284  			r.remain -= batchRemain - int(limitReader.N)
   285  			r.readerStack = &readerStack{
   286  				reader: bufio.NewReaderSize(r.decompressed, 0), // the new stack reads from the decompressed buffer
   287  				remain: r.decompressed.Len(),
   288  				base:   -1, // base is unused here
   289  				parent: r.readerStack,
   290  				header: r.header,
   291  				count:  r.count,
   292  			}
   293  			// all of the messages in this set are in the decompressed set just pushed onto the reader
   294  			// stack. here we set the parent count to 0 so that when the child set is exhausted, the
   295  			// reader will then try to read the header of the next message set
   296  			r.readerStack.parent.count = 0
   297  		}
   298  	}
   299  	remainBefore := r.remain
   300  	var length int64
   301  	if err = r.readVarInt(&length); err != nil {
   302  		return
   303  	}
   304  	lengthOfLength := remainBefore - r.remain
   305  	var attrs int8
   306  	if err = r.readInt8(&attrs); err != nil {
   307  		return
   308  	}
   309  	var timestampDelta int64
   310  	if err = r.readVarInt(&timestampDelta); err != nil {
   311  		return
   312  	}
   313  	timestamp = r.header.v2.firstTimestamp + timestampDelta
   314  	var offsetDelta int64
   315  	if err = r.readVarInt(&offsetDelta); err != nil {
   316  		return
   317  	}
   318  	offset = r.header.firstOffset + offsetDelta
   319  	if err = r.runFunc(key); err != nil {
   320  		return
   321  	}
   322  	if err = r.runFunc(val); err != nil {
   323  		return
   324  	}
   325  	var headerCount int64
   326  	if err = r.readVarInt(&headerCount); err != nil {
   327  		return
   328  	}
   329  	if headerCount > 0 {
   330  		headers = make([]Header, headerCount)
   331  		for i := range headers {
   332  			if err = r.readMessageHeader(&headers[i]); err != nil {
   333  				return
   334  			}
   335  		}
   336  	}
   337  	lastOffset = r.header.firstOffset + int64(r.header.v2.lastOffsetDelta)
   338  	r.lengthRemain -= int(length) + lengthOfLength
   339  	r.markRead()
   340  	return
   341  }
   342  
   343  func (r *messageSetReader) discardBytes() (err error) {
   344  	r.remain, err = discardBytes(r.reader, r.remain)
   345  	return
   346  }
   347  
   348  func (r *messageSetReader) discardN(sz int) (err error) {
   349  	r.remain, err = discardN(r.reader, r.remain, sz)
   350  	return
   351  }
   352  
   353  func (r *messageSetReader) markRead() {
   354  	if r.count == 0 {
   355  		panic("markRead: negative count")
   356  	}
   357  	r.count--
   358  	r.unwindStack()
   359  	if r.debug {
   360  		r.log("Mark read remain=%d", r.remain)
   361  	}
   362  }
   363  
   364  func (r *messageSetReader) unwindStack() {
   365  	for r.count == 0 {
   366  		if r.remain == 0 {
   367  			if r.parent != nil {
   368  				if r.debug {
   369  					r.log("Popped reader stack")
   370  				}
   371  				r.readerStack = r.parent
   372  				continue
   373  			}
   374  		}
   375  		break
   376  	}
   377  }
   378  
   379  func (r *messageSetReader) readMessageHeader(header *Header) (err error) {
   380  	var keyLen int64
   381  	if err = r.readVarInt(&keyLen); err != nil {
   382  		return
   383  	}
   384  	if header.Key, err = r.readNewString(int(keyLen)); err != nil {
   385  		return
   386  	}
   387  	var valLen int64
   388  	if err = r.readVarInt(&valLen); err != nil {
   389  		return
   390  	}
   391  	if header.Value, err = r.readNewBytes(int(valLen)); err != nil {
   392  		return
   393  	}
   394  	return nil
   395  }
   396  
   397  func (r *messageSetReader) runFunc(rbFunc readBytesFunc) (err error) {
   398  	var length int64
   399  	if err = r.readVarInt(&length); err != nil {
   400  		return
   401  	}
   402  	if r.remain, err = rbFunc(r.reader, r.remain, int(length)); err != nil {
   403  		return
   404  	}
   405  	return
   406  }
   407  
   408  func (r *messageSetReader) readHeader() (err error) {
   409  	if r.count > 0 {
   410  		// currently reading a set of messages, no need to read a header until they are exhausted.
   411  		return
   412  	}
   413  	r.header = messagesHeader{}
   414  	if err = r.readInt64(&r.header.firstOffset); err != nil {
   415  		return
   416  	}
   417  	if err = r.readInt32(&r.header.length); err != nil {
   418  		return
   419  	}
   420  	var crcOrLeaderEpoch int32
   421  	if err = r.readInt32(&crcOrLeaderEpoch); err != nil {
   422  		return
   423  	}
   424  	if err = r.readInt8(&r.header.magic); err != nil {
   425  		return
   426  	}
   427  	switch r.header.magic {
   428  	case 0:
   429  		r.header.crc = crcOrLeaderEpoch
   430  		if err = r.readInt8(&r.header.v1.attributes); err != nil {
   431  			return
   432  		}
   433  		r.count = 1
   434  		// Set arbitrary non-zero length so that we always assume the
   435  		// message is truncated since bytes remain.
   436  		r.lengthRemain = 1
   437  		if r.debug {
   438  			r.log("Read v0 header with offset=%d len=%d magic=%d attributes=%d", r.header.firstOffset, r.header.length, r.header.magic, r.header.v1.attributes)
   439  		}
   440  	case 1:
   441  		r.header.crc = crcOrLeaderEpoch
   442  		if err = r.readInt8(&r.header.v1.attributes); err != nil {
   443  			return
   444  		}
   445  		if err = r.readInt64(&r.header.v1.timestamp); err != nil {
   446  			return
   447  		}
   448  		r.count = 1
   449  		// Set arbitrary non-zero length so that we always assume the
   450  		// message is truncated since bytes remain.
   451  		r.lengthRemain = 1
   452  		if r.debug {
   453  			r.log("Read v1 header with remain=%d offset=%d magic=%d and attributes=%d", r.remain, r.header.firstOffset, r.header.magic, r.header.v1.attributes)
   454  		}
   455  	case 2:
   456  		r.header.v2.leaderEpoch = crcOrLeaderEpoch
   457  		if err = r.readInt32(&r.header.crc); err != nil {
   458  			return
   459  		}
   460  		if err = r.readInt16(&r.header.v2.attributes); err != nil {
   461  			return
   462  		}
   463  		if err = r.readInt32(&r.header.v2.lastOffsetDelta); err != nil {
   464  			return
   465  		}
   466  		if err = r.readInt64(&r.header.v2.firstTimestamp); err != nil {
   467  			return
   468  		}
   469  		if err = r.readInt64(&r.header.v2.lastTimestamp); err != nil {
   470  			return
   471  		}
   472  		if err = r.readInt64(&r.header.v2.producerID); err != nil {
   473  			return
   474  		}
   475  		if err = r.readInt16(&r.header.v2.producerEpoch); err != nil {
   476  			return
   477  		}
   478  		if err = r.readInt32(&r.header.v2.baseSequence); err != nil {
   479  			return
   480  		}
   481  		if err = r.readInt32(&r.header.v2.count); err != nil {
   482  			return
   483  		}
   484  		r.count = int(r.header.v2.count)
   485  		// Subtracts the header bytes from the length
   486  		r.lengthRemain = int(r.header.length) - 49
   487  		if r.debug {
   488  			r.log("Read v2 header with count=%d offset=%d len=%d magic=%d attributes=%d", r.count, r.header.firstOffset, r.header.length, r.header.magic, r.header.v2.attributes)
   489  		}
   490  	default:
   491  		err = r.header.badMagic()
   492  		return
   493  	}
   494  	return
   495  }
   496  
   497  func (r *messageSetReader) readNewBytes(len int) (res []byte, err error) {
   498  	res, r.remain, err = readNewBytes(r.reader, r.remain, len)
   499  	return
   500  }
   501  
   502  func (r *messageSetReader) readNewString(len int) (res string, err error) {
   503  	res, r.remain, err = readNewString(r.reader, r.remain, len)
   504  	return
   505  }
   506  
   507  func (r *messageSetReader) readInt8(val *int8) (err error) {
   508  	r.remain, err = readInt8(r.reader, r.remain, val)
   509  	return
   510  }
   511  
   512  func (r *messageSetReader) readInt16(val *int16) (err error) {
   513  	r.remain, err = readInt16(r.reader, r.remain, val)
   514  	return
   515  }
   516  
   517  func (r *messageSetReader) readInt32(val *int32) (err error) {
   518  	r.remain, err = readInt32(r.reader, r.remain, val)
   519  	return
   520  }
   521  
   522  func (r *messageSetReader) readInt64(val *int64) (err error) {
   523  	r.remain, err = readInt64(r.reader, r.remain, val)
   524  	return
   525  }
   526  
   527  func (r *messageSetReader) readVarInt(val *int64) (err error) {
   528  	r.remain, err = readVarInt(r.reader, r.remain, val)
   529  	return
   530  }
   531  
   532  func (r *messageSetReader) readBytesWith(fn readBytesFunc) (err error) {
   533  	r.remain, err = readBytesWith(r.reader, r.remain, fn)
   534  	return
   535  }
   536  
   537  func (r *messageSetReader) log(msg string, args ...interface{}) {
   538  	log.Printf("[DEBUG] "+msg, args...)
   539  }
   540  
   541  func extractOffset(base int64, msgSet []byte) (offset int64, err error) {
   542  	r, remain := bufio.NewReader(bytes.NewReader(msgSet)), len(msgSet)
   543  	for remain > 0 {
   544  		if remain, err = readInt64(r, remain, &offset); err != nil {
   545  			return
   546  		}
   547  		var sz int32
   548  		if remain, err = readInt32(r, remain, &sz); err != nil {
   549  			return
   550  		}
   551  		if remain, err = discardN(r, remain, int(sz)); err != nil {
   552  			return
   553  		}
   554  	}
   555  	offset = base - offset
   556  	return
   557  }