github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/gpt_bpe.go (about)

     1  package gpt_bpe
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"encoding/binary"
     7  	"encoding/json"
     8  	"io"
     9  	"log"
    10  	"math"
    11  	"regexp"
    12  	"regexp/syntax"
    13  	"sort"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"unicode"
    18  
    19  	"github.com/pkg/errors"
    20  
    21  	lru "github.com/hashicorp/golang-lru"
    22  	"github.com/wbrown/gpt_bpe/resources"
    23  	"github.com/wbrown/gpt_bpe/types"
    24  )
    25  
    26  const BPE_LRU_SZ = 16384
    27  const RUNEBUF_SZ = 16384
    28  const WORDCHAN_SZ = 4096
    29  const defaultPadTokenString = "[PAD]"
    30  
    31  type Token = types.Token
    32  type Tokens = types.Tokens
    33  
    34  type TypedTwoTierCache struct {
    35  	// Filler
    36  	filler int
    37  }
    38  
    39  type GPTEncoder struct {
    40  	Encoder               map[string]Token
    41  	Decoder               map[Token][]byte
    42  	BpeRanks              map[GPTPair]float64
    43  	TokenMerges           map[TokenPair]Token
    44  	BytesEncoder          *map[byte]Token
    45  	unitrim               []int
    46  	pattern               *regexp.Regexp
    47  	puncPat               *regexp.Regexp
    48  	specialsPat           *regexp.Regexp
    49  	byteToRune            [256]rune
    50  	runeToByte            map[rune]byte
    51  	Specials              map[string]Tokens
    52  	SpecialsTree          *RuneNode
    53  	Cache                 *lru.ARCCache
    54  	TwoTierCache          *TypedTwoTierCache
    55  	PuncRunes             []rune
    56  	Normalizer            *strings.Replacer
    57  	DecodeExtra           *strings.Replacer
    58  	BosToken              Token
    59  	EosToken              Token
    60  	PadToken              Token
    61  	ignoreMerges          bool
    62  	encloseEosBos         bool
    63  	encloseBos            bool
    64  	encloseEos            bool
    65  	prefixSpace           bool
    66  	lowerCase             bool
    67  	endOfWord             string
    68  	replacements          map[string]string
    69  	runeBufSz             int
    70  	wordChanSz            int
    71  	LruHits               int
    72  	LruMisses             int
    73  	LruEvictions          int
    74  	LruSize               int
    75  	SplitterThreads       int
    76  	VocabId               string
    77  	tokenizerClass        string
    78  	normalizerStringMap   map[string]string
    79  	regexWordSplitterTree *RegexNode
    80  	wordSplitterMap       [][]int
    81  }
    82  
    83  type GPTPair struct {
    84  	Left  string
    85  	Right string
    86  }
    87  
    88  type TokenPair struct {
    89  	Left  Token
    90  	Right Token
    91  }
    92  
    93  type BGERank struct {
    94  	rank   float64
    95  	bigram GPTPair
    96  }
    97  
    98  type BGERanks []BGERank
    99  
   100  func (bs BGERanks) Len() int {
   101  	return len(bs)
   102  }
   103  
   104  func (bs BGERanks) Swap(i, j int) {
   105  	bs[i], bs[j] = bs[j], bs[i]
   106  }
   107  
   108  func (bs BGERanks) Less(i, j int) bool {
   109  	return bs[i].rank < bs[j].rank
   110  }
   111  
   112  const SPLIT_REGEX = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L" +
   113  	"}+| ?\\p{N}+| ?[^\\s\\p{L" +
   114  	"}\\p{N}]+|\\s+(\\S){0}|\\s+"
   115  const PUNC_REGEX = "\\p{L}[.!?;]\\p{L}"
   116  const REGEX_ERROR = "gpt_bpe: Fatal error compiling regular expression: %v"
   117  
   118  const VOCAB_ID_GPT2 = "gpt2-tokenizer"
   119  const VOCAB_ID_PILE = "pile-tokenizer"
   120  const VOCAB_ID_CLIP = "clip-tokenizer"
   121  const VOCAB_ID_NERDSTASH_V1 = "nerdstash_v1-tokenizer"
   122  const VOCAB_ID_NERDSTASH_V2 = "nerdstash_v2-tokenizer"
   123  const VOCAB_ID_LLAMA = "llama-tokenizer"
   124  const VOCAB_ID_LLAMA_3 = "llama3-tokenizer"
   125  const VOCAB_ID_MISTRAL = "mistral-tokenizer"
   126  
   127  func NewGPT2Encoder() GPTEncoder {
   128  	encoder, _ := NewEncoder(VOCAB_ID_GPT2)
   129  	return *encoder
   130  }
   131  
   132  func NewPileEncoder() GPTEncoder {
   133  	encoder, _ := NewEncoder(VOCAB_ID_PILE)
   134  	return *encoder
   135  }
   136  
   137  func NewCLIPEncoder() GPTEncoder {
   138  	encoder, _ := NewEncoder(VOCAB_ID_CLIP)
   139  	return *encoder
   140  }
   141  
   142  func NewNerdstashV1Encoder() GPTEncoder {
   143  	encoder, _ := NewEncoder(VOCAB_ID_NERDSTASH_V1)
   144  	return *encoder
   145  }
   146  
   147  func NewNerdstashV2Encoder() GPTEncoder {
   148  	encoder, _ := NewEncoder(VOCAB_ID_NERDSTASH_V2)
   149  	return *encoder
   150  }
   151  
   152  func NewLlama2Encoder() GPTEncoder {
   153  	encoder, _ := NewEncoder(VOCAB_ID_LLAMA)
   154  	return *encoder
   155  }
   156  
   157  func NewLlama3Encoder() GPTEncoder {
   158  	encoder, _ := NewEncoder(VOCAB_ID_LLAMA_3)
   159  	return *encoder
   160  }
   161  
   162  func NewMistralEncoder() GPTEncoder {
   163  	encoder, _ := NewEncoder(VOCAB_ID_MISTRAL)
   164  	return *encoder
   165  }
   166  
   167  func (encoder *GPTEncoder) Clone() *GPTEncoder {
   168  	// Shallow copy everything first
   169  	clone := *encoder
   170  	clone.Cache, _ = lru.NewARC(BPE_LRU_SZ)
   171  	// Copy our maps
   172  	clone.Encoder = make(map[string]Token)
   173  	for k, v := range encoder.Encoder {
   174  		clone.Encoder[k] = v
   175  	}
   176  	clone.Decoder = make(map[Token][]byte)
   177  	for k, v := range encoder.Decoder {
   178  		clone.Decoder[k] = v
   179  	}
   180  	clone.BpeRanks = make(map[GPTPair]float64)
   181  	for k, v := range encoder.BpeRanks {
   182  		clone.BpeRanks[k] = v
   183  	}
   184  	clone.TokenMerges = make(map[TokenPair]Token)
   185  	for k, v := range encoder.TokenMerges {
   186  		clone.TokenMerges[k] = v
   187  	}
   188  	if encoder.BytesEncoder != nil {
   189  		encoderCopy := make(map[byte]Token)
   190  		for k, v := range *encoder.BytesEncoder {
   191  			encoderCopy[k] = v
   192  		}
   193  		clone.BytesEncoder = &encoderCopy
   194  	}
   195  	clone.unitrim = make([]int, len(encoder.unitrim))
   196  	copy(clone.unitrim, encoder.unitrim)
   197  	clone.PuncRunes = make([]rune, len(encoder.PuncRunes))
   198  	copy(clone.PuncRunes, encoder.PuncRunes)
   199  	clone.Normalizer = encoder.Normalizer
   200  	clone.normalizerStringMap = encoder.normalizerStringMap
   201  	clone.DecodeExtra = encoder.DecodeExtra
   202  	clone.Specials = make(map[string]Tokens)
   203  	for k, v := range encoder.Specials {
   204  		clone.Specials[k] = v
   205  	}
   206  	clone.UpdateSpecialsTree()
   207  	for k, v := range encoder.replacements {
   208  		clone.replacements[k] = v
   209  	}
   210  	clone.runeBufSz = encoder.runeBufSz
   211  	clone.regexWordSplitterTree = encoder.regexWordSplitterTree
   212  	clone.wordSplitterMap = encoder.wordSplitterMap
   213  	return &clone
   214  }
   215  
   216  // NewEncoder
   217  // Returns a GPTEncoder with the tokenizer data loaded for that vocabulary
   218  // id.
   219  func NewEncoder(vocabId string) (*GPTEncoder, error) {
   220  	log.Printf("Loading encoder for vocab id: %s\n", vocabId)
   221  	hfConfig, resourcesPtr, vocabErr := resources.ResolveVocabId(
   222  		vocabId, "",
   223  	)
   224  
   225  	if vocabErr != nil {
   226  		return nil, vocabErr
   227  	} else if hfConfig == nil {
   228  		// We should never get this error, but just in case, we return an
   229  		// error if we can't find the config.
   230  		return nil, errors.Errorf(
   231  			"Can't load encoder for vocab id: %s",
   232  			vocabId,
   233  		)
   234  	} else if resourcesPtr == nil {
   235  		return nil, errors.Errorf(
   236  			"Can't load resources for vocab id: %s",
   237  			vocabId,
   238  		)
   239  	}
   240  	rsrcs := *resourcesPtr
   241  
   242  	if hfConfig.ModelId != nil {
   243  		vocabId = *hfConfig.ModelId
   244  	}
   245  
   246  	specialConfig := resources.SpecialConfig{
   247  		PuncRunes:     nil,
   248  		Normalizer:    nil,
   249  		EncloseEosBos: false,
   250  		PrefixSpace:   true,
   251  		LowerCase:     false,
   252  		EndOfWord:     "",
   253  		DecodeExtra:   nil,
   254  		SplitRegex:    nil,
   255  	}
   256  	if special, ok := (rsrcs)["special_config.json"]; ok {
   257  		if special.Data != nil {
   258  			if json.Unmarshal(*special.Data, &specialConfig) != nil {
   259  				log.Fatal("Error unmarshalling special_config.json")
   260  			}
   261  		}
   262  	}
   263  
   264  	// Sometimes we have a split regex that's provided by the model's
   265  	// tokenizer config.
   266  	if specialConfig.SplitRegex == nil {
   267  		splitRegexPtr := rsrcs.ResolveSplitRegex()
   268  		if splitRegexPtr != nil {
   269  			// Use our default split regex if we can't find one.
   270  			specialConfig.SplitRegex = splitRegexPtr
   271  		}
   272  	}
   273  
   274  	// These are the runes that are considered punctuation and have
   275  	// special handling.
   276  	puncRunes := make([]rune, 0)
   277  	if specialConfig.PuncRunes != nil {
   278  		for _, r := range specialConfig.PuncRunes {
   279  			puncRunes = append(puncRunes, rune((*r)[0]))
   280  		}
   281  	}
   282  
   283  	// Create a replacer for normalizing text.
   284  	normalizer := strings.NewReplacer()
   285  	norms := make([]string, 0)
   286  	normsMap := make(map[string]string)
   287  	if specialConfig.Normalizer != nil {
   288  
   289  		for k, v := range *specialConfig.Normalizer {
   290  			norms = append(norms, k, v)
   291  			normsMap[k] = v
   292  		}
   293  		normalizer = strings.NewReplacer(norms...)
   294  	}
   295  
   296  	// Create a replacer for extra decoding. This is used to decode
   297  	// special tokens that are not in the encoder.
   298  	decodeExtra := strings.NewReplacer()
   299  	if specialConfig.DecodeExtra != nil {
   300  		decode := make([]string, 0)
   301  		for k, v := range *specialConfig.DecodeExtra {
   302  			decode = append(decode, k, v)
   303  		}
   304  		decodeExtra = strings.NewReplacer(decode...)
   305  	}
   306  
   307  	// Build the bytes to unicode tables.
   308  	bytesUnicode, unicodeBytes := makeByteTranslationTables()
   309  
   310  	// Read encoder mappings.
   311  	vocab, err := rsrcs.GetVocab(hfConfig)
   312  	if err != nil {
   313  		return nil, err
   314  	}
   315  	encoderTokens := make(map[string]Token)
   316  	for k, v := range vocab {
   317  		encoderTokens[k] = Token(v)
   318  	}
   319  
   320  	// Build the unitrim array. This is used to trim token sequences
   321  	// to valid UTF-8 boundaries.
   322  	unitrimArr := makeUnitrimArr(encoderTokens)
   323  
   324  	// Go through the encoder mappings for possible byte runes
   325  	// and also generate reverse mappings.
   326  	bytesEncoder := make(map[byte]Token)
   327  	tokensEncoder := make(map[Token][]byte)
   328  	for text, token := range encoderTokens {
   329  		if strings.HasPrefix(text, "0x") && len(text) == 4 {
   330  			// Convert the hex string to a byte
   331  			byteValue, err := strconv.ParseUint(text[2:], 16, 8)
   332  			if err != nil {
   333  				panic(err)
   334  			}
   335  			tokensEncoder[token] = []byte{byte(byteValue)}
   336  			bytesEncoder[byte(byteValue)] = token
   337  			delete(encoderTokens, text)
   338  		} else {
   339  			tokensEncoder[token] = []byte(text)
   340  		}
   341  	}
   342  	bytesEncoderPtr := &bytesEncoder
   343  	if len(bytesEncoder) == 0 {
   344  		bytesEncoderPtr = nil
   345  	}
   346  
   347  	// Read merge table into BpeRanks
   348  	bpeRanks := make(map[GPTPair]float64)
   349  	rscBpeRanks, err := resources.GetMergesAsBpeRank(&rsrcs)
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  	// Convert rscBpeRanks to bpeRanks (map[GPTPair]float64)
   354  	for k, v := range rscBpeRanks {
   355  		bpeRanks[GPTPair{k.Left, k.Right}] = v
   356  	}
   357  
   358  	// Build our TokenMerges. These are used to merge tokens together
   359  	// based on the BPE merge table.
   360  	tokenMerges := make(map[TokenPair]Token)
   361  	for pair := range bpeRanks {
   362  		tokenMerges[TokenPair{
   363  			encoderTokens[pair.Left],
   364  			encoderTokens[pair.Right]}] =
   365  			encoderTokens[pair.Left+pair.Right]
   366  	}
   367  
   368  	// Handle special tokens. Special tokens are removed from input before
   369  	// tokenization, so we need to search for them before we tokenize.
   370  	specialsRegexTokens := make([]string, 0)
   371  	specials := make(map[string]Tokens)
   372  	specialsArr := make([]string, 0)
   373  
   374  	if specialsTxt, ok := rsrcs["specials.txt"]; ok {
   375  		specialsBuffer := bytes.NewBuffer(*specialsTxt.Data)
   376  		specialsScanner := bufio.NewScanner(specialsBuffer)
   377  		for specialsScanner.Scan() {
   378  			specialToken := specialsScanner.Text()
   379  			if specialToken == "" {
   380  				continue
   381  			}
   382  			specials[specialToken] = Tokens{encoderTokens[specialToken]}
   383  			specialsArr = append(specialsArr, specialToken)
   384  			quotedToken := regexp.QuoteMeta(specialToken)
   385  			specialsRegexTokens = append(
   386  				specialsRegexTokens, quotedToken,
   387  			)
   388  		}
   389  	} else if specialsJson, ok := rsrcs["specials.json"]; ok {
   390  		specialsData := make(map[string]string)
   391  		seenSpecials := make(map[string]bool)
   392  		if specialErr := json.Unmarshal(
   393  			*specialsJson.Data,
   394  			&specialsData,
   395  		); specialErr != nil {
   396  			return nil, specialErr
   397  		}
   398  		for _, v := range specialsData {
   399  			if _, seen := seenSpecials[v]; !seen {
   400  				seenSpecials[v] = true
   401  				specials[v] = Tokens{encoderTokens[v]}
   402  				specialsArr = append(specialsArr, v)
   403  				quotedToken := regexp.QuoteMeta(v)
   404  				specialsRegexTokens = append(
   405  					specialsRegexTokens, quotedToken,
   406  				)
   407  			}
   408  		}
   409  	}
   410  	specialsRegex := strings.Join(specialsRegexTokens, "|")
   411  
   412  	// Now compile our regexes.
   413  	specialsPat, err := regexp.Compile(specialsRegex)
   414  	if err != nil {
   415  		log.Fatalf(REGEX_ERROR, err)
   416  	}
   417  
   418  	var pat *regexp.Regexp
   419  	if specialConfig.SplitRegex != nil {
   420  		pat, err = regexp.Compile(*specialConfig.SplitRegex)
   421  	} else {
   422  		pat, err = regexp.Compile(SPLIT_REGEX)
   423  	}
   424  	if err != nil {
   425  		log.Fatalf(REGEX_ERROR, err)
   426  	}
   427  	puncPat, err := regexp.Compile(PUNC_REGEX)
   428  	if err != nil {
   429  		log.Fatalf(REGEX_ERROR, err)
   430  	}
   431  
   432  	cache, _ := lru.NewARC(BPE_LRU_SZ)
   433  
   434  	replacements := make(map[string]string)
   435  	if hfConfig.NewLineMode != nil && *hfConfig.NewLineMode == "s" {
   436  		replacements["\n"] = "</s>"
   437  	}
   438  
   439  	if specialConfig.EncloseEosBos {
   440  		bosBool := true
   441  		eosBool := true
   442  		hfConfig.AddBosToken = &bosBool
   443  		hfConfig.AddEosToken = &eosBool
   444  	}
   445  
   446  	// Add in default pad token if not already set
   447  	padTokenNotFound := hfConfig.PadTokenStr == nil ||
   448  		*hfConfig.PadTokenStr == ""
   449  	if padTokenNotFound {
   450  		// Attempt to resolve from specials
   451  		for k := range specials {
   452  			if strings.Contains(k, "pad") {
   453  				hfConfig.PadTokenStr = &k
   454  				padTokenNotFound = false
   455  				break
   456  			}
   457  		}
   458  		// Inject the pad token into the encoder to uintmax32,
   459  		// throw an error if vocab is larger than uintmax32
   460  		if len(encoderTokens) >= math.MaxUint32 {
   461  			log.Fatalf(
   462  				"Vocab size of %d is larger than uint32 max of %d. "+
   463  					"Please specify a pad token in the vocab file.",
   464  				len(encoderTokens), math.MaxUint32,
   465  			)
   466  		}
   467  		if padTokenNotFound {
   468  			padToken := defaultPadTokenString
   469  			if len(encoderTokens) >= math.MaxUint16 {
   470  				encoderTokens[padToken] = math.MaxUint32
   471  			} else {
   472  				encoderTokens[padToken] = math.MaxUint16
   473  			}
   474  			hfConfig.PadTokenStr = &padToken
   475  		}
   476  	}
   477  
   478  	// Create the encoder
   479  	encoder := &GPTEncoder{
   480  		Encoder:               encoderTokens,
   481  		Decoder:               tokensEncoder,
   482  		BpeRanks:              bpeRanks,
   483  		TokenMerges:           tokenMerges,
   484  		BytesEncoder:          bytesEncoderPtr,
   485  		unitrim:               unitrimArr,
   486  		pattern:               pat,
   487  		puncPat:               puncPat,
   488  		specialsPat:           specialsPat,
   489  		byteToRune:            bytesUnicode,
   490  		runeToByte:            unicodeBytes,
   491  		Specials:              specials,
   492  		SpecialsTree:          nil,
   493  		Cache:                 cache,
   494  		PuncRunes:             puncRunes,
   495  		Normalizer:            normalizer,
   496  		DecodeExtra:           decodeExtra,
   497  		BosToken:              encoderTokens[*hfConfig.BosTokenStr],
   498  		EosToken:              encoderTokens[*hfConfig.EosTokenStr],
   499  		PadToken:              encoderTokens[*hfConfig.PadTokenStr],
   500  		ignoreMerges:          *hfConfig.IgnoreMerges,
   501  		encloseEosBos:         specialConfig.EncloseEosBos,
   502  		encloseBos:            *hfConfig.AddBosToken,
   503  		encloseEos:            *hfConfig.AddEosToken,
   504  		prefixSpace:           specialConfig.PrefixSpace,
   505  		lowerCase:             specialConfig.LowerCase,
   506  		endOfWord:             specialConfig.EndOfWord,
   507  		replacements:          replacements,
   508  		runeBufSz:             RUNEBUF_SZ,
   509  		wordChanSz:            WORDCHAN_SZ,
   510  		LruHits:               0,
   511  		LruMisses:             0,
   512  		LruEvictions:          0,
   513  		LruSize:               BPE_LRU_SZ,
   514  		SplitterThreads:       2,
   515  		VocabId:               vocabId,
   516  		tokenizerClass:        *hfConfig.TokenizerClass,
   517  		normalizerStringMap:   normsMap,
   518  		regexWordSplitterTree: nil,
   519  		wordSplitterMap:       nil,
   520  	}
   521  	encoder.UpdateSpecialsTree()
   522  	return encoder, nil
   523  }
   524  
   525  func (encoder *GPTEncoder) UpdateSpecialsTree() {
   526  	// Turn the keys of the specials map into a slice
   527  	idx := 0
   528  	specialsArr := make([]string, len(encoder.Specials))
   529  	for k := range encoder.Specials {
   530  		specialsArr[idx] = k
   531  		idx++
   532  	}
   533  	encoder.SpecialsTree = CreateRuneTree(specialsArr)
   534  }
   535  
   536  // makeByteTranslationTables creates lookup tables for interconverting
   537  // between runes in decoded token strings and the UTF-8 byte sequences
   538  // that they encode.
   539  func makeByteTranslationTables() ([256]rune, map[rune]byte) {
   540  	// GPT2's BPE implementation reinterprets UTF-8-encoded bytes as
   541  	// Unicode codepoints, but remaps the 68 code points
   542  	// corresponding to control, format, and space-separator characters
   543  	// (i.e. Unicode character categories Cc, Cf, and Zs)
   544  	// in the range [0, 255] to sequential codepoints in [256, 323],
   545  	// which happens to contain no characters from those three categories.
   546  	// For example, the byte \x00 is mapped to codepoint 256, and the final
   547  	// affected byte \xAD is mapped to codepoint 323.
   548  	// The remapped bytes are sequential even though the original bytes
   549  	// are not. The original bytes' codepoint interpretations all fall
   550  	// in the following ranges:
   551  	// - [\x00, \x20] ('NUL' to 'SPACE'; up to right before '!'),
   552  	// - [\x7F, \xA0] ('DELETE' to 'NO-BREAK SPACE'; between '~' and '¡')
   553  	// - \xAD exactly ('SOFT HYPHEN')
   554  	// Refer to "src/encoder.py" in the openai/gpt-2 repository for
   555  	// more detail.
   556  
   557  	byteDecoderMap := make(map[rune]byte, 256)
   558  	var byteEncoderLUT [256]rune
   559  
   560  	for i, relocated := rune(0), rune(256); i < 256; i++ {
   561  		relocatedByte := i
   562  		if i < '!' || i > '~' && i < '¡' || i == '\xAD' {
   563  			relocatedByte = relocated
   564  			relocated++
   565  		}
   566  		byteEncoderLUT[i] = relocatedByte
   567  		byteDecoderMap[relocatedByte] = byte(i)
   568  	}
   569  
   570  	return byteEncoderLUT, byteDecoderMap
   571  }
   572  
   573  // makeUnitrimArr creates a lookup table for trimming token sequences
   574  // to valid UTF-8 boundaries. It replaces unitrim.json files generated
   575  // in advance.
   576  func makeUnitrimArr(encoderMap map[string]Token) []int {
   577  	// In order to check how many UTF-8 continuation bytes are missing from
   578  	// each individual token, the decoded token strings need to be translated
   579  	// to UTF-8.
   580  	_, byteDecoderMap := makeByteTranslationTables()
   581  
   582  	// This function returns the following LUT, representing either
   583  	// how many continuation bytes are needed following a given token,
   584  	// or how many continuation bytes a given token fulfills.
   585  	// Positive entries require that many more continuation bytes to follow;
   586  	// negative entries fulfill that many continuation bytes.
   587  	debtLUT := make([]int, len(encoderMap))
   588  
   589  	// Continuation byte requirements are defined by the UTF-8 standard
   590  	// and can be determined from bit patterns of each byte. We make a
   591  	// LUT of bit patterns to make this calculation faster.
   592  	// Only the 5 most significant bits are relevant.
   593  	var byteDebtLUT [32]int8
   594  	for b := 0; b <= 0b11110; b++ {
   595  		// According to UTF-8 variable-length binary encoding:
   596  		if (b & 0b10000) == 0 {
   597  			// All 7-bit ASCII characters have the bit pattern 0xxxxxxx
   598  			// - They are self-contained, and require no continuation
   599  			// - They are the only characters encoded with a single byte
   600  			byteDebtLUT[b] = 0
   601  		} else if (b & 0b11100) == 0b11000 {
   602  			// All 2-byte characters start with a 110xxxxx byte
   603  			// - These add +1 continuation byte debt
   604  			byteDebtLUT[b] = 1
   605  		} else if (b & 0b11110) == 0b11100 {
   606  			// All 3-byte characters start with a 1110xxxx byte
   607  			// - These add +2 continuation byte debt
   608  			byteDebtLUT[b] = 2
   609  		} else if (b & 0b11110) == 0b11110 {
   610  			// All 4-byte characters start with a 11110xxx byte
   611  			// - These add +3 continuation byte debt
   612  			// - No valid Unicode starts with 11111xxx, so the last
   613  			//   0 should be redundant, but some tokenizers include
   614  			//   such bytes in their vocabularies regardless.
   615  			byteDebtLUT[b] = 3
   616  		} else if (b & 0b11000) == 0b10000 {
   617  			// All continuation characters start with a 10xxxxxx byte
   618  			//- These satisfy (-) 1 continuation byte debt
   619  			byteDebtLUT[b] = -1
   620  		}
   621  	}
   622  
   623  	// Calculate the debtLUT entries for each token ID
   624  	for decodedToken, token := range encoderMap {
   625  		tokenDebt := 0
   626  		minTokenDebt := 0
   627  
   628  		// Decode each Unicode codepoint into a UTF-8 byte
   629  		codepoints := []rune(decodedToken)
   630  		utf8Bytes := make([]byte, len(codepoints))
   631  		for i, c := range codepoints {
   632  			utf8Bytes[i] = byteDecoderMap[c]
   633  		}
   634  
   635  		// Keep track of continuation byte requirements
   636  		// between each UTF-8 byte.
   637  		for _, b := range utf8Bytes {
   638  			b >>= 3 // trim to relevant bits
   639  			byteDebt := int(byteDebtLUT[b])
   640  			if byteDebt < 0 {
   641  				// Continuation bytes are tracked relative to the bytes
   642  				// preceding them
   643  				tokenDebt += byteDebt
   644  			} else {
   645  				// Starting bytes have no relation to bytes preceding them
   646  				tokenDebt = byteDebt
   647  			}
   648  
   649  			if tokenDebt < 0 {
   650  				minTokenDebt = tokenDebt
   651  			} else if tokenDebt == 0 {
   652  				// If the beginning of the string satisfies continuation
   653  				// byte debt, don't forget that just to track less-important
   654  				// information about self-contained byte sequences that follow.
   655  				// Do overwrite it if it ends with fresh debt.
   656  				// NB: if a token both satisfies continuation byte debt
   657  				// and then begins new debt, only the latter can be tracked.
   658  				// This is a limitation of the LUT entries being single
   659  				// integers rather than pairs of integers.
   660  				tokenDebt = minTokenDebt
   661  			}
   662  		}
   663  		debtLUT[token] = tokenDebt
   664  	}
   665  
   666  	return debtLUT
   667  }
   668  
   669  type PreallocBGERanks struct {
   670  	data []BGERank
   671  	len  int
   672  }
   673  
   674  func NewPreallocBGERanks(capacity int) *PreallocBGERanks {
   675  	return &PreallocBGERanks{
   676  		data: make([]BGERank, capacity),
   677  		len:  0,
   678  	}
   679  }
   680  
   681  func (p *PreallocBGERanks) InsertSorted(v BGERank) {
   682  	// Binary search
   683  	i := sort.Search(
   684  		p.len, func(i int) bool {
   685  			return p.data[i].rank >= v.rank
   686  		},
   687  	)
   688  
   689  	// Check for exact duplicate using full BGERank comparison
   690  	if i < p.len && p.data[i].rank == v.rank && p.data[i].bigram == v.bigram {
   691  		return
   692  	}
   693  
   694  	// Ensure we have space
   695  	if p.len >= len(p.data) {
   696  		return // or could panic/grow if needed
   697  	}
   698  
   699  	// Shift and insert
   700  	if i < p.len {
   701  		copy(p.data[i+1:p.len+1], p.data[i:p.len])
   702  	}
   703  	p.data[i] = v
   704  	p.len++
   705  }
   706  
   707  func findBestPair(word []string, bpeRanks map[GPTPair]float64) (
   708  	BGERank,
   709  	bool,
   710  ) {
   711  	var bestRank BGERank
   712  	bestRank.rank = math.Inf(1)
   713  	found := false
   714  
   715  	prev := word[0]
   716  	pair := GPTPair{}
   717  	wordLen := len(word) // Calculate once
   718  
   719  	for idx := 1; idx < wordLen; idx++ {
   720  		present := word[idx]
   721  		pair.Left = prev
   722  		pair.Right = present
   723  		if rank, ok := bpeRanks[pair]; ok {
   724  			if rank < bestRank.rank {
   725  				bestRank = BGERank{rank, pair}
   726  				found = true
   727  			}
   728  		}
   729  		prev = present
   730  	}
   731  	return bestRank, found
   732  }
   733  
   734  // Standard version with proper duplicate checking
   735  func insertSortedNoDups(data BGERanks, v BGERank) BGERanks {
   736  	// Fast path: append to end if it's greater than all existing elements
   737  	if len(data) == 0 || data[len(data)-1].rank < v.rank {
   738  		return append(data, v)
   739  	}
   740  
   741  	i := sort.Search(
   742  		len(data), func(i int) bool {
   743  			return data[i].rank >= v.rank
   744  		},
   745  	)
   746  
   747  	// Check for exact duplicate using full BGERank comparison
   748  	if i < len(data) && data[i].rank == v.rank && data[i].bigram == v.bigram {
   749  		return data
   750  	}
   751  
   752  	// Use optimized insertAt
   753  	if len(data) == cap(data) {
   754  		// Grow slice with extra space
   755  		newCap := cap(data) * 2
   756  		if newCap == 0 {
   757  			newCap = 4
   758  		}
   759  		newData := make([]BGERank, len(data), newCap)
   760  		copy(newData, data)
   761  		data = newData
   762  	}
   763  
   764  	// Extend length by one
   765  	data = data[:len(data)+1]
   766  
   767  	// Shift elements in a single operation
   768  	if i < len(data)-1 {
   769  		copy(data[i+1:], data[i:len(data)-1])
   770  	}
   771  
   772  	// Insert new element
   773  	data[i] = v
   774  	return data
   775  }
   776  
   777  func getPairs(word []string) []GPTPair {
   778  	pairsSet := make(map[GPTPair]bool, len(word))
   779  	pairs := make([]GPTPair, len(word))
   780  	begin := 1
   781  	prev := word[0]
   782  	ct := 0
   783  	for idx := begin; idx < len(word); idx++ {
   784  		present := word[idx]
   785  		pair := GPTPair{prev, present}
   786  		if _, ok := pairsSet[pair]; !ok {
   787  			pairs[len(pairsSet)] = pair
   788  			ct++
   789  		}
   790  		pairsSet[pair] = true
   791  		prev = present
   792  	}
   793  	return pairs[0:ct]
   794  }
   795  
   796  // getRankedPairs
   797  // Accepts a slice of strings and returns a slice of BGERanks, sorted by
   798  // their rank.
   799  func (encoder *GPTEncoder) getRankedPairs(word []string) BGERanks {
   800  	rankedPairs := make(BGERanks, 0, len(word))
   801  	begin := 1
   802  	prev := word[0]
   803  	for idx := begin; idx < len(word); idx++ {
   804  		present := word[idx]
   805  		pair := GPTPair{prev, present}
   806  		bpe, ok := encoder.BpeRanks[pair]
   807  		if !ok {
   808  			bpe = math.Inf(1)
   809  		}
   810  		rankedPairs = insertSortedNoDups(
   811  			rankedPairs,
   812  			BGERank{bpe, pair},
   813  		)
   814  		prev = present
   815  	}
   816  	return rankedPairs
   817  }
   818  
   819  // rankPairs
   820  // Accepts a slice of GPTPair and returns a slice of BGERanks, sorted by
   821  // their rank.
   822  func (encoder *GPTEncoder) rankPairs(pairs []GPTPair) BGERanks {
   823  	rankedPairs := make(BGERanks, 0)
   824  	for idx := range pairs {
   825  		bpe, ok := encoder.BpeRanks[pairs[idx]]
   826  		if !ok {
   827  			bpe = math.Inf(1)
   828  		}
   829  		rankedPairs = insertSortedNoDups(
   830  			rankedPairs,
   831  			BGERank{bpe, pairs[idx]},
   832  		)
   833  	}
   834  	sort.Sort(rankedPairs)
   835  	return rankedPairs
   836  }
   837  
   838  // minPair
   839  // Accepts a slice of GPTPair and returns the pair with the lowest BPE rank.
   840  func (encoder *GPTEncoder) minPair(pairs []GPTPair) (retPair GPTPair) {
   841  	rankedPairs := encoder.rankPairs(pairs)
   842  	if len(rankedPairs) > 0 {
   843  		retPair = rankedPairs[0].bigram
   844  	}
   845  	return retPair
   846  }
   847  
   848  // pos finds the index of the first occurrence of seek in word past index i.
   849  func pos(word []string, seek string, i int) int {
   850  	for j, v := range word[i:] {
   851  		if seek == v {
   852  			return j + i
   853  		}
   854  	}
   855  	return -1
   856  }
   857  
   858  // findAllStringIndex returns a set of indexes of all occurrences of substr in
   859  // string.
   860  func findAllStringIndex(text string, substr string) [][]int {
   861  	var indexes [][]int
   862  	for i := 0; i < len(text); {
   863  		j := strings.Index(text[i:], substr)
   864  		if j < 0 {
   865  			break
   866  		}
   867  		indexes = append(indexes, []int{i + j, i + j + len(substr)})
   868  		i += j + len(substr)
   869  	}
   870  	return indexes
   871  }
   872  
   873  // findAllStringsIndexes returns a set of indexes of all occurrences of strings,
   874  // which are substrings of text removing all overlaps.
   875  func findAllStringsIndexes(text string, strings []string) [][]int {
   876  	var indexes [][]int
   877  	for _, substr := range strings {
   878  		indexes = append(indexes, findAllStringIndex(text, substr)...)
   879  	}
   880  	return indexes
   881  }
   882  
   883  var wordBufferPool = sync.Pool{
   884  	New: func() interface{} {
   885  		s1 := make([]string, 0, 256)
   886  		s2 := make([]string, 0, 256)
   887  		return &[2][]string{s1, s2} // Return pair of buffers
   888  	},
   889  }
   890  
   891  // ToBPE
   892  // Given pre-split text, perform bigram ranking and merges, and returns Tokens
   893  // Add at package level - reusable buffers for common operations
   894  func (encoder *GPTEncoder) ToBPE(text string) Tokens {
   895  	if lookup, ok := encoder.Cache.Get(text); ok {
   896  		encoder.LruHits++
   897  		return lookup.(Tokens)
   898  	}
   899  	encoder.LruMisses++
   900  
   901  	// Early return for ignoreMerges case
   902  	if encoder.ignoreMerges {
   903  		if token, ok := encoder.Encoder[text]; ok {
   904  			encoder.Cache.Add(text, Tokens{token})
   905  			return Tokens{token}
   906  		}
   907  	}
   908  
   909  	// Get word buffer from pool
   910  	bufsPtr := wordBufferPool.Get().(*[2][]string)
   911  	word := (*bufsPtr)[0][:0]
   912  	newWord := (*bufsPtr)[1][:0]
   913  	defer wordBufferPool.Put(bufsPtr)
   914  
   915  	word = append(word, strings.Split(text, "")...)
   916  	if len(word) > 0 {
   917  		word[len(word)-1] = word[len(word)-1] + encoder.endOfWord
   918  	}
   919  
   920  	// Single character optimization
   921  	if len(word) == 1 {
   922  		var tokens Tokens
   923  		if token, ok := encoder.Encoder[word[0]]; ok {
   924  			tokens = Tokens{token}
   925  		} else if encoder.BytesEncoder != nil {
   926  			tokens = make(Tokens, 0, len(word[0]))
   927  			runeBytes := []byte(word[0])
   928  			for _, b := range runeBytes {
   929  				tokens = append(tokens, (*encoder.BytesEncoder)[b])
   930  			}
   931  		} else {
   932  			tokens = Tokens{encoder.Encoder[word[0]]}
   933  		}
   934  		encoder.Cache.Add(text, tokens)
   935  		return tokens
   936  	}
   937  
   938  	// Main merge loop using findBestPair
   939  	for {
   940  		bestRank, found := findBestPair(word, encoder.BpeRanks)
   941  		if !found {
   942  			break
   943  		}
   944  
   945  		// Reset newWord for reuse
   946  		newWord = newWord[:0]
   947  		first := bestRank.bigram.Left
   948  		second := bestRank.bigram.Right
   949  
   950  		for i := 0; i < len(word); {
   951  			j := pos(word, first, i)
   952  			if j == -1 {
   953  				newWord = append(newWord, word[i:]...)
   954  				break
   955  			}
   956  			newWord = append(newWord, word[i:j]...)
   957  			i = j
   958  
   959  			if word[i] == first && i < len(word)-1 && word[i+1] == second {
   960  				newWord = append(newWord, first+second)
   961  				i += 2
   962  			} else {
   963  				newWord = append(newWord, word[i])
   964  				i += 1
   965  			}
   966  		}
   967  
   968  		word, newWord = newWord, word
   969  
   970  		if len(word) == 1 {
   971  			break
   972  		}
   973  	}
   974  
   975  	// Final encoding
   976  	tokens := make(Tokens, 0, len(word))
   977  	for _, token := range word {
   978  		if lookup, ok := encoder.Encoder[token]; ok {
   979  			tokens = append(tokens, lookup)
   980  		} else if encoder.BytesEncoder != nil {
   981  			runeBytes := []byte(token)
   982  			for _, b := range runeBytes {
   983  				tokens = append(tokens, (*encoder.BytesEncoder)[b])
   984  			}
   985  		}
   986  	}
   987  
   988  	encoder.Cache.Add(text, tokens)
   989  	return tokens
   990  }
   991  
   992  func (encoder *GPTEncoder) getSpecials() map[int][][]rune {
   993  	lenMap := make(map[int][][]rune)
   994  	for k := range encoder.Specials {
   995  		keyLen := len(k)
   996  		keyRunes := []rune(k)
   997  		if entry, ok := lenMap[keyLen]; ok {
   998  			lenMap[keyLen] = append(entry, keyRunes)
   999  		} else {
  1000  			lenMap[keyLen] = [][]rune{keyRunes}
  1001  		}
  1002  	}
  1003  	return lenMap
  1004  }
  1005  
  1006  func (encoder *GPTEncoder) splitWords(
  1007  	text string,
  1008  	specialToken bool, specialsNode *RuneNode,
  1009  ) []*string {
  1010  	// Some things such as KoboldAI have a 'replacement' rule, where
  1011  	// they replace tokens such as `\n` with `</s>` for Fairseq
  1012  	// handling.
  1013  	for replaced, replacement := range encoder.replacements {
  1014  		text = strings.ReplaceAll(text, replaced, replacement)
  1015  	}
  1016  	text = encoder.Normalizer.Replace(text)
  1017  
  1018  	idxes := encoder.pattern.FindAllStringIndex(text, -1)
  1019  	words := make([]*string, 0, len(idxes)+1)
  1020  	for idx := range idxes {
  1021  		word := text[idxes[idx][0]:idxes[idx][1]]
  1022  		if encoder.lowerCase {
  1023  			word = strings.ToLower(word)
  1024  		}
  1025  
  1026  		if !encoder.prefixSpace {
  1027  			word = strings.TrimSpace(word)
  1028  		}
  1029  
  1030  		if len(word) > 0 {
  1031  			words = append(words, &word)
  1032  		}
  1033  	}
  1034  
  1035  	// Finally, if we have a special token, we cap it off.
  1036  	if specialToken {
  1037  		runeString := string(specialsNode.runes)
  1038  		words = append(words, &runeString)
  1039  	}
  1040  	return words
  1041  }
  1042  
  1043  type NextRuneFunc func() (rune, int, error)
  1044  type WordCallback func([]string)
  1045  
  1046  func (encoder *GPTEncoder) makeWordSplitter(
  1047  	nextRuneFunc NextRuneFunc,
  1048  	wordCallback WordCallback,
  1049  	completeCallback func(),
  1050  ) func() {
  1051  	if encoder.regexWordSplitterTree == nil {
  1052  		regexString := encoder.pattern.String()
  1053  		if regexString == "" {
  1054  			regexString = SPLIT_REGEX
  1055  		}
  1056  		regexAST, err := syntax.Parse(regexString, syntax.Perl)
  1057  		if err != nil {
  1058  			panic(err)
  1059  		}
  1060  		regexAST.Simplify()
  1061  		encoder.regexWordSplitterTree = CreateRegexTree(regexAST)
  1062  		encoder.wordSplitterMap = encoder.regexWordSplitterTree.GeneratePathMap()
  1063  	}
  1064  
  1065  	// How many words we send on each callback.
  1066  	const batchSize = 256
  1067  	workQueue := make(chan []string, encoder.SplitterThreads*2)
  1068  	wg := sync.WaitGroup{}
  1069  	wg.Add(1)
  1070  
  1071  	// Single consumer goroutine that processes batches
  1072  	go func() {
  1073  		defer wg.Done()
  1074  		for batch := range workQueue {
  1075  			wordCallback(batch)
  1076  		}
  1077  	}()
  1078  
  1079  	return func() {
  1080  		specialsRuneRoot := encoder.SpecialsTree
  1081  		runeAccumulator := make([]rune, 0, encoder.runeBufSz)
  1082  		wordBatch := make([]string, 0, batchSize)
  1083  		specialToken := false
  1084  		specialsCandidates := make(RuneNodes, 0, 16)
  1085  		var candidateNode *RuneNode
  1086  
  1087  		// Define a function to flush the batch once it is full
  1088  		flushBatch := func() {
  1089  			if len(wordBatch) > 0 {
  1090  				// Copy the batch to prevent race conditions
  1091  				batch := make([]string, len(wordBatch))
  1092  				copy(batch, wordBatch)
  1093  				workQueue <- batch
  1094  				wordBatch = wordBatch[:0]
  1095  			}
  1096  		}
  1097  
  1098  		// appendBatch appends a batch of words to the wordBatch and flushes
  1099  		// the batch if it is full.
  1100  		appendBatch := func(words []string, forceFlush bool) {
  1101  			if len(words) == 0 && (!forceFlush || len(wordBatch) == 0) {
  1102  				return
  1103  			}
  1104  			// If we are appending words, we need to process them
  1105  			for _, word := range words {
  1106  				if encoder.lowerCase {
  1107  					word = strings.ToLower(word)
  1108  				}
  1109  				if !encoder.prefixSpace {
  1110  					word = strings.TrimSpace(word)
  1111  				}
  1112  
  1113  				// After every word, we append it to the wordBatch
  1114  				// We also check if the wordBatch is full and flush it
  1115  				if len(word) > 0 {
  1116  					wordBatch = append(wordBatch, word)
  1117  					if len(wordBatch) >= batchSize {
  1118  						flushBatch()
  1119  					}
  1120  				}
  1121  			}
  1122  			// forceFlush forces the batch to be flushed.
  1123  			// Useful for ensuring that the last batch is flushed.
  1124  			if forceFlush && len(wordBatch) > 0 {
  1125  				// If we are forcing a flush, we flush the batch after processing
  1126  				// the words
  1127  				for i, word := range wordBatch {
  1128  					if encoder.lowerCase {
  1129  						word = strings.ToLower(word)
  1130  					}
  1131  					if !encoder.prefixSpace {
  1132  						word = strings.TrimSpace(word)
  1133  					}
  1134  					wordBatch[i] = word
  1135  				}
  1136  
  1137  				flushBatch()
  1138  			}
  1139  		}
  1140  
  1141  		processLine := func(
  1142  			line []rune,
  1143  			special bool,
  1144  			node *RuneNode,
  1145  		) {
  1146  			// Find all words by using the regexWordSplitterTree
  1147  			matches := encoder.regexWordSplitterTree.EvaluateRegexTree(
  1148  				line, encoder.wordSplitterMap,
  1149  			)
  1150  			for _, word := range matches {
  1151  				if encoder.lowerCase {
  1152  					word = strings.ToLower(word)
  1153  				}
  1154  				if !encoder.prefixSpace {
  1155  					word = strings.TrimSpace(word)
  1156  				}
  1157  				appendBatch([]string{word}, false)
  1158  			}
  1159  
  1160  			// Re-add the special token if it was removed
  1161  			// This is done after the regex splitting to ensure that the special
  1162  			// token is not split by the regex
  1163  			if special && node != nil {
  1164  				special := string(node.runes)
  1165  				appendBatch([]string{special}, false)
  1166  			}
  1167  		}
  1168  
  1169  		// Apply replacements defined in the runetree
  1170  		checkAndReplaceNode := func() {
  1171  			matchLen := len(candidateNode.runes)
  1172  			accTruncIdx := len(runeAccumulator) - matchLen
  1173  			runeAccumulator = append(
  1174  				runeAccumulator[:accTruncIdx],
  1175  				*candidateNode.replacement...,
  1176  			)
  1177  			specialsCandidates = specialsCandidates[:0]
  1178  			candidateNode = specialsRuneRoot
  1179  			specialToken = false
  1180  		}
  1181  		// We repeatedly call the nextRuneFunc until it returns an error or other break
  1182  		// condition. This fills the runeAccumulator with runes until we have a full line.
  1183  		for {
  1184  			// Collect runes until newline or special token
  1185  			for {
  1186  				r, size, err := nextRuneFunc()
  1187  				if size == 0 || err != nil {
  1188  					break
  1189  				}
  1190  
  1191  				runeAccumulator = append(runeAccumulator, r)
  1192  
  1193  				if r == '\n' {
  1194  					break
  1195  				}
  1196  
  1197  				// Conduct replacement and special token checks
  1198  				candidateNode = specialsCandidates.evaluate(r)
  1199  				if candidateNode != nil {
  1200  					if candidateNode.replacement != nil {
  1201  						checkAndReplaceNode()
  1202  					} else if candidateNode.terminal {
  1203  						specialToken = true
  1204  						break
  1205  					}
  1206  				}
  1207  
  1208  				candidateNode = specialsRuneRoot.evaluate(r)
  1209  				if candidateNode != nil {
  1210  					specialsCandidates = append(
  1211  						specialsCandidates,
  1212  						candidateNode,
  1213  					)
  1214  					if candidateNode.replacement != nil {
  1215  						checkAndReplaceNode()
  1216  					} else if candidateNode.terminal {
  1217  						specialToken = true
  1218  						break
  1219  					}
  1220  				}
  1221  			}
  1222  
  1223  			// If we have no runes, we are done
  1224  			if len(runeAccumulator) == 0 {
  1225  				appendBatch(nil, true)
  1226  				wordCallback(nil)
  1227  				break
  1228  			}
  1229  
  1230  			// Apply replacements and normalization
  1231  			if specialToken && candidateNode != nil {
  1232  				runeAccumulator =
  1233  					runeAccumulator[:len(runeAccumulator)-len(
  1234  						candidateNode.runes,
  1235  					)]
  1236  			}
  1237  			if len(encoder.replacements) > 0 {
  1238  				runeAccumulator = replaceRunes(
  1239  					runeAccumulator, encoder.replacements,
  1240  				)
  1241  			}
  1242  
  1243  			if encoder.Normalizer != nil {
  1244  				if encoder.normalizerStringMap != nil && len(encoder.normalizerStringMap) > 0 {
  1245  					runeAccumulator = replaceRunes(
  1246  						runeAccumulator, encoder.normalizerStringMap,
  1247  					)
  1248  				}
  1249  			}
  1250  			// If we don't recognize the regex, we default to using the regex package
  1251  			processLine(
  1252  				runeAccumulator, specialToken,
  1253  				candidateNode,
  1254  			)
  1255  			runeAccumulator = runeAccumulator[:0]
  1256  			candidateNode = specialsRuneRoot
  1257  			specialToken = false
  1258  			specialsCandidates = specialsCandidates[:0]
  1259  		}
  1260  
  1261  		// Close the work queue and wait for all workers to finish
  1262  		close(workQueue)
  1263  		wg.Wait()
  1264  		completeCallback()
  1265  	}
  1266  }
  1267  
  1268  func (encoder *GPTEncoder) WordSplitter(reader io.RuneReader) func() *string {
  1269  	moreWords := make(chan []string, encoder.wordChanSz)
  1270  	wordSplitter := encoder.makeWordSplitter(
  1271  		reader.ReadRune,
  1272  		func(words []string) {
  1273  			if len(words) > 0 {
  1274  				moreWords <- words
  1275  			}
  1276  		},
  1277  		func() {
  1278  			close(moreWords)
  1279  		},
  1280  	)
  1281  	go wordSplitter()
  1282  
  1283  	var wordsBuffer []string
  1284  	idx := 1
  1285  
  1286  	return func() *string {
  1287  		var more bool
  1288  		if idx >= len(wordsBuffer) {
  1289  			wordsBuffer, more = <-moreWords
  1290  			if !more {
  1291  				return nil
  1292  			}
  1293  			idx = 1
  1294  		} else {
  1295  			idx++
  1296  		}
  1297  		word := wordsBuffer[idx-1]
  1298  		return &word
  1299  	}
  1300  }
  1301  
  1302  // Helper functions
  1303  func trimSpacesRunes(runes []rune) []rune {
  1304  	// Runespace trims leading and trailing spaces from a slice of runes
  1305  	// and returns the trimmed slice.
  1306  	start := 0
  1307  	end := len(runes)
  1308  	for start < end && unicode.IsSpace(runes[start]) {
  1309  		start++
  1310  	}
  1311  	for end > start && unicode.IsSpace(runes[end-1]) {
  1312  		end--
  1313  	}
  1314  	return runes[start:end]
  1315  }
  1316  
  1317  func toLowercaseRunes(runes []rune) []rune {
  1318  	// Runespace converts a slice of runes to lowercase and returns the
  1319  	// lowercase slice.
  1320  	for i := 0; i < len(runes); i++ {
  1321  		runes[i] = unicode.ToLower(runes[i])
  1322  	}
  1323  	return runes
  1324  }
  1325  
  1326  func replaceRunes(
  1327  	runes []rune,
  1328  	replacements map[string]string,
  1329  ) []rune {
  1330  	runeReplacements := make(map[string][]rune, len(replacements))
  1331  	for k, v := range replacements {
  1332  		runeReplacements[k] = []rune(v)
  1333  	}
  1334  
  1335  	// Iterate through runes
  1336  	for i := 0; i < len(runes); i++ {
  1337  		matchFound := false
  1338  
  1339  		// Iterate through replacements
  1340  		for k, v := range runeReplacements {
  1341  			if len(v) == 0 {
  1342  				continue
  1343  			}
  1344  			if runes[i] == []rune(k)[0] {
  1345  				matchFound = true
  1346  				if len(v) > 1 {
  1347  					// Try to get a slice of the runes to match the key, if it matches, replace it
  1348  					keySlice := runes[i : i+len(k)]
  1349  					for j := 0; j < len(keySlice); j++ {
  1350  						if keySlice[j] != []rune(k)[j] {
  1351  							matchFound = false
  1352  							break
  1353  						}
  1354  					}
  1355  					if matchFound {
  1356  						runes = append(runes[:i], []rune(v)...)
  1357  					}
  1358  
  1359  				} else {
  1360  					runes[i] = v[0]
  1361  				}
  1362  			}
  1363  		}
  1364  		if !matchFound {
  1365  			continue
  1366  		}
  1367  	}
  1368  	return runes
  1369  }
  1370  
  1371  // Excludes new line whitespaces. Thus is horizontal whitespace.
  1372  func isHorizontalWhitespace(r rune) bool {
  1373  	return r == ' ' || r == '\t' || r == '\r'
  1374  }
  1375  
  1376  func isSymbol(r rune) bool {
  1377  	return !unicode.IsLetter(r) && !unicode.IsNumber(r) && !isHorizontalWhitespace(r) && !isNewLine(r)
  1378  }
  1379  
  1380  func isNewLine(r rune) bool {
  1381  	// While \n is often considered a whitespace, we treat it as a symbol
  1382  	// to ensure it is always a separate token.
  1383  	return r == '\n'
  1384  }
  1385  
  1386  func (encoder *GPTEncoder) SplitWords(text *string) *[]string {
  1387  	words := make([]string, 0)
  1388  	nextWord := encoder.WordSplitter(strings.NewReader(*text))
  1389  	for {
  1390  		word := nextWord()
  1391  		if word == nil {
  1392  			break
  1393  		}
  1394  		words = append(words, *word)
  1395  	}
  1396  	return &words
  1397  }
  1398  
  1399  func (encoder *GPTEncoder) toUnicode(text *string) string {
  1400  	if encoder.BytesEncoder != nil {
  1401  		runes := []rune(*text)
  1402  		return string(runes)
  1403  	}
  1404  	textBytes := []byte(*text)
  1405  	outArr := make([]rune, len(*text))
  1406  	for idx := range textBytes {
  1407  		outArr[idx] = encoder.byteToRune[textBytes[idx]]
  1408  	}
  1409  	return string(outArr)
  1410  }
  1411  
  1412  func (encoder *GPTEncoder) encodeTokens(tokens *[]string) (encoded Tokens) {
  1413  	encoded = make(Tokens, len(*tokens))
  1414  	for idx := range *tokens {
  1415  		encoded[idx] = encoder.Encoder[(*tokens)[idx]]
  1416  	}
  1417  	return encoded
  1418  }
  1419  
  1420  var tokenAccumulatorPool = sync.Pool{
  1421  	New: func() interface{} {
  1422  		// Size based on typical artifact size - adjust if needed
  1423  		tokens := make(Tokens, 0, 65536)
  1424  		return &tokens
  1425  	},
  1426  }
  1427  
  1428  func (encoder *GPTEncoder) StreamingEncode(reader io.RuneReader) func(int) *Tokens {
  1429  	nextWord := encoder.WordSplitter(reader)
  1430  
  1431  	// Get accumulator from pool
  1432  	accumulatorPtr := tokenAccumulatorPool.Get().(*Tokens)
  1433  	accumulator := (*accumulatorPtr)[:0] // Reset length but keep capacity
  1434  
  1435  	if encoder.encloseEosBos || encoder.encloseBos {
  1436  		accumulator = append(accumulator, encoder.BosToken)
  1437  	}
  1438  
  1439  	eosReturned := false
  1440  
  1441  	return func(desiredTokens int) *Tokens {
  1442  		for {
  1443  			if len(accumulator) >= desiredTokens {
  1444  				chunk := make(Tokens, desiredTokens)
  1445  				copy(chunk, accumulator[:desiredTokens])
  1446  
  1447  				// Preserve capacity while shifting remaining tokens
  1448  				copy(accumulator, accumulator[desiredTokens:])
  1449  				accumulator = accumulator[:len(accumulator)-desiredTokens]
  1450  				return &chunk
  1451  			}
  1452  
  1453  			word := nextWord()
  1454  			if word == nil {
  1455  				if (encoder.encloseEosBos || encoder.encloseEos) && !eosReturned {
  1456  					accumulator = append(
  1457  						accumulator, encoder.EosToken,
  1458  					)
  1459  					eosReturned = true
  1460  				}
  1461  
  1462  				if len(accumulator) > 0 {
  1463  					chunk := make(Tokens, len(accumulator))
  1464  					copy(chunk, accumulator)
  1465  					accumulator = accumulator[:0]
  1466  					return &chunk
  1467  				}
  1468  
  1469  				// Return accumulator to pool when done
  1470  				tokenAccumulatorPool.Put(accumulatorPtr)
  1471  				return nil
  1472  			}
  1473  
  1474  			var encodedTokens Tokens
  1475  			if specialToken, isSpecial := encoder.Specials[*word]; isSpecial {
  1476  				encodedTokens = Tokens{
  1477  					encoder.Encoder[string(encoder.Decoder[specialToken[0]])],
  1478  				}
  1479  			} else {
  1480  				fragment := encoder.toUnicode(word)
  1481  				encodedTokens = encoder.ToBPE(fragment)
  1482  			}
  1483  			accumulator = append(accumulator, encodedTokens...)
  1484  
  1485  			if encoder.ignoreMerges {
  1486  				continue
  1487  			}
  1488  
  1489  			if offsetIdx := len(accumulator) - len(encodedTokens) - 1; offsetIdx >= 0 {
  1490  				idx := offsetIdx
  1491  				for idx < len(accumulator)-1 {
  1492  					pair := TokenPair{accumulator[idx], accumulator[idx+1]}
  1493  					if merged, ok := encoder.TokenMerges[pair]; ok && merged != 0 {
  1494  						before := accumulator[:idx]
  1495  						var after Tokens
  1496  						if idx+2 < len(accumulator) {
  1497  							after = accumulator[idx+2:]
  1498  						}
  1499  						accumulator = append(before, merged)
  1500  						accumulator = append(accumulator, after...)
  1501  						if idx > 0 {
  1502  							idx--
  1503  						}
  1504  					} else {
  1505  						idx++
  1506  					}
  1507  				}
  1508  			}
  1509  		}
  1510  	}
  1511  }
  1512  
  1513  func (encoder *GPTEncoder) EncodeReader(reader io.RuneReader) *Tokens {
  1514  	encoded := make(Tokens, 0, 4096)
  1515  	nextTokens := encoder.StreamingEncode(reader)
  1516  	for {
  1517  		tokens := nextTokens(4096)
  1518  		if tokens == nil {
  1519  			break
  1520  		}
  1521  		encoded = append(encoded, *tokens...)
  1522  	}
  1523  	return &encoded
  1524  }
  1525  
  1526  // EncodeBuffer takes a byte array and encodes it into Tokens in another
  1527  // byte array.
  1528  func (encoder *GPTEncoder) EncodeBuffer(buffer *[]byte) (
  1529  	*[]byte, uint64,
  1530  ) {
  1531  	runeReader := bytes.NewReader(*buffer)
  1532  	nextTokens := encoder.StreamingEncode(runeReader)
  1533  	buf := bytes.NewBuffer(make([]byte, 0, 4096))
  1534  	var count uint64 = 0
  1535  	for {
  1536  		tokens := nextTokens(2048)
  1537  		if tokens == nil {
  1538  			break
  1539  		}
  1540  		_ = binary.Write(buf, binary.LittleEndian, tokens)
  1541  		count += uint64(len(*tokens))
  1542  	}
  1543  	bufBytes := buf.Bytes()
  1544  	return &bufBytes, count
  1545  }
  1546  
  1547  // Encode encodes a string into a sequence of tokens.
  1548  func (encoder *GPTEncoder) Encode(text *string) *Tokens {
  1549  	// Temporary hack - inject a space token at the end of the accumulator for mistral-tokenizer
  1550  	if encoder.VocabId == VOCAB_ID_MISTRAL {
  1551  		*text = " " + *text
  1552  	}
  1553  	runeReader := strings.NewReader(*text)
  1554  
  1555  	return encoder.EncodeReader(runeReader)
  1556  }
  1557  
  1558  // Get
  1559  // Looks up text in the Encoder, and returns the Token representation of it. If
  1560  // the text is not found, then nil is returned.
  1561  func (encoder *GPTEncoder) Get(text string) *Token {
  1562  	if token, ok := encoder.Encoder[text]; !ok {
  1563  		return nil
  1564  	} else {
  1565  		return &token
  1566  	}
  1567  }
  1568  
  1569  // Decode Tokens back into a string, handling unicode.
  1570  func (encoder *GPTEncoder) Decode(encoded *Tokens) (text string) {
  1571  	// Check if we have an end of word token defined.
  1572  	convertEndOfWord := false
  1573  	if encoder.endOfWord != "" {
  1574  		convertEndOfWord = true
  1575  	}
  1576  	// Accumulate tokens until it is unicode complete.
  1577  	tokensAcc := make(Tokens, 0)
  1578  	runesAcc := make([]rune, 0)
  1579  	for i, token := range *encoded {
  1580  		tokensAcc = append(tokensAcc, token)
  1581  		bs := make([]byte, 0)
  1582  		// If we have a byte token and a byteEncoder, then we need to
  1583  		// accumulate until we have a full rune. If we are at the end of
  1584  		// the encoded tokens, then we need to decode the accumulated
  1585  		// tokens regardless.
  1586  		flagHoldForByte := encoder.IsByteToken(&token) &&
  1587  			encoder.IsLastTokenByte(&tokensAcc)
  1588  
  1589  		if encoder.TokensReady(&tokensAcc) && (i == len(*encoded)-1 || !flagHoldForByte) {
  1590  			for _, safeToken := range tokensAcc {
  1591  				if v, ok := encoder.Decoder[safeToken]; ok {
  1592  					bs = append(bs, v...)
  1593  				}
  1594  			}
  1595  			// Convert our bytearray to string, interpreting as UTF-8 and
  1596  			// then to 32-bit runes. If we don't have a BytesEncoder, then we
  1597  			// are using GPT BPE's byte encoding algorithm for Unicode.
  1598  			var runes = []rune(string(bs))
  1599  			var fragment string
  1600  			if encoder.BytesEncoder == nil {
  1601  				decoded := make([]byte, len(runes))
  1602  				// Convert our runes into 8-bit bytes using a 256-slot table.
  1603  				for runeIdx := range runes {
  1604  					decoded[runeIdx] = encoder.runeToByte[runes[runeIdx]]
  1605  				}
  1606  				fragment = string(decoded)
  1607  				runes = []rune(fragment)
  1608  			} else {
  1609  				fragment = string(bs)
  1610  				runes = []rune(fragment)
  1611  			}
  1612  			// Decode our final token representation into a Unicode string.
  1613  			if convertEndOfWord {
  1614  				if strings.HasSuffix(fragment, encoder.endOfWord) {
  1615  					runes = runes[:len(runes)-len(encoder.endOfWord)]
  1616  					if len(runes) == 1 && runes[0] == '\'' {
  1617  					} else {
  1618  						runes = append(runes, ' ')
  1619  					}
  1620  				}
  1621  				if len(runes) == 1 &&
  1622  					unicode.IsNumber(runes[0]) {
  1623  					runes = append(runes, ' ')
  1624  				}
  1625  				// If we have a punctuation rune, and the previous rune is a
  1626  				// space, then we remove the space. This is to handle cases
  1627  				// like " ,".
  1628  				if len(runesAcc) > 1 && runeIsIn(
  1629  					runes[0],
  1630  					encoder.PuncRunes,
  1631  				) && unicode.IsSpace(
  1632  					runesAcc[len(
  1633  						runesAcc,
  1634  					)-1],
  1635  				) {
  1636  					runesAcc = runesAcc[:len(runesAcc)-1]
  1637  				}
  1638  			}
  1639  			runesAcc = append(runesAcc, runes...)
  1640  			tokensAcc = tokensAcc[:0]
  1641  		}
  1642  	}
  1643  
  1644  	return string(runesAcc)
  1645  }
  1646  
  1647  // DecodeBuffer
  1648  // Decode Tokens from a byte array into a string.
  1649  func (encoder *GPTEncoder) DecodeBuffer(
  1650  	encoded *[]byte,
  1651  	useUint32 bool,
  1652  ) (text string) {
  1653  	// First convert our bytearray into uint32 `Token` array.
  1654  	var tokens *Tokens
  1655  	if useUint32 {
  1656  		tokens = types.TokensFromBin32(encoded)
  1657  	} else {
  1658  		tokens = types.TokensFromBin(encoded)
  1659  	}
  1660  	// Decode our tokens into a string.
  1661  	return encoder.Decode(tokens)
  1662  }
  1663  
  1664  // IsByteToken
  1665  // Determine if the token is a byte token.
  1666  func (encoder *GPTEncoder) IsByteToken(token *Token) bool {
  1667  	if encoder.BytesEncoder == nil {
  1668  		return false
  1669  	}
  1670  	return int(*token) <= len(*encoder.BytesEncoder)
  1671  }
  1672  
  1673  // IsLastTokenByte
  1674  // Determine if the last token in the sequence is a byte token.
  1675  func (encoder *GPTEncoder) IsLastTokenByte(tokens *Tokens) bool {
  1676  	if encoder.BytesEncoder == nil || len(*tokens) == 0 {
  1677  		return false
  1678  	}
  1679  	return encoder.IsByteToken(&(*tokens)[len(*tokens)-1])
  1680  }
  1681  
  1682  // TokensReady
  1683  // Determine if the sequence of Tokens given is ready to be serialized
  1684  // to string, based on if the sequence will produce valid Unicode runes.
  1685  func (encoder *GPTEncoder) TokensReady(tokens *Tokens) bool {
  1686  	if encoder.BytesEncoder != nil {
  1687  		return true
  1688  	}
  1689  	good := 0
  1690  	need := 0
  1691  	for tokenIdx := range *tokens {
  1692  		tok := (*tokens)[tokenIdx]
  1693  		var req int
  1694  		if int(tok) >= len(encoder.unitrim) {
  1695  			// Don't error out on tokens that we don't know about.
  1696  			req = 0
  1697  		} else {
  1698  			req = encoder.unitrim[(*tokens)[tokenIdx]]
  1699  		}
  1700  
  1701  		if !(need+req < 0) {
  1702  			need += req
  1703  		}
  1704  		if req == 0 {
  1705  			// reset need to 0 to avoid being stuck when we have invalid
  1706  			// unicode being generated.
  1707  			need = 0
  1708  		}
  1709  		if need == 0 {
  1710  			good = tokenIdx + 1
  1711  		}
  1712  	}
  1713  	return good == len(*tokens)
  1714  }
  1715  
  1716  // TrimTokens
  1717  // Trims the given Tokens to tokens that produce valid unicode.
  1718  func (encoder *GPTEncoder) TrimTokens(tokens *Tokens) (trimmed *Tokens) {
  1719  	trimmed = tokens
  1720  	for {
  1721  		if len(*trimmed) == 0 {
  1722  			return trimmed
  1723  		}
  1724  		if encoder.TokensReady(trimmed) {
  1725  			return trimmed
  1726  		} else {
  1727  			newTrimmed := (*trimmed)[0 : len(*trimmed)-1]
  1728  			trimmed = &newTrimmed
  1729  		}
  1730  	}
  1731  }