git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/toml/encode.go (about) 1 package toml 2 3 import ( 4 "bufio" 5 "encoding" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "io" 10 "math" 11 "reflect" 12 "sort" 13 "strconv" 14 "strings" 15 "time" 16 17 "git.sr.ht/~pingoo/stdx/toml/internal" 18 ) 19 20 type tomlEncodeError struct{ error } 21 22 var ( 23 errArrayNilElement = errors.New("toml: cannot encode array with nil element") 24 errNonString = errors.New("toml: cannot encode a map with non-string key type") 25 errNoKey = errors.New("toml: top-level values must be Go maps or structs") 26 errAnything = errors.New("") // used in testing 27 ) 28 29 var dblQuotedReplacer = strings.NewReplacer( 30 "\"", "\\\"", 31 "\\", "\\\\", 32 "\x00", `\u0000`, 33 "\x01", `\u0001`, 34 "\x02", `\u0002`, 35 "\x03", `\u0003`, 36 "\x04", `\u0004`, 37 "\x05", `\u0005`, 38 "\x06", `\u0006`, 39 "\x07", `\u0007`, 40 "\b", `\b`, 41 "\t", `\t`, 42 "\n", `\n`, 43 "\x0b", `\u000b`, 44 "\f", `\f`, 45 "\r", `\r`, 46 "\x0e", `\u000e`, 47 "\x0f", `\u000f`, 48 "\x10", `\u0010`, 49 "\x11", `\u0011`, 50 "\x12", `\u0012`, 51 "\x13", `\u0013`, 52 "\x14", `\u0014`, 53 "\x15", `\u0015`, 54 "\x16", `\u0016`, 55 "\x17", `\u0017`, 56 "\x18", `\u0018`, 57 "\x19", `\u0019`, 58 "\x1a", `\u001a`, 59 "\x1b", `\u001b`, 60 "\x1c", `\u001c`, 61 "\x1d", `\u001d`, 62 "\x1e", `\u001e`, 63 "\x1f", `\u001f`, 64 "\x7f", `\u007f`, 65 ) 66 67 var ( 68 marshalToml = reflect.TypeOf((*Marshaler)(nil)).Elem() 69 marshalText = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() 70 timeType = reflect.TypeOf((*time.Time)(nil)).Elem() 71 ) 72 73 // Marshaler is the interface implemented by types that can marshal themselves 74 // into valid TOML. 75 type Marshaler interface { 76 MarshalTOML() ([]byte, error) 77 } 78 79 // Encoder encodes a Go to a TOML document. 80 // 81 // The mapping between Go values and TOML values should be precisely the same as 82 // for the Decode* functions. 83 // 84 // time.Time is encoded as a RFC 3339 string, and time.Duration as its string 85 // representation. 86 // 87 // The toml.Marshaler and encoder.TextMarshaler interfaces are supported to 88 // encoding the value as custom TOML. 89 // 90 // If you want to write arbitrary binary data then you will need to use 91 // something like base64 since TOML does not have any binary types. 92 // 93 // When encoding TOML hashes (Go maps or structs), keys without any sub-hashes 94 // are encoded first. 95 // 96 // Go maps will be sorted alphabetically by key for deterministic output. 97 // 98 // The toml struct tag can be used to provide the key name; if omitted the 99 // struct field name will be used. If the "omitempty" option is present the 100 // following value will be skipped: 101 // 102 // - arrays, slices, maps, and string with len of 0 103 // - struct with all zero values 104 // - bool false 105 // 106 // If omitzero is given all int and float types with a value of 0 will be 107 // skipped. 108 // 109 // Encoding Go values without a corresponding TOML representation will return an 110 // error. Examples of this includes maps with non-string keys, slices with nil 111 // elements, embedded non-struct types, and nested slices containing maps or 112 // structs. (e.g. [][]map[string]string is not allowed but []map[string]string 113 // is okay, as is []map[string][]string). 114 // 115 // NOTE: only exported keys are encoded due to the use of reflection. Unexported 116 // keys are silently discarded. 117 type Encoder struct { 118 // String to use for a single indentation level; default is two spaces. 119 Indent string 120 121 w *bufio.Writer 122 hasWritten bool // written any output to w yet? 123 } 124 125 // NewEncoder create a new Encoder. 126 func NewEncoder(w io.Writer) *Encoder { 127 return &Encoder{ 128 w: bufio.NewWriter(w), 129 Indent: " ", 130 } 131 } 132 133 // Encode writes a TOML representation of the Go value to the Encoder's writer. 134 // 135 // An error is returned if the value given cannot be encoded to a valid TOML 136 // document. 137 func (enc *Encoder) Encode(v interface{}) error { 138 rv := eindirect(reflect.ValueOf(v)) 139 if err := enc.safeEncode(Key([]string{}), rv); err != nil { 140 return err 141 } 142 return enc.w.Flush() 143 } 144 145 func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) { 146 defer func() { 147 if r := recover(); r != nil { 148 if terr, ok := r.(tomlEncodeError); ok { 149 err = terr.error 150 return 151 } 152 panic(r) 153 } 154 }() 155 enc.encode(key, rv) 156 return nil 157 } 158 159 func (enc *Encoder) encode(key Key, rv reflect.Value) { 160 // If we can marshal the type to text, then we use that. This prevents the 161 // encoder for handling these types as generic structs (or whatever the 162 // underlying type of a TextMarshaler is). 163 switch { 164 case isMarshaler(rv): 165 enc.writeKeyValue(key, rv, false) 166 return 167 case rv.Type() == primitiveType: // TODO: #76 would make this superfluous after implemented. 168 enc.encode(key, reflect.ValueOf(rv.Interface().(Primitive).undecoded)) 169 return 170 } 171 172 k := rv.Kind() 173 switch k { 174 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, 175 reflect.Int64, 176 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, 177 reflect.Uint64, 178 reflect.Float32, reflect.Float64, reflect.String, reflect.Bool: 179 enc.writeKeyValue(key, rv, false) 180 case reflect.Array, reflect.Slice: 181 if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) { 182 enc.eArrayOfTables(key, rv) 183 } else { 184 enc.writeKeyValue(key, rv, false) 185 } 186 case reflect.Interface: 187 if rv.IsNil() { 188 return 189 } 190 enc.encode(key, rv.Elem()) 191 case reflect.Map: 192 if rv.IsNil() { 193 return 194 } 195 enc.eTable(key, rv) 196 case reflect.Ptr: 197 if rv.IsNil() { 198 return 199 } 200 enc.encode(key, rv.Elem()) 201 case reflect.Struct: 202 enc.eTable(key, rv) 203 default: 204 encPanic(fmt.Errorf("unsupported type for key '%s': %s", key, k)) 205 } 206 } 207 208 // eElement encodes any value that can be an array element. 209 func (enc *Encoder) eElement(rv reflect.Value) { 210 switch v := rv.Interface().(type) { 211 case time.Time: // Using TextMarshaler adds extra quotes, which we don't want. 212 format := time.RFC3339Nano 213 switch v.Location() { 214 case internal.LocalDatetime: 215 format = "2006-01-02T15:04:05.999999999" 216 case internal.LocalDate: 217 format = "2006-01-02" 218 case internal.LocalTime: 219 format = "15:04:05.999999999" 220 } 221 switch v.Location() { 222 default: 223 enc.wf(v.Format(format)) 224 case internal.LocalDatetime, internal.LocalDate, internal.LocalTime: 225 enc.wf(v.In(time.UTC).Format(format)) 226 } 227 return 228 case Marshaler: 229 s, err := v.MarshalTOML() 230 if err != nil { 231 encPanic(err) 232 } 233 if s == nil { 234 encPanic(errors.New("MarshalTOML returned nil and no error")) 235 } 236 enc.w.Write(s) 237 return 238 case encoding.TextMarshaler: 239 s, err := v.MarshalText() 240 if err != nil { 241 encPanic(err) 242 } 243 if s == nil { 244 encPanic(errors.New("MarshalText returned nil and no error")) 245 } 246 enc.writeQuoted(string(s)) 247 return 248 case time.Duration: 249 enc.writeQuoted(v.String()) 250 return 251 case json.Number: 252 n, _ := rv.Interface().(json.Number) 253 254 if n == "" { /// Useful zero value. 255 enc.w.WriteByte('0') 256 return 257 } else if v, err := n.Int64(); err == nil { 258 enc.eElement(reflect.ValueOf(v)) 259 return 260 } else if v, err := n.Float64(); err == nil { 261 enc.eElement(reflect.ValueOf(v)) 262 return 263 } 264 encPanic(fmt.Errorf("unable to convert %q to int64 or float64", n)) 265 } 266 267 switch rv.Kind() { 268 case reflect.Ptr: 269 enc.eElement(rv.Elem()) 270 return 271 case reflect.String: 272 enc.writeQuoted(rv.String()) 273 case reflect.Bool: 274 enc.wf(strconv.FormatBool(rv.Bool())) 275 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 276 enc.wf(strconv.FormatInt(rv.Int(), 10)) 277 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 278 enc.wf(strconv.FormatUint(rv.Uint(), 10)) 279 case reflect.Float32: 280 f := rv.Float() 281 if math.IsNaN(f) { 282 enc.wf("nan") 283 } else if math.IsInf(f, 0) { 284 enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)]) 285 } else { 286 enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 32))) 287 } 288 case reflect.Float64: 289 f := rv.Float() 290 if math.IsNaN(f) { 291 enc.wf("nan") 292 } else if math.IsInf(f, 0) { 293 enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)]) 294 } else { 295 enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 64))) 296 } 297 case reflect.Array, reflect.Slice: 298 enc.eArrayOrSliceElement(rv) 299 case reflect.Struct: 300 enc.eStruct(nil, rv, true) 301 case reflect.Map: 302 enc.eMap(nil, rv, true) 303 case reflect.Interface: 304 enc.eElement(rv.Elem()) 305 default: 306 encPanic(fmt.Errorf("unexpected type: %T", rv.Interface())) 307 } 308 } 309 310 // By the TOML spec, all floats must have a decimal with at least one number on 311 // either side. 312 func floatAddDecimal(fstr string) string { 313 if !strings.Contains(fstr, ".") { 314 return fstr + ".0" 315 } 316 return fstr 317 } 318 319 func (enc *Encoder) writeQuoted(s string) { 320 enc.wf("\"%s\"", dblQuotedReplacer.Replace(s)) 321 } 322 323 func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) { 324 length := rv.Len() 325 enc.wf("[") 326 for i := 0; i < length; i++ { 327 elem := eindirect(rv.Index(i)) 328 enc.eElement(elem) 329 if i != length-1 { 330 enc.wf(", ") 331 } 332 } 333 enc.wf("]") 334 } 335 336 func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) { 337 if len(key) == 0 { 338 encPanic(errNoKey) 339 } 340 for i := 0; i < rv.Len(); i++ { 341 trv := eindirect(rv.Index(i)) 342 if isNil(trv) { 343 continue 344 } 345 enc.newline() 346 enc.wf("%s[[%s]]", enc.indentStr(key), key) 347 enc.newline() 348 enc.eMapOrStruct(key, trv, false) 349 } 350 } 351 352 func (enc *Encoder) eTable(key Key, rv reflect.Value) { 353 if len(key) == 1 { 354 // Output an extra newline between top-level tables. 355 // (The newline isn't written if nothing else has been written though.) 356 enc.newline() 357 } 358 if len(key) > 0 { 359 enc.wf("%s[%s]", enc.indentStr(key), key) 360 enc.newline() 361 } 362 enc.eMapOrStruct(key, rv, false) 363 } 364 365 func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) { 366 switch rv.Kind() { 367 case reflect.Map: 368 enc.eMap(key, rv, inline) 369 case reflect.Struct: 370 enc.eStruct(key, rv, inline) 371 default: 372 // Should never happen? 373 panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String()) 374 } 375 } 376 377 func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) { 378 rt := rv.Type() 379 if rt.Key().Kind() != reflect.String { 380 encPanic(errNonString) 381 } 382 383 // Sort keys so that we have deterministic output. And write keys directly 384 // underneath this key first, before writing sub-structs or sub-maps. 385 var mapKeysDirect, mapKeysSub []string 386 for _, mapKey := range rv.MapKeys() { 387 k := mapKey.String() 388 if typeIsTable(tomlTypeOfGo(eindirect(rv.MapIndex(mapKey)))) { 389 mapKeysSub = append(mapKeysSub, k) 390 } else { 391 mapKeysDirect = append(mapKeysDirect, k) 392 } 393 } 394 395 var writeMapKeys = func(mapKeys []string, trailC bool) { 396 sort.Strings(mapKeys) 397 for i, mapKey := range mapKeys { 398 val := eindirect(rv.MapIndex(reflect.ValueOf(mapKey))) 399 if isNil(val) { 400 continue 401 } 402 403 if inline { 404 enc.writeKeyValue(Key{mapKey}, val, true) 405 if trailC || i != len(mapKeys)-1 { 406 enc.wf(", ") 407 } 408 } else { 409 enc.encode(key.add(mapKey), val) 410 } 411 } 412 } 413 414 if inline { 415 enc.wf("{") 416 } 417 writeMapKeys(mapKeysDirect, len(mapKeysSub) > 0) 418 writeMapKeys(mapKeysSub, false) 419 if inline { 420 enc.wf("}") 421 } 422 } 423 424 const is32Bit = (32 << (^uint(0) >> 63)) == 32 425 426 func pointerTo(t reflect.Type) reflect.Type { 427 if t.Kind() == reflect.Ptr { 428 return pointerTo(t.Elem()) 429 } 430 return t 431 } 432 433 func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) { 434 // Write keys for fields directly under this key first, because if we write 435 // a field that creates a new table then all keys under it will be in that 436 // table (not the one we're writing here). 437 // 438 // Fields is a [][]int: for fieldsDirect this always has one entry (the 439 // struct index). For fieldsSub it contains two entries: the parent field 440 // index from tv, and the field indexes for the fields of the sub. 441 var ( 442 rt = rv.Type() 443 fieldsDirect, fieldsSub [][]int 444 addFields func(rt reflect.Type, rv reflect.Value, start []int) 445 ) 446 addFields = func(rt reflect.Type, rv reflect.Value, start []int) { 447 for i := 0; i < rt.NumField(); i++ { 448 f := rt.Field(i) 449 isEmbed := f.Anonymous && pointerTo(f.Type).Kind() == reflect.Struct 450 if f.PkgPath != "" && !isEmbed { /// Skip unexported fields. 451 continue 452 } 453 opts := getOptions(f.Tag) 454 if opts.skip { 455 continue 456 } 457 458 frv := eindirect(rv.Field(i)) 459 460 // Treat anonymous struct fields with tag names as though they are 461 // not anonymous, like encoding/json does. 462 // 463 // Non-struct anonymous fields use the normal encoding logic. 464 if isEmbed { 465 if getOptions(f.Tag).name == "" && frv.Kind() == reflect.Struct { 466 addFields(frv.Type(), frv, append(start, f.Index...)) 467 continue 468 } 469 } 470 471 if typeIsTable(tomlTypeOfGo(frv)) { 472 fieldsSub = append(fieldsSub, append(start, f.Index...)) 473 } else { 474 // Copy so it works correct on 32bit archs; not clear why this 475 // is needed. See #314, and https://www.reddit.com/r/golang/comments/pnx8v4 476 // This also works fine on 64bit, but 32bit archs are somewhat 477 // rare and this is a wee bit faster. 478 if is32Bit { 479 copyStart := make([]int, len(start)) 480 copy(copyStart, start) 481 fieldsDirect = append(fieldsDirect, append(copyStart, f.Index...)) 482 } else { 483 fieldsDirect = append(fieldsDirect, append(start, f.Index...)) 484 } 485 } 486 } 487 } 488 addFields(rt, rv, nil) 489 490 writeFields := func(fields [][]int) { 491 for _, fieldIndex := range fields { 492 fieldType := rt.FieldByIndex(fieldIndex) 493 fieldVal := eindirect(rv.FieldByIndex(fieldIndex)) 494 495 if isNil(fieldVal) { /// Don't write anything for nil fields. 496 continue 497 } 498 499 opts := getOptions(fieldType.Tag) 500 if opts.skip { 501 continue 502 } 503 keyName := fieldType.Name 504 if opts.name != "" { 505 keyName = opts.name 506 } 507 508 if opts.omitempty && enc.isEmpty(fieldVal) { 509 continue 510 } 511 if opts.omitzero && isZero(fieldVal) { 512 continue 513 } 514 515 if inline { 516 enc.writeKeyValue(Key{keyName}, fieldVal, true) 517 if fieldIndex[0] != len(fields)-1 { 518 enc.wf(", ") 519 } 520 } else { 521 enc.encode(key.add(keyName), fieldVal) 522 } 523 } 524 } 525 526 if inline { 527 enc.wf("{") 528 } 529 writeFields(fieldsDirect) 530 writeFields(fieldsSub) 531 if inline { 532 enc.wf("}") 533 } 534 } 535 536 // tomlTypeOfGo returns the TOML type name of the Go value's type. 537 // 538 // It is used to determine whether the types of array elements are mixed (which 539 // is forbidden). If the Go value is nil, then it is illegal for it to be an 540 // array element, and valueIsNil is returned as true. 541 // 542 // The type may be `nil`, which means no concrete TOML type could be found. 543 func tomlTypeOfGo(rv reflect.Value) tomlType { 544 if isNil(rv) || !rv.IsValid() { 545 return nil 546 } 547 548 if rv.Kind() == reflect.Struct { 549 if rv.Type() == timeType { 550 return tomlDatetime 551 } 552 if isMarshaler(rv) { 553 return tomlString 554 } 555 return tomlHash 556 } 557 558 if isMarshaler(rv) { 559 return tomlString 560 } 561 562 switch rv.Kind() { 563 case reflect.Bool: 564 return tomlBool 565 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, 566 reflect.Int64, 567 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, 568 reflect.Uint64: 569 return tomlInteger 570 case reflect.Float32, reflect.Float64: 571 return tomlFloat 572 case reflect.Array, reflect.Slice: 573 if isTableArray(rv) { 574 return tomlArrayHash 575 } 576 return tomlArray 577 case reflect.Ptr, reflect.Interface: 578 return tomlTypeOfGo(rv.Elem()) 579 case reflect.String: 580 return tomlString 581 case reflect.Map: 582 return tomlHash 583 default: 584 encPanic(errors.New("unsupported type: " + rv.Kind().String())) 585 panic("unreachable") 586 } 587 } 588 589 func isMarshaler(rv reflect.Value) bool { 590 return rv.Type().Implements(marshalText) || rv.Type().Implements(marshalToml) 591 } 592 593 // isTableArray reports if all entries in the array or slice are a table. 594 func isTableArray(arr reflect.Value) bool { 595 if isNil(arr) || !arr.IsValid() || arr.Len() == 0 { 596 return false 597 } 598 599 ret := true 600 for i := 0; i < arr.Len(); i++ { 601 tt := tomlTypeOfGo(eindirect(arr.Index(i))) 602 // Don't allow nil. 603 if tt == nil { 604 encPanic(errArrayNilElement) 605 } 606 607 if ret && !typeEqual(tomlHash, tt) { 608 ret = false 609 } 610 } 611 return ret 612 } 613 614 type tagOptions struct { 615 skip bool // "-" 616 name string 617 omitempty bool 618 omitzero bool 619 } 620 621 func getOptions(tag reflect.StructTag) tagOptions { 622 t := tag.Get("toml") 623 if t == "-" { 624 return tagOptions{skip: true} 625 } 626 var opts tagOptions 627 parts := strings.Split(t, ",") 628 opts.name = parts[0] 629 for _, s := range parts[1:] { 630 switch s { 631 case "omitempty": 632 opts.omitempty = true 633 case "omitzero": 634 opts.omitzero = true 635 } 636 } 637 return opts 638 } 639 640 func isZero(rv reflect.Value) bool { 641 switch rv.Kind() { 642 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 643 return rv.Int() == 0 644 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 645 return rv.Uint() == 0 646 case reflect.Float32, reflect.Float64: 647 return rv.Float() == 0.0 648 } 649 return false 650 } 651 652 func (enc *Encoder) isEmpty(rv reflect.Value) bool { 653 switch rv.Kind() { 654 case reflect.Array, reflect.Slice, reflect.Map, reflect.String: 655 return rv.Len() == 0 656 case reflect.Struct: 657 if rv.Type().Comparable() { 658 return reflect.Zero(rv.Type()).Interface() == rv.Interface() 659 } 660 case reflect.Bool: 661 return !rv.Bool() 662 } 663 return false 664 } 665 666 func (enc *Encoder) newline() { 667 if enc.hasWritten { 668 enc.wf("\n") 669 } 670 } 671 672 // Write a key/value pair: 673 // 674 // key = <any value> 675 // 676 // This is also used for "k = v" in inline tables; so something like this will 677 // be written in three calls: 678 // 679 // ┌────────────────────┐ 680 // │ ┌───┐ ┌─────┐│ 681 // v v v v vv 682 // key = {k = v, k2 = v2} 683 func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) { 684 if len(key) == 0 { 685 encPanic(errNoKey) 686 } 687 enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1)) 688 enc.eElement(val) 689 if !inline { 690 enc.newline() 691 } 692 } 693 694 func (enc *Encoder) wf(format string, v ...interface{}) { 695 _, err := fmt.Fprintf(enc.w, format, v...) 696 if err != nil { 697 encPanic(err) 698 } 699 enc.hasWritten = true 700 } 701 702 func (enc *Encoder) indentStr(key Key) string { 703 return strings.Repeat(enc.Indent, len(key)-1) 704 } 705 706 func encPanic(err error) { 707 panic(tomlEncodeError{err}) 708 } 709 710 // Resolve any level of pointers to the actual value (e.g. **string → string). 711 func eindirect(v reflect.Value) reflect.Value { 712 if v.Kind() != reflect.Ptr && v.Kind() != reflect.Interface { 713 if isMarshaler(v) { 714 return v 715 } 716 if v.CanAddr() { /// Special case for marshalers; see #358. 717 if pv := v.Addr(); isMarshaler(pv) { 718 return pv 719 } 720 } 721 return v 722 } 723 724 if v.IsNil() { 725 return v 726 } 727 728 return eindirect(v.Elem()) 729 } 730 731 func isNil(rv reflect.Value) bool { 732 switch rv.Kind() { 733 case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: 734 return rv.IsNil() 735 default: 736 return false 737 } 738 }