github.com/3JoB/go-json@v0.10.4/internal/encoder/compiler.go (about) 1 package encoder 2 3 import ( 4 "context" 5 "encoding" 6 "encoding/json" 7 "sync/atomic" 8 "unsafe" 9 10 "github.com/3JoB/go-reflect" 11 12 "github.com/3JoB/go-json/internal/errors" 13 "github.com/3JoB/go-json/internal/runtime" 14 ) 15 16 type marshalerContext interface { 17 MarshalJSON(context.Context) ([]byte, error) 18 } 19 20 var ( 21 marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() 22 marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem() 23 marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() 24 jsonNumberType = reflect.TypeOf(json.Number("")) 25 cachedOpcodeSets []*OpcodeSet 26 cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet 27 typeAddr *runtime.TypeAddr 28 ) 29 30 func init() { 31 typeAddr = runtime.AnalyzeTypeAddr() 32 if typeAddr == nil { 33 typeAddr = &runtime.TypeAddr{} 34 } 35 cachedOpcodeSets = make([]*OpcodeSet, typeAddr.AddrRange>>typeAddr.AddrShift+1) 36 } 37 38 func loadOpcodeMap() map[uintptr]*OpcodeSet { 39 p := atomic.LoadPointer(&cachedOpcodeMap) 40 return *(*map[uintptr]*OpcodeSet)(unsafe.Pointer(&p)) 41 } 42 43 func storeOpcodeSet(typ uintptr, set *OpcodeSet, m map[uintptr]*OpcodeSet) { 44 newOpcodeMap := make(map[uintptr]*OpcodeSet, len(m)+1) 45 newOpcodeMap[typ] = set 46 47 for k, v := range m { 48 newOpcodeMap[k] = v 49 } 50 51 atomic.StorePointer(&cachedOpcodeMap, *(*unsafe.Pointer)(unsafe.Pointer(&newOpcodeMap))) 52 } 53 54 func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) { 55 opcodeMap := loadOpcodeMap() 56 if codeSet, exists := opcodeMap[typeptr]; exists { 57 return codeSet, nil 58 } 59 codeSet, err := newCompiler().compile(typeptr) 60 if err != nil { 61 return nil, err 62 } 63 storeOpcodeSet(typeptr, codeSet, opcodeMap) 64 return codeSet, nil 65 } 66 67 func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*OpcodeSet, error) { 68 if (ctx.Option.Flag & ContextOption) == 0 { 69 return codeSet, nil 70 } 71 query := FieldQueryFromContext(ctx.Option.Context) 72 if query == nil { 73 return codeSet, nil 74 } 75 ctx.Option.Flag |= FieldQueryOption 76 cacheCodeSet := codeSet.getQueryCache(query.Hash()) 77 if cacheCodeSet != nil { 78 return cacheCodeSet, nil 79 } 80 queryCodeSet, err := newCompiler().codeToOpcodeSet(codeSet.Type, codeSet.Code.Filter(query)) 81 if err != nil { 82 return nil, err 83 } 84 codeSet.setQueryCache(query.Hash(), queryCodeSet) 85 return queryCodeSet, nil 86 } 87 88 type Compiler struct { 89 structTypeToCode map[uintptr]*StructCode 90 } 91 92 func newCompiler() *Compiler { 93 return &Compiler{ 94 structTypeToCode: map[uintptr]*StructCode{}, 95 } 96 } 97 98 func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) { 99 // noescape trick for header.typ ( reflect.*rtype ) 100 typ := *(**runtime.Type)(unsafe.Pointer(&typeptr)) 101 code, err := c.typeToCode(typ) 102 if err != nil { 103 return nil, err 104 } 105 return c.codeToOpcodeSet(typ, code) 106 } 107 108 func (c *Compiler) codeToOpcodeSet(typ *runtime.Type, code Code) (*OpcodeSet, error) { 109 noescapeKeyCode := c.codeToOpcode(&compileContext{ 110 structTypeToCodes: map[uintptr]Opcodes{}, 111 recursiveCodes: &Opcodes{}, 112 }, typ, code) 113 if err := noescapeKeyCode.Validate(); err != nil { 114 return nil, err 115 } 116 escapeKeyCode := c.codeToOpcode(&compileContext{ 117 structTypeToCodes: map[uintptr]Opcodes{}, 118 recursiveCodes: &Opcodes{}, 119 escapeKey: true, 120 }, typ, code) 121 noescapeKeyCode = copyOpcode(noescapeKeyCode) 122 escapeKeyCode = copyOpcode(escapeKeyCode) 123 setTotalLengthToInterfaceOp(noescapeKeyCode) 124 setTotalLengthToInterfaceOp(escapeKeyCode) 125 interfaceNoescapeKeyCode := copyToInterfaceOpcode(noescapeKeyCode) 126 interfaceEscapeKeyCode := copyToInterfaceOpcode(escapeKeyCode) 127 codeLength := noescapeKeyCode.TotalLength() 128 return &OpcodeSet{ 129 Type: typ, 130 NoescapeKeyCode: noescapeKeyCode, 131 EscapeKeyCode: escapeKeyCode, 132 InterfaceNoescapeKeyCode: interfaceNoescapeKeyCode, 133 InterfaceEscapeKeyCode: interfaceEscapeKeyCode, 134 CodeLength: codeLength, 135 EndCode: ToEndCode(interfaceNoescapeKeyCode), 136 Code: code, 137 QueryCache: map[string]*OpcodeSet{}, 138 }, nil 139 } 140 141 func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) { 142 switch { 143 case c.implementsMarshalJSON(typ): 144 return c.marshalJSONCode(typ) 145 case c.implementsMarshalText(typ): 146 return c.marshalTextCode(typ) 147 } 148 149 isPtr := false 150 orgType := typ 151 if typ.Kind() == reflect.Ptr { 152 typ = typ.Elem() 153 isPtr = true 154 } 155 switch { 156 case c.implementsMarshalJSON(typ): 157 return c.marshalJSONCode(orgType) 158 case c.implementsMarshalText(typ): 159 return c.marshalTextCode(orgType) 160 } 161 switch typ.Kind() { 162 case reflect.Slice: 163 elem := typ.Elem() 164 if elem.Kind() == reflect.Uint8 { 165 p := runtime.PtrTo(elem) 166 if !c.implementsMarshalJSONType(p) && !p.Implements(reflect.ToRT(marshalTextType)) { 167 return c.bytesCode(typ, isPtr) 168 } 169 } 170 return c.sliceCode(typ) 171 case reflect.Map: 172 if isPtr { 173 return c.ptrCode(runtime.PtrTo(typ)) 174 } 175 return c.mapCode(typ) 176 case reflect.Struct: 177 return c.structCode(typ, isPtr) 178 case reflect.Int: 179 return c.intCode(typ, isPtr) 180 case reflect.Int8: 181 return c.int8Code(typ, isPtr) 182 case reflect.Int16: 183 return c.int16Code(typ, isPtr) 184 case reflect.Int32: 185 return c.int32Code(typ, isPtr) 186 case reflect.Int64: 187 return c.int64Code(typ, isPtr) 188 case reflect.Uint, reflect.Uintptr: 189 return c.uintCode(typ, isPtr) 190 case reflect.Uint8: 191 return c.uint8Code(typ, isPtr) 192 case reflect.Uint16: 193 return c.uint16Code(typ, isPtr) 194 case reflect.Uint32: 195 return c.uint32Code(typ, isPtr) 196 case reflect.Uint64: 197 return c.uint64Code(typ, isPtr) 198 case reflect.Float32: 199 return c.float32Code(typ, isPtr) 200 case reflect.Float64: 201 return c.float64Code(typ, isPtr) 202 case reflect.String: 203 return c.stringCode(typ, isPtr) 204 case reflect.Bool: 205 return c.boolCode(typ, isPtr) 206 case reflect.Interface: 207 return c.interfaceCode(typ, isPtr) 208 default: 209 if isPtr && typ.Implements(reflect.ToRT(marshalTextType)) { 210 typ = orgType 211 } 212 return c.typeToCodeWithPtr(typ, isPtr) 213 } 214 } 215 216 func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error) { 217 switch { 218 case c.implementsMarshalJSON(typ): 219 return c.marshalJSONCode(typ) 220 case c.implementsMarshalText(typ): 221 return c.marshalTextCode(typ) 222 } 223 switch typ.Kind() { 224 case reflect.Ptr: 225 return c.ptrCode(typ) 226 case reflect.Slice: 227 elem := typ.Elem() 228 if elem.Kind() == reflect.Uint8 { 229 p := runtime.PtrTo(elem) 230 if !c.implementsMarshalJSONType(p) && !p.Implements(reflect.ToRT(marshalTextType)) { 231 return c.bytesCode(typ, false) 232 } 233 } 234 return c.sliceCode(typ) 235 case reflect.Array: 236 return c.arrayCode(typ) 237 case reflect.Map: 238 return c.mapCode(typ) 239 case reflect.Struct: 240 return c.structCode(typ, isPtr) 241 case reflect.Interface: 242 return c.interfaceCode(typ, false) 243 case reflect.Int: 244 return c.intCode(typ, false) 245 case reflect.Int8: 246 return c.int8Code(typ, false) 247 case reflect.Int16: 248 return c.int16Code(typ, false) 249 case reflect.Int32: 250 return c.int32Code(typ, false) 251 case reflect.Int64: 252 return c.int64Code(typ, false) 253 case reflect.Uint: 254 return c.uintCode(typ, false) 255 case reflect.Uint8: 256 return c.uint8Code(typ, false) 257 case reflect.Uint16: 258 return c.uint16Code(typ, false) 259 case reflect.Uint32: 260 return c.uint32Code(typ, false) 261 case reflect.Uint64: 262 return c.uint64Code(typ, false) 263 case reflect.Uintptr: 264 return c.uintCode(typ, false) 265 case reflect.Float32: 266 return c.float32Code(typ, false) 267 case reflect.Float64: 268 return c.float64Code(typ, false) 269 case reflect.String: 270 return c.stringCode(typ, false) 271 case reflect.Bool: 272 return c.boolCode(typ, false) 273 } 274 return nil, &errors.UnsupportedTypeError{Type: reflect.ToT(runtime.RType2Type(typ))} 275 } 276 277 const intSize = 32 << (^uint(0) >> 63) 278 279 //nolint:unparam 280 func (c *Compiler) intCode(typ *runtime.Type, isPtr bool) (*IntCode, error) { 281 return &IntCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil 282 } 283 284 //nolint:unparam 285 func (c *Compiler) int8Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 286 return &IntCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil 287 } 288 289 //nolint:unparam 290 func (c *Compiler) int16Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 291 return &IntCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil 292 } 293 294 //nolint:unparam 295 func (c *Compiler) int32Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 296 return &IntCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil 297 } 298 299 //nolint:unparam 300 func (c *Compiler) int64Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 301 return &IntCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil 302 } 303 304 //nolint:unparam 305 func (c *Compiler) uintCode(typ *runtime.Type, isPtr bool) (*UintCode, error) { 306 return &UintCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil 307 } 308 309 //nolint:unparam 310 func (c *Compiler) uint8Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 311 return &UintCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil 312 } 313 314 //nolint:unparam 315 func (c *Compiler) uint16Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 316 return &UintCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil 317 } 318 319 //nolint:unparam 320 func (c *Compiler) uint32Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 321 return &UintCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil 322 } 323 324 //nolint:unparam 325 func (c *Compiler) uint64Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 326 return &UintCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil 327 } 328 329 //nolint:unparam 330 func (c *Compiler) float32Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) { 331 return &FloatCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil 332 } 333 334 //nolint:unparam 335 func (c *Compiler) float64Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) { 336 return &FloatCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil 337 } 338 339 //nolint:unparam 340 func (c *Compiler) stringCode(typ *runtime.Type, isPtr bool) (*StringCode, error) { 341 return &StringCode{typ: typ, isPtr: isPtr}, nil 342 } 343 344 //nolint:unparam 345 func (c *Compiler) boolCode(typ *runtime.Type, isPtr bool) (*BoolCode, error) { 346 return &BoolCode{typ: typ, isPtr: isPtr}, nil 347 } 348 349 //nolint:unparam 350 func (c *Compiler) intStringCode(typ *runtime.Type) (*IntCode, error) { 351 return &IntCode{typ: typ, bitSize: intSize, isString: true}, nil 352 } 353 354 //nolint:unparam 355 func (c *Compiler) int8StringCode(typ *runtime.Type) (*IntCode, error) { 356 return &IntCode{typ: typ, bitSize: 8, isString: true}, nil 357 } 358 359 //nolint:unparam 360 func (c *Compiler) int16StringCode(typ *runtime.Type) (*IntCode, error) { 361 return &IntCode{typ: typ, bitSize: 16, isString: true}, nil 362 } 363 364 //nolint:unparam 365 func (c *Compiler) int32StringCode(typ *runtime.Type) (*IntCode, error) { 366 return &IntCode{typ: typ, bitSize: 32, isString: true}, nil 367 } 368 369 //nolint:unparam 370 func (c *Compiler) int64StringCode(typ *runtime.Type) (*IntCode, error) { 371 return &IntCode{typ: typ, bitSize: 64, isString: true}, nil 372 } 373 374 //nolint:unparam 375 func (c *Compiler) uintStringCode(typ *runtime.Type) (*UintCode, error) { 376 return &UintCode{typ: typ, bitSize: intSize, isString: true}, nil 377 } 378 379 //nolint:unparam 380 func (c *Compiler) uint8StringCode(typ *runtime.Type) (*UintCode, error) { 381 return &UintCode{typ: typ, bitSize: 8, isString: true}, nil 382 } 383 384 //nolint:unparam 385 func (c *Compiler) uint16StringCode(typ *runtime.Type) (*UintCode, error) { 386 return &UintCode{typ: typ, bitSize: 16, isString: true}, nil 387 } 388 389 //nolint:unparam 390 func (c *Compiler) uint32StringCode(typ *runtime.Type) (*UintCode, error) { 391 return &UintCode{typ: typ, bitSize: 32, isString: true}, nil 392 } 393 394 //nolint:unparam 395 func (c *Compiler) uint64StringCode(typ *runtime.Type) (*UintCode, error) { 396 return &UintCode{typ: typ, bitSize: 64, isString: true}, nil 397 } 398 399 //nolint:unparam 400 func (c *Compiler) bytesCode(typ *runtime.Type, isPtr bool) (*BytesCode, error) { 401 return &BytesCode{typ: typ, isPtr: isPtr}, nil 402 } 403 404 //nolint:unparam 405 func (c *Compiler) interfaceCode(typ *runtime.Type, isPtr bool) (*InterfaceCode, error) { 406 return &InterfaceCode{typ: typ, isPtr: isPtr}, nil 407 } 408 409 //nolint:unparam 410 func (c *Compiler) marshalJSONCode(typ *runtime.Type) (*MarshalJSONCode, error) { 411 return &MarshalJSONCode{ 412 typ: typ, 413 isAddrForMarshaler: c.isPtrMarshalJSONType(typ), 414 isNilableType: c.isNilableType(typ), 415 isMarshalerContext: typ.Implements(reflect.ToRT(marshalJSONContextType)) || runtime.PtrTo(typ).Implements(reflect.ToRT(marshalJSONContextType)), 416 }, nil 417 } 418 419 //nolint:unparam 420 func (c *Compiler) marshalTextCode(typ *runtime.Type) (*MarshalTextCode, error) { 421 return &MarshalTextCode{ 422 typ: typ, 423 isAddrForMarshaler: c.isPtrMarshalTextType(typ), 424 isNilableType: c.isNilableType(typ), 425 }, nil 426 } 427 428 func (c *Compiler) ptrCode(typ *runtime.Type) (*PtrCode, error) { 429 code, err := c.typeToCodeWithPtr(typ.Elem(), true) 430 if err != nil { 431 return nil, err 432 } 433 ptr, ok := code.(*PtrCode) 434 if ok { 435 return &PtrCode{typ: typ, value: ptr.value, ptrNum: ptr.ptrNum + 1}, nil 436 } 437 return &PtrCode{typ: typ, value: code, ptrNum: 1}, nil 438 } 439 440 func (c *Compiler) sliceCode(typ *runtime.Type) (*SliceCode, error) { 441 elem := typ.Elem() 442 code, err := c.listElemCode(elem) 443 if err != nil { 444 return nil, err 445 } 446 if code.Kind() == CodeKindStruct { 447 structCode := code.(*StructCode) 448 structCode.enableIndirect() 449 } 450 return &SliceCode{typ: typ, value: code}, nil 451 } 452 453 func (c *Compiler) arrayCode(typ *runtime.Type) (*ArrayCode, error) { 454 elem := typ.Elem() 455 code, err := c.listElemCode(elem) 456 if err != nil { 457 return nil, err 458 } 459 if code.Kind() == CodeKindStruct { 460 structCode := code.(*StructCode) 461 structCode.enableIndirect() 462 } 463 return &ArrayCode{typ: typ, value: code}, nil 464 } 465 466 func (c *Compiler) mapCode(typ *runtime.Type) (*MapCode, error) { 467 keyCode, err := c.mapKeyCode(typ.Key()) 468 if err != nil { 469 return nil, err 470 } 471 valueCode, err := c.mapValueCode(typ.Elem()) 472 if err != nil { 473 return nil, err 474 } 475 if valueCode.Kind() == CodeKindStruct { 476 structCode := valueCode.(*StructCode) 477 structCode.enableIndirect() 478 } 479 return &MapCode{typ: typ, key: keyCode, value: valueCode}, nil 480 } 481 482 func (c *Compiler) listElemCode(typ *runtime.Type) (Code, error) { 483 switch { 484 case c.isPtrMarshalJSONType(typ): 485 return c.marshalJSONCode(typ) 486 case !typ.Implements(reflect.ToRT(marshalTextType)) && runtime.PtrTo(typ).Implements(reflect.ToRT(marshalTextType)): 487 return c.marshalTextCode(typ) 488 case typ.Kind() == reflect.Map: 489 return c.ptrCode(runtime.PtrTo(typ)) 490 default: 491 // isPtr was originally used to indicate whether the type of top level is pointer. 492 // However, since the slice/array element is a specification that can get the pointer address, explicitly set isPtr to true. 493 // See here for related issues: https://github.com/3JoB/go-json/issues/370 494 code, err := c.typeToCodeWithPtr(typ, true) 495 if err != nil { 496 return nil, err 497 } 498 ptr, ok := code.(*PtrCode) 499 if ok { 500 if ptr.value.Kind() == CodeKindMap { 501 ptr.ptrNum++ 502 } 503 } 504 return code, nil 505 } 506 } 507 508 func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) { 509 switch { 510 case c.implementsMarshalText(typ): 511 return c.marshalTextCode(typ) 512 } 513 switch typ.Kind() { 514 case reflect.Ptr: 515 return c.ptrCode(typ) 516 case reflect.String: 517 return c.stringCode(typ, false) 518 case reflect.Int: 519 return c.intStringCode(typ) 520 case reflect.Int8: 521 return c.int8StringCode(typ) 522 case reflect.Int16: 523 return c.int16StringCode(typ) 524 case reflect.Int32: 525 return c.int32StringCode(typ) 526 case reflect.Int64: 527 return c.int64StringCode(typ) 528 case reflect.Uint: 529 return c.uintStringCode(typ) 530 case reflect.Uint8: 531 return c.uint8StringCode(typ) 532 case reflect.Uint16: 533 return c.uint16StringCode(typ) 534 case reflect.Uint32: 535 return c.uint32StringCode(typ) 536 case reflect.Uint64: 537 return c.uint64StringCode(typ) 538 case reflect.Uintptr: 539 return c.uintStringCode(typ) 540 } 541 return nil, &errors.UnsupportedTypeError{Type: reflect.ToT(runtime.RType2Type(typ))} 542 } 543 544 func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) { 545 switch typ.Kind() { 546 case reflect.Map: 547 return c.ptrCode(runtime.PtrTo(typ)) 548 default: 549 code, err := c.typeToCodeWithPtr(typ, false) 550 if err != nil { 551 return nil, err 552 } 553 ptr, ok := code.(*PtrCode) 554 if ok { 555 if ptr.value.Kind() == CodeKindMap { 556 ptr.ptrNum++ 557 } 558 } 559 return code, nil 560 } 561 } 562 563 func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error) { 564 typeptr := uintptr(unsafe.Pointer(typ)) 565 if code, exists := c.structTypeToCode[typeptr]; exists { 566 derefCode := *code 567 derefCode.isRecursive = true 568 return &derefCode, nil 569 } 570 indirect := runtime.IfaceIndir(typ) 571 code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect} 572 c.structTypeToCode[typeptr] = code 573 574 fieldNum := typ.NumField() 575 tags := c.typeToStructTags(typ) 576 fields := []*StructFieldCode{} 577 for i, tag := range tags { 578 isOnlyOneFirstField := i == 0 && fieldNum == 1 579 field, err := c.structFieldCode(code, tag, isPtr, isOnlyOneFirstField) 580 if err != nil { 581 return nil, err 582 } 583 if field.isAnonymous { 584 structCode := field.getAnonymousStruct() 585 if structCode != nil { 586 structCode.removeFieldsByTags(tags) 587 if c.isAssignableIndirect(field, isPtr) { 588 if indirect { 589 structCode.isIndirect = true 590 } else { 591 structCode.isIndirect = false 592 } 593 } 594 } 595 } else { 596 structCode := field.getStruct() 597 if structCode != nil { 598 if indirect { 599 // if parent is indirect type, set child indirect property to true 600 structCode.isIndirect = true 601 } else { 602 // if parent is not indirect type, set child indirect property to false. 603 // but if parent's indirect is false and isPtr is true, then indirect must be true. 604 // Do this only if indirectConversion is enabled at the end of compileStruct. 605 structCode.isIndirect = false 606 } 607 } 608 } 609 fields = append(fields, field) 610 } 611 fieldMap := c.getFieldMap(fields) 612 duplicatedFieldMap := c.getDuplicatedFieldMap(fieldMap) 613 code.fields = c.filteredDuplicatedFields(fields, duplicatedFieldMap) 614 if !code.disableIndirectConversion && !indirect && isPtr { 615 code.enableIndirect() 616 } 617 delete(c.structTypeToCode, typeptr) 618 return code, nil 619 } 620 621 func toElemType(t *runtime.Type) *runtime.Type { 622 for t.Kind() == reflect.Ptr { 623 t = t.Elem() 624 } 625 return t 626 } 627 628 func (c *Compiler) structFieldCode(structCode *StructCode, tag *runtime.StructTag, isPtr, isOnlyOneFirstField bool) (*StructFieldCode, error) { 629 field := tag.Field 630 fieldType := runtime.Type2RType(field.Type) 631 isIndirectSpecialCase := isPtr && isOnlyOneFirstField 632 fieldCode := &StructFieldCode{ 633 typ: fieldType, 634 key: tag.Key, 635 tag: tag, 636 offset: field.Offset, 637 isAnonymous: field.Anonymous && !tag.IsTaggedKey && toElemType(fieldType).Kind() == reflect.Struct, 638 isTaggedKey: tag.IsTaggedKey, 639 isNilableType: c.isNilableType(fieldType), 640 isNilCheck: true, 641 } 642 switch { 643 case c.isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(fieldType, isIndirectSpecialCase): 644 code, err := c.marshalJSONCode(fieldType) 645 if err != nil { 646 return nil, err 647 } 648 fieldCode.value = code 649 fieldCode.isAddrForMarshaler = true 650 fieldCode.isNilCheck = false 651 structCode.isIndirect = false 652 structCode.disableIndirectConversion = true 653 case c.isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(fieldType, isIndirectSpecialCase): 654 code, err := c.marshalTextCode(fieldType) 655 if err != nil { 656 return nil, err 657 } 658 fieldCode.value = code 659 fieldCode.isAddrForMarshaler = true 660 fieldCode.isNilCheck = false 661 structCode.isIndirect = false 662 structCode.disableIndirectConversion = true 663 case isPtr && c.isPtrMarshalJSONType(fieldType): 664 // *struct{ field T } 665 // func (*T) MarshalJSON() ([]byte, error) 666 code, err := c.marshalJSONCode(fieldType) 667 if err != nil { 668 return nil, err 669 } 670 fieldCode.value = code 671 fieldCode.isAddrForMarshaler = true 672 fieldCode.isNilCheck = false 673 case isPtr && c.isPtrMarshalTextType(fieldType): 674 // *struct{ field T } 675 // func (*T) MarshalText() ([]byte, error) 676 code, err := c.marshalTextCode(fieldType) 677 if err != nil { 678 return nil, err 679 } 680 fieldCode.value = code 681 fieldCode.isAddrForMarshaler = true 682 fieldCode.isNilCheck = false 683 default: 684 code, err := c.typeToCodeWithPtr(fieldType, isPtr) 685 if err != nil { 686 return nil, err 687 } 688 switch code.Kind() { 689 case CodeKindPtr, CodeKindInterface: 690 fieldCode.isNextOpPtrType = true 691 } 692 fieldCode.value = code 693 } 694 return fieldCode, nil 695 } 696 697 func (c *Compiler) isAssignableIndirect(fieldCode *StructFieldCode, isPtr bool) bool { 698 if isPtr { 699 return false 700 } 701 codeType := fieldCode.value.Kind() 702 if codeType == CodeKindMarshalJSON { 703 return false 704 } 705 if codeType == CodeKindMarshalText { 706 return false 707 } 708 return true 709 } 710 711 func (c *Compiler) getFieldMap(fields []*StructFieldCode) map[string][]*StructFieldCode { 712 fieldMap := map[string][]*StructFieldCode{} 713 for _, field := range fields { 714 if field.isAnonymous { 715 for k, v := range c.getAnonymousFieldMap(field) { 716 fieldMap[k] = append(fieldMap[k], v...) 717 } 718 continue 719 } 720 fieldMap[field.key] = append(fieldMap[field.key], field) 721 } 722 return fieldMap 723 } 724 725 func (c *Compiler) getAnonymousFieldMap(field *StructFieldCode) map[string][]*StructFieldCode { 726 fieldMap := map[string][]*StructFieldCode{} 727 structCode := field.getAnonymousStruct() 728 if structCode == nil || structCode.isRecursive { 729 fieldMap[field.key] = append(fieldMap[field.key], field) 730 return fieldMap 731 } 732 for k, v := range c.getFieldMapFromAnonymousParent(structCode.fields) { 733 fieldMap[k] = append(fieldMap[k], v...) 734 } 735 return fieldMap 736 } 737 738 func (c *Compiler) getFieldMapFromAnonymousParent(fields []*StructFieldCode) map[string][]*StructFieldCode { 739 fieldMap := map[string][]*StructFieldCode{} 740 for _, field := range fields { 741 if field.isAnonymous { 742 for k, v := range c.getAnonymousFieldMap(field) { 743 // Do not handle tagged key when embedding more than once 744 for _, vv := range v { 745 vv.isTaggedKey = false 746 } 747 fieldMap[k] = append(fieldMap[k], v...) 748 } 749 continue 750 } 751 fieldMap[field.key] = append(fieldMap[field.key], field) 752 } 753 return fieldMap 754 } 755 756 func (c *Compiler) getDuplicatedFieldMap(fieldMap map[string][]*StructFieldCode) map[*StructFieldCode]struct{} { 757 duplicatedFieldMap := map[*StructFieldCode]struct{}{} 758 for _, fields := range fieldMap { 759 if len(fields) == 1 { 760 continue 761 } 762 if c.isTaggedKeyOnly(fields) { 763 for _, field := range fields { 764 if field.isTaggedKey { 765 continue 766 } 767 duplicatedFieldMap[field] = struct{}{} 768 } 769 } else { 770 for _, field := range fields { 771 duplicatedFieldMap[field] = struct{}{} 772 } 773 } 774 } 775 return duplicatedFieldMap 776 } 777 778 func (c *Compiler) filteredDuplicatedFields(fields []*StructFieldCode, duplicatedFieldMap map[*StructFieldCode]struct{}) []*StructFieldCode { 779 filteredFields := make([]*StructFieldCode, 0, len(fields)) 780 for _, field := range fields { 781 if field.isAnonymous { 782 structCode := field.getAnonymousStruct() 783 if structCode != nil && !structCode.isRecursive { 784 structCode.fields = c.filteredDuplicatedFields(structCode.fields, duplicatedFieldMap) 785 if len(structCode.fields) > 0 { 786 filteredFields = append(filteredFields, field) 787 } 788 continue 789 } 790 } 791 if _, exists := duplicatedFieldMap[field]; exists { 792 continue 793 } 794 filteredFields = append(filteredFields, field) 795 } 796 return filteredFields 797 } 798 799 func (c *Compiler) isTaggedKeyOnly(fields []*StructFieldCode) bool { 800 var taggedKeyFieldCount int 801 for _, field := range fields { 802 if field.isTaggedKey { 803 taggedKeyFieldCount++ 804 } 805 } 806 return taggedKeyFieldCount == 1 807 } 808 809 func (c *Compiler) typeToStructTags(typ *runtime.Type) runtime.StructTags { 810 tags := runtime.StructTags{} 811 fieldNum := typ.NumField() 812 for i := 0; i < fieldNum; i++ { 813 field := typ.Field(i) 814 if runtime.IsIgnoredStructField(field) { 815 continue 816 } 817 tags = append(tags, runtime.StructTagFromField(field)) 818 } 819 return tags 820 } 821 822 // *struct{ field T } => struct { field *T } 823 // func (*T) MarshalJSON() ([]byte, error) 824 func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool { 825 return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalJSONType(typ) 826 } 827 828 // *struct{ field T } => struct { field *T } 829 // func (*T) MarshalText() ([]byte, error) 830 func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool { 831 return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalTextType(typ) 832 } 833 834 func (c *Compiler) implementsMarshalJSON(typ *runtime.Type) bool { 835 if !c.implementsMarshalJSONType(typ) { 836 return false 837 } 838 if typ.Kind() != reflect.Ptr { 839 return true 840 } 841 // type kind is reflect.Ptr 842 if !c.implementsMarshalJSONType(typ.Elem()) { 843 return true 844 } 845 // needs to dereference 846 return false 847 } 848 849 func (c *Compiler) implementsMarshalText(typ *runtime.Type) bool { 850 if !typ.Implements(reflect.ToRT(marshalTextType)) { 851 return false 852 } 853 if typ.Kind() != reflect.Ptr { 854 return true 855 } 856 // type kind is reflect.Ptr 857 if !typ.Elem().Implements(reflect.ToRT(marshalTextType)) { 858 return true 859 } 860 // needs to dereference 861 return false 862 } 863 864 func (c *Compiler) isNilableType(typ *runtime.Type) bool { 865 if !runtime.IfaceIndir(typ) { 866 return true 867 } 868 switch typ.Kind() { 869 case reflect.Ptr: 870 return true 871 case reflect.Map: 872 return true 873 case reflect.Func: 874 return true 875 default: 876 return false 877 } 878 } 879 880 func (c *Compiler) implementsMarshalJSONType(typ *runtime.Type) bool { 881 return typ.Implements(reflect.ToRT(marshalJSONType)) || typ.Implements(reflect.ToRT(marshalJSONContextType)) 882 } 883 884 func (c *Compiler) isPtrMarshalJSONType(typ *runtime.Type) bool { 885 return !c.implementsMarshalJSONType(typ) && c.implementsMarshalJSONType(runtime.PtrTo(typ)) 886 } 887 888 func (c *Compiler) isPtrMarshalTextType(typ *runtime.Type) bool { 889 return !typ.Implements(reflect.ToRT(marshalTextType)) && runtime.PtrTo(typ).Implements(reflect.ToRT(marshalTextType)) 890 } 891 892 func (c *Compiler) codeToOpcode(ctx *compileContext, typ *runtime.Type, code Code) *Opcode { 893 codes := code.ToOpcode(ctx) 894 codes.Last().Next = newEndOp(ctx, typ) 895 c.linkRecursiveCode(ctx) 896 return codes.First() 897 } 898 899 func (c *Compiler) linkRecursiveCode(ctx *compileContext) { 900 recursiveCodes := map[uintptr]*CompiledCode{} 901 for _, recursive := range *ctx.recursiveCodes { 902 typeptr := uintptr(unsafe.Pointer(recursive.Type)) 903 codes := ctx.structTypeToCodes[typeptr] 904 if recursiveCode, ok := recursiveCodes[typeptr]; ok { 905 *recursive.Jmp = *recursiveCode 906 continue 907 } 908 909 code := copyOpcode(codes.First()) 910 code.Op = code.Op.PtrHeadToHead() 911 lastCode := newEndOp(&compileContext{}, recursive.Type) 912 lastCode.Op = OpRecursiveEnd 913 914 // OpRecursiveEnd must set before call TotalLength 915 code.End.Next = lastCode 916 917 totalLength := code.TotalLength() 918 919 // Idx, ElemIdx, Length must set after call TotalLength 920 lastCode.Idx = uint32((totalLength + 1) * uintptrSize) 921 lastCode.ElemIdx = lastCode.Idx + uintptrSize 922 lastCode.Length = lastCode.Idx + 2*uintptrSize 923 924 // extend length to alloc slot for elemIdx + length 925 curTotalLength := uintptr(recursive.TotalLength()) + 3 926 nextTotalLength := uintptr(totalLength) + 3 927 928 compiled := recursive.Jmp 929 compiled.Code = code 930 compiled.CurLen = curTotalLength 931 compiled.NextLen = nextTotalLength 932 compiled.Linked = true 933 934 recursiveCodes[typeptr] = compiled 935 } 936 }