github.com/QuangHoangHao/kafka-go@v0.4.36/protocol/decode.go (about)

     1  package protocol
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"hash/crc32"
     8  	"io"
     9  	"io/ioutil"
    10  	"reflect"
    11  	"sync"
    12  	"sync/atomic"
    13  )
    14  
    15  type discarder interface {
    16  	Discard(int) (int, error)
    17  }
    18  
    19  type decoder struct {
    20  	reader io.Reader
    21  	remain int
    22  	buffer [8]byte
    23  	err    error
    24  	table  *crc32.Table
    25  	crc32  uint32
    26  }
    27  
    28  func (d *decoder) Reset(r io.Reader, n int) {
    29  	d.reader = r
    30  	d.remain = n
    31  	d.buffer = [8]byte{}
    32  	d.err = nil
    33  	d.table = nil
    34  	d.crc32 = 0
    35  }
    36  
    37  func (d *decoder) Read(b []byte) (int, error) {
    38  	if d.err != nil {
    39  		return 0, d.err
    40  	}
    41  	if d.remain == 0 {
    42  		return 0, io.EOF
    43  	}
    44  	if len(b) > d.remain {
    45  		b = b[:d.remain]
    46  	}
    47  	n, err := d.reader.Read(b)
    48  	if n > 0 && d.table != nil {
    49  		d.crc32 = crc32.Update(d.crc32, d.table, b[:n])
    50  	}
    51  	d.remain -= n
    52  	return n, err
    53  }
    54  
    55  func (d *decoder) ReadByte() (byte, error) {
    56  	c := d.readByte()
    57  	return c, d.err
    58  }
    59  
    60  func (d *decoder) done() bool {
    61  	return d.remain == 0 || d.err != nil
    62  }
    63  
    64  func (d *decoder) setCRC(table *crc32.Table) {
    65  	d.table, d.crc32 = table, 0
    66  }
    67  
    68  func (d *decoder) decodeBool(v value) {
    69  	v.setBool(d.readBool())
    70  }
    71  
    72  func (d *decoder) decodeInt8(v value) {
    73  	v.setInt8(d.readInt8())
    74  }
    75  
    76  func (d *decoder) decodeInt16(v value) {
    77  	v.setInt16(d.readInt16())
    78  }
    79  
    80  func (d *decoder) decodeInt32(v value) {
    81  	v.setInt32(d.readInt32())
    82  }
    83  
    84  func (d *decoder) decodeInt64(v value) {
    85  	v.setInt64(d.readInt64())
    86  }
    87  
    88  func (d *decoder) decodeString(v value) {
    89  	v.setString(d.readString())
    90  }
    91  
    92  func (d *decoder) decodeCompactString(v value) {
    93  	v.setString(d.readCompactString())
    94  }
    95  
    96  func (d *decoder) decodeBytes(v value) {
    97  	v.setBytes(d.readBytes())
    98  }
    99  
   100  func (d *decoder) decodeCompactBytes(v value) {
   101  	v.setBytes(d.readCompactBytes())
   102  }
   103  
   104  func (d *decoder) decodeArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
   105  	if n := d.readInt32(); n < 0 {
   106  		v.setArray(array{})
   107  	} else {
   108  		a := makeArray(elemType, int(n))
   109  		for i := 0; i < int(n) && d.remain > 0; i++ {
   110  			decodeElem(d, a.index(i))
   111  		}
   112  		v.setArray(a)
   113  	}
   114  }
   115  
   116  func (d *decoder) decodeCompactArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
   117  	if n := d.readUnsignedVarInt(); n < 1 {
   118  		v.setArray(array{})
   119  	} else {
   120  		a := makeArray(elemType, int(n-1))
   121  		for i := 0; i < int(n-1) && d.remain > 0; i++ {
   122  			decodeElem(d, a.index(i))
   123  		}
   124  		v.setArray(a)
   125  	}
   126  }
   127  
   128  func (d *decoder) discardAll() {
   129  	d.discard(d.remain)
   130  }
   131  
   132  func (d *decoder) discard(n int) {
   133  	if n > d.remain {
   134  		n = d.remain
   135  	}
   136  	var err error
   137  	if r, _ := d.reader.(discarder); r != nil {
   138  		n, err = r.Discard(n)
   139  		d.remain -= n
   140  	} else {
   141  		_, err = io.Copy(ioutil.Discard, d)
   142  	}
   143  	d.setError(err)
   144  }
   145  
   146  func (d *decoder) read(n int) []byte {
   147  	b := make([]byte, n)
   148  	n, err := io.ReadFull(d, b)
   149  	b = b[:n]
   150  	d.setError(err)
   151  	return b
   152  }
   153  
   154  func (d *decoder) writeTo(w io.Writer, n int) {
   155  	limit := d.remain
   156  	if n < limit {
   157  		d.remain = n
   158  	}
   159  	c, err := io.Copy(w, d)
   160  	if int(c) < n && err == nil {
   161  		err = io.ErrUnexpectedEOF
   162  	}
   163  	d.remain = limit - int(c)
   164  	d.setError(err)
   165  }
   166  
   167  func (d *decoder) setError(err error) {
   168  	if d.err == nil && err != nil {
   169  		d.err = err
   170  		d.discardAll()
   171  	}
   172  }
   173  
   174  func (d *decoder) readFull(b []byte) bool {
   175  	n, err := io.ReadFull(d, b)
   176  	d.setError(err)
   177  	return n == len(b)
   178  }
   179  
   180  func (d *decoder) readByte() byte {
   181  	if d.readFull(d.buffer[:1]) {
   182  		return d.buffer[0]
   183  	}
   184  	return 0
   185  }
   186  
   187  func (d *decoder) readBool() bool {
   188  	return d.readByte() != 0
   189  }
   190  
   191  func (d *decoder) readInt8() int8 {
   192  	if d.readFull(d.buffer[:1]) {
   193  		return readInt8(d.buffer[:1])
   194  	}
   195  	return 0
   196  }
   197  
   198  func (d *decoder) readInt16() int16 {
   199  	if d.readFull(d.buffer[:2]) {
   200  		return readInt16(d.buffer[:2])
   201  	}
   202  	return 0
   203  }
   204  
   205  func (d *decoder) readInt32() int32 {
   206  	if d.readFull(d.buffer[:4]) {
   207  		return readInt32(d.buffer[:4])
   208  	}
   209  	return 0
   210  }
   211  
   212  func (d *decoder) readInt64() int64 {
   213  	if d.readFull(d.buffer[:8]) {
   214  		return readInt64(d.buffer[:8])
   215  	}
   216  	return 0
   217  }
   218  
   219  func (d *decoder) readString() string {
   220  	if n := d.readInt16(); n < 0 {
   221  		return ""
   222  	} else {
   223  		return bytesToString(d.read(int(n)))
   224  	}
   225  }
   226  
   227  func (d *decoder) readVarString() string {
   228  	if n := d.readVarInt(); n < 0 {
   229  		return ""
   230  	} else {
   231  		return bytesToString(d.read(int(n)))
   232  	}
   233  }
   234  
   235  func (d *decoder) readCompactString() string {
   236  	if n := d.readUnsignedVarInt(); n < 1 {
   237  		return ""
   238  	} else {
   239  		return bytesToString(d.read(int(n - 1)))
   240  	}
   241  }
   242  
   243  func (d *decoder) readBytes() []byte {
   244  	if n := d.readInt32(); n < 0 {
   245  		return nil
   246  	} else {
   247  		return d.read(int(n))
   248  	}
   249  }
   250  
   251  func (d *decoder) readVarBytes() []byte {
   252  	if n := d.readVarInt(); n < 0 {
   253  		return nil
   254  	} else {
   255  		return d.read(int(n))
   256  	}
   257  }
   258  
   259  func (d *decoder) readCompactBytes() []byte {
   260  	if n := d.readUnsignedVarInt(); n < 1 {
   261  		return nil
   262  	} else {
   263  		return d.read(int(n - 1))
   264  	}
   265  }
   266  
   267  func (d *decoder) readVarInt() int64 {
   268  	n := 11 // varints are at most 11 bytes
   269  
   270  	if n > d.remain {
   271  		n = d.remain
   272  	}
   273  
   274  	x := uint64(0)
   275  	s := uint(0)
   276  
   277  	for n > 0 {
   278  		b := d.readByte()
   279  
   280  		if (b & 0x80) == 0 {
   281  			x |= uint64(b) << s
   282  			return int64(x>>1) ^ -(int64(x) & 1)
   283  		}
   284  
   285  		x |= uint64(b&0x7f) << s
   286  		s += 7
   287  		n--
   288  	}
   289  
   290  	d.setError(fmt.Errorf("cannot decode varint from input stream"))
   291  	return 0
   292  }
   293  
   294  func (d *decoder) readUnsignedVarInt() uint64 {
   295  	n := 11 // varints are at most 11 bytes
   296  
   297  	if n > d.remain {
   298  		n = d.remain
   299  	}
   300  
   301  	x := uint64(0)
   302  	s := uint(0)
   303  
   304  	for n > 0 {
   305  		b := d.readByte()
   306  
   307  		if (b & 0x80) == 0 {
   308  			x |= uint64(b) << s
   309  			return x
   310  		}
   311  
   312  		x |= uint64(b&0x7f) << s
   313  		s += 7
   314  		n--
   315  	}
   316  
   317  	d.setError(fmt.Errorf("cannot decode unsigned varint from input stream"))
   318  	return 0
   319  }
   320  
   321  type decodeFunc func(*decoder, value)
   322  
   323  var (
   324  	_ io.Reader     = (*decoder)(nil)
   325  	_ io.ByteReader = (*decoder)(nil)
   326  
   327  	readerFrom = reflect.TypeOf((*io.ReaderFrom)(nil)).Elem()
   328  )
   329  
   330  func decodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
   331  	if reflect.PtrTo(typ).Implements(readerFrom) {
   332  		return readerDecodeFuncOf(typ)
   333  	}
   334  	switch typ.Kind() {
   335  	case reflect.Bool:
   336  		return (*decoder).decodeBool
   337  	case reflect.Int8:
   338  		return (*decoder).decodeInt8
   339  	case reflect.Int16:
   340  		return (*decoder).decodeInt16
   341  	case reflect.Int32:
   342  		return (*decoder).decodeInt32
   343  	case reflect.Int64:
   344  		return (*decoder).decodeInt64
   345  	case reflect.String:
   346  		return stringDecodeFuncOf(flexible, tag)
   347  	case reflect.Struct:
   348  		return structDecodeFuncOf(typ, version, flexible)
   349  	case reflect.Slice:
   350  		if typ.Elem().Kind() == reflect.Uint8 { // []byte
   351  			return bytesDecodeFuncOf(flexible, tag)
   352  		}
   353  		return arrayDecodeFuncOf(typ, version, flexible, tag)
   354  	default:
   355  		panic("unsupported type: " + typ.String())
   356  	}
   357  }
   358  
   359  func stringDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
   360  	if flexible {
   361  		// In flexible messages, all strings are compact
   362  		return (*decoder).decodeCompactString
   363  	}
   364  	return (*decoder).decodeString
   365  }
   366  
   367  func bytesDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
   368  	if flexible {
   369  		// In flexible messages, all arrays are compact
   370  		return (*decoder).decodeCompactBytes
   371  	}
   372  	return (*decoder).decodeBytes
   373  }
   374  
   375  func structDecodeFuncOf(typ reflect.Type, version int16, flexible bool) decodeFunc {
   376  	type field struct {
   377  		decode decodeFunc
   378  		index  index
   379  		tagID  int
   380  	}
   381  
   382  	var fields []field
   383  	taggedFields := map[int]*field{}
   384  
   385  	forEachStructField(typ, func(typ reflect.Type, index index, tag string) {
   386  		forEachStructTag(tag, func(tag structTag) bool {
   387  			if tag.MinVersion <= version && version <= tag.MaxVersion {
   388  				f := field{
   389  					decode: decodeFuncOf(typ, version, flexible, tag),
   390  					index:  index,
   391  					tagID:  tag.TagID,
   392  				}
   393  
   394  				if tag.TagID < -1 {
   395  					// Normal required field
   396  					fields = append(fields, f)
   397  				} else {
   398  					// Optional tagged field (flexible messages only)
   399  					taggedFields[tag.TagID] = &f
   400  				}
   401  				return false
   402  			}
   403  			return true
   404  		})
   405  	})
   406  
   407  	return func(d *decoder, v value) {
   408  		for i := range fields {
   409  			f := &fields[i]
   410  			f.decode(d, v.fieldByIndex(f.index))
   411  		}
   412  
   413  		if flexible {
   414  			// See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
   415  			// for details of tag buffers in "flexible" messages.
   416  			n := int(d.readUnsignedVarInt())
   417  
   418  			for i := 0; i < n; i++ {
   419  				tagID := int(d.readUnsignedVarInt())
   420  				size := int(d.readUnsignedVarInt())
   421  
   422  				f, ok := taggedFields[tagID]
   423  				if ok {
   424  					f.decode(d, v.fieldByIndex(f.index))
   425  				} else {
   426  					d.read(size)
   427  				}
   428  			}
   429  		}
   430  	}
   431  }
   432  
   433  func arrayDecodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
   434  	elemType := typ.Elem()
   435  	elemFunc := decodeFuncOf(elemType, version, flexible, tag)
   436  	if flexible {
   437  		// In flexible messages, all arrays are compact
   438  		return func(d *decoder, v value) { d.decodeCompactArray(v, elemType, elemFunc) }
   439  	}
   440  
   441  	return func(d *decoder, v value) { d.decodeArray(v, elemType, elemFunc) }
   442  }
   443  
   444  func readerDecodeFuncOf(typ reflect.Type) decodeFunc {
   445  	typ = reflect.PtrTo(typ)
   446  	return func(d *decoder, v value) {
   447  		if d.err == nil {
   448  			_, err := v.iface(typ).(io.ReaderFrom).ReadFrom(d)
   449  			if err != nil {
   450  				d.setError(err)
   451  			}
   452  		}
   453  	}
   454  }
   455  
   456  func readInt8(b []byte) int8 {
   457  	return int8(b[0])
   458  }
   459  
   460  func readInt16(b []byte) int16 {
   461  	return int16(binary.BigEndian.Uint16(b))
   462  }
   463  
   464  func readInt32(b []byte) int32 {
   465  	return int32(binary.BigEndian.Uint32(b))
   466  }
   467  
   468  func readInt64(b []byte) int64 {
   469  	return int64(binary.BigEndian.Uint64(b))
   470  }
   471  
   472  func Unmarshal(data []byte, version int16, value interface{}) error {
   473  	typ := elemTypeOf(value)
   474  	cache, _ := unmarshalers.Load().(map[versionedType]decodeFunc)
   475  	key := versionedType{typ: typ, version: version}
   476  	decode := cache[key]
   477  
   478  	if decode == nil {
   479  		decode = decodeFuncOf(reflect.TypeOf(value).Elem(), version, false, structTag{
   480  			MinVersion: -1,
   481  			MaxVersion: -1,
   482  			TagID:      -2,
   483  			Compact:    true,
   484  			Nullable:   true,
   485  		})
   486  
   487  		newCache := make(map[versionedType]decodeFunc, len(cache)+1)
   488  		newCache[key] = decode
   489  
   490  		for typ, fun := range cache {
   491  			newCache[typ] = fun
   492  		}
   493  
   494  		unmarshalers.Store(newCache)
   495  	}
   496  
   497  	d, _ := decoders.Get().(*decoder)
   498  	if d == nil {
   499  		d = &decoder{reader: bytes.NewReader(nil)}
   500  	}
   501  
   502  	d.remain = len(data)
   503  	r, _ := d.reader.(*bytes.Reader)
   504  	r.Reset(data)
   505  
   506  	defer func() {
   507  		r.Reset(nil)
   508  		d.Reset(r, 0)
   509  		decoders.Put(d)
   510  	}()
   511  
   512  	decode(d, valueOf(value))
   513  	return dontExpectEOF(d.err)
   514  }
   515  
   516  var (
   517  	decoders     sync.Pool    // *decoder
   518  	unmarshalers atomic.Value // map[versionedType]decodeFunc
   519  )