github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/gpt_bpe_test.go (about) 1 package gpt_bpe 2 3 import ( 4 "bufio" 5 "encoding/base64" 6 "encoding/hex" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "io/ioutil" 11 "log" 12 "os" 13 "regexp" 14 "runtime" 15 "strings" 16 "testing" 17 "time" 18 19 "github.com/wbrown/gpt_bpe/types" 20 21 "github.com/stretchr/testify/assert" 22 "github.com/ulikunitz/xz" 23 "github.com/wbrown/gpt_bpe/resources" 24 ) 25 26 var clipEncoder GPTEncoder 27 var gpt2Encoder GPTEncoder 28 var pileEncoder GPTEncoder 29 var nerdstashV2Encoder GPTEncoder 30 var llama2Encoder GPTEncoder 31 var llama3Encoder GPTEncoder 32 var mistralEncoder GPTEncoder 33 var corpus string 34 var clipCorpus string 35 var largeCorpus *string 36 37 // var corpus2 string 38 var gpt2Encoded *Tokens 39 var pileEncoded *Tokens 40 var clipEncoded *Tokens 41 var nerdstashEncoded *Tokens 42 var llama2Encoded *Tokens 43 var llama3Encoded *Tokens 44 var mistralEncoded *Tokens 45 var unicodeTrimTests []*Tokens 46 47 var benchmarkPrefix string 48 var encoders map[string]*GPTEncoder 49 50 const largeCorpusPath = "resources/wiki.train.raw.xz" 51 52 func handleRead(path string) []byte { 53 if textBytes, err := os.ReadFile(path); err != nil { 54 log.Fatalf("Error opening `%s`: %v", path, err) 55 } else { 56 return textBytes 57 } 58 return nil 59 } 60 61 func loadUnicodeTrimTests(path string) []*Tokens { 62 tests := make([]*Tokens, 0) 63 fileBlob := string(handleRead(path)) 64 fileLines := strings.Split(fileBlob, "\n") 65 for idx := range fileLines { 66 line := fileLines[idx] 67 if len(line) == 0 { 68 continue 69 } 70 unicodeTrimTest := make(Tokens, 0) 71 if err := json.Unmarshal( 72 []byte(line), 73 &unicodeTrimTest, 74 ); err != nil { 75 log.Fatalf("Error unmarshaling `%s`: %v", path, err) 76 } 77 tests = append(tests, &unicodeTrimTest) 78 } 79 return tests 80 } 81 82 func Chunks(s string, chunkSize int) []string { 83 if len(s) == 0 { 84 return nil 85 } 86 if chunkSize >= len(s) { 87 return []string{s} 88 } 89 var chunks []string = make([]string, 0, (len(s)-1)/chunkSize+1) 90 currentLen := 0 91 currentStart := 0 92 for i := range s { 93 if currentLen == chunkSize { 94 chunks = append(chunks, s[currentStart:i]) 95 currentLen = 0 96 currentStart = i 97 } 98 currentLen++ 99 } 100 chunks = append(chunks, s[currentStart:]) 101 return chunks 102 } 103 104 func getStringBounds( 105 i int, 106 output string, 107 decoded string, 108 ) ( 109 left int, 110 right int, 111 ) { 112 if i < 20 { 113 left = 0 114 } else { 115 left = i - 20 116 } 117 if len(output) < len(decoded) { 118 right = len(output) 119 } else { 120 right = len(decoded) 121 } 122 if i+20 < right { 123 right = i + 20 124 } 125 return left, right 126 } 127 128 func init() { 129 isBench := isRunningBenchmarkTest() 130 131 // These will all be null 132 encoders = map[string]*GPTEncoder{ 133 "gpt2-tokenizer": &gpt2Encoder, 134 "pile-tokenizer": &pileEncoder, 135 "clip-tokenizer": &clipEncoder, 136 "nerdstash_v2-tokenizer": &nerdstashV2Encoder, 137 "llama-tokenizer": &llama2Encoder, 138 "llama3-tokenizer": &llama3Encoder, 139 "mistral-tokenizer": &mistralEncoder, 140 } 141 142 // Load all encoders up front, as this is desirable for benchmarking 143 if isBench { 144 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 145 pileEncoder = *CacheLoadEncoder("pile-tokenizer") 146 clipEncoder = *CacheLoadEncoder("clip-tokenizer") 147 nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer") 148 llama2Encoder = *CacheLoadEncoder("llama-tokenizer") 149 llama3Encoder = *CacheLoadEncoder("llama3-tokenizer") 150 mistralEncoder = *CacheLoadEncoder("mistral-tokenizer") 151 } 152 153 textBytes := handleRead("resources/frankenstein.txt") 154 clipBytes := handleRead("resources/frankenstein_clip.txt") 155 corpus = string(textBytes) 156 clipCorpus = string(clipBytes) 157 unicodeTrimTests = loadUnicodeTrimTests("resources/trim_tests.jsonl") 158 var err error 159 _, err = GetLargeCorpus() 160 if err != nil { 161 log.Fatalf("Error opening `%s`: %v", largeCorpusPath, err) 162 } 163 } 164 165 // isRunningBenchmarkTest 166 // Check if we're running a benchmark test 167 // We assume a benchmark test is defined as a test 168 // that begins with a BENCHMARK_PREFIX. This is 169 // by default "Benchmark", but can be configured 170 // by the environment variable BENCHMARK_PREFIX 171 func isRunningBenchmarkTest() bool { 172 prefix, ok := os.LookupEnv("BENCHMARK_PREFIX") 173 if ok { 174 benchmarkPrefix = prefix 175 } else { 176 benchmarkPrefix = "Benchmark" 177 } 178 179 for _, arg := range os.Args { 180 parsedArg, err := regexp.Compile(arg) 181 if err != nil { 182 log.Fatalf("Failed parsing CLI args using regexp") 183 } 184 prefix, _ = parsedArg.LiteralPrefix() 185 if strings.HasPrefix(prefix, benchmarkPrefix) { 186 log.Println("Running benchmark test, so loading encoders up front...") 187 return true 188 } 189 } 190 return false 191 } 192 193 // CacheLoadEncoder 194 // Loads an encoder, but only once 195 func CacheLoadEncoder(vocabId string) *GPTEncoder { 196 if encoders[vocabId].Encoder == nil { 197 encoder, err := NewEncoder(vocabId) 198 if err != nil { 199 log.Fatalf("Error loading encoder `%s`: ", vocabId) 200 } 201 202 // Cache the encoder for later use 203 encoders[vocabId] = encoder 204 return encoder 205 } 206 return encoders[vocabId] 207 } 208 209 func GetLargeCorpus() (*string, error) { 210 if largeCorpus == nil { 211 var err error 212 largeCorpus, err = DecompressXZ(largeCorpusPath) 213 if err != nil { 214 return nil, err 215 } 216 } 217 return largeCorpus, nil 218 } 219 220 func TestMain(m *testing.M) { 221 m.Run() 222 } 223 224 type TrimTest struct { 225 Input string 226 Direction TrimDirection 227 Limit uint 228 Expected string 229 } 230 231 const sent1 = "This is test sentence 1. This is test sentence 2. This is test sentence 3." 232 const sent2 = "\nThis is test sentence 4.\nThis is test sentence 5.\nThis is test sentence 6.\n" 233 const hindiSentence = "व्याकरण शास्त्रीय परिभाषाएँ : डॉ. पर्णदत्त सिंह द्वारा हिंदी पीडीऍफ़ पुस्तक" 234 const jpSentence = "「そんな心構えで、本当に俺の『未練』を果たせるのか? 知ってのとおり、俺の『未練』は『<|rubycover|>相川渦波<|rubystart|>おまえ<|rubyend|>の成長を最後まで見届けること』だ。……言っとくが、俺は年季が入ってる上に、拗らせに拗らせた元神学者。俺の『大いなる<|rubycover|>救世主<|rubystart|>マグナ・メサイア<|rubyend|>』の『理想』は高いぞ? 少なくとも、この『血陸』を止められないようじゃ、任せ切れないな」\n<|mtsentence|><|mtsenglish|>Please check if the meat is being roasted at the right heat.<|mtsjapanese|>焼き肉の火加減を見なさい。<|mtsentenceend|>\n<|mtvocab|><|mtvjapanese|>[ぶんけんがく] 文献学<|mtvenglish|>(n) philology<|mtvocabend|>" 235 236 var TrimSentencesTests = []TrimTest{ 237 {sent1, TrimTop, 10, 238 " This is test sentence 3."}, 239 {sent1, TrimTop, 20, 240 " This is test sentence 2. This is test sentence 3."}, 241 {sent1, TrimTop, 30, 242 sent1}, 243 {sent2, TrimTop, 10, 244 "\nThis is test sentence 6.\n"}, 245 {sent2, TrimTop, 18, 246 "\nThis is test sentence 5.\nThis is test sentence 6.\n"}, 247 {sent2, TrimTop, 30, 248 sent2}, 249 {sent1, TrimBottom, 10, 250 "This is test sentence 1."}, 251 {sent1, TrimBottom, 20, 252 "This is test sentence 1. This is test sentence 2."}, 253 {sent1, TrimBottom, 30, 254 sent1}, 255 {sent2, TrimBottom, 10, 256 "\nThis is test sentence 4.\n"}, 257 {sent2, TrimBottom, 18, 258 "\nThis is test sentence 4.\nThis is test sentence 5.\n"}, 259 {sent2, TrimBottom, 30, 260 sent2}, 261 } 262 263 func TestHFResolution(t *testing.T) { 264 _, err := NewEncoder("EleutherAI/gpt-j-6B") 265 if err != nil { 266 t.Error(err) 267 } 268 _, err = NewEncoder("nonexist/nonexist") 269 if err == nil { 270 t.Error(errors.New("failed to return error on non-existent model")) 271 } 272 } 273 274 func TestHFTokenzier(t *testing.T) { 275 enc, err := NewEncoder("EleutherAI/gpt-j-6B") 276 if err != nil { 277 t.Error(err) 278 } 279 sent := "The fox jumped over the hare." 280 hfTokens := enc.Encode(&sent) 281 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 282 gptTokens := gpt2Encoder.Encode(&sent) 283 assert.Equal(t, hfTokens, gptTokens) 284 } 285 286 func TestFairSeqTokenizer(t *testing.T) { 287 enc, err := NewEncoder("KoboldAI/fairseq-dense-2.7B") 288 if err != nil { 289 t.Error(err) 290 return 291 } 292 sent := "The fox jumped over the hare.\nThe turtle is faster than the hare." 293 tokens := Tokens{464, 21831, 11687, 625, 262, 387, 260, 25970, 82, 29, 294 464, 28699, 318, 5443, 621, 262, 387, 260, 13} 295 fsTokens := enc.Encode(&sent) 296 assert.Equal(t, *fsTokens, tokens) 297 } 298 299 var TrimNewLinesTests = append( 300 TrimSentencesTests[3:5], TrimSentencesTests[9:11]..., 301 ) 302 303 func TestGPTEncoder_TrimIncompleteSentence(t *testing.T) { 304 testStr := "This is a test. He says, \"This is an unterminated quote. She says, this is actually terminated.\" This is awesome! This is incomplete " 305 expected := "This is a test. He says, \"This is an unterminated quote. She says, this is actually terminated.\" This is awesome!" 306 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 307 trimmed, _ := gpt2Encoder.TrimIncompleteSentence(gpt2Encoder.Encode(&testStr)) 308 output := gpt2Encoder.Decode(trimmed) 309 if expected != output { 310 t.Error("output != expected; output := ", expected) 311 } 312 } 313 314 func TestGPTEncoder_TrimTokens(t *testing.T) { 315 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 316 for testIdx := range unicodeTrimTests { 317 assert.NotEqual( 318 t, len( 319 *gpt2Encoder.TrimTokens( 320 unicodeTrimTests[testIdx], 321 ), 322 ), 323 len(*unicodeTrimTests[testIdx]), 324 ) 325 } 326 } 327 328 func TestGPTEncoder_TrimNewlines(t *testing.T) { 329 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 330 for testIdx := range TrimNewLinesTests { 331 test := TrimNewLinesTests[testIdx] 332 res, err := gpt2Encoder.TrimNewlines( 333 gpt2Encoder.Encode(&test.Input), 334 test.Direction, test.Limit, 335 ) 336 if err != nil { 337 t.Error("TrimNewlines: error:", err) 338 } 339 decodeRes := gpt2Encoder.Decode(res) 340 if decodeRes != test.Expected { 341 t.Error( 342 "TrimNewlines: expected '" + test.Expected + "' got '" + 343 decodeRes + "'", 344 ) 345 } 346 } 347 } 348 349 func TestGPTEncoder_TrimSentences(t *testing.T) { 350 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 351 for testIdx := range TrimSentencesTests { 352 test := TrimSentencesTests[testIdx] 353 res, err := gpt2Encoder.TrimSentences( 354 gpt2Encoder.Encode(&test.Input), 355 test.Direction, test.Limit, 356 ) 357 if err != nil { 358 t.Error("TrimSentences: error:", err) 359 } 360 decodeRes := gpt2Encoder.Decode(res) 361 if decodeRes != test.Expected { 362 t.Error( 363 "TrimSentences: expected '" + test.Expected + "' got '" + 364 decodeRes + "'", 365 ) 366 } 367 } 368 } 369 370 type SplitTest struct { 371 Input string 372 Expected []string 373 } 374 375 var SplitTests = []SplitTest{ 376 {"we'll go jump in a lake.", 377 []string{"we", "'ll", " go", " jump", " in", " a", " lake", 378 "."}}, 379 {"multiple encoded spaces.", 380 []string{"multiple", " ", "encoded", " spaces", "."}}, 381 {"Capitalized Words Are Cool", 382 []string{"Capitalized", " Words", " Are", " Cool"}}, 383 {"we'LL test irregular cApitalizatioN.", 384 []string{"we", "'", "LL", " test", " irregular", 385 " cApitalizatioN", "."}}, 386 {"multilines\nare awesome", 387 []string{"multilines", "\n", "are", " awesome"}}, 388 {"\nstarting with multilines\nis awesome", 389 []string{"\n", "starting", " with", " multilines", 390 "\n", "is", " awesome"}}, 391 {"we'll go jump<|endoftext|> in a lake.", 392 []string{"we", "'ll", " go", " jump", "<|endoftext|>", 393 " in", " a", " lake", "."}}, 394 {"we'll go jump<|end\noftext|> in a lake.", 395 []string{"we", "'ll", " go", " jump", "<|", "end", "\n", 396 "oftext", "|>", " in", " a", " lake", "."}}, 397 } 398 399 func TestGPTEncoder_Split(t *testing.T) { 400 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 401 402 for testIdx := range SplitTests { 403 test := SplitTests[testIdx] 404 assert.Equal(t, test.Expected, *(gpt2Encoder.SplitWords(&test.Input))) 405 } 406 } 407 408 func DecompressXZ(path string) (*string, error) { 409 corpusHandle, err := os.Open(path) 410 if err != nil { 411 return nil, err 412 } 413 defer corpusHandle.Close() 414 decompressorHandle, err := xz.NewReader(corpusHandle) 415 if err != nil { 416 return nil, err 417 } 418 decompressed, err := ioutil.ReadAll(decompressorHandle) 419 if err != nil { 420 return nil, err 421 } 422 decompressedStr := string(decompressed) 423 return &decompressedStr, nil 424 } 425 426 func OpenXZStream(path string) (*xz.Reader, *os.File, error) { 427 corpusHandle, err := os.Open(path) 428 if err != nil { 429 return nil, nil, err 430 } 431 decompressorHandle, err := xz.NewReader(corpusHandle) 432 if err != nil { 433 return nil, nil, err 434 } 435 return decompressorHandle, corpusHandle, nil 436 } 437 438 type EncoderTest struct { 439 Input string 440 GPT2Expected Tokens 441 PileExpected Tokens 442 CLIPExpected Tokens 443 NerdstashExpected Tokens 444 } 445 446 var GPTEncoderTests = []EncoderTest{ 447 {"… …", 448 Tokens{1399, 3926}, 449 Tokens{2866, 8139}, 450 Tokens{49406, 959, 959, 49407}, 451 Tokens{49289, 5512}}, 452 {"<|endoftext|>", 453 Tokens{50256}, 454 Tokens{0}, 455 Tokens{49406, 49407, 49407}, 456 Tokens{3}}, 457 {" <|endoftext|>\n<|endoftext|>foo", 458 Tokens{220, 50256, 198, 50256, 21943}, 459 Tokens{209, 0, 187, 0, 12110}, 460 Tokens{49406, 49407, 49407, 23435, 49407}, 461 Tokens{49209, 3, 85, 3, 49225, 3292}}, 462 {" <|padding|>test", 463 Tokens{220, 50257, 9288}, 464 Tokens{209, 1, 2566}, 465 Tokens{49406, 27, 347, 3798, 796, 91, 285, 1628, 49407}, 466 Tokens{3252, 49376, 42545, 49376, 49405, 10180}, 467 }, 468 } 469 470 func TestGPTEncoder_Encode(t *testing.T) { 471 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 472 473 // This test is to check if the GPTEncoder is able to encode the tokens correctly 474 start := time.Now() 475 tokenCt := len(*gpt2Encoder.Encode(&corpus)) 476 duration := time.Since(start) 477 t.Logf( 478 "%v bytes into %v tokens over %v", 479 len(corpus), tokenCt, duration, 480 ) 481 for testIdx := range GPTEncoderTests { 482 tokensPtr := *gpt2Encoder.Encode( 483 &(GPTEncoderTests[testIdx].Input), 484 ) 485 assert.Equal(t, tokensPtr, GPTEncoderTests[testIdx].GPT2Expected) 486 } 487 } 488 489 func TestGPTEncode(t *testing.T) { 490 // This test is to check if the GPTEncoder is able to encode the tokens correctly 491 strin := "The quick brown fox jumps over the lazy dog." 492 expected := Tokens{464, 21831, 11687, 625, 262, 387, 260, 25970, 82, 29, 464, 28699, 318, 5443, 621, 262, 387, 260, 13} 493 encoded := gpt2Encoder.Encode(&strin) 494 fmt.Printf("Encoded: with commas:") 495 for _, token := range *encoded { 496 fmt.Printf("%v, ", token) 497 } 498 assert.Equal(t, *encoded, expected) 499 } 500 501 func TestGPTEncoder_StreamingEncode(t *testing.T) { 502 // This test is to check if the GPTEncoder is able to encode the tokens 503 // correctly 504 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 505 start := time.Now() 506 corpusRunes := strings.NewReader(*largeCorpus) 507 // Set our profiler up 508 profileHandle, _ := os.Create("streaming.prof") 509 defer profileHandle.Close() 510 runtime.GC() 511 //pprof.StartCPUProfile(profileHandle) 512 nextTokens := gpt2Encoder.StreamingEncode(corpusRunes) 513 tokenCt := 0 514 for { 515 tokens := nextTokens(16384) 516 if tokens == nil { 517 break 518 } 519 tokenCt += len(*tokens) 520 } 521 duration := time.Since(start) 522 //pprof.StopCPUProfile() 523 t.Logf( 524 "streaming encode: %d tokens/sec", 525 int64(float64(tokenCt)/duration.Seconds()), 526 ) 527 } 528 529 func TestCLIPEncoder_Encode(t *testing.T) { 530 clipEncoder = *CacheLoadEncoder("clip-tokenizer") 531 532 // This test is to check if the CLIPEncoder is able to encode the tokens correctly 533 start := time.Now() 534 tokenCt := len(*clipEncoder.Encode(&corpus)) 535 duration := time.Since(start) 536 t.Logf( 537 "%v bytes into %v tokens over %v", 538 len(corpus), tokenCt, duration, 539 ) 540 for testIdx := range GPTEncoderTests { 541 testStr := GPTEncoderTests[testIdx].Input 542 tokensPtr := *clipEncoder.Encode(&testStr) 543 assert.Equal(t, GPTEncoderTests[testIdx].CLIPExpected, tokensPtr) 544 } 545 } 546 547 func TestPileEncoder_Encode(t *testing.T) { 548 pileEncoder = *CacheLoadEncoder("pile-tokenizer") 549 550 // This test is to check if the PileEncoder is able to encode the tokens correctly 551 start := time.Now() 552 tokenCt := len(*pileEncoder.Encode(&corpus)) 553 duration := time.Since(start) 554 t.Logf( 555 "%v bytes into %v tokens over %v", 556 len(corpus), tokenCt, duration, 557 ) 558 for testIdx := range GPTEncoderTests { 559 tokensPtr := *pileEncoder.Encode( 560 &(GPTEncoderTests[testIdx].Input), 561 ) 562 assert.Equal(t, GPTEncoderTests[testIdx].PileExpected, tokensPtr) 563 } 564 } 565 566 func TestNerdstashEncoder_Encode(t *testing.T) { 567 // This test is to check if the NerdstashEncoder is able to encode the tokens correctly 568 start := time.Now() 569 nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer") 570 571 tokenCt := len(*nerdstashV2Encoder.Encode(&corpus)) 572 duration := time.Since(start) 573 t.Logf( 574 "%v bytes into %v tokens over %v", 575 len(corpus), tokenCt, duration, 576 ) 577 for testIdx := range GPTEncoderTests { 578 tokensPtr := *nerdstashV2Encoder.Encode( 579 &(GPTEncoderTests[testIdx].Input), 580 ) 581 assert.Equal(t, GPTEncoderTests[testIdx].NerdstashExpected, tokensPtr) 582 } 583 } 584 585 func TestNerdstashEncoder_EncodeSpaces(t *testing.T) { 586 nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer") 587 588 // This test is to check if the NerdstashEncoder is able to encode spaces correctly 589 testString := " 12 => '',\n" 590 expected := Tokens{16, 124, 125, 10631, 1695, 49231, 85} 591 encoded := nerdstashV2Encoder.Encode(&testString) 592 assert.Equal(t, expected, *encoded) 593 } 594 595 func TestNerdstashEncoder_Encode2(t *testing.T) { 596 nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer") 597 598 // read the jsonl test file in 599 testFile, err := os.Open("resources/subset.jsonl") 600 if err != nil { 601 t.Error(err) 602 } 603 defer testFile.Close() 604 scanner := bufio.NewScanner(testFile) 605 scanner.Split(bufio.ScanLines) 606 type testLineStruct struct { 607 Text *string `json:"text"` 608 Hex *string `json:"hex"` 609 Encoded Tokens `json:"encoded"` 610 } 611 612 passCt := 0 613 failCt := 0 614 615 for scanner.Scan() { 616 jsonLine := scanner.Text() 617 testLine := testLineStruct{} 618 err := json.Unmarshal([]byte(jsonLine), &testLine) 619 if err != nil { 620 t.Error(err) 621 } 622 expected := testLine.Encoded 623 var inputStr string 624 if testLine.Hex != nil { 625 inputBytes, hexErr := hex.DecodeString(*testLine.Hex) 626 if hexErr != nil { 627 t.Error(hexErr) 628 } 629 inputStr = string(inputBytes) 630 } else { 631 inputStr = *testLine.Text 632 } 633 // encode the string 634 encoded := nerdstashV2Encoder.Encode(&inputStr) 635 // check that the encoded string is the same as the expected 636 if !assert.Equal(t, expected, *encoded) { 637 t.Logf("failure on input: `%v`", inputStr) 638 expectedRepr := []string{} 639 for _, token := range expected { 640 expectedRepr = append( 641 expectedRepr, 642 string(nerdstashV2Encoder.Decoder[token]), 643 ) 644 } 645 actualRepr := []string{} 646 for _, token := range *encoded { 647 actualRepr = append( 648 actualRepr, 649 string(nerdstashV2Encoder.Decoder[token]), 650 ) 651 } 652 t.Logf("expected: |%s", strings.Join(expectedRepr, "|")) 653 t.Logf("actual: |%s", strings.Join(actualRepr, "|")) 654 failCt += 1 655 } else { 656 passCt += 1 657 } 658 } 659 t.Logf("pass: %v, fail: %v", passCt, failCt) 660 } 661 662 func TestNerdstashEncoder_Decode(t *testing.T) { 663 nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer") 664 665 // This test is to check if the NerdstashEncoder is able to decode the tokens correctly 666 for testIdx := range GPTEncoderTests { 667 decodedStr := nerdstashV2Encoder.Decode( 668 &(GPTEncoderTests[testIdx].NerdstashExpected), 669 ) 670 assert.Equal(t, GPTEncoderTests[testIdx].Input, decodedStr) 671 } 672 } 673 674 func TestGPTEncoder_Decode2(t *testing.T) { 675 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 676 677 // This test is to check if the GPTEncoder is able to decode the tokens correctly from a base64 encoded string 678 gpt2EncodedCorpus := "NrGIEOQBRzFfAQEBCAE5GeADPCFGAQhdBgFhBkcHXwEBATM5HgGilUYBpAdDEaUheR8iAQEBmgSnbyQpRgHIjaYBiSQYLfoHYwHogg0A0AHsGFUmpgEGAcd0qApjAzwa7hscAeHAYwEGAbYRB3UiAax0PQPjAgoXpgEGAZgE6G2gAWMExy5GAb5szQdGAXUBAR2gAVQBRgG8CdYBYbCgAe4QAxg/NA0AdyoiAZMGOXL8AWlmAQGgFXknNlIGAdADLiciAT4B6lk=" 679 decodedCorpus := "frying whatever they touched with a sizzled smell that fills the air along with a shower of sparks that land harmlessly elsewhere and a few stray drops that drip from fingers burned black as charcoal.The shock waves from the blasts cause many nearby trees to topple as the earth shakes and trembles underfoot from the power unleashed by each blast that destroys anything that was struck by it that wasn't shielded by heavy metal plates." 680 if binTokens, err := base64.StdEncoding.DecodeString(gpt2EncodedCorpus); err != nil { 681 log.Println("ERROR:", err) 682 } else { 683 tokens := types.TokensFromBin(&binTokens) 684 tokens, err = gpt2Encoder.TrimIncompleteSentence(tokens) 685 if err != nil { 686 t.Error(err) 687 } 688 assert.Equal(t, gpt2Encoder.Decode(tokens), decodedCorpus) 689 } 690 } 691 692 func TestGPTEncoder_Decode(t *testing.T) { 693 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 694 695 // This test is to check if the GPTEncoder is able to decode the tokens correctly 696 if gpt2Encoded == nil { 697 corpEncoded := gpt2Encoder.Encode(&corpus) 698 gpt2Encoded = corpEncoded 699 } 700 start := time.Now() 701 decoded := gpt2Encoder.Decode(gpt2Encoded) 702 duration := time.Since(start) 703 tokenNumBytes := len(decoded) 704 t.Logf( 705 "%v tokens into %v bytes over %v\n", 706 len(*gpt2Encoded), tokenNumBytes, duration, 707 ) 708 assert.Equal(t, corpus, decoded) 709 } 710 711 // BUG: CLIP TOKENIZER has a bug that causes 'the to be split into 712 // "'t<w>he<w>" instead of "'<w>the<w>". This causes the 713 // clipCorpus to be different from the corpus. This is a bug in 714 // the CLIP tokenizer from huggingface that was used to generate 715 // the clipCorpus. The decoded corpus is correct in this test. 716 // We stop the test right before the bug. 717 func TestCLIPEncoder_Decode(t *testing.T) { 718 clipEncoder = *CacheLoadEncoder("clip-tokenizer") 719 720 if clipEncoded == nil { 721 corpEncoded := clipEncoder.Encode(&corpus) 722 clipEncoded = corpEncoded 723 } 724 start := time.Now() 725 decoded := clipEncoder.Decode(clipEncoded) 726 duration := time.Since(start) 727 tokenNumBytes := len(decoded) 728 idxToStop := 229550 729 t.Logf( 730 "%v tokens into %v bytes over %v\n", len(*clipEncoded), tokenNumBytes, 731 duration, 732 ) 733 for idx := range clipCorpus { 734 if idx > idxToStop { 735 break 736 } 737 738 if clipCorpus[idx] != decoded[idx] { 739 t.Errorf( 740 "idx: %d, clipCorpus: %v, decoded: %v\n", idx, 741 clipCorpus[idx], decoded[idx], 742 ) 743 break 744 } 745 } 746 // assert.Equal(t, clipCorpus, decoded) 747 } 748 749 func TestPileEncoder_Decode(t *testing.T) { 750 pileEncoder = *CacheLoadEncoder("pile-tokenizer") 751 752 // This test is to check if the PileEncoder is able to decode the tokens correctly 753 if pileEncoded == nil { 754 corpEncoded := pileEncoder.Encode(&corpus) 755 pileEncoded = corpEncoded 756 } 757 start := time.Now() 758 decoded := pileEncoder.Decode(pileEncoded) 759 duration := time.Since(start) 760 tokenNumBytes := len(decoded) 761 t.Logf( 762 "%v tokens into %v bytes over %v\n", 763 len(*pileEncoded), tokenNumBytes, duration, 764 ) 765 range_data := corpus 766 if len(corpus) > len(decoded) { 767 range_data = decoded 768 } 769 if len(corpus) != len(decoded) { 770 t.Errorf(fmt.Sprintf("%v != %v", len(corpus), len(decoded))) 771 } 772 for idx := range range_data { 773 if corpus[idx] != decoded[idx] { 774 t.Errorf( 775 "%v != %v", clipCorpus[idx-20:idx+20], 776 decoded[idx-20:idx+20], 777 ) 778 return 779 } 780 } 781 } 782 783 func TestGPTEncoder_TokensReady(t *testing.T) { 784 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 785 786 // This test is to check if the TokensReady function is able to determine if the tokens are ready for context 787 multiTokenAsterism := "⁂" 788 tokens := gpt2Encoder.Encode(&multiTokenAsterism) 789 fmt.Printf("Tokens: %v, len: %v\n", tokens, len(*tokens)) 790 var idx int 791 for idx = range *tokens { 792 tokenSlice := (*tokens)[0 : idx+1] 793 fmt.Printf("TokenSlice: %v, len: %v\n", tokenSlice, len(tokenSlice)) 794 if gpt2Encoder.TokensReady(&tokenSlice) { 795 break 796 } 797 } 798 if idx < len(*tokens)-1 { 799 t.Errorf( 800 "Expected TokensReady on idx: %d for `%s`", idx, 801 multiTokenAsterism, 802 ) 803 } 804 } 805 806 func TestGPTEncoder_TokensReadyContext(t *testing.T) { 807 pileEncoder = *CacheLoadEncoder("pile-tokenizer") 808 809 // This test is to check if the TokensReady function is able to determine if the tokens are ready for context 810 var tokens Tokens 811 badContext, err := os.ReadFile("resources/badcontext.json") 812 if err != nil { 813 t.Errorf("Could not read badcontext.json: %v", err) 814 } 815 unmarshalErr := json.Unmarshal(badContext, &tokens) 816 if unmarshalErr != nil { 817 t.Errorf("Could not unmarshal badcontext.json: %v", unmarshalErr) 818 } 819 if !pileEncoder.TokensReady(&tokens) { 820 t.Errorf("Expected TokensReady to be true for badcontext.json") 821 } 822 } 823 824 func TestUnitrimFunctionality(t *testing.T) { 825 // This test is to check if the makeUnitrimArr function is able to generate the unitrim array correctly 826 for _, tokenizer := range []string{"clip-tokenizer", "gpt2-tokenizer", "pile-tokenizer"} { 827 encoderFile := fmt.Sprintf( 828 "resources/data/%s/encoder.json", tokenizer, 829 ) 830 unitrimFile := fmt.Sprintf( 831 "resources/data/%s/unitrim.json", tokenizer, 832 ) 833 834 // make sure the files exist 835 if _, err := os.Stat(encoderFile); os.IsNotExist(err) { 836 t.Errorf("Could not find file %s\n", encoderFile) 837 } 838 if _, err := os.Stat(unitrimFile); os.IsNotExist(err) { 839 t.Errorf("Could not find file %s\n", unitrimFile) 840 } 841 842 // read in the Encoder and unitrim files 843 encoderBytes, err := os.ReadFile(encoderFile) 844 if err != nil { 845 t.Errorf("Could not read Encoder file: %v\n", err) 846 } 847 // unmarshal the Encoder file 848 var encoder map[string]Token 849 err = json.Unmarshal(encoderBytes, &encoder) 850 if err != nil { 851 t.Errorf("Could not unmarshal Encoder file: %v\n", err) 852 } 853 854 // read in the unitrim file 855 unitrimBytes, err := os.ReadFile(unitrimFile) 856 if err != nil { 857 t.Errorf("Could not read unitrim file: %v\n", err) 858 } 859 // unmarshal the unitrim file 860 var unitrim []int 861 err = json.Unmarshal(unitrimBytes, &unitrim) 862 if err != nil { 863 t.Errorf("Could not unmarshal unitrim file: %v\n", err) 864 } 865 866 // get generated array for unitrim with the makeUnitrimArr function 867 generatedArray := makeUnitrimArr(encoder) 868 869 // check that the generated array is the same as the unitrim array 870 fmt.Printf( 871 "Generated array length: %d, unitrim array length: %d\n", 872 len(generatedArray), len(unitrim), 873 ) 874 if len(generatedArray) != len(unitrim) { 875 t.Errorf("Generated array and unitrim array are not the same length\n") 876 } 877 878 for i := range generatedArray { 879 if generatedArray[i] != unitrim[i] { 880 fmt.Printf( 881 "Generated array: %v and unitrim array: %v at index %d are not the same\n", 882 generatedArray[i], unitrim[i], i, 883 ) 884 fmt.Printf( 885 "mismatched unicode is: %c\n", rune(generatedArray[i]), 886 ) 887 t.Errorf("Generated array and unitrim array are not the same\n") 888 } 889 } 890 891 fmt.Printf("Length and contents of generated array and unitrim array are the same\n") 892 } 893 } 894 895 func TestLlamaEncoder_Encode(t *testing.T) { 896 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 897 898 // This test is to check if the encoder is able to encode a basic string 899 start := time.Now() 900 tokenCt := len(*gpt2Encoder.Encode(&corpus)) 901 duration := time.Since(start) 902 t.Logf( 903 "%v bytes into %v tokens over %v", 904 len(corpus), tokenCt, duration, 905 ) 906 for testIdx := range GPTEncoderTests { 907 tokensPtr := *gpt2Encoder.Encode( 908 &(GPTEncoderTests[testIdx].Input), 909 ) 910 assert.Equal(t, tokensPtr, GPTEncoderTests[testIdx].GPT2Expected) 911 } 912 } 913 914 func TestLlamaTwoEncoder_Encode(t *testing.T) { 915 llama2Encoder = *CacheLoadEncoder("llama-tokenizer") 916 917 // This test is to check if the encoder is able to encode a basic string 918 testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." 919 llamaTokens := llama2Encoder.Encode(&testString) 920 assert.Equal( 921 t, llamaTokens, 922 &Tokens{1576, 1701, 29916, 12500, 287, 975, 278, 447, 276, 29889, 13, 1576, 260, 4227, 280, 338, 8473, 1135, 278, 447, 276, 29889}, 923 ) 924 } 925 926 func TestLlamaTwoTokenizerDecode(t *testing.T) { 927 llama2Encoder = *CacheLoadEncoder("llama-tokenizer") 928 929 // This test is to check if the decoder is able to decode the tokens correctly 930 outputString := "<s>The fox jumped over the hare.\nThe turtle is faster than the hare." 931 llamaTokens := Tokens{1, 1576, 1701, 29916, 12500, 287, 975, 278, 447, 276, 29889, 13, 1576, 260, 4227, 280, 338, 8473, 1135, 278, 447, 276, 29889} 932 output := llama2Encoder.Decode(&llamaTokens) 933 assert.Equal(t, outputString, output) 934 } 935 936 func TestLlamaTwoEncodeDecode(t *testing.T) { 937 llama2Encoder = *CacheLoadEncoder("llama-tokenizer") 938 939 // This test is to check if the encoder is able to encode and decode a basic string 940 testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." 941 outputString := "The fox jumped over the hare.\nThe turtle is faster than the hare." 942 llamaTokens := llama2Encoder.Encode(&testString) 943 output := llama2Encoder.Decode(llamaTokens) 944 assert.Equal(t, outputString, output) 945 } 946 947 // This is Mistral tokenizer V1, associated with 7b instruct https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2 948 func TestMistralEncoder_Encode(t *testing.T) { 949 mistralEncoder = *CacheLoadEncoder("mistral-tokenizer") 950 951 // This test is to check if the encoder is able to encode a basic string 952 testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." 953 mistralTokens := mistralEncoder.Encode(&testString) 954 assert.Equal( 955 t, mistralTokens, 956 &Tokens{1, 415, 285, 1142, 14949, 754, 272, 295, 492, 28723, 13, 1014, 261, 3525, 291, 349, 9556, 821, 272, 295, 492, 28723}, 957 ) 958 } 959 960 func TestMistralTokenizerDecode(t *testing.T) { 961 mistralEncoder = *CacheLoadEncoder("mistral-tokenizer") 962 963 // This test is to check if the decoder is able to decode the tokens correctly 964 outputString := "<s> The fox jumped over the hare.\nThe turtle is faster than the hare." 965 mistralTokens := Tokens{1, 415, 285, 1142, 14949, 754, 272, 295, 492, 28723, 13, 1014, 261, 3525, 291, 349, 9556, 821, 272, 295, 492, 28723} 966 output := mistralEncoder.Decode(&mistralTokens) 967 assert.Equal(t, outputString, output) 968 } 969 970 func TestMistralEncodeDecode(t *testing.T) { 971 mistralEncoder = *CacheLoadEncoder("mistral-tokenizer") 972 973 // This test is to check if the encoder is able to encode and decode a basic string 974 testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." 975 outputString := "<s> The fox jumped over the hare.\nThe turtle is faster than the hare." 976 mistralTokens := mistralEncoder.Encode(&testString) 977 output := mistralEncoder.Decode(mistralTokens) 978 assert.Equal(t, outputString, output) 979 } 980 981 func TestMistralEncodeDecodeFrankenstein(t *testing.T) { 982 mistralEncoder = *CacheLoadEncoder("mistral-tokenizer") 983 984 // This test is to check if the encoder is able to encode and decode the Frankenstein corpus 985 frankensteinCorpus := "resources/frankenstein.txt" 986 frankensteinText, err := os.ReadFile(frankensteinCorpus) 987 if err != nil { 988 t.Errorf("Error reading Frankenstein corpus: %v", err) 989 } 990 frankensteinString := string(frankensteinText) 991 mistralTokens := mistralEncoder.Encode(&frankensteinString) 992 frankensteinString = "<s>" + frankensteinString 993 output := mistralEncoder.Decode(mistralTokens) 994 for i := 0; i < len(output); i++ { 995 if output[i] != frankensteinString[i] { 996 t.Errorf( 997 "Mismatch at around index %d Expected: %v, Actual: %v", i, 998 string(frankensteinString[i]), string(output[i]), 999 ) 1000 break 1001 } 1002 } 1003 } 1004 1005 func TestMistralEncodeDecode_Emojis(t *testing.T) { 1006 mistralEncoder = *CacheLoadEncoder("mistral-tokenizer") 1007 1008 // This test is to check if the encoder is able to encode and decode emojis 1009 // Requires the ability to properly handle byte tokens in the encoder 1010 testString := "expensive 😦 padding ⁂ padding" 1011 tokens := mistralEncoder.Encode(&testString) 1012 output := mistralEncoder.Decode(tokens) 1013 testString = "<s>" + testString 1014 assert.Equal(t, testString, output) 1015 } 1016 1017 func TestMistralEncodeDecode_LargeCorpus(t *testing.T) { 1018 mistralEncoder = *CacheLoadEncoder("mistral-tokenizer") 1019 1020 // This test is to check if the encoder is able to encode and decode a large corpus 1021 referenceFile := "resources/test_references/753.txt" 1022 referenceBin := "resources/test_references/753_mistralv1.bin" 1023 referenceText, err := os.ReadFile(referenceFile) 1024 if err != nil { 1025 t.Errorf("Error reading reference file: %v", err) 1026 } 1027 referenceString := string(referenceText) 1028 // Need to decode the reference bin file 1029 referenceBinData, err := os.ReadFile(referenceBin) 1030 if err != nil { 1031 t.Errorf("Error reading reference bin file: %v", err) 1032 } 1033 referenceTokens := types.TokensFromBin32(&referenceBinData) 1034 // Encode the reference string 1035 mistralTokens := mistralEncoder.Encode(&referenceString) 1036 for i := 0; i < len(*mistralTokens); i++ { 1037 if (*mistralTokens)[i] != (*referenceTokens)[i] { 1038 t.Errorf( 1039 "Mismatch at around index %d Expected: %v, Actual: %v", i, 1040 (*referenceTokens)[i], (*mistralTokens)[i], 1041 ) 1042 } 1043 } 1044 assert.Equal(t, mistralTokens, referenceTokens) 1045 // Decode the tokens to check if the decoded string is the same as the reference string 1046 output := mistralEncoder.Decode(mistralTokens) 1047 referenceString = "<s>" + referenceString 1048 for i := 0; i < len(output); i++ { 1049 if output[i] != referenceString[i] { 1050 fmt.Printf("Mismatch at around index %d\n", i) 1051 fmt.Printf("Expected: %s\n", referenceString[i-20:i+20]) 1052 fmt.Printf("Actual: %s\n", output[i-20:i+20]) 1053 t.Errorf( 1054 "Mismatch at around index %d Expected: %s, Actual: %s", i, 1055 string(referenceString[i]), string(output[i]), 1056 ) 1057 break 1058 } 1059 } 1060 assert.Equal(t, referenceString, output) 1061 } 1062 1063 func TestLlama3Encoder_Encode(t *testing.T) { 1064 llama3Encoder = *CacheLoadEncoder("llama3-tokenizer") 1065 1066 // This test is to check if the encoder is able to encode a basic string 1067 testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." 1068 llamaTokens := llama3Encoder.Encode(&testString) 1069 fmt.Printf("Llama3 tokens: %v\n", llamaTokens) 1070 assert.Equal( 1071 t, llamaTokens, 1072 &Tokens{128000, 791, 39935, 27096, 927, 279, 96018, 627, 791, 37189, 374, 10819, 1109, 279, 96018, 13, 128001}, 1073 ) 1074 } 1075 1076 func TestLlama3TokenizerDecode(t *testing.T) { 1077 llama3Encoder = *CacheLoadEncoder("llama3-tokenizer") 1078 1079 // This test is to check if the decoder is able to decode the tokens correctly 1080 outputString := "<|begin_of_text|>The fox jumped over the hare.\nThe turtle is faster than the hare.<|end_of_text|>" 1081 llamaTokens := Tokens{128000, 791, 39935, 27096, 927, 279, 96018, 627, 791, 37189, 374, 10819, 1109, 279, 96018, 13, 128001} 1082 output := llama3Encoder.Decode(&llamaTokens) 1083 assert.Equal(t, outputString, output) 1084 } 1085 1086 func TestLlama3EncodeDecode(t *testing.T) { 1087 llama3Encoder = *CacheLoadEncoder("llama3-tokenizer") 1088 1089 // This test is to check if the encoder is able to encode and decode a basic string 1090 testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." 1091 outputString := "<|begin_of_text|>The fox jumped over the hare.\nThe turtle is faster than the hare.<|end_of_text|>" 1092 llamaTokens := llama3Encoder.Encode(&testString) 1093 output := llama3Encoder.Decode(llamaTokens) 1094 assert.Equal(t, outputString, output) 1095 } 1096 1097 func TestLlama3EncodeDecode_Merges(t *testing.T) { 1098 llama3Encoder = *CacheLoadEncoder("llama3-tokenizer") 1099 1100 // This test is to check if the encoder is able to merge the tokens correctly 1101 // If it fails, the merge function in the streaming_encode does not check for invalid merge pairs correctly 1102 testString := "Ah! Cornelius Agrippa! My dear Victor, d" 1103 outputString := "<|begin_of_text|>Ah! Cornelius Agrippa! My dear Victor, d<|end_of_text|>" 1104 llamaTokens := llama3Encoder.Encode(&testString) 1105 fmt.Printf("Llama3 tokens: %v\n", llamaTokens) 1106 output := llama3Encoder.Decode(llamaTokens) 1107 assert.Equal(t, outputString, output) 1108 } 1109 func TestLlama3Merge(t *testing.T) { 1110 llama3Encoder = *CacheLoadEncoder("llama3-tokenizer") 1111 1112 // This test is to check if the encoder is able to merge the tokens correctly 1113 // If it fails, the merge function in the streaming_encode does not check for invalid merge pairs correctly 1114 //testString := "Description\ndescription\n Description\n description" 1115 testString := "1234" 1116 llamaTokens := llama3Encoder.Encode(&testString) 1117 decodedTokens := make([]string, len(*llamaTokens)) 1118 for i := 0; i < len(*llamaTokens); i++ { 1119 decodedTokens[i] = string(llama3Encoder.Decoder[(*llamaTokens)[i]]) 1120 } 1121 1122 if len(*llamaTokens) != 4 { 1123 t.Errorf("Expected 4 tokens, got %d", len(*llamaTokens)) 1124 } 1125 1126 if decodedTokens[1] != "123" && decodedTokens[2] != "4" { 1127 t.Errorf( 1128 "Expected 123|4, got %s|%s", 1129 decodedTokens[1], decodedTokens[2], 1130 ) 1131 } 1132 } 1133 1134 func TestLlama3EncodeDecode_LargeCorpus(t *testing.T) { 1135 llama3Encoder = *CacheLoadEncoder("llama3-tokenizer") 1136 1137 // This test is to check if the encoder is able to encode and decode a large corpus 1138 referenceFile := "resources/test_references/753.txt" 1139 referenceBin := "resources/test_references/753_llama3.bin" 1140 referenceText, err := os.ReadFile(referenceFile) 1141 if err != nil { 1142 t.Errorf("Error reading reference file: %v", err) 1143 } 1144 referenceString := string(referenceText) 1145 // Need to decode the reference bin file 1146 referenceBinData, err := os.ReadFile(referenceBin) 1147 if err != nil { 1148 t.Errorf("Error reading reference bin file: %v", err) 1149 } 1150 referenceTokens := types.TokensFromBin32(&referenceBinData) 1151 // Encode the reference string 1152 llamaTokens := llama3Encoder.Encode(&referenceString) 1153 for i := 0; i < len(*llamaTokens); i++ { 1154 if (*llamaTokens)[i] != (*referenceTokens)[i] { 1155 t.Errorf( 1156 "Mismatch at around index %d Expected: %v, Actual: %v", i, 1157 (*referenceTokens)[i], (*llamaTokens)[i], 1158 ) 1159 } 1160 } 1161 // Check that the encoded tokens are the same as the reference tokens 1162 assert.Equal(t, llamaTokens, referenceTokens) 1163 // Decode the tokens 1164 output := llama3Encoder.Decode(llamaTokens) 1165 refDecoded := llama3Encoder.Decode(referenceTokens) 1166 // Check that the decoded string is the same as the reference string 1167 for i := 0; i < len(output); i++ { 1168 if output[i] != refDecoded[i] { 1169 left, right := getStringBounds(i, output, refDecoded) 1170 fmt.Printf("Mismatch at around index %d\n", i) 1171 fmt.Printf("Expected: %s\n", refDecoded[left:right]) 1172 fmt.Printf("Actual: %s\n", output[left:right]) 1173 break 1174 } 1175 } 1176 } 1177 1178 func TestLlama3EncodeDecodeFrankenstein(t *testing.T) { 1179 llama3Encoder = *CacheLoadEncoder("llama3-tokenizer") 1180 1181 // This test is to check if the encoder is able to encode and decode the Frankenstein corpus 1182 frankensteinCorpus := "resources/frankenstein.txt" 1183 frankensteinText, err := os.ReadFile(frankensteinCorpus) 1184 if err != nil { 1185 t.Errorf("Error reading Frankenstein corpus: %v", err) 1186 } 1187 frankensteinString := string(frankensteinText) 1188 llamaTokens := llama3Encoder.Encode(&frankensteinString) 1189 output := llama3Encoder.Decode(llamaTokens) 1190 frankensteinString = "<|begin_of_text|>" + frankensteinString + "<|end_of_text|>" 1191 for i := 0; i < len(output); i++ { 1192 if output[i] != frankensteinString[i] { 1193 left, right := getStringBounds(i, output, frankensteinString) 1194 fmt.Printf("Mismatch at around index %d\n", i) 1195 fmt.Printf("Expected: %s\n", frankensteinString[left:right]) 1196 fmt.Printf("Actual: %s\n", output[left:right]) 1197 break 1198 } 1199 } 1200 assert.Equal(t, frankensteinString, output) 1201 } 1202 1203 func TestReadTokenizerConfig(t *testing.T) { 1204 // This test is to check if the encoder is able to read the tokenizer_config.json file 1205 // json with eos, bos, pad as strings 1206 jsonStr := `{"eos_token": "TC", "bos_token": "TD", "pad_token": "TE"}` //cooresponds to 6669, 10989, 5428 in pythia vocab 1207 1208 //download filler model 1209 modelId := "EleutherAI/pythia-70m" 1210 destPath := "./TestReadTokenizerConfig" 1211 destPathPTR := &destPath 1212 defer os.RemoveAll(destPath) 1213 rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN") 1214 os.MkdirAll(destPath, 0755) 1215 _, rsrcErr := resources.ResolveResources( 1216 modelId, destPathPTR, 1217 resources.RESOURCE_MODEL, rsrcType, hfApiToken, 1218 ) 1219 if rsrcErr != nil { 1220 t.Errorf("Error downloading model resources: %s", rsrcErr) 1221 } 1222 1223 // replace tokenizer_config.json with jsonStr 1224 tokenizerConfigPath := destPath + "/tokenizer_config.json" 1225 err := os.WriteFile(tokenizerConfigPath, []byte(jsonStr), 0644) 1226 if err != nil { 1227 t.Errorf("Error writing to tokenizer_config.json: %v", err) 1228 } 1229 1230 // read tokenizer config by encoding a string 1231 encoder, err := NewEncoder(destPath) 1232 if err != nil { 1233 t.Errorf("Error creating encoder: %v", err) 1234 } 1235 1236 // check that the tokens are correct 1237 assert.Equal(t, encoder.EosToken, Token(6669)) 1238 assert.Equal(t, encoder.BosToken, Token(10989)) 1239 assert.Equal(t, encoder.PadToken, Token(5428)) 1240 1241 // Finish the test, allow defered cleanup 1242 fmt.Println("All Exists - Looks good.") 1243 } 1244 1245 func TestGPT2DefaultPadding(t *testing.T) { 1246 gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer") 1247 1248 // GPT2 defines a padding token, we test if it properly gets this token 1249 // corresponds to <|padding|> in the vocab 1250 assert.Equal(t, gpt2Encoder.PadToken, Token(50257)) 1251 assert.Equal(t, gpt2Encoder.Encoder["<|padding|>"], Token(50257)) 1252 } 1253 1254 func TestPilePadding(t *testing.T) { 1255 pileEncoder = *CacheLoadEncoder("pile-tokenizer") 1256 1257 // Pile defines a padding token, we test if it properly gets this token 1258 // corresponds to <|padding|> in the vocab 1259 assert.Equal(t, pileEncoder.PadToken, Token(1)) 1260 assert.Equal(t, pileEncoder.Encoder["<|padding|>"], Token(1)) 1261 } 1262 1263 func TestClipPadding(t *testing.T) { 1264 clipEncoder = *CacheLoadEncoder("clip-tokenizer") 1265 1266 // CLIP defines a padding token, we test if it properly gets this token 1267 // corresponds to <|endoftext|> in the vocab 1268 assert.Equal(t, clipEncoder.PadToken, Token(49407)) 1269 assert.Equal(t, clipEncoder.Encoder["<|endoftext|>"], Token(49407)) 1270 } 1271 1272 func TestNerdstashPadding(t *testing.T) { 1273 nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer") 1274 1275 // Nerdstash defines a padding token, we test if it properly gets this token 1276 // corresponds to <|pad|> in the vocab 1277 assert.Equal(t, nerdstashV2Encoder.PadToken, Token(0)) 1278 assert.Equal(t, nerdstashV2Encoder.Encoder["<|pad|>"], Token(0)) 1279 } 1280 1281 func TestLlamaPadding(t *testing.T) { 1282 llama2Encoder = *CacheLoadEncoder("llama-tokenizer") 1283 1284 // Llama doesn't define a padding token, we test if it properly defaults to 1285 // [PAD] as 65535 1286 assert.Equal(t, llama2Encoder.PadToken, Token(65535)) 1287 assert.Equal(t, llama2Encoder.Encoder["[PAD]"], Token(65535)) 1288 } 1289 1290 func TestMistralPadding(t *testing.T) { 1291 mistralEncoder = *CacheLoadEncoder("mistral-tokenizer") 1292 1293 // Mistral doesn't define a padding token, we test if it properly defaults to 1294 // [PAD] as 65535 1295 assert.Equal(t, mistralEncoder.PadToken, Token(65535)) 1296 assert.Equal(t, mistralEncoder.Encoder["[PAD]"], Token(65535)) 1297 } 1298 1299 func TestLlama3Padding(t *testing.T) { 1300 llama3Encoder = *CacheLoadEncoder("llama3-tokenizer") 1301 1302 // Llama doesn't define a padding token, we test if it properly defaults to 1303 // [PAD] as 4294967295 due to the uint32 max value 1304 assert.Equal(t, llama3Encoder.PadToken, Token(4294967295)) 1305 assert.Equal(t, llama3Encoder.Encoder["[PAD]"], Token(4294967295)) 1306 } 1307 1308 func TestGPTDecoder_Decode(t *testing.T) { 1309 // TBD 1310 } 1311 1312 func TestRankPairs(t *testing.T) { 1313 } 1314 1315 func downloadModel(modelId string, destPath string) error { 1316 // Download the model 1317 destPathPTR := &destPath 1318 rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN") 1319 os.MkdirAll(destPath, 0755) 1320 _, rsrcErr := resources.ResolveResources( 1321 modelId, destPathPTR, 1322 resources.RESOURCE_MODEL, rsrcType, hfApiToken, 1323 ) 1324 if rsrcErr != nil { 1325 return rsrcErr 1326 } 1327 return nil 1328 } 1329 1330 func assertFileExists(t *testing.T, filePath string) { 1331 if _, err := os.Stat(filePath); err != nil && os.IsNotExist(err) { 1332 t.Errorf("File does not exist: %s", filePath) 1333 } else if err != nil { 1334 t.Errorf("Error checking file: %v", err) 1335 } 1336 1337 } 1338 1339 func TestModelDownload(t *testing.T) { 1340 // Download the model 1341 modelId := "gpt2" 1342 destPath := "./TestModelDownload" 1343 err := downloadModel(modelId, destPath) 1344 if err != nil { 1345 os.RemoveAll(destPath) 1346 t.Errorf("Error downloading model: %v", err) 1347 } 1348 defer os.RemoveAll(destPath) 1349 1350 // Check that the model files are there 1351 // We want to check for the presence of the following files: 1352 // config.json, pytorch_model.bin, 1353 // tokenizer.json, vocab.json 1354 1355 // Check for config.json 1356 configPath := destPath + "/config.json" 1357 assertFileExists(t, configPath) 1358 1359 // Check for pytorch_model.bin 1360 modelPath := destPath + "/pytorch_model.bin" 1361 assertFileExists(t, modelPath) 1362 1363 // Check for tokenizer.json 1364 tokenizerConfigPath := destPath + "/tokenizer.json" 1365 assertFileExists(t, tokenizerConfigPath) 1366 1367 // Check for vocab.json 1368 vocabPath := destPath + "/vocab.json" 1369 assertFileExists(t, vocabPath) 1370 1371 // Finish the test, allow defered cleanup 1372 fmt.Println("All Exists - Looks good.") 1373 } 1374 1375 func TestPythiaRemoteDownloadTokenizer(t *testing.T) { 1376 // Tests the ability to download a tokenizer from a remote model 1377 // and use it to encode and decode strings 1378 modelId := "EleutherAI/pythia-70m" 1379 destPath := "./TestPythiaRemoteDownloadTokenizer" 1380 defer os.RemoveAll(destPath) 1381 encoderPythia, err := NewEncoder(modelId) 1382 if err != nil { 1383 t.Errorf("Error creating encoder: %v", err) 1384 } 1385 1386 // Attempt to tokenize 1387 testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." 1388 1389 // Encode the string 1390 encoded := encoderPythia.Encode(&testString) 1391 // Check that the encoded string is the same as the expected - Reference from python's transformers lib 1392 expected := Tokens{510, 30013, 16780, 689, 253, 419, 250, 15, 187, 510, 45993, 310, 7938, 685, 253, 419, 250, 15} 1393 if !assert.Equal(t, expected, *encoded) { 1394 t.Errorf("Expected: %v\nActual: %v", expected, *encoded) 1395 } 1396 } 1397 1398 func TestLlama3RemoteDownloadTokenizer(t *testing.T) { 1399 // Tests the ability to download a tokenizer from a remote model 1400 // and use it to encode and decode strings 1401 modelId := "Groq/Llama-3-Groq-8B-Tool-Use" // Original Llama3 model is gated 1402 destPath := "./TestLlama3RemoteDownloadTokenizer" 1403 defer os.RemoveAll(destPath) 1404 encoderLlama3, err := NewEncoder(modelId) 1405 if err != nil { 1406 t.Errorf("Error creating encoder: %v", err) 1407 } 1408 1409 // Attempt to tokenize 1410 testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." 1411 1412 // Encode the string 1413 encoded := encoderLlama3.Encode(&testString) 1414 // Check that the encoded string is the same as the expected - 128009 is 128001 in the original Llama3 model 1415 expected := Tokens{128000, 791, 39935, 27096, 927, 279, 96018, 627, 791, 37189, 374, 10819, 1109, 279, 96018, 13, 128009} 1416 if !assert.Equal(t, expected, *encoded) { 1417 t.Errorf("Expected: %v\nActual: %v", expected, *encoded) 1418 } 1419 } 1420 1421 func TestMistralRemoteDownloadTokenizer(t *testing.T) { 1422 // Tests the ability to download a tokenizer from a remote model 1423 // and use it to encode and decode strings 1424 modelId := "openaccess-ai-collective/tiny-mistral" 1425 //destPath := "./TestMistralRemoteDownloadTokenizer" 1426 //defer os.RemoveAll(destPath) 1427 encoderMistral, err := NewEncoder(modelId) 1428 if err != nil { 1429 t.Errorf("Error creating encoder: %v", err) 1430 } 1431 1432 // Attempt to tokenize 1433 testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." 1434 1435 // Encode the string 1436 encoded := encoderMistral.Encode(&testString) 1437 // Check that the encoded string is the same as the expected - Reference from python's transformers lib 1438 expected := Tokens{1, 1014, 285, 1142, 14949, 754, 272, 295, 492, 28723, 13, 1014, 261, 3525, 291, 349, 9556, 821, 272, 295, 492, 28723} 1439 if !assert.Equal(t, expected, *encoded) { 1440 t.Errorf("Expected: %v\nActual: %v", expected, *encoded) 1441 } 1442 } 1443 1444 func TestModelDownloadPythia(t *testing.T) { 1445 // Pythia uses a slightly different file structure, where 1446 // the vocab.json and merges.txt files are stored in the 1447 // tokenizer.json file. We want to check if we are able to 1448 // download the model and extract the vocab.json and merges.txt 1449 modelId := "EleutherAI/pythia-70m" 1450 destPath := "./TestModelDownloadPythia" 1451 err := downloadModel(modelId, destPath) 1452 if err != nil { 1453 os.RemoveAll(destPath) 1454 t.Errorf("Error downloading model: %v", err) 1455 } 1456 1457 // Check that the model files are there 1458 // We want to check for the presence of the following files: 1459 // config.json, pytorch_model.bin, 1460 // tokenizer.json, vocab.json 1461 1462 // Check for additional metadata files 1463 metaFiles := []string{"tokenizer.json", "vocab.json", "config.json", "pytorch_model.bin"} 1464 for _, metaFile := range metaFiles { 1465 metaPath := destPath + "/" + metaFile 1466 assertFileExists(t, metaPath) 1467 } 1468 1469 // Finish the test, allow defered cleanup 1470 fmt.Println("All Exists - Looks good.") 1471 } 1472 1473 func TestModelDownloadPythiaSharded(t *testing.T) { 1474 // This tests the model downloader's ability 1475 // to download a sharded model. 1476 1477 modelId := "EleutherAI/pythia-6.9b-deduped" 1478 destPath := "./TestModelDownloadPythiaSharded" 1479 err := downloadModel(modelId, destPath) 1480 if err != nil { 1481 os.RemoveAll(destPath) 1482 t.Errorf("Error downloading model: %v", err) 1483 } 1484 defer os.RemoveAll(destPath) 1485 1486 // Check that the model files are there 1487 // We want to check for the presence of the following files: 1488 // pytorch_model-00001-of-00002.bin, pytorch_model-00002-of-00002.bin, 1489 // pytorch_model.bin.index.json 1490 1491 // Check for pytorch_model-00001-of-00002.bin 1492 model1Path := destPath + "/pytorch_model-00001-of-00002.bin" 1493 assertFileExists(t, model1Path) 1494 1495 // Check for pytorch_model-00002-of-00002.bin 1496 model2Path := destPath + "/pytorch_model-00002-of-00002.bin" 1497 assertFileExists(t, model2Path) 1498 1499 // Check for pytorch_model.bin.index.json 1500 shardconfigPath := destPath + "/pytorch_model.bin.index.json" 1501 assertFileExists(t, shardconfigPath) 1502 1503 // Finish the test, allow defered cleanup 1504 fmt.Println("All Exists - Looks good.") 1505 1506 } 1507 1508 func TestModelDownloadLlama(t *testing.T) { 1509 // Pythia uses a slightly different file structure, where 1510 // the vocab.json and merges.txt files are stored in the 1511 // tokenizer.json file. We want to check if we are able to 1512 // download the model and extract the vocab.json and merges.txt 1513 modelId := "Maykeye/TinyLLama-v0" 1514 destPath := "./TestModelDownloadLlama" 1515 err := downloadModel(modelId, destPath) 1516 if err != nil { 1517 os.RemoveAll(destPath) 1518 t.Errorf("Error downloading model: %v", err) 1519 } 1520 defer os.RemoveAll(destPath) 1521 1522 // Check that the model files are there 1523 // We want to check for the presence of the following files: 1524 // config.json, pytorch_model.bin, 1525 // tokenizer.model, vocab.json 1526 1527 // Check for pytorch_model.bin 1528 singleModelPattern := regexp.MustCompile(`pytorch_model\.bin$`) 1529 re, err := regexp.Compile(`-(\d+)-of-(\d+)\.bin$`) 1530 if err != nil { 1531 t.Errorf("Error compiling regex: %s", err) 1532 } 1533 1534 //check all files in the directory against the pattern 1535 files, err := ioutil.ReadDir(destPath) 1536 if err != nil { 1537 t.Errorf("Error reading directory: %s", err) 1538 } 1539 found := false 1540 1541 for _, file := range files { 1542 if singleModelPattern.MatchString(file.Name()) { 1543 found = true 1544 break 1545 } 1546 1547 matches := re.FindStringSubmatch(file.Name()) 1548 if len(matches) > 2 { 1549 if strings.Compare(matches[1], matches[2]) == 0 { 1550 found = true 1551 break 1552 } 1553 } 1554 } 1555 if !found { 1556 t.Errorf("pytorch_model.bin does not exist or was not found") 1557 } 1558 1559 // Check for additional metadata files 1560 metaFiles := []string{"tokenizer.model", "vocab.json", "config.json"} 1561 for _, metaFile := range metaFiles { 1562 metaPath := destPath + "/" + metaFile 1563 assertFileExists(t, metaPath) 1564 } 1565 1566 // Finish the test, allow defered cleanup 1567 fmt.Println("All Exists - Looks good.") 1568 } 1569 1570 func TestModelDownloadMistral(t *testing.T) { 1571 // Download a downstream mistral model due to mistral being gated 1572 modelId := "openaccess-ai-collective/tiny-mistral" 1573 destPath := "./TestModelDownloadMistral" 1574 err := downloadModel(modelId, destPath) 1575 if err != nil { 1576 os.RemoveAll(destPath) 1577 t.Errorf("Error downloading model: %v", err) 1578 } 1579 defer os.RemoveAll(destPath) 1580 1581 // Check that the model files are there 1582 // We want to check for the presence of the following files: 1583 // config.json, pytorch_model.bin, 1584 // tokenizer.model 1585 1586 // Check for additional metadata files 1587 metaFiles := []string{"tokenizer.model", "config.json", "pytorch_model.bin"} 1588 for _, metaFile := range metaFiles { 1589 metaPath := destPath + "/" + metaFile 1590 assertFileExists(t, metaPath) 1591 } 1592 1593 // Finish the test, allow defered cleanup 1594 fmt.Println("All Exists - Looks good.") 1595 } 1596 1597 func TestModelDownloadFairseq(t *testing.T) { 1598 // Koboldai's fairseq models are stored in a different format 1599 // it has merges and vocab but no tokenizer.json 1600 modelId := "KoboldAI/fairseq-dense-355M" 1601 destPath := "./TestModelDownloadFairseq" 1602 1603 // Download the model 1604 err := downloadModel(modelId, destPath) 1605 if err != nil { 1606 os.RemoveAll(destPath) 1607 t.Errorf("Error downloading model: %v", err) 1608 } 1609 defer os.RemoveAll(destPath) 1610 1611 // Check that the model files are there 1612 // We want to check for the presence of the following files: 1613 // vocab, config. merges, pytorch_model 1614 1615 // Check for additional metadata files 1616 metaFiles := []string{"vocab.json", "config.json", "pytorch_model.bin", "merges.txt"} 1617 for _, metaFile := range metaFiles { 1618 metaPath := destPath + "/" + metaFile 1619 assertFileExists(t, metaPath) 1620 } 1621 1622 // Finish the test, allow defered cleanup 1623 fmt.Println("All Exists - Looks good (Fairseq Download).") 1624 }