github.com/segmentio/kafka-go@v0.4.48-0.20240318174348-3f6244eb34fd/protocol/decode.go (about) 1 package protocol 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "fmt" 7 "hash/crc32" 8 "io" 9 "io/ioutil" 10 "math" 11 "reflect" 12 "sync" 13 "sync/atomic" 14 ) 15 16 type discarder interface { 17 Discard(int) (int, error) 18 } 19 20 type decoder struct { 21 reader io.Reader 22 remain int 23 buffer [8]byte 24 err error 25 table *crc32.Table 26 crc32 uint32 27 } 28 29 func (d *decoder) Reset(r io.Reader, n int) { 30 d.reader = r 31 d.remain = n 32 d.buffer = [8]byte{} 33 d.err = nil 34 d.table = nil 35 d.crc32 = 0 36 } 37 38 func (d *decoder) Read(b []byte) (int, error) { 39 if d.err != nil { 40 return 0, d.err 41 } 42 if d.remain == 0 { 43 return 0, io.EOF 44 } 45 if len(b) > d.remain { 46 b = b[:d.remain] 47 } 48 n, err := d.reader.Read(b) 49 if n > 0 && d.table != nil { 50 d.crc32 = crc32.Update(d.crc32, d.table, b[:n]) 51 } 52 d.remain -= n 53 return n, err 54 } 55 56 func (d *decoder) ReadByte() (byte, error) { 57 c := d.readByte() 58 return c, d.err 59 } 60 61 func (d *decoder) done() bool { 62 return d.remain == 0 || d.err != nil 63 } 64 65 func (d *decoder) setCRC(table *crc32.Table) { 66 d.table, d.crc32 = table, 0 67 } 68 69 func (d *decoder) decodeBool(v value) { 70 v.setBool(d.readBool()) 71 } 72 73 func (d *decoder) decodeInt8(v value) { 74 v.setInt8(d.readInt8()) 75 } 76 77 func (d *decoder) decodeInt16(v value) { 78 v.setInt16(d.readInt16()) 79 } 80 81 func (d *decoder) decodeInt32(v value) { 82 v.setInt32(d.readInt32()) 83 } 84 85 func (d *decoder) decodeInt64(v value) { 86 v.setInt64(d.readInt64()) 87 } 88 89 func (d *decoder) decodeFloat64(v value) { 90 v.setFloat64(d.readFloat64()) 91 } 92 93 func (d *decoder) decodeString(v value) { 94 v.setString(d.readString()) 95 } 96 97 func (d *decoder) decodeCompactString(v value) { 98 v.setString(d.readCompactString()) 99 } 100 101 func (d *decoder) decodeBytes(v value) { 102 v.setBytes(d.readBytes()) 103 } 104 105 func (d *decoder) decodeCompactBytes(v value) { 106 v.setBytes(d.readCompactBytes()) 107 } 108 109 func (d *decoder) decodeArray(v value, elemType reflect.Type, decodeElem decodeFunc) { 110 if n := d.readInt32(); n < 0 { 111 v.setArray(array{}) 112 } else { 113 a := makeArray(elemType, int(n)) 114 for i := 0; i < int(n) && d.remain > 0; i++ { 115 decodeElem(d, a.index(i)) 116 } 117 v.setArray(a) 118 } 119 } 120 121 func (d *decoder) decodeCompactArray(v value, elemType reflect.Type, decodeElem decodeFunc) { 122 if n := d.readUnsignedVarInt(); n < 1 { 123 v.setArray(array{}) 124 } else { 125 a := makeArray(elemType, int(n-1)) 126 for i := 0; i < int(n-1) && d.remain > 0; i++ { 127 decodeElem(d, a.index(i)) 128 } 129 v.setArray(a) 130 } 131 } 132 133 func (d *decoder) discardAll() { 134 d.discard(d.remain) 135 } 136 137 func (d *decoder) discard(n int) { 138 if n > d.remain { 139 n = d.remain 140 } 141 var err error 142 if r, _ := d.reader.(discarder); r != nil { 143 n, err = r.Discard(n) 144 d.remain -= n 145 } else { 146 _, err = io.Copy(ioutil.Discard, d) 147 } 148 d.setError(err) 149 } 150 151 func (d *decoder) read(n int) []byte { 152 b := make([]byte, n) 153 n, err := io.ReadFull(d, b) 154 b = b[:n] 155 d.setError(err) 156 return b 157 } 158 159 func (d *decoder) writeTo(w io.Writer, n int) { 160 limit := d.remain 161 if n < limit { 162 d.remain = n 163 } 164 c, err := io.Copy(w, d) 165 if int(c) < n && err == nil { 166 err = io.ErrUnexpectedEOF 167 } 168 d.remain = limit - int(c) 169 d.setError(err) 170 } 171 172 func (d *decoder) setError(err error) { 173 if d.err == nil && err != nil { 174 d.err = err 175 d.discardAll() 176 } 177 } 178 179 func (d *decoder) readFull(b []byte) bool { 180 n, err := io.ReadFull(d, b) 181 d.setError(err) 182 return n == len(b) 183 } 184 185 func (d *decoder) readByte() byte { 186 if d.readFull(d.buffer[:1]) { 187 return d.buffer[0] 188 } 189 return 0 190 } 191 192 func (d *decoder) readBool() bool { 193 return d.readByte() != 0 194 } 195 196 func (d *decoder) readInt8() int8 { 197 if d.readFull(d.buffer[:1]) { 198 return readInt8(d.buffer[:1]) 199 } 200 return 0 201 } 202 203 func (d *decoder) readInt16() int16 { 204 if d.readFull(d.buffer[:2]) { 205 return readInt16(d.buffer[:2]) 206 } 207 return 0 208 } 209 210 func (d *decoder) readInt32() int32 { 211 if d.readFull(d.buffer[:4]) { 212 return readInt32(d.buffer[:4]) 213 } 214 return 0 215 } 216 217 func (d *decoder) readInt64() int64 { 218 if d.readFull(d.buffer[:8]) { 219 return readInt64(d.buffer[:8]) 220 } 221 return 0 222 } 223 224 func (d *decoder) readFloat64() float64 { 225 if d.readFull(d.buffer[:8]) { 226 return readFloat64(d.buffer[:8]) 227 } 228 return 0 229 } 230 231 func (d *decoder) readString() string { 232 if n := d.readInt16(); n < 0 { 233 return "" 234 } else { 235 return bytesToString(d.read(int(n))) 236 } 237 } 238 239 func (d *decoder) readVarString() string { 240 if n := d.readVarInt(); n < 0 { 241 return "" 242 } else { 243 return bytesToString(d.read(int(n))) 244 } 245 } 246 247 func (d *decoder) readCompactString() string { 248 if n := d.readUnsignedVarInt(); n < 1 { 249 return "" 250 } else { 251 return bytesToString(d.read(int(n - 1))) 252 } 253 } 254 255 func (d *decoder) readBytes() []byte { 256 if n := d.readInt32(); n < 0 { 257 return nil 258 } else { 259 return d.read(int(n)) 260 } 261 } 262 263 func (d *decoder) readVarBytes() []byte { 264 if n := d.readVarInt(); n < 0 { 265 return nil 266 } else { 267 return d.read(int(n)) 268 } 269 } 270 271 func (d *decoder) readCompactBytes() []byte { 272 if n := d.readUnsignedVarInt(); n < 1 { 273 return nil 274 } else { 275 return d.read(int(n - 1)) 276 } 277 } 278 279 func (d *decoder) readVarInt() int64 { 280 n := 11 // varints are at most 11 bytes 281 282 if n > d.remain { 283 n = d.remain 284 } 285 286 x := uint64(0) 287 s := uint(0) 288 289 for n > 0 { 290 b := d.readByte() 291 292 if (b & 0x80) == 0 { 293 x |= uint64(b) << s 294 return int64(x>>1) ^ -(int64(x) & 1) 295 } 296 297 x |= uint64(b&0x7f) << s 298 s += 7 299 n-- 300 } 301 302 d.setError(fmt.Errorf("cannot decode varint from input stream")) 303 return 0 304 } 305 306 func (d *decoder) readUnsignedVarInt() uint64 { 307 n := 11 // varints are at most 11 bytes 308 309 if n > d.remain { 310 n = d.remain 311 } 312 313 x := uint64(0) 314 s := uint(0) 315 316 for n > 0 { 317 b := d.readByte() 318 319 if (b & 0x80) == 0 { 320 x |= uint64(b) << s 321 return x 322 } 323 324 x |= uint64(b&0x7f) << s 325 s += 7 326 n-- 327 } 328 329 d.setError(fmt.Errorf("cannot decode unsigned varint from input stream")) 330 return 0 331 } 332 333 type decodeFunc func(*decoder, value) 334 335 var ( 336 _ io.Reader = (*decoder)(nil) 337 _ io.ByteReader = (*decoder)(nil) 338 339 readerFrom = reflect.TypeOf((*io.ReaderFrom)(nil)).Elem() 340 ) 341 342 func decodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc { 343 if reflect.PtrTo(typ).Implements(readerFrom) { 344 return readerDecodeFuncOf(typ) 345 } 346 switch typ.Kind() { 347 case reflect.Bool: 348 return (*decoder).decodeBool 349 case reflect.Int8: 350 return (*decoder).decodeInt8 351 case reflect.Int16: 352 return (*decoder).decodeInt16 353 case reflect.Int32: 354 return (*decoder).decodeInt32 355 case reflect.Int64: 356 return (*decoder).decodeInt64 357 case reflect.Float64: 358 return (*decoder).decodeFloat64 359 case reflect.String: 360 return stringDecodeFuncOf(flexible, tag) 361 case reflect.Struct: 362 return structDecodeFuncOf(typ, version, flexible) 363 case reflect.Slice: 364 if typ.Elem().Kind() == reflect.Uint8 { // []byte 365 return bytesDecodeFuncOf(flexible, tag) 366 } 367 return arrayDecodeFuncOf(typ, version, flexible, tag) 368 default: 369 panic("unsupported type: " + typ.String()) 370 } 371 } 372 373 func stringDecodeFuncOf(flexible bool, tag structTag) decodeFunc { 374 if flexible { 375 // In flexible messages, all strings are compact 376 return (*decoder).decodeCompactString 377 } 378 return (*decoder).decodeString 379 } 380 381 func bytesDecodeFuncOf(flexible bool, tag structTag) decodeFunc { 382 if flexible { 383 // In flexible messages, all arrays are compact 384 return (*decoder).decodeCompactBytes 385 } 386 return (*decoder).decodeBytes 387 } 388 389 func structDecodeFuncOf(typ reflect.Type, version int16, flexible bool) decodeFunc { 390 type field struct { 391 decode decodeFunc 392 index index 393 tagID int 394 } 395 396 var fields []field 397 taggedFields := map[int]*field{} 398 399 forEachStructField(typ, func(typ reflect.Type, index index, tag string) { 400 forEachStructTag(tag, func(tag structTag) bool { 401 if tag.MinVersion <= version && version <= tag.MaxVersion { 402 f := field{ 403 decode: decodeFuncOf(typ, version, flexible, tag), 404 index: index, 405 tagID: tag.TagID, 406 } 407 408 if tag.TagID < -1 { 409 // Normal required field 410 fields = append(fields, f) 411 } else { 412 // Optional tagged field (flexible messages only) 413 taggedFields[tag.TagID] = &f 414 } 415 return false 416 } 417 return true 418 }) 419 }) 420 421 return func(d *decoder, v value) { 422 for i := range fields { 423 f := &fields[i] 424 f.decode(d, v.fieldByIndex(f.index)) 425 } 426 427 if flexible { 428 // See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields 429 // for details of tag buffers in "flexible" messages. 430 n := int(d.readUnsignedVarInt()) 431 432 for i := 0; i < n; i++ { 433 tagID := int(d.readUnsignedVarInt()) 434 size := int(d.readUnsignedVarInt()) 435 436 f, ok := taggedFields[tagID] 437 if ok { 438 f.decode(d, v.fieldByIndex(f.index)) 439 } else { 440 d.read(size) 441 } 442 } 443 } 444 } 445 } 446 447 func arrayDecodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc { 448 elemType := typ.Elem() 449 elemFunc := decodeFuncOf(elemType, version, flexible, tag) 450 if flexible { 451 // In flexible messages, all arrays are compact 452 return func(d *decoder, v value) { d.decodeCompactArray(v, elemType, elemFunc) } 453 } 454 455 return func(d *decoder, v value) { d.decodeArray(v, elemType, elemFunc) } 456 } 457 458 func readerDecodeFuncOf(typ reflect.Type) decodeFunc { 459 typ = reflect.PtrTo(typ) 460 return func(d *decoder, v value) { 461 if d.err == nil { 462 _, err := v.iface(typ).(io.ReaderFrom).ReadFrom(d) 463 if err != nil { 464 d.setError(err) 465 } 466 } 467 } 468 } 469 470 func readInt8(b []byte) int8 { 471 return int8(b[0]) 472 } 473 474 func readInt16(b []byte) int16 { 475 return int16(binary.BigEndian.Uint16(b)) 476 } 477 478 func readInt32(b []byte) int32 { 479 return int32(binary.BigEndian.Uint32(b)) 480 } 481 482 func readInt64(b []byte) int64 { 483 return int64(binary.BigEndian.Uint64(b)) 484 } 485 486 func readFloat64(b []byte) float64 { 487 return math.Float64frombits(binary.BigEndian.Uint64(b)) 488 } 489 490 func Unmarshal(data []byte, version int16, value interface{}) error { 491 typ := elemTypeOf(value) 492 cache, _ := unmarshalers.Load().(map[versionedType]decodeFunc) 493 key := versionedType{typ: typ, version: version} 494 decode := cache[key] 495 496 if decode == nil { 497 decode = decodeFuncOf(reflect.TypeOf(value).Elem(), version, false, structTag{ 498 MinVersion: -1, 499 MaxVersion: -1, 500 TagID: -2, 501 Compact: true, 502 Nullable: true, 503 }) 504 505 newCache := make(map[versionedType]decodeFunc, len(cache)+1) 506 newCache[key] = decode 507 508 for typ, fun := range cache { 509 newCache[typ] = fun 510 } 511 512 unmarshalers.Store(newCache) 513 } 514 515 d, _ := decoders.Get().(*decoder) 516 if d == nil { 517 d = &decoder{reader: bytes.NewReader(nil)} 518 } 519 520 d.remain = len(data) 521 r, _ := d.reader.(*bytes.Reader) 522 r.Reset(data) 523 524 defer func() { 525 r.Reset(nil) 526 d.Reset(r, 0) 527 decoders.Put(d) 528 }() 529 530 decode(d, valueOf(value)) 531 return dontExpectEOF(d.err) 532 } 533 534 var ( 535 decoders sync.Pool // *decoder 536 unmarshalers atomic.Value // map[versionedType]decodeFunc 537 )