github.com/goshafaq/sonic@v0.0.0-20231026082336-871835fb94c6/native/utf8.h (about)

     1  /*
     2   * Copyright (C) 2019 Yaoyuan <ibireme@gmail.com>.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   * Copyright 2018-2023 The simdjson authors
    17   *
    18   * Licensed under the Apache License, Version 2.0 (the "License");
    19   * you may not use this file except in compliance with the License.
    20   * You may obtain a copy of the License at
    21  
    22   *     http://www.apache.org/licenses/LICENSE-2.0
    23  
    24   * Unless required by applicable law or agreed to in writing, software
    25   * distributed under the License is distributed on an "AS IS" BASIS,
    26   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    27   * See the License for the specific language governing permissions and
    28   * limitations under the License.
    29   * 
    30   * This file may have been modified by ByteDance authors. All ByteDance
    31   * Modifications are Copyright 2022 ByteDance Authors.
    32   */
    33  
    34  #pragma once
    35  
    36  #include "native.h"
    37  #include "utils.h"
    38  #include "test/xassert.h"
    39  #include "test/xprintf.h"
    40  
    41  static inline ssize_t valid_utf8_4byte(uint32_t ubin) {
    42      /*
    43       Each unicode code point is encoded as 1 to 4 bytes in UTF-8 encoding,
    44       we use 4-byte mask and pattern value to validate UTF-8 byte sequence,
    45       this requires the input data to have 4-byte zero padding.
    46       ---------------------------------------------------
    47       1 byte
    48       unicode range [U+0000, U+007F]
    49       unicode min   [.......0]
    50       unicode max   [.1111111]
    51       bit pattern   [0.......]
    52       ---------------------------------------------------
    53       2 byte
    54       unicode range [U+0080, U+07FF]
    55       unicode min   [......10 ..000000]
    56       unicode max   [...11111 ..111111]
    57       bit require   [...xxxx. ........] (1E 00)
    58       bit mask      [xxx..... xx......] (E0 C0)
    59       bit pattern   [110..... 10......] (C0 80)
    60       // 1101 0100 10110000
    61       // 0001 1110
    62       ---------------------------------------------------
    63       3 byte
    64       unicode range [U+0800, U+FFFF]
    65       unicode min   [........ ..100000 ..000000]
    66       unicode max   [....1111 ..111111 ..111111]
    67       bit require   [....xxxx ..x..... ........] (0F 20 00)
    68       bit mask      [xxxx.... xx...... xx......] (F0 C0 C0)
    69       bit pattern   [1110.... 10...... 10......] (E0 80 80)
    70       ---------------------------------------------------
    71       3 byte invalid (reserved for surrogate halves)
    72       unicode range [U+D800, U+DFFF]
    73       unicode min   [....1101 ..100000 ..000000]
    74       unicode max   [....1101 ..111111 ..111111]
    75       bit mask      [....xxxx ..x..... ........] (0F 20 00)
    76       bit pattern   [....1101 ..1..... ........] (0D 20 00)
    77       ---------------------------------------------------
    78       4 byte
    79       unicode range [U+10000, U+10FFFF]
    80       unicode min   [........ ...10000 ..000000 ..000000]
    81       unicode max   [.....100 ..001111 ..111111 ..111111]
    82       bit err0      [.....100 ........ ........ ........] (04 00 00 00)
    83       bit err1      [.....011 ..110000 ........ ........] (03 30 00 00)
    84       bit require   [.....xxx ..xx.... ........ ........] (07 30 00 00)
    85       bit mask      [xxxxx... xx...... xx...... xx......] (F8 C0 C0 C0)
    86       bit pattern   [11110... 10...... 10...... 10......] (F0 80 80 80)
    87       ---------------------------------------------------
    88       */
    89      const uint32_t b2_mask = 0x0000C0E0UL;
    90      const uint32_t b2_patt = 0x000080C0UL;
    91      const uint32_t b2_requ = 0x0000001EUL;
    92      const uint32_t b3_mask = 0x00C0C0F0UL;
    93      const uint32_t b3_patt = 0x008080E0UL;
    94      const uint32_t b3_requ = 0x0000200FUL;
    95      const uint32_t b3_erro = 0x0000200DUL;
    96      const uint32_t b4_mask = 0xC0C0C0F8UL;
    97      const uint32_t b4_patt = 0x808080F0UL;
    98      const uint32_t b4_requ = 0x00003007UL;
    99      const uint32_t b4_err0 = 0x00000004UL;
   100      const uint32_t b4_err1 = 0x00003003UL;
   101  
   102  #define is_valid_seq_2(uni) ( \
   103      ((uni & b2_mask) == b2_patt) && \
   104      ((uni & b2_requ)) \
   105  )
   106      
   107  #define is_valid_seq_3(uni) ( \
   108      ((uni & b3_mask) == b3_patt) && \
   109      ((tmp = (uni & b3_requ))) && \
   110      ((tmp != b3_erro)) \
   111  )
   112      
   113  #define is_valid_seq_4(uni) ( \
   114      ((uni & b4_mask) == b4_patt) && \
   115      ((tmp = (uni & b4_requ))) && \
   116      ((tmp & b4_err0) == 0 || (tmp & b4_err1) == 0) \
   117  )
   118      uint32_t tmp = 0;
   119     
   120      if (is_valid_seq_3(ubin)) return 3;
   121      if (is_valid_seq_2(ubin)) return 2;
   122      if (is_valid_seq_4(ubin)) return 4;
   123      return 0;
   124  }
   125  
   126  static always_inline long write_error(int pos, StateMachine *m, size_t msize) {
   127      if (m->sp >= msize) {
   128          return -1;
   129      }
   130      m->vt[m->sp++] = pos;
   131      return 0;
   132  }
   133  
   134  // scalar code, error position should excesss 4096
   135  static always_inline long validate_utf8_with_errors(const char *src, long len, long *p, StateMachine *m) {
   136      const char* start = src + *p;
   137      const char* end = src + len;
   138      while (start < end - 3) {
   139          uint32_t u = (*(uint32_t*)(start));
   140          if ((unsigned)(*start) < 0x80) {
   141              start += 1;
   142              continue;
   143          }
   144          size_t n = valid_utf8_4byte(u);
   145          if (n != 0) { // valid utf
   146              start += n;
   147              continue;
   148          }
   149          long err = write_error(start - src, m, MAX_RECURSE);
   150          if (err) {
   151              *p = start - src;
   152              return err;
   153          }
   154          start += 1;
   155      }
   156      while (start < end) {
   157          if ((unsigned)(*start) < 0x80) {
   158              start += 1;
   159              continue;
   160          }
   161          uint32_t u = 0;
   162          memcpy_p4(&u, start, end - start);
   163          size_t n = valid_utf8_4byte(u);
   164          if (n != 0) { // valid utf
   165              start += n;
   166              continue;
   167          }
   168          long err = write_error(start - src, m, MAX_RECURSE);
   169          if (err) {
   170              *p = start - src;
   171              return err;
   172          }
   173          start += 1;
   174      }
   175      *p = start - src;
   176      return 0;
   177  }
   178  
   179  // validate_utf8_errors returns zero if valid, otherwise, the error position.
   180  static always_inline long validate_utf8_errors(const GoString* s) {
   181      const char* start = s->buf;
   182      const char* end = s->buf + s->len;
   183      while (start < end - 3) {
   184          uint32_t u = (*(uint32_t*)(start));
   185          if ((unsigned)(*start) < 0x80) {
   186              start += 1;
   187              continue;
   188          }
   189          size_t n = valid_utf8_4byte(u);
   190          if (n == 0) { // invalid utf
   191              return -(start - s->buf) - 1;
   192          }
   193          start += n;
   194      }
   195      while (start < end) {
   196          if ((unsigned)(*start) < 0x80) {
   197              start += 1;
   198              continue;
   199          }
   200          uint32_t u = 0;
   201          memcpy_p4(&u, start, end - start);
   202          size_t n = valid_utf8_4byte(u);
   203          if (n == 0) { // invalid utf
   204              return -(start - s->buf) - 1;
   205          }
   206          start += n;
   207      }
   208      return 0;
   209  }
   210  
   211  // SIMD implementation
   212  #if USE_AVX2
   213  
   214      static always_inline __m256i simd256_shr(const __m256i input, const int shift) {
   215          __m256i shifted = _mm256_srli_epi16(input, shift);
   216          __m256i mask = _mm256_set1_epi8(0xFFu >> shift);
   217          return _mm256_and_si256(shifted, mask);
   218      }
   219  
   220  #define simd256_prev(input, prev, N) _mm256_alignr_epi8(input, _mm256_permute2x128_si256(prev, input, 0x21), 16 - (N));
   221  
   222      static always_inline __m256i must_be_2_3_continuation(const __m256i prev2, const __m256i prev3) {
   223          __m256i is_third_byte  = _mm256_subs_epu8(prev2, _mm256_set1_epi8(0b11100000u-1)); // Only 111_____ will be > 0
   224          __m256i is_fourth_byte = _mm256_subs_epu8(prev3, _mm256_set1_epi8(0b11110000u-1)); // Only 1111____ will be > 0
   225          // Caller requires a bool (all 1's). All values resulting from the subtraction will be <= 64, so signed comparison is fine.
   226          __m256i or = _mm256_or_si256(is_third_byte, is_fourth_byte);
   227          return _mm256_cmpgt_epi8(or, _mm256_set1_epi8(0));;
   228      }
   229  
   230      static always_inline __m256i simd256_lookup16(const __m256i input, const uint8_t* table) {
   231          return _mm256_shuffle_epi8(_mm256_setr_epi8(table[0], table[1], table[2], table[3], table[4], table[5], table[6], table[7], table[8], table[9], table[10], table[11], table[12], table[13], table[14], table[15], table[0], table[1], table[2], table[3], table[4], table[5], table[6], table[7], table[8], table[9], table[10], table[11], table[12], table[13], table[14], table[15]), input);
   232      }
   233  
   234    //
   235    // Return nonzero if there are incomplete multibyte characters at the end of the block:
   236    // e.g. if there is a 4-byte character, but it's 3 bytes from the end.
   237    //
   238        static always_inline  __m256i is_incomplete(const __m256i input) {
   239      // If the previous input's last 3 bytes match this, they're too short (they ended at EOF):
   240      // ... 1111____ 111_____ 11______
   241        const uint8_t tab[32] = {
   242        255, 255, 255, 255, 255, 255, 255, 255,
   243        255, 255, 255, 255, 255, 255, 255, 255,
   244        255, 255, 255, 255, 255, 255, 255, 255,
   245        255, 255, 255, 255, 255, 0b11110000u-1, 0b11100000u-1, 0b11000000u-1};
   246          const __m256i max_value = _mm256_loadu_si256((const __m256i_u *)(&tab[0]));
   247          return _mm256_subs_epu8(input, max_value);
   248      }
   249  
   250    static always_inline __m256i check_special_cases(const __m256i input, const __m256i prev1) {
   251      // Bit 0 = Too Short (lead byte/ASCII followed by lead byte/ASCII)
   252      // Bit 1 = Too Long (ASCII followed by continuation)
   253      // Bit 2 = Overlong 3-byte
   254      // Bit 4 = Surrogate
   255      // Bit 5 = Overlong 2-byte
   256      // Bit 7 = Two Continuations
   257       const uint8_t TOO_SHORT   = 1<<0; // 11______ 0_______
   258                                                  // 11______ 11______
   259       const uint8_t TOO_LONG    = 1<<1; // 0_______ 10______
   260       const uint8_t OVERLONG_3  = 1<<2; // 11100000 100_____
   261       const uint8_t SURROGATE   = 1<<4; // 11101101 101_____
   262       const uint8_t OVERLONG_2  = 1<<5; // 1100000_ 10______
   263       const uint8_t TWO_CONTS   = 1<<7; // 10______ 10______
   264       const uint8_t TOO_LARGE   = 1<<3; // 11110100 1001____
   265                                                  // 11110100 101_____
   266                                                  // 11110101 1001____
   267                                                  // 11110101 101_____
   268                                                  // 1111011_ 1001____
   269                                                  // 1111011_ 101_____
   270                                                  // 11111___ 1001____
   271                                                  // 11111___ 101_____
   272       const uint8_t TOO_LARGE_1000 = 1<<6;
   273                                                  // 11110101 1000____
   274                                                  // 1111011_ 1000____
   275                                                  // 11111___ 1000____
   276       const uint8_t OVERLONG_4  = 1<<6; // 11110000 1000____
   277  
   278      const __m256i prev1_shr4 = simd256_shr(prev1, 4);
   279      static const uint8_t tab1[16] = {
   280                // 0_______ ________ <ASCII in byte 1>
   281        TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
   282        TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
   283        // 10______ ________ <continuation in byte 1>
   284        TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS,
   285        // 1100____ ________ <two byte lead in byte 1>
   286        TOO_SHORT | OVERLONG_2,
   287        // 1101____ ________ <two byte lead in byte 1>
   288        TOO_SHORT,
   289        // 1110____ ________ <three byte lead in byte 1>
   290        TOO_SHORT | OVERLONG_3 | SURROGATE,
   291        // 1111____ ________ <four+ byte lead in byte 1>
   292        TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4,
   293      };
   294      __m256i byte_1_high = simd256_lookup16(prev1_shr4, tab1);
   295      
   296  
   297      const uint8_t CARRY = TOO_SHORT | TOO_LONG | TWO_CONTS; // These all have ____ in byte 1 .
   298      __m256i prev1_low = _mm256_and_si256(prev1, _mm256_set1_epi8(0x0F));
   299      static const uint8_t tab2[16] = {
   300        // ____0000 ________
   301        CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4,
   302        // ____0001 ________
   303        CARRY | OVERLONG_2,
   304        // ____001_ ________
   305        CARRY,
   306        CARRY,
   307  
   308        // ____0100 ________
   309        CARRY | TOO_LARGE,
   310        // ____0101 ________
   311        CARRY | TOO_LARGE | TOO_LARGE_1000,
   312        // ____011_ ________
   313        CARRY | TOO_LARGE | TOO_LARGE_1000,
   314        CARRY | TOO_LARGE | TOO_LARGE_1000,
   315  
   316        // ____1___ ________
   317        CARRY | TOO_LARGE | TOO_LARGE_1000,
   318        CARRY | TOO_LARGE | TOO_LARGE_1000,
   319        CARRY | TOO_LARGE | TOO_LARGE_1000,
   320        CARRY | TOO_LARGE | TOO_LARGE_1000,
   321        CARRY | TOO_LARGE | TOO_LARGE_1000,
   322        // ____1101 ________
   323        CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE,
   324        CARRY | TOO_LARGE | TOO_LARGE_1000,
   325        CARRY | TOO_LARGE | TOO_LARGE_1000
   326      };
   327      __m256i byte_1_low = simd256_lookup16(prev1_low, tab2);
   328      
   329  
   330      const __m256i input_shr4 = simd256_shr(input, 4);
   331      static const uint8_t tab3[16] = {
   332        // ________ 0_______ <ASCII in byte 2>
   333        TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
   334        TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
   335  
   336        // ________ 1000____
   337        TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4,
   338        // ________ 1001____
   339        TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE,
   340        // ________ 101_____
   341        TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE  | TOO_LARGE,
   342        TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE  | TOO_LARGE,
   343  
   344        // ________ 11______
   345        TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT
   346      };
   347      __m256i byte_2_high = simd256_lookup16(input_shr4, tab3);
   348       
   349  
   350      return _mm256_and_si256(_mm256_and_si256(byte_1_high, byte_1_low), byte_2_high);
   351    }
   352  
   353      static always_inline __m256i check_multibyte_lengths(const __m256i input, const __m256i prev_input, const __m256i sc) {
   354      __m256i prev2 = simd256_prev(input, prev_input, 2);
   355      __m256i prev3 = simd256_prev(input, prev_input, 3);
   356      
   357      
   358      __m256i must23 = must_be_2_3_continuation(prev2, prev3);
   359      
   360      __m256i must23_80 = _mm256_and_si256(must23, _mm256_set1_epi8(0x80));
   361      
   362      return _mm256_xor_si256(must23_80, sc);
   363    }
   364  
   365  
   366      // Check whether the current bytes are valid UTF-8.
   367      static always_inline __m256i check_utf8_bytes(const __m256i input, const __m256i prev_input) {
   368          // Flip prev1...prev3 so we can easily determine if they are 2+, 3+ or 4+ lead bytes
   369          // (2, 3, 4-byte leads become large positive numbers instead of small negative numbers)
   370          __m256i prev1 = simd256_prev(input, prev_input, 1);
   371          __m256i sc    = check_special_cases(input, prev1);
   372          __m256i ret  = check_multibyte_lengths(input, prev_input, sc);
   373          return ret;
   374      }
   375  
   376      static always_inline bool is_ascii(const __m256i input) {
   377        return _mm256_movemask_epi8(input) == 0;
   378      }
   379  
   380      typedef struct {
   381          // If this is nonzero, there has been a UTF-8 error.
   382          __m256i error;
   383          // The last input we received
   384          __m256i prev_input_block;
   385          // Whether the last input we received was incomplete (used for ASCII fast path)
   386          __m256i prev_incomplete;
   387      } utf8_checker;
   388  
   389      static always_inline void utf8_checker_init(utf8_checker* checker) {
   390          checker->error = _mm256_setzero_si256();
   391          checker->prev_input_block = _mm256_setzero_si256();
   392          checker->prev_incomplete = _mm256_setzero_si256();
   393      }
   394      
   395      static always_inline bool check_error(utf8_checker* checker) {
   396          return !_mm256_testz_si256(checker->error, checker->error);
   397      }
   398  
   399      static always_inline void check64_utf(utf8_checker* checker, const uint8_t* start) {
   400          __m256i input = _mm256_loadu_si256((__m256i*)start);
   401          __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32));
   402          // check utf-8 chars
   403          __m256i error1 = check_utf8_bytes(input, checker->prev_input_block);
   404          __m256i error2 = check_utf8_bytes(input2, input);
   405          checker->error = _mm256_or_si256(checker->error, _mm256_or_si256(error1, error2));
   406          checker->prev_input_block = input2;
   407          checker->prev_incomplete = is_incomplete(input2);
   408      }
   409  
   410      static always_inline void check64(utf8_checker* checker, const uint8_t* start) {
   411          // fast path for contiguous ASCII
   412          __m256i input = _mm256_loadu_si256((__m256i*)start);
   413          __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32));
   414          __m256i reducer = _mm256_or_si256(input, input2);
   415          // check utf-8
   416          if (likely(is_ascii(reducer))) {
   417              checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
   418              return;
   419          }
   420          check64_utf(checker, start);
   421      }
   422  
   423      static always_inline void check128(utf8_checker* checker, const uint8_t* start) {
   424          // fast path for contiguous ASCII
   425          __m256i input = _mm256_loadu_si256((__m256i*)start);
   426          __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32));
   427          __m256i input3 = _mm256_loadu_si256((__m256i*)(start + 64));
   428          __m256i input4 = _mm256_loadu_si256((__m256i*)(start + 96));
   429          
   430          __m256i reducer1 = _mm256_or_si256(input, input2);
   431          __m256i reducer2 = _mm256_or_si256(input3, input4);
   432          __m256i reducer  = _mm256_or_si256(reducer1, reducer2);
   433  
   434          // full 128 bytes are ascii
   435          if (likely(is_ascii(reducer))) {
   436              checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
   437              return;
   438          }
   439  
   440          // frist 64 bytes is ascii, next 64 bytes must be utf8
   441          if (likely(is_ascii(reducer1))) {
   442              checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
   443              check64_utf(checker, start + 64);
   444              return;
   445          }
   446  
   447          // frist 64 bytes has utf8, next 64 bytes 
   448          check64_utf(checker, start);
   449          if (unlikely(is_ascii(reducer2))) {
   450              checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
   451          } else {
   452              check64_utf(checker, start + 64);
   453          }
   454      }
   455  
   456      static always_inline void check_eof(utf8_checker* checker) {
   457          checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete);
   458      }
   459  
   460      static always_inline void check_remain(utf8_checker* checker, const uint8_t* start, const uint8_t* end) {
   461          uint8_t buffer[64] = {0};
   462          int i = 0;
   463          while (start < end) {
   464              buffer[i++] = *(start++);
   465          };
   466          check64(checker, buffer);
   467          check_eof(checker);
   468      }
   469  
   470      static always_inline long validate_utf8_avx2(const GoString* s) {
   471          xassert(s->buf != NULL || s->len != 0);
   472          const uint8_t* start = (const uint8_t*)(s->buf);
   473          const uint8_t* end   = (const uint8_t*)(s->buf + s->len);
   474          /* check eof */
   475          if (s->len == 0) {
   476              return 0;
   477          }
   478          utf8_checker checker;
   479          utf8_checker_init(&checker);
   480          while (start < (end - 128)) {
   481              check128(&checker, start);
   482              if (check_error(&checker)) {
   483              }
   484              start += 128;
   485          };
   486          while (start < end - 64) {
   487              check64(&checker, start);
   488              start += 64;
   489          }
   490          check_remain(&checker, start, end);
   491          return check_error(&checker) ? -1 : 0;
   492      }
   493  #endif