github.com/segmentio/kafka-go@v0.4.48-0.20240318174348-3f6244eb34fd/protocol/decode.go (about)

     1  package protocol
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"hash/crc32"
     8  	"io"
     9  	"io/ioutil"
    10  	"math"
    11  	"reflect"
    12  	"sync"
    13  	"sync/atomic"
    14  )
    15  
    16  type discarder interface {
    17  	Discard(int) (int, error)
    18  }
    19  
    20  type decoder struct {
    21  	reader io.Reader
    22  	remain int
    23  	buffer [8]byte
    24  	err    error
    25  	table  *crc32.Table
    26  	crc32  uint32
    27  }
    28  
    29  func (d *decoder) Reset(r io.Reader, n int) {
    30  	d.reader = r
    31  	d.remain = n
    32  	d.buffer = [8]byte{}
    33  	d.err = nil
    34  	d.table = nil
    35  	d.crc32 = 0
    36  }
    37  
    38  func (d *decoder) Read(b []byte) (int, error) {
    39  	if d.err != nil {
    40  		return 0, d.err
    41  	}
    42  	if d.remain == 0 {
    43  		return 0, io.EOF
    44  	}
    45  	if len(b) > d.remain {
    46  		b = b[:d.remain]
    47  	}
    48  	n, err := d.reader.Read(b)
    49  	if n > 0 && d.table != nil {
    50  		d.crc32 = crc32.Update(d.crc32, d.table, b[:n])
    51  	}
    52  	d.remain -= n
    53  	return n, err
    54  }
    55  
    56  func (d *decoder) ReadByte() (byte, error) {
    57  	c := d.readByte()
    58  	return c, d.err
    59  }
    60  
    61  func (d *decoder) done() bool {
    62  	return d.remain == 0 || d.err != nil
    63  }
    64  
    65  func (d *decoder) setCRC(table *crc32.Table) {
    66  	d.table, d.crc32 = table, 0
    67  }
    68  
    69  func (d *decoder) decodeBool(v value) {
    70  	v.setBool(d.readBool())
    71  }
    72  
    73  func (d *decoder) decodeInt8(v value) {
    74  	v.setInt8(d.readInt8())
    75  }
    76  
    77  func (d *decoder) decodeInt16(v value) {
    78  	v.setInt16(d.readInt16())
    79  }
    80  
    81  func (d *decoder) decodeInt32(v value) {
    82  	v.setInt32(d.readInt32())
    83  }
    84  
    85  func (d *decoder) decodeInt64(v value) {
    86  	v.setInt64(d.readInt64())
    87  }
    88  
    89  func (d *decoder) decodeFloat64(v value) {
    90  	v.setFloat64(d.readFloat64())
    91  }
    92  
    93  func (d *decoder) decodeString(v value) {
    94  	v.setString(d.readString())
    95  }
    96  
    97  func (d *decoder) decodeCompactString(v value) {
    98  	v.setString(d.readCompactString())
    99  }
   100  
   101  func (d *decoder) decodeBytes(v value) {
   102  	v.setBytes(d.readBytes())
   103  }
   104  
   105  func (d *decoder) decodeCompactBytes(v value) {
   106  	v.setBytes(d.readCompactBytes())
   107  }
   108  
   109  func (d *decoder) decodeArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
   110  	if n := d.readInt32(); n < 0 {
   111  		v.setArray(array{})
   112  	} else {
   113  		a := makeArray(elemType, int(n))
   114  		for i := 0; i < int(n) && d.remain > 0; i++ {
   115  			decodeElem(d, a.index(i))
   116  		}
   117  		v.setArray(a)
   118  	}
   119  }
   120  
   121  func (d *decoder) decodeCompactArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
   122  	if n := d.readUnsignedVarInt(); n < 1 {
   123  		v.setArray(array{})
   124  	} else {
   125  		a := makeArray(elemType, int(n-1))
   126  		for i := 0; i < int(n-1) && d.remain > 0; i++ {
   127  			decodeElem(d, a.index(i))
   128  		}
   129  		v.setArray(a)
   130  	}
   131  }
   132  
   133  func (d *decoder) discardAll() {
   134  	d.discard(d.remain)
   135  }
   136  
   137  func (d *decoder) discard(n int) {
   138  	if n > d.remain {
   139  		n = d.remain
   140  	}
   141  	var err error
   142  	if r, _ := d.reader.(discarder); r != nil {
   143  		n, err = r.Discard(n)
   144  		d.remain -= n
   145  	} else {
   146  		_, err = io.Copy(ioutil.Discard, d)
   147  	}
   148  	d.setError(err)
   149  }
   150  
   151  func (d *decoder) read(n int) []byte {
   152  	b := make([]byte, n)
   153  	n, err := io.ReadFull(d, b)
   154  	b = b[:n]
   155  	d.setError(err)
   156  	return b
   157  }
   158  
   159  func (d *decoder) writeTo(w io.Writer, n int) {
   160  	limit := d.remain
   161  	if n < limit {
   162  		d.remain = n
   163  	}
   164  	c, err := io.Copy(w, d)
   165  	if int(c) < n && err == nil {
   166  		err = io.ErrUnexpectedEOF
   167  	}
   168  	d.remain = limit - int(c)
   169  	d.setError(err)
   170  }
   171  
   172  func (d *decoder) setError(err error) {
   173  	if d.err == nil && err != nil {
   174  		d.err = err
   175  		d.discardAll()
   176  	}
   177  }
   178  
   179  func (d *decoder) readFull(b []byte) bool {
   180  	n, err := io.ReadFull(d, b)
   181  	d.setError(err)
   182  	return n == len(b)
   183  }
   184  
   185  func (d *decoder) readByte() byte {
   186  	if d.readFull(d.buffer[:1]) {
   187  		return d.buffer[0]
   188  	}
   189  	return 0
   190  }
   191  
   192  func (d *decoder) readBool() bool {
   193  	return d.readByte() != 0
   194  }
   195  
   196  func (d *decoder) readInt8() int8 {
   197  	if d.readFull(d.buffer[:1]) {
   198  		return readInt8(d.buffer[:1])
   199  	}
   200  	return 0
   201  }
   202  
   203  func (d *decoder) readInt16() int16 {
   204  	if d.readFull(d.buffer[:2]) {
   205  		return readInt16(d.buffer[:2])
   206  	}
   207  	return 0
   208  }
   209  
   210  func (d *decoder) readInt32() int32 {
   211  	if d.readFull(d.buffer[:4]) {
   212  		return readInt32(d.buffer[:4])
   213  	}
   214  	return 0
   215  }
   216  
   217  func (d *decoder) readInt64() int64 {
   218  	if d.readFull(d.buffer[:8]) {
   219  		return readInt64(d.buffer[:8])
   220  	}
   221  	return 0
   222  }
   223  
   224  func (d *decoder) readFloat64() float64 {
   225  	if d.readFull(d.buffer[:8]) {
   226  		return readFloat64(d.buffer[:8])
   227  	}
   228  	return 0
   229  }
   230  
   231  func (d *decoder) readString() string {
   232  	if n := d.readInt16(); n < 0 {
   233  		return ""
   234  	} else {
   235  		return bytesToString(d.read(int(n)))
   236  	}
   237  }
   238  
   239  func (d *decoder) readVarString() string {
   240  	if n := d.readVarInt(); n < 0 {
   241  		return ""
   242  	} else {
   243  		return bytesToString(d.read(int(n)))
   244  	}
   245  }
   246  
   247  func (d *decoder) readCompactString() string {
   248  	if n := d.readUnsignedVarInt(); n < 1 {
   249  		return ""
   250  	} else {
   251  		return bytesToString(d.read(int(n - 1)))
   252  	}
   253  }
   254  
   255  func (d *decoder) readBytes() []byte {
   256  	if n := d.readInt32(); n < 0 {
   257  		return nil
   258  	} else {
   259  		return d.read(int(n))
   260  	}
   261  }
   262  
   263  func (d *decoder) readVarBytes() []byte {
   264  	if n := d.readVarInt(); n < 0 {
   265  		return nil
   266  	} else {
   267  		return d.read(int(n))
   268  	}
   269  }
   270  
   271  func (d *decoder) readCompactBytes() []byte {
   272  	if n := d.readUnsignedVarInt(); n < 1 {
   273  		return nil
   274  	} else {
   275  		return d.read(int(n - 1))
   276  	}
   277  }
   278  
   279  func (d *decoder) readVarInt() int64 {
   280  	n := 11 // varints are at most 11 bytes
   281  
   282  	if n > d.remain {
   283  		n = d.remain
   284  	}
   285  
   286  	x := uint64(0)
   287  	s := uint(0)
   288  
   289  	for n > 0 {
   290  		b := d.readByte()
   291  
   292  		if (b & 0x80) == 0 {
   293  			x |= uint64(b) << s
   294  			return int64(x>>1) ^ -(int64(x) & 1)
   295  		}
   296  
   297  		x |= uint64(b&0x7f) << s
   298  		s += 7
   299  		n--
   300  	}
   301  
   302  	d.setError(fmt.Errorf("cannot decode varint from input stream"))
   303  	return 0
   304  }
   305  
   306  func (d *decoder) readUnsignedVarInt() uint64 {
   307  	n := 11 // varints are at most 11 bytes
   308  
   309  	if n > d.remain {
   310  		n = d.remain
   311  	}
   312  
   313  	x := uint64(0)
   314  	s := uint(0)
   315  
   316  	for n > 0 {
   317  		b := d.readByte()
   318  
   319  		if (b & 0x80) == 0 {
   320  			x |= uint64(b) << s
   321  			return x
   322  		}
   323  
   324  		x |= uint64(b&0x7f) << s
   325  		s += 7
   326  		n--
   327  	}
   328  
   329  	d.setError(fmt.Errorf("cannot decode unsigned varint from input stream"))
   330  	return 0
   331  }
   332  
   333  type decodeFunc func(*decoder, value)
   334  
   335  var (
   336  	_ io.Reader     = (*decoder)(nil)
   337  	_ io.ByteReader = (*decoder)(nil)
   338  
   339  	readerFrom = reflect.TypeOf((*io.ReaderFrom)(nil)).Elem()
   340  )
   341  
   342  func decodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
   343  	if reflect.PtrTo(typ).Implements(readerFrom) {
   344  		return readerDecodeFuncOf(typ)
   345  	}
   346  	switch typ.Kind() {
   347  	case reflect.Bool:
   348  		return (*decoder).decodeBool
   349  	case reflect.Int8:
   350  		return (*decoder).decodeInt8
   351  	case reflect.Int16:
   352  		return (*decoder).decodeInt16
   353  	case reflect.Int32:
   354  		return (*decoder).decodeInt32
   355  	case reflect.Int64:
   356  		return (*decoder).decodeInt64
   357  	case reflect.Float64:
   358  		return (*decoder).decodeFloat64
   359  	case reflect.String:
   360  		return stringDecodeFuncOf(flexible, tag)
   361  	case reflect.Struct:
   362  		return structDecodeFuncOf(typ, version, flexible)
   363  	case reflect.Slice:
   364  		if typ.Elem().Kind() == reflect.Uint8 { // []byte
   365  			return bytesDecodeFuncOf(flexible, tag)
   366  		}
   367  		return arrayDecodeFuncOf(typ, version, flexible, tag)
   368  	default:
   369  		panic("unsupported type: " + typ.String())
   370  	}
   371  }
   372  
   373  func stringDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
   374  	if flexible {
   375  		// In flexible messages, all strings are compact
   376  		return (*decoder).decodeCompactString
   377  	}
   378  	return (*decoder).decodeString
   379  }
   380  
   381  func bytesDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
   382  	if flexible {
   383  		// In flexible messages, all arrays are compact
   384  		return (*decoder).decodeCompactBytes
   385  	}
   386  	return (*decoder).decodeBytes
   387  }
   388  
   389  func structDecodeFuncOf(typ reflect.Type, version int16, flexible bool) decodeFunc {
   390  	type field struct {
   391  		decode decodeFunc
   392  		index  index
   393  		tagID  int
   394  	}
   395  
   396  	var fields []field
   397  	taggedFields := map[int]*field{}
   398  
   399  	forEachStructField(typ, func(typ reflect.Type, index index, tag string) {
   400  		forEachStructTag(tag, func(tag structTag) bool {
   401  			if tag.MinVersion <= version && version <= tag.MaxVersion {
   402  				f := field{
   403  					decode: decodeFuncOf(typ, version, flexible, tag),
   404  					index:  index,
   405  					tagID:  tag.TagID,
   406  				}
   407  
   408  				if tag.TagID < -1 {
   409  					// Normal required field
   410  					fields = append(fields, f)
   411  				} else {
   412  					// Optional tagged field (flexible messages only)
   413  					taggedFields[tag.TagID] = &f
   414  				}
   415  				return false
   416  			}
   417  			return true
   418  		})
   419  	})
   420  
   421  	return func(d *decoder, v value) {
   422  		for i := range fields {
   423  			f := &fields[i]
   424  			f.decode(d, v.fieldByIndex(f.index))
   425  		}
   426  
   427  		if flexible {
   428  			// See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
   429  			// for details of tag buffers in "flexible" messages.
   430  			n := int(d.readUnsignedVarInt())
   431  
   432  			for i := 0; i < n; i++ {
   433  				tagID := int(d.readUnsignedVarInt())
   434  				size := int(d.readUnsignedVarInt())
   435  
   436  				f, ok := taggedFields[tagID]
   437  				if ok {
   438  					f.decode(d, v.fieldByIndex(f.index))
   439  				} else {
   440  					d.read(size)
   441  				}
   442  			}
   443  		}
   444  	}
   445  }
   446  
   447  func arrayDecodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
   448  	elemType := typ.Elem()
   449  	elemFunc := decodeFuncOf(elemType, version, flexible, tag)
   450  	if flexible {
   451  		// In flexible messages, all arrays are compact
   452  		return func(d *decoder, v value) { d.decodeCompactArray(v, elemType, elemFunc) }
   453  	}
   454  
   455  	return func(d *decoder, v value) { d.decodeArray(v, elemType, elemFunc) }
   456  }
   457  
   458  func readerDecodeFuncOf(typ reflect.Type) decodeFunc {
   459  	typ = reflect.PtrTo(typ)
   460  	return func(d *decoder, v value) {
   461  		if d.err == nil {
   462  			_, err := v.iface(typ).(io.ReaderFrom).ReadFrom(d)
   463  			if err != nil {
   464  				d.setError(err)
   465  			}
   466  		}
   467  	}
   468  }
   469  
   470  func readInt8(b []byte) int8 {
   471  	return int8(b[0])
   472  }
   473  
   474  func readInt16(b []byte) int16 {
   475  	return int16(binary.BigEndian.Uint16(b))
   476  }
   477  
   478  func readInt32(b []byte) int32 {
   479  	return int32(binary.BigEndian.Uint32(b))
   480  }
   481  
   482  func readInt64(b []byte) int64 {
   483  	return int64(binary.BigEndian.Uint64(b))
   484  }
   485  
   486  func readFloat64(b []byte) float64 {
   487  	return math.Float64frombits(binary.BigEndian.Uint64(b))
   488  }
   489  
   490  func Unmarshal(data []byte, version int16, value interface{}) error {
   491  	typ := elemTypeOf(value)
   492  	cache, _ := unmarshalers.Load().(map[versionedType]decodeFunc)
   493  	key := versionedType{typ: typ, version: version}
   494  	decode := cache[key]
   495  
   496  	if decode == nil {
   497  		decode = decodeFuncOf(reflect.TypeOf(value).Elem(), version, false, structTag{
   498  			MinVersion: -1,
   499  			MaxVersion: -1,
   500  			TagID:      -2,
   501  			Compact:    true,
   502  			Nullable:   true,
   503  		})
   504  
   505  		newCache := make(map[versionedType]decodeFunc, len(cache)+1)
   506  		newCache[key] = decode
   507  
   508  		for typ, fun := range cache {
   509  			newCache[typ] = fun
   510  		}
   511  
   512  		unmarshalers.Store(newCache)
   513  	}
   514  
   515  	d, _ := decoders.Get().(*decoder)
   516  	if d == nil {
   517  		d = &decoder{reader: bytes.NewReader(nil)}
   518  	}
   519  
   520  	d.remain = len(data)
   521  	r, _ := d.reader.(*bytes.Reader)
   522  	r.Reset(data)
   523  
   524  	defer func() {
   525  		r.Reset(nil)
   526  		d.Reset(r, 0)
   527  		decoders.Put(d)
   528  	}()
   529  
   530  	decode(d, valueOf(value))
   531  	return dontExpectEOF(d.err)
   532  }
   533  
   534  var (
   535  	decoders     sync.Pool    // *decoder
   536  	unmarshalers atomic.Value // map[versionedType]decodeFunc
   537  )