github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/gpt_bpe.go (about) 1 package gpt_bpe 2 3 import ( 4 "bufio" 5 "bytes" 6 "encoding/binary" 7 "encoding/json" 8 "io" 9 "log" 10 "math" 11 "regexp" 12 "regexp/syntax" 13 "sort" 14 "strconv" 15 "strings" 16 "sync" 17 "unicode" 18 19 "github.com/pkg/errors" 20 21 lru "github.com/hashicorp/golang-lru" 22 "github.com/wbrown/gpt_bpe/resources" 23 "github.com/wbrown/gpt_bpe/types" 24 ) 25 26 const BPE_LRU_SZ = 16384 27 const RUNEBUF_SZ = 16384 28 const WORDCHAN_SZ = 4096 29 const defaultPadTokenString = "[PAD]" 30 31 type Token = types.Token 32 type Tokens = types.Tokens 33 34 type TypedTwoTierCache struct { 35 // Filler 36 filler int 37 } 38 39 type GPTEncoder struct { 40 Encoder map[string]Token 41 Decoder map[Token][]byte 42 BpeRanks map[GPTPair]float64 43 TokenMerges map[TokenPair]Token 44 BytesEncoder *map[byte]Token 45 unitrim []int 46 pattern *regexp.Regexp 47 puncPat *regexp.Regexp 48 specialsPat *regexp.Regexp 49 byteToRune [256]rune 50 runeToByte map[rune]byte 51 Specials map[string]Tokens 52 SpecialsTree *RuneNode 53 Cache *lru.ARCCache 54 TwoTierCache *TypedTwoTierCache 55 PuncRunes []rune 56 Normalizer *strings.Replacer 57 DecodeExtra *strings.Replacer 58 BosToken Token 59 EosToken Token 60 PadToken Token 61 ignoreMerges bool 62 encloseEosBos bool 63 encloseBos bool 64 encloseEos bool 65 prefixSpace bool 66 lowerCase bool 67 endOfWord string 68 replacements map[string]string 69 runeBufSz int 70 wordChanSz int 71 LruHits int 72 LruMisses int 73 LruEvictions int 74 LruSize int 75 SplitterThreads int 76 VocabId string 77 tokenizerClass string 78 normalizerStringMap map[string]string 79 regexWordSplitterTree *RegexNode 80 wordSplitterMap [][]int 81 } 82 83 type GPTPair struct { 84 Left string 85 Right string 86 } 87 88 type TokenPair struct { 89 Left Token 90 Right Token 91 } 92 93 type BGERank struct { 94 rank float64 95 bigram GPTPair 96 } 97 98 type BGERanks []BGERank 99 100 func (bs BGERanks) Len() int { 101 return len(bs) 102 } 103 104 func (bs BGERanks) Swap(i, j int) { 105 bs[i], bs[j] = bs[j], bs[i] 106 } 107 108 func (bs BGERanks) Less(i, j int) bool { 109 return bs[i].rank < bs[j].rank 110 } 111 112 const SPLIT_REGEX = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L" + 113 "}+| ?\\p{N}+| ?[^\\s\\p{L" + 114 "}\\p{N}]+|\\s+(\\S){0}|\\s+" 115 const PUNC_REGEX = "\\p{L}[.!?;]\\p{L}" 116 const REGEX_ERROR = "gpt_bpe: Fatal error compiling regular expression: %v" 117 118 const VOCAB_ID_GPT2 = "gpt2-tokenizer" 119 const VOCAB_ID_PILE = "pile-tokenizer" 120 const VOCAB_ID_CLIP = "clip-tokenizer" 121 const VOCAB_ID_NERDSTASH_V1 = "nerdstash_v1-tokenizer" 122 const VOCAB_ID_NERDSTASH_V2 = "nerdstash_v2-tokenizer" 123 const VOCAB_ID_LLAMA = "llama-tokenizer" 124 const VOCAB_ID_LLAMA_3 = "llama3-tokenizer" 125 const VOCAB_ID_MISTRAL = "mistral-tokenizer" 126 127 func NewGPT2Encoder() GPTEncoder { 128 encoder, _ := NewEncoder(VOCAB_ID_GPT2) 129 return *encoder 130 } 131 132 func NewPileEncoder() GPTEncoder { 133 encoder, _ := NewEncoder(VOCAB_ID_PILE) 134 return *encoder 135 } 136 137 func NewCLIPEncoder() GPTEncoder { 138 encoder, _ := NewEncoder(VOCAB_ID_CLIP) 139 return *encoder 140 } 141 142 func NewNerdstashV1Encoder() GPTEncoder { 143 encoder, _ := NewEncoder(VOCAB_ID_NERDSTASH_V1) 144 return *encoder 145 } 146 147 func NewNerdstashV2Encoder() GPTEncoder { 148 encoder, _ := NewEncoder(VOCAB_ID_NERDSTASH_V2) 149 return *encoder 150 } 151 152 func NewLlama2Encoder() GPTEncoder { 153 encoder, _ := NewEncoder(VOCAB_ID_LLAMA) 154 return *encoder 155 } 156 157 func NewLlama3Encoder() GPTEncoder { 158 encoder, _ := NewEncoder(VOCAB_ID_LLAMA_3) 159 return *encoder 160 } 161 162 func NewMistralEncoder() GPTEncoder { 163 encoder, _ := NewEncoder(VOCAB_ID_MISTRAL) 164 return *encoder 165 } 166 167 func (encoder *GPTEncoder) Clone() *GPTEncoder { 168 // Shallow copy everything first 169 clone := *encoder 170 clone.Cache, _ = lru.NewARC(BPE_LRU_SZ) 171 // Copy our maps 172 clone.Encoder = make(map[string]Token) 173 for k, v := range encoder.Encoder { 174 clone.Encoder[k] = v 175 } 176 clone.Decoder = make(map[Token][]byte) 177 for k, v := range encoder.Decoder { 178 clone.Decoder[k] = v 179 } 180 clone.BpeRanks = make(map[GPTPair]float64) 181 for k, v := range encoder.BpeRanks { 182 clone.BpeRanks[k] = v 183 } 184 clone.TokenMerges = make(map[TokenPair]Token) 185 for k, v := range encoder.TokenMerges { 186 clone.TokenMerges[k] = v 187 } 188 if encoder.BytesEncoder != nil { 189 encoderCopy := make(map[byte]Token) 190 for k, v := range *encoder.BytesEncoder { 191 encoderCopy[k] = v 192 } 193 clone.BytesEncoder = &encoderCopy 194 } 195 clone.unitrim = make([]int, len(encoder.unitrim)) 196 copy(clone.unitrim, encoder.unitrim) 197 clone.PuncRunes = make([]rune, len(encoder.PuncRunes)) 198 copy(clone.PuncRunes, encoder.PuncRunes) 199 clone.Normalizer = encoder.Normalizer 200 clone.normalizerStringMap = encoder.normalizerStringMap 201 clone.DecodeExtra = encoder.DecodeExtra 202 clone.Specials = make(map[string]Tokens) 203 for k, v := range encoder.Specials { 204 clone.Specials[k] = v 205 } 206 clone.UpdateSpecialsTree() 207 for k, v := range encoder.replacements { 208 clone.replacements[k] = v 209 } 210 clone.runeBufSz = encoder.runeBufSz 211 clone.regexWordSplitterTree = encoder.regexWordSplitterTree 212 clone.wordSplitterMap = encoder.wordSplitterMap 213 return &clone 214 } 215 216 // NewEncoder 217 // Returns a GPTEncoder with the tokenizer data loaded for that vocabulary 218 // id. 219 func NewEncoder(vocabId string) (*GPTEncoder, error) { 220 log.Printf("Loading encoder for vocab id: %s\n", vocabId) 221 hfConfig, resourcesPtr, vocabErr := resources.ResolveVocabId( 222 vocabId, "", 223 ) 224 225 if vocabErr != nil { 226 return nil, vocabErr 227 } else if hfConfig == nil { 228 // We should never get this error, but just in case, we return an 229 // error if we can't find the config. 230 return nil, errors.Errorf( 231 "Can't load encoder for vocab id: %s", 232 vocabId, 233 ) 234 } else if resourcesPtr == nil { 235 return nil, errors.Errorf( 236 "Can't load resources for vocab id: %s", 237 vocabId, 238 ) 239 } 240 rsrcs := *resourcesPtr 241 242 if hfConfig.ModelId != nil { 243 vocabId = *hfConfig.ModelId 244 } 245 246 specialConfig := resources.SpecialConfig{ 247 PuncRunes: nil, 248 Normalizer: nil, 249 EncloseEosBos: false, 250 PrefixSpace: true, 251 LowerCase: false, 252 EndOfWord: "", 253 DecodeExtra: nil, 254 SplitRegex: nil, 255 } 256 if special, ok := (rsrcs)["special_config.json"]; ok { 257 if special.Data != nil { 258 if json.Unmarshal(*special.Data, &specialConfig) != nil { 259 log.Fatal("Error unmarshalling special_config.json") 260 } 261 } 262 } 263 264 // Sometimes we have a split regex that's provided by the model's 265 // tokenizer config. 266 if specialConfig.SplitRegex == nil { 267 splitRegexPtr := rsrcs.ResolveSplitRegex() 268 if splitRegexPtr != nil { 269 // Use our default split regex if we can't find one. 270 specialConfig.SplitRegex = splitRegexPtr 271 } 272 } 273 274 // These are the runes that are considered punctuation and have 275 // special handling. 276 puncRunes := make([]rune, 0) 277 if specialConfig.PuncRunes != nil { 278 for _, r := range specialConfig.PuncRunes { 279 puncRunes = append(puncRunes, rune((*r)[0])) 280 } 281 } 282 283 // Create a replacer for normalizing text. 284 normalizer := strings.NewReplacer() 285 norms := make([]string, 0) 286 normsMap := make(map[string]string) 287 if specialConfig.Normalizer != nil { 288 289 for k, v := range *specialConfig.Normalizer { 290 norms = append(norms, k, v) 291 normsMap[k] = v 292 } 293 normalizer = strings.NewReplacer(norms...) 294 } 295 296 // Create a replacer for extra decoding. This is used to decode 297 // special tokens that are not in the encoder. 298 decodeExtra := strings.NewReplacer() 299 if specialConfig.DecodeExtra != nil { 300 decode := make([]string, 0) 301 for k, v := range *specialConfig.DecodeExtra { 302 decode = append(decode, k, v) 303 } 304 decodeExtra = strings.NewReplacer(decode...) 305 } 306 307 // Build the bytes to unicode tables. 308 bytesUnicode, unicodeBytes := makeByteTranslationTables() 309 310 // Read encoder mappings. 311 vocab, err := rsrcs.GetVocab(hfConfig) 312 if err != nil { 313 return nil, err 314 } 315 encoderTokens := make(map[string]Token) 316 for k, v := range vocab { 317 encoderTokens[k] = Token(v) 318 } 319 320 // Build the unitrim array. This is used to trim token sequences 321 // to valid UTF-8 boundaries. 322 unitrimArr := makeUnitrimArr(encoderTokens) 323 324 // Go through the encoder mappings for possible byte runes 325 // and also generate reverse mappings. 326 bytesEncoder := make(map[byte]Token) 327 tokensEncoder := make(map[Token][]byte) 328 for text, token := range encoderTokens { 329 if strings.HasPrefix(text, "0x") && len(text) == 4 { 330 // Convert the hex string to a byte 331 byteValue, err := strconv.ParseUint(text[2:], 16, 8) 332 if err != nil { 333 panic(err) 334 } 335 tokensEncoder[token] = []byte{byte(byteValue)} 336 bytesEncoder[byte(byteValue)] = token 337 delete(encoderTokens, text) 338 } else { 339 tokensEncoder[token] = []byte(text) 340 } 341 } 342 bytesEncoderPtr := &bytesEncoder 343 if len(bytesEncoder) == 0 { 344 bytesEncoderPtr = nil 345 } 346 347 // Read merge table into BpeRanks 348 bpeRanks := make(map[GPTPair]float64) 349 rscBpeRanks, err := resources.GetMergesAsBpeRank(&rsrcs) 350 if err != nil { 351 return nil, err 352 } 353 // Convert rscBpeRanks to bpeRanks (map[GPTPair]float64) 354 for k, v := range rscBpeRanks { 355 bpeRanks[GPTPair{k.Left, k.Right}] = v 356 } 357 358 // Build our TokenMerges. These are used to merge tokens together 359 // based on the BPE merge table. 360 tokenMerges := make(map[TokenPair]Token) 361 for pair := range bpeRanks { 362 tokenMerges[TokenPair{ 363 encoderTokens[pair.Left], 364 encoderTokens[pair.Right]}] = 365 encoderTokens[pair.Left+pair.Right] 366 } 367 368 // Handle special tokens. Special tokens are removed from input before 369 // tokenization, so we need to search for them before we tokenize. 370 specialsRegexTokens := make([]string, 0) 371 specials := make(map[string]Tokens) 372 specialsArr := make([]string, 0) 373 374 if specialsTxt, ok := rsrcs["specials.txt"]; ok { 375 specialsBuffer := bytes.NewBuffer(*specialsTxt.Data) 376 specialsScanner := bufio.NewScanner(specialsBuffer) 377 for specialsScanner.Scan() { 378 specialToken := specialsScanner.Text() 379 if specialToken == "" { 380 continue 381 } 382 specials[specialToken] = Tokens{encoderTokens[specialToken]} 383 specialsArr = append(specialsArr, specialToken) 384 quotedToken := regexp.QuoteMeta(specialToken) 385 specialsRegexTokens = append( 386 specialsRegexTokens, quotedToken, 387 ) 388 } 389 } else if specialsJson, ok := rsrcs["specials.json"]; ok { 390 specialsData := make(map[string]string) 391 seenSpecials := make(map[string]bool) 392 if specialErr := json.Unmarshal( 393 *specialsJson.Data, 394 &specialsData, 395 ); specialErr != nil { 396 return nil, specialErr 397 } 398 for _, v := range specialsData { 399 if _, seen := seenSpecials[v]; !seen { 400 seenSpecials[v] = true 401 specials[v] = Tokens{encoderTokens[v]} 402 specialsArr = append(specialsArr, v) 403 quotedToken := regexp.QuoteMeta(v) 404 specialsRegexTokens = append( 405 specialsRegexTokens, quotedToken, 406 ) 407 } 408 } 409 } 410 specialsRegex := strings.Join(specialsRegexTokens, "|") 411 412 // Now compile our regexes. 413 specialsPat, err := regexp.Compile(specialsRegex) 414 if err != nil { 415 log.Fatalf(REGEX_ERROR, err) 416 } 417 418 var pat *regexp.Regexp 419 if specialConfig.SplitRegex != nil { 420 pat, err = regexp.Compile(*specialConfig.SplitRegex) 421 } else { 422 pat, err = regexp.Compile(SPLIT_REGEX) 423 } 424 if err != nil { 425 log.Fatalf(REGEX_ERROR, err) 426 } 427 puncPat, err := regexp.Compile(PUNC_REGEX) 428 if err != nil { 429 log.Fatalf(REGEX_ERROR, err) 430 } 431 432 cache, _ := lru.NewARC(BPE_LRU_SZ) 433 434 replacements := make(map[string]string) 435 if hfConfig.NewLineMode != nil && *hfConfig.NewLineMode == "s" { 436 replacements["\n"] = "</s>" 437 } 438 439 if specialConfig.EncloseEosBos { 440 bosBool := true 441 eosBool := true 442 hfConfig.AddBosToken = &bosBool 443 hfConfig.AddEosToken = &eosBool 444 } 445 446 // Add in default pad token if not already set 447 padTokenNotFound := hfConfig.PadTokenStr == nil || 448 *hfConfig.PadTokenStr == "" 449 if padTokenNotFound { 450 // Attempt to resolve from specials 451 for k := range specials { 452 if strings.Contains(k, "pad") { 453 hfConfig.PadTokenStr = &k 454 padTokenNotFound = false 455 break 456 } 457 } 458 // Inject the pad token into the encoder to uintmax32, 459 // throw an error if vocab is larger than uintmax32 460 if len(encoderTokens) >= math.MaxUint32 { 461 log.Fatalf( 462 "Vocab size of %d is larger than uint32 max of %d. "+ 463 "Please specify a pad token in the vocab file.", 464 len(encoderTokens), math.MaxUint32, 465 ) 466 } 467 if padTokenNotFound { 468 padToken := defaultPadTokenString 469 if len(encoderTokens) >= math.MaxUint16 { 470 encoderTokens[padToken] = math.MaxUint32 471 } else { 472 encoderTokens[padToken] = math.MaxUint16 473 } 474 hfConfig.PadTokenStr = &padToken 475 } 476 } 477 478 // Create the encoder 479 encoder := &GPTEncoder{ 480 Encoder: encoderTokens, 481 Decoder: tokensEncoder, 482 BpeRanks: bpeRanks, 483 TokenMerges: tokenMerges, 484 BytesEncoder: bytesEncoderPtr, 485 unitrim: unitrimArr, 486 pattern: pat, 487 puncPat: puncPat, 488 specialsPat: specialsPat, 489 byteToRune: bytesUnicode, 490 runeToByte: unicodeBytes, 491 Specials: specials, 492 SpecialsTree: nil, 493 Cache: cache, 494 PuncRunes: puncRunes, 495 Normalizer: normalizer, 496 DecodeExtra: decodeExtra, 497 BosToken: encoderTokens[*hfConfig.BosTokenStr], 498 EosToken: encoderTokens[*hfConfig.EosTokenStr], 499 PadToken: encoderTokens[*hfConfig.PadTokenStr], 500 ignoreMerges: *hfConfig.IgnoreMerges, 501 encloseEosBos: specialConfig.EncloseEosBos, 502 encloseBos: *hfConfig.AddBosToken, 503 encloseEos: *hfConfig.AddEosToken, 504 prefixSpace: specialConfig.PrefixSpace, 505 lowerCase: specialConfig.LowerCase, 506 endOfWord: specialConfig.EndOfWord, 507 replacements: replacements, 508 runeBufSz: RUNEBUF_SZ, 509 wordChanSz: WORDCHAN_SZ, 510 LruHits: 0, 511 LruMisses: 0, 512 LruEvictions: 0, 513 LruSize: BPE_LRU_SZ, 514 SplitterThreads: 2, 515 VocabId: vocabId, 516 tokenizerClass: *hfConfig.TokenizerClass, 517 normalizerStringMap: normsMap, 518 regexWordSplitterTree: nil, 519 wordSplitterMap: nil, 520 } 521 encoder.UpdateSpecialsTree() 522 return encoder, nil 523 } 524 525 func (encoder *GPTEncoder) UpdateSpecialsTree() { 526 // Turn the keys of the specials map into a slice 527 idx := 0 528 specialsArr := make([]string, len(encoder.Specials)) 529 for k := range encoder.Specials { 530 specialsArr[idx] = k 531 idx++ 532 } 533 encoder.SpecialsTree = CreateRuneTree(specialsArr) 534 } 535 536 // makeByteTranslationTables creates lookup tables for interconverting 537 // between runes in decoded token strings and the UTF-8 byte sequences 538 // that they encode. 539 func makeByteTranslationTables() ([256]rune, map[rune]byte) { 540 // GPT2's BPE implementation reinterprets UTF-8-encoded bytes as 541 // Unicode codepoints, but remaps the 68 code points 542 // corresponding to control, format, and space-separator characters 543 // (i.e. Unicode character categories Cc, Cf, and Zs) 544 // in the range [0, 255] to sequential codepoints in [256, 323], 545 // which happens to contain no characters from those three categories. 546 // For example, the byte \x00 is mapped to codepoint 256, and the final 547 // affected byte \xAD is mapped to codepoint 323. 548 // The remapped bytes are sequential even though the original bytes 549 // are not. The original bytes' codepoint interpretations all fall 550 // in the following ranges: 551 // - [\x00, \x20] ('NUL' to 'SPACE'; up to right before '!'), 552 // - [\x7F, \xA0] ('DELETE' to 'NO-BREAK SPACE'; between '~' and '¡') 553 // - \xAD exactly ('SOFT HYPHEN') 554 // Refer to "src/encoder.py" in the openai/gpt-2 repository for 555 // more detail. 556 557 byteDecoderMap := make(map[rune]byte, 256) 558 var byteEncoderLUT [256]rune 559 560 for i, relocated := rune(0), rune(256); i < 256; i++ { 561 relocatedByte := i 562 if i < '!' || i > '~' && i < '¡' || i == '\xAD' { 563 relocatedByte = relocated 564 relocated++ 565 } 566 byteEncoderLUT[i] = relocatedByte 567 byteDecoderMap[relocatedByte] = byte(i) 568 } 569 570 return byteEncoderLUT, byteDecoderMap 571 } 572 573 // makeUnitrimArr creates a lookup table for trimming token sequences 574 // to valid UTF-8 boundaries. It replaces unitrim.json files generated 575 // in advance. 576 func makeUnitrimArr(encoderMap map[string]Token) []int { 577 // In order to check how many UTF-8 continuation bytes are missing from 578 // each individual token, the decoded token strings need to be translated 579 // to UTF-8. 580 _, byteDecoderMap := makeByteTranslationTables() 581 582 // This function returns the following LUT, representing either 583 // how many continuation bytes are needed following a given token, 584 // or how many continuation bytes a given token fulfills. 585 // Positive entries require that many more continuation bytes to follow; 586 // negative entries fulfill that many continuation bytes. 587 debtLUT := make([]int, len(encoderMap)) 588 589 // Continuation byte requirements are defined by the UTF-8 standard 590 // and can be determined from bit patterns of each byte. We make a 591 // LUT of bit patterns to make this calculation faster. 592 // Only the 5 most significant bits are relevant. 593 var byteDebtLUT [32]int8 594 for b := 0; b <= 0b11110; b++ { 595 // According to UTF-8 variable-length binary encoding: 596 if (b & 0b10000) == 0 { 597 // All 7-bit ASCII characters have the bit pattern 0xxxxxxx 598 // - They are self-contained, and require no continuation 599 // - They are the only characters encoded with a single byte 600 byteDebtLUT[b] = 0 601 } else if (b & 0b11100) == 0b11000 { 602 // All 2-byte characters start with a 110xxxxx byte 603 // - These add +1 continuation byte debt 604 byteDebtLUT[b] = 1 605 } else if (b & 0b11110) == 0b11100 { 606 // All 3-byte characters start with a 1110xxxx byte 607 // - These add +2 continuation byte debt 608 byteDebtLUT[b] = 2 609 } else if (b & 0b11110) == 0b11110 { 610 // All 4-byte characters start with a 11110xxx byte 611 // - These add +3 continuation byte debt 612 // - No valid Unicode starts with 11111xxx, so the last 613 // 0 should be redundant, but some tokenizers include 614 // such bytes in their vocabularies regardless. 615 byteDebtLUT[b] = 3 616 } else if (b & 0b11000) == 0b10000 { 617 // All continuation characters start with a 10xxxxxx byte 618 //- These satisfy (-) 1 continuation byte debt 619 byteDebtLUT[b] = -1 620 } 621 } 622 623 // Calculate the debtLUT entries for each token ID 624 for decodedToken, token := range encoderMap { 625 tokenDebt := 0 626 minTokenDebt := 0 627 628 // Decode each Unicode codepoint into a UTF-8 byte 629 codepoints := []rune(decodedToken) 630 utf8Bytes := make([]byte, len(codepoints)) 631 for i, c := range codepoints { 632 utf8Bytes[i] = byteDecoderMap[c] 633 } 634 635 // Keep track of continuation byte requirements 636 // between each UTF-8 byte. 637 for _, b := range utf8Bytes { 638 b >>= 3 // trim to relevant bits 639 byteDebt := int(byteDebtLUT[b]) 640 if byteDebt < 0 { 641 // Continuation bytes are tracked relative to the bytes 642 // preceding them 643 tokenDebt += byteDebt 644 } else { 645 // Starting bytes have no relation to bytes preceding them 646 tokenDebt = byteDebt 647 } 648 649 if tokenDebt < 0 { 650 minTokenDebt = tokenDebt 651 } else if tokenDebt == 0 { 652 // If the beginning of the string satisfies continuation 653 // byte debt, don't forget that just to track less-important 654 // information about self-contained byte sequences that follow. 655 // Do overwrite it if it ends with fresh debt. 656 // NB: if a token both satisfies continuation byte debt 657 // and then begins new debt, only the latter can be tracked. 658 // This is a limitation of the LUT entries being single 659 // integers rather than pairs of integers. 660 tokenDebt = minTokenDebt 661 } 662 } 663 debtLUT[token] = tokenDebt 664 } 665 666 return debtLUT 667 } 668 669 type PreallocBGERanks struct { 670 data []BGERank 671 len int 672 } 673 674 func NewPreallocBGERanks(capacity int) *PreallocBGERanks { 675 return &PreallocBGERanks{ 676 data: make([]BGERank, capacity), 677 len: 0, 678 } 679 } 680 681 func (p *PreallocBGERanks) InsertSorted(v BGERank) { 682 // Binary search 683 i := sort.Search( 684 p.len, func(i int) bool { 685 return p.data[i].rank >= v.rank 686 }, 687 ) 688 689 // Check for exact duplicate using full BGERank comparison 690 if i < p.len && p.data[i].rank == v.rank && p.data[i].bigram == v.bigram { 691 return 692 } 693 694 // Ensure we have space 695 if p.len >= len(p.data) { 696 return // or could panic/grow if needed 697 } 698 699 // Shift and insert 700 if i < p.len { 701 copy(p.data[i+1:p.len+1], p.data[i:p.len]) 702 } 703 p.data[i] = v 704 p.len++ 705 } 706 707 func findBestPair(word []string, bpeRanks map[GPTPair]float64) ( 708 BGERank, 709 bool, 710 ) { 711 var bestRank BGERank 712 bestRank.rank = math.Inf(1) 713 found := false 714 715 prev := word[0] 716 pair := GPTPair{} 717 wordLen := len(word) // Calculate once 718 719 for idx := 1; idx < wordLen; idx++ { 720 present := word[idx] 721 pair.Left = prev 722 pair.Right = present 723 if rank, ok := bpeRanks[pair]; ok { 724 if rank < bestRank.rank { 725 bestRank = BGERank{rank, pair} 726 found = true 727 } 728 } 729 prev = present 730 } 731 return bestRank, found 732 } 733 734 // Standard version with proper duplicate checking 735 func insertSortedNoDups(data BGERanks, v BGERank) BGERanks { 736 // Fast path: append to end if it's greater than all existing elements 737 if len(data) == 0 || data[len(data)-1].rank < v.rank { 738 return append(data, v) 739 } 740 741 i := sort.Search( 742 len(data), func(i int) bool { 743 return data[i].rank >= v.rank 744 }, 745 ) 746 747 // Check for exact duplicate using full BGERank comparison 748 if i < len(data) && data[i].rank == v.rank && data[i].bigram == v.bigram { 749 return data 750 } 751 752 // Use optimized insertAt 753 if len(data) == cap(data) { 754 // Grow slice with extra space 755 newCap := cap(data) * 2 756 if newCap == 0 { 757 newCap = 4 758 } 759 newData := make([]BGERank, len(data), newCap) 760 copy(newData, data) 761 data = newData 762 } 763 764 // Extend length by one 765 data = data[:len(data)+1] 766 767 // Shift elements in a single operation 768 if i < len(data)-1 { 769 copy(data[i+1:], data[i:len(data)-1]) 770 } 771 772 // Insert new element 773 data[i] = v 774 return data 775 } 776 777 func getPairs(word []string) []GPTPair { 778 pairsSet := make(map[GPTPair]bool, len(word)) 779 pairs := make([]GPTPair, len(word)) 780 begin := 1 781 prev := word[0] 782 ct := 0 783 for idx := begin; idx < len(word); idx++ { 784 present := word[idx] 785 pair := GPTPair{prev, present} 786 if _, ok := pairsSet[pair]; !ok { 787 pairs[len(pairsSet)] = pair 788 ct++ 789 } 790 pairsSet[pair] = true 791 prev = present 792 } 793 return pairs[0:ct] 794 } 795 796 // getRankedPairs 797 // Accepts a slice of strings and returns a slice of BGERanks, sorted by 798 // their rank. 799 func (encoder *GPTEncoder) getRankedPairs(word []string) BGERanks { 800 rankedPairs := make(BGERanks, 0, len(word)) 801 begin := 1 802 prev := word[0] 803 for idx := begin; idx < len(word); idx++ { 804 present := word[idx] 805 pair := GPTPair{prev, present} 806 bpe, ok := encoder.BpeRanks[pair] 807 if !ok { 808 bpe = math.Inf(1) 809 } 810 rankedPairs = insertSortedNoDups( 811 rankedPairs, 812 BGERank{bpe, pair}, 813 ) 814 prev = present 815 } 816 return rankedPairs 817 } 818 819 // rankPairs 820 // Accepts a slice of GPTPair and returns a slice of BGERanks, sorted by 821 // their rank. 822 func (encoder *GPTEncoder) rankPairs(pairs []GPTPair) BGERanks { 823 rankedPairs := make(BGERanks, 0) 824 for idx := range pairs { 825 bpe, ok := encoder.BpeRanks[pairs[idx]] 826 if !ok { 827 bpe = math.Inf(1) 828 } 829 rankedPairs = insertSortedNoDups( 830 rankedPairs, 831 BGERank{bpe, pairs[idx]}, 832 ) 833 } 834 sort.Sort(rankedPairs) 835 return rankedPairs 836 } 837 838 // minPair 839 // Accepts a slice of GPTPair and returns the pair with the lowest BPE rank. 840 func (encoder *GPTEncoder) minPair(pairs []GPTPair) (retPair GPTPair) { 841 rankedPairs := encoder.rankPairs(pairs) 842 if len(rankedPairs) > 0 { 843 retPair = rankedPairs[0].bigram 844 } 845 return retPair 846 } 847 848 // pos finds the index of the first occurrence of seek in word past index i. 849 func pos(word []string, seek string, i int) int { 850 for j, v := range word[i:] { 851 if seek == v { 852 return j + i 853 } 854 } 855 return -1 856 } 857 858 // findAllStringIndex returns a set of indexes of all occurrences of substr in 859 // string. 860 func findAllStringIndex(text string, substr string) [][]int { 861 var indexes [][]int 862 for i := 0; i < len(text); { 863 j := strings.Index(text[i:], substr) 864 if j < 0 { 865 break 866 } 867 indexes = append(indexes, []int{i + j, i + j + len(substr)}) 868 i += j + len(substr) 869 } 870 return indexes 871 } 872 873 // findAllStringsIndexes returns a set of indexes of all occurrences of strings, 874 // which are substrings of text removing all overlaps. 875 func findAllStringsIndexes(text string, strings []string) [][]int { 876 var indexes [][]int 877 for _, substr := range strings { 878 indexes = append(indexes, findAllStringIndex(text, substr)...) 879 } 880 return indexes 881 } 882 883 var wordBufferPool = sync.Pool{ 884 New: func() interface{} { 885 s1 := make([]string, 0, 256) 886 s2 := make([]string, 0, 256) 887 return &[2][]string{s1, s2} // Return pair of buffers 888 }, 889 } 890 891 // ToBPE 892 // Given pre-split text, perform bigram ranking and merges, and returns Tokens 893 // Add at package level - reusable buffers for common operations 894 func (encoder *GPTEncoder) ToBPE(text string) Tokens { 895 if lookup, ok := encoder.Cache.Get(text); ok { 896 encoder.LruHits++ 897 return lookup.(Tokens) 898 } 899 encoder.LruMisses++ 900 901 // Early return for ignoreMerges case 902 if encoder.ignoreMerges { 903 if token, ok := encoder.Encoder[text]; ok { 904 encoder.Cache.Add(text, Tokens{token}) 905 return Tokens{token} 906 } 907 } 908 909 // Get word buffer from pool 910 bufsPtr := wordBufferPool.Get().(*[2][]string) 911 word := (*bufsPtr)[0][:0] 912 newWord := (*bufsPtr)[1][:0] 913 defer wordBufferPool.Put(bufsPtr) 914 915 word = append(word, strings.Split(text, "")...) 916 if len(word) > 0 { 917 word[len(word)-1] = word[len(word)-1] + encoder.endOfWord 918 } 919 920 // Single character optimization 921 if len(word) == 1 { 922 var tokens Tokens 923 if token, ok := encoder.Encoder[word[0]]; ok { 924 tokens = Tokens{token} 925 } else if encoder.BytesEncoder != nil { 926 tokens = make(Tokens, 0, len(word[0])) 927 runeBytes := []byte(word[0]) 928 for _, b := range runeBytes { 929 tokens = append(tokens, (*encoder.BytesEncoder)[b]) 930 } 931 } else { 932 tokens = Tokens{encoder.Encoder[word[0]]} 933 } 934 encoder.Cache.Add(text, tokens) 935 return tokens 936 } 937 938 // Main merge loop using findBestPair 939 for { 940 bestRank, found := findBestPair(word, encoder.BpeRanks) 941 if !found { 942 break 943 } 944 945 // Reset newWord for reuse 946 newWord = newWord[:0] 947 first := bestRank.bigram.Left 948 second := bestRank.bigram.Right 949 950 for i := 0; i < len(word); { 951 j := pos(word, first, i) 952 if j == -1 { 953 newWord = append(newWord, word[i:]...) 954 break 955 } 956 newWord = append(newWord, word[i:j]...) 957 i = j 958 959 if word[i] == first && i < len(word)-1 && word[i+1] == second { 960 newWord = append(newWord, first+second) 961 i += 2 962 } else { 963 newWord = append(newWord, word[i]) 964 i += 1 965 } 966 } 967 968 word, newWord = newWord, word 969 970 if len(word) == 1 { 971 break 972 } 973 } 974 975 // Final encoding 976 tokens := make(Tokens, 0, len(word)) 977 for _, token := range word { 978 if lookup, ok := encoder.Encoder[token]; ok { 979 tokens = append(tokens, lookup) 980 } else if encoder.BytesEncoder != nil { 981 runeBytes := []byte(token) 982 for _, b := range runeBytes { 983 tokens = append(tokens, (*encoder.BytesEncoder)[b]) 984 } 985 } 986 } 987 988 encoder.Cache.Add(text, tokens) 989 return tokens 990 } 991 992 func (encoder *GPTEncoder) getSpecials() map[int][][]rune { 993 lenMap := make(map[int][][]rune) 994 for k := range encoder.Specials { 995 keyLen := len(k) 996 keyRunes := []rune(k) 997 if entry, ok := lenMap[keyLen]; ok { 998 lenMap[keyLen] = append(entry, keyRunes) 999 } else { 1000 lenMap[keyLen] = [][]rune{keyRunes} 1001 } 1002 } 1003 return lenMap 1004 } 1005 1006 func (encoder *GPTEncoder) splitWords( 1007 text string, 1008 specialToken bool, specialsNode *RuneNode, 1009 ) []*string { 1010 // Some things such as KoboldAI have a 'replacement' rule, where 1011 // they replace tokens such as `\n` with `</s>` for Fairseq 1012 // handling. 1013 for replaced, replacement := range encoder.replacements { 1014 text = strings.ReplaceAll(text, replaced, replacement) 1015 } 1016 text = encoder.Normalizer.Replace(text) 1017 1018 idxes := encoder.pattern.FindAllStringIndex(text, -1) 1019 words := make([]*string, 0, len(idxes)+1) 1020 for idx := range idxes { 1021 word := text[idxes[idx][0]:idxes[idx][1]] 1022 if encoder.lowerCase { 1023 word = strings.ToLower(word) 1024 } 1025 1026 if !encoder.prefixSpace { 1027 word = strings.TrimSpace(word) 1028 } 1029 1030 if len(word) > 0 { 1031 words = append(words, &word) 1032 } 1033 } 1034 1035 // Finally, if we have a special token, we cap it off. 1036 if specialToken { 1037 runeString := string(specialsNode.runes) 1038 words = append(words, &runeString) 1039 } 1040 return words 1041 } 1042 1043 type NextRuneFunc func() (rune, int, error) 1044 type WordCallback func([]string) 1045 1046 func (encoder *GPTEncoder) makeWordSplitter( 1047 nextRuneFunc NextRuneFunc, 1048 wordCallback WordCallback, 1049 completeCallback func(), 1050 ) func() { 1051 if encoder.regexWordSplitterTree == nil { 1052 regexString := encoder.pattern.String() 1053 if regexString == "" { 1054 regexString = SPLIT_REGEX 1055 } 1056 regexAST, err := syntax.Parse(regexString, syntax.Perl) 1057 if err != nil { 1058 panic(err) 1059 } 1060 regexAST.Simplify() 1061 encoder.regexWordSplitterTree = CreateRegexTree(regexAST) 1062 encoder.wordSplitterMap = encoder.regexWordSplitterTree.GeneratePathMap() 1063 } 1064 1065 // How many words we send on each callback. 1066 const batchSize = 256 1067 workQueue := make(chan []string, encoder.SplitterThreads*2) 1068 wg := sync.WaitGroup{} 1069 wg.Add(1) 1070 1071 // Single consumer goroutine that processes batches 1072 go func() { 1073 defer wg.Done() 1074 for batch := range workQueue { 1075 wordCallback(batch) 1076 } 1077 }() 1078 1079 return func() { 1080 specialsRuneRoot := encoder.SpecialsTree 1081 runeAccumulator := make([]rune, 0, encoder.runeBufSz) 1082 wordBatch := make([]string, 0, batchSize) 1083 specialToken := false 1084 specialsCandidates := make(RuneNodes, 0, 16) 1085 var candidateNode *RuneNode 1086 1087 // Define a function to flush the batch once it is full 1088 flushBatch := func() { 1089 if len(wordBatch) > 0 { 1090 // Copy the batch to prevent race conditions 1091 batch := make([]string, len(wordBatch)) 1092 copy(batch, wordBatch) 1093 workQueue <- batch 1094 wordBatch = wordBatch[:0] 1095 } 1096 } 1097 1098 // appendBatch appends a batch of words to the wordBatch and flushes 1099 // the batch if it is full. 1100 appendBatch := func(words []string, forceFlush bool) { 1101 if len(words) == 0 && (!forceFlush || len(wordBatch) == 0) { 1102 return 1103 } 1104 // If we are appending words, we need to process them 1105 for _, word := range words { 1106 if encoder.lowerCase { 1107 word = strings.ToLower(word) 1108 } 1109 if !encoder.prefixSpace { 1110 word = strings.TrimSpace(word) 1111 } 1112 1113 // After every word, we append it to the wordBatch 1114 // We also check if the wordBatch is full and flush it 1115 if len(word) > 0 { 1116 wordBatch = append(wordBatch, word) 1117 if len(wordBatch) >= batchSize { 1118 flushBatch() 1119 } 1120 } 1121 } 1122 // forceFlush forces the batch to be flushed. 1123 // Useful for ensuring that the last batch is flushed. 1124 if forceFlush && len(wordBatch) > 0 { 1125 // If we are forcing a flush, we flush the batch after processing 1126 // the words 1127 for i, word := range wordBatch { 1128 if encoder.lowerCase { 1129 word = strings.ToLower(word) 1130 } 1131 if !encoder.prefixSpace { 1132 word = strings.TrimSpace(word) 1133 } 1134 wordBatch[i] = word 1135 } 1136 1137 flushBatch() 1138 } 1139 } 1140 1141 processLine := func( 1142 line []rune, 1143 special bool, 1144 node *RuneNode, 1145 ) { 1146 // Find all words by using the regexWordSplitterTree 1147 matches := encoder.regexWordSplitterTree.EvaluateRegexTree( 1148 line, encoder.wordSplitterMap, 1149 ) 1150 for _, word := range matches { 1151 if encoder.lowerCase { 1152 word = strings.ToLower(word) 1153 } 1154 if !encoder.prefixSpace { 1155 word = strings.TrimSpace(word) 1156 } 1157 appendBatch([]string{word}, false) 1158 } 1159 1160 // Re-add the special token if it was removed 1161 // This is done after the regex splitting to ensure that the special 1162 // token is not split by the regex 1163 if special && node != nil { 1164 special := string(node.runes) 1165 appendBatch([]string{special}, false) 1166 } 1167 } 1168 1169 // Apply replacements defined in the runetree 1170 checkAndReplaceNode := func() { 1171 matchLen := len(candidateNode.runes) 1172 accTruncIdx := len(runeAccumulator) - matchLen 1173 runeAccumulator = append( 1174 runeAccumulator[:accTruncIdx], 1175 *candidateNode.replacement..., 1176 ) 1177 specialsCandidates = specialsCandidates[:0] 1178 candidateNode = specialsRuneRoot 1179 specialToken = false 1180 } 1181 // We repeatedly call the nextRuneFunc until it returns an error or other break 1182 // condition. This fills the runeAccumulator with runes until we have a full line. 1183 for { 1184 // Collect runes until newline or special token 1185 for { 1186 r, size, err := nextRuneFunc() 1187 if size == 0 || err != nil { 1188 break 1189 } 1190 1191 runeAccumulator = append(runeAccumulator, r) 1192 1193 if r == '\n' { 1194 break 1195 } 1196 1197 // Conduct replacement and special token checks 1198 candidateNode = specialsCandidates.evaluate(r) 1199 if candidateNode != nil { 1200 if candidateNode.replacement != nil { 1201 checkAndReplaceNode() 1202 } else if candidateNode.terminal { 1203 specialToken = true 1204 break 1205 } 1206 } 1207 1208 candidateNode = specialsRuneRoot.evaluate(r) 1209 if candidateNode != nil { 1210 specialsCandidates = append( 1211 specialsCandidates, 1212 candidateNode, 1213 ) 1214 if candidateNode.replacement != nil { 1215 checkAndReplaceNode() 1216 } else if candidateNode.terminal { 1217 specialToken = true 1218 break 1219 } 1220 } 1221 } 1222 1223 // If we have no runes, we are done 1224 if len(runeAccumulator) == 0 { 1225 appendBatch(nil, true) 1226 wordCallback(nil) 1227 break 1228 } 1229 1230 // Apply replacements and normalization 1231 if specialToken && candidateNode != nil { 1232 runeAccumulator = 1233 runeAccumulator[:len(runeAccumulator)-len( 1234 candidateNode.runes, 1235 )] 1236 } 1237 if len(encoder.replacements) > 0 { 1238 runeAccumulator = replaceRunes( 1239 runeAccumulator, encoder.replacements, 1240 ) 1241 } 1242 1243 if encoder.Normalizer != nil { 1244 if encoder.normalizerStringMap != nil && len(encoder.normalizerStringMap) > 0 { 1245 runeAccumulator = replaceRunes( 1246 runeAccumulator, encoder.normalizerStringMap, 1247 ) 1248 } 1249 } 1250 // If we don't recognize the regex, we default to using the regex package 1251 processLine( 1252 runeAccumulator, specialToken, 1253 candidateNode, 1254 ) 1255 runeAccumulator = runeAccumulator[:0] 1256 candidateNode = specialsRuneRoot 1257 specialToken = false 1258 specialsCandidates = specialsCandidates[:0] 1259 } 1260 1261 // Close the work queue and wait for all workers to finish 1262 close(workQueue) 1263 wg.Wait() 1264 completeCallback() 1265 } 1266 } 1267 1268 func (encoder *GPTEncoder) WordSplitter(reader io.RuneReader) func() *string { 1269 moreWords := make(chan []string, encoder.wordChanSz) 1270 wordSplitter := encoder.makeWordSplitter( 1271 reader.ReadRune, 1272 func(words []string) { 1273 if len(words) > 0 { 1274 moreWords <- words 1275 } 1276 }, 1277 func() { 1278 close(moreWords) 1279 }, 1280 ) 1281 go wordSplitter() 1282 1283 var wordsBuffer []string 1284 idx := 1 1285 1286 return func() *string { 1287 var more bool 1288 if idx >= len(wordsBuffer) { 1289 wordsBuffer, more = <-moreWords 1290 if !more { 1291 return nil 1292 } 1293 idx = 1 1294 } else { 1295 idx++ 1296 } 1297 word := wordsBuffer[idx-1] 1298 return &word 1299 } 1300 } 1301 1302 // Helper functions 1303 func trimSpacesRunes(runes []rune) []rune { 1304 // Runespace trims leading and trailing spaces from a slice of runes 1305 // and returns the trimmed slice. 1306 start := 0 1307 end := len(runes) 1308 for start < end && unicode.IsSpace(runes[start]) { 1309 start++ 1310 } 1311 for end > start && unicode.IsSpace(runes[end-1]) { 1312 end-- 1313 } 1314 return runes[start:end] 1315 } 1316 1317 func toLowercaseRunes(runes []rune) []rune { 1318 // Runespace converts a slice of runes to lowercase and returns the 1319 // lowercase slice. 1320 for i := 0; i < len(runes); i++ { 1321 runes[i] = unicode.ToLower(runes[i]) 1322 } 1323 return runes 1324 } 1325 1326 func replaceRunes( 1327 runes []rune, 1328 replacements map[string]string, 1329 ) []rune { 1330 runeReplacements := make(map[string][]rune, len(replacements)) 1331 for k, v := range replacements { 1332 runeReplacements[k] = []rune(v) 1333 } 1334 1335 // Iterate through runes 1336 for i := 0; i < len(runes); i++ { 1337 matchFound := false 1338 1339 // Iterate through replacements 1340 for k, v := range runeReplacements { 1341 if len(v) == 0 { 1342 continue 1343 } 1344 if runes[i] == []rune(k)[0] { 1345 matchFound = true 1346 if len(v) > 1 { 1347 // Try to get a slice of the runes to match the key, if it matches, replace it 1348 keySlice := runes[i : i+len(k)] 1349 for j := 0; j < len(keySlice); j++ { 1350 if keySlice[j] != []rune(k)[j] { 1351 matchFound = false 1352 break 1353 } 1354 } 1355 if matchFound { 1356 runes = append(runes[:i], []rune(v)...) 1357 } 1358 1359 } else { 1360 runes[i] = v[0] 1361 } 1362 } 1363 } 1364 if !matchFound { 1365 continue 1366 } 1367 } 1368 return runes 1369 } 1370 1371 // Excludes new line whitespaces. Thus is horizontal whitespace. 1372 func isHorizontalWhitespace(r rune) bool { 1373 return r == ' ' || r == '\t' || r == '\r' 1374 } 1375 1376 func isSymbol(r rune) bool { 1377 return !unicode.IsLetter(r) && !unicode.IsNumber(r) && !isHorizontalWhitespace(r) && !isNewLine(r) 1378 } 1379 1380 func isNewLine(r rune) bool { 1381 // While \n is often considered a whitespace, we treat it as a symbol 1382 // to ensure it is always a separate token. 1383 return r == '\n' 1384 } 1385 1386 func (encoder *GPTEncoder) SplitWords(text *string) *[]string { 1387 words := make([]string, 0) 1388 nextWord := encoder.WordSplitter(strings.NewReader(*text)) 1389 for { 1390 word := nextWord() 1391 if word == nil { 1392 break 1393 } 1394 words = append(words, *word) 1395 } 1396 return &words 1397 } 1398 1399 func (encoder *GPTEncoder) toUnicode(text *string) string { 1400 if encoder.BytesEncoder != nil { 1401 runes := []rune(*text) 1402 return string(runes) 1403 } 1404 textBytes := []byte(*text) 1405 outArr := make([]rune, len(*text)) 1406 for idx := range textBytes { 1407 outArr[idx] = encoder.byteToRune[textBytes[idx]] 1408 } 1409 return string(outArr) 1410 } 1411 1412 func (encoder *GPTEncoder) encodeTokens(tokens *[]string) (encoded Tokens) { 1413 encoded = make(Tokens, len(*tokens)) 1414 for idx := range *tokens { 1415 encoded[idx] = encoder.Encoder[(*tokens)[idx]] 1416 } 1417 return encoded 1418 } 1419 1420 var tokenAccumulatorPool = sync.Pool{ 1421 New: func() interface{} { 1422 // Size based on typical artifact size - adjust if needed 1423 tokens := make(Tokens, 0, 65536) 1424 return &tokens 1425 }, 1426 } 1427 1428 func (encoder *GPTEncoder) StreamingEncode(reader io.RuneReader) func(int) *Tokens { 1429 nextWord := encoder.WordSplitter(reader) 1430 1431 // Get accumulator from pool 1432 accumulatorPtr := tokenAccumulatorPool.Get().(*Tokens) 1433 accumulator := (*accumulatorPtr)[:0] // Reset length but keep capacity 1434 1435 if encoder.encloseEosBos || encoder.encloseBos { 1436 accumulator = append(accumulator, encoder.BosToken) 1437 } 1438 1439 eosReturned := false 1440 1441 return func(desiredTokens int) *Tokens { 1442 for { 1443 if len(accumulator) >= desiredTokens { 1444 chunk := make(Tokens, desiredTokens) 1445 copy(chunk, accumulator[:desiredTokens]) 1446 1447 // Preserve capacity while shifting remaining tokens 1448 copy(accumulator, accumulator[desiredTokens:]) 1449 accumulator = accumulator[:len(accumulator)-desiredTokens] 1450 return &chunk 1451 } 1452 1453 word := nextWord() 1454 if word == nil { 1455 if (encoder.encloseEosBos || encoder.encloseEos) && !eosReturned { 1456 accumulator = append( 1457 accumulator, encoder.EosToken, 1458 ) 1459 eosReturned = true 1460 } 1461 1462 if len(accumulator) > 0 { 1463 chunk := make(Tokens, len(accumulator)) 1464 copy(chunk, accumulator) 1465 accumulator = accumulator[:0] 1466 return &chunk 1467 } 1468 1469 // Return accumulator to pool when done 1470 tokenAccumulatorPool.Put(accumulatorPtr) 1471 return nil 1472 } 1473 1474 var encodedTokens Tokens 1475 if specialToken, isSpecial := encoder.Specials[*word]; isSpecial { 1476 encodedTokens = Tokens{ 1477 encoder.Encoder[string(encoder.Decoder[specialToken[0]])], 1478 } 1479 } else { 1480 fragment := encoder.toUnicode(word) 1481 encodedTokens = encoder.ToBPE(fragment) 1482 } 1483 accumulator = append(accumulator, encodedTokens...) 1484 1485 if encoder.ignoreMerges { 1486 continue 1487 } 1488 1489 if offsetIdx := len(accumulator) - len(encodedTokens) - 1; offsetIdx >= 0 { 1490 idx := offsetIdx 1491 for idx < len(accumulator)-1 { 1492 pair := TokenPair{accumulator[idx], accumulator[idx+1]} 1493 if merged, ok := encoder.TokenMerges[pair]; ok && merged != 0 { 1494 before := accumulator[:idx] 1495 var after Tokens 1496 if idx+2 < len(accumulator) { 1497 after = accumulator[idx+2:] 1498 } 1499 accumulator = append(before, merged) 1500 accumulator = append(accumulator, after...) 1501 if idx > 0 { 1502 idx-- 1503 } 1504 } else { 1505 idx++ 1506 } 1507 } 1508 } 1509 } 1510 } 1511 } 1512 1513 func (encoder *GPTEncoder) EncodeReader(reader io.RuneReader) *Tokens { 1514 encoded := make(Tokens, 0, 4096) 1515 nextTokens := encoder.StreamingEncode(reader) 1516 for { 1517 tokens := nextTokens(4096) 1518 if tokens == nil { 1519 break 1520 } 1521 encoded = append(encoded, *tokens...) 1522 } 1523 return &encoded 1524 } 1525 1526 // EncodeBuffer takes a byte array and encodes it into Tokens in another 1527 // byte array. 1528 func (encoder *GPTEncoder) EncodeBuffer(buffer *[]byte) ( 1529 *[]byte, uint64, 1530 ) { 1531 runeReader := bytes.NewReader(*buffer) 1532 nextTokens := encoder.StreamingEncode(runeReader) 1533 buf := bytes.NewBuffer(make([]byte, 0, 4096)) 1534 var count uint64 = 0 1535 for { 1536 tokens := nextTokens(2048) 1537 if tokens == nil { 1538 break 1539 } 1540 _ = binary.Write(buf, binary.LittleEndian, tokens) 1541 count += uint64(len(*tokens)) 1542 } 1543 bufBytes := buf.Bytes() 1544 return &bufBytes, count 1545 } 1546 1547 // Encode encodes a string into a sequence of tokens. 1548 func (encoder *GPTEncoder) Encode(text *string) *Tokens { 1549 // Temporary hack - inject a space token at the end of the accumulator for mistral-tokenizer 1550 if encoder.VocabId == VOCAB_ID_MISTRAL { 1551 *text = " " + *text 1552 } 1553 runeReader := strings.NewReader(*text) 1554 1555 return encoder.EncodeReader(runeReader) 1556 } 1557 1558 // Get 1559 // Looks up text in the Encoder, and returns the Token representation of it. If 1560 // the text is not found, then nil is returned. 1561 func (encoder *GPTEncoder) Get(text string) *Token { 1562 if token, ok := encoder.Encoder[text]; !ok { 1563 return nil 1564 } else { 1565 return &token 1566 } 1567 } 1568 1569 // Decode Tokens back into a string, handling unicode. 1570 func (encoder *GPTEncoder) Decode(encoded *Tokens) (text string) { 1571 // Check if we have an end of word token defined. 1572 convertEndOfWord := false 1573 if encoder.endOfWord != "" { 1574 convertEndOfWord = true 1575 } 1576 // Accumulate tokens until it is unicode complete. 1577 tokensAcc := make(Tokens, 0) 1578 runesAcc := make([]rune, 0) 1579 for i, token := range *encoded { 1580 tokensAcc = append(tokensAcc, token) 1581 bs := make([]byte, 0) 1582 // If we have a byte token and a byteEncoder, then we need to 1583 // accumulate until we have a full rune. If we are at the end of 1584 // the encoded tokens, then we need to decode the accumulated 1585 // tokens regardless. 1586 flagHoldForByte := encoder.IsByteToken(&token) && 1587 encoder.IsLastTokenByte(&tokensAcc) 1588 1589 if encoder.TokensReady(&tokensAcc) && (i == len(*encoded)-1 || !flagHoldForByte) { 1590 for _, safeToken := range tokensAcc { 1591 if v, ok := encoder.Decoder[safeToken]; ok { 1592 bs = append(bs, v...) 1593 } 1594 } 1595 // Convert our bytearray to string, interpreting as UTF-8 and 1596 // then to 32-bit runes. If we don't have a BytesEncoder, then we 1597 // are using GPT BPE's byte encoding algorithm for Unicode. 1598 var runes = []rune(string(bs)) 1599 var fragment string 1600 if encoder.BytesEncoder == nil { 1601 decoded := make([]byte, len(runes)) 1602 // Convert our runes into 8-bit bytes using a 256-slot table. 1603 for runeIdx := range runes { 1604 decoded[runeIdx] = encoder.runeToByte[runes[runeIdx]] 1605 } 1606 fragment = string(decoded) 1607 runes = []rune(fragment) 1608 } else { 1609 fragment = string(bs) 1610 runes = []rune(fragment) 1611 } 1612 // Decode our final token representation into a Unicode string. 1613 if convertEndOfWord { 1614 if strings.HasSuffix(fragment, encoder.endOfWord) { 1615 runes = runes[:len(runes)-len(encoder.endOfWord)] 1616 if len(runes) == 1 && runes[0] == '\'' { 1617 } else { 1618 runes = append(runes, ' ') 1619 } 1620 } 1621 if len(runes) == 1 && 1622 unicode.IsNumber(runes[0]) { 1623 runes = append(runes, ' ') 1624 } 1625 // If we have a punctuation rune, and the previous rune is a 1626 // space, then we remove the space. This is to handle cases 1627 // like " ,". 1628 if len(runesAcc) > 1 && runeIsIn( 1629 runes[0], 1630 encoder.PuncRunes, 1631 ) && unicode.IsSpace( 1632 runesAcc[len( 1633 runesAcc, 1634 )-1], 1635 ) { 1636 runesAcc = runesAcc[:len(runesAcc)-1] 1637 } 1638 } 1639 runesAcc = append(runesAcc, runes...) 1640 tokensAcc = tokensAcc[:0] 1641 } 1642 } 1643 1644 return string(runesAcc) 1645 } 1646 1647 // DecodeBuffer 1648 // Decode Tokens from a byte array into a string. 1649 func (encoder *GPTEncoder) DecodeBuffer( 1650 encoded *[]byte, 1651 useUint32 bool, 1652 ) (text string) { 1653 // First convert our bytearray into uint32 `Token` array. 1654 var tokens *Tokens 1655 if useUint32 { 1656 tokens = types.TokensFromBin32(encoded) 1657 } else { 1658 tokens = types.TokensFromBin(encoded) 1659 } 1660 // Decode our tokens into a string. 1661 return encoder.Decode(tokens) 1662 } 1663 1664 // IsByteToken 1665 // Determine if the token is a byte token. 1666 func (encoder *GPTEncoder) IsByteToken(token *Token) bool { 1667 if encoder.BytesEncoder == nil { 1668 return false 1669 } 1670 return int(*token) <= len(*encoder.BytesEncoder) 1671 } 1672 1673 // IsLastTokenByte 1674 // Determine if the last token in the sequence is a byte token. 1675 func (encoder *GPTEncoder) IsLastTokenByte(tokens *Tokens) bool { 1676 if encoder.BytesEncoder == nil || len(*tokens) == 0 { 1677 return false 1678 } 1679 return encoder.IsByteToken(&(*tokens)[len(*tokens)-1]) 1680 } 1681 1682 // TokensReady 1683 // Determine if the sequence of Tokens given is ready to be serialized 1684 // to string, based on if the sequence will produce valid Unicode runes. 1685 func (encoder *GPTEncoder) TokensReady(tokens *Tokens) bool { 1686 if encoder.BytesEncoder != nil { 1687 return true 1688 } 1689 good := 0 1690 need := 0 1691 for tokenIdx := range *tokens { 1692 tok := (*tokens)[tokenIdx] 1693 var req int 1694 if int(tok) >= len(encoder.unitrim) { 1695 // Don't error out on tokens that we don't know about. 1696 req = 0 1697 } else { 1698 req = encoder.unitrim[(*tokens)[tokenIdx]] 1699 } 1700 1701 if !(need+req < 0) { 1702 need += req 1703 } 1704 if req == 0 { 1705 // reset need to 0 to avoid being stuck when we have invalid 1706 // unicode being generated. 1707 need = 0 1708 } 1709 if need == 0 { 1710 good = tokenIdx + 1 1711 } 1712 } 1713 return good == len(*tokens) 1714 } 1715 1716 // TrimTokens 1717 // Trims the given Tokens to tokens that produce valid unicode. 1718 func (encoder *GPTEncoder) TrimTokens(tokens *Tokens) (trimmed *Tokens) { 1719 trimmed = tokens 1720 for { 1721 if len(*trimmed) == 0 { 1722 return trimmed 1723 } 1724 if encoder.TokensReady(trimmed) { 1725 return trimmed 1726 } else { 1727 newTrimmed := (*trimmed)[0 : len(*trimmed)-1] 1728 trimmed = &newTrimmed 1729 } 1730 } 1731 }