github.com/trim21/go-phpserialize@v0.0.22-0.20240301204449-2fca0319b3f0/internal/decoder/compile.go (about)

     1  package decoder
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strings"
     7  	"sync/atomic"
     8  	"unicode"
     9  	"unsafe"
    10  
    11  	"github.com/trim21/go-phpserialize/internal/runtime"
    12  )
    13  
    14  var (
    15  	typeAddr         *runtime.TypeAddr
    16  	cachedDecoderMap unsafe.Pointer // map[uintptr]decoder
    17  	cachedDecoder    []Decoder
    18  )
    19  
    20  func init() {
    21  	typeAddr = runtime.AnalyzeTypeAddr()
    22  	if typeAddr == nil {
    23  		typeAddr = &runtime.TypeAddr{}
    24  	}
    25  	cachedDecoder = make([]Decoder, typeAddr.AddrRange>>typeAddr.AddrShift+1)
    26  }
    27  
    28  func loadDecoderMap() map[uintptr]Decoder {
    29  	p := atomic.LoadPointer(&cachedDecoderMap)
    30  	return *(*map[uintptr]Decoder)(unsafe.Pointer(&p))
    31  }
    32  
    33  func storeDecoder(typ uintptr, dec Decoder, m map[uintptr]Decoder) {
    34  	newDecoderMap := make(map[uintptr]Decoder, len(m)+1)
    35  	newDecoderMap[typ] = dec
    36  
    37  	for k, v := range m {
    38  		newDecoderMap[k] = v
    39  	}
    40  
    41  	atomic.StorePointer(&cachedDecoderMap, *(*unsafe.Pointer)(unsafe.Pointer(&newDecoderMap)))
    42  }
    43  
    44  func compileToGetDecoderSlowPath(typeptr uintptr, typ *runtime.Type) (Decoder, error) {
    45  	decoderMap := loadDecoderMap()
    46  	if dec, exists := decoderMap[typeptr]; exists {
    47  		return dec, nil
    48  	}
    49  
    50  	dec, err := compileHead(typ, map[uintptr]Decoder{})
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  	storeDecoder(typeptr, dec, decoderMap)
    55  	return dec, nil
    56  }
    57  
    58  func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
    59  	switch {
    60  	case runtime.PtrTo(typ).Implements(unmarshalPHPType):
    61  		return newUnmarshalTextDecoder(runtime.PtrTo(typ), "", ""), nil
    62  	}
    63  	return compile(typ.Elem(), "", "", structTypeToDecoder)
    64  }
    65  
    66  func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
    67  	switch {
    68  	case runtime.PtrTo(typ).Implements(unmarshalPHPType):
    69  		return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil
    70  	}
    71  
    72  	switch typ.Kind() {
    73  	case reflect.Ptr:
    74  		return compilePtr(typ, structName, fieldName, structTypeToDecoder)
    75  	case reflect.Struct:
    76  		return compileStruct(typ, structName, fieldName, structTypeToDecoder)
    77  	case reflect.Slice:
    78  		elem := typ.Elem()
    79  		if elem.Kind() == reflect.Uint8 {
    80  			return compileBytes(elem, structName, fieldName)
    81  		}
    82  		return compileSlice(typ, structName, fieldName, structTypeToDecoder)
    83  	case reflect.Array:
    84  		return compileArray(typ, structName, fieldName, structTypeToDecoder)
    85  	case reflect.Map:
    86  		return compileMap(typ, structName, fieldName, structTypeToDecoder)
    87  	case reflect.Interface:
    88  		return compileInterface(typ, structName, fieldName)
    89  	case reflect.Uintptr:
    90  		return compileUint(typ, structName, fieldName)
    91  	case reflect.Int:
    92  		return compileInt(typ, structName, fieldName)
    93  	case reflect.Int8:
    94  		return compileInt8(typ, structName, fieldName)
    95  	case reflect.Int16:
    96  		return compileInt16(typ, structName, fieldName)
    97  	case reflect.Int32:
    98  		return compileInt32(typ, structName, fieldName)
    99  	case reflect.Int64:
   100  		return compileInt64(typ, structName, fieldName)
   101  	case reflect.Uint:
   102  		return compileUint(typ, structName, fieldName)
   103  	case reflect.Uint8:
   104  		return compileUint8(typ, structName, fieldName)
   105  	case reflect.Uint16:
   106  		return compileUint16(typ, structName, fieldName)
   107  	case reflect.Uint32:
   108  		return compileUint32(typ, structName, fieldName)
   109  	case reflect.Uint64:
   110  		return compileUint64(typ, structName, fieldName)
   111  	case reflect.String:
   112  		return compileString(typ, structName, fieldName)
   113  	case reflect.Bool:
   114  		return compileBool(structName, fieldName)
   115  	case reflect.Float32:
   116  		return compileFloat32(structName, fieldName)
   117  	case reflect.Float64:
   118  		return compileFloat64(structName, fieldName)
   119  	}
   120  	return newInvalidDecoder(typ, structName, fieldName), nil
   121  }
   122  
   123  func isStringTagSupportedType(typ *runtime.Type) bool {
   124  	switch {
   125  	case runtime.PtrTo(typ).Implements(unmarshalPHPType):
   126  		return false
   127  	}
   128  	switch typ.Kind() {
   129  	case reflect.Map:
   130  		return false
   131  	case reflect.Slice:
   132  		return false
   133  	case reflect.Array:
   134  		return false
   135  	case reflect.Struct:
   136  		return false
   137  	case reflect.Interface:
   138  		return false
   139  	}
   140  	return true
   141  }
   142  
   143  func compileMapKey(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   144  	if runtime.PtrTo(typ).Implements(unmarshalPHPType) {
   145  		return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil
   146  	}
   147  	if typ.Kind() == reflect.String {
   148  		return newStringDecoder(structName, fieldName), nil
   149  	}
   150  	dec, err := compile(typ, structName, fieldName, structTypeToDecoder)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  
   155  	for {
   156  		switch t := dec.(type) {
   157  		case *stringDecoder, *interfaceDecoder:
   158  			return dec, nil
   159  		case *boolDecoder, *intDecoder, *uintDecoder:
   160  			return dec, nil
   161  		case *ptrDecoder:
   162  			dec = t.dec
   163  		default:
   164  			return newInvalidDecoder(typ, structName, fieldName), nil
   165  		}
   166  	}
   167  }
   168  
   169  func compilePtr(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   170  	dec, err := compile(typ.Elem(), structName, fieldName, structTypeToDecoder)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  	return newPtrDecoder(dec, typ.Elem(), structName, fieldName)
   175  }
   176  
   177  func compileInt(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   178  	return newIntDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v int64) {
   179  		*(*int)(p) = int(v)
   180  	}), nil
   181  }
   182  
   183  func compileInt8(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   184  	return newIntDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v int64) {
   185  		*(*int8)(p) = int8(v)
   186  	}), nil
   187  }
   188  
   189  func compileInt16(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   190  	return newIntDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v int64) {
   191  		*(*int16)(p) = int16(v)
   192  	}), nil
   193  }
   194  
   195  func compileInt32(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   196  	return newIntDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v int64) {
   197  		*(*int32)(p) = int32(v)
   198  	}), nil
   199  }
   200  
   201  func compileInt64(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   202  	return newIntDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v int64) {
   203  		*(*int64)(p) = v
   204  	}), nil
   205  }
   206  
   207  func compileUint(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   208  	return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) {
   209  		*(*uint)(p) = uint(v)
   210  	}), nil
   211  }
   212  
   213  func compileUint8(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   214  	return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) {
   215  		*(*uint8)(p) = uint8(v)
   216  	}), nil
   217  }
   218  
   219  func compileUint16(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   220  	return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) {
   221  		*(*uint16)(p) = uint16(v)
   222  	}), nil
   223  }
   224  
   225  func compileUint32(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   226  	return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) {
   227  		*(*uint32)(p) = uint32(v)
   228  	}), nil
   229  }
   230  
   231  func compileUint64(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   232  	return newUintDecoder(typ, structName, fieldName, func(p unsafe.Pointer, v uint64) {
   233  		*(*uint64)(p) = v
   234  	}), nil
   235  }
   236  
   237  func compileFloat32(structName, fieldName string) (Decoder, error) {
   238  	return newFloatDecoder(structName, fieldName, func(p unsafe.Pointer, v float64) {
   239  		*(*float32)(p) = float32(v)
   240  	}), nil
   241  }
   242  
   243  func compileFloat64(structName, fieldName string) (Decoder, error) {
   244  	return newFloatDecoder(structName, fieldName, func(p unsafe.Pointer, v float64) {
   245  		*(*float64)(p) = v
   246  	}), nil
   247  }
   248  
   249  func compileString(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   250  	return newStringDecoder(structName, fieldName), nil
   251  }
   252  
   253  func compileBool(structName, fieldName string) (Decoder, error) {
   254  	return newBoolDecoder(structName, fieldName), nil
   255  }
   256  
   257  func compileBytes(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   258  	return newBytesDecoder(typ, structName, fieldName), nil
   259  }
   260  
   261  func compileSlice(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   262  	elem := typ.Elem()
   263  	decoder, err := compile(elem, structName, fieldName, structTypeToDecoder)
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  	return newSliceDecoder(decoder, elem, elem.Size(), structName, fieldName), nil
   268  }
   269  
   270  func compileArray(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   271  	elem := typ.Elem()
   272  	decoder, err := compile(elem, structName, fieldName, structTypeToDecoder)
   273  	if err != nil {
   274  		return nil, err
   275  	}
   276  	return newArrayDecoder(decoder, elem, typ.Len(), structName, fieldName), nil
   277  }
   278  
   279  func compileMap(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   280  	keyDec, err := compileMapKey(typ.Key(), structName, fieldName, structTypeToDecoder)
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  	valueDec, err := compile(typ.Elem(), structName, fieldName, structTypeToDecoder)
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  	return newMapDecoder(typ, typ.Key(), keyDec, typ.Elem(), valueDec, structName, fieldName), nil
   289  }
   290  
   291  func compileInterface(typ *runtime.Type, structName, fieldName string) (Decoder, error) {
   292  	return newInterfaceDecoder(typ, structName, fieldName), nil
   293  }
   294  
   295  func typeToStructTags(typ *runtime.Type) runtime.StructTags {
   296  	tags := runtime.StructTags{}
   297  	fieldNum := typ.NumField()
   298  	for i := 0; i < fieldNum; i++ {
   299  		field := typ.Field(i)
   300  		if runtime.IsIgnoredStructField(field) {
   301  			continue
   302  		}
   303  		tags = append(tags, runtime.StructTagFromField(field))
   304  	}
   305  	return tags
   306  }
   307  
   308  func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
   309  	fieldNum := typ.NumField()
   310  	fieldMap := map[string]*structFieldSet{}
   311  	typeptr := uintptr(unsafe.Pointer(typ))
   312  	if dec, exists := structTypeToDecoder[typeptr]; exists {
   313  		return dec, nil
   314  	}
   315  	structDec := newStructDecoder(structName, fieldName, fieldMap)
   316  	structTypeToDecoder[typeptr] = structDec
   317  	structName = typ.Name()
   318  	tags := typeToStructTags(typ)
   319  	allFields := []*structFieldSet{}
   320  	for i := 0; i < fieldNum; i++ {
   321  		field := typ.Field(i)
   322  		if runtime.IsIgnoredStructField(field) {
   323  			continue
   324  		}
   325  		isUnexportedField := unicode.IsLower([]rune(field.Name)[0])
   326  		tag := runtime.StructTagFromField(field)
   327  		dec, err := compile(runtime.Type2RType(field.Type), structName, field.Name, structTypeToDecoder)
   328  		if err != nil {
   329  			return nil, err
   330  		}
   331  		if field.Anonymous && !tag.IsTaggedKey {
   332  			if stDec, ok := dec.(*structDecoder); ok {
   333  				if runtime.Type2RType(field.Type) == typ {
   334  					// recursive definition
   335  					continue
   336  				}
   337  				for k, v := range stDec.fieldMap {
   338  					if tags.ExistsKey(k) {
   339  						continue
   340  					}
   341  					fieldSet := &structFieldSet{
   342  						dec:         v.dec,
   343  						offset:      field.Offset + v.offset,
   344  						isTaggedKey: v.isTaggedKey,
   345  						key:         k,
   346  						keyLen:      int64(len(k)),
   347  					}
   348  					allFields = append(allFields, fieldSet)
   349  				}
   350  			} else if pdec, ok := dec.(*ptrDecoder); ok {
   351  				contentDec := pdec.contentDecoder()
   352  				if pdec.typ == typ {
   353  					// recursive definition
   354  					continue
   355  				}
   356  				var fieldSetErr error
   357  				if isUnexportedField {
   358  					fieldSetErr = fmt.Errorf(
   359  						"php: cannot set embedded pointer to unexported struct: %v",
   360  						field.Type.Elem(),
   361  					)
   362  				}
   363  				if dec, ok := contentDec.(*structDecoder); ok {
   364  					for k, v := range dec.fieldMap {
   365  						if tags.ExistsKey(k) {
   366  							continue
   367  						}
   368  						fieldSet := &structFieldSet{
   369  							dec:         newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec),
   370  							offset:      field.Offset,
   371  							isTaggedKey: v.isTaggedKey,
   372  							key:         k,
   373  							keyLen:      int64(len(k)),
   374  							err:         fieldSetErr,
   375  						}
   376  						allFields = append(allFields, fieldSet)
   377  					}
   378  				}
   379  			} else {
   380  				fieldSet := &structFieldSet{
   381  					dec:         dec,
   382  					offset:      field.Offset,
   383  					isTaggedKey: tag.IsTaggedKey,
   384  					key:         field.Name,
   385  					keyLen:      int64(len(field.Name)),
   386  				}
   387  				allFields = append(allFields, fieldSet)
   388  			}
   389  		} else {
   390  			if tag.IsString && isStringTagSupportedType(runtime.Type2RType(field.Type)) {
   391  				dec, err = newWrappedStringDecoder(runtime.Type2RType(field.Type), dec, structName, field.Name)
   392  				if err != nil {
   393  					return nil, err
   394  				}
   395  			}
   396  			var key string
   397  			if tag.Key != "" {
   398  				key = tag.Key
   399  			} else {
   400  				key = field.Name
   401  			}
   402  			fieldSet := &structFieldSet{
   403  				dec:         dec,
   404  				offset:      field.Offset,
   405  				isTaggedKey: tag.IsTaggedKey,
   406  				key:         key,
   407  				keyLen:      int64(len(key)),
   408  			}
   409  			allFields = append(allFields, fieldSet)
   410  		}
   411  	}
   412  	for _, set := range filterDuplicatedFields(allFields) {
   413  		fieldMap[set.key] = set
   414  		lower := strings.ToLower(set.key)
   415  		if _, exists := fieldMap[lower]; !exists {
   416  			// first win
   417  			fieldMap[lower] = set
   418  		}
   419  	}
   420  	delete(structTypeToDecoder, typeptr)
   421  	structDec.tryOptimize()
   422  	return structDec, nil
   423  }
   424  
   425  func filterDuplicatedFields(allFields []*structFieldSet) []*structFieldSet {
   426  	fieldMap := map[string][]*structFieldSet{}
   427  	for _, field := range allFields {
   428  		fieldMap[field.key] = append(fieldMap[field.key], field)
   429  	}
   430  	duplicatedFieldMap := map[string]struct{}{}
   431  	for k, sets := range fieldMap {
   432  		sets = filterFieldSets(sets)
   433  		if len(sets) != 1 {
   434  			duplicatedFieldMap[k] = struct{}{}
   435  		}
   436  	}
   437  
   438  	filtered := make([]*structFieldSet, 0, len(allFields))
   439  	for _, field := range allFields {
   440  		if _, exists := duplicatedFieldMap[field.key]; exists {
   441  			continue
   442  		}
   443  		filtered = append(filtered, field)
   444  	}
   445  	return filtered
   446  }
   447  
   448  func filterFieldSets(sets []*structFieldSet) []*structFieldSet {
   449  	if len(sets) == 1 {
   450  		return sets
   451  	}
   452  	filtered := make([]*structFieldSet, 0, len(sets))
   453  	for _, set := range sets {
   454  		if set.isTaggedKey {
   455  			filtered = append(filtered, set)
   456  		}
   457  	}
   458  	return filtered
   459  }