github.com/bytedance/sonic@v1.11.7-0.20240517092252-d2edb31b167b/ast/decode.go (about)

     1  /*
     2   * Copyright 2022 ByteDance Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package ast
    18  
    19  import (
    20      `encoding/base64`
    21      `runtime`
    22      `strconv`
    23      `unsafe`
    24  
    25      `github.com/bytedance/sonic/internal/native/types`
    26      `github.com/bytedance/sonic/internal/rt`
    27  )
    28  
    29  const _blankCharsMask = (1 << ' ') | (1 << '\t') | (1 << '\r') | (1 << '\n')
    30  
    31  const (
    32      bytesNull   = "null"
    33      bytesTrue   = "true"
    34      bytesFalse  = "false"
    35      bytesObject = "{}"
    36      bytesArray  = "[]"
    37  )
    38  
    39  func isSpace(c byte) bool {
    40      return (int(1<<c) & _blankCharsMask) != 0
    41  }
    42  
    43  //go:nocheckptr
    44  func skipBlank(src string, pos int) int {
    45      se := uintptr(rt.IndexChar(src, len(src)))
    46      sp := uintptr(rt.IndexChar(src, pos))
    47  
    48      for sp < se {
    49          if !isSpace(*(*byte)(unsafe.Pointer(sp))) {
    50              break
    51          }
    52          sp += 1
    53      }
    54      if sp >= se {
    55          return -int(types.ERR_EOF)
    56      }
    57      runtime.KeepAlive(src)
    58      return int(sp - uintptr(rt.IndexChar(src, 0)))
    59  }
    60  
    61  func decodeNull(src string, pos int) (ret int) {
    62      ret = pos + 4
    63      if ret > len(src) {
    64          return -int(types.ERR_EOF)
    65      }
    66      if src[pos:ret] == bytesNull {
    67          return ret
    68      } else {
    69          return -int(types.ERR_INVALID_CHAR)
    70      }
    71  }
    72  
    73  func decodeTrue(src string, pos int) (ret int) {
    74      ret = pos + 4
    75      if ret > len(src) {
    76          return -int(types.ERR_EOF)
    77      }
    78      if src[pos:ret] == bytesTrue {
    79          return ret
    80      } else {
    81          return -int(types.ERR_INVALID_CHAR)
    82      }
    83  
    84  }
    85  
    86  func decodeFalse(src string, pos int) (ret int) {
    87      ret = pos + 5
    88      if ret > len(src) {
    89          return -int(types.ERR_EOF)
    90      }
    91      if src[pos:ret] == bytesFalse {
    92          return ret
    93      }
    94      return -int(types.ERR_INVALID_CHAR)
    95  }
    96  
    97  //go:nocheckptr
    98  func decodeString(src string, pos int) (ret int, v string) {
    99      ret, ep := skipString(src, pos)
   100      if ep == -1 {
   101          (*rt.GoString)(unsafe.Pointer(&v)).Ptr = rt.IndexChar(src, pos+1)
   102          (*rt.GoString)(unsafe.Pointer(&v)).Len = ret - pos - 2
   103          return ret, v
   104      }
   105  
   106      vv, ok := unquoteBytes(rt.Str2Mem(src[pos:ret]))
   107      if !ok {
   108          return -int(types.ERR_INVALID_CHAR), ""
   109      }
   110  
   111      runtime.KeepAlive(src)
   112      return ret, rt.Mem2Str(vv)
   113  }
   114  
   115  func decodeBinary(src string, pos int) (ret int, v []byte) {
   116      var vv string
   117      ret, vv = decodeString(src, pos)
   118      if ret < 0 {
   119          return ret, nil
   120      }
   121      var err error
   122      v, err = base64.StdEncoding.DecodeString(vv)
   123      if err != nil {
   124          return -int(types.ERR_INVALID_CHAR), nil
   125      }
   126      return ret, v
   127  }
   128  
   129  func isDigit(c byte) bool {
   130      return c >= '0' && c <= '9'
   131  }
   132  
   133  //go:nocheckptr
   134  func decodeInt64(src string, pos int) (ret int, v int64, err error) {
   135      sp := uintptr(rt.IndexChar(src, pos))
   136      ss := uintptr(sp)
   137      se := uintptr(rt.IndexChar(src, len(src)))
   138      if uintptr(sp) >= se {
   139          return -int(types.ERR_EOF), 0, nil
   140      }
   141  
   142      if c := *(*byte)(unsafe.Pointer(sp)); c == '-' {
   143          sp += 1
   144      }
   145      if sp == se {
   146          return -int(types.ERR_EOF), 0, nil
   147      }
   148  
   149      for ; sp < se; sp += uintptr(1) {
   150          if !isDigit(*(*byte)(unsafe.Pointer(sp))) {
   151              break
   152          }
   153      }
   154  
   155      if sp < se {
   156          if c := *(*byte)(unsafe.Pointer(sp)); c == '.' || c == 'e' || c == 'E' {
   157              return -int(types.ERR_INVALID_NUMBER_FMT), 0, nil
   158          }
   159      }
   160  
   161      var vv string
   162      ret = int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr))
   163      (*rt.GoString)(unsafe.Pointer(&vv)).Ptr = unsafe.Pointer(ss)
   164      (*rt.GoString)(unsafe.Pointer(&vv)).Len = ret - pos
   165  
   166      v, err = strconv.ParseInt(vv, 10, 64)
   167      if err != nil {
   168          //NOTICE: allow overflow here
   169          if err.(*strconv.NumError).Err == strconv.ErrRange {
   170              return ret, 0, err
   171          }
   172          return -int(types.ERR_INVALID_CHAR), 0, err
   173      }
   174  
   175      runtime.KeepAlive(src)
   176      return ret, v, nil
   177  }
   178  
   179  func isNumberChars(c byte) bool {
   180      return (c >= '0' && c <= '9') || c == '+' || c == '-' || c == 'e' || c == 'E' || c == '.'
   181  }
   182  
   183  //go:nocheckptr
   184  func decodeFloat64(src string, pos int) (ret int, v float64, err error) {
   185      sp := uintptr(rt.IndexChar(src, pos))
   186      ss := uintptr(sp)
   187      se := uintptr(rt.IndexChar(src, len(src)))
   188      if uintptr(sp) >= se {
   189          return -int(types.ERR_EOF), 0, nil
   190      }
   191  
   192      if c := *(*byte)(unsafe.Pointer(sp)); c == '-' {
   193          sp += 1
   194      }
   195      if sp == se {
   196          return -int(types.ERR_EOF), 0, nil
   197      }
   198  
   199      for ; sp < se; sp += uintptr(1) {
   200          if !isNumberChars(*(*byte)(unsafe.Pointer(sp))) {
   201              break
   202          }
   203      }
   204  
   205      var vv string
   206      ret = int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr))
   207      (*rt.GoString)(unsafe.Pointer(&vv)).Ptr = unsafe.Pointer(ss)
   208      (*rt.GoString)(unsafe.Pointer(&vv)).Len = ret - pos
   209  
   210      v, err = strconv.ParseFloat(vv, 64)
   211      if err != nil {
   212          //NOTICE: allow overflow here
   213          if err.(*strconv.NumError).Err == strconv.ErrRange {
   214              return ret, 0, err
   215          }
   216          return -int(types.ERR_INVALID_CHAR), 0, err
   217      }
   218  
   219      runtime.KeepAlive(src)
   220      return ret, v, nil
   221  }
   222  
   223  func decodeValue(src string, pos int, skipnum bool) (ret int, v types.JsonState) {
   224      pos = skipBlank(src, pos)
   225      if pos < 0 {
   226          return pos, types.JsonState{Vt: types.ValueType(pos)}
   227      }
   228      switch c := src[pos]; c {
   229      case 'n':
   230          ret = decodeNull(src, pos)
   231          if ret < 0 {
   232              return ret, types.JsonState{Vt: types.ValueType(ret)}
   233          }
   234          return ret, types.JsonState{Vt: types.V_NULL}
   235      case '"':
   236          var ep int
   237          ret, ep = skipString(src, pos)
   238          if ret < 0 {
   239              return ret, types.JsonState{Vt: types.ValueType(ret)}
   240          }
   241          return ret, types.JsonState{Vt: types.V_STRING, Iv: int64(pos + 1), Ep: ep}
   242      case '{':
   243          return pos + 1, types.JsonState{Vt: types.V_OBJECT}
   244      case '[':
   245          return pos + 1, types.JsonState{Vt: types.V_ARRAY}
   246      case 't':
   247          ret = decodeTrue(src, pos)
   248          if ret < 0 {
   249              return ret, types.JsonState{Vt: types.ValueType(ret)}
   250          }
   251          return ret, types.JsonState{Vt: types.V_TRUE}
   252      case 'f':
   253          ret = decodeFalse(src, pos)
   254          if ret < 0 {
   255              return ret, types.JsonState{Vt: types.ValueType(ret)}
   256          }
   257          return ret, types.JsonState{Vt: types.V_FALSE}
   258      case '-', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
   259          if skipnum {
   260              ret = skipNumber(src, pos)
   261              if ret >= 0 {
   262                  return ret, types.JsonState{Vt: types.V_DOUBLE, Iv: 0, Ep: pos}
   263              } else {
   264                  return ret, types.JsonState{Vt: types.ValueType(ret)}
   265              }
   266          } else {
   267              var iv int64
   268              ret, iv, _ = decodeInt64(src, pos)
   269              if ret >= 0 {
   270                  return ret, types.JsonState{Vt: types.V_INTEGER, Iv: iv, Ep: pos}
   271              } else if ret != -int(types.ERR_INVALID_NUMBER_FMT) {
   272                  return ret, types.JsonState{Vt: types.ValueType(ret)}
   273              }
   274              var fv float64
   275              ret, fv, _ = decodeFloat64(src, pos)
   276              if ret >= 0 {
   277                  return ret, types.JsonState{Vt: types.V_DOUBLE, Dv: fv, Ep: pos}
   278              } else {
   279                  return ret, types.JsonState{Vt: types.ValueType(ret)}
   280              }
   281          }
   282          
   283      default:
   284          return -int(types.ERR_INVALID_CHAR), types.JsonState{Vt:-types.ValueType(types.ERR_INVALID_CHAR)}
   285      }
   286  }
   287  
   288  //go:nocheckptr
   289  func skipNumber(src string, pos int) (ret int) {
   290      sp := uintptr(rt.IndexChar(src, pos))
   291      se := uintptr(rt.IndexChar(src, len(src)))
   292      if uintptr(sp) >= se {
   293          return -int(types.ERR_EOF)
   294      }
   295  
   296      if c := *(*byte)(unsafe.Pointer(sp)); c == '-' {
   297          sp += 1
   298      }
   299      ss := sp
   300  
   301      var pointer bool
   302      var exponent bool
   303      var lastIsDigit bool
   304      var nextNeedDigit = true
   305  
   306      for ; sp < se; sp += uintptr(1) {
   307          c := *(*byte)(unsafe.Pointer(sp))
   308          if isDigit(c) {
   309              lastIsDigit = true
   310              nextNeedDigit = false
   311              continue
   312          } else if nextNeedDigit {
   313              return -int(types.ERR_INVALID_CHAR)
   314          } else if c == '.' {
   315              if !lastIsDigit || pointer || exponent || sp == ss {
   316                  return -int(types.ERR_INVALID_CHAR)
   317              }
   318              pointer = true
   319              lastIsDigit = false
   320              nextNeedDigit = true
   321              continue
   322          } else if c == 'e' || c == 'E' {
   323              if !lastIsDigit || exponent {
   324                  return -int(types.ERR_INVALID_CHAR)
   325              }
   326              if sp == se-1 {
   327                  return -int(types.ERR_EOF)
   328              }
   329              exponent = true
   330              lastIsDigit = false
   331              nextNeedDigit = false
   332              continue
   333          } else if c == '-' || c == '+' {
   334              if prev := *(*byte)(unsafe.Pointer(sp - 1)); prev != 'e' && prev != 'E' {
   335                  return -int(types.ERR_INVALID_CHAR)
   336              }
   337              lastIsDigit = false
   338              nextNeedDigit = true
   339              continue
   340          } else {
   341              break
   342          }
   343      }
   344  
   345      if nextNeedDigit {
   346          return -int(types.ERR_EOF)
   347      }
   348  
   349      runtime.KeepAlive(src)
   350      return int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr))
   351  }
   352  
   353  //go:nocheckptr
   354  func skipString(src string, pos int) (ret int, ep int) {
   355      if pos+1 >= len(src) {
   356          return -int(types.ERR_EOF), -1
   357      }
   358  
   359      sp := uintptr(rt.IndexChar(src, pos))
   360      se := uintptr(rt.IndexChar(src, len(src)))
   361  
   362      // not start with quote
   363      if *(*byte)(unsafe.Pointer(sp)) != '"' {
   364          return -int(types.ERR_INVALID_CHAR), -1
   365      }
   366      sp += 1
   367  
   368      ep = -1
   369      for sp < se {
   370          c := *(*byte)(unsafe.Pointer(sp))
   371          if c == '\\' {
   372              if ep == -1 {
   373                  ep = int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr))
   374              }
   375              sp += 2
   376              continue
   377          }
   378          sp += 1
   379          if c == '"' {
   380              return int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr)), ep
   381          }
   382      }
   383  
   384      runtime.KeepAlive(src)
   385      // not found the closed quote until EOF
   386      return -int(types.ERR_EOF), -1
   387  }
   388  
   389  //go:nocheckptr
   390  func skipPair(src string, pos int, lchar byte, rchar byte) (ret int) {
   391      if pos+1 >= len(src) {
   392          return -int(types.ERR_EOF)
   393      }
   394  
   395      sp := uintptr(rt.IndexChar(src, pos))
   396      se := uintptr(rt.IndexChar(src, len(src)))
   397  
   398      if *(*byte)(unsafe.Pointer(sp)) != lchar {
   399          return -int(types.ERR_INVALID_CHAR)
   400      }
   401  
   402      sp += 1
   403      nbrace := 1
   404      inquote := false
   405  
   406      for sp < se {
   407          c := *(*byte)(unsafe.Pointer(sp))
   408          if c == '\\' {
   409              sp += 2
   410              continue
   411          } else if c == '"' {
   412              inquote = !inquote
   413          } else if c == lchar {
   414              if !inquote {
   415                  nbrace += 1
   416              }
   417          } else if c == rchar {
   418              if !inquote {
   419                  nbrace -= 1
   420                  if nbrace == 0 {
   421                      sp += 1
   422                      break
   423                  }
   424              }
   425          }
   426          sp += 1
   427      }
   428  
   429      if nbrace != 0 {
   430          return -int(types.ERR_INVALID_CHAR)
   431      }
   432  
   433      runtime.KeepAlive(src)
   434      return int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr))
   435  }
   436  
   437  func skipValueFast(src string, pos int) (ret int, start int) {
   438      pos = skipBlank(src, pos)
   439      if pos < 0 {
   440          return pos, -1
   441      }
   442      switch c := src[pos]; c {
   443      case 'n':
   444          ret = decodeNull(src, pos)
   445      case '"':
   446          ret, _ = skipString(src, pos)
   447      case '{':
   448          ret = skipPair(src, pos, '{', '}')
   449      case '[':
   450          ret = skipPair(src, pos, '[', ']')
   451      case 't':
   452          ret = decodeTrue(src, pos)
   453      case 'f':
   454          ret = decodeFalse(src, pos)
   455      case '-', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
   456          ret = skipNumber(src, pos)
   457      default:
   458          ret = -int(types.ERR_INVALID_CHAR)
   459      }
   460      return ret, pos
   461  }
   462  
   463  func skipValue(src string, pos int) (ret int, start int) {
   464      pos = skipBlank(src, pos)
   465      if pos < 0 {
   466          return pos, -1
   467      }
   468      switch c := src[pos]; c {
   469      case 'n':
   470          ret = decodeNull(src, pos)
   471      case '"':
   472          ret, _ = skipString(src, pos)
   473      case '{':
   474          ret, _ = skipObject(src, pos)
   475      case '[':
   476          ret, _ = skipArray(src, pos)
   477      case 't':
   478          ret = decodeTrue(src, pos)
   479      case 'f':
   480          ret = decodeFalse(src, pos)
   481      case '-', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
   482          ret = skipNumber(src, pos)
   483      default:
   484          ret = -int(types.ERR_INVALID_CHAR)
   485      }
   486      return ret, pos
   487  }
   488  
   489  func skipObject(src string, pos int) (ret int, start int) {
   490      start = skipBlank(src, pos)
   491      if start < 0 {
   492          return start, -1
   493      }
   494  
   495      if src[start] != '{' {
   496          return -int(types.ERR_INVALID_CHAR), -1
   497      }
   498  
   499      pos = start + 1
   500      pos = skipBlank(src, pos)
   501      if pos < 0 {
   502          return pos, -1
   503      }
   504      if src[pos] == '}' {
   505          return pos + 1, start
   506      }
   507  
   508      for {
   509          pos, _ = skipString(src, pos)
   510          if pos < 0 {
   511              return pos, -1
   512          }
   513  
   514          pos = skipBlank(src, pos)
   515          if pos < 0 {
   516              return pos, -1
   517          }
   518          if src[pos] != ':' {
   519              return -int(types.ERR_INVALID_CHAR), -1
   520          }
   521  
   522          pos++
   523          pos, _ = skipValue(src, pos)
   524          if pos < 0 {
   525              return pos, -1
   526          }
   527  
   528          pos = skipBlank(src, pos)
   529          if pos < 0 {
   530              return pos, -1
   531          }
   532          if src[pos] == '}' {
   533              return pos + 1, start
   534          }
   535          if src[pos] != ',' {
   536              return -int(types.ERR_INVALID_CHAR), -1
   537          }
   538  
   539          pos++
   540          pos = skipBlank(src, pos)
   541          if pos < 0 {
   542              return pos, -1
   543          }
   544  
   545      }
   546  }
   547  
   548  func skipArray(src string, pos int) (ret int, start int) {
   549      start = skipBlank(src, pos)
   550      if start < 0 {
   551          return start, -1
   552      }
   553  
   554      if src[start] != '[' {
   555          return -int(types.ERR_INVALID_CHAR), -1
   556      }
   557  
   558      pos = start + 1
   559      pos = skipBlank(src, pos)
   560      if pos < 0 {
   561          return pos, -1
   562      }
   563      if src[pos] == ']' {
   564          return pos + 1, start
   565      }
   566  
   567      for {
   568          pos, _ = skipValue(src, pos)
   569          if pos < 0 {
   570              return pos, -1
   571          }
   572  
   573          pos = skipBlank(src, pos)
   574          if pos < 0 {
   575              return pos, -1
   576          }
   577          if src[pos] == ']' {
   578              return pos + 1, start
   579          }
   580          if src[pos] != ',' {
   581              return -int(types.ERR_INVALID_CHAR), -1
   582          }
   583          pos++
   584      }
   585  }
   586  
   587  // DecodeString decodes a JSON string from pos and return golang string.
   588  //   - needEsc indicates if to unescaped escaping chars
   589  //   - hasEsc tells if the returned string has escaping chars
   590  //   - validStr enables validating UTF8 charset
   591  //
   592  func _DecodeString(src string, pos int, needEsc bool, validStr bool) (v string, ret int, hasEsc bool) {
   593      p := NewParserObj(src)
   594      p.p = pos
   595      switch val := p.decodeValue(); val.Vt {
   596      case types.V_STRING:
   597          str := p.s[val.Iv : p.p-1]
   598          if validStr && !validate_utf8(str) {
   599             return "", -int(types.ERR_INVALID_UTF8), false
   600          }
   601          /* fast path: no escape sequence */
   602          if val.Ep == -1 {
   603              return str, p.p, false
   604          } else if !needEsc {
   605              return str, p.p, true
   606          }
   607          /* unquote the string */
   608          out, err := unquote(str)
   609          /* check for errors */
   610          if err != 0 {
   611              return "", -int(err), true
   612          } else {
   613              return out, p.p, true
   614          }
   615      default:
   616          return "", -int(_ERR_UNSUPPORT_TYPE), false
   617      }
   618  }