github.com/goshafaq/sonic@v0.0.0-20231026082336-871835fb94c6/ast/decode.go (about)

     1  /*
     2   * Copyright 2022 ByteDance Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package ast
    18  
    19  import (
    20  	"encoding/base64"
    21  	"runtime"
    22  	"strconv"
    23  	"unsafe"
    24  
    25  	"github.com/goshafaq/sonic/internal/native/types"
    26  	"github.com/goshafaq/sonic/internal/rt"
    27  )
    28  
    29  const _blankCharsMask = (1 << ' ') | (1 << '\t') | (1 << '\r') | (1 << '\n')
    30  
    31  const (
    32  	bytesNull   = "null"
    33  	bytesTrue   = "true"
    34  	bytesFalse  = "false"
    35  	bytesObject = "{}"
    36  	bytesArray  = "[]"
    37  )
    38  
    39  func isSpace(c byte) bool {
    40  	return (int(1<<c) & _blankCharsMask) != 0
    41  }
    42  
    43  //go:nocheckptr
    44  func skipBlank(src string, pos int) int {
    45  	se := uintptr(rt.IndexChar(src, len(src)))
    46  	sp := uintptr(rt.IndexChar(src, pos))
    47  
    48  	for sp < se {
    49  		if !isSpace(*(*byte)(unsafe.Pointer(sp))) {
    50  			break
    51  		}
    52  		sp += 1
    53  	}
    54  	if sp >= se {
    55  		return -int(types.ERR_EOF)
    56  	}
    57  	runtime.KeepAlive(src)
    58  	return int(sp - uintptr(rt.IndexChar(src, 0)))
    59  }
    60  
    61  func decodeNull(src string, pos int) (ret int) {
    62  	ret = pos + 4
    63  	if ret > len(src) {
    64  		return -int(types.ERR_EOF)
    65  	}
    66  	if src[pos:ret] == bytesNull {
    67  		return ret
    68  	} else {
    69  		return -int(types.ERR_INVALID_CHAR)
    70  	}
    71  }
    72  
    73  func decodeTrue(src string, pos int) (ret int) {
    74  	ret = pos + 4
    75  	if ret > len(src) {
    76  		return -int(types.ERR_EOF)
    77  	}
    78  	if src[pos:ret] == bytesTrue {
    79  		return ret
    80  	} else {
    81  		return -int(types.ERR_INVALID_CHAR)
    82  	}
    83  
    84  }
    85  
    86  func decodeFalse(src string, pos int) (ret int) {
    87  	ret = pos + 5
    88  	if ret > len(src) {
    89  		return -int(types.ERR_EOF)
    90  	}
    91  	if src[pos:ret] == bytesFalse {
    92  		return ret
    93  	}
    94  	return -int(types.ERR_INVALID_CHAR)
    95  }
    96  
    97  //go:nocheckptr
    98  func decodeString(src string, pos int) (ret int, v string) {
    99  	ret, ep := skipString(src, pos)
   100  	if ep == -1 {
   101  		(*rt.GoString)(unsafe.Pointer(&v)).Ptr = rt.IndexChar(src, pos+1)
   102  		(*rt.GoString)(unsafe.Pointer(&v)).Len = ret - pos - 2
   103  		return ret, v
   104  	}
   105  
   106  	vv, ok := unquoteBytes(rt.Str2Mem(src[pos:ret]))
   107  	if !ok {
   108  		return -int(types.ERR_INVALID_CHAR), ""
   109  	}
   110  
   111  	runtime.KeepAlive(src)
   112  	return ret, rt.Mem2Str(vv)
   113  }
   114  
   115  func decodeBinary(src string, pos int) (ret int, v []byte) {
   116  	var vv string
   117  	ret, vv = decodeString(src, pos)
   118  	if ret < 0 {
   119  		return ret, nil
   120  	}
   121  	var err error
   122  	v, err = base64.StdEncoding.DecodeString(vv)
   123  	if err != nil {
   124  		return -int(types.ERR_INVALID_CHAR), nil
   125  	}
   126  	return ret, v
   127  }
   128  
   129  func isDigit(c byte) bool {
   130  	return c >= '0' && c <= '9'
   131  }
   132  
   133  //go:nocheckptr
   134  func decodeInt64(src string, pos int) (ret int, v int64, err error) {
   135  	sp := uintptr(rt.IndexChar(src, pos))
   136  	ss := uintptr(sp)
   137  	se := uintptr(rt.IndexChar(src, len(src)))
   138  	if uintptr(sp) >= se {
   139  		return -int(types.ERR_EOF), 0, nil
   140  	}
   141  
   142  	if c := *(*byte)(unsafe.Pointer(sp)); c == '-' {
   143  		sp += 1
   144  	}
   145  	if sp == se {
   146  		return -int(types.ERR_EOF), 0, nil
   147  	}
   148  
   149  	for ; sp < se; sp += uintptr(1) {
   150  		if !isDigit(*(*byte)(unsafe.Pointer(sp))) {
   151  			break
   152  		}
   153  	}
   154  
   155  	if sp < se {
   156  		if c := *(*byte)(unsafe.Pointer(sp)); c == '.' || c == 'e' || c == 'E' {
   157  			return -int(types.ERR_INVALID_NUMBER_FMT), 0, nil
   158  		}
   159  	}
   160  
   161  	var vv string
   162  	ret = int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr))
   163  	(*rt.GoString)(unsafe.Pointer(&vv)).Ptr = unsafe.Pointer(ss)
   164  	(*rt.GoString)(unsafe.Pointer(&vv)).Len = ret - pos
   165  
   166  	v, err = strconv.ParseInt(vv, 10, 64)
   167  	if err != nil {
   168  		//NOTICE: allow overflow here
   169  		if err.(*strconv.NumError).Err == strconv.ErrRange {
   170  			return ret, 0, err
   171  		}
   172  		return -int(types.ERR_INVALID_CHAR), 0, err
   173  	}
   174  
   175  	runtime.KeepAlive(src)
   176  	return ret, v, nil
   177  }
   178  
   179  func isNumberChars(c byte) bool {
   180  	return (c >= '0' && c <= '9') || c == '+' || c == '-' || c == 'e' || c == 'E' || c == '.'
   181  }
   182  
   183  //go:nocheckptr
   184  func decodeFloat64(src string, pos int) (ret int, v float64, err error) {
   185  	sp := uintptr(rt.IndexChar(src, pos))
   186  	ss := uintptr(sp)
   187  	se := uintptr(rt.IndexChar(src, len(src)))
   188  	if uintptr(sp) >= se {
   189  		return -int(types.ERR_EOF), 0, nil
   190  	}
   191  
   192  	if c := *(*byte)(unsafe.Pointer(sp)); c == '-' {
   193  		sp += 1
   194  	}
   195  	if sp == se {
   196  		return -int(types.ERR_EOF), 0, nil
   197  	}
   198  
   199  	for ; sp < se; sp += uintptr(1) {
   200  		if !isNumberChars(*(*byte)(unsafe.Pointer(sp))) {
   201  			break
   202  		}
   203  	}
   204  
   205  	var vv string
   206  	ret = int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr))
   207  	(*rt.GoString)(unsafe.Pointer(&vv)).Ptr = unsafe.Pointer(ss)
   208  	(*rt.GoString)(unsafe.Pointer(&vv)).Len = ret - pos
   209  
   210  	v, err = strconv.ParseFloat(vv, 64)
   211  	if err != nil {
   212  		//NOTICE: allow overflow here
   213  		if err.(*strconv.NumError).Err == strconv.ErrRange {
   214  			return ret, 0, err
   215  		}
   216  		return -int(types.ERR_INVALID_CHAR), 0, err
   217  	}
   218  
   219  	runtime.KeepAlive(src)
   220  	return ret, v, nil
   221  }
   222  
   223  func decodeValue(src string, pos int, skipnum bool) (ret int, v types.JsonState) {
   224  	pos = skipBlank(src, pos)
   225  	if pos < 0 {
   226  		return pos, types.JsonState{Vt: types.ValueType(pos)}
   227  	}
   228  	switch c := src[pos]; c {
   229  	case 'n':
   230  		ret = decodeNull(src, pos)
   231  		if ret < 0 {
   232  			return ret, types.JsonState{Vt: types.ValueType(ret)}
   233  		}
   234  		return ret, types.JsonState{Vt: types.V_NULL}
   235  	case '"':
   236  		var ep int
   237  		ret, ep = skipString(src, pos)
   238  		if ret < 0 {
   239  			return ret, types.JsonState{Vt: types.ValueType(ret)}
   240  		}
   241  		return ret, types.JsonState{Vt: types.V_STRING, Iv: int64(pos + 1), Ep: ep}
   242  	case '{':
   243  		return pos + 1, types.JsonState{Vt: types.V_OBJECT}
   244  	case '[':
   245  		return pos + 1, types.JsonState{Vt: types.V_ARRAY}
   246  	case 't':
   247  		ret = decodeTrue(src, pos)
   248  		if ret < 0 {
   249  			return ret, types.JsonState{Vt: types.ValueType(ret)}
   250  		}
   251  		return ret, types.JsonState{Vt: types.V_TRUE}
   252  	case 'f':
   253  		ret = decodeFalse(src, pos)
   254  		if ret < 0 {
   255  			return ret, types.JsonState{Vt: types.ValueType(ret)}
   256  		}
   257  		return ret, types.JsonState{Vt: types.V_FALSE}
   258  	case '-', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
   259  		if skipnum {
   260  			ret = skipNumber(src, pos)
   261  			if ret >= 0 {
   262  				return ret, types.JsonState{Vt: types.V_DOUBLE, Iv: 0, Ep: pos}
   263  			} else {
   264  				return ret, types.JsonState{Vt: types.ValueType(ret)}
   265  			}
   266  		} else {
   267  			var iv int64
   268  			ret, iv, _ = decodeInt64(src, pos)
   269  			if ret >= 0 {
   270  				return ret, types.JsonState{Vt: types.V_INTEGER, Iv: iv, Ep: pos}
   271  			} else if ret != -int(types.ERR_INVALID_NUMBER_FMT) {
   272  				return ret, types.JsonState{Vt: types.ValueType(ret)}
   273  			}
   274  			var fv float64
   275  			ret, fv, _ = decodeFloat64(src, pos)
   276  			if ret >= 0 {
   277  				return ret, types.JsonState{Vt: types.V_DOUBLE, Dv: fv, Ep: pos}
   278  			} else {
   279  				return ret, types.JsonState{Vt: types.ValueType(ret)}
   280  			}
   281  		}
   282  
   283  	default:
   284  		return -int(types.ERR_INVALID_CHAR), types.JsonState{Vt: -types.ValueType(types.ERR_INVALID_CHAR)}
   285  	}
   286  }
   287  
   288  //go:nocheckptr
   289  func skipNumber(src string, pos int) (ret int) {
   290  	sp := uintptr(rt.IndexChar(src, pos))
   291  	se := uintptr(rt.IndexChar(src, len(src)))
   292  	if uintptr(sp) >= se {
   293  		return -int(types.ERR_EOF)
   294  	}
   295  
   296  	if c := *(*byte)(unsafe.Pointer(sp)); c == '-' {
   297  		sp += 1
   298  	}
   299  	ss := sp
   300  
   301  	var pointer bool
   302  	var exponent bool
   303  	var lastIsDigit bool
   304  	var nextNeedDigit = true
   305  
   306  	for ; sp < se; sp += uintptr(1) {
   307  		c := *(*byte)(unsafe.Pointer(sp))
   308  		if isDigit(c) {
   309  			lastIsDigit = true
   310  			nextNeedDigit = false
   311  			continue
   312  		} else if nextNeedDigit {
   313  			return -int(types.ERR_INVALID_CHAR)
   314  		} else if c == '.' {
   315  			if !lastIsDigit || pointer || exponent || sp == ss {
   316  				return -int(types.ERR_INVALID_CHAR)
   317  			}
   318  			pointer = true
   319  			lastIsDigit = false
   320  			nextNeedDigit = true
   321  			continue
   322  		} else if c == 'e' || c == 'E' {
   323  			if !lastIsDigit || exponent {
   324  				return -int(types.ERR_INVALID_CHAR)
   325  			}
   326  			if sp == se-1 {
   327  				return -int(types.ERR_EOF)
   328  			}
   329  			exponent = true
   330  			lastIsDigit = false
   331  			nextNeedDigit = false
   332  			continue
   333  		} else if c == '-' || c == '+' {
   334  			if prev := *(*byte)(unsafe.Pointer(sp - 1)); prev != 'e' && prev != 'E' {
   335  				return -int(types.ERR_INVALID_CHAR)
   336  			}
   337  			lastIsDigit = false
   338  			nextNeedDigit = true
   339  			continue
   340  		} else {
   341  			break
   342  		}
   343  	}
   344  
   345  	if nextNeedDigit {
   346  		return -int(types.ERR_EOF)
   347  	}
   348  
   349  	runtime.KeepAlive(src)
   350  	return int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr))
   351  }
   352  
   353  //go:nocheckptr
   354  func skipString(src string, pos int) (ret int, ep int) {
   355  	if pos+1 >= len(src) {
   356  		return -int(types.ERR_EOF), -1
   357  	}
   358  
   359  	sp := uintptr(rt.IndexChar(src, pos))
   360  	se := uintptr(rt.IndexChar(src, len(src)))
   361  
   362  	// not start with quote
   363  	if *(*byte)(unsafe.Pointer(sp)) != '"' {
   364  		return -int(types.ERR_INVALID_CHAR), -1
   365  	}
   366  	sp += 1
   367  
   368  	ep = -1
   369  	for sp < se {
   370  		c := *(*byte)(unsafe.Pointer(sp))
   371  		if c == '\\' {
   372  			if ep == -1 {
   373  				ep = int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr))
   374  			}
   375  			sp += 2
   376  			continue
   377  		}
   378  		sp += 1
   379  		if c == '"' {
   380  			return int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr)), ep
   381  		}
   382  	}
   383  
   384  	runtime.KeepAlive(src)
   385  	// not found the closed quote until EOF
   386  	return -int(types.ERR_EOF), -1
   387  }
   388  
   389  //go:nocheckptr
   390  func skipPair(src string, pos int, lchar byte, rchar byte) (ret int) {
   391  	if pos+1 >= len(src) {
   392  		return -int(types.ERR_EOF)
   393  	}
   394  
   395  	sp := uintptr(rt.IndexChar(src, pos))
   396  	se := uintptr(rt.IndexChar(src, len(src)))
   397  
   398  	if *(*byte)(unsafe.Pointer(sp)) != lchar {
   399  		return -int(types.ERR_INVALID_CHAR)
   400  	}
   401  
   402  	sp += 1
   403  	nbrace := 1
   404  	inquote := false
   405  
   406  	for sp < se {
   407  		c := *(*byte)(unsafe.Pointer(sp))
   408  		if c == '\\' {
   409  			sp += 2
   410  			continue
   411  		} else if c == '"' {
   412  			inquote = !inquote
   413  		} else if c == lchar {
   414  			if !inquote {
   415  				nbrace += 1
   416  			}
   417  		} else if c == rchar {
   418  			if !inquote {
   419  				nbrace -= 1
   420  				if nbrace == 0 {
   421  					sp += 1
   422  					break
   423  				}
   424  			}
   425  		}
   426  		sp += 1
   427  	}
   428  
   429  	if nbrace != 0 {
   430  		return -int(types.ERR_INVALID_CHAR)
   431  	}
   432  
   433  	runtime.KeepAlive(src)
   434  	return int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr))
   435  }
   436  
   437  func skipValueFast(src string, pos int) (ret int, start int) {
   438  	pos = skipBlank(src, pos)
   439  	if pos < 0 {
   440  		return pos, -1
   441  	}
   442  	switch c := src[pos]; c {
   443  	case 'n':
   444  		ret = decodeNull(src, pos)
   445  	case '"':
   446  		ret, _ = skipString(src, pos)
   447  	case '{':
   448  		ret = skipPair(src, pos, '{', '}')
   449  	case '[':
   450  		ret = skipPair(src, pos, '[', ']')
   451  	case 't':
   452  		ret = decodeTrue(src, pos)
   453  	case 'f':
   454  		ret = decodeFalse(src, pos)
   455  	case '-', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
   456  		ret = skipNumber(src, pos)
   457  	default:
   458  		ret = -int(types.ERR_INVALID_CHAR)
   459  	}
   460  	return ret, pos
   461  }
   462  
   463  func skipValue(src string, pos int) (ret int, start int) {
   464  	pos = skipBlank(src, pos)
   465  	if pos < 0 {
   466  		return pos, -1
   467  	}
   468  	switch c := src[pos]; c {
   469  	case 'n':
   470  		ret = decodeNull(src, pos)
   471  	case '"':
   472  		ret, _ = skipString(src, pos)
   473  	case '{':
   474  		ret, _ = skipObject(src, pos)
   475  	case '[':
   476  		ret, _ = skipArray(src, pos)
   477  	case 't':
   478  		ret = decodeTrue(src, pos)
   479  	case 'f':
   480  		ret = decodeFalse(src, pos)
   481  	case '-', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
   482  		ret = skipNumber(src, pos)
   483  	default:
   484  		ret = -int(types.ERR_INVALID_CHAR)
   485  	}
   486  	return ret, pos
   487  }
   488  
   489  func skipObject(src string, pos int) (ret int, start int) {
   490  	start = skipBlank(src, pos)
   491  	if start < 0 {
   492  		return start, -1
   493  	}
   494  
   495  	if src[start] != '{' {
   496  		return -int(types.ERR_INVALID_CHAR), -1
   497  	}
   498  
   499  	pos = start + 1
   500  	pos = skipBlank(src, pos)
   501  	if pos < 0 {
   502  		return pos, -1
   503  	}
   504  	if src[pos] == '}' {
   505  		return pos + 1, start
   506  	}
   507  
   508  	for {
   509  		pos, _ = skipString(src, pos)
   510  		if pos < 0 {
   511  			return pos, -1
   512  		}
   513  
   514  		pos = skipBlank(src, pos)
   515  		if pos < 0 {
   516  			return pos, -1
   517  		}
   518  		if src[pos] != ':' {
   519  			return -int(types.ERR_INVALID_CHAR), -1
   520  		}
   521  
   522  		pos++
   523  		pos, _ = skipValue(src, pos)
   524  		if pos < 0 {
   525  			return pos, -1
   526  		}
   527  
   528  		pos = skipBlank(src, pos)
   529  		if pos < 0 {
   530  			return pos, -1
   531  		}
   532  		if src[pos] == '}' {
   533  			return pos + 1, start
   534  		}
   535  		if src[pos] != ',' {
   536  			return -int(types.ERR_INVALID_CHAR), -1
   537  		}
   538  
   539  		pos++
   540  		pos = skipBlank(src, pos)
   541  		if pos < 0 {
   542  			return pos, -1
   543  		}
   544  
   545  	}
   546  }
   547  
   548  func skipArray(src string, pos int) (ret int, start int) {
   549  	start = skipBlank(src, pos)
   550  	if start < 0 {
   551  		return start, -1
   552  	}
   553  
   554  	if src[start] != '[' {
   555  		return -int(types.ERR_INVALID_CHAR), -1
   556  	}
   557  
   558  	pos = start + 1
   559  	pos = skipBlank(src, pos)
   560  	if pos < 0 {
   561  		return pos, -1
   562  	}
   563  	if src[pos] == ']' {
   564  		return pos + 1, start
   565  	}
   566  
   567  	for {
   568  		pos, _ = skipValue(src, pos)
   569  		if pos < 0 {
   570  			return pos, -1
   571  		}
   572  
   573  		pos = skipBlank(src, pos)
   574  		if pos < 0 {
   575  			return pos, -1
   576  		}
   577  		if src[pos] == ']' {
   578  			return pos + 1, start
   579  		}
   580  		if src[pos] != ',' {
   581  			return -int(types.ERR_INVALID_CHAR), -1
   582  		}
   583  		pos++
   584  	}
   585  }