github.com/goccy/go-json@v0.10.3-0.20240509105655-5e2ae3f23c1d/internal/decoder/compile.go (about)

     1  package decoder
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  	"sync/atomic"
     9  	"unicode"
    10  	"unsafe"
    11  
    12  	"github.com/goccy/go-json/internal/runtime"
    13  )
    14  
    15  var (
    16  	jsonNumberType   = reflect.TypeOf(json.Number(""))
    17  	typeAddr         *runtime.TypeAddr
    18  	cachedDecoderMap unsafe.Pointer // map[uintptr]decoder
    19  	cachedDecoder    []Decoder
    20  )
    21  
    22  func init() {
    23  	typeAddr = runtime.AnalyzeTypeAddr()
    24  	if typeAddr == nil {
    25  		typeAddr = &runtime.TypeAddr{}
    26  	}
    27  	cachedDecoder = make([]Decoder, typeAddr.AddrRange>>typeAddr.AddrShift+1)
    28  }
    29  
    30  func loadDecoderMap() map[uintptr]Decoder {
    31  	p := atomic.LoadPointer(&cachedDecoderMap)
    32  	return *(*map[uintptr]Decoder)(unsafe.Pointer(&p))
    33  }
    34  
    35  func storeDecoder(typ uintptr, dec Decoder, m map[uintptr]Decoder) {
    36  	newDecoderMap := make(map[uintptr]Decoder, len(m)+1)
    37  	newDecoderMap[typ] = dec
    38  
    39  	for k, v := range m {
    40  		newDecoderMap[k] = v
    41  	}
    42  
    43  	atomic.StorePointer(&cachedDecoderMap, *(*unsafe.Pointer)(unsafe.Pointer(&newDecoderMap)))
    44  }
    45  
    46  func compileToGetDecoderSlowPath(typeptr uintptr, typ *runtime.Type) (Decoder, error) {
    47  	decoderMap := loadDecoderMap()
    48  	if dec, exists := decoderMap[typeptr]; exists {
    49  		return dec, nil
    50  	}
    51  
    52  	dec, err := compileHead(typ, map[uintptr]Decoder{})
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  	storeDecoder(typeptr, dec, decoderMap)
    57  	return dec, nil
    58  }
    59  
    60  func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
    61  	switch {
    62  	case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
    63  		return newUnmarshalJSONDecoder(runtime.PtrTo(typ), "", ""), nil
    64  	case runtime.PtrTo(typ).Implements(unmarshalTextType):
    65  		return newUnmarshalTextDecoder(runtime.PtrTo(typ), "", ""), nil
    66  	}
    67  	return compile(typ.Elem(), "", "", structTypeToDecoder)
    68  }
    69  
    70  func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
    71  	switch {
    72  	case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
    73  		return newUnmarshalJSONDecoder(runtime.PtrTo(typ), structName, fieldName), nil
    74  	case runtime.PtrTo(typ).Implements(unmarshalTextType):
    75  		return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil
    76  	}
    77  
    78  	switch typ.Kind() {
    79  	case reflect.Ptr:
    80  		return compilePtr(typ, structName, fieldName, structTypeToDecoder)
    81  	case reflect.Struct:
    82  		return compileStruct(typ, structName, fieldName, structTypeToDecoder)
    83  	case reflect.Slice:
    84  		elem := typ.Elem()
    85  		if elem.Kind() == reflect.Uint8 {
    86  			return compileBytes(elem, structName, fieldName)
    87  		}
    88  		return compileSlice(typ, structName, fieldName, structTypeToDecoder)
    89  	case reflect.Array:
    90  		return compileArray(typ, structName, fieldName, structTypeToDecoder)
    91  	case reflect.Map:
    92  		return compileMap(typ, structName, fieldName, structTypeToDecoder)
    93  	case reflect.Interface:
    94  		return compileInterface(typ, structName, fieldName)
    95  	case reflect.Uintptr:
    96  		return compileUint(typ, structName, fieldName)
    97  	case reflect.Int:
    98  		return compileInt(typ, structName, fieldName)
    99  	case reflect.Int8:
   100  		return compileInt8(typ, structName, fieldName)
   101  	case reflect.Int16:
   102  		return compileInt16(typ, structName, fieldName)
   103  	case reflect.Int32:
   104  		return compileInt32(typ, structName, fieldName)
   105  	case reflect.Int64:
   106  		return compileInt64(typ, structName, fieldName)
   107  	case reflect.Uint:
   108  		return compileUint(typ, structName, fieldName)
   109  	case reflect.Uint8:
   110  		return compileUint8(typ, structName, fieldName)
   111  	case reflect.Uint16:
   112  		return compileUint16(typ, structName, fieldName)
   113  	case reflect.Uint32:
   114  		return compileUint32(typ, structName, fieldName)
   115  	case reflect.Uint64:
   116  		return compileUint64(typ, structName, fieldName)
   117  	case reflect.String:
   118  		return compileString(typ, structName, fieldName)
   119  	case reflect.Bool:
   120  		return compileBool(structName, fieldName)
   121  	case reflect.Float32:
   122  		return compileFloat32(structName, fieldName)
   123  	case reflect.Float64:
   124  		return compileFloat64(structName, fieldName)
   125  	case reflect.Func:
   126  		return compileFunc(typ, structName, fieldName)
   127  	}
   128  	return newInvalidDecoder(typ, structName, fieldName), nil
   129  }
   130  
   131  func isStringTagSupportedType(typ *runtime.Type) bool {
   132  	switch {
   133  	case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
   134  		return false
   135  	case runtime.PtrTo(typ).Implements(unmarshalTextType):
   136  		return false
   137  	}
   138  	switch typ.Kind() {
   139  	case reflect.Map:
   140  		return false
   141  	case reflect.Slice:
   142  		return false
   143  	case reflect.Array:
   144  		return false
   145  	case reflect.Struct:
   146  		return false
   147  	case reflect.Interface:
   148  		return false
   149  	}
   150  	return true
   151  }
   152  
   153  func compileMapKey(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   154  	if runtime.PtrTo(typ).Implements(unmarshalTextType) {
   155  		return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil
   156  	}
   157  	if typ.Kind() == reflect.String {
   158  		return newStringDecoder(structName, fieldName), nil
   159  	}
   160  	dec, err := compile(typ, structName, fieldName, structTypeToDecoder)
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  	for {
   165  		switch t := dec.(type) {
   166  		case *stringDecoder, *interfaceDecoder:
   167  			return dec, nil
   168  		case *boolDecoder, *intDecoder, *uintDecoder, *numberDecoder:
   169  			return newWrappedStringDecoder(typ, dec, structName, fieldName), nil
   170  		case *ptrDecoder:
   171  			dec = t.dec
   172  		default:
   173  			return newInvalidDecoder(typ, structName, fieldName), nil
   174  		}
   175  	}
   176  }
   177  
   178  func compilePtr(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   179  	dec, err := compile(typ.Elem(), structName, fieldName, structTypeToDecoder)
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  	return newPtrDecoder(dec, typ.Elem(), structName, fieldName), nil
   184  }
   185  
   186  func compileInt(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   187  	return newIntDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v int64) {
   188  		*(*int)(p) = int(v)
   189  	}), nil
   190  }
   191  
   192  func compileInt8(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   193  	return newIntDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v int64) {
   194  		*(*int8)(p) = int8(v)
   195  	}), nil
   196  }
   197  
   198  func compileInt16(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   199  	return newIntDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v int64) {
   200  		*(*int16)(p) = int16(v)
   201  	}), nil
   202  }
   203  
   204  func compileInt32(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   205  	return newIntDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v int64) {
   206  		*(*int32)(p) = int32(v)
   207  	}), nil
   208  }
   209  
   210  func compileInt64(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   211  	return newIntDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v int64) {
   212  		*(*int64)(p) = v
   213  	}), nil
   214  }
   215  
   216  func compileUint(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   217  	return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) {
   218  		*(*uint)(p) = uint(v)
   219  	}), nil
   220  }
   221  
   222  func compileUint8(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   223  	return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) {
   224  		*(*uint8)(p) = uint8(v)
   225  	}), nil
   226  }
   227  
   228  func compileUint16(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   229  	return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) {
   230  		*(*uint16)(p) = uint16(v)
   231  	}), nil
   232  }
   233  
   234  func compileUint32(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   235  	return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) {
   236  		*(*uint32)(p) = uint32(v)
   237  	}), nil
   238  }
   239  
   240  func compileUint64(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   241  	return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) {
   242  		*(*uint64)(p) = v
   243  	}), nil
   244  }
   245  
   246  func compileFloat32(structName, fieldName string) (Decoder, error) {
   247  	return newFloatDecoder(structName, fieldName, func(p unsafe.Pointer, v float64) {
   248  		*(*float32)(p) = float32(v)
   249  	}), nil
   250  }
   251  
   252  func compileFloat64(structName, fieldName string) (Decoder, error) {
   253  	return newFloatDecoder(structName, fieldName, func(p unsafe.Pointer, v float64) {
   254  		*(*float64)(p) = v
   255  	}), nil
   256  }
   257  
   258  func compileString(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   259  	if typ == runtime.Type2RType(jsonNumberType) {
   260  		return newNumberDecoder(structName, fieldName, func(p unsafe.Pointer, v json.Number) {
   261  			*(*json.Number)(p) = v
   262  		}), nil
   263  	}
   264  	return newStringDecoder(structName, fieldName), nil
   265  }
   266  
   267  func compileBool(structName, fieldName string) (Decoder, error) {
   268  	return newBoolDecoder(structName, fieldName), nil
   269  }
   270  
   271  func compileBytes(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   272  	return newBytesDecoder(typ, structName, fieldName), nil
   273  }
   274  
   275  func compileSlice(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   276  	elem := typ.Elem()
   277  	decoder, err := compile(elem, structName, fieldName, structTypeToDecoder)
   278  	if err != nil {
   279  		return nil, err
   280  	}
   281  	return newSliceDecoder(decoder, elem, elem.Size(), structName, fieldName), nil
   282  }
   283  
   284  func compileArray(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   285  	elem := typ.Elem()
   286  	decoder, err := compile(elem, structName, fieldName, structTypeToDecoder)
   287  	if err != nil {
   288  		return nil, err
   289  	}
   290  	return newArrayDecoder(decoder, elem, typ.Len(), structName, fieldName), nil
   291  }
   292  
   293  func compileMap(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   294  	keyDec, err := compileMapKey(typ.Key(), structName, fieldName, structTypeToDecoder)
   295  	if err != nil {
   296  		return nil, err
   297  	}
   298  	valueDec, err := compile(typ.Elem(), structName, fieldName, structTypeToDecoder)
   299  	if err != nil {
   300  		return nil, err
   301  	}
   302  	return newMapDecoder(typ, typ.Key(), keyDec, typ.Elem(), valueDec, structName, fieldName), nil
   303  }
   304  
   305  func compileInterface(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   306  	return newInterfaceDecoder(typ, structName, fieldName), nil
   307  }
   308  
   309  func compileFunc(typ *runtime.Type, strutName, fieldName string) (Decoder, error) {
   310  	return newFuncDecoder(typ, strutName, fieldName), nil
   311  }
   312  
   313  func typeToStructTags(typ *runtime.Type) runtime.StructTags {
   314  	tags := runtime.StructTags{}
   315  	fieldNum := typ.NumField()
   316  	for i := 0; i < fieldNum; i++ {
   317  		field := typ.Field(i)
   318  		if runtime.IsIgnoredStructField(field) {
   319  			continue
   320  		}
   321  		tags = append(tags, runtime.StructTagFromField(field))
   322  	}
   323  	return tags
   324  }
   325  
   326  func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   327  	fieldNum := typ.NumField()
   328  	fieldMap := map[string]*structFieldSet{}
   329  	typeptr := uintptr(unsafe.Pointer(typ))
   330  	if dec, exists := structTypeToDecoder[typeptr]; exists {
   331  		return dec, nil
   332  	}
   333  	structDec := newStructDecoder(structName, fieldName, fieldMap)
   334  	structTypeToDecoder[typeptr] = structDec
   335  	structName = typ.Name()
   336  	tags := typeToStructTags(typ)
   337  	allFields := []*structFieldSet{}
   338  	for i := 0; i < fieldNum; i++ {
   339  		field := typ.Field(i)
   340  		if runtime.IsIgnoredStructField(field) {
   341  			continue
   342  		}
   343  		isUnexportedField := unicode.IsLower([]rune(field.Name)[0])
   344  		tag := runtime.StructTagFromField(field)
   345  		dec, err := compile(runtime.Type2RType(field.Type), structName, field.Name, structTypeToDecoder)
   346  		if err != nil {
   347  			return nil, err
   348  		}
   349  		if field.Anonymous && !tag.IsTaggedKey {
   350  			if stDec, ok := dec.(*structDecoder); ok {
   351  				if runtime.Type2RType(field.Type) == typ {
   352  					// recursive definition
   353  					continue
   354  				}
   355  				for k, v := range stDec.fieldMap {
   356  					if tags.ExistsKey(k) {
   357  						continue
   358  					}
   359  					fieldSet := &structFieldSet{
   360  						dec:         v.dec,
   361  						offset:      field.Offset + v.offset,
   362  						isTaggedKey: v.isTaggedKey,
   363  						key:         k,
   364  						keyLen:      int64(len(k)),
   365  					}
   366  					allFields = append(allFields, fieldSet)
   367  				}
   368  			} else if pdec, ok := dec.(*ptrDecoder); ok {
   369  				contentDec := pdec.contentDecoder()
   370  				if pdec.typ == typ {
   371  					// recursive definition
   372  					continue
   373  				}
   374  				var fieldSetErr error
   375  				if isUnexportedField {
   376  					fieldSetErr = fmt.Errorf(
   377  						"json: cannot set embedded pointer to unexported struct: %v",
   378  						field.Type.Elem(),
   379  					)
   380  				}
   381  				if dec, ok := contentDec.(*structDecoder); ok {
   382  					for k, v := range dec.fieldMap {
   383  						if tags.ExistsKey(k) {
   384  							continue
   385  						}
   386  						fieldSet := &structFieldSet{
   387  							dec:         newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec),
   388  							offset:      field.Offset,
   389  							isTaggedKey: v.isTaggedKey,
   390  							key:         k,
   391  							keyLen:      int64(len(k)),
   392  							err:         fieldSetErr,
   393  						}
   394  						allFields = append(allFields, fieldSet)
   395  					}
   396  				} else {
   397  					fieldSet := &structFieldSet{
   398  						dec:         pdec,
   399  						offset:      field.Offset,
   400  						isTaggedKey: tag.IsTaggedKey,
   401  						key:         field.Name,
   402  						keyLen:      int64(len(field.Name)),
   403  					}
   404  					allFields = append(allFields, fieldSet)
   405  				}
   406  			} else {
   407  				fieldSet := &structFieldSet{
   408  					dec:         dec,
   409  					offset:      field.Offset,
   410  					isTaggedKey: tag.IsTaggedKey,
   411  					key:         field.Name,
   412  					keyLen:      int64(len(field.Name)),
   413  				}
   414  				allFields = append(allFields, fieldSet)
   415  			}
   416  		} else {
   417  			if tag.IsString && isStringTagSupportedType(runtime.Type2RType(field.Type)) {
   418  				dec = newWrappedStringDecoder(runtime.Type2RType(field.Type), dec, structName, field.Name)
   419  			}
   420  			var key string
   421  			if tag.Key != "" {
   422  				key = tag.Key
   423  			} else {
   424  				key = field.Name
   425  			}
   426  			fieldSet := &structFieldSet{
   427  				dec:         dec,
   428  				offset:      field.Offset,
   429  				isTaggedKey: tag.IsTaggedKey,
   430  				key:         key,
   431  				keyLen:      int64(len(key)),
   432  			}
   433  			allFields = append(allFields, fieldSet)
   434  		}
   435  	}
   436  	for _, set := range filterDuplicatedFields(allFields) {
   437  		fieldMap[set.key] = set
   438  		lower := strings.ToLower(set.key)
   439  		if _, exists := fieldMap[lower]; !exists {
   440  			// first win
   441  			fieldMap[lower] = set
   442  		}
   443  	}
   444  	delete(structTypeToDecoder, typeptr)
   445  	structDec.tryOptimize()
   446  	return structDec, nil
   447  }
   448  
   449  func filterDuplicatedFields(allFields []*structFieldSet) []*structFieldSet {
   450  	fieldMap := map[string][]*structFieldSet{}
   451  	for _, field := range allFields {
   452  		fieldMap[field.key] = append(fieldMap[field.key], field)
   453  	}
   454  	duplicatedFieldMap := map[string]struct{}{}
   455  	for k, sets := range fieldMap {
   456  		sets = filterFieldSets(sets)
   457  		if len(sets) != 1 {
   458  			duplicatedFieldMap[k] = struct{}{}
   459  		}
   460  	}
   461  
   462  	filtered := make([]*structFieldSet, 0, len(allFields))
   463  	for _, field := range allFields {
   464  		if _, exists := duplicatedFieldMap[field.key]; exists {
   465  			continue
   466  		}
   467  		filtered = append(filtered, field)
   468  	}
   469  	return filtered
   470  }
   471  
   472  func filterFieldSets(sets []*structFieldSet) []*structFieldSet {
   473  	if len(sets) == 1 {
   474  		return sets
   475  	}
   476  	filtered := make([]*structFieldSet, 0, len(sets))
   477  	for _, set := range sets {
   478  		if set.isTaggedKey {
   479  			filtered = append(filtered, set)
   480  		}
   481  	}
   482  	return filtered
   483  }
   484  
   485  func implementsUnmarshalJSONType(typ *runtime.Type) bool {
   486  	return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType)
   487  }