github.com/night-codes/go-json@v0.9.15/internal/encoder/compiler.go (about) 1 package encoder 2 3 import ( 4 "context" 5 "encoding" 6 "encoding/json" 7 "reflect" 8 "sync/atomic" 9 "unsafe" 10 11 "github.com/night-codes/go-json/internal/errors" 12 "github.com/night-codes/go-json/internal/runtime" 13 ) 14 15 type marshalerContext interface { 16 MarshalJSON(context.Context) ([]byte, error) 17 } 18 19 var ( 20 marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() 21 marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem() 22 marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() 23 jsonNumberType = reflect.TypeOf(json.Number("")) 24 cachedOpcodeSets []*OpcodeSet 25 cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet 26 typeAddr *runtime.TypeAddr 27 ) 28 29 func init() { 30 typeAddr = runtime.AnalyzeTypeAddr() 31 if typeAddr == nil { 32 typeAddr = &runtime.TypeAddr{} 33 } 34 cachedOpcodeSets = make([]*OpcodeSet, typeAddr.AddrRange>>typeAddr.AddrShift+1) 35 } 36 37 func loadOpcodeMap() map[uintptr]*OpcodeSet { 38 p := atomic.LoadPointer(&cachedOpcodeMap) 39 return *(*map[uintptr]*OpcodeSet)(unsafe.Pointer(&p)) 40 } 41 42 func storeOpcodeSet(typ uintptr, set *OpcodeSet, m map[uintptr]*OpcodeSet) { 43 newOpcodeMap := make(map[uintptr]*OpcodeSet, len(m)+1) 44 newOpcodeMap[typ] = set 45 46 for k, v := range m { 47 newOpcodeMap[k] = v 48 } 49 50 atomic.StorePointer(&cachedOpcodeMap, *(*unsafe.Pointer)(unsafe.Pointer(&newOpcodeMap))) 51 } 52 53 func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) { 54 opcodeMap := loadOpcodeMap() 55 if codeSet, exists := opcodeMap[typeptr]; exists { 56 return codeSet, nil 57 } 58 codeSet, err := newCompiler().compile(typeptr) 59 if err != nil { 60 return nil, err 61 } 62 storeOpcodeSet(typeptr, codeSet, opcodeMap) 63 return codeSet, nil 64 } 65 66 func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*OpcodeSet, error) { 67 if (ctx.Option.Flag & ContextOption) == 0 { 68 return codeSet, nil 69 } 70 query := FieldQueryFromContext(ctx.Option.Context) 71 if query == nil { 72 return codeSet, nil 73 } 74 ctx.Option.Flag |= FieldQueryOption 75 cacheCodeSet := codeSet.getQueryCache(query.Hash()) 76 if cacheCodeSet != nil { 77 return cacheCodeSet, nil 78 } 79 queryCodeSet, err := newCompiler().codeToOpcodeSet(codeSet.Type, codeSet.Code.Filter(query)) 80 if err != nil { 81 return nil, err 82 } 83 codeSet.setQueryCache(query.Hash(), queryCodeSet) 84 return queryCodeSet, nil 85 } 86 87 type Compiler struct { 88 structTypeToCode map[uintptr]*StructCode 89 } 90 91 func newCompiler() *Compiler { 92 return &Compiler{ 93 structTypeToCode: map[uintptr]*StructCode{}, 94 } 95 } 96 97 func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) { 98 // noescape trick for header.typ ( reflect.*rtype ) 99 typ := *(**runtime.Type)(unsafe.Pointer(&typeptr)) 100 code, err := c.typeToCode(typ) 101 if err != nil { 102 return nil, err 103 } 104 return c.codeToOpcodeSet(typ, code) 105 } 106 107 func (c *Compiler) codeToOpcodeSet(typ *runtime.Type, code Code) (*OpcodeSet, error) { 108 noescapeKeyCode := c.codeToOpcode(&compileContext{ 109 structTypeToCodes: map[uintptr]Opcodes{}, 110 recursiveCodes: &Opcodes{}, 111 }, typ, code) 112 if err := noescapeKeyCode.Validate(); err != nil { 113 return nil, err 114 } 115 escapeKeyCode := c.codeToOpcode(&compileContext{ 116 structTypeToCodes: map[uintptr]Opcodes{}, 117 recursiveCodes: &Opcodes{}, 118 escapeKey: true, 119 }, typ, code) 120 noescapeKeyCode = copyOpcode(noescapeKeyCode) 121 escapeKeyCode = copyOpcode(escapeKeyCode) 122 setTotalLengthToInterfaceOp(noescapeKeyCode) 123 setTotalLengthToInterfaceOp(escapeKeyCode) 124 interfaceNoescapeKeyCode := copyToInterfaceOpcode(noescapeKeyCode) 125 interfaceEscapeKeyCode := copyToInterfaceOpcode(escapeKeyCode) 126 codeLength := noescapeKeyCode.TotalLength() 127 return &OpcodeSet{ 128 Type: typ, 129 NoescapeKeyCode: noescapeKeyCode, 130 EscapeKeyCode: escapeKeyCode, 131 InterfaceNoescapeKeyCode: interfaceNoescapeKeyCode, 132 InterfaceEscapeKeyCode: interfaceEscapeKeyCode, 133 CodeLength: codeLength, 134 EndCode: ToEndCode(interfaceNoescapeKeyCode), 135 Code: code, 136 QueryCache: map[string]*OpcodeSet{}, 137 }, nil 138 } 139 140 func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) { 141 switch { 142 case c.implementsMarshalJSON(typ): 143 return c.marshalJSONCode(typ) 144 case c.implementsMarshalText(typ): 145 return c.marshalTextCode(typ) 146 } 147 148 isPtr := false 149 orgType := typ 150 if typ.Kind() == reflect.Ptr { 151 typ = typ.Elem() 152 isPtr = true 153 } 154 switch { 155 case c.implementsMarshalJSON(typ): 156 return c.marshalJSONCode(orgType) 157 case c.implementsMarshalText(typ): 158 return c.marshalTextCode(orgType) 159 } 160 switch typ.Kind() { 161 case reflect.Slice: 162 elem := typ.Elem() 163 if elem.Kind() == reflect.Uint8 { 164 p := runtime.PtrTo(elem) 165 if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { 166 return c.bytesCode(typ, isPtr) 167 } 168 } 169 return c.sliceCode(typ) 170 case reflect.Map: 171 if isPtr { 172 return c.ptrCode(runtime.PtrTo(typ)) 173 } 174 return c.mapCode(typ) 175 case reflect.Struct: 176 return c.structCode(typ, isPtr) 177 case reflect.Int: 178 return c.intCode(typ, isPtr) 179 case reflect.Int8: 180 return c.int8Code(typ, isPtr) 181 case reflect.Int16: 182 return c.int16Code(typ, isPtr) 183 case reflect.Int32: 184 return c.int32Code(typ, isPtr) 185 case reflect.Int64: 186 return c.int64Code(typ, isPtr) 187 case reflect.Uint, reflect.Uintptr: 188 return c.uintCode(typ, isPtr) 189 case reflect.Uint8: 190 return c.uint8Code(typ, isPtr) 191 case reflect.Uint16: 192 return c.uint16Code(typ, isPtr) 193 case reflect.Uint32: 194 return c.uint32Code(typ, isPtr) 195 case reflect.Uint64: 196 return c.uint64Code(typ, isPtr) 197 case reflect.Float32: 198 return c.float32Code(typ, isPtr) 199 case reflect.Float64: 200 return c.float64Code(typ, isPtr) 201 case reflect.String: 202 return c.stringCode(typ, isPtr) 203 case reflect.Bool: 204 return c.boolCode(typ, isPtr) 205 case reflect.Interface: 206 return c.interfaceCode(typ, isPtr) 207 default: 208 if isPtr && typ.Implements(marshalTextType) { 209 typ = orgType 210 } 211 return c.typeToCodeWithPtr(typ, isPtr) 212 } 213 } 214 215 func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error) { 216 switch { 217 case c.implementsMarshalJSON(typ): 218 return c.marshalJSONCode(typ) 219 case c.implementsMarshalText(typ): 220 return c.marshalTextCode(typ) 221 } 222 switch typ.Kind() { 223 case reflect.Ptr: 224 return c.ptrCode(typ) 225 case reflect.Slice: 226 elem := typ.Elem() 227 if elem.Kind() == reflect.Uint8 { 228 p := runtime.PtrTo(elem) 229 if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { 230 return c.bytesCode(typ, false) 231 } 232 } 233 return c.sliceCode(typ) 234 case reflect.Array: 235 return c.arrayCode(typ) 236 case reflect.Map: 237 return c.mapCode(typ) 238 case reflect.Struct: 239 return c.structCode(typ, isPtr) 240 case reflect.Interface: 241 return c.interfaceCode(typ, false) 242 case reflect.Int: 243 return c.intCode(typ, false) 244 case reflect.Int8: 245 return c.int8Code(typ, false) 246 case reflect.Int16: 247 return c.int16Code(typ, false) 248 case reflect.Int32: 249 return c.int32Code(typ, false) 250 case reflect.Int64: 251 return c.int64Code(typ, false) 252 case reflect.Uint: 253 return c.uintCode(typ, false) 254 case reflect.Uint8: 255 return c.uint8Code(typ, false) 256 case reflect.Uint16: 257 return c.uint16Code(typ, false) 258 case reflect.Uint32: 259 return c.uint32Code(typ, false) 260 case reflect.Uint64: 261 return c.uint64Code(typ, false) 262 case reflect.Uintptr: 263 return c.uintCode(typ, false) 264 case reflect.Float32: 265 return c.float32Code(typ, false) 266 case reflect.Float64: 267 return c.float64Code(typ, false) 268 case reflect.String: 269 return c.stringCode(typ, false) 270 case reflect.Bool: 271 return c.boolCode(typ, false) 272 } 273 return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)} 274 } 275 276 const intSize = 32 << (^uint(0) >> 63) 277 278 //nolint:unparam 279 func (c *Compiler) intCode(typ *runtime.Type, isPtr bool) (*IntCode, error) { 280 return &IntCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil 281 } 282 283 //nolint:unparam 284 func (c *Compiler) int8Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 285 return &IntCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil 286 } 287 288 //nolint:unparam 289 func (c *Compiler) int16Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 290 return &IntCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil 291 } 292 293 //nolint:unparam 294 func (c *Compiler) int32Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 295 return &IntCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil 296 } 297 298 //nolint:unparam 299 func (c *Compiler) int64Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 300 return &IntCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil 301 } 302 303 //nolint:unparam 304 func (c *Compiler) uintCode(typ *runtime.Type, isPtr bool) (*UintCode, error) { 305 return &UintCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil 306 } 307 308 //nolint:unparam 309 func (c *Compiler) uint8Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 310 return &UintCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil 311 } 312 313 //nolint:unparam 314 func (c *Compiler) uint16Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 315 return &UintCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil 316 } 317 318 //nolint:unparam 319 func (c *Compiler) uint32Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 320 return &UintCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil 321 } 322 323 //nolint:unparam 324 func (c *Compiler) uint64Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 325 return &UintCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil 326 } 327 328 //nolint:unparam 329 func (c *Compiler) float32Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) { 330 return &FloatCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil 331 } 332 333 //nolint:unparam 334 func (c *Compiler) float64Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) { 335 return &FloatCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil 336 } 337 338 //nolint:unparam 339 func (c *Compiler) stringCode(typ *runtime.Type, isPtr bool) (*StringCode, error) { 340 return &StringCode{typ: typ, isPtr: isPtr}, nil 341 } 342 343 //nolint:unparam 344 func (c *Compiler) boolCode(typ *runtime.Type, isPtr bool) (*BoolCode, error) { 345 return &BoolCode{typ: typ, isPtr: isPtr}, nil 346 } 347 348 //nolint:unparam 349 func (c *Compiler) intStringCode(typ *runtime.Type) (*IntCode, error) { 350 return &IntCode{typ: typ, bitSize: intSize, isString: true}, nil 351 } 352 353 //nolint:unparam 354 func (c *Compiler) int8StringCode(typ *runtime.Type) (*IntCode, error) { 355 return &IntCode{typ: typ, bitSize: 8, isString: true}, nil 356 } 357 358 //nolint:unparam 359 func (c *Compiler) int16StringCode(typ *runtime.Type) (*IntCode, error) { 360 return &IntCode{typ: typ, bitSize: 16, isString: true}, nil 361 } 362 363 //nolint:unparam 364 func (c *Compiler) int32StringCode(typ *runtime.Type) (*IntCode, error) { 365 return &IntCode{typ: typ, bitSize: 32, isString: true}, nil 366 } 367 368 //nolint:unparam 369 func (c *Compiler) int64StringCode(typ *runtime.Type) (*IntCode, error) { 370 return &IntCode{typ: typ, bitSize: 64, isString: true}, nil 371 } 372 373 //nolint:unparam 374 func (c *Compiler) uintStringCode(typ *runtime.Type) (*UintCode, error) { 375 return &UintCode{typ: typ, bitSize: intSize, isString: true}, nil 376 } 377 378 //nolint:unparam 379 func (c *Compiler) uint8StringCode(typ *runtime.Type) (*UintCode, error) { 380 return &UintCode{typ: typ, bitSize: 8, isString: true}, nil 381 } 382 383 //nolint:unparam 384 func (c *Compiler) uint16StringCode(typ *runtime.Type) (*UintCode, error) { 385 return &UintCode{typ: typ, bitSize: 16, isString: true}, nil 386 } 387 388 //nolint:unparam 389 func (c *Compiler) uint32StringCode(typ *runtime.Type) (*UintCode, error) { 390 return &UintCode{typ: typ, bitSize: 32, isString: true}, nil 391 } 392 393 //nolint:unparam 394 func (c *Compiler) uint64StringCode(typ *runtime.Type) (*UintCode, error) { 395 return &UintCode{typ: typ, bitSize: 64, isString: true}, nil 396 } 397 398 //nolint:unparam 399 func (c *Compiler) bytesCode(typ *runtime.Type, isPtr bool) (*BytesCode, error) { 400 return &BytesCode{typ: typ, isPtr: isPtr}, nil 401 } 402 403 //nolint:unparam 404 func (c *Compiler) interfaceCode(typ *runtime.Type, isPtr bool) (*InterfaceCode, error) { 405 return &InterfaceCode{typ: typ, isPtr: isPtr}, nil 406 } 407 408 //nolint:unparam 409 func (c *Compiler) marshalJSONCode(typ *runtime.Type) (*MarshalJSONCode, error) { 410 return &MarshalJSONCode{ 411 typ: typ, 412 isAddrForMarshaler: c.isPtrMarshalJSONType(typ), 413 isNilableType: c.isNilableType(typ), 414 isMarshalerContext: typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType), 415 }, nil 416 } 417 418 //nolint:unparam 419 func (c *Compiler) marshalTextCode(typ *runtime.Type) (*MarshalTextCode, error) { 420 return &MarshalTextCode{ 421 typ: typ, 422 isAddrForMarshaler: c.isPtrMarshalTextType(typ), 423 isNilableType: c.isNilableType(typ), 424 }, nil 425 } 426 427 func (c *Compiler) ptrCode(typ *runtime.Type) (*PtrCode, error) { 428 code, err := c.typeToCodeWithPtr(typ.Elem(), true) 429 if err != nil { 430 return nil, err 431 } 432 ptr, ok := code.(*PtrCode) 433 if ok { 434 return &PtrCode{typ: typ, value: ptr.value, ptrNum: ptr.ptrNum + 1}, nil 435 } 436 return &PtrCode{typ: typ, value: code, ptrNum: 1}, nil 437 } 438 439 func (c *Compiler) sliceCode(typ *runtime.Type) (*SliceCode, error) { 440 elem := typ.Elem() 441 code, err := c.listElemCode(elem) 442 if err != nil { 443 return nil, err 444 } 445 if code.Kind() == CodeKindStruct { 446 structCode := code.(*StructCode) 447 structCode.enableIndirect() 448 } 449 return &SliceCode{typ: typ, value: code}, nil 450 } 451 452 func (c *Compiler) arrayCode(typ *runtime.Type) (*ArrayCode, error) { 453 elem := typ.Elem() 454 code, err := c.listElemCode(elem) 455 if err != nil { 456 return nil, err 457 } 458 if code.Kind() == CodeKindStruct { 459 structCode := code.(*StructCode) 460 structCode.enableIndirect() 461 } 462 return &ArrayCode{typ: typ, value: code}, nil 463 } 464 465 func (c *Compiler) mapCode(typ *runtime.Type) (*MapCode, error) { 466 keyCode, err := c.mapKeyCode(typ.Key()) 467 if err != nil { 468 return nil, err 469 } 470 valueCode, err := c.mapValueCode(typ.Elem()) 471 if err != nil { 472 return nil, err 473 } 474 if valueCode.Kind() == CodeKindStruct { 475 structCode := valueCode.(*StructCode) 476 structCode.enableIndirect() 477 } 478 return &MapCode{typ: typ, key: keyCode, value: valueCode}, nil 479 } 480 481 func (c *Compiler) listElemCode(typ *runtime.Type) (Code, error) { 482 switch { 483 case c.isPtrMarshalJSONType(typ): 484 return c.marshalJSONCode(typ) 485 case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType): 486 return c.marshalTextCode(typ) 487 case typ.Kind() == reflect.Map: 488 return c.ptrCode(runtime.PtrTo(typ)) 489 default: 490 // isPtr was originally used to indicate whether the type of top level is pointer. 491 // However, since the slice/array element is a specification that can get the pointer address, explicitly set isPtr to true. 492 // See here for related issues: https://github.com/night-codes/go-json/issues/370 493 code, err := c.typeToCodeWithPtr(typ, true) 494 if err != nil { 495 return nil, err 496 } 497 ptr, ok := code.(*PtrCode) 498 if ok { 499 if ptr.value.Kind() == CodeKindMap { 500 ptr.ptrNum++ 501 } 502 } 503 return code, nil 504 } 505 } 506 507 func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) { 508 switch { 509 case c.implementsMarshalJSON(typ): 510 return c.marshalJSONCode(typ) 511 case c.implementsMarshalText(typ): 512 return c.marshalTextCode(typ) 513 } 514 switch typ.Kind() { 515 case reflect.Ptr: 516 return c.ptrCode(typ) 517 case reflect.String: 518 return c.stringCode(typ, false) 519 case reflect.Int: 520 return c.intStringCode(typ) 521 case reflect.Int8: 522 return c.int8StringCode(typ) 523 case reflect.Int16: 524 return c.int16StringCode(typ) 525 case reflect.Int32: 526 return c.int32StringCode(typ) 527 case reflect.Int64: 528 return c.int64StringCode(typ) 529 case reflect.Uint: 530 return c.uintStringCode(typ) 531 case reflect.Uint8: 532 return c.uint8StringCode(typ) 533 case reflect.Uint16: 534 return c.uint16StringCode(typ) 535 case reflect.Uint32: 536 return c.uint32StringCode(typ) 537 case reflect.Uint64: 538 return c.uint64StringCode(typ) 539 case reflect.Uintptr: 540 return c.uintStringCode(typ) 541 } 542 return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)} 543 } 544 545 func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) { 546 switch typ.Kind() { 547 case reflect.Map: 548 return c.ptrCode(runtime.PtrTo(typ)) 549 default: 550 code, err := c.typeToCodeWithPtr(typ, false) 551 if err != nil { 552 return nil, err 553 } 554 ptr, ok := code.(*PtrCode) 555 if ok { 556 if ptr.value.Kind() == CodeKindMap { 557 ptr.ptrNum++ 558 } 559 } 560 return code, nil 561 } 562 } 563 564 func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error) { 565 typeptr := uintptr(unsafe.Pointer(typ)) 566 if code, exists := c.structTypeToCode[typeptr]; exists { 567 derefCode := *code 568 derefCode.isRecursive = true 569 return &derefCode, nil 570 } 571 indirect := runtime.IfaceIndir(typ) 572 code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect} 573 c.structTypeToCode[typeptr] = code 574 575 fieldNum := typ.NumField() 576 tags := c.typeToStructTags(typ) 577 fields := []*StructFieldCode{} 578 for i, tag := range tags { 579 isOnlyOneFirstField := i == 0 && fieldNum == 1 580 field, err := c.structFieldCode(code, tag, isPtr, isOnlyOneFirstField) 581 if err != nil { 582 return nil, err 583 } 584 if field.isAnonymous { 585 structCode := field.getAnonymousStruct() 586 if structCode != nil { 587 structCode.removeFieldsByTags(tags) 588 if c.isAssignableIndirect(field, isPtr) { 589 if indirect { 590 structCode.isIndirect = true 591 } else { 592 structCode.isIndirect = false 593 } 594 } 595 } 596 } else { 597 structCode := field.getStruct() 598 if structCode != nil { 599 if indirect { 600 // if parent is indirect type, set child indirect property to true 601 structCode.isIndirect = true 602 } else { 603 // if parent is not indirect type, set child indirect property to false. 604 // but if parent's indirect is false and isPtr is true, then indirect must be true. 605 // Do this only if indirectConversion is enabled at the end of compileStruct. 606 structCode.isIndirect = false 607 } 608 } 609 } 610 fields = append(fields, field) 611 } 612 fieldMap := c.getFieldMap(fields) 613 duplicatedFieldMap := c.getDuplicatedFieldMap(fieldMap) 614 code.fields = c.filteredDuplicatedFields(fields, duplicatedFieldMap) 615 if !code.disableIndirectConversion && !indirect && isPtr { 616 code.enableIndirect() 617 } 618 delete(c.structTypeToCode, typeptr) 619 return code, nil 620 } 621 622 func (c *Compiler) structFieldCode(structCode *StructCode, tag *runtime.StructTag, isPtr, isOnlyOneFirstField bool) (*StructFieldCode, error) { 623 field := tag.Field 624 fieldType := runtime.Type2RType(field.Type) 625 isIndirectSpecialCase := isPtr && isOnlyOneFirstField 626 fieldCode := &StructFieldCode{ 627 typ: fieldType, 628 key: tag.Key, 629 tag: tag, 630 offset: field.Offset, 631 isAnonymous: field.Anonymous && !tag.IsTaggedKey, 632 isTaggedKey: tag.IsTaggedKey, 633 isNilableType: c.isNilableType(fieldType), 634 isNilCheck: true, 635 } 636 switch { 637 case c.isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(fieldType, isIndirectSpecialCase): 638 code, err := c.marshalJSONCode(fieldType) 639 if err != nil { 640 return nil, err 641 } 642 fieldCode.value = code 643 fieldCode.isAddrForMarshaler = true 644 fieldCode.isNilCheck = false 645 structCode.isIndirect = false 646 structCode.disableIndirectConversion = true 647 case c.isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(fieldType, isIndirectSpecialCase): 648 code, err := c.marshalTextCode(fieldType) 649 if err != nil { 650 return nil, err 651 } 652 fieldCode.value = code 653 fieldCode.isAddrForMarshaler = true 654 fieldCode.isNilCheck = false 655 structCode.isIndirect = false 656 structCode.disableIndirectConversion = true 657 case isPtr && c.isPtrMarshalJSONType(fieldType): 658 // *struct{ field T } 659 // func (*T) MarshalJSON() ([]byte, error) 660 code, err := c.marshalJSONCode(fieldType) 661 if err != nil { 662 return nil, err 663 } 664 fieldCode.value = code 665 fieldCode.isAddrForMarshaler = true 666 fieldCode.isNilCheck = false 667 case isPtr && c.isPtrMarshalTextType(fieldType): 668 // *struct{ field T } 669 // func (*T) MarshalText() ([]byte, error) 670 code, err := c.marshalTextCode(fieldType) 671 if err != nil { 672 return nil, err 673 } 674 fieldCode.value = code 675 fieldCode.isAddrForMarshaler = true 676 fieldCode.isNilCheck = false 677 default: 678 code, err := c.typeToCodeWithPtr(fieldType, isPtr) 679 if err != nil { 680 return nil, err 681 } 682 switch code.Kind() { 683 case CodeKindPtr, CodeKindInterface: 684 fieldCode.isNextOpPtrType = true 685 } 686 fieldCode.value = code 687 } 688 return fieldCode, nil 689 } 690 691 func (c *Compiler) isAssignableIndirect(fieldCode *StructFieldCode, isPtr bool) bool { 692 if isPtr { 693 return false 694 } 695 codeType := fieldCode.value.Kind() 696 if codeType == CodeKindMarshalJSON { 697 return false 698 } 699 if codeType == CodeKindMarshalText { 700 return false 701 } 702 return true 703 } 704 705 func (c *Compiler) getFieldMap(fields []*StructFieldCode) map[string][]*StructFieldCode { 706 fieldMap := map[string][]*StructFieldCode{} 707 for _, field := range fields { 708 if field.isAnonymous { 709 for k, v := range c.getAnonymousFieldMap(field) { 710 fieldMap[k] = append(fieldMap[k], v...) 711 } 712 continue 713 } 714 fieldMap[field.key] = append(fieldMap[field.key], field) 715 } 716 return fieldMap 717 } 718 719 func (c *Compiler) getAnonymousFieldMap(field *StructFieldCode) map[string][]*StructFieldCode { 720 fieldMap := map[string][]*StructFieldCode{} 721 structCode := field.getAnonymousStruct() 722 if structCode == nil || structCode.isRecursive { 723 fieldMap[field.key] = append(fieldMap[field.key], field) 724 return fieldMap 725 } 726 for k, v := range c.getFieldMapFromAnonymousParent(structCode.fields) { 727 fieldMap[k] = append(fieldMap[k], v...) 728 } 729 return fieldMap 730 } 731 732 func (c *Compiler) getFieldMapFromAnonymousParent(fields []*StructFieldCode) map[string][]*StructFieldCode { 733 fieldMap := map[string][]*StructFieldCode{} 734 for _, field := range fields { 735 if field.isAnonymous { 736 for k, v := range c.getAnonymousFieldMap(field) { 737 // Do not handle tagged key when embedding more than once 738 for _, vv := range v { 739 vv.isTaggedKey = false 740 } 741 fieldMap[k] = append(fieldMap[k], v...) 742 } 743 continue 744 } 745 fieldMap[field.key] = append(fieldMap[field.key], field) 746 } 747 return fieldMap 748 } 749 750 func (c *Compiler) getDuplicatedFieldMap(fieldMap map[string][]*StructFieldCode) map[*StructFieldCode]struct{} { 751 duplicatedFieldMap := map[*StructFieldCode]struct{}{} 752 for _, fields := range fieldMap { 753 if len(fields) == 1 { 754 continue 755 } 756 if c.isTaggedKeyOnly(fields) { 757 for _, field := range fields { 758 if field.isTaggedKey { 759 continue 760 } 761 duplicatedFieldMap[field] = struct{}{} 762 } 763 } else { 764 for _, field := range fields { 765 duplicatedFieldMap[field] = struct{}{} 766 } 767 } 768 } 769 return duplicatedFieldMap 770 } 771 772 func (c *Compiler) filteredDuplicatedFields(fields []*StructFieldCode, duplicatedFieldMap map[*StructFieldCode]struct{}) []*StructFieldCode { 773 filteredFields := make([]*StructFieldCode, 0, len(fields)) 774 for _, field := range fields { 775 if field.isAnonymous { 776 structCode := field.getAnonymousStruct() 777 if structCode != nil && !structCode.isRecursive { 778 structCode.fields = c.filteredDuplicatedFields(structCode.fields, duplicatedFieldMap) 779 if len(structCode.fields) > 0 { 780 filteredFields = append(filteredFields, field) 781 } 782 continue 783 } 784 } 785 if _, exists := duplicatedFieldMap[field]; exists { 786 continue 787 } 788 filteredFields = append(filteredFields, field) 789 } 790 return filteredFields 791 } 792 793 func (c *Compiler) isTaggedKeyOnly(fields []*StructFieldCode) bool { 794 var taggedKeyFieldCount int 795 for _, field := range fields { 796 if field.isTaggedKey { 797 taggedKeyFieldCount++ 798 } 799 } 800 return taggedKeyFieldCount == 1 801 } 802 803 func (c *Compiler) typeToStructTags(typ *runtime.Type) runtime.StructTags { 804 tags := runtime.StructTags{} 805 fieldNum := typ.NumField() 806 for i := 0; i < fieldNum; i++ { 807 field := typ.Field(i) 808 if runtime.IsIgnoredStructField(field) { 809 continue 810 } 811 tags = append(tags, runtime.StructTagFromField(field)) 812 } 813 return tags 814 } 815 816 // *struct{ field T } => struct { field *T } 817 // func (*T) MarshalJSON() ([]byte, error) 818 func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool { 819 return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalJSONType(typ) 820 } 821 822 // *struct{ field T } => struct { field *T } 823 // func (*T) MarshalText() ([]byte, error) 824 func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool { 825 return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalTextType(typ) 826 } 827 828 func (c *Compiler) implementsMarshalJSON(typ *runtime.Type) bool { 829 if !c.implementsMarshalJSONType(typ) { 830 return false 831 } 832 if typ.Kind() != reflect.Ptr { 833 return true 834 } 835 // type kind is reflect.Ptr 836 if !c.implementsMarshalJSONType(typ.Elem()) { 837 return true 838 } 839 // needs to dereference 840 return false 841 } 842 843 func (c *Compiler) implementsMarshalText(typ *runtime.Type) bool { 844 if !typ.Implements(marshalTextType) { 845 return false 846 } 847 if typ.Kind() != reflect.Ptr { 848 return true 849 } 850 // type kind is reflect.Ptr 851 if !typ.Elem().Implements(marshalTextType) { 852 return true 853 } 854 // needs to dereference 855 return false 856 } 857 858 func (c *Compiler) isNilableType(typ *runtime.Type) bool { 859 if !runtime.IfaceIndir(typ) { 860 return true 861 } 862 switch typ.Kind() { 863 case reflect.Ptr: 864 return true 865 case reflect.Map: 866 return true 867 case reflect.Func: 868 return true 869 default: 870 return false 871 } 872 } 873 874 func (c *Compiler) implementsMarshalJSONType(typ *runtime.Type) bool { 875 return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType) 876 } 877 878 func (c *Compiler) isPtrMarshalJSONType(typ *runtime.Type) bool { 879 return !c.implementsMarshalJSONType(typ) && c.implementsMarshalJSONType(runtime.PtrTo(typ)) 880 } 881 882 func (c *Compiler) isPtrMarshalTextType(typ *runtime.Type) bool { 883 return !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType) 884 } 885 886 func (c *Compiler) codeToOpcode(ctx *compileContext, typ *runtime.Type, code Code) *Opcode { 887 codes := code.ToOpcode(ctx) 888 codes.Last().Next = newEndOp(ctx, typ) 889 c.linkRecursiveCode(ctx) 890 return codes.First() 891 } 892 893 func (c *Compiler) linkRecursiveCode(ctx *compileContext) { 894 recursiveCodes := map[uintptr]*CompiledCode{} 895 for _, recursive := range *ctx.recursiveCodes { 896 typeptr := uintptr(unsafe.Pointer(recursive.Type)) 897 codes := ctx.structTypeToCodes[typeptr] 898 if recursiveCode, ok := recursiveCodes[typeptr]; ok { 899 *recursive.Jmp = *recursiveCode 900 continue 901 } 902 903 code := copyOpcode(codes.First()) 904 code.Op = code.Op.PtrHeadToHead() 905 lastCode := newEndOp(&compileContext{}, recursive.Type) 906 lastCode.Op = OpRecursiveEnd 907 908 // OpRecursiveEnd must set before call TotalLength 909 code.End.Next = lastCode 910 911 totalLength := code.TotalLength() 912 913 // Idx, ElemIdx, Length must set after call TotalLength 914 lastCode.Idx = uint32((totalLength + 1) * uintptrSize) 915 lastCode.ElemIdx = lastCode.Idx + uintptrSize 916 lastCode.Length = lastCode.Idx + 2*uintptrSize 917 918 // extend length to alloc slot for elemIdx + length 919 curTotalLength := uintptr(recursive.TotalLength()) + 3 920 nextTotalLength := uintptr(totalLength) + 3 921 922 compiled := recursive.Jmp 923 compiled.Code = code 924 compiled.CurLen = curTotalLength 925 compiled.NextLen = nextTotalLength 926 compiled.Linked = true 927 928 recursiveCodes[typeptr] = compiled 929 } 930 }