github.com/hack0072008/kafka-go@v1.0.1/protocol/decode.go (about)

     1  package protocol
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"hash/crc32"
     8  	"io"
     9  	"io/ioutil"
    10  	"reflect"
    11  	"sync"
    12  	"sync/atomic"
    13  )
    14  
    15  type discarder interface {
    16  	Discard(int) (int, error)
    17  }
    18  
    19  type decoder struct {
    20  	reader io.Reader
    21  	remain int
    22  	buffer [8]byte
    23  	err    error
    24  	table  *crc32.Table
    25  	crc32  uint32
    26  }
    27  
    28  func (d *decoder) Reset(r io.Reader, n int) {
    29  	d.reader = r
    30  	d.remain = n
    31  	d.buffer = [8]byte{}
    32  	d.err = nil
    33  	d.table = nil
    34  	d.crc32 = 0
    35  }
    36  
    37  func (d *decoder) Read(b []byte) (int, error) {
    38  	if d.err != nil {
    39  		return 0, d.err
    40  	}
    41  	if d.remain == 0 {
    42  		return 0, io.EOF
    43  	}
    44  	if len(b) > d.remain {
    45  		b = b[:d.remain]
    46  	}
    47  	n, err := d.reader.Read(b)
    48  	if n > 0 && d.table != nil {
    49  		d.crc32 = crc32.Update(d.crc32, d.table, b[:n])
    50  	}
    51  	d.remain -= n
    52  	return n, err
    53  }
    54  
    55  func (d *decoder) ReadByte() (byte, error) {
    56  	c := d.readByte()
    57  	return c, d.err
    58  }
    59  
    60  func (d *decoder) done() bool {
    61  	return d.remain == 0 || d.err != nil
    62  }
    63  
    64  func (d *decoder) setCRC(table *crc32.Table) {
    65  	d.table, d.crc32 = table, 0
    66  }
    67  
    68  func (d *decoder) decodeBool(v value) {
    69  	v.setBool(d.readBool())
    70  }
    71  
    72  func (d *decoder) decodeInt8(v value) {
    73  	v.setInt8(d.readInt8())
    74  }
    75  
    76  func (d *decoder) decodeInt16(v value) {
    77  	v.setInt16(d.readInt16())
    78  }
    79  
    80  func (d *decoder) decodeInt32(v value) {
    81  	v.setInt32(d.readInt32())
    82  }
    83  
    84  func (d *decoder) decodeInt64(v value) {
    85  	v.setInt64(d.readInt64())
    86  }
    87  
    88  func (d *decoder) decodeString(v value) {
    89  	v.setString(d.readString())
    90  }
    91  
    92  func (d *decoder) decodeCompactString(v value) {
    93  	v.setString(d.readCompactString())
    94  }
    95  
    96  func (d *decoder) decodeBytes(v value) {
    97  	v.setBytes(d.readBytes())
    98  }
    99  
   100  func (d *decoder) decodeCompactBytes(v value) {
   101  	v.setBytes(d.readCompactBytes())
   102  }
   103  
   104  func (d *decoder) decodeArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
   105  	if n := d.readInt32(); n < 0 {
   106  		v.setArray(array{})
   107  	} else {
   108  		a := makeArray(elemType, int(n))
   109  		for i := 0; i < int(n) && d.remain > 0; i++ {
   110  			decodeElem(d, a.index(i))
   111  		}
   112  		v.setArray(a)
   113  	}
   114  }
   115  
   116  func (d *decoder) decodeCompactArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
   117  	if n := d.readUnsignedVarInt(); n < 1 {
   118  		v.setArray(array{})
   119  	} else {
   120  		a := makeArray(elemType, int(n-1))
   121  		for i := 0; i < int(n-1) && d.remain > 0; i++ {
   122  			decodeElem(d, a.index(i))
   123  		}
   124  		v.setArray(a)
   125  	}
   126  }
   127  
   128  func (d *decoder) discardAll() {
   129  	d.discard(d.remain)
   130  }
   131  
   132  func (d *decoder) discard(n int) {
   133  	if n > d.remain {
   134  		n = d.remain
   135  	}
   136  	var err error
   137  	if r, _ := d.reader.(discarder); r != nil {
   138  		n, err = r.Discard(n)
   139  		d.remain -= n
   140  	} else {
   141  		_, err = io.Copy(ioutil.Discard, d)
   142  	}
   143  	d.setError(err)
   144  }
   145  
   146  func (d *decoder) read(n int) []byte {
   147  	b := make([]byte, n)
   148  	n, err := io.ReadFull(d, b)
   149  	b = b[:n]
   150  	d.setError(err)
   151  	return b
   152  }
   153  
   154  func (d *decoder) writeTo(w io.Writer, n int) {
   155  	limit := d.remain
   156  	if n < limit {
   157  		d.remain = n
   158  	}
   159  	c, err := io.Copy(w, d)
   160  	if int(c) < n && err == nil {
   161  		err = io.ErrUnexpectedEOF
   162  	}
   163  	d.remain = limit - int(c)
   164  	d.setError(err)
   165  }
   166  
   167  func (d *decoder) setError(err error) {
   168  	if d.err == nil && err != nil {
   169  		d.err = err
   170  		d.discardAll()
   171  	}
   172  }
   173  
   174  func (d *decoder) readFull(b []byte) bool {
   175  	n, err := io.ReadFull(d, b)
   176  	d.setError(err)
   177  	return n == len(b)
   178  }
   179  
   180  func (d *decoder) readByte() byte {
   181  	if d.readFull(d.buffer[:1]) {
   182  		return d.buffer[0]
   183  	}
   184  	return 0
   185  }
   186  
   187  func (d *decoder) readBool() bool {
   188  	return d.readByte() != 0
   189  }
   190  
   191  func (d *decoder) readInt8() int8 {
   192  	if d.readFull(d.buffer[:1]) {
   193  		return readInt8(d.buffer[:1])
   194  	}
   195  	return 0
   196  }
   197  
   198  func (d *decoder) readInt16() int16 {
   199  	if d.readFull(d.buffer[:2]) {
   200  		return readInt16(d.buffer[:2])
   201  	}
   202  	return 0
   203  }
   204  
   205  func (d *decoder) readInt32() int32 {
   206  	if d.readFull(d.buffer[:4]) {
   207  		return readInt32(d.buffer[:4])
   208  	}
   209  	return 0
   210  }
   211  
   212  func (d *decoder) readInt64() int64 {
   213  	if d.readFull(d.buffer[:8]) {
   214  		return readInt64(d.buffer[:8])
   215  	}
   216  	return 0
   217  }
   218  
   219  func (d *decoder) readString() string {
   220  	if n := d.readInt16(); n < 0 {
   221  		return ""
   222  	} else {
   223  		return bytesToString(d.read(int(n)))
   224  	}
   225  }
   226  
   227  func (d *decoder) readVarString() string {
   228  	if n := d.readVarInt(); n < 0 {
   229  		return ""
   230  	} else {
   231  		return bytesToString(d.read(int(n)))
   232  	}
   233  }
   234  
   235  func (d *decoder) readCompactString() string {
   236  	if n := d.readUnsignedVarInt(); n < 1 {
   237  		return ""
   238  	} else {
   239  		return bytesToString(d.read(int(n - 1)))
   240  	}
   241  }
   242  
   243  func (d *decoder) readBytes() []byte {
   244  	if n := d.readInt32(); n < 0 {
   245  		return nil
   246  	} else {
   247  		return d.read(int(n))
   248  	}
   249  }
   250  
   251  func (d *decoder) readBytesTo(w io.Writer) bool {
   252  	if n := d.readInt32(); n < 0 {
   253  		return false
   254  	} else {
   255  		d.writeTo(w, int(n))
   256  		return d.err == nil
   257  	}
   258  }
   259  
   260  func (d *decoder) readVarBytes() []byte {
   261  	if n := d.readVarInt(); n < 0 {
   262  		return nil
   263  	} else {
   264  		return d.read(int(n))
   265  	}
   266  }
   267  
   268  func (d *decoder) readVarBytesTo(w io.Writer) bool {
   269  	if n := d.readVarInt(); n < 0 {
   270  		return false
   271  	} else {
   272  		d.writeTo(w, int(n))
   273  		return d.err == nil
   274  	}
   275  }
   276  
   277  func (d *decoder) readCompactBytes() []byte {
   278  	if n := d.readUnsignedVarInt(); n < 1 {
   279  		return nil
   280  	} else {
   281  		return d.read(int(n - 1))
   282  	}
   283  }
   284  
   285  func (d *decoder) readCompactBytesTo(w io.Writer) bool {
   286  	if n := d.readUnsignedVarInt(); n < 1 {
   287  		return false
   288  	} else {
   289  		d.writeTo(w, int(n-1))
   290  		return d.err == nil
   291  	}
   292  }
   293  
   294  func (d *decoder) readVarInt() int64 {
   295  	n := 11 // varints are at most 11 bytes
   296  
   297  	if n > d.remain {
   298  		n = d.remain
   299  	}
   300  
   301  	x := uint64(0)
   302  	s := uint(0)
   303  
   304  	for n > 0 {
   305  		b := d.readByte()
   306  
   307  		if (b & 0x80) == 0 {
   308  			x |= uint64(b) << s
   309  			return int64(x>>1) ^ -(int64(x) & 1)
   310  		}
   311  
   312  		x |= uint64(b&0x7f) << s
   313  		s += 7
   314  		n--
   315  	}
   316  
   317  	d.setError(fmt.Errorf("cannot decode varint from input stream"))
   318  	return 0
   319  }
   320  
   321  func (d *decoder) readUnsignedVarInt() uint64 {
   322  	n := 11 // varints are at most 11 bytes
   323  
   324  	if n > d.remain {
   325  		n = d.remain
   326  	}
   327  
   328  	x := uint64(0)
   329  	s := uint(0)
   330  
   331  	for n > 0 {
   332  		b := d.readByte()
   333  
   334  		if (b & 0x80) == 0 {
   335  			x |= uint64(b) << s
   336  			return x
   337  		}
   338  
   339  		x |= uint64(b&0x7f) << s
   340  		s += 7
   341  		n--
   342  	}
   343  
   344  	d.setError(fmt.Errorf("cannot decode unsigned varint from input stream"))
   345  	return 0
   346  }
   347  
   348  type decodeFunc func(*decoder, value)
   349  
   350  var (
   351  	_ io.Reader     = (*decoder)(nil)
   352  	_ io.ByteReader = (*decoder)(nil)
   353  
   354  	readerFrom = reflect.TypeOf((*io.ReaderFrom)(nil)).Elem()
   355  )
   356  
   357  func decodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
   358  	if reflect.PtrTo(typ).Implements(readerFrom) {
   359  		return readerDecodeFuncOf(typ)
   360  	}
   361  	switch typ.Kind() {
   362  	case reflect.Bool:
   363  		return (*decoder).decodeBool
   364  	case reflect.Int8:
   365  		return (*decoder).decodeInt8
   366  	case reflect.Int16:
   367  		return (*decoder).decodeInt16
   368  	case reflect.Int32:
   369  		return (*decoder).decodeInt32
   370  	case reflect.Int64:
   371  		return (*decoder).decodeInt64
   372  	case reflect.String:
   373  		return stringDecodeFuncOf(flexible, tag)
   374  	case reflect.Struct:
   375  		return structDecodeFuncOf(typ, version, flexible)
   376  	case reflect.Slice:
   377  		if typ.Elem().Kind() == reflect.Uint8 { // []byte
   378  			return bytesDecodeFuncOf(flexible, tag)
   379  		}
   380  		return arrayDecodeFuncOf(typ, version, flexible, tag)
   381  	default:
   382  		panic("unsupported type: " + typ.String())
   383  	}
   384  }
   385  
   386  func stringDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
   387  	if flexible {
   388  		// In flexible messages, all strings are compact
   389  		return (*decoder).decodeCompactString
   390  	}
   391  	return (*decoder).decodeString
   392  }
   393  
   394  func bytesDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
   395  	if flexible {
   396  		// In flexible messages, all arrays are compact
   397  		return (*decoder).decodeCompactBytes
   398  	}
   399  	return (*decoder).decodeBytes
   400  }
   401  
   402  func structDecodeFuncOf(typ reflect.Type, version int16, flexible bool) decodeFunc {
   403  	type field struct {
   404  		decode decodeFunc
   405  		index  index
   406  		tagID  int
   407  	}
   408  
   409  	var fields []field
   410  	taggedFields := map[int]*field{}
   411  
   412  	forEachStructField(typ, func(typ reflect.Type, index index, tag string) {
   413  		forEachStructTag(tag, func(tag structTag) bool {
   414  			if tag.MinVersion <= version && version <= tag.MaxVersion {
   415  				f := field{
   416  					decode: decodeFuncOf(typ, version, flexible, tag),
   417  					index:  index,
   418  					tagID:  tag.TagID,
   419  				}
   420  
   421  				if tag.TagID < -1 {
   422  					// Normal required field
   423  					fields = append(fields, f)
   424  				} else {
   425  					// Optional tagged field (flexible messages only)
   426  					taggedFields[tag.TagID] = &f
   427  				}
   428  				return false
   429  			}
   430  			return true
   431  		})
   432  	})
   433  
   434  	return func(d *decoder, v value) {
   435  		for i := range fields {
   436  			f := &fields[i]
   437  			f.decode(d, v.fieldByIndex(f.index))
   438  		}
   439  
   440  		if flexible {
   441  			// See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
   442  			// for details of tag buffers in "flexible" messages.
   443  			n := int(d.readUnsignedVarInt())
   444  
   445  			for i := 0; i < n; i++ {
   446  				tagID := int(d.readUnsignedVarInt())
   447  				size := int(d.readUnsignedVarInt())
   448  
   449  				f, ok := taggedFields[tagID]
   450  				if ok {
   451  					f.decode(d, v.fieldByIndex(f.index))
   452  				} else {
   453  					d.read(size)
   454  				}
   455  			}
   456  		}
   457  	}
   458  }
   459  
   460  func arrayDecodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
   461  	elemType := typ.Elem()
   462  	elemFunc := decodeFuncOf(elemType, version, flexible, tag)
   463  	if flexible {
   464  		// In flexible messages, all arrays are compact
   465  		return func(d *decoder, v value) { d.decodeCompactArray(v, elemType, elemFunc) }
   466  	}
   467  
   468  	return func(d *decoder, v value) { d.decodeArray(v, elemType, elemFunc) }
   469  }
   470  
   471  func readerDecodeFuncOf(typ reflect.Type) decodeFunc {
   472  	typ = reflect.PtrTo(typ)
   473  	return func(d *decoder, v value) {
   474  		if d.err == nil {
   475  			_, err := v.iface(typ).(io.ReaderFrom).ReadFrom(d)
   476  			if err != nil {
   477  				d.setError(err)
   478  			}
   479  		}
   480  	}
   481  }
   482  
   483  func readInt8(b []byte) int8 {
   484  	return int8(b[0])
   485  }
   486  
   487  func readInt16(b []byte) int16 {
   488  	return int16(binary.BigEndian.Uint16(b))
   489  }
   490  
   491  func readInt32(b []byte) int32 {
   492  	return int32(binary.BigEndian.Uint32(b))
   493  }
   494  
   495  func readInt64(b []byte) int64 {
   496  	return int64(binary.BigEndian.Uint64(b))
   497  }
   498  
   499  func Unmarshal(data []byte, version int16, value interface{}) error {
   500  	typ := elemTypeOf(value)
   501  	cache, _ := unmarshalers.Load().(map[versionedType]decodeFunc)
   502  	key := versionedType{typ: typ, version: version}
   503  	decode := cache[key]
   504  
   505  	if decode == nil {
   506  		decode = decodeFuncOf(reflect.TypeOf(value).Elem(), version, false, structTag{
   507  			MinVersion: -1,
   508  			MaxVersion: -1,
   509  			TagID:      -2,
   510  			Compact:    true,
   511  			Nullable:   true,
   512  		})
   513  
   514  		newCache := make(map[versionedType]decodeFunc, len(cache)+1)
   515  		newCache[key] = decode
   516  
   517  		for typ, fun := range cache {
   518  			newCache[typ] = fun
   519  		}
   520  
   521  		unmarshalers.Store(newCache)
   522  	}
   523  
   524  	d, _ := decoders.Get().(*decoder)
   525  	if d == nil {
   526  		d = &decoder{reader: bytes.NewReader(nil)}
   527  	}
   528  
   529  	d.remain = len(data)
   530  	r, _ := d.reader.(*bytes.Reader)
   531  	r.Reset(data)
   532  
   533  	defer func() {
   534  		r.Reset(nil)
   535  		d.Reset(r, 0)
   536  		decoders.Put(d)
   537  	}()
   538  
   539  	decode(d, valueOf(value))
   540  	return dontExpectEOF(d.err)
   541  }
   542  
   543  var (
   544  	decoders     sync.Pool    // *decoder
   545  	unmarshalers atomic.Value // map[versionedType]decodeFunc
   546  )