github.com/deanMdreon/kafka-go@v0.4.32/protocol/decode.go (about) 1 package protocol 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "fmt" 7 "hash/crc32" 8 "io" 9 "io/ioutil" 10 "reflect" 11 "sync" 12 "sync/atomic" 13 ) 14 15 type discarder interface { 16 Discard(int) (int, error) 17 } 18 19 type decoder struct { 20 reader io.Reader 21 remain int 22 buffer [8]byte 23 err error 24 table *crc32.Table 25 crc32 uint32 26 } 27 28 func (d *decoder) Reset(r io.Reader, n int) { 29 d.reader = r 30 d.remain = n 31 d.buffer = [8]byte{} 32 d.err = nil 33 d.table = nil 34 d.crc32 = 0 35 } 36 37 func (d *decoder) Read(b []byte) (int, error) { 38 if d.err != nil { 39 return 0, d.err 40 } 41 if d.remain == 0 { 42 return 0, io.EOF 43 } 44 if len(b) > d.remain { 45 b = b[:d.remain] 46 } 47 n, err := d.reader.Read(b) 48 if n > 0 && d.table != nil { 49 d.crc32 = crc32.Update(d.crc32, d.table, b[:n]) 50 } 51 d.remain -= n 52 return n, err 53 } 54 55 func (d *decoder) ReadByte() (byte, error) { 56 c := d.readByte() 57 return c, d.err 58 } 59 60 func (d *decoder) done() bool { 61 return d.remain == 0 || d.err != nil 62 } 63 64 func (d *decoder) setCRC(table *crc32.Table) { 65 d.table, d.crc32 = table, 0 66 } 67 68 func (d *decoder) decodeBool(v value) { 69 v.setBool(d.readBool()) 70 } 71 72 func (d *decoder) decodeInt8(v value) { 73 v.setInt8(d.readInt8()) 74 } 75 76 func (d *decoder) decodeInt16(v value) { 77 v.setInt16(d.readInt16()) 78 } 79 80 func (d *decoder) decodeInt32(v value) { 81 v.setInt32(d.readInt32()) 82 } 83 84 func (d *decoder) decodeInt64(v value) { 85 v.setInt64(d.readInt64()) 86 } 87 88 func (d *decoder) decodeString(v value) { 89 v.setString(d.readString()) 90 } 91 92 func (d *decoder) decodeCompactString(v value) { 93 v.setString(d.readCompactString()) 94 } 95 96 func (d *decoder) decodeBytes(v value) { 97 v.setBytes(d.readBytes()) 98 } 99 100 func (d *decoder) decodeCompactBytes(v value) { 101 v.setBytes(d.readCompactBytes()) 102 } 103 104 func (d *decoder) decodeArray(v value, elemType reflect.Type, decodeElem decodeFunc) { 105 if n := d.readInt32(); n < 0 { 106 v.setArray(array{}) 107 } else { 108 a := makeArray(elemType, int(n)) 109 for i := 0; i < int(n) && d.remain > 0; i++ { 110 decodeElem(d, a.index(i)) 111 } 112 v.setArray(a) 113 } 114 } 115 116 func (d *decoder) decodeCompactArray(v value, elemType reflect.Type, decodeElem decodeFunc) { 117 if n := d.readUnsignedVarInt(); n < 1 { 118 v.setArray(array{}) 119 } else { 120 a := makeArray(elemType, int(n-1)) 121 for i := 0; i < int(n-1) && d.remain > 0; i++ { 122 decodeElem(d, a.index(i)) 123 } 124 v.setArray(a) 125 } 126 } 127 128 func (d *decoder) discardAll() { 129 d.discard(d.remain) 130 } 131 132 func (d *decoder) discard(n int) { 133 if n > d.remain { 134 n = d.remain 135 } 136 var err error 137 if r, _ := d.reader.(discarder); r != nil { 138 n, err = r.Discard(n) 139 d.remain -= n 140 } else { 141 _, err = io.Copy(ioutil.Discard, d) 142 } 143 d.setError(err) 144 } 145 146 func (d *decoder) read(n int) []byte { 147 b := make([]byte, n) 148 n, err := io.ReadFull(d, b) 149 b = b[:n] 150 d.setError(err) 151 return b 152 } 153 154 func (d *decoder) writeTo(w io.Writer, n int) { 155 limit := d.remain 156 if n < limit { 157 d.remain = n 158 } 159 c, err := io.Copy(w, d) 160 if int(c) < n && err == nil { 161 err = io.ErrUnexpectedEOF 162 } 163 d.remain = limit - int(c) 164 d.setError(err) 165 } 166 167 func (d *decoder) setError(err error) { 168 if d.err == nil && err != nil { 169 d.err = err 170 d.discardAll() 171 } 172 } 173 174 func (d *decoder) readFull(b []byte) bool { 175 n, err := io.ReadFull(d, b) 176 d.setError(err) 177 return n == len(b) 178 } 179 180 func (d *decoder) readByte() byte { 181 if d.readFull(d.buffer[:1]) { 182 return d.buffer[0] 183 } 184 return 0 185 } 186 187 func (d *decoder) readBool() bool { 188 return d.readByte() != 0 189 } 190 191 func (d *decoder) readInt8() int8 { 192 if d.readFull(d.buffer[:1]) { 193 return readInt8(d.buffer[:1]) 194 } 195 return 0 196 } 197 198 func (d *decoder) readInt16() int16 { 199 if d.readFull(d.buffer[:2]) { 200 return readInt16(d.buffer[:2]) 201 } 202 return 0 203 } 204 205 func (d *decoder) readInt32() int32 { 206 if d.readFull(d.buffer[:4]) { 207 return readInt32(d.buffer[:4]) 208 } 209 return 0 210 } 211 212 func (d *decoder) readInt64() int64 { 213 if d.readFull(d.buffer[:8]) { 214 return readInt64(d.buffer[:8]) 215 } 216 return 0 217 } 218 219 func (d *decoder) readString() string { 220 if n := d.readInt16(); n < 0 { 221 return "" 222 } else { 223 return bytesToString(d.read(int(n))) 224 } 225 } 226 227 func (d *decoder) readVarString() string { 228 if n := d.readVarInt(); n < 0 { 229 return "" 230 } else { 231 return bytesToString(d.read(int(n))) 232 } 233 } 234 235 func (d *decoder) readCompactString() string { 236 if n := d.readUnsignedVarInt(); n < 1 { 237 return "" 238 } else { 239 return bytesToString(d.read(int(n - 1))) 240 } 241 } 242 243 func (d *decoder) readBytes() []byte { 244 if n := d.readInt32(); n < 0 { 245 return nil 246 } else { 247 return d.read(int(n)) 248 } 249 } 250 251 func (d *decoder) readBytesTo(w io.Writer) bool { 252 if n := d.readInt32(); n < 0 { 253 return false 254 } else { 255 d.writeTo(w, int(n)) 256 return d.err == nil 257 } 258 } 259 260 func (d *decoder) readVarBytes() []byte { 261 if n := d.readVarInt(); n < 0 { 262 return nil 263 } else { 264 return d.read(int(n)) 265 } 266 } 267 268 func (d *decoder) readVarBytesTo(w io.Writer) bool { 269 if n := d.readVarInt(); n < 0 { 270 return false 271 } else { 272 d.writeTo(w, int(n)) 273 return d.err == nil 274 } 275 } 276 277 func (d *decoder) readCompactBytes() []byte { 278 if n := d.readUnsignedVarInt(); n < 1 { 279 return nil 280 } else { 281 return d.read(int(n - 1)) 282 } 283 } 284 285 func (d *decoder) readCompactBytesTo(w io.Writer) bool { 286 if n := d.readUnsignedVarInt(); n < 1 { 287 return false 288 } else { 289 d.writeTo(w, int(n-1)) 290 return d.err == nil 291 } 292 } 293 294 func (d *decoder) readVarInt() int64 { 295 n := 11 // varints are at most 11 bytes 296 297 if n > d.remain { 298 n = d.remain 299 } 300 301 x := uint64(0) 302 s := uint(0) 303 304 for n > 0 { 305 b := d.readByte() 306 307 if (b & 0x80) == 0 { 308 x |= uint64(b) << s 309 return int64(x>>1) ^ -(int64(x) & 1) 310 } 311 312 x |= uint64(b&0x7f) << s 313 s += 7 314 n-- 315 } 316 317 d.setError(fmt.Errorf("cannot decode varint from input stream")) 318 return 0 319 } 320 321 func (d *decoder) readUnsignedVarInt() uint64 { 322 n := 11 // varints are at most 11 bytes 323 324 if n > d.remain { 325 n = d.remain 326 } 327 328 x := uint64(0) 329 s := uint(0) 330 331 for n > 0 { 332 b := d.readByte() 333 334 if (b & 0x80) == 0 { 335 x |= uint64(b) << s 336 return x 337 } 338 339 x |= uint64(b&0x7f) << s 340 s += 7 341 n-- 342 } 343 344 d.setError(fmt.Errorf("cannot decode unsigned varint from input stream")) 345 return 0 346 } 347 348 type decodeFunc func(*decoder, value) 349 350 var ( 351 _ io.Reader = (*decoder)(nil) 352 _ io.ByteReader = (*decoder)(nil) 353 354 readerFrom = reflect.TypeOf((*io.ReaderFrom)(nil)).Elem() 355 ) 356 357 func decodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc { 358 if reflect.PtrTo(typ).Implements(readerFrom) { 359 return readerDecodeFuncOf(typ) 360 } 361 switch typ.Kind() { 362 case reflect.Bool: 363 return (*decoder).decodeBool 364 case reflect.Int8: 365 return (*decoder).decodeInt8 366 case reflect.Int16: 367 return (*decoder).decodeInt16 368 case reflect.Int32: 369 return (*decoder).decodeInt32 370 case reflect.Int64: 371 return (*decoder).decodeInt64 372 case reflect.String: 373 return stringDecodeFuncOf(flexible, tag) 374 case reflect.Struct: 375 return structDecodeFuncOf(typ, version, flexible) 376 case reflect.Slice: 377 if typ.Elem().Kind() == reflect.Uint8 { // []byte 378 return bytesDecodeFuncOf(flexible, tag) 379 } 380 return arrayDecodeFuncOf(typ, version, flexible, tag) 381 default: 382 panic("unsupported type: " + typ.String()) 383 } 384 } 385 386 func stringDecodeFuncOf(flexible bool, tag structTag) decodeFunc { 387 if flexible { 388 // In flexible messages, all strings are compact 389 return (*decoder).decodeCompactString 390 } 391 return (*decoder).decodeString 392 } 393 394 func bytesDecodeFuncOf(flexible bool, tag structTag) decodeFunc { 395 if flexible { 396 // In flexible messages, all arrays are compact 397 return (*decoder).decodeCompactBytes 398 } 399 return (*decoder).decodeBytes 400 } 401 402 func structDecodeFuncOf(typ reflect.Type, version int16, flexible bool) decodeFunc { 403 type field struct { 404 decode decodeFunc 405 index index 406 tagID int 407 } 408 409 var fields []field 410 taggedFields := map[int]*field{} 411 412 forEachStructField(typ, func(typ reflect.Type, index index, tag string) { 413 forEachStructTag(tag, func(tag structTag) bool { 414 if tag.MinVersion <= version && version <= tag.MaxVersion { 415 f := field{ 416 decode: decodeFuncOf(typ, version, flexible, tag), 417 index: index, 418 tagID: tag.TagID, 419 } 420 421 if tag.TagID < -1 { 422 // Normal required field 423 fields = append(fields, f) 424 } else { 425 // Optional tagged field (flexible messages only) 426 taggedFields[tag.TagID] = &f 427 } 428 return false 429 } 430 return true 431 }) 432 }) 433 434 return func(d *decoder, v value) { 435 for i := range fields { 436 f := &fields[i] 437 f.decode(d, v.fieldByIndex(f.index)) 438 } 439 440 if flexible { 441 // See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields 442 // for details of tag buffers in "flexible" messages. 443 n := int(d.readUnsignedVarInt()) 444 445 for i := 0; i < n; i++ { 446 tagID := int(d.readUnsignedVarInt()) 447 size := int(d.readUnsignedVarInt()) 448 449 f, ok := taggedFields[tagID] 450 if ok { 451 f.decode(d, v.fieldByIndex(f.index)) 452 } else { 453 d.read(size) 454 } 455 } 456 } 457 } 458 } 459 460 func arrayDecodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc { 461 elemType := typ.Elem() 462 elemFunc := decodeFuncOf(elemType, version, flexible, tag) 463 if flexible { 464 // In flexible messages, all arrays are compact 465 return func(d *decoder, v value) { d.decodeCompactArray(v, elemType, elemFunc) } 466 } 467 468 return func(d *decoder, v value) { d.decodeArray(v, elemType, elemFunc) } 469 } 470 471 func readerDecodeFuncOf(typ reflect.Type) decodeFunc { 472 typ = reflect.PtrTo(typ) 473 return func(d *decoder, v value) { 474 if d.err == nil { 475 _, err := v.iface(typ).(io.ReaderFrom).ReadFrom(d) 476 if err != nil { 477 d.setError(err) 478 } 479 } 480 } 481 } 482 483 func readInt8(b []byte) int8 { 484 return int8(b[0]) 485 } 486 487 func readInt16(b []byte) int16 { 488 return int16(binary.BigEndian.Uint16(b)) 489 } 490 491 func readInt32(b []byte) int32 { 492 return int32(binary.BigEndian.Uint32(b)) 493 } 494 495 func readInt64(b []byte) int64 { 496 return int64(binary.BigEndian.Uint64(b)) 497 } 498 499 func Unmarshal(data []byte, version int16, value interface{}) error { 500 typ := elemTypeOf(value) 501 cache, _ := unmarshalers.Load().(map[versionedType]decodeFunc) 502 key := versionedType{typ: typ, version: version} 503 decode := cache[key] 504 505 if decode == nil { 506 decode = decodeFuncOf(reflect.TypeOf(value).Elem(), version, false, structTag{ 507 MinVersion: -1, 508 MaxVersion: -1, 509 TagID: -2, 510 Compact: true, 511 Nullable: true, 512 }) 513 514 newCache := make(map[versionedType]decodeFunc, len(cache)+1) 515 newCache[key] = decode 516 517 for typ, fun := range cache { 518 newCache[typ] = fun 519 } 520 521 unmarshalers.Store(newCache) 522 } 523 524 d, _ := decoders.Get().(*decoder) 525 if d == nil { 526 d = &decoder{reader: bytes.NewReader(nil)} 527 } 528 529 d.remain = len(data) 530 r, _ := d.reader.(*bytes.Reader) 531 r.Reset(data) 532 533 defer func() { 534 r.Reset(nil) 535 d.Reset(r, 0) 536 decoders.Put(d) 537 }() 538 539 decode(d, valueOf(value)) 540 return dontExpectEOF(d.err) 541 } 542 543 var ( 544 decoders sync.Pool // *decoder 545 unmarshalers atomic.Value // map[versionedType]decodeFunc 546 )