github.com/anacrolix/torrent@v1.61.0/bencode/decode.go (about) 1 package bencode 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 "math/big" 9 "reflect" 10 "runtime" 11 "strconv" 12 "sync" 13 ) 14 15 // The default bencode string length limit. This is a poor attempt to prevent excessive memory 16 // allocation when parsing, but also leaves the window open to implement a better solution. 17 const DefaultDecodeMaxStrLen = 1<<27 - 1 // ~128MiB 18 19 type MaxStrLen = int64 20 21 type Decoder struct { 22 // Maximum parsed bencode string length. Defaults to DefaultMaxStrLen if zero. 23 MaxStrLen MaxStrLen 24 25 r interface { 26 io.ByteScanner 27 io.Reader 28 } 29 // Sum of bytes used to Decode values. 30 Offset int64 31 buf bytes.Buffer 32 } 33 34 func (d *Decoder) Decode(v interface{}) (err error) { 35 defer func() { 36 if err != nil { 37 return 38 } 39 r := recover() 40 if r == nil { 41 return 42 } 43 _, ok := r.(runtime.Error) 44 if ok { 45 panic(r) 46 } 47 if err, ok = r.(error); !ok { 48 panic(r) 49 } 50 // Errors thrown from deeper in parsing are unexpected. At value boundaries, errors should 51 // be returned directly (at least until all the panic nonsense is removed entirely). 52 if err == io.EOF { 53 err = io.ErrUnexpectedEOF 54 } 55 }() 56 57 pv := reflect.ValueOf(v) 58 if pv.Kind() != reflect.Ptr || pv.IsNil() { 59 return &UnmarshalInvalidArgError{reflect.TypeOf(v)} 60 } 61 62 ok, err := d.parseValue(pv.Elem()) 63 if err != nil { 64 return 65 } 66 if !ok { 67 d.throwSyntaxError(d.Offset-1, errors.New("unexpected 'e'")) 68 } 69 return 70 } 71 72 // Check for EOF in the decoder input stream. Used to assert the input ends on a clean message 73 // boundary. 74 func (d *Decoder) ReadEOF() error { 75 _, err := d.r.ReadByte() 76 if err == nil { 77 err := d.r.UnreadByte() 78 if err != nil { 79 panic(err) 80 } 81 return errors.New("expected EOF") 82 } 83 if err == io.EOF { 84 return nil 85 } 86 return fmt.Errorf("expected EOF, got %w", err) 87 } 88 89 func checkForUnexpectedEOF(err error, offset int64) { 90 if err == io.EOF { 91 panic(&SyntaxError{ 92 Offset: offset, 93 What: io.ErrUnexpectedEOF, 94 }) 95 } 96 } 97 98 func (d *Decoder) readByte() byte { 99 b, err := d.r.ReadByte() 100 if err != nil { 101 checkForUnexpectedEOF(err, d.Offset) 102 panic(err) 103 } 104 105 d.Offset++ 106 return b 107 } 108 109 // reads data writing it to 'd.buf' until 'sep' byte is encountered, 'sep' byte 110 // is consumed, but not included into the 'd.buf' 111 func (d *Decoder) readUntil(sep byte) { 112 for { 113 b := d.readByte() 114 if b == sep { 115 return 116 } 117 d.buf.WriteByte(b) 118 } 119 } 120 121 func checkForIntParseError(err error, offset int64) { 122 if err != nil { 123 panic(&SyntaxError{ 124 Offset: offset, 125 What: err, 126 }) 127 } 128 } 129 130 func (d *Decoder) throwSyntaxError(offset int64, err error) { 131 panic(&SyntaxError{ 132 Offset: offset, 133 What: err, 134 }) 135 } 136 137 // Assume the 'i' is already consumed. Read and validate the rest of an int into the buffer. 138 func (d *Decoder) readInt() error { 139 // start := d.Offset - 1 140 d.readUntil('e') 141 if err := d.checkBufferedInt(); err != nil { 142 return err 143 } 144 // if d.buf.Len() == 0 { 145 // panic(&SyntaxError{ 146 // Offset: start, 147 // What: errors.New("empty integer value"), 148 // }) 149 // } 150 return nil 151 } 152 153 // called when 'i' was consumed, for the integer type in v. 154 func (d *Decoder) parseInt(v reflect.Value) error { 155 start := d.Offset - 1 156 157 if err := d.readInt(); err != nil { 158 return err 159 } 160 s := bytesAsString(d.buf.Bytes()) 161 162 switch v.Kind() { 163 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 164 n, err := strconv.ParseInt(s, 10, 64) 165 checkForIntParseError(err, start) 166 167 if v.OverflowInt(n) { 168 return &UnmarshalTypeError{ 169 BencodeTypeName: "int", 170 UnmarshalTargetType: v.Type(), 171 } 172 } 173 v.SetInt(n) 174 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 175 n, err := strconv.ParseUint(s, 10, 64) 176 checkForIntParseError(err, start) 177 178 if v.OverflowUint(n) { 179 return &UnmarshalTypeError{ 180 BencodeTypeName: "int", 181 UnmarshalTargetType: v.Type(), 182 } 183 } 184 v.SetUint(n) 185 case reflect.Bool: 186 v.SetBool(s != "0") 187 default: 188 return &UnmarshalTypeError{ 189 BencodeTypeName: "int", 190 UnmarshalTargetType: v.Type(), 191 } 192 } 193 d.buf.Reset() 194 return nil 195 } 196 197 func (d *Decoder) checkBufferedInt() error { 198 b := d.buf.Bytes() 199 if len(b) <= 1 { 200 return nil 201 } 202 if b[0] == '-' { 203 b = b[1:] 204 } 205 if b[0] < '1' || b[0] > '9' { 206 return errors.New("invalid leading digit") 207 } 208 return nil 209 } 210 211 func (d *Decoder) parseStringLength() (int, error) { 212 // We should have already consumed the first byte of the length into the Decoder buf. 213 start := d.Offset - 1 214 d.readUntil(':') 215 if err := d.checkBufferedInt(); err != nil { 216 return 0, err 217 } 218 // Really the limit should be the uint size for the platform. But we can't pass in an allocator, 219 // or limit total memory use in Go, the best we might hope to do is limit the size of a single 220 // decoded value (by reading it in in-place and then operating on a view). 221 length, err := strconv.ParseInt(bytesAsString(d.buf.Bytes()), 10, 0) 222 checkForIntParseError(err, start) 223 if int64(length) > d.getMaxStrLen() { 224 err = fmt.Errorf("parsed string length %v exceeds limit (%v)", length, DefaultDecodeMaxStrLen) 225 } 226 d.buf.Reset() 227 return int(length), err 228 } 229 230 func (d *Decoder) parseString(v reflect.Value) error { 231 length, err := d.parseStringLength() 232 if err != nil { 233 return err 234 } 235 defer d.buf.Reset() 236 read := func(b []byte) { 237 n, err := io.ReadFull(d.r, b) 238 d.Offset += int64(n) 239 if err != nil { 240 checkForUnexpectedEOF(err, d.Offset) 241 panic(&SyntaxError{ 242 Offset: d.Offset, 243 What: errors.New("unexpected I/O error: " + err.Error()), 244 }) 245 } 246 } 247 248 switch v.Kind() { 249 case reflect.String: 250 b := make([]byte, length) 251 read(b) 252 v.SetString(bytesAsString(b)) 253 return nil 254 case reflect.Slice: 255 if v.Type().Elem().Kind() != reflect.Uint8 { 256 break 257 } 258 b := make([]byte, length) 259 read(b) 260 v.SetBytes(b) 261 return nil 262 case reflect.Array: 263 if v.Type().Elem().Kind() != reflect.Uint8 { 264 break 265 } 266 d.buf.Grow(length) 267 b := d.buf.Bytes()[:length] 268 read(b) 269 reflect.Copy(v, reflect.ValueOf(b)) 270 return nil 271 case reflect.Bool: 272 d.buf.Grow(length) 273 b := d.buf.Bytes()[:length] 274 read(b) 275 x, err := strconv.ParseBool(bytesAsString(b)) 276 if err != nil { 277 x = length != 0 278 } 279 v.SetBool(x) 280 return nil 281 } 282 // Can't move this into default clause because some cases above fail through to here after 283 // additional checks. 284 d.buf.Grow(length) 285 read(d.buf.Bytes()[:length]) 286 // I believe we return here to support "ignore_unmarshal_type_error". 287 return &UnmarshalTypeError{ 288 BencodeTypeName: "string", 289 UnmarshalTargetType: v.Type(), 290 } 291 } 292 293 // Info for parsing a dict value. 294 type dictField struct { 295 Type reflect.Type 296 Get func(value reflect.Value) func(reflect.Value) 297 Tags tag 298 } 299 300 // Returns specifics for parsing a dict field value. 301 func getDictField(dict reflect.Type, key reflect.Value) (_ dictField, err error) { 302 // get valuev as a map value or as a struct field 303 switch k := dict.Kind(); k { 304 case reflect.Map: 305 return dictField{ 306 Type: dict.Elem(), 307 Get: func(mapValue reflect.Value) func(reflect.Value) { 308 return func(value reflect.Value) { 309 if mapValue.IsNil() { 310 mapValue.Set(reflect.MakeMap(dict)) 311 } 312 // Assigns the value into the map. 313 mapValue.SetMapIndex(key, value) 314 } 315 }, 316 }, nil 317 case reflect.Struct: 318 if key.Kind() != reflect.String { 319 // This doesn't make sense for structs. They have to use strings. If they didn't they 320 // should at least have things that convert to strings trivially and somehow much the 321 // bencode tag. 322 panic(key) 323 } 324 return getStructFieldForKey(dict, key.String()), nil 325 // if sf.r.PkgPath != "" { 326 // panic(&UnmarshalFieldError{ 327 // First: key, 328 // Type: dict.Type(), 329 // Field: sf.r, 330 // }) 331 // } 332 default: 333 err = fmt.Errorf("can't assign bencode dict items into a %v", k) 334 return 335 } 336 } 337 338 var ( 339 structFieldsMu sync.Mutex 340 structFields = map[reflect.Type]map[string]dictField{} 341 ) 342 343 func parseStructFields(struct_ reflect.Type, each func(key string, df dictField)) { 344 for _i, n := 0, struct_.NumField(); _i < n; _i++ { 345 i := _i 346 f := struct_.Field(i) 347 if f.Anonymous { 348 t := f.Type 349 if t.Kind() == reflect.Ptr { 350 t = t.Elem() 351 } 352 parseStructFields(t, func(key string, df dictField) { 353 innerGet := df.Get 354 df.Get = func(value reflect.Value) func(reflect.Value) { 355 anonPtr := value.Field(i) 356 if anonPtr.Kind() == reflect.Ptr && anonPtr.IsNil() { 357 anonPtr.Set(reflect.New(f.Type.Elem())) 358 anonPtr = anonPtr.Elem() 359 } 360 return innerGet(anonPtr) 361 } 362 each(key, df) 363 }) 364 continue 365 } 366 tagStr := f.Tag.Get("bencode") 367 if tagStr == "-" { 368 continue 369 } 370 tag := parseTag(tagStr) 371 key := tag.Key() 372 if key == "" { 373 key = f.Name 374 } 375 each(key, dictField{f.Type, func(value reflect.Value) func(reflect.Value) { 376 return value.Field(i).Set 377 }, tag}) 378 } 379 } 380 381 func saveStructFields(struct_ reflect.Type) { 382 m := make(map[string]dictField) 383 parseStructFields(struct_, func(key string, sf dictField) { 384 m[key] = sf 385 }) 386 structFields[struct_] = m 387 } 388 389 func getStructFieldForKey(struct_ reflect.Type, key string) (f dictField) { 390 structFieldsMu.Lock() 391 if _, ok := structFields[struct_]; !ok { 392 saveStructFields(struct_) 393 } 394 f, ok := structFields[struct_][key] 395 structFieldsMu.Unlock() 396 if !ok { 397 var discard interface{} 398 return dictField{ 399 Type: reflect.TypeOf(discard), 400 Get: func(reflect.Value) func(reflect.Value) { return func(reflect.Value) {} }, 401 Tags: nil, 402 } 403 } 404 return 405 } 406 407 var structKeyType = reflect.TypeFor[string]() 408 409 func keyType(v reflect.Value) reflect.Type { 410 switch v.Kind() { 411 case reflect.Map: 412 return v.Type().Key() 413 case reflect.Struct: 414 return structKeyType 415 default: 416 return nil 417 } 418 } 419 420 func (d *Decoder) parseDict(v reflect.Value) error { 421 // At this point 'd' byte was consumed, now read key/value pairs. 422 423 // The key type does not need to be a string for maps. 424 keyType := keyType(v) 425 if keyType == nil { 426 return fmt.Errorf("cannot parse dicts into %v", v.Type()) 427 } 428 for { 429 keyValue := reflect.New(keyType).Elem() 430 ok, err := d.parseValue(keyValue) 431 if err != nil { 432 return fmt.Errorf("error parsing dict key: %w", err) 433 } 434 if !ok { 435 return nil 436 } 437 438 df, err := getDictField(v.Type(), keyValue) 439 if err != nil { 440 return fmt.Errorf("parsing bencode dict into %v: %w", v.Type(), err) 441 } 442 443 // now we need to actually parse it 444 if df.Type == nil { 445 // Discard the value, there's nowhere to put it. 446 var if_ interface{} 447 if_, ok = d.parseValueInterface() 448 if if_ == nil { 449 return fmt.Errorf("error parsing value for key %q", keyValue) 450 } 451 if !ok { 452 return fmt.Errorf("missing value for key %q", keyValue) 453 } 454 continue 455 } 456 setValue := reflect.New(df.Type).Elem() 457 // log.Printf("parsing into %v", setValue.Type()) 458 ok, err = d.parseValue(setValue) 459 if err != nil { 460 var target *UnmarshalTypeError 461 if !(errors.As(err, &target) && df.Tags.IgnoreUnmarshalTypeError()) { 462 return fmt.Errorf("parsing value for key %q: %w", keyValue, err) 463 } 464 } 465 if !ok { 466 return fmt.Errorf("missing value for key %q", keyValue) 467 } 468 df.Get(v)(setValue) 469 } 470 } 471 472 func (d *Decoder) parseList(v reflect.Value) error { 473 switch v.Kind() { 474 default: 475 // If the list is a singleton of the expected type, use that value. See 476 // https://github.com/anacrolix/torrent/issues/297. 477 l := reflect.New(reflect.SliceOf(v.Type())) 478 if err := d.parseList(l.Elem()); err != nil { 479 return err 480 } 481 if l.Elem().Len() != 1 { 482 return &UnmarshalTypeError{ 483 BencodeTypeName: "list", 484 UnmarshalTargetType: v.Type(), 485 } 486 } 487 v.Set(l.Elem().Index(0)) 488 return nil 489 case reflect.Array, reflect.Slice: 490 // We can work with this. Normal case, fallthrough. 491 } 492 493 i := 0 494 for ; ; i++ { 495 if v.Kind() == reflect.Slice && i >= v.Len() { 496 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) 497 } 498 499 if i < v.Len() { 500 ok, err := d.parseValue(v.Index(i)) 501 if err != nil { 502 return err 503 } 504 if !ok { 505 break 506 } 507 } else { 508 _, ok := d.parseValueInterface() 509 if !ok { 510 break 511 } 512 } 513 } 514 515 if i < v.Len() { 516 if v.Kind() == reflect.Array { 517 z := reflect.Zero(v.Type().Elem()) 518 for n := v.Len(); i < n; i++ { 519 v.Index(i).Set(z) 520 } 521 } else { 522 v.SetLen(i) 523 } 524 } 525 526 if i == 0 && v.Kind() == reflect.Slice { 527 v.Set(reflect.MakeSlice(v.Type(), 0, 0)) 528 } 529 return nil 530 } 531 532 func (d *Decoder) readOneValue() bool { 533 b, err := d.r.ReadByte() 534 if err != nil { 535 panic(err) 536 } 537 if b == 'e' { 538 d.r.UnreadByte() 539 return false 540 } else { 541 d.Offset++ 542 d.buf.WriteByte(b) 543 } 544 545 switch b { 546 case 'd', 'l': 547 // read until there is nothing to read 548 for d.readOneValue() { 549 } 550 // consume 'e' as well 551 b = d.readByte() 552 d.buf.WriteByte(b) 553 case 'i': 554 d.readUntil('e') 555 d.buf.WriteString("e") 556 default: 557 if b >= '0' && b <= '9' { 558 start := d.buf.Len() - 1 559 d.readUntil(':') 560 length, err := strconv.ParseInt(bytesAsString(d.buf.Bytes()[start:]), 10, 64) 561 checkForIntParseError(err, d.Offset-1) 562 563 d.buf.WriteString(":") 564 n, err := io.CopyN(&d.buf, d.r, length) 565 d.Offset += n 566 if err != nil { 567 checkForUnexpectedEOF(err, d.Offset) 568 panic(&SyntaxError{ 569 Offset: d.Offset, 570 What: errors.New("unexpected I/O error: " + err.Error()), 571 }) 572 } 573 break 574 } 575 576 d.raiseUnknownValueType(b, d.Offset-1) 577 } 578 579 return true 580 } 581 582 func (d *Decoder) parseUnmarshaler(v reflect.Value) bool { 583 if !v.Type().Implements(unmarshalerType) { 584 if v.Addr().Type().Implements(unmarshalerType) { 585 v = v.Addr() 586 } else { 587 return false 588 } 589 } 590 d.buf.Reset() 591 if !d.readOneValue() { 592 return false 593 } 594 m := v.Interface().(Unmarshaler) 595 err := m.UnmarshalBencode(d.buf.Bytes()) 596 if err != nil { 597 panic(&UnmarshalerError{v.Type(), err}) 598 } 599 return true 600 } 601 602 // Returns true if there was a value and it's now stored in 'v'. Otherwise, there was an end symbol 603 // ("e") and no value was stored. 604 func (d *Decoder) parseValue(v reflect.Value) (bool, error) { 605 // we support one level of indirection at the moment 606 if v.Kind() == reflect.Ptr { 607 // if the pointer is nil, allocate a new element of the type it 608 // points to 609 if v.IsNil() { 610 v.Set(reflect.New(v.Type().Elem())) 611 } 612 v = v.Elem() 613 } 614 615 if d.parseUnmarshaler(v) { 616 return true, nil 617 } 618 619 // common case: interface{} 620 if v.Kind() == reflect.Interface && v.NumMethod() == 0 { 621 iface, _ := d.parseValueInterface() 622 v.Set(reflect.ValueOf(iface)) 623 return true, nil 624 } 625 626 b, err := d.r.ReadByte() 627 if err != nil { 628 return false, err 629 } 630 d.Offset++ 631 632 switch b { 633 case 'e': 634 return false, nil 635 case 'd': 636 return true, d.parseDict(v) 637 case 'l': 638 return true, d.parseList(v) 639 case 'i': 640 return true, d.parseInt(v) 641 default: 642 if b >= '0' && b <= '9' { 643 // It's a string. 644 d.buf.Reset() 645 // Write the first digit of the length to the buffer. 646 d.buf.WriteByte(b) 647 return true, d.parseString(v) 648 } 649 650 d.raiseUnknownValueType(b, d.Offset-1) 651 } 652 panic("unreachable") 653 } 654 655 // An unknown bencode type character was encountered. 656 func (d *Decoder) raiseUnknownValueType(b byte, offset int64) { 657 panic(&SyntaxError{ 658 Offset: offset, 659 What: fmt.Errorf("unknown value type %+q", b), 660 }) 661 } 662 663 func (d *Decoder) parseValueInterface() (interface{}, bool) { 664 b, err := d.r.ReadByte() 665 if err != nil { 666 panic(err) 667 } 668 d.Offset++ 669 670 switch b { 671 case 'e': 672 return nil, false 673 case 'd': 674 return d.parseDictInterface(), true 675 case 'l': 676 return d.parseListInterface(), true 677 case 'i': 678 return d.parseIntInterface(), true 679 default: 680 if b >= '0' && b <= '9' { 681 // string 682 // append first digit of the length to the buffer 683 d.buf.WriteByte(b) 684 return d.parseStringInterface(), true 685 } 686 687 d.raiseUnknownValueType(b, d.Offset-1) 688 panic("unreachable") 689 } 690 } 691 692 // Called after 'i', for an arbitrary integer size. 693 func (d *Decoder) parseIntInterface() (ret interface{}) { 694 start := d.Offset - 1 695 696 if err := d.readInt(); err != nil { 697 panic(err) 698 } 699 n, err := strconv.ParseInt(d.buf.String(), 10, 64) 700 if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange { 701 i := new(big.Int) 702 _, ok := i.SetString(d.buf.String(), 10) 703 if !ok { 704 panic(&SyntaxError{ 705 Offset: start, 706 What: errors.New("failed to parse integer"), 707 }) 708 } 709 ret = i 710 } else { 711 checkForIntParseError(err, start) 712 ret = n 713 } 714 715 d.buf.Reset() 716 return 717 } 718 719 func (d *Decoder) readBytes(length int) []byte { 720 b, err := io.ReadAll(io.LimitReader(d.r, int64(length))) 721 if err != nil { 722 panic(err) 723 } 724 if len(b) != length { 725 panic(fmt.Errorf("read %v bytes expected %v", len(b), length)) 726 } 727 return b 728 } 729 730 func (d *Decoder) parseStringInterface() string { 731 length, err := d.parseStringLength() 732 if err != nil { 733 panic(err) 734 } 735 b := d.readBytes(int(length)) 736 d.Offset += int64(len(b)) 737 if err != nil { 738 panic(&SyntaxError{Offset: d.Offset, What: err}) 739 } 740 return bytesAsString(b) 741 } 742 743 func (d *Decoder) parseDictInterface() interface{} { 744 dict := make(map[string]interface{}) 745 var lastKey string 746 lastKeyOk := false 747 for { 748 start := d.Offset 749 keyi, ok := d.parseValueInterface() 750 if !ok { 751 break 752 } 753 754 key, ok := keyi.(string) 755 if !ok { 756 panic(&SyntaxError{ 757 Offset: d.Offset, 758 What: errors.New("non-string key in a dict"), 759 }) 760 } 761 if lastKeyOk && key <= lastKey { 762 d.throwSyntaxError(start, fmt.Errorf("dict keys unsorted: %q <= %q", key, lastKey)) 763 } 764 start = d.Offset 765 valuei, ok := d.parseValueInterface() 766 if !ok { 767 d.throwSyntaxError(start, fmt.Errorf("dict elem missing value [key=%v]", key)) 768 } 769 770 lastKey = key 771 lastKeyOk = true 772 dict[key] = valuei 773 } 774 return dict 775 } 776 777 func (d *Decoder) parseListInterface() (list []interface{}) { 778 list = []interface{}{} 779 valuei, ok := d.parseValueInterface() 780 for ok { 781 list = append(list, valuei) 782 valuei, ok = d.parseValueInterface() 783 } 784 return 785 } 786 787 func (d *Decoder) getMaxStrLen() int64 { 788 if d.MaxStrLen == 0 { 789 return DefaultDecodeMaxStrLen 790 } 791 return d.MaxStrLen 792 }