github.com/goshafaq/sonic@v0.0.0-20231026082336-871835fb94c6/native/utf8.h (about) 1 /* 2 * Copyright (C) 2019 Yaoyuan <ibireme@gmail.com>. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * 16 * Copyright 2018-2023 The simdjson authors 17 * 18 * Licensed under the Apache License, Version 2.0 (the "License"); 19 * you may not use this file except in compliance with the License. 20 * You may obtain a copy of the License at 21 22 * http://www.apache.org/licenses/LICENSE-2.0 23 24 * Unless required by applicable law or agreed to in writing, software 25 * distributed under the License is distributed on an "AS IS" BASIS, 26 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 27 * See the License for the specific language governing permissions and 28 * limitations under the License. 29 * 30 * This file may have been modified by ByteDance authors. All ByteDance 31 * Modifications are Copyright 2022 ByteDance Authors. 32 */ 33 34 #pragma once 35 36 #include "native.h" 37 #include "utils.h" 38 #include "test/xassert.h" 39 #include "test/xprintf.h" 40 41 static inline ssize_t valid_utf8_4byte(uint32_t ubin) { 42 /* 43 Each unicode code point is encoded as 1 to 4 bytes in UTF-8 encoding, 44 we use 4-byte mask and pattern value to validate UTF-8 byte sequence, 45 this requires the input data to have 4-byte zero padding. 46 --------------------------------------------------- 47 1 byte 48 unicode range [U+0000, U+007F] 49 unicode min [.......0] 50 unicode max [.1111111] 51 bit pattern [0.......] 52 --------------------------------------------------- 53 2 byte 54 unicode range [U+0080, U+07FF] 55 unicode min [......10 ..000000] 56 unicode max [...11111 ..111111] 57 bit require [...xxxx. ........] (1E 00) 58 bit mask [xxx..... xx......] (E0 C0) 59 bit pattern [110..... 10......] (C0 80) 60 // 1101 0100 10110000 61 // 0001 1110 62 --------------------------------------------------- 63 3 byte 64 unicode range [U+0800, U+FFFF] 65 unicode min [........ ..100000 ..000000] 66 unicode max [....1111 ..111111 ..111111] 67 bit require [....xxxx ..x..... ........] (0F 20 00) 68 bit mask [xxxx.... xx...... xx......] (F0 C0 C0) 69 bit pattern [1110.... 10...... 10......] (E0 80 80) 70 --------------------------------------------------- 71 3 byte invalid (reserved for surrogate halves) 72 unicode range [U+D800, U+DFFF] 73 unicode min [....1101 ..100000 ..000000] 74 unicode max [....1101 ..111111 ..111111] 75 bit mask [....xxxx ..x..... ........] (0F 20 00) 76 bit pattern [....1101 ..1..... ........] (0D 20 00) 77 --------------------------------------------------- 78 4 byte 79 unicode range [U+10000, U+10FFFF] 80 unicode min [........ ...10000 ..000000 ..000000] 81 unicode max [.....100 ..001111 ..111111 ..111111] 82 bit err0 [.....100 ........ ........ ........] (04 00 00 00) 83 bit err1 [.....011 ..110000 ........ ........] (03 30 00 00) 84 bit require [.....xxx ..xx.... ........ ........] (07 30 00 00) 85 bit mask [xxxxx... xx...... xx...... xx......] (F8 C0 C0 C0) 86 bit pattern [11110... 10...... 10...... 10......] (F0 80 80 80) 87 --------------------------------------------------- 88 */ 89 const uint32_t b2_mask = 0x0000C0E0UL; 90 const uint32_t b2_patt = 0x000080C0UL; 91 const uint32_t b2_requ = 0x0000001EUL; 92 const uint32_t b3_mask = 0x00C0C0F0UL; 93 const uint32_t b3_patt = 0x008080E0UL; 94 const uint32_t b3_requ = 0x0000200FUL; 95 const uint32_t b3_erro = 0x0000200DUL; 96 const uint32_t b4_mask = 0xC0C0C0F8UL; 97 const uint32_t b4_patt = 0x808080F0UL; 98 const uint32_t b4_requ = 0x00003007UL; 99 const uint32_t b4_err0 = 0x00000004UL; 100 const uint32_t b4_err1 = 0x00003003UL; 101 102 #define is_valid_seq_2(uni) ( \ 103 ((uni & b2_mask) == b2_patt) && \ 104 ((uni & b2_requ)) \ 105 ) 106 107 #define is_valid_seq_3(uni) ( \ 108 ((uni & b3_mask) == b3_patt) && \ 109 ((tmp = (uni & b3_requ))) && \ 110 ((tmp != b3_erro)) \ 111 ) 112 113 #define is_valid_seq_4(uni) ( \ 114 ((uni & b4_mask) == b4_patt) && \ 115 ((tmp = (uni & b4_requ))) && \ 116 ((tmp & b4_err0) == 0 || (tmp & b4_err1) == 0) \ 117 ) 118 uint32_t tmp = 0; 119 120 if (is_valid_seq_3(ubin)) return 3; 121 if (is_valid_seq_2(ubin)) return 2; 122 if (is_valid_seq_4(ubin)) return 4; 123 return 0; 124 } 125 126 static always_inline long write_error(int pos, StateMachine *m, size_t msize) { 127 if (m->sp >= msize) { 128 return -1; 129 } 130 m->vt[m->sp++] = pos; 131 return 0; 132 } 133 134 // scalar code, error position should excesss 4096 135 static always_inline long validate_utf8_with_errors(const char *src, long len, long *p, StateMachine *m) { 136 const char* start = src + *p; 137 const char* end = src + len; 138 while (start < end - 3) { 139 uint32_t u = (*(uint32_t*)(start)); 140 if ((unsigned)(*start) < 0x80) { 141 start += 1; 142 continue; 143 } 144 size_t n = valid_utf8_4byte(u); 145 if (n != 0) { // valid utf 146 start += n; 147 continue; 148 } 149 long err = write_error(start - src, m, MAX_RECURSE); 150 if (err) { 151 *p = start - src; 152 return err; 153 } 154 start += 1; 155 } 156 while (start < end) { 157 if ((unsigned)(*start) < 0x80) { 158 start += 1; 159 continue; 160 } 161 uint32_t u = 0; 162 memcpy_p4(&u, start, end - start); 163 size_t n = valid_utf8_4byte(u); 164 if (n != 0) { // valid utf 165 start += n; 166 continue; 167 } 168 long err = write_error(start - src, m, MAX_RECURSE); 169 if (err) { 170 *p = start - src; 171 return err; 172 } 173 start += 1; 174 } 175 *p = start - src; 176 return 0; 177 } 178 179 // validate_utf8_errors returns zero if valid, otherwise, the error position. 180 static always_inline long validate_utf8_errors(const GoString* s) { 181 const char* start = s->buf; 182 const char* end = s->buf + s->len; 183 while (start < end - 3) { 184 uint32_t u = (*(uint32_t*)(start)); 185 if ((unsigned)(*start) < 0x80) { 186 start += 1; 187 continue; 188 } 189 size_t n = valid_utf8_4byte(u); 190 if (n == 0) { // invalid utf 191 return -(start - s->buf) - 1; 192 } 193 start += n; 194 } 195 while (start < end) { 196 if ((unsigned)(*start) < 0x80) { 197 start += 1; 198 continue; 199 } 200 uint32_t u = 0; 201 memcpy_p4(&u, start, end - start); 202 size_t n = valid_utf8_4byte(u); 203 if (n == 0) { // invalid utf 204 return -(start - s->buf) - 1; 205 } 206 start += n; 207 } 208 return 0; 209 } 210 211 // SIMD implementation 212 #if USE_AVX2 213 214 static always_inline __m256i simd256_shr(const __m256i input, const int shift) { 215 __m256i shifted = _mm256_srli_epi16(input, shift); 216 __m256i mask = _mm256_set1_epi8(0xFFu >> shift); 217 return _mm256_and_si256(shifted, mask); 218 } 219 220 #define simd256_prev(input, prev, N) _mm256_alignr_epi8(input, _mm256_permute2x128_si256(prev, input, 0x21), 16 - (N)); 221 222 static always_inline __m256i must_be_2_3_continuation(const __m256i prev2, const __m256i prev3) { 223 __m256i is_third_byte = _mm256_subs_epu8(prev2, _mm256_set1_epi8(0b11100000u-1)); // Only 111_____ will be > 0 224 __m256i is_fourth_byte = _mm256_subs_epu8(prev3, _mm256_set1_epi8(0b11110000u-1)); // Only 1111____ will be > 0 225 // Caller requires a bool (all 1's). All values resulting from the subtraction will be <= 64, so signed comparison is fine. 226 __m256i or = _mm256_or_si256(is_third_byte, is_fourth_byte); 227 return _mm256_cmpgt_epi8(or, _mm256_set1_epi8(0));; 228 } 229 230 static always_inline __m256i simd256_lookup16(const __m256i input, const uint8_t* table) { 231 return _mm256_shuffle_epi8(_mm256_setr_epi8(table[0], table[1], table[2], table[3], table[4], table[5], table[6], table[7], table[8], table[9], table[10], table[11], table[12], table[13], table[14], table[15], table[0], table[1], table[2], table[3], table[4], table[5], table[6], table[7], table[8], table[9], table[10], table[11], table[12], table[13], table[14], table[15]), input); 232 } 233 234 // 235 // Return nonzero if there are incomplete multibyte characters at the end of the block: 236 // e.g. if there is a 4-byte character, but it's 3 bytes from the end. 237 // 238 static always_inline __m256i is_incomplete(const __m256i input) { 239 // If the previous input's last 3 bytes match this, they're too short (they ended at EOF): 240 // ... 1111____ 111_____ 11______ 241 const uint8_t tab[32] = { 242 255, 255, 255, 255, 255, 255, 255, 255, 243 255, 255, 255, 255, 255, 255, 255, 255, 244 255, 255, 255, 255, 255, 255, 255, 255, 245 255, 255, 255, 255, 255, 0b11110000u-1, 0b11100000u-1, 0b11000000u-1}; 246 const __m256i max_value = _mm256_loadu_si256((const __m256i_u *)(&tab[0])); 247 return _mm256_subs_epu8(input, max_value); 248 } 249 250 static always_inline __m256i check_special_cases(const __m256i input, const __m256i prev1) { 251 // Bit 0 = Too Short (lead byte/ASCII followed by lead byte/ASCII) 252 // Bit 1 = Too Long (ASCII followed by continuation) 253 // Bit 2 = Overlong 3-byte 254 // Bit 4 = Surrogate 255 // Bit 5 = Overlong 2-byte 256 // Bit 7 = Two Continuations 257 const uint8_t TOO_SHORT = 1<<0; // 11______ 0_______ 258 // 11______ 11______ 259 const uint8_t TOO_LONG = 1<<1; // 0_______ 10______ 260 const uint8_t OVERLONG_3 = 1<<2; // 11100000 100_____ 261 const uint8_t SURROGATE = 1<<4; // 11101101 101_____ 262 const uint8_t OVERLONG_2 = 1<<5; // 1100000_ 10______ 263 const uint8_t TWO_CONTS = 1<<7; // 10______ 10______ 264 const uint8_t TOO_LARGE = 1<<3; // 11110100 1001____ 265 // 11110100 101_____ 266 // 11110101 1001____ 267 // 11110101 101_____ 268 // 1111011_ 1001____ 269 // 1111011_ 101_____ 270 // 11111___ 1001____ 271 // 11111___ 101_____ 272 const uint8_t TOO_LARGE_1000 = 1<<6; 273 // 11110101 1000____ 274 // 1111011_ 1000____ 275 // 11111___ 1000____ 276 const uint8_t OVERLONG_4 = 1<<6; // 11110000 1000____ 277 278 const __m256i prev1_shr4 = simd256_shr(prev1, 4); 279 static const uint8_t tab1[16] = { 280 // 0_______ ________ <ASCII in byte 1> 281 TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG, 282 TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG, 283 // 10______ ________ <continuation in byte 1> 284 TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS, 285 // 1100____ ________ <two byte lead in byte 1> 286 TOO_SHORT | OVERLONG_2, 287 // 1101____ ________ <two byte lead in byte 1> 288 TOO_SHORT, 289 // 1110____ ________ <three byte lead in byte 1> 290 TOO_SHORT | OVERLONG_3 | SURROGATE, 291 // 1111____ ________ <four+ byte lead in byte 1> 292 TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4, 293 }; 294 __m256i byte_1_high = simd256_lookup16(prev1_shr4, tab1); 295 296 297 const uint8_t CARRY = TOO_SHORT | TOO_LONG | TWO_CONTS; // These all have ____ in byte 1 . 298 __m256i prev1_low = _mm256_and_si256(prev1, _mm256_set1_epi8(0x0F)); 299 static const uint8_t tab2[16] = { 300 // ____0000 ________ 301 CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4, 302 // ____0001 ________ 303 CARRY | OVERLONG_2, 304 // ____001_ ________ 305 CARRY, 306 CARRY, 307 308 // ____0100 ________ 309 CARRY | TOO_LARGE, 310 // ____0101 ________ 311 CARRY | TOO_LARGE | TOO_LARGE_1000, 312 // ____011_ ________ 313 CARRY | TOO_LARGE | TOO_LARGE_1000, 314 CARRY | TOO_LARGE | TOO_LARGE_1000, 315 316 // ____1___ ________ 317 CARRY | TOO_LARGE | TOO_LARGE_1000, 318 CARRY | TOO_LARGE | TOO_LARGE_1000, 319 CARRY | TOO_LARGE | TOO_LARGE_1000, 320 CARRY | TOO_LARGE | TOO_LARGE_1000, 321 CARRY | TOO_LARGE | TOO_LARGE_1000, 322 // ____1101 ________ 323 CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE, 324 CARRY | TOO_LARGE | TOO_LARGE_1000, 325 CARRY | TOO_LARGE | TOO_LARGE_1000 326 }; 327 __m256i byte_1_low = simd256_lookup16(prev1_low, tab2); 328 329 330 const __m256i input_shr4 = simd256_shr(input, 4); 331 static const uint8_t tab3[16] = { 332 // ________ 0_______ <ASCII in byte 2> 333 TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, 334 TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, 335 336 // ________ 1000____ 337 TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4, 338 // ________ 1001____ 339 TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE, 340 // ________ 101_____ 341 TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, 342 TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, 343 344 // ________ 11______ 345 TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT 346 }; 347 __m256i byte_2_high = simd256_lookup16(input_shr4, tab3); 348 349 350 return _mm256_and_si256(_mm256_and_si256(byte_1_high, byte_1_low), byte_2_high); 351 } 352 353 static always_inline __m256i check_multibyte_lengths(const __m256i input, const __m256i prev_input, const __m256i sc) { 354 __m256i prev2 = simd256_prev(input, prev_input, 2); 355 __m256i prev3 = simd256_prev(input, prev_input, 3); 356 357 358 __m256i must23 = must_be_2_3_continuation(prev2, prev3); 359 360 __m256i must23_80 = _mm256_and_si256(must23, _mm256_set1_epi8(0x80)); 361 362 return _mm256_xor_si256(must23_80, sc); 363 } 364 365 366 // Check whether the current bytes are valid UTF-8. 367 static always_inline __m256i check_utf8_bytes(const __m256i input, const __m256i prev_input) { 368 // Flip prev1...prev3 so we can easily determine if they are 2+, 3+ or 4+ lead bytes 369 // (2, 3, 4-byte leads become large positive numbers instead of small negative numbers) 370 __m256i prev1 = simd256_prev(input, prev_input, 1); 371 __m256i sc = check_special_cases(input, prev1); 372 __m256i ret = check_multibyte_lengths(input, prev_input, sc); 373 return ret; 374 } 375 376 static always_inline bool is_ascii(const __m256i input) { 377 return _mm256_movemask_epi8(input) == 0; 378 } 379 380 typedef struct { 381 // If this is nonzero, there has been a UTF-8 error. 382 __m256i error; 383 // The last input we received 384 __m256i prev_input_block; 385 // Whether the last input we received was incomplete (used for ASCII fast path) 386 __m256i prev_incomplete; 387 } utf8_checker; 388 389 static always_inline void utf8_checker_init(utf8_checker* checker) { 390 checker->error = _mm256_setzero_si256(); 391 checker->prev_input_block = _mm256_setzero_si256(); 392 checker->prev_incomplete = _mm256_setzero_si256(); 393 } 394 395 static always_inline bool check_error(utf8_checker* checker) { 396 return !_mm256_testz_si256(checker->error, checker->error); 397 } 398 399 static always_inline void check64_utf(utf8_checker* checker, const uint8_t* start) { 400 __m256i input = _mm256_loadu_si256((__m256i*)start); 401 __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32)); 402 // check utf-8 chars 403 __m256i error1 = check_utf8_bytes(input, checker->prev_input_block); 404 __m256i error2 = check_utf8_bytes(input2, input); 405 checker->error = _mm256_or_si256(checker->error, _mm256_or_si256(error1, error2)); 406 checker->prev_input_block = input2; 407 checker->prev_incomplete = is_incomplete(input2); 408 } 409 410 static always_inline void check64(utf8_checker* checker, const uint8_t* start) { 411 // fast path for contiguous ASCII 412 __m256i input = _mm256_loadu_si256((__m256i*)start); 413 __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32)); 414 __m256i reducer = _mm256_or_si256(input, input2); 415 // check utf-8 416 if (likely(is_ascii(reducer))) { 417 checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete); 418 return; 419 } 420 check64_utf(checker, start); 421 } 422 423 static always_inline void check128(utf8_checker* checker, const uint8_t* start) { 424 // fast path for contiguous ASCII 425 __m256i input = _mm256_loadu_si256((__m256i*)start); 426 __m256i input2 = _mm256_loadu_si256((__m256i*)(start + 32)); 427 __m256i input3 = _mm256_loadu_si256((__m256i*)(start + 64)); 428 __m256i input4 = _mm256_loadu_si256((__m256i*)(start + 96)); 429 430 __m256i reducer1 = _mm256_or_si256(input, input2); 431 __m256i reducer2 = _mm256_or_si256(input3, input4); 432 __m256i reducer = _mm256_or_si256(reducer1, reducer2); 433 434 // full 128 bytes are ascii 435 if (likely(is_ascii(reducer))) { 436 checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete); 437 return; 438 } 439 440 // frist 64 bytes is ascii, next 64 bytes must be utf8 441 if (likely(is_ascii(reducer1))) { 442 checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete); 443 check64_utf(checker, start + 64); 444 return; 445 } 446 447 // frist 64 bytes has utf8, next 64 bytes 448 check64_utf(checker, start); 449 if (unlikely(is_ascii(reducer2))) { 450 checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete); 451 } else { 452 check64_utf(checker, start + 64); 453 } 454 } 455 456 static always_inline void check_eof(utf8_checker* checker) { 457 checker->error = _mm256_or_si256(checker->error, checker->prev_incomplete); 458 } 459 460 static always_inline void check_remain(utf8_checker* checker, const uint8_t* start, const uint8_t* end) { 461 uint8_t buffer[64] = {0}; 462 int i = 0; 463 while (start < end) { 464 buffer[i++] = *(start++); 465 }; 466 check64(checker, buffer); 467 check_eof(checker); 468 } 469 470 static always_inline long validate_utf8_avx2(const GoString* s) { 471 xassert(s->buf != NULL || s->len != 0); 472 const uint8_t* start = (const uint8_t*)(s->buf); 473 const uint8_t* end = (const uint8_t*)(s->buf + s->len); 474 /* check eof */ 475 if (s->len == 0) { 476 return 0; 477 } 478 utf8_checker checker; 479 utf8_checker_init(&checker); 480 while (start < (end - 128)) { 481 check128(&checker, start); 482 if (check_error(&checker)) { 483 } 484 start += 128; 485 }; 486 while (start < end - 64) { 487 check64(&checker, start); 488 start += 64; 489 } 490 check_remain(&checker, start, end); 491 return check_error(&checker) ? -1 : 0; 492 } 493 #endif