github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/gpt_bpe_test.go (about)

     1  package gpt_bpe
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/base64"
     6  	"encoding/hex"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"io/ioutil"
    11  	"log"
    12  	"os"
    13  	"regexp"
    14  	"runtime"
    15  	"strings"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/wbrown/gpt_bpe/types"
    20  
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/ulikunitz/xz"
    23  	"github.com/wbrown/gpt_bpe/resources"
    24  )
    25  
    26  var clipEncoder GPTEncoder
    27  var gpt2Encoder GPTEncoder
    28  var pileEncoder GPTEncoder
    29  var nerdstashV2Encoder GPTEncoder
    30  var llama2Encoder GPTEncoder
    31  var llama3Encoder GPTEncoder
    32  var mistralEncoder GPTEncoder
    33  var corpus string
    34  var clipCorpus string
    35  var largeCorpus *string
    36  
    37  // var corpus2 string
    38  var gpt2Encoded *Tokens
    39  var pileEncoded *Tokens
    40  var clipEncoded *Tokens
    41  var nerdstashEncoded *Tokens
    42  var llama2Encoded *Tokens
    43  var llama3Encoded *Tokens
    44  var mistralEncoded *Tokens
    45  var unicodeTrimTests []*Tokens
    46  
    47  var benchmarkPrefix string
    48  var encoders map[string]*GPTEncoder
    49  
    50  const largeCorpusPath = "resources/wiki.train.raw.xz"
    51  
    52  func handleRead(path string) []byte {
    53  	if textBytes, err := os.ReadFile(path); err != nil {
    54  		log.Fatalf("Error opening `%s`: %v", path, err)
    55  	} else {
    56  		return textBytes
    57  	}
    58  	return nil
    59  }
    60  
    61  func loadUnicodeTrimTests(path string) []*Tokens {
    62  	tests := make([]*Tokens, 0)
    63  	fileBlob := string(handleRead(path))
    64  	fileLines := strings.Split(fileBlob, "\n")
    65  	for idx := range fileLines {
    66  		line := fileLines[idx]
    67  		if len(line) == 0 {
    68  			continue
    69  		}
    70  		unicodeTrimTest := make(Tokens, 0)
    71  		if err := json.Unmarshal(
    72  			[]byte(line),
    73  			&unicodeTrimTest,
    74  		); err != nil {
    75  			log.Fatalf("Error unmarshaling `%s`: %v", path, err)
    76  		}
    77  		tests = append(tests, &unicodeTrimTest)
    78  	}
    79  	return tests
    80  }
    81  
    82  func Chunks(s string, chunkSize int) []string {
    83  	if len(s) == 0 {
    84  		return nil
    85  	}
    86  	if chunkSize >= len(s) {
    87  		return []string{s}
    88  	}
    89  	var chunks []string = make([]string, 0, (len(s)-1)/chunkSize+1)
    90  	currentLen := 0
    91  	currentStart := 0
    92  	for i := range s {
    93  		if currentLen == chunkSize {
    94  			chunks = append(chunks, s[currentStart:i])
    95  			currentLen = 0
    96  			currentStart = i
    97  		}
    98  		currentLen++
    99  	}
   100  	chunks = append(chunks, s[currentStart:])
   101  	return chunks
   102  }
   103  
   104  func getStringBounds(
   105  	i int,
   106  	output string,
   107  	decoded string,
   108  ) (
   109  	left int,
   110  	right int,
   111  ) {
   112  	if i < 20 {
   113  		left = 0
   114  	} else {
   115  		left = i - 20
   116  	}
   117  	if len(output) < len(decoded) {
   118  		right = len(output)
   119  	} else {
   120  		right = len(decoded)
   121  	}
   122  	if i+20 < right {
   123  		right = i + 20
   124  	}
   125  	return left, right
   126  }
   127  
   128  func init() {
   129  	isBench := isRunningBenchmarkTest()
   130  
   131  	// These will all be null
   132  	encoders = map[string]*GPTEncoder{
   133  		"gpt2-tokenizer":         &gpt2Encoder,
   134  		"pile-tokenizer":         &pileEncoder,
   135  		"clip-tokenizer":         &clipEncoder,
   136  		"nerdstash_v2-tokenizer": &nerdstashV2Encoder,
   137  		"llama-tokenizer":        &llama2Encoder,
   138  		"llama3-tokenizer":       &llama3Encoder,
   139  		"mistral-tokenizer":      &mistralEncoder,
   140  	}
   141  
   142  	// Load all encoders up front, as this is desirable for benchmarking
   143  	if isBench {
   144  		gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   145  		pileEncoder = *CacheLoadEncoder("pile-tokenizer")
   146  		clipEncoder = *CacheLoadEncoder("clip-tokenizer")
   147  		nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer")
   148  		llama2Encoder = *CacheLoadEncoder("llama-tokenizer")
   149  		llama3Encoder = *CacheLoadEncoder("llama3-tokenizer")
   150  		mistralEncoder = *CacheLoadEncoder("mistral-tokenizer")
   151  	}
   152  
   153  	textBytes := handleRead("resources/frankenstein.txt")
   154  	clipBytes := handleRead("resources/frankenstein_clip.txt")
   155  	corpus = string(textBytes)
   156  	clipCorpus = string(clipBytes)
   157  	unicodeTrimTests = loadUnicodeTrimTests("resources/trim_tests.jsonl")
   158  	var err error
   159  	_, err = GetLargeCorpus()
   160  	if err != nil {
   161  		log.Fatalf("Error opening `%s`: %v", largeCorpusPath, err)
   162  	}
   163  }
   164  
   165  // isRunningBenchmarkTest
   166  // Check if we're running a benchmark test
   167  // We assume a benchmark test is defined as a test
   168  // that begins with a BENCHMARK_PREFIX. This is
   169  // by default "Benchmark", but can be configured
   170  // by the environment variable BENCHMARK_PREFIX
   171  func isRunningBenchmarkTest() bool {
   172  	prefix, ok := os.LookupEnv("BENCHMARK_PREFIX")
   173  	if ok {
   174  		benchmarkPrefix = prefix
   175  	} else {
   176  		benchmarkPrefix = "Benchmark"
   177  	}
   178  
   179  	for _, arg := range os.Args {
   180  		parsedArg, err := regexp.Compile(arg)
   181  		if err != nil {
   182  			log.Fatalf("Failed parsing CLI args using regexp")
   183  		}
   184  		prefix, _ = parsedArg.LiteralPrefix()
   185  		if strings.HasPrefix(prefix, benchmarkPrefix) {
   186  			log.Println("Running benchmark test, so loading encoders up front...")
   187  			return true
   188  		}
   189  	}
   190  	return false
   191  }
   192  
   193  // CacheLoadEncoder
   194  // Loads an encoder, but only once
   195  func CacheLoadEncoder(vocabId string) *GPTEncoder {
   196  	if encoders[vocabId].Encoder == nil {
   197  		encoder, err := NewEncoder(vocabId)
   198  		if err != nil {
   199  			log.Fatalf("Error loading encoder `%s`: ", vocabId)
   200  		}
   201  
   202  		// Cache the encoder for later use
   203  		encoders[vocabId] = encoder
   204  		return encoder
   205  	}
   206  	return encoders[vocabId]
   207  }
   208  
   209  func GetLargeCorpus() (*string, error) {
   210  	if largeCorpus == nil {
   211  		var err error
   212  		largeCorpus, err = DecompressXZ(largeCorpusPath)
   213  		if err != nil {
   214  			return nil, err
   215  		}
   216  	}
   217  	return largeCorpus, nil
   218  }
   219  
   220  func TestMain(m *testing.M) {
   221  	m.Run()
   222  }
   223  
   224  type TrimTest struct {
   225  	Input     string
   226  	Direction TrimDirection
   227  	Limit     uint
   228  	Expected  string
   229  }
   230  
   231  const sent1 = "This is test sentence 1.  This is test sentence 2.  This is test sentence 3."
   232  const sent2 = "\nThis is test sentence 4.\nThis is test sentence 5.\nThis is test sentence 6.\n"
   233  const hindiSentence = "व्याकरण शास्त्रीय परिभाषाएँ : डॉ. पर्णदत्त सिंह द्वारा हिंदी पीडीऍफ़ पुस्तक"
   234  const jpSentence = "「そんな心構えで、本当に俺の『未練』を果たせるのか? 知ってのとおり、俺の『未練』は『<|rubycover|>相川渦波<|rubystart|>おまえ<|rubyend|>の成長を最後まで見届けること』だ。……言っとくが、俺は年季が入ってる上に、拗らせに拗らせた元神学者。俺の『大いなる<|rubycover|>救世主<|rubystart|>マグナ・メサイア<|rubyend|>』の『理想』は高いぞ? 少なくとも、この『血陸』を止められないようじゃ、任せ切れないな」\n<|mtsentence|><|mtsenglish|>Please check if the meat is being roasted at the right heat.<|mtsjapanese|>焼き肉の火加減を見なさい。<|mtsentenceend|>\n<|mtvocab|><|mtvjapanese|>[ぶんけんがく] 文献学<|mtvenglish|>(n) philology<|mtvocabend|>"
   235  
   236  var TrimSentencesTests = []TrimTest{
   237  	{sent1, TrimTop, 10,
   238  		" This is test sentence 3."},
   239  	{sent1, TrimTop, 20,
   240  		" This is test sentence 2.  This is test sentence 3."},
   241  	{sent1, TrimTop, 30,
   242  		sent1},
   243  	{sent2, TrimTop, 10,
   244  		"\nThis is test sentence 6.\n"},
   245  	{sent2, TrimTop, 18,
   246  		"\nThis is test sentence 5.\nThis is test sentence 6.\n"},
   247  	{sent2, TrimTop, 30,
   248  		sent2},
   249  	{sent1, TrimBottom, 10,
   250  		"This is test sentence 1."},
   251  	{sent1, TrimBottom, 20,
   252  		"This is test sentence 1.  This is test sentence 2."},
   253  	{sent1, TrimBottom, 30,
   254  		sent1},
   255  	{sent2, TrimBottom, 10,
   256  		"\nThis is test sentence 4.\n"},
   257  	{sent2, TrimBottom, 18,
   258  		"\nThis is test sentence 4.\nThis is test sentence 5.\n"},
   259  	{sent2, TrimBottom, 30,
   260  		sent2},
   261  }
   262  
   263  func TestHFResolution(t *testing.T) {
   264  	_, err := NewEncoder("EleutherAI/gpt-j-6B")
   265  	if err != nil {
   266  		t.Error(err)
   267  	}
   268  	_, err = NewEncoder("nonexist/nonexist")
   269  	if err == nil {
   270  		t.Error(errors.New("failed to return error on non-existent model"))
   271  	}
   272  }
   273  
   274  func TestHFTokenzier(t *testing.T) {
   275  	enc, err := NewEncoder("EleutherAI/gpt-j-6B")
   276  	if err != nil {
   277  		t.Error(err)
   278  	}
   279  	sent := "The fox jumped over the hare."
   280  	hfTokens := enc.Encode(&sent)
   281  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   282  	gptTokens := gpt2Encoder.Encode(&sent)
   283  	assert.Equal(t, hfTokens, gptTokens)
   284  }
   285  
   286  func TestFairSeqTokenizer(t *testing.T) {
   287  	enc, err := NewEncoder("KoboldAI/fairseq-dense-2.7B")
   288  	if err != nil {
   289  		t.Error(err)
   290  		return
   291  	}
   292  	sent := "The fox jumped over the hare.\nThe turtle is faster than the hare."
   293  	tokens := Tokens{464, 21831, 11687, 625, 262, 387, 260, 25970, 82, 29,
   294  		464, 28699, 318, 5443, 621, 262, 387, 260, 13}
   295  	fsTokens := enc.Encode(&sent)
   296  	assert.Equal(t, *fsTokens, tokens)
   297  }
   298  
   299  var TrimNewLinesTests = append(
   300  	TrimSentencesTests[3:5], TrimSentencesTests[9:11]...,
   301  )
   302  
   303  func TestGPTEncoder_TrimIncompleteSentence(t *testing.T) {
   304  	testStr := "This is a test. He says, \"This is an unterminated quote. She says, this is actually terminated.\" This is awesome! This is incomplete "
   305  	expected := "This is a test. He says, \"This is an unterminated quote. She says, this is actually terminated.\" This is awesome!"
   306  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   307  	trimmed, _ := gpt2Encoder.TrimIncompleteSentence(gpt2Encoder.Encode(&testStr))
   308  	output := gpt2Encoder.Decode(trimmed)
   309  	if expected != output {
   310  		t.Error("output != expected; output := ", expected)
   311  	}
   312  }
   313  
   314  func TestGPTEncoder_TrimTokens(t *testing.T) {
   315  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   316  	for testIdx := range unicodeTrimTests {
   317  		assert.NotEqual(
   318  			t, len(
   319  				*gpt2Encoder.TrimTokens(
   320  					unicodeTrimTests[testIdx],
   321  				),
   322  			),
   323  			len(*unicodeTrimTests[testIdx]),
   324  		)
   325  	}
   326  }
   327  
   328  func TestGPTEncoder_TrimNewlines(t *testing.T) {
   329  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   330  	for testIdx := range TrimNewLinesTests {
   331  		test := TrimNewLinesTests[testIdx]
   332  		res, err := gpt2Encoder.TrimNewlines(
   333  			gpt2Encoder.Encode(&test.Input),
   334  			test.Direction, test.Limit,
   335  		)
   336  		if err != nil {
   337  			t.Error("TrimNewlines: error:", err)
   338  		}
   339  		decodeRes := gpt2Encoder.Decode(res)
   340  		if decodeRes != test.Expected {
   341  			t.Error(
   342  				"TrimNewlines: expected '" + test.Expected + "' got '" +
   343  					decodeRes + "'",
   344  			)
   345  		}
   346  	}
   347  }
   348  
   349  func TestGPTEncoder_TrimSentences(t *testing.T) {
   350  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   351  	for testIdx := range TrimSentencesTests {
   352  		test := TrimSentencesTests[testIdx]
   353  		res, err := gpt2Encoder.TrimSentences(
   354  			gpt2Encoder.Encode(&test.Input),
   355  			test.Direction, test.Limit,
   356  		)
   357  		if err != nil {
   358  			t.Error("TrimSentences: error:", err)
   359  		}
   360  		decodeRes := gpt2Encoder.Decode(res)
   361  		if decodeRes != test.Expected {
   362  			t.Error(
   363  				"TrimSentences: expected '" + test.Expected + "' got '" +
   364  					decodeRes + "'",
   365  			)
   366  		}
   367  	}
   368  }
   369  
   370  type SplitTest struct {
   371  	Input    string
   372  	Expected []string
   373  }
   374  
   375  var SplitTests = []SplitTest{
   376  	{"we'll go jump in a lake.",
   377  		[]string{"we", "'ll", " go", " jump", " in", " a", " lake",
   378  			"."}},
   379  	{"multiple  encoded spaces.",
   380  		[]string{"multiple", "  ", "encoded", " spaces", "."}},
   381  	{"Capitalized Words Are Cool",
   382  		[]string{"Capitalized", " Words", " Are", " Cool"}},
   383  	{"we'LL test irregular cApitalizatioN.",
   384  		[]string{"we", "'", "LL", " test", " irregular",
   385  			" cApitalizatioN", "."}},
   386  	{"multilines\nare awesome",
   387  		[]string{"multilines", "\n", "are", " awesome"}},
   388  	{"\nstarting with multilines\nis awesome",
   389  		[]string{"\n", "starting", " with", " multilines",
   390  			"\n", "is", " awesome"}},
   391  	{"we'll go jump<|endoftext|> in a lake.",
   392  		[]string{"we", "'ll", " go", " jump", "<|endoftext|>",
   393  			" in", " a", " lake", "."}},
   394  	{"we'll go jump<|end\noftext|> in a lake.",
   395  		[]string{"we", "'ll", " go", " jump", "<|", "end", "\n",
   396  			"oftext", "|>", " in", " a", " lake", "."}},
   397  }
   398  
   399  func TestGPTEncoder_Split(t *testing.T) {
   400  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   401  
   402  	for testIdx := range SplitTests {
   403  		test := SplitTests[testIdx]
   404  		assert.Equal(t, test.Expected, *(gpt2Encoder.SplitWords(&test.Input)))
   405  	}
   406  }
   407  
   408  func DecompressXZ(path string) (*string, error) {
   409  	corpusHandle, err := os.Open(path)
   410  	if err != nil {
   411  		return nil, err
   412  	}
   413  	defer corpusHandle.Close()
   414  	decompressorHandle, err := xz.NewReader(corpusHandle)
   415  	if err != nil {
   416  		return nil, err
   417  	}
   418  	decompressed, err := ioutil.ReadAll(decompressorHandle)
   419  	if err != nil {
   420  		return nil, err
   421  	}
   422  	decompressedStr := string(decompressed)
   423  	return &decompressedStr, nil
   424  }
   425  
   426  func OpenXZStream(path string) (*xz.Reader, *os.File, error) {
   427  	corpusHandle, err := os.Open(path)
   428  	if err != nil {
   429  		return nil, nil, err
   430  	}
   431  	decompressorHandle, err := xz.NewReader(corpusHandle)
   432  	if err != nil {
   433  		return nil, nil, err
   434  	}
   435  	return decompressorHandle, corpusHandle, nil
   436  }
   437  
   438  type EncoderTest struct {
   439  	Input             string
   440  	GPT2Expected      Tokens
   441  	PileExpected      Tokens
   442  	CLIPExpected      Tokens
   443  	NerdstashExpected Tokens
   444  }
   445  
   446  var GPTEncoderTests = []EncoderTest{
   447  	{"… …",
   448  		Tokens{1399, 3926},
   449  		Tokens{2866, 8139},
   450  		Tokens{49406, 959, 959, 49407},
   451  		Tokens{49289, 5512}},
   452  	{"<|endoftext|>",
   453  		Tokens{50256},
   454  		Tokens{0},
   455  		Tokens{49406, 49407, 49407},
   456  		Tokens{3}},
   457  	{" <|endoftext|>\n<|endoftext|>foo",
   458  		Tokens{220, 50256, 198, 50256, 21943},
   459  		Tokens{209, 0, 187, 0, 12110},
   460  		Tokens{49406, 49407, 49407, 23435, 49407},
   461  		Tokens{49209, 3, 85, 3, 49225, 3292}},
   462  	{" <|padding|>test",
   463  		Tokens{220, 50257, 9288},
   464  		Tokens{209, 1, 2566},
   465  		Tokens{49406, 27, 347, 3798, 796, 91, 285, 1628, 49407},
   466  		Tokens{3252, 49376, 42545, 49376, 49405, 10180},
   467  	},
   468  }
   469  
   470  func TestGPTEncoder_Encode(t *testing.T) {
   471  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   472  
   473  	// This test is to check if the GPTEncoder is able to encode the tokens correctly
   474  	start := time.Now()
   475  	tokenCt := len(*gpt2Encoder.Encode(&corpus))
   476  	duration := time.Since(start)
   477  	t.Logf(
   478  		"%v bytes into %v tokens over %v",
   479  		len(corpus), tokenCt, duration,
   480  	)
   481  	for testIdx := range GPTEncoderTests {
   482  		tokensPtr := *gpt2Encoder.Encode(
   483  			&(GPTEncoderTests[testIdx].Input),
   484  		)
   485  		assert.Equal(t, tokensPtr, GPTEncoderTests[testIdx].GPT2Expected)
   486  	}
   487  }
   488  
   489  func TestGPTEncode(t *testing.T) {
   490  	// This test is to check if the GPTEncoder is able to encode the tokens correctly
   491  	strin := "The quick brown fox jumps over the lazy dog."
   492  	expected := Tokens{464, 21831, 11687, 625, 262, 387, 260, 25970, 82, 29, 464, 28699, 318, 5443, 621, 262, 387, 260, 13}
   493  	encoded := gpt2Encoder.Encode(&strin)
   494  	fmt.Printf("Encoded: with commas:")
   495  	for _, token := range *encoded {
   496  		fmt.Printf("%v, ", token)
   497  	}
   498  	assert.Equal(t, *encoded, expected)
   499  }
   500  
   501  func TestGPTEncoder_StreamingEncode(t *testing.T) {
   502  	// This test is to check if the GPTEncoder is able to encode the tokens
   503  	// correctly
   504  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   505  	start := time.Now()
   506  	corpusRunes := strings.NewReader(*largeCorpus)
   507  	// Set our profiler up
   508  	profileHandle, _ := os.Create("streaming.prof")
   509  	defer profileHandle.Close()
   510  	runtime.GC()
   511  	//pprof.StartCPUProfile(profileHandle)
   512  	nextTokens := gpt2Encoder.StreamingEncode(corpusRunes)
   513  	tokenCt := 0
   514  	for {
   515  		tokens := nextTokens(16384)
   516  		if tokens == nil {
   517  			break
   518  		}
   519  		tokenCt += len(*tokens)
   520  	}
   521  	duration := time.Since(start)
   522  	//pprof.StopCPUProfile()
   523  	t.Logf(
   524  		"streaming encode: %d tokens/sec",
   525  		int64(float64(tokenCt)/duration.Seconds()),
   526  	)
   527  }
   528  
   529  func TestCLIPEncoder_Encode(t *testing.T) {
   530  	clipEncoder = *CacheLoadEncoder("clip-tokenizer")
   531  
   532  	// This test is to check if the CLIPEncoder is able to encode the tokens correctly
   533  	start := time.Now()
   534  	tokenCt := len(*clipEncoder.Encode(&corpus))
   535  	duration := time.Since(start)
   536  	t.Logf(
   537  		"%v bytes into %v tokens over %v",
   538  		len(corpus), tokenCt, duration,
   539  	)
   540  	for testIdx := range GPTEncoderTests {
   541  		testStr := GPTEncoderTests[testIdx].Input
   542  		tokensPtr := *clipEncoder.Encode(&testStr)
   543  		assert.Equal(t, GPTEncoderTests[testIdx].CLIPExpected, tokensPtr)
   544  	}
   545  }
   546  
   547  func TestPileEncoder_Encode(t *testing.T) {
   548  	pileEncoder = *CacheLoadEncoder("pile-tokenizer")
   549  
   550  	// This test is to check if the PileEncoder is able to encode the tokens correctly
   551  	start := time.Now()
   552  	tokenCt := len(*pileEncoder.Encode(&corpus))
   553  	duration := time.Since(start)
   554  	t.Logf(
   555  		"%v bytes into %v tokens over %v",
   556  		len(corpus), tokenCt, duration,
   557  	)
   558  	for testIdx := range GPTEncoderTests {
   559  		tokensPtr := *pileEncoder.Encode(
   560  			&(GPTEncoderTests[testIdx].Input),
   561  		)
   562  		assert.Equal(t, GPTEncoderTests[testIdx].PileExpected, tokensPtr)
   563  	}
   564  }
   565  
   566  func TestNerdstashEncoder_Encode(t *testing.T) {
   567  	// This test is to check if the NerdstashEncoder is able to encode the tokens correctly
   568  	start := time.Now()
   569  	nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer")
   570  
   571  	tokenCt := len(*nerdstashV2Encoder.Encode(&corpus))
   572  	duration := time.Since(start)
   573  	t.Logf(
   574  		"%v bytes into %v tokens over %v",
   575  		len(corpus), tokenCt, duration,
   576  	)
   577  	for testIdx := range GPTEncoderTests {
   578  		tokensPtr := *nerdstashV2Encoder.Encode(
   579  			&(GPTEncoderTests[testIdx].Input),
   580  		)
   581  		assert.Equal(t, GPTEncoderTests[testIdx].NerdstashExpected, tokensPtr)
   582  	}
   583  }
   584  
   585  func TestNerdstashEncoder_EncodeSpaces(t *testing.T) {
   586  	nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer")
   587  
   588  	// This test is to check if the NerdstashEncoder is able to encode spaces correctly
   589  	testString := "        12 => '',\n"
   590  	expected := Tokens{16, 124, 125, 10631, 1695, 49231, 85}
   591  	encoded := nerdstashV2Encoder.Encode(&testString)
   592  	assert.Equal(t, expected, *encoded)
   593  }
   594  
   595  func TestNerdstashEncoder_Encode2(t *testing.T) {
   596  	nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer")
   597  
   598  	// read the jsonl test file in
   599  	testFile, err := os.Open("resources/subset.jsonl")
   600  	if err != nil {
   601  		t.Error(err)
   602  	}
   603  	defer testFile.Close()
   604  	scanner := bufio.NewScanner(testFile)
   605  	scanner.Split(bufio.ScanLines)
   606  	type testLineStruct struct {
   607  		Text    *string `json:"text"`
   608  		Hex     *string `json:"hex"`
   609  		Encoded Tokens  `json:"encoded"`
   610  	}
   611  
   612  	passCt := 0
   613  	failCt := 0
   614  
   615  	for scanner.Scan() {
   616  		jsonLine := scanner.Text()
   617  		testLine := testLineStruct{}
   618  		err := json.Unmarshal([]byte(jsonLine), &testLine)
   619  		if err != nil {
   620  			t.Error(err)
   621  		}
   622  		expected := testLine.Encoded
   623  		var inputStr string
   624  		if testLine.Hex != nil {
   625  			inputBytes, hexErr := hex.DecodeString(*testLine.Hex)
   626  			if hexErr != nil {
   627  				t.Error(hexErr)
   628  			}
   629  			inputStr = string(inputBytes)
   630  		} else {
   631  			inputStr = *testLine.Text
   632  		}
   633  		// encode the string
   634  		encoded := nerdstashV2Encoder.Encode(&inputStr)
   635  		// check that the encoded string is the same as the expected
   636  		if !assert.Equal(t, expected, *encoded) {
   637  			t.Logf("failure on input: `%v`", inputStr)
   638  			expectedRepr := []string{}
   639  			for _, token := range expected {
   640  				expectedRepr = append(
   641  					expectedRepr,
   642  					string(nerdstashV2Encoder.Decoder[token]),
   643  				)
   644  			}
   645  			actualRepr := []string{}
   646  			for _, token := range *encoded {
   647  				actualRepr = append(
   648  					actualRepr,
   649  					string(nerdstashV2Encoder.Decoder[token]),
   650  				)
   651  			}
   652  			t.Logf("expected: |%s", strings.Join(expectedRepr, "|"))
   653  			t.Logf("actual:   |%s", strings.Join(actualRepr, "|"))
   654  			failCt += 1
   655  		} else {
   656  			passCt += 1
   657  		}
   658  	}
   659  	t.Logf("pass: %v, fail: %v", passCt, failCt)
   660  }
   661  
   662  func TestNerdstashEncoder_Decode(t *testing.T) {
   663  	nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer")
   664  
   665  	// This test is to check if the NerdstashEncoder is able to decode the tokens correctly
   666  	for testIdx := range GPTEncoderTests {
   667  		decodedStr := nerdstashV2Encoder.Decode(
   668  			&(GPTEncoderTests[testIdx].NerdstashExpected),
   669  		)
   670  		assert.Equal(t, GPTEncoderTests[testIdx].Input, decodedStr)
   671  	}
   672  }
   673  
   674  func TestGPTEncoder_Decode2(t *testing.T) {
   675  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   676  
   677  	// This test is to check if the GPTEncoder is able to decode the tokens correctly from a base64 encoded string
   678  	gpt2EncodedCorpus := "NrGIEOQBRzFfAQEBCAE5GeADPCFGAQhdBgFhBkcHXwEBATM5HgGilUYBpAdDEaUheR8iAQEBmgSnbyQpRgHIjaYBiSQYLfoHYwHogg0A0AHsGFUmpgEGAcd0qApjAzwa7hscAeHAYwEGAbYRB3UiAax0PQPjAgoXpgEGAZgE6G2gAWMExy5GAb5szQdGAXUBAR2gAVQBRgG8CdYBYbCgAe4QAxg/NA0AdyoiAZMGOXL8AWlmAQGgFXknNlIGAdADLiciAT4B6lk="
   679  	decodedCorpus := "frying whatever they touched with a sizzled smell that fills the air along with a shower of sparks that land harmlessly elsewhere and a few stray drops that drip from fingers burned black as charcoal.The shock waves from the blasts cause many nearby trees to topple as the earth shakes and trembles underfoot from the power unleashed by each blast that destroys anything that was struck by it that wasn't shielded by heavy metal plates."
   680  	if binTokens, err := base64.StdEncoding.DecodeString(gpt2EncodedCorpus); err != nil {
   681  		log.Println("ERROR:", err)
   682  	} else {
   683  		tokens := types.TokensFromBin(&binTokens)
   684  		tokens, err = gpt2Encoder.TrimIncompleteSentence(tokens)
   685  		if err != nil {
   686  			t.Error(err)
   687  		}
   688  		assert.Equal(t, gpt2Encoder.Decode(tokens), decodedCorpus)
   689  	}
   690  }
   691  
   692  func TestGPTEncoder_Decode(t *testing.T) {
   693  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   694  
   695  	// This test is to check if the GPTEncoder is able to decode the tokens correctly
   696  	if gpt2Encoded == nil {
   697  		corpEncoded := gpt2Encoder.Encode(&corpus)
   698  		gpt2Encoded = corpEncoded
   699  	}
   700  	start := time.Now()
   701  	decoded := gpt2Encoder.Decode(gpt2Encoded)
   702  	duration := time.Since(start)
   703  	tokenNumBytes := len(decoded)
   704  	t.Logf(
   705  		"%v tokens into %v bytes over %v\n",
   706  		len(*gpt2Encoded), tokenNumBytes, duration,
   707  	)
   708  	assert.Equal(t, corpus, decoded)
   709  }
   710  
   711  // BUG: CLIP TOKENIZER has a bug that causes 'the to be split into
   712  // "'t<w>he<w>" instead of "'<w>the<w>".  This causes the
   713  // clipCorpus to be different from the corpus.  This is a bug in
   714  // the CLIP tokenizer from huggingface that was used to generate
   715  // the clipCorpus. The decoded corpus is correct in this test.
   716  // We stop the test right before the bug.
   717  func TestCLIPEncoder_Decode(t *testing.T) {
   718  	clipEncoder = *CacheLoadEncoder("clip-tokenizer")
   719  
   720  	if clipEncoded == nil {
   721  		corpEncoded := clipEncoder.Encode(&corpus)
   722  		clipEncoded = corpEncoded
   723  	}
   724  	start := time.Now()
   725  	decoded := clipEncoder.Decode(clipEncoded)
   726  	duration := time.Since(start)
   727  	tokenNumBytes := len(decoded)
   728  	idxToStop := 229550
   729  	t.Logf(
   730  		"%v tokens into %v bytes over %v\n", len(*clipEncoded), tokenNumBytes,
   731  		duration,
   732  	)
   733  	for idx := range clipCorpus {
   734  		if idx > idxToStop {
   735  			break
   736  		}
   737  
   738  		if clipCorpus[idx] != decoded[idx] {
   739  			t.Errorf(
   740  				"idx: %d, clipCorpus: %v, decoded: %v\n", idx,
   741  				clipCorpus[idx], decoded[idx],
   742  			)
   743  			break
   744  		}
   745  	}
   746  	// assert.Equal(t, clipCorpus, decoded)
   747  }
   748  
   749  func TestPileEncoder_Decode(t *testing.T) {
   750  	pileEncoder = *CacheLoadEncoder("pile-tokenizer")
   751  
   752  	// This test is to check if the PileEncoder is able to decode the tokens correctly
   753  	if pileEncoded == nil {
   754  		corpEncoded := pileEncoder.Encode(&corpus)
   755  		pileEncoded = corpEncoded
   756  	}
   757  	start := time.Now()
   758  	decoded := pileEncoder.Decode(pileEncoded)
   759  	duration := time.Since(start)
   760  	tokenNumBytes := len(decoded)
   761  	t.Logf(
   762  		"%v tokens into %v bytes over %v\n",
   763  		len(*pileEncoded), tokenNumBytes, duration,
   764  	)
   765  	range_data := corpus
   766  	if len(corpus) > len(decoded) {
   767  		range_data = decoded
   768  	}
   769  	if len(corpus) != len(decoded) {
   770  		t.Errorf(fmt.Sprintf("%v != %v", len(corpus), len(decoded)))
   771  	}
   772  	for idx := range range_data {
   773  		if corpus[idx] != decoded[idx] {
   774  			t.Errorf(
   775  				"%v != %v", clipCorpus[idx-20:idx+20],
   776  				decoded[idx-20:idx+20],
   777  			)
   778  			return
   779  		}
   780  	}
   781  }
   782  
   783  func TestGPTEncoder_TokensReady(t *testing.T) {
   784  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   785  
   786  	// This test is to check if the TokensReady function is able to determine if the tokens are ready for context
   787  	multiTokenAsterism := "⁂"
   788  	tokens := gpt2Encoder.Encode(&multiTokenAsterism)
   789  	fmt.Printf("Tokens: %v, len: %v\n", tokens, len(*tokens))
   790  	var idx int
   791  	for idx = range *tokens {
   792  		tokenSlice := (*tokens)[0 : idx+1]
   793  		fmt.Printf("TokenSlice: %v, len: %v\n", tokenSlice, len(tokenSlice))
   794  		if gpt2Encoder.TokensReady(&tokenSlice) {
   795  			break
   796  		}
   797  	}
   798  	if idx < len(*tokens)-1 {
   799  		t.Errorf(
   800  			"Expected TokensReady on idx: %d for `%s`", idx,
   801  			multiTokenAsterism,
   802  		)
   803  	}
   804  }
   805  
   806  func TestGPTEncoder_TokensReadyContext(t *testing.T) {
   807  	pileEncoder = *CacheLoadEncoder("pile-tokenizer")
   808  
   809  	// This test is to check if the TokensReady function is able to determine if the tokens are ready for context
   810  	var tokens Tokens
   811  	badContext, err := os.ReadFile("resources/badcontext.json")
   812  	if err != nil {
   813  		t.Errorf("Could not read badcontext.json: %v", err)
   814  	}
   815  	unmarshalErr := json.Unmarshal(badContext, &tokens)
   816  	if unmarshalErr != nil {
   817  		t.Errorf("Could not unmarshal badcontext.json: %v", unmarshalErr)
   818  	}
   819  	if !pileEncoder.TokensReady(&tokens) {
   820  		t.Errorf("Expected TokensReady to be true for badcontext.json")
   821  	}
   822  }
   823  
   824  func TestUnitrimFunctionality(t *testing.T) {
   825  	// This test is to check if the makeUnitrimArr function is able to generate the unitrim array correctly
   826  	for _, tokenizer := range []string{"clip-tokenizer", "gpt2-tokenizer", "pile-tokenizer"} {
   827  		encoderFile := fmt.Sprintf(
   828  			"resources/data/%s/encoder.json", tokenizer,
   829  		)
   830  		unitrimFile := fmt.Sprintf(
   831  			"resources/data/%s/unitrim.json", tokenizer,
   832  		)
   833  
   834  		// make sure the files exist
   835  		if _, err := os.Stat(encoderFile); os.IsNotExist(err) {
   836  			t.Errorf("Could not find file %s\n", encoderFile)
   837  		}
   838  		if _, err := os.Stat(unitrimFile); os.IsNotExist(err) {
   839  			t.Errorf("Could not find file %s\n", unitrimFile)
   840  		}
   841  
   842  		// read in the Encoder and unitrim files
   843  		encoderBytes, err := os.ReadFile(encoderFile)
   844  		if err != nil {
   845  			t.Errorf("Could not read Encoder file: %v\n", err)
   846  		}
   847  		// unmarshal the Encoder file
   848  		var encoder map[string]Token
   849  		err = json.Unmarshal(encoderBytes, &encoder)
   850  		if err != nil {
   851  			t.Errorf("Could not unmarshal Encoder file: %v\n", err)
   852  		}
   853  
   854  		// read in the unitrim file
   855  		unitrimBytes, err := os.ReadFile(unitrimFile)
   856  		if err != nil {
   857  			t.Errorf("Could not read unitrim file: %v\n", err)
   858  		}
   859  		// unmarshal the unitrim file
   860  		var unitrim []int
   861  		err = json.Unmarshal(unitrimBytes, &unitrim)
   862  		if err != nil {
   863  			t.Errorf("Could not unmarshal unitrim file: %v\n", err)
   864  		}
   865  
   866  		// get generated array for unitrim with the makeUnitrimArr function
   867  		generatedArray := makeUnitrimArr(encoder)
   868  
   869  		// check that the generated array is the same as the unitrim array
   870  		fmt.Printf(
   871  			"Generated array length: %d, unitrim array length: %d\n",
   872  			len(generatedArray), len(unitrim),
   873  		)
   874  		if len(generatedArray) != len(unitrim) {
   875  			t.Errorf("Generated array and unitrim array are not the same length\n")
   876  		}
   877  
   878  		for i := range generatedArray {
   879  			if generatedArray[i] != unitrim[i] {
   880  				fmt.Printf(
   881  					"Generated array: %v and unitrim array: %v at index %d are not the same\n",
   882  					generatedArray[i], unitrim[i], i,
   883  				)
   884  				fmt.Printf(
   885  					"mismatched unicode is: %c\n", rune(generatedArray[i]),
   886  				)
   887  				t.Errorf("Generated array and unitrim array are not the same\n")
   888  			}
   889  		}
   890  
   891  		fmt.Printf("Length and contents of generated array and unitrim array are the same\n")
   892  	}
   893  }
   894  
   895  func TestLlamaEncoder_Encode(t *testing.T) {
   896  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
   897  
   898  	// This test is to check if the encoder is able to encode a basic string
   899  	start := time.Now()
   900  	tokenCt := len(*gpt2Encoder.Encode(&corpus))
   901  	duration := time.Since(start)
   902  	t.Logf(
   903  		"%v bytes into %v tokens over %v",
   904  		len(corpus), tokenCt, duration,
   905  	)
   906  	for testIdx := range GPTEncoderTests {
   907  		tokensPtr := *gpt2Encoder.Encode(
   908  			&(GPTEncoderTests[testIdx].Input),
   909  		)
   910  		assert.Equal(t, tokensPtr, GPTEncoderTests[testIdx].GPT2Expected)
   911  	}
   912  }
   913  
   914  func TestLlamaTwoEncoder_Encode(t *testing.T) {
   915  	llama2Encoder = *CacheLoadEncoder("llama-tokenizer")
   916  
   917  	// This test is to check if the encoder is able to encode a basic string
   918  	testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
   919  	llamaTokens := llama2Encoder.Encode(&testString)
   920  	assert.Equal(
   921  		t, llamaTokens,
   922  		&Tokens{1576, 1701, 29916, 12500, 287, 975, 278, 447, 276, 29889, 13, 1576, 260, 4227, 280, 338, 8473, 1135, 278, 447, 276, 29889},
   923  	)
   924  }
   925  
   926  func TestLlamaTwoTokenizerDecode(t *testing.T) {
   927  	llama2Encoder = *CacheLoadEncoder("llama-tokenizer")
   928  
   929  	// This test is to check if the decoder is able to decode the tokens correctly
   930  	outputString := "<s>The fox jumped over the hare.\nThe turtle is faster than the hare."
   931  	llamaTokens := Tokens{1, 1576, 1701, 29916, 12500, 287, 975, 278, 447, 276, 29889, 13, 1576, 260, 4227, 280, 338, 8473, 1135, 278, 447, 276, 29889}
   932  	output := llama2Encoder.Decode(&llamaTokens)
   933  	assert.Equal(t, outputString, output)
   934  }
   935  
   936  func TestLlamaTwoEncodeDecode(t *testing.T) {
   937  	llama2Encoder = *CacheLoadEncoder("llama-tokenizer")
   938  
   939  	// This test is to check if the encoder is able to encode and decode a basic string
   940  	testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
   941  	outputString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
   942  	llamaTokens := llama2Encoder.Encode(&testString)
   943  	output := llama2Encoder.Decode(llamaTokens)
   944  	assert.Equal(t, outputString, output)
   945  }
   946  
   947  // This is Mistral tokenizer V1, associated with 7b instruct https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
   948  func TestMistralEncoder_Encode(t *testing.T) {
   949  	mistralEncoder = *CacheLoadEncoder("mistral-tokenizer")
   950  
   951  	// This test is to check if the encoder is able to encode a basic string
   952  	testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
   953  	mistralTokens := mistralEncoder.Encode(&testString)
   954  	assert.Equal(
   955  		t, mistralTokens,
   956  		&Tokens{1, 415, 285, 1142, 14949, 754, 272, 295, 492, 28723, 13, 1014, 261, 3525, 291, 349, 9556, 821, 272, 295, 492, 28723},
   957  	)
   958  }
   959  
   960  func TestMistralTokenizerDecode(t *testing.T) {
   961  	mistralEncoder = *CacheLoadEncoder("mistral-tokenizer")
   962  
   963  	// This test is to check if the decoder is able to decode the tokens correctly
   964  	outputString := "<s> The fox jumped over the hare.\nThe turtle is faster than the hare."
   965  	mistralTokens := Tokens{1, 415, 285, 1142, 14949, 754, 272, 295, 492, 28723, 13, 1014, 261, 3525, 291, 349, 9556, 821, 272, 295, 492, 28723}
   966  	output := mistralEncoder.Decode(&mistralTokens)
   967  	assert.Equal(t, outputString, output)
   968  }
   969  
   970  func TestMistralEncodeDecode(t *testing.T) {
   971  	mistralEncoder = *CacheLoadEncoder("mistral-tokenizer")
   972  
   973  	// This test is to check if the encoder is able to encode and decode a basic string
   974  	testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
   975  	outputString := "<s> The fox jumped over the hare.\nThe turtle is faster than the hare."
   976  	mistralTokens := mistralEncoder.Encode(&testString)
   977  	output := mistralEncoder.Decode(mistralTokens)
   978  	assert.Equal(t, outputString, output)
   979  }
   980  
   981  func TestMistralEncodeDecodeFrankenstein(t *testing.T) {
   982  	mistralEncoder = *CacheLoadEncoder("mistral-tokenizer")
   983  
   984  	// This test is to check if the encoder is able to encode and decode the Frankenstein corpus
   985  	frankensteinCorpus := "resources/frankenstein.txt"
   986  	frankensteinText, err := os.ReadFile(frankensteinCorpus)
   987  	if err != nil {
   988  		t.Errorf("Error reading Frankenstein corpus: %v", err)
   989  	}
   990  	frankensteinString := string(frankensteinText)
   991  	mistralTokens := mistralEncoder.Encode(&frankensteinString)
   992  	frankensteinString = "<s>" + frankensteinString
   993  	output := mistralEncoder.Decode(mistralTokens)
   994  	for i := 0; i < len(output); i++ {
   995  		if output[i] != frankensteinString[i] {
   996  			t.Errorf(
   997  				"Mismatch at around index %d Expected: %v, Actual: %v", i,
   998  				string(frankensteinString[i]), string(output[i]),
   999  			)
  1000  			break
  1001  		}
  1002  	}
  1003  }
  1004  
  1005  func TestMistralEncodeDecode_Emojis(t *testing.T) {
  1006  	mistralEncoder = *CacheLoadEncoder("mistral-tokenizer")
  1007  
  1008  	// This test is to check if the encoder is able to encode and decode emojis
  1009  	// Requires the ability to properly handle byte tokens in the encoder
  1010  	testString := "expensive 😦 padding ⁂ padding"
  1011  	tokens := mistralEncoder.Encode(&testString)
  1012  	output := mistralEncoder.Decode(tokens)
  1013  	testString = "<s>" + testString
  1014  	assert.Equal(t, testString, output)
  1015  }
  1016  
  1017  func TestMistralEncodeDecode_LargeCorpus(t *testing.T) {
  1018  	mistralEncoder = *CacheLoadEncoder("mistral-tokenizer")
  1019  
  1020  	// This test is to check if the encoder is able to encode and decode a large corpus
  1021  	referenceFile := "resources/test_references/753.txt"
  1022  	referenceBin := "resources/test_references/753_mistralv1.bin"
  1023  	referenceText, err := os.ReadFile(referenceFile)
  1024  	if err != nil {
  1025  		t.Errorf("Error reading reference file: %v", err)
  1026  	}
  1027  	referenceString := string(referenceText)
  1028  	// Need to decode the reference bin file
  1029  	referenceBinData, err := os.ReadFile(referenceBin)
  1030  	if err != nil {
  1031  		t.Errorf("Error reading reference bin file: %v", err)
  1032  	}
  1033  	referenceTokens := types.TokensFromBin32(&referenceBinData)
  1034  	// Encode the reference string
  1035  	mistralTokens := mistralEncoder.Encode(&referenceString)
  1036  	for i := 0; i < len(*mistralTokens); i++ {
  1037  		if (*mistralTokens)[i] != (*referenceTokens)[i] {
  1038  			t.Errorf(
  1039  				"Mismatch at around index %d Expected: %v, Actual: %v", i,
  1040  				(*referenceTokens)[i], (*mistralTokens)[i],
  1041  			)
  1042  		}
  1043  	}
  1044  	assert.Equal(t, mistralTokens, referenceTokens)
  1045  	// Decode the tokens to check if the decoded string is the same as the reference string
  1046  	output := mistralEncoder.Decode(mistralTokens)
  1047  	referenceString = "<s>" + referenceString
  1048  	for i := 0; i < len(output); i++ {
  1049  		if output[i] != referenceString[i] {
  1050  			fmt.Printf("Mismatch at around index %d\n", i)
  1051  			fmt.Printf("Expected: %s\n", referenceString[i-20:i+20])
  1052  			fmt.Printf("Actual: %s\n", output[i-20:i+20])
  1053  			t.Errorf(
  1054  				"Mismatch at around index %d Expected: %s, Actual: %s", i,
  1055  				string(referenceString[i]), string(output[i]),
  1056  			)
  1057  			break
  1058  		}
  1059  	}
  1060  	assert.Equal(t, referenceString, output)
  1061  }
  1062  
  1063  func TestLlama3Encoder_Encode(t *testing.T) {
  1064  	llama3Encoder = *CacheLoadEncoder("llama3-tokenizer")
  1065  
  1066  	// This test is to check if the encoder is able to encode a basic string
  1067  	testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
  1068  	llamaTokens := llama3Encoder.Encode(&testString)
  1069  	fmt.Printf("Llama3 tokens: %v\n", llamaTokens)
  1070  	assert.Equal(
  1071  		t, llamaTokens,
  1072  		&Tokens{128000, 791, 39935, 27096, 927, 279, 96018, 627, 791, 37189, 374, 10819, 1109, 279, 96018, 13, 128001},
  1073  	)
  1074  }
  1075  
  1076  func TestLlama3TokenizerDecode(t *testing.T) {
  1077  	llama3Encoder = *CacheLoadEncoder("llama3-tokenizer")
  1078  
  1079  	// This test is to check if the decoder is able to decode the tokens correctly
  1080  	outputString := "<|begin_of_text|>The fox jumped over the hare.\nThe turtle is faster than the hare.<|end_of_text|>"
  1081  	llamaTokens := Tokens{128000, 791, 39935, 27096, 927, 279, 96018, 627, 791, 37189, 374, 10819, 1109, 279, 96018, 13, 128001}
  1082  	output := llama3Encoder.Decode(&llamaTokens)
  1083  	assert.Equal(t, outputString, output)
  1084  }
  1085  
  1086  func TestLlama3EncodeDecode(t *testing.T) {
  1087  	llama3Encoder = *CacheLoadEncoder("llama3-tokenizer")
  1088  
  1089  	// This test is to check if the encoder is able to encode and decode a basic string
  1090  	testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
  1091  	outputString := "<|begin_of_text|>The fox jumped over the hare.\nThe turtle is faster than the hare.<|end_of_text|>"
  1092  	llamaTokens := llama3Encoder.Encode(&testString)
  1093  	output := llama3Encoder.Decode(llamaTokens)
  1094  	assert.Equal(t, outputString, output)
  1095  }
  1096  
  1097  func TestLlama3EncodeDecode_Merges(t *testing.T) {
  1098  	llama3Encoder = *CacheLoadEncoder("llama3-tokenizer")
  1099  
  1100  	// This test is to check if the encoder is able to merge the tokens correctly
  1101  	// If it fails, the merge function in the streaming_encode does not check for invalid merge pairs correctly
  1102  	testString := "Ah! Cornelius Agrippa! My dear Victor, d"
  1103  	outputString := "<|begin_of_text|>Ah! Cornelius Agrippa! My dear Victor, d<|end_of_text|>"
  1104  	llamaTokens := llama3Encoder.Encode(&testString)
  1105  	fmt.Printf("Llama3 tokens: %v\n", llamaTokens)
  1106  	output := llama3Encoder.Decode(llamaTokens)
  1107  	assert.Equal(t, outputString, output)
  1108  }
  1109  func TestLlama3Merge(t *testing.T) {
  1110  	llama3Encoder = *CacheLoadEncoder("llama3-tokenizer")
  1111  
  1112  	// This test is to check if the encoder is able to merge the tokens correctly
  1113  	// If it fails, the merge function in the streaming_encode does not check for invalid merge pairs correctly
  1114  	//testString := "Description\ndescription\n Description\n description"
  1115  	testString := "1234"
  1116  	llamaTokens := llama3Encoder.Encode(&testString)
  1117  	decodedTokens := make([]string, len(*llamaTokens))
  1118  	for i := 0; i < len(*llamaTokens); i++ {
  1119  		decodedTokens[i] = string(llama3Encoder.Decoder[(*llamaTokens)[i]])
  1120  	}
  1121  
  1122  	if len(*llamaTokens) != 4 {
  1123  		t.Errorf("Expected 4 tokens, got %d", len(*llamaTokens))
  1124  	}
  1125  
  1126  	if decodedTokens[1] != "123" && decodedTokens[2] != "4" {
  1127  		t.Errorf(
  1128  			"Expected 123|4, got %s|%s",
  1129  			decodedTokens[1], decodedTokens[2],
  1130  		)
  1131  	}
  1132  }
  1133  
  1134  func TestLlama3EncodeDecode_LargeCorpus(t *testing.T) {
  1135  	llama3Encoder = *CacheLoadEncoder("llama3-tokenizer")
  1136  
  1137  	// This test is to check if the encoder is able to encode and decode a large corpus
  1138  	referenceFile := "resources/test_references/753.txt"
  1139  	referenceBin := "resources/test_references/753_llama3.bin"
  1140  	referenceText, err := os.ReadFile(referenceFile)
  1141  	if err != nil {
  1142  		t.Errorf("Error reading reference file: %v", err)
  1143  	}
  1144  	referenceString := string(referenceText)
  1145  	// Need to decode the reference bin file
  1146  	referenceBinData, err := os.ReadFile(referenceBin)
  1147  	if err != nil {
  1148  		t.Errorf("Error reading reference bin file: %v", err)
  1149  	}
  1150  	referenceTokens := types.TokensFromBin32(&referenceBinData)
  1151  	// Encode the reference string
  1152  	llamaTokens := llama3Encoder.Encode(&referenceString)
  1153  	for i := 0; i < len(*llamaTokens); i++ {
  1154  		if (*llamaTokens)[i] != (*referenceTokens)[i] {
  1155  			t.Errorf(
  1156  				"Mismatch at around index %d Expected: %v, Actual: %v", i,
  1157  				(*referenceTokens)[i], (*llamaTokens)[i],
  1158  			)
  1159  		}
  1160  	}
  1161  	// Check that the encoded tokens are the same as the reference tokens
  1162  	assert.Equal(t, llamaTokens, referenceTokens)
  1163  	// Decode the tokens
  1164  	output := llama3Encoder.Decode(llamaTokens)
  1165  	refDecoded := llama3Encoder.Decode(referenceTokens)
  1166  	// Check that the decoded string is the same as the reference string
  1167  	for i := 0; i < len(output); i++ {
  1168  		if output[i] != refDecoded[i] {
  1169  			left, right := getStringBounds(i, output, refDecoded)
  1170  			fmt.Printf("Mismatch at around index %d\n", i)
  1171  			fmt.Printf("Expected: %s\n", refDecoded[left:right])
  1172  			fmt.Printf("Actual: %s\n", output[left:right])
  1173  			break
  1174  		}
  1175  	}
  1176  }
  1177  
  1178  func TestLlama3EncodeDecodeFrankenstein(t *testing.T) {
  1179  	llama3Encoder = *CacheLoadEncoder("llama3-tokenizer")
  1180  
  1181  	// This test is to check if the encoder is able to encode and decode the Frankenstein corpus
  1182  	frankensteinCorpus := "resources/frankenstein.txt"
  1183  	frankensteinText, err := os.ReadFile(frankensteinCorpus)
  1184  	if err != nil {
  1185  		t.Errorf("Error reading Frankenstein corpus: %v", err)
  1186  	}
  1187  	frankensteinString := string(frankensteinText)
  1188  	llamaTokens := llama3Encoder.Encode(&frankensteinString)
  1189  	output := llama3Encoder.Decode(llamaTokens)
  1190  	frankensteinString = "<|begin_of_text|>" + frankensteinString + "<|end_of_text|>"
  1191  	for i := 0; i < len(output); i++ {
  1192  		if output[i] != frankensteinString[i] {
  1193  			left, right := getStringBounds(i, output, frankensteinString)
  1194  			fmt.Printf("Mismatch at around index %d\n", i)
  1195  			fmt.Printf("Expected: %s\n", frankensteinString[left:right])
  1196  			fmt.Printf("Actual: %s\n", output[left:right])
  1197  			break
  1198  		}
  1199  	}
  1200  	assert.Equal(t, frankensteinString, output)
  1201  }
  1202  
  1203  func TestReadTokenizerConfig(t *testing.T) {
  1204  	// This test is to check if the encoder is able to read the tokenizer_config.json file
  1205  	// json with eos, bos, pad as strings
  1206  	jsonStr := `{"eos_token": "TC", "bos_token": "TD", "pad_token": "TE"}` //cooresponds to 6669, 10989, 5428 in pythia vocab
  1207  
  1208  	//download filler model
  1209  	modelId := "EleutherAI/pythia-70m"
  1210  	destPath := "./TestReadTokenizerConfig"
  1211  	destPathPTR := &destPath
  1212  	defer os.RemoveAll(destPath)
  1213  	rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
  1214  	os.MkdirAll(destPath, 0755)
  1215  	_, rsrcErr := resources.ResolveResources(
  1216  		modelId, destPathPTR,
  1217  		resources.RESOURCE_MODEL, rsrcType, hfApiToken,
  1218  	)
  1219  	if rsrcErr != nil {
  1220  		t.Errorf("Error downloading model resources: %s", rsrcErr)
  1221  	}
  1222  
  1223  	// replace tokenizer_config.json with jsonStr
  1224  	tokenizerConfigPath := destPath + "/tokenizer_config.json"
  1225  	err := os.WriteFile(tokenizerConfigPath, []byte(jsonStr), 0644)
  1226  	if err != nil {
  1227  		t.Errorf("Error writing to tokenizer_config.json: %v", err)
  1228  	}
  1229  
  1230  	// read tokenizer config by encoding a string
  1231  	encoder, err := NewEncoder(destPath)
  1232  	if err != nil {
  1233  		t.Errorf("Error creating encoder: %v", err)
  1234  	}
  1235  
  1236  	// check that the tokens are correct
  1237  	assert.Equal(t, encoder.EosToken, Token(6669))
  1238  	assert.Equal(t, encoder.BosToken, Token(10989))
  1239  	assert.Equal(t, encoder.PadToken, Token(5428))
  1240  
  1241  	// Finish the test, allow defered cleanup
  1242  	fmt.Println("All Exists - Looks good.")
  1243  }
  1244  
  1245  func TestGPT2DefaultPadding(t *testing.T) {
  1246  	gpt2Encoder = *CacheLoadEncoder("gpt2-tokenizer")
  1247  
  1248  	// GPT2 defines a padding token, we test if it properly gets this token
  1249  	// corresponds to <|padding|> in the vocab
  1250  	assert.Equal(t, gpt2Encoder.PadToken, Token(50257))
  1251  	assert.Equal(t, gpt2Encoder.Encoder["<|padding|>"], Token(50257))
  1252  }
  1253  
  1254  func TestPilePadding(t *testing.T) {
  1255  	pileEncoder = *CacheLoadEncoder("pile-tokenizer")
  1256  
  1257  	// Pile defines a padding token, we test if it properly gets this token
  1258  	// corresponds to <|padding|> in the vocab
  1259  	assert.Equal(t, pileEncoder.PadToken, Token(1))
  1260  	assert.Equal(t, pileEncoder.Encoder["<|padding|>"], Token(1))
  1261  }
  1262  
  1263  func TestClipPadding(t *testing.T) {
  1264  	clipEncoder = *CacheLoadEncoder("clip-tokenizer")
  1265  
  1266  	// CLIP defines a padding token, we test if it properly gets this token
  1267  	// corresponds to <|endoftext|> in the vocab
  1268  	assert.Equal(t, clipEncoder.PadToken, Token(49407))
  1269  	assert.Equal(t, clipEncoder.Encoder["<|endoftext|>"], Token(49407))
  1270  }
  1271  
  1272  func TestNerdstashPadding(t *testing.T) {
  1273  	nerdstashV2Encoder = *CacheLoadEncoder("nerdstash_v2-tokenizer")
  1274  
  1275  	// Nerdstash defines a padding token, we test if it properly gets this token
  1276  	// corresponds to <|pad|> in the vocab
  1277  	assert.Equal(t, nerdstashV2Encoder.PadToken, Token(0))
  1278  	assert.Equal(t, nerdstashV2Encoder.Encoder["<|pad|>"], Token(0))
  1279  }
  1280  
  1281  func TestLlamaPadding(t *testing.T) {
  1282  	llama2Encoder = *CacheLoadEncoder("llama-tokenizer")
  1283  
  1284  	// Llama doesn't define a padding token, we test if it properly defaults to
  1285  	// [PAD] as 65535
  1286  	assert.Equal(t, llama2Encoder.PadToken, Token(65535))
  1287  	assert.Equal(t, llama2Encoder.Encoder["[PAD]"], Token(65535))
  1288  }
  1289  
  1290  func TestMistralPadding(t *testing.T) {
  1291  	mistralEncoder = *CacheLoadEncoder("mistral-tokenizer")
  1292  
  1293  	// Mistral doesn't define a padding token, we test if it properly defaults to
  1294  	// [PAD] as 65535
  1295  	assert.Equal(t, mistralEncoder.PadToken, Token(65535))
  1296  	assert.Equal(t, mistralEncoder.Encoder["[PAD]"], Token(65535))
  1297  }
  1298  
  1299  func TestLlama3Padding(t *testing.T) {
  1300  	llama3Encoder = *CacheLoadEncoder("llama3-tokenizer")
  1301  
  1302  	// Llama doesn't define a padding token, we test if it properly defaults to
  1303  	// [PAD] as 4294967295 due to the uint32 max value
  1304  	assert.Equal(t, llama3Encoder.PadToken, Token(4294967295))
  1305  	assert.Equal(t, llama3Encoder.Encoder["[PAD]"], Token(4294967295))
  1306  }
  1307  
  1308  func TestGPTDecoder_Decode(t *testing.T) {
  1309  	// TBD
  1310  }
  1311  
  1312  func TestRankPairs(t *testing.T) {
  1313  }
  1314  
  1315  func downloadModel(modelId string, destPath string) error {
  1316  	// Download the model
  1317  	destPathPTR := &destPath
  1318  	rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
  1319  	os.MkdirAll(destPath, 0755)
  1320  	_, rsrcErr := resources.ResolveResources(
  1321  		modelId, destPathPTR,
  1322  		resources.RESOURCE_MODEL, rsrcType, hfApiToken,
  1323  	)
  1324  	if rsrcErr != nil {
  1325  		return rsrcErr
  1326  	}
  1327  	return nil
  1328  }
  1329  
  1330  func assertFileExists(t *testing.T, filePath string) {
  1331  	if _, err := os.Stat(filePath); err != nil && os.IsNotExist(err) {
  1332  		t.Errorf("File does not exist: %s", filePath)
  1333  	} else if err != nil {
  1334  		t.Errorf("Error checking file: %v", err)
  1335  	}
  1336  
  1337  }
  1338  
  1339  func TestModelDownload(t *testing.T) {
  1340  	// Download the model
  1341  	modelId := "gpt2"
  1342  	destPath := "./TestModelDownload"
  1343  	err := downloadModel(modelId, destPath)
  1344  	if err != nil {
  1345  		os.RemoveAll(destPath)
  1346  		t.Errorf("Error downloading model: %v", err)
  1347  	}
  1348  	defer os.RemoveAll(destPath)
  1349  
  1350  	// Check that the model files are there
  1351  	// We want to check for the presence of the following files:
  1352  	// config.json, pytorch_model.bin,
  1353  	// tokenizer.json, vocab.json
  1354  
  1355  	// Check for config.json
  1356  	configPath := destPath + "/config.json"
  1357  	assertFileExists(t, configPath)
  1358  
  1359  	// Check for pytorch_model.bin
  1360  	modelPath := destPath + "/pytorch_model.bin"
  1361  	assertFileExists(t, modelPath)
  1362  
  1363  	// Check for tokenizer.json
  1364  	tokenizerConfigPath := destPath + "/tokenizer.json"
  1365  	assertFileExists(t, tokenizerConfigPath)
  1366  
  1367  	// Check for vocab.json
  1368  	vocabPath := destPath + "/vocab.json"
  1369  	assertFileExists(t, vocabPath)
  1370  
  1371  	// Finish the test, allow defered cleanup
  1372  	fmt.Println("All Exists - Looks good.")
  1373  }
  1374  
  1375  func TestPythiaRemoteDownloadTokenizer(t *testing.T) {
  1376  	// Tests the ability to download a tokenizer from a remote model
  1377  	// and use it to encode and decode strings
  1378  	modelId := "EleutherAI/pythia-70m"
  1379  	destPath := "./TestPythiaRemoteDownloadTokenizer"
  1380  	defer os.RemoveAll(destPath)
  1381  	encoderPythia, err := NewEncoder(modelId)
  1382  	if err != nil {
  1383  		t.Errorf("Error creating encoder: %v", err)
  1384  	}
  1385  
  1386  	// Attempt to tokenize
  1387  	testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
  1388  
  1389  	// Encode the string
  1390  	encoded := encoderPythia.Encode(&testString)
  1391  	// Check that the encoded string is the same as the expected - Reference from python's transformers lib
  1392  	expected := Tokens{510, 30013, 16780, 689, 253, 419, 250, 15, 187, 510, 45993, 310, 7938, 685, 253, 419, 250, 15}
  1393  	if !assert.Equal(t, expected, *encoded) {
  1394  		t.Errorf("Expected: %v\nActual: %v", expected, *encoded)
  1395  	}
  1396  }
  1397  
  1398  func TestLlama3RemoteDownloadTokenizer(t *testing.T) {
  1399  	// Tests the ability to download a tokenizer from a remote model
  1400  	// and use it to encode and decode strings
  1401  	modelId := "Groq/Llama-3-Groq-8B-Tool-Use" // Original Llama3 model is gated
  1402  	destPath := "./TestLlama3RemoteDownloadTokenizer"
  1403  	defer os.RemoveAll(destPath)
  1404  	encoderLlama3, err := NewEncoder(modelId)
  1405  	if err != nil {
  1406  		t.Errorf("Error creating encoder: %v", err)
  1407  	}
  1408  
  1409  	// Attempt to tokenize
  1410  	testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
  1411  
  1412  	// Encode the string
  1413  	encoded := encoderLlama3.Encode(&testString)
  1414  	// Check that the encoded string is the same as the expected - 128009 is 128001 in the original Llama3 model
  1415  	expected := Tokens{128000, 791, 39935, 27096, 927, 279, 96018, 627, 791, 37189, 374, 10819, 1109, 279, 96018, 13, 128009}
  1416  	if !assert.Equal(t, expected, *encoded) {
  1417  		t.Errorf("Expected: %v\nActual: %v", expected, *encoded)
  1418  	}
  1419  }
  1420  
  1421  func TestMistralRemoteDownloadTokenizer(t *testing.T) {
  1422  	// Tests the ability to download a tokenizer from a remote model
  1423  	// and use it to encode and decode strings
  1424  	modelId := "openaccess-ai-collective/tiny-mistral"
  1425  	//destPath := "./TestMistralRemoteDownloadTokenizer"
  1426  	//defer os.RemoveAll(destPath)
  1427  	encoderMistral, err := NewEncoder(modelId)
  1428  	if err != nil {
  1429  		t.Errorf("Error creating encoder: %v", err)
  1430  	}
  1431  
  1432  	// Attempt to tokenize
  1433  	testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
  1434  
  1435  	// Encode the string
  1436  	encoded := encoderMistral.Encode(&testString)
  1437  	// Check that the encoded string is the same as the expected - Reference from python's transformers lib
  1438  	expected := Tokens{1, 1014, 285, 1142, 14949, 754, 272, 295, 492, 28723, 13, 1014, 261, 3525, 291, 349, 9556, 821, 272, 295, 492, 28723}
  1439  	if !assert.Equal(t, expected, *encoded) {
  1440  		t.Errorf("Expected: %v\nActual: %v", expected, *encoded)
  1441  	}
  1442  }
  1443  
  1444  func TestModelDownloadPythia(t *testing.T) {
  1445  	// Pythia uses a slightly different file structure, where
  1446  	// the vocab.json and merges.txt files are stored in the
  1447  	// tokenizer.json file. We want to check if we are able to
  1448  	// download the model and extract the vocab.json and merges.txt
  1449  	modelId := "EleutherAI/pythia-70m"
  1450  	destPath := "./TestModelDownloadPythia"
  1451  	err := downloadModel(modelId, destPath)
  1452  	if err != nil {
  1453  		os.RemoveAll(destPath)
  1454  		t.Errorf("Error downloading model: %v", err)
  1455  	}
  1456  
  1457  	// Check that the model files are there
  1458  	// We want to check for the presence of the following files:
  1459  	// config.json, pytorch_model.bin,
  1460  	// tokenizer.json, vocab.json
  1461  
  1462  	// Check for additional metadata files
  1463  	metaFiles := []string{"tokenizer.json", "vocab.json", "config.json", "pytorch_model.bin"}
  1464  	for _, metaFile := range metaFiles {
  1465  		metaPath := destPath + "/" + metaFile
  1466  		assertFileExists(t, metaPath)
  1467  	}
  1468  
  1469  	// Finish the test, allow defered cleanup
  1470  	fmt.Println("All Exists - Looks good.")
  1471  }
  1472  
  1473  func TestModelDownloadPythiaSharded(t *testing.T) {
  1474  	// This tests the model downloader's ability
  1475  	// to download a sharded model.
  1476  
  1477  	modelId := "EleutherAI/pythia-6.9b-deduped"
  1478  	destPath := "./TestModelDownloadPythiaSharded"
  1479  	err := downloadModel(modelId, destPath)
  1480  	if err != nil {
  1481  		os.RemoveAll(destPath)
  1482  		t.Errorf("Error downloading model: %v", err)
  1483  	}
  1484  	defer os.RemoveAll(destPath)
  1485  
  1486  	// Check that the model files are there
  1487  	// We want to check for the presence of the following files:
  1488  	// pytorch_model-00001-of-00002.bin, pytorch_model-00002-of-00002.bin,
  1489  	// pytorch_model.bin.index.json
  1490  
  1491  	// Check for pytorch_model-00001-of-00002.bin
  1492  	model1Path := destPath + "/pytorch_model-00001-of-00002.bin"
  1493  	assertFileExists(t, model1Path)
  1494  
  1495  	// Check for pytorch_model-00002-of-00002.bin
  1496  	model2Path := destPath + "/pytorch_model-00002-of-00002.bin"
  1497  	assertFileExists(t, model2Path)
  1498  
  1499  	// Check for pytorch_model.bin.index.json
  1500  	shardconfigPath := destPath + "/pytorch_model.bin.index.json"
  1501  	assertFileExists(t, shardconfigPath)
  1502  
  1503  	// Finish the test, allow defered cleanup
  1504  	fmt.Println("All Exists - Looks good.")
  1505  
  1506  }
  1507  
  1508  func TestModelDownloadLlama(t *testing.T) {
  1509  	// Pythia uses a slightly different file structure, where
  1510  	// the vocab.json and merges.txt files are stored in the
  1511  	// tokenizer.json file. We want to check if we are able to
  1512  	// download the model and extract the vocab.json and merges.txt
  1513  	modelId := "Maykeye/TinyLLama-v0"
  1514  	destPath := "./TestModelDownloadLlama"
  1515  	err := downloadModel(modelId, destPath)
  1516  	if err != nil {
  1517  		os.RemoveAll(destPath)
  1518  		t.Errorf("Error downloading model: %v", err)
  1519  	}
  1520  	defer os.RemoveAll(destPath)
  1521  
  1522  	// Check that the model files are there
  1523  	// We want to check for the presence of the following files:
  1524  	// config.json, pytorch_model.bin,
  1525  	// tokenizer.model, vocab.json
  1526  
  1527  	// Check for pytorch_model.bin
  1528  	singleModelPattern := regexp.MustCompile(`pytorch_model\.bin$`)
  1529  	re, err := regexp.Compile(`-(\d+)-of-(\d+)\.bin$`)
  1530  	if err != nil {
  1531  		t.Errorf("Error compiling regex: %s", err)
  1532  	}
  1533  
  1534  	//check all files in the directory against the pattern
  1535  	files, err := ioutil.ReadDir(destPath)
  1536  	if err != nil {
  1537  		t.Errorf("Error reading directory: %s", err)
  1538  	}
  1539  	found := false
  1540  
  1541  	for _, file := range files {
  1542  		if singleModelPattern.MatchString(file.Name()) {
  1543  			found = true
  1544  			break
  1545  		}
  1546  
  1547  		matches := re.FindStringSubmatch(file.Name())
  1548  		if len(matches) > 2 {
  1549  			if strings.Compare(matches[1], matches[2]) == 0 {
  1550  				found = true
  1551  				break
  1552  			}
  1553  		}
  1554  	}
  1555  	if !found {
  1556  		t.Errorf("pytorch_model.bin does not exist or was not found")
  1557  	}
  1558  
  1559  	// Check for additional metadata files
  1560  	metaFiles := []string{"tokenizer.model", "vocab.json", "config.json"}
  1561  	for _, metaFile := range metaFiles {
  1562  		metaPath := destPath + "/" + metaFile
  1563  		assertFileExists(t, metaPath)
  1564  	}
  1565  
  1566  	// Finish the test, allow defered cleanup
  1567  	fmt.Println("All Exists - Looks good.")
  1568  }
  1569  
  1570  func TestModelDownloadMistral(t *testing.T) {
  1571  	// Download a downstream mistral model due to mistral being gated
  1572  	modelId := "openaccess-ai-collective/tiny-mistral"
  1573  	destPath := "./TestModelDownloadMistral"
  1574  	err := downloadModel(modelId, destPath)
  1575  	if err != nil {
  1576  		os.RemoveAll(destPath)
  1577  		t.Errorf("Error downloading model: %v", err)
  1578  	}
  1579  	defer os.RemoveAll(destPath)
  1580  
  1581  	// Check that the model files are there
  1582  	// We want to check for the presence of the following files:
  1583  	// config.json, pytorch_model.bin,
  1584  	// tokenizer.model
  1585  
  1586  	// Check for additional metadata files
  1587  	metaFiles := []string{"tokenizer.model", "config.json", "pytorch_model.bin"}
  1588  	for _, metaFile := range metaFiles {
  1589  		metaPath := destPath + "/" + metaFile
  1590  		assertFileExists(t, metaPath)
  1591  	}
  1592  
  1593  	// Finish the test, allow defered cleanup
  1594  	fmt.Println("All Exists - Looks good.")
  1595  }
  1596  
  1597  func TestModelDownloadFairseq(t *testing.T) {
  1598  	// Koboldai's fairseq models are stored in a different format
  1599  	// it has merges and vocab but no tokenizer.json
  1600  	modelId := "KoboldAI/fairseq-dense-355M"
  1601  	destPath := "./TestModelDownloadFairseq"
  1602  
  1603  	// Download the model
  1604  	err := downloadModel(modelId, destPath)
  1605  	if err != nil {
  1606  		os.RemoveAll(destPath)
  1607  		t.Errorf("Error downloading model: %v", err)
  1608  	}
  1609  	defer os.RemoveAll(destPath)
  1610  
  1611  	// Check that the model files are there
  1612  	// We want to check for the presence of the following files:
  1613  	// vocab, config. merges, pytorch_model
  1614  
  1615  	// Check for additional metadata files
  1616  	metaFiles := []string{"vocab.json", "config.json", "pytorch_model.bin", "merges.txt"}
  1617  	for _, metaFile := range metaFiles {
  1618  		metaPath := destPath + "/" + metaFile
  1619  		assertFileExists(t, metaPath)
  1620  	}
  1621  
  1622  	// Finish the test, allow defered cleanup
  1623  	fmt.Println("All Exists - Looks good (Fairseq Download).")
  1624  }