github.com/trim21/go-phpserialize@v0.0.22-0.20240301204449-2fca0319b3f0/internal/decoder/struct.go (about)

     1  package decoder
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"math/bits"
     7  	"runtime"
     8  	"sort"
     9  	"strings"
    10  	"unicode"
    11  	"unicode/utf16"
    12  	"unsafe"
    13  
    14  	"github.com/trim21/go-phpserialize/internal/errors"
    15  )
    16  
    17  type structFieldSet struct {
    18  	dec         Decoder
    19  	offset      uintptr
    20  	isTaggedKey bool
    21  	fieldIdx    int
    22  	key         string
    23  	keyLen      int64
    24  	err         error
    25  }
    26  
    27  type structDecoder struct {
    28  	fieldMap           map[string]*structFieldSet
    29  	fieldUniqueNameNum int
    30  	stringDecoder      *stringDecoder
    31  	structName         string
    32  	fieldName          string
    33  	isTriedOptimize    bool
    34  	keyBitmapUint8     [][256]uint8
    35  	keyBitmapUint16    [][256]uint16
    36  	sortedFieldSets    []*structFieldSet
    37  	keyDecoder         func(*structDecoder, []byte, int64) (int64, *structFieldSet, error)
    38  	// keyStreamDecoder   func(*structDecoder, *Stream) (*structFieldSet, string, error)
    39  }
    40  
    41  var (
    42  	largeToSmallTable [256]byte
    43  )
    44  
    45  func init() {
    46  	for i := 0; i < 256; i++ {
    47  		c := i
    48  		if 'A' <= c && c <= 'Z' {
    49  			c += 'a' - 'A'
    50  		}
    51  		largeToSmallTable[i] = byte(c)
    52  	}
    53  }
    54  
    55  func newStructDecoder(structName, fieldName string, fieldMap map[string]*structFieldSet) *structDecoder {
    56  	return &structDecoder{
    57  		fieldMap:      fieldMap,
    58  		stringDecoder: newStringDecoder(structName, fieldName),
    59  		structName:    structName,
    60  		fieldName:     fieldName,
    61  		keyDecoder:    decodeKey,
    62  		// keyFStreamDecoder: decodeKeyStream,
    63  	}
    64  }
    65  
    66  const (
    67  	allowOptimizeMaxKeyLen   = 64
    68  	allowOptimizeMaxFieldLen = 16
    69  )
    70  
    71  func (d *structDecoder) tryOptimize() {
    72  	fieldUniqueNameMap := map[string]int{}
    73  	fieldIdx := -1
    74  	for k, v := range d.fieldMap {
    75  		lower := strings.ToLower(k)
    76  		idx, exists := fieldUniqueNameMap[lower]
    77  		if exists {
    78  			v.fieldIdx = idx
    79  		} else {
    80  			fieldIdx++
    81  			v.fieldIdx = fieldIdx
    82  		}
    83  		fieldUniqueNameMap[lower] = fieldIdx
    84  	}
    85  	d.fieldUniqueNameNum = len(fieldUniqueNameMap)
    86  
    87  	if d.isTriedOptimize {
    88  		return
    89  	}
    90  	fieldMap := map[string]*structFieldSet{}
    91  	conflicted := map[string]struct{}{}
    92  	for k, v := range d.fieldMap {
    93  		key := strings.ToLower(k)
    94  		if key != k {
    95  			// already exists same key (e.g. Hello and HELLO has same lower case key
    96  			if _, exists := conflicted[key]; exists {
    97  				d.isTriedOptimize = true
    98  				return
    99  			}
   100  			conflicted[key] = struct{}{}
   101  		}
   102  		if field, exists := fieldMap[key]; exists {
   103  			if field != v {
   104  				d.isTriedOptimize = true
   105  				return
   106  			}
   107  		}
   108  		fieldMap[key] = v
   109  	}
   110  
   111  	if len(fieldMap) > allowOptimizeMaxFieldLen {
   112  		d.isTriedOptimize = true
   113  		return
   114  	}
   115  
   116  	var maxKeyLen int
   117  	sortedKeys := []string{}
   118  	for key := range fieldMap {
   119  		keyLen := len(key)
   120  		if keyLen > allowOptimizeMaxKeyLen {
   121  			d.isTriedOptimize = true
   122  			return
   123  		}
   124  		if maxKeyLen < keyLen {
   125  			maxKeyLen = keyLen
   126  		}
   127  		sortedKeys = append(sortedKeys, key)
   128  	}
   129  	sort.Strings(sortedKeys)
   130  
   131  	// By allocating one extra capacity than `maxKeyLen`,
   132  	// it is possible to avoid the process of comparing the index of the key with the length of the bitmap each time.
   133  	bitmapLen := maxKeyLen + 1
   134  	// TODO:  this
   135  	// if len(sortedKeys) <= 8 {
   136  	// 	keyBitmap := make([][256]uint8, bitmapLen)
   137  	// 	for i, key := range sortedKeys {
   138  	// 		for j := 0; j < len(key); j++ {
   139  	// 			c := key[j]
   140  	// 			keyBitmap[j][c] |= (1 << uint(i))
   141  	// 		}
   142  	// 		d.sortedFieldSets = append(d.sortedFieldSets, fieldMap[key])
   143  	// 	}
   144  	// 	d.keyBitmapUint8 = keyBitmap
   145  	// d.keyDecoder = decodeKeyByBitmapUint8
   146  	// d.keyStreamDecoder = decodeKeyByBitmapUint8Stream
   147  	// } else {
   148  	keyBitmap := make([][256]uint16, bitmapLen)
   149  	for i, key := range sortedKeys {
   150  		for j := 0; j < len(key); j++ {
   151  			c := key[j]
   152  			keyBitmap[j][c] |= (1 << uint(i))
   153  		}
   154  		d.sortedFieldSets = append(d.sortedFieldSets, fieldMap[key])
   155  	}
   156  	d.keyBitmapUint16 = keyBitmap
   157  	d.keyDecoder = decodeKeyByBitmapUint16
   158  	// }
   159  }
   160  
   161  // decode from '\uXXXX'
   162  func decodeKeyCharByUnicodeRune(buf []byte, cursor int64) ([]byte, int64) {
   163  	const defaultOffset = 4
   164  	const surrogateOffset = 6
   165  
   166  	r := unicodeToRune(buf[cursor : cursor+defaultOffset])
   167  	if utf16.IsSurrogate(r) {
   168  		cursor += defaultOffset
   169  		if cursor+surrogateOffset >= int64(len(buf)) || buf[cursor] != '\\' || buf[cursor+1] != 'u' {
   170  			return []byte(string(unicode.ReplacementChar)), cursor + defaultOffset - 1
   171  		}
   172  		cursor += 2
   173  		r2 := unicodeToRune(buf[cursor : cursor+defaultOffset])
   174  		if r := utf16.DecodeRune(r, r2); r != unicode.ReplacementChar {
   175  			return []byte(string(r)), cursor + defaultOffset - 1
   176  		}
   177  	}
   178  	return []byte(string(r)), cursor + defaultOffset - 1
   179  }
   180  
   181  func decodeKeyCharByEscapedChar(buf []byte, cursor int64) ([]byte, int64) {
   182  	c := buf[cursor]
   183  	cursor++
   184  	switch c {
   185  	case '"':
   186  		return []byte{'"'}, cursor
   187  	case '\\':
   188  		return []byte{'\\'}, cursor
   189  	case '/':
   190  		return []byte{'/'}, cursor
   191  	case 'b':
   192  		return []byte{'\b'}, cursor
   193  	case 'f':
   194  		return []byte{'\f'}, cursor
   195  	case 'n':
   196  		return []byte{'\n'}, cursor
   197  	case 'r':
   198  		return []byte{'\r'}, cursor
   199  	case 't':
   200  		return []byte{'\t'}, cursor
   201  	case 'u':
   202  		return decodeKeyCharByUnicodeRune(buf, cursor)
   203  	}
   204  	return nil, cursor
   205  }
   206  
   207  // TODO: not finished
   208  func decodeKeyByBitmapUint8(d *structDecoder, buf []byte, cursor int64) (int64, *structFieldSet, error) {
   209  	var (
   210  		curBit uint8 = math.MaxUint8
   211  	)
   212  	b := (*sliceHeader)(unsafe.Pointer(&buf)).data
   213  	for {
   214  		switch char(b, cursor) {
   215  		case 'i':
   216  			// array with int key, should we skip or just omit?
   217  
   218  		// case '"':
   219  		case 's':
   220  			cursor++
   221  			c := char(b, cursor)
   222  			if c != ':' {
   223  				return 0, nil, errors.ErrSyntax(fmt.Sprintf("unexpected chat (%c) before str length", c), cursor)
   224  			}
   225  
   226  			cursor++
   227  			sLen, end, err := readLength(buf, cursor)
   228  			if err != nil {
   229  				return 0, nil, err
   230  			}
   231  			cursor = end
   232  
   233  			c = char(b, cursor)
   234  			if c != ':' {
   235  				return 0, nil, errors.ErrSyntax(fmt.Sprintf("unexpected chat (%c) before str length", c), cursor)
   236  			}
   237  
   238  			runtime.KeepAlive(sLen)
   239  			cursor++
   240  			c = char(b, cursor)
   241  			switch c {
   242  			case '"':
   243  				cursor++
   244  				return cursor, nil, nil
   245  			case nul:
   246  				return 0, nil, errors.ErrUnexpectedEnd("string", cursor)
   247  			}
   248  			keyIdx := 0
   249  			bitmap := d.keyBitmapUint8
   250  			start := cursor
   251  			for {
   252  				c := char(b, cursor)
   253  				switch c {
   254  				case '"':
   255  					fieldSetIndex := bits.TrailingZeros8(curBit)
   256  					field := d.sortedFieldSets[fieldSetIndex]
   257  					keyLen := cursor - start
   258  					cursor++
   259  					if keyLen < field.keyLen {
   260  						// early match
   261  						return cursor, nil, nil
   262  					}
   263  					return cursor, field, nil
   264  				case nul:
   265  					return 0, nil, errors.ErrUnexpectedEnd("string", cursor)
   266  				case '\\':
   267  					cursor++
   268  					chars, nextCursor := decodeKeyCharByEscapedChar(buf, cursor)
   269  					for _, c := range chars {
   270  						curBit &= bitmap[keyIdx][largeToSmallTable[c]]
   271  						if curBit == 0 {
   272  							return decodeKeyNotFound(b, cursor)
   273  						}
   274  						keyIdx++
   275  					}
   276  					cursor = nextCursor
   277  				default:
   278  					curBit &= bitmap[keyIdx][largeToSmallTable[c]]
   279  					if curBit == 0 {
   280  						return decodeKeyNotFound(b, cursor)
   281  					}
   282  					keyIdx++
   283  				}
   284  				cursor++
   285  			}
   286  		default:
   287  			return cursor, nil, errors.ErrInvalidBeginningOfValue(char(b, cursor), cursor)
   288  		}
   289  	}
   290  }
   291  
   292  func decodeKeyByBitmapUint16(d *structDecoder, buf []byte, cursor int64) (int64, *structFieldSet, error) {
   293  	var (
   294  		curBit uint16 = math.MaxUint16
   295  	)
   296  	b := (*sliceHeader)(unsafe.Pointer(&buf)).data
   297  
   298  	switch char(b, cursor) {
   299  	case 'i':
   300  	// TODO array with int key
   301  	// array with int key, should we skip or just omit?
   302  	case 's':
   303  		cursor++
   304  		sLen, end, err := readLength(buf, cursor)
   305  		if err != nil {
   306  			return 0, nil, err
   307  		}
   308  		cursor = end
   309  		cursor++ // '"'
   310  
   311  		keyIdx := 0
   312  		bitmap := d.keyBitmapUint16
   313  		start := cursor
   314  
   315  		if char(b, start+sLen) != '"' {
   316  			return 0, nil, errors.ErrExpected("string should be quoted", cursor)
   317  		}
   318  
   319  		if char(b, start+sLen+1) != ';' {
   320  			return 0, nil, errors.ErrExpected("string end with semi", cursor)
   321  		}
   322  
   323  		for i := start; i < start+sLen; i++ {
   324  			cursor = i
   325  			c := char(b, cursor)
   326  			curBit &= bitmap[keyIdx][largeToSmallTable[c]]
   327  			if curBit == 0 {
   328  				return decodeKeyNotFound(b, cursor)
   329  			}
   330  			keyIdx++
   331  		}
   332  
   333  		fieldSetIndex := bits.TrailingZeros16(curBit)
   334  		field := d.sortedFieldSets[fieldSetIndex]
   335  		cursor++
   336  		if sLen < field.keyLen {
   337  			// early match
   338  			return cursor, nil, nil
   339  		}
   340  		cursor++ // '"'
   341  		cursor++ // ';'
   342  		return cursor, field, nil
   343  	}
   344  
   345  	return cursor, nil, errors.ErrInvalidBeginningOfValue(char(b, cursor), cursor)
   346  }
   347  
   348  func decodeKeyNotFound(b unsafe.Pointer, cursor int64) (int64, *structFieldSet, error) {
   349  	for {
   350  		cursor++
   351  		switch char(b, cursor) {
   352  		case '"':
   353  			cursor += 2
   354  			return cursor, nil, nil
   355  		case nul:
   356  			return 0, nil, errors.ErrUnexpectedEnd("string", cursor)
   357  		}
   358  	}
   359  }
   360  
   361  func decodeKey(d *structDecoder, buf []byte, cursor int64) (int64, *structFieldSet, error) {
   362  	key, c, err := d.stringDecoder.decodeByte(buf, cursor)
   363  	if err != nil {
   364  		return 0, nil, err
   365  	}
   366  	cursor = c
   367  	k := *(*string)(unsafe.Pointer(&key))
   368  	field, exists := d.fieldMap[k]
   369  	if !exists {
   370  		return cursor, nil, nil
   371  	}
   372  	return cursor, field, nil
   373  }
   374  
   375  func (d *structDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.Pointer) (int64, error) {
   376  	buf := ctx.Buf
   377  	depth++
   378  	if depth > maxDecodeNestingDepth {
   379  		return 0, errors.ErrExceededMaxDepth(buf[cursor], cursor)
   380  	}
   381  	buflen := int64(len(buf))
   382  	b := (*sliceHeader)(unsafe.Pointer(&buf)).data
   383  	switch char(b, cursor) {
   384  	case 'N':
   385  		if err := validateNull(buf, cursor); err != nil {
   386  			return 0, err
   387  		}
   388  		cursor += 2
   389  		return cursor, nil
   390  	case 'O':
   391  		// O:8:"stdClass":1:{s:1:"a";s:1:"q";}
   392  		end, err := skipClassName(buf, cursor)
   393  		if err != nil {
   394  			return cursor, err
   395  		}
   396  		cursor = end
   397  		fallthrough
   398  	case 'a':
   399  		cursor++
   400  		if buf[cursor] != ':' {
   401  			return 0, errors.ErrInvalidBeginningOfValue(char(b, cursor), cursor)
   402  		}
   403  	default:
   404  		return 0, errors.ErrInvalidBeginningOfValue(char(b, cursor), cursor)
   405  	}
   406  
   407  	// skip  :${length}:
   408  	end, err := skipLengthWithBothColon(buf, cursor)
   409  	if err != nil {
   410  		return cursor, err
   411  	}
   412  	cursor = end
   413  	if buf[cursor] != '{' {
   414  		return 0, errors.ErrInvalidBeginningOfArray(char(b, cursor), cursor)
   415  	}
   416  
   417  	cursor++
   418  	if buf[cursor] == '}' {
   419  		cursor++
   420  		return cursor, nil
   421  	}
   422  
   423  	for {
   424  		c, field, err := d.keyDecoder(d, buf, cursor)
   425  		if err != nil {
   426  			return 0, err
   427  		}
   428  
   429  		cursor = c
   430  
   431  		// cursor++
   432  		if cursor >= buflen {
   433  			return 0, errors.ErrExpected("object value after colon", cursor)
   434  		}
   435  		if field != nil {
   436  			if field.err != nil {
   437  				return 0, field.err
   438  			}
   439  			c, err := field.dec.Decode(ctx, cursor, depth, unsafe.Pointer(uintptr(p)+field.offset))
   440  			if err != nil {
   441  				return 0, err
   442  			}
   443  			cursor = c
   444  		} else {
   445  			c, err := skipValue(buf, cursor, depth)
   446  			if err != nil {
   447  				return 0, err
   448  			}
   449  			cursor = c
   450  		}
   451  
   452  		if char(b, cursor) == '}' {
   453  			cursor++
   454  			return cursor, nil
   455  		}
   456  	}
   457  }