github.com/segmentio/kafka-go@v0.4.48-0.20240318174348-3f6244eb34fd/protocol/encode.go (about)

     1  package protocol
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"hash/crc32"
     8  	"io"
     9  	"math"
    10  	"reflect"
    11  	"sync"
    12  	"sync/atomic"
    13  )
    14  
    15  type encoder struct {
    16  	writer io.Writer
    17  	err    error
    18  	table  *crc32.Table
    19  	crc32  uint32
    20  	buffer [32]byte
    21  }
    22  
    23  type encoderChecksum struct {
    24  	reader  io.Reader
    25  	encoder *encoder
    26  }
    27  
    28  func (e *encoderChecksum) Read(b []byte) (int, error) {
    29  	n, err := e.reader.Read(b)
    30  	if n > 0 {
    31  		e.encoder.update(b[:n])
    32  	}
    33  	return n, err
    34  }
    35  
    36  func (e *encoder) Reset(w io.Writer) {
    37  	e.writer = w
    38  	e.err = nil
    39  	e.table = nil
    40  	e.crc32 = 0
    41  	e.buffer = [32]byte{}
    42  }
    43  
    44  func (e *encoder) ReadFrom(r io.Reader) (int64, error) {
    45  	if e.table != nil {
    46  		r = &encoderChecksum{
    47  			reader:  r,
    48  			encoder: e,
    49  		}
    50  	}
    51  	return io.Copy(e.writer, r)
    52  }
    53  
    54  func (e *encoder) Write(b []byte) (int, error) {
    55  	if e.err != nil {
    56  		return 0, e.err
    57  	}
    58  	n, err := e.writer.Write(b)
    59  	if n > 0 {
    60  		e.update(b[:n])
    61  	}
    62  	if err != nil {
    63  		e.err = err
    64  	}
    65  	return n, err
    66  }
    67  
    68  func (e *encoder) WriteByte(b byte) error {
    69  	e.buffer[0] = b
    70  	_, err := e.Write(e.buffer[:1])
    71  	return err
    72  }
    73  
    74  func (e *encoder) WriteString(s string) (int, error) {
    75  	// This implementation is an optimization to avoid the heap allocation that
    76  	// would occur when converting the string to a []byte to call crc32.Update.
    77  	//
    78  	// Strings are rarely long in the kafka protocol, so the use of a 32 byte
    79  	// buffer is a good comprise between keeping the encoder value small and
    80  	// limiting the number of calls to Write.
    81  	//
    82  	// We introduced this optimization because memory profiles on the benchmarks
    83  	// showed that most heap allocations were caused by this code path.
    84  	n := 0
    85  
    86  	for len(s) != 0 {
    87  		c := copy(e.buffer[:], s)
    88  		w, err := e.Write(e.buffer[:c])
    89  		n += w
    90  		if err != nil {
    91  			return n, err
    92  		}
    93  		s = s[c:]
    94  	}
    95  
    96  	return n, nil
    97  }
    98  
    99  func (e *encoder) setCRC(table *crc32.Table) {
   100  	e.table, e.crc32 = table, 0
   101  }
   102  
   103  func (e *encoder) update(b []byte) {
   104  	if e.table != nil {
   105  		e.crc32 = crc32.Update(e.crc32, e.table, b)
   106  	}
   107  }
   108  
   109  func (e *encoder) encodeBool(v value) {
   110  	b := int8(0)
   111  	if v.bool() {
   112  		b = 1
   113  	}
   114  	e.writeInt8(b)
   115  }
   116  
   117  func (e *encoder) encodeInt8(v value) {
   118  	e.writeInt8(v.int8())
   119  }
   120  
   121  func (e *encoder) encodeInt16(v value) {
   122  	e.writeInt16(v.int16())
   123  }
   124  
   125  func (e *encoder) encodeInt32(v value) {
   126  	e.writeInt32(v.int32())
   127  }
   128  
   129  func (e *encoder) encodeInt64(v value) {
   130  	e.writeInt64(v.int64())
   131  }
   132  
   133  func (e *encoder) encodeFloat64(v value) {
   134  	e.writeFloat64(v.float64())
   135  }
   136  
   137  func (e *encoder) encodeString(v value) {
   138  	e.writeString(v.string())
   139  }
   140  
   141  func (e *encoder) encodeCompactString(v value) {
   142  	e.writeCompactString(v.string())
   143  }
   144  
   145  func (e *encoder) encodeNullString(v value) {
   146  	e.writeNullString(v.string())
   147  }
   148  
   149  func (e *encoder) encodeCompactNullString(v value) {
   150  	e.writeCompactNullString(v.string())
   151  }
   152  
   153  func (e *encoder) encodeBytes(v value) {
   154  	e.writeBytes(v.bytes())
   155  }
   156  
   157  func (e *encoder) encodeCompactBytes(v value) {
   158  	e.writeCompactBytes(v.bytes())
   159  }
   160  
   161  func (e *encoder) encodeNullBytes(v value) {
   162  	e.writeNullBytes(v.bytes())
   163  }
   164  
   165  func (e *encoder) encodeCompactNullBytes(v value) {
   166  	e.writeCompactNullBytes(v.bytes())
   167  }
   168  
   169  func (e *encoder) encodeArray(v value, elemType reflect.Type, encodeElem encodeFunc) {
   170  	a := v.array(elemType)
   171  	n := a.length()
   172  	e.writeInt32(int32(n))
   173  
   174  	for i := 0; i < n; i++ {
   175  		encodeElem(e, a.index(i))
   176  	}
   177  }
   178  
   179  func (e *encoder) encodeCompactArray(v value, elemType reflect.Type, encodeElem encodeFunc) {
   180  	a := v.array(elemType)
   181  	n := a.length()
   182  	e.writeUnsignedVarInt(uint64(n + 1))
   183  
   184  	for i := 0; i < n; i++ {
   185  		encodeElem(e, a.index(i))
   186  	}
   187  }
   188  
   189  func (e *encoder) encodeNullArray(v value, elemType reflect.Type, encodeElem encodeFunc) {
   190  	a := v.array(elemType)
   191  	if a.isNil() {
   192  		e.writeInt32(-1)
   193  		return
   194  	}
   195  
   196  	n := a.length()
   197  	e.writeInt32(int32(n))
   198  
   199  	for i := 0; i < n; i++ {
   200  		encodeElem(e, a.index(i))
   201  	}
   202  }
   203  
   204  func (e *encoder) encodeCompactNullArray(v value, elemType reflect.Type, encodeElem encodeFunc) {
   205  	a := v.array(elemType)
   206  	if a.isNil() {
   207  		e.writeUnsignedVarInt(0)
   208  		return
   209  	}
   210  
   211  	n := a.length()
   212  	e.writeUnsignedVarInt(uint64(n + 1))
   213  	for i := 0; i < n; i++ {
   214  		encodeElem(e, a.index(i))
   215  	}
   216  }
   217  
   218  func (e *encoder) writeInt8(i int8) {
   219  	writeInt8(e.buffer[:1], i)
   220  	e.Write(e.buffer[:1])
   221  }
   222  
   223  func (e *encoder) writeInt16(i int16) {
   224  	writeInt16(e.buffer[:2], i)
   225  	e.Write(e.buffer[:2])
   226  }
   227  
   228  func (e *encoder) writeInt32(i int32) {
   229  	writeInt32(e.buffer[:4], i)
   230  	e.Write(e.buffer[:4])
   231  }
   232  
   233  func (e *encoder) writeInt64(i int64) {
   234  	writeInt64(e.buffer[:8], i)
   235  	e.Write(e.buffer[:8])
   236  }
   237  
   238  func (e *encoder) writeFloat64(f float64) {
   239  	writeFloat64(e.buffer[:8], f)
   240  	e.Write(e.buffer[:8])
   241  }
   242  
   243  func (e *encoder) writeString(s string) {
   244  	e.writeInt16(int16(len(s)))
   245  	e.WriteString(s)
   246  }
   247  
   248  func (e *encoder) writeVarString(s string) {
   249  	e.writeVarInt(int64(len(s)))
   250  	e.WriteString(s)
   251  }
   252  
   253  func (e *encoder) writeCompactString(s string) {
   254  	e.writeUnsignedVarInt(uint64(len(s)) + 1)
   255  	e.WriteString(s)
   256  }
   257  
   258  func (e *encoder) writeNullString(s string) {
   259  	if s == "" {
   260  		e.writeInt16(-1)
   261  	} else {
   262  		e.writeInt16(int16(len(s)))
   263  		e.WriteString(s)
   264  	}
   265  }
   266  
   267  func (e *encoder) writeCompactNullString(s string) {
   268  	if s == "" {
   269  		e.writeUnsignedVarInt(0)
   270  	} else {
   271  		e.writeUnsignedVarInt(uint64(len(s)) + 1)
   272  		e.WriteString(s)
   273  	}
   274  }
   275  
   276  func (e *encoder) writeBytes(b []byte) {
   277  	e.writeInt32(int32(len(b)))
   278  	e.Write(b)
   279  }
   280  
   281  func (e *encoder) writeCompactBytes(b []byte) {
   282  	e.writeUnsignedVarInt(uint64(len(b)) + 1)
   283  	e.Write(b)
   284  }
   285  
   286  func (e *encoder) writeNullBytes(b []byte) {
   287  	if b == nil {
   288  		e.writeInt32(-1)
   289  	} else {
   290  		e.writeInt32(int32(len(b)))
   291  		e.Write(b)
   292  	}
   293  }
   294  
   295  func (e *encoder) writeVarNullBytes(b []byte) {
   296  	if b == nil {
   297  		e.writeVarInt(-1)
   298  	} else {
   299  		e.writeVarInt(int64(len(b)))
   300  		e.Write(b)
   301  	}
   302  }
   303  
   304  func (e *encoder) writeCompactNullBytes(b []byte) {
   305  	if b == nil {
   306  		e.writeUnsignedVarInt(0)
   307  	} else {
   308  		e.writeUnsignedVarInt(uint64(len(b)) + 1)
   309  		e.Write(b)
   310  	}
   311  }
   312  
   313  func (e *encoder) writeNullBytesFrom(b Bytes) error {
   314  	if b == nil {
   315  		e.writeInt32(-1)
   316  		return nil
   317  	} else {
   318  		size := int64(b.Len())
   319  		e.writeInt32(int32(size))
   320  		n, err := io.Copy(e, b)
   321  		if err == nil && n != size {
   322  			err = fmt.Errorf("size of nullable bytes does not match the number of bytes that were written (size=%d, written=%d): %w", size, n, io.ErrUnexpectedEOF)
   323  		}
   324  		return err
   325  	}
   326  }
   327  
   328  func (e *encoder) writeVarNullBytesFrom(b Bytes) error {
   329  	if b == nil {
   330  		e.writeVarInt(-1)
   331  		return nil
   332  	} else {
   333  		size := int64(b.Len())
   334  		e.writeVarInt(size)
   335  		n, err := io.Copy(e, b)
   336  		if err == nil && n != size {
   337  			err = fmt.Errorf("size of nullable bytes does not match the number of bytes that were written (size=%d, written=%d): %w", size, n, io.ErrUnexpectedEOF)
   338  		}
   339  		return err
   340  	}
   341  }
   342  
   343  func (e *encoder) writeVarInt(i int64) {
   344  	e.writeUnsignedVarInt(uint64((i << 1) ^ (i >> 63)))
   345  }
   346  
   347  func (e *encoder) writeUnsignedVarInt(i uint64) {
   348  	b := e.buffer[:]
   349  	n := 0
   350  
   351  	for i >= 0x80 && n < len(b) {
   352  		b[n] = byte(i) | 0x80
   353  		i >>= 7
   354  		n++
   355  	}
   356  
   357  	if n < len(b) {
   358  		b[n] = byte(i)
   359  		n++
   360  	}
   361  
   362  	e.Write(b[:n])
   363  }
   364  
   365  type encodeFunc func(*encoder, value)
   366  
   367  var (
   368  	_ io.ReaderFrom   = (*encoder)(nil)
   369  	_ io.Writer       = (*encoder)(nil)
   370  	_ io.ByteWriter   = (*encoder)(nil)
   371  	_ io.StringWriter = (*encoder)(nil)
   372  
   373  	writerTo = reflect.TypeOf((*io.WriterTo)(nil)).Elem()
   374  )
   375  
   376  func encodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) encodeFunc {
   377  	if reflect.PtrTo(typ).Implements(writerTo) {
   378  		return writerEncodeFuncOf(typ)
   379  	}
   380  	switch typ.Kind() {
   381  	case reflect.Bool:
   382  		return (*encoder).encodeBool
   383  	case reflect.Int8:
   384  		return (*encoder).encodeInt8
   385  	case reflect.Int16:
   386  		return (*encoder).encodeInt16
   387  	case reflect.Int32:
   388  		return (*encoder).encodeInt32
   389  	case reflect.Int64:
   390  		return (*encoder).encodeInt64
   391  	case reflect.Float64:
   392  		return (*encoder).encodeFloat64
   393  	case reflect.String:
   394  		return stringEncodeFuncOf(flexible, tag)
   395  	case reflect.Struct:
   396  		return structEncodeFuncOf(typ, version, flexible)
   397  	case reflect.Slice:
   398  		if typ.Elem().Kind() == reflect.Uint8 { // []byte
   399  			return bytesEncodeFuncOf(flexible, tag)
   400  		}
   401  		return arrayEncodeFuncOf(typ, version, flexible, tag)
   402  	default:
   403  		panic("unsupported type: " + typ.String())
   404  	}
   405  }
   406  
   407  func stringEncodeFuncOf(flexible bool, tag structTag) encodeFunc {
   408  	switch {
   409  	case flexible && tag.Nullable:
   410  		// In flexible messages, all strings are compact
   411  		return (*encoder).encodeCompactNullString
   412  	case flexible:
   413  		// In flexible messages, all strings are compact
   414  		return (*encoder).encodeCompactString
   415  	case tag.Nullable:
   416  		return (*encoder).encodeNullString
   417  	default:
   418  		return (*encoder).encodeString
   419  	}
   420  }
   421  
   422  func bytesEncodeFuncOf(flexible bool, tag structTag) encodeFunc {
   423  	switch {
   424  	case flexible && tag.Nullable:
   425  		// In flexible messages, all arrays are compact
   426  		return (*encoder).encodeCompactNullBytes
   427  	case flexible:
   428  		// In flexible messages, all arrays are compact
   429  		return (*encoder).encodeCompactBytes
   430  	case tag.Nullable:
   431  		return (*encoder).encodeNullBytes
   432  	default:
   433  		return (*encoder).encodeBytes
   434  	}
   435  }
   436  
   437  func structEncodeFuncOf(typ reflect.Type, version int16, flexible bool) encodeFunc {
   438  	type field struct {
   439  		encode encodeFunc
   440  		index  index
   441  		tagID  int
   442  	}
   443  
   444  	var fields []field
   445  	var taggedFields []field
   446  
   447  	forEachStructField(typ, func(typ reflect.Type, index index, tag string) {
   448  		if typ.Size() != 0 { // skip struct{}
   449  			forEachStructTag(tag, func(tag structTag) bool {
   450  				if tag.MinVersion <= version && version <= tag.MaxVersion {
   451  					f := field{
   452  						encode: encodeFuncOf(typ, version, flexible, tag),
   453  						index:  index,
   454  						tagID:  tag.TagID,
   455  					}
   456  
   457  					if tag.TagID < -1 {
   458  						// Normal required field
   459  						fields = append(fields, f)
   460  					} else {
   461  						// Optional tagged field (flexible messages only)
   462  						taggedFields = append(taggedFields, f)
   463  					}
   464  					return false
   465  				}
   466  				return true
   467  			})
   468  		}
   469  	})
   470  
   471  	return func(e *encoder, v value) {
   472  		for i := range fields {
   473  			f := &fields[i]
   474  			f.encode(e, v.fieldByIndex(f.index))
   475  		}
   476  
   477  		if flexible {
   478  			// See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
   479  			// for details of tag buffers in "flexible" messages.
   480  			e.writeUnsignedVarInt(uint64(len(taggedFields)))
   481  
   482  			for i := range taggedFields {
   483  				f := &taggedFields[i]
   484  				e.writeUnsignedVarInt(uint64(f.tagID))
   485  
   486  				buf := &bytes.Buffer{}
   487  				se := &encoder{writer: buf}
   488  				f.encode(se, v.fieldByIndex(f.index))
   489  				e.writeUnsignedVarInt(uint64(buf.Len()))
   490  				e.Write(buf.Bytes())
   491  			}
   492  		}
   493  	}
   494  }
   495  
   496  func arrayEncodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) encodeFunc {
   497  	elemType := typ.Elem()
   498  	elemFunc := encodeFuncOf(elemType, version, flexible, tag)
   499  	switch {
   500  	case flexible && tag.Nullable:
   501  		// In flexible messages, all arrays are compact
   502  		return func(e *encoder, v value) { e.encodeCompactNullArray(v, elemType, elemFunc) }
   503  	case flexible:
   504  		// In flexible messages, all arrays are compact
   505  		return func(e *encoder, v value) { e.encodeCompactArray(v, elemType, elemFunc) }
   506  	case tag.Nullable:
   507  		return func(e *encoder, v value) { e.encodeNullArray(v, elemType, elemFunc) }
   508  	default:
   509  		return func(e *encoder, v value) { e.encodeArray(v, elemType, elemFunc) }
   510  	}
   511  }
   512  
   513  func writerEncodeFuncOf(typ reflect.Type) encodeFunc {
   514  	typ = reflect.PtrTo(typ)
   515  	return func(e *encoder, v value) {
   516  		// Optimization to write directly into the buffer when the encoder
   517  		// does no need to compute a crc32 checksum.
   518  		w := io.Writer(e)
   519  		if e.table == nil {
   520  			w = e.writer
   521  		}
   522  		_, err := v.iface(typ).(io.WriterTo).WriteTo(w)
   523  		if err != nil {
   524  			e.err = err
   525  		}
   526  	}
   527  }
   528  
   529  func writeInt8(b []byte, i int8) {
   530  	b[0] = byte(i)
   531  }
   532  
   533  func writeInt16(b []byte, i int16) {
   534  	binary.BigEndian.PutUint16(b, uint16(i))
   535  }
   536  
   537  func writeInt32(b []byte, i int32) {
   538  	binary.BigEndian.PutUint32(b, uint32(i))
   539  }
   540  
   541  func writeInt64(b []byte, i int64) {
   542  	binary.BigEndian.PutUint64(b, uint64(i))
   543  }
   544  
   545  func writeFloat64(b []byte, f float64) {
   546  	binary.BigEndian.PutUint64(b, math.Float64bits(f))
   547  }
   548  
   549  func Marshal(version int16, value interface{}) ([]byte, error) {
   550  	typ := typeOf(value)
   551  	cache, _ := marshalers.Load().(map[versionedType]encodeFunc)
   552  	key := versionedType{typ: typ, version: version}
   553  	encode := cache[key]
   554  
   555  	if encode == nil {
   556  		encode = encodeFuncOf(reflect.TypeOf(value), version, false, structTag{
   557  			MinVersion: -1,
   558  			MaxVersion: -1,
   559  			TagID:      -2,
   560  			Compact:    true,
   561  			Nullable:   true,
   562  		})
   563  
   564  		newCache := make(map[versionedType]encodeFunc, len(cache)+1)
   565  		newCache[key] = encode
   566  
   567  		for typ, fun := range cache {
   568  			newCache[typ] = fun
   569  		}
   570  
   571  		marshalers.Store(newCache)
   572  	}
   573  
   574  	e, _ := encoders.Get().(*encoder)
   575  	if e == nil {
   576  		e = &encoder{writer: new(bytes.Buffer)}
   577  	}
   578  
   579  	b, _ := e.writer.(*bytes.Buffer)
   580  	defer func() {
   581  		b.Reset()
   582  		e.Reset(b)
   583  		encoders.Put(e)
   584  	}()
   585  
   586  	encode(e, nonAddressableValueOf(value))
   587  
   588  	if e.err != nil {
   589  		return nil, e.err
   590  	}
   591  
   592  	buf := b.Bytes()
   593  	out := make([]byte, len(buf))
   594  	copy(out, buf)
   595  	return out, nil
   596  }
   597  
   598  type versionedType struct {
   599  	typ     _type
   600  	version int16
   601  }
   602  
   603  var (
   604  	encoders   sync.Pool    // *encoder
   605  	marshalers atomic.Value // map[versionedType]encodeFunc
   606  )