github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/resources/resolver.go (about)

     1  package resources
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"io/fs"
    11  	"io/ioutil"
    12  	"log"
    13  	"math"
    14  	"net/url"
    15  	"os"
    16  	"path"
    17  	"regexp"
    18  	"strconv"
    19  	"strings"
    20  	"time"
    21  
    22  	"github.com/wbrown/gpt_bpe/types"
    23  
    24  	"github.com/dustin/go-humanize"
    25  )
    26  
    27  type Token types.Token
    28  type Tokens types.Tokens
    29  type TokenMap types.TokenMap
    30  
    31  type JsonMap map[string]interface{}
    32  
    33  type ResourceFlag uint8
    34  type ResourceType uint8
    35  
    36  // WriteCounter counts the number of bytes written to it, and every 10
    37  // seconds, it prints a message reporting the number of bytes written so far.
    38  type WriteCounter struct {
    39  	Total    uint64
    40  	Last     time.Time
    41  	Reported bool
    42  	Path     string
    43  	Size     uint64
    44  }
    45  
    46  // Write writes p to the WriteCounter and updates the total number of bytes
    47  // written.
    48  func (wc *WriteCounter) Write(p []byte) (int, error) {
    49  	n := len(p)
    50  	wc.Total += uint64(n)
    51  	if time.Since(wc.Last).Seconds() > 10 {
    52  		wc.Reported = true
    53  		wc.Last = time.Now()
    54  		log.Printf(
    55  			"Downloading %s... %s / %s completed.",
    56  			wc.Path, humanize.Bytes(wc.Total), humanize.Bytes(wc.Size),
    57  		)
    58  	}
    59  	return n, nil
    60  }
    61  
    62  // Enumeration of resource flags that indicate what the resolver should do
    63  // with the resource.
    64  const (
    65  	RESOURCE_REQUIRED ResourceFlag = 1 << iota
    66  	RESOURCE_OPTIONAL
    67  	RESOURCE_DERIVED
    68  	RESOURCE_MODEL
    69  	RESOURCE_ONEOF
    70  )
    71  
    72  // Enumeration of different types of models
    73  const (
    74  	RESOURCETYPE_TRANSFORMERS ResourceType = 1 << iota
    75  	RESOURCETYPE_DIFFUSERS
    76  )
    77  
    78  var TransformerResources = ResourceEntryDefs{
    79  	"config.json":                  RESOURCE_REQUIRED,
    80  	"vocab.json":                   RESOURCE_OPTIONAL,
    81  	"merges.txt":                   RESOURCE_OPTIONAL,
    82  	"special_tokens_map.json":      RESOURCE_OPTIONAL,
    83  	"wordtokens.json":              RESOURCE_OPTIONAL,
    84  	"specials.txt":                 RESOURCE_OPTIONAL | RESOURCE_DERIVED,
    85  	"tokenizer_config.json":        RESOURCE_OPTIONAL,
    86  	"pytorch_model.bin.index.json": RESOURCE_OPTIONAL,
    87  	"tokenizer.json":               RESOURCE_OPTIONAL,
    88  	"tokenizer.model":              RESOURCE_OPTIONAL,
    89  	"pytorch_model.bin":            RESOURCE_MODEL,
    90  }
    91  
    92  var DiffuserResources = ResourceEntryDefs{
    93  	"feature_extractor/preprocessor_config.json": RESOURCE_OPTIONAL,
    94  	"safety_checker/config.json":                 RESOURCE_OPTIONAL,
    95  	"safety_checker/pytorch_model.bin":           RESOURCE_OPTIONAL,
    96  	"scheduler/scheduler_config.json":            RESOURCE_REQUIRED,
    97  	"text_encoder/config.json":                   RESOURCE_REQUIRED,
    98  	"text_encoder/pytorch_model.bin":             RESOURCE_MODEL,
    99  	"tokenizer/merges.txt":                       RESOURCE_REQUIRED,
   100  	"tokenizer/special_tokens_map.json":          RESOURCE_REQUIRED,
   101  	"tokenizer/tokenizer_config.json":            RESOURCE_REQUIRED,
   102  	"tokenizer/vocab.json":                       RESOURCE_REQUIRED,
   103  	"unet/config.json":                           RESOURCE_REQUIRED,
   104  	"unet/diffusion_pytorch_model.bin":           RESOURCE_MODEL,
   105  	"vae/config.json":                            RESOURCE_REQUIRED,
   106  	"vae/diffusion_pytorch_model.bin":            RESOURCE_MODEL,
   107  	"model_index.json":                           RESOURCE_REQUIRED,
   108  }
   109  
   110  type ResourceEntryDefs map[string]ResourceFlag
   111  type ResourceEntry struct {
   112  	file interface{}
   113  	Data *[]byte
   114  }
   115  
   116  type Resources map[string]ResourceEntry
   117  
   118  // Cleanup closes all open file handles in the Resources map.
   119  func (rsrcs *Resources) Cleanup() {
   120  	for _, rsrc := range *rsrcs {
   121  		file := rsrc.file
   122  		switch t := file.(type) {
   123  		case os.File:
   124  			_ = t.Close()
   125  		case fs.File:
   126  			_ = t.Close()
   127  		}
   128  	}
   129  }
   130  
   131  // GetFile
   132  // Returns the file handle for a given resource name.
   133  func (rsrcs *Resources) GetFile(name string) (interface{}, error) {
   134  	if rsrcEntry, ok := (*rsrcs)[name]; ok {
   135  		return rsrcEntry.file, nil
   136  	} else {
   137  		return nil, fmt.Errorf("file %s not found", name)
   138  	}
   139  }
   140  
   141  // GetResourceEntries
   142  // Returns a default map of resource entries that express what files are
   143  // required, optional, derived, and/or model resources. Requires a
   144  // ResourceType.
   145  func GetResourceEntries(typ ResourceType) ResourceEntryDefs {
   146  	switch typ {
   147  	case RESOURCETYPE_TRANSFORMERS:
   148  		return TransformerResources
   149  	case RESOURCETYPE_DIFFUSERS:
   150  
   151  		return DiffuserResources
   152  	default:
   153  		return ResourceEntryDefs{}
   154  	}
   155  }
   156  
   157  // getResourceEntryAliases
   158  // Returns a map of defined resources to known alternative filenames
   159  // for each resource of a given ResourceType.
   160  func getResourceEntryAliases(typ ResourceType) map[string][]string {
   161  	switch typ {
   162  	case RESOURCETYPE_TRANSFORMERS:
   163  		return map[string][]string{
   164  			"vocab.json": {"encoder.json"},
   165  		}
   166  	default:
   167  		return map[string][]string{}
   168  	}
   169  }
   170  
   171  // FetchHuggingFace
   172  // Wrapper around FetchHTTP that fetches a resource from huggingface.co.
   173  func FetchHuggingFace(id string, rsrc string) (io.ReadCloser, error) {
   174  	token := os.Getenv("HF_API_TOKEN")
   175  	return FetchHTTP(
   176  		"https://huggingface.co/"+id+"/resolve/main", rsrc, token,
   177  	)
   178  }
   179  
   180  // SizeHuggingFace
   181  // Wrapper around SizeHTTP that gets the size of a resource from huggingface.co.
   182  func SizeHuggingFace(id string, rsrc string) (uint, error) {
   183  	token := os.Getenv("HF_API_TOKEN")
   184  	return SizeHTTP("https://huggingface.co/"+id+"/resolve/main", rsrc, token)
   185  }
   186  
   187  // isValidUrl
   188  // Checks if a given string is a valid URL, returns true if it is, false
   189  // otherwise.
   190  func isValidUrl(toTest string) bool {
   191  	_, err := url.ParseRequestURI(toTest)
   192  	if err != nil {
   193  		return false
   194  	}
   195  
   196  	u, err := url.Parse(toTest)
   197  	if err != nil || u.Scheme == "" || u.Host == "" {
   198  		return false
   199  	}
   200  
   201  	return true
   202  }
   203  
   204  // Fetch
   205  // Given a base URI and a resource name, determines if the resource is local,
   206  // remote, or from huggingface.co. If the resource is local, it returns a
   207  // file handle to the resource. If the resource is remote, or from
   208  // huggingface.co, it fetches the resource and returns a ReadCloser to the
   209  // fetched or cached resource.
   210  func Fetch(uri string, rsrc string, token string) (io.ReadCloser, error) {
   211  	if isValidUrl(uri) {
   212  		return FetchHTTP(uri, rsrc, token)
   213  	} else if _, err := os.Stat(path.Join(uri, rsrc)); !os.IsNotExist(err) {
   214  		if handle, fileErr := os.Open(path.Join(uri, rsrc)); fileErr != nil {
   215  			return nil, fmt.Errorf(
   216  				"error opening %s/%s: %v",
   217  				uri, rsrc, fileErr,
   218  			)
   219  		} else {
   220  			return handle, fileErr
   221  		}
   222  	} else {
   223  		return FetchHuggingFace(uri, rsrc)
   224  	}
   225  }
   226  
   227  // Size
   228  // Given a base URI and a resource name, determine the size of the resource.
   229  func Size(uri string, rsrc string, token string) (uint, error) {
   230  	if isValidUrl(uri) {
   231  		return SizeHTTP(uri, rsrc, token)
   232  	} else if fsz, err := os.Stat(path.Join(uri, rsrc)); !os.IsNotExist(err) {
   233  		return uint(fsz.Size()), nil
   234  	} else {
   235  		return SizeHuggingFace(uri, rsrc)
   236  	}
   237  }
   238  
   239  // AddEntry
   240  // Add a resource to the Resources map, opening it as a mmap.Map.
   241  func (rsrcs *Resources) AddEntry(name string, file *os.File) error {
   242  	fileMmap, mmapErr := readMmap(file)
   243  	if mmapErr != nil {
   244  		return fmt.Errorf("error trying to mmap file: %s", mmapErr)
   245  	} else {
   246  		(*rsrcs)[name] = ResourceEntry{file, fileMmap}
   247  	}
   248  	return nil
   249  }
   250  
   251  // Specials
   252  // Map of special tokens such as <|pad|>, <|endoftext|>, etc.
   253  type Specials map[string]string
   254  
   255  // ResolveSpecialTokens
   256  // If specials.json does not exist in dir, create it from the
   257  // special_tokens_map.json file.
   258  func (rsrcs *Resources) ResolveSpecialTokens(dir string) (
   259  	realizedSpecials Specials, err error,
   260  ) {
   261  	realizedSpecials = make(Specials)
   262  	// If we already have specials.json, we don't need to generate it.
   263  	if _, ok := (*rsrcs)["specials.json"]; ok {
   264  		if specErr := json.Unmarshal(
   265  			*(*rsrcs)["specials.json"].Data,
   266  			&realizedSpecials,
   267  		); specErr != nil {
   268  			return nil, fmt.Errorf(
   269  				"cannot unmarshal specials.json: %s", specErr,
   270  			)
   271  		}
   272  		return realizedSpecials, nil
   273  	}
   274  
   275  	// We can only generate specials.json if we have special_tokens_map
   276  	specialsJson, ok := (*rsrcs)["special_tokens_map.json"]
   277  	if !ok {
   278  		return nil, nil
   279  	}
   280  
   281  	specialTokens := make(JsonMap)
   282  	if specialErr := json.Unmarshal(
   283  		*specialsJson.Data,
   284  		&specialTokens,
   285  	); specialErr != nil {
   286  		return nil, specialErr
   287  	}
   288  
   289  	for k, v := range specialTokens {
   290  		var specialToken string
   291  		switch t := v.(type) {
   292  		case string:
   293  			specialToken = t
   294  		case JsonMap:
   295  			mv := t["content"]
   296  			switch mvt := mv.(type) {
   297  			case string:
   298  				specialToken = mvt
   299  			default:
   300  				log.Fatalf(
   301  					"unknown format for `special_tokens_map."+
   302  						"json`: %v", t,
   303  				)
   304  			}
   305  		default:
   306  			log.Fatalf(
   307  				"unknown format for `special_tokens_map."+
   308  					"json`: %v", t,
   309  			)
   310  		}
   311  		realizedSpecials[k] = specialToken
   312  	}
   313  	if len(realizedSpecials) > 0 {
   314  		specialsFile, specialFileErr := os.OpenFile(
   315  			path.Join(dir, "specials.json"),
   316  			os.O_TRUNC|os.O_RDWR|os.O_CREATE, 0755,
   317  		)
   318  		if specialFileErr != nil {
   319  			return nil, fmt.Errorf(
   320  				"cannot generate specials.json: %s",
   321  				specialFileErr,
   322  			)
   323  		}
   324  		specialsJsonBytes, specialsErr := json.Marshal(realizedSpecials)
   325  		if specialsErr != nil {
   326  			_ = specialsFile.Close()
   327  			return nil, fmt.Errorf(
   328  				"cannot marshal specials.json: %s", specialsErr,
   329  			)
   330  		}
   331  		if _, writeErr := specialsFile.Write(
   332  			specialsJsonBytes,
   333  		); writeErr != nil {
   334  			_ = specialsFile.Close()
   335  			return nil, fmt.Errorf(
   336  				"cannot write specials.json: %s", specialsErr,
   337  			)
   338  		}
   339  		if _, seekErr := specialsFile.Seek(0, 0); seekErr != nil {
   340  			return nil,
   341  				fmt.Errorf("cannot seek specials.json: %s", seekErr)
   342  		}
   343  		if mmapErr := rsrcs.AddEntry(
   344  			"specials.json",
   345  			specialsFile,
   346  		); mmapErr != nil {
   347  			return nil, mmapErr
   348  		}
   349  	}
   350  	return realizedSpecials, nil
   351  }
   352  
   353  // ResolveResources resolves all resources at a given uri, and checks if they
   354  // exist in the given directory. If they don't exist, they are downloaded.
   355  func ResolveResources(
   356  	uri string,
   357  	dir *string,
   358  	rsrcLvl ResourceFlag,
   359  	rsrcType ResourceType,
   360  	token string,
   361  ) (
   362  	*Resources,
   363  	error,
   364  ) {
   365  	foundResources := make(Resources)
   366  	resources := GetResourceEntries(rsrcType)
   367  	aliases := getResourceEntryAliases(rsrcType)
   368  
   369  	for file, flag := range resources {
   370  		var rsrcFile os.File
   371  
   372  		// Resolve the resource
   373  		if flag <= rsrcLvl {
   374  			log.Printf("Resolving %s/%s... ", uri, file)
   375  			targetPath := path.Join(*dir, file)
   376  			rsrcSize, rsrcSizeErr := Size(uri, file, token)
   377  			alias := file
   378  			if rsrcSizeErr != nil {
   379  				// If the resource isn't found under its normal filename,
   380  				// check under any known aliases.
   381  				if aliasesList, ok := aliases[file]; ok {
   382  					for _, alias = range aliasesList {
   383  						rsrcSize, rsrcSizeErr = Size(uri, alias, token)
   384  						if rsrcSizeErr == nil {
   385  							log.Printf(
   386  								"Resolving %s/%s as alias %s/%s...",
   387  								uri, file, uri, alias,
   388  							)
   389  							break
   390  						}
   391  					}
   392  				}
   393  			}
   394  			if rsrcSizeErr != nil {
   395  				// If the resource is required, we cannot continue.
   396  				if flag&RESOURCE_REQUIRED != 0 {
   397  					log.Printf(
   398  						"%s/%s not found, required!",
   399  						uri, file,
   400  					)
   401  					return &foundResources, fmt.Errorf(
   402  						"cannot retrieve required `%s from %s`: %s",
   403  						uri, file, rsrcSizeErr,
   404  					)
   405  				} else {
   406  					// Otherwise, we can skip it.
   407  					continue
   408  				}
   409  				// If the resource exists, and is the correct size, skip it.
   410  			} else if targetStat, targetStatErr := os.Stat(targetPath); !os.IsNotExist(
   411  				targetStatErr,
   412  			) && uint(targetStat.Size()) == rsrcSize {
   413  				log.Printf(
   414  					"Skipping %s/%s... already exists, "+
   415  						"and of the correct size.", uri, file,
   416  				)
   417  				openFile, skipFileErr := os.OpenFile(
   418  					path.Join(*dir, file),
   419  					os.O_RDONLY, 0755,
   420  				)
   421  				if skipFileErr != nil {
   422  					return &foundResources, fmt.Errorf(
   423  						"error opening '%s' for read: %s",
   424  						file, skipFileErr,
   425  					)
   426  
   427  				} else {
   428  					// If the resource exists, but is the wrong size, we need
   429  					// to download it.
   430  					rsrcFile = *openFile
   431  				}
   432  			} else if rsrcReader, rsrcErr := Fetch(
   433  				uri, alias, token,
   434  			); rsrcErr != nil {
   435  				return &foundResources, fmt.Errorf(
   436  					"cannot retrieve `%s from %s`: %s",
   437  					uri, alias, rsrcErr,
   438  				)
   439  			} else {
   440  				if dirErr := os.MkdirAll(
   441  					path.Dir(path.Join(*dir, file)), 0755,
   442  				); dirErr != nil {
   443  					return &foundResources, fmt.Errorf(
   444  						"cannot create directory for '%s': %s",
   445  						file, dirErr,
   446  					)
   447  				}
   448  				openFile, rsrcFileErr := os.OpenFile(
   449  					path.Join(*dir, file),
   450  					os.O_TRUNC|os.O_RDWR|os.O_CREATE, 0755,
   451  				)
   452  				if rsrcFileErr != nil {
   453  					return &foundResources, fmt.Errorf(
   454  						"error opening '%s' for write: %s",
   455  						file, rsrcFileErr,
   456  					)
   457  				}
   458  				rsrcFile = *openFile
   459  
   460  				counter := &WriteCounter{
   461  					Last: time.Now(),
   462  					Path: fmt.Sprintf("%s/%s", uri, file),
   463  					Size: uint64(rsrcSize),
   464  				}
   465  				bytesDownloaded, ioErr := io.Copy(
   466  					&rsrcFile,
   467  					io.TeeReader(rsrcReader, counter),
   468  				)
   469  				_ = rsrcReader.Close()
   470  				if ioErr != nil {
   471  					return &foundResources, fmt.Errorf(
   472  						"error downloading '%s': %s",
   473  						alias, ioErr,
   474  					)
   475  				} else {
   476  					log.Printf(
   477  						"Downloaded %s/%s... "+
   478  							"%s completed.", uri, alias,
   479  						humanize.Bytes(uint64(bytesDownloaded)),
   480  					)
   481  				}
   482  			}
   483  
   484  			if mmapErr := foundResources.AddEntry(
   485  				file,
   486  				&rsrcFile,
   487  			); mmapErr != nil {
   488  				return &foundResources, fmt.Errorf(
   489  					"error trying to mmap file: %s",
   490  					mmapErr,
   491  				)
   492  			}
   493  		}
   494  	}
   495  
   496  	// check if tokenizer.model exists, if so, expand to files
   497  	flagTokenizerModelExist := CheckFileExist(
   498  		path.Join(
   499  			*dir, "tokenizer.model",
   500  		),
   501  	)
   502  	if flagTokenizerModelExist {
   503  		// check size of tokenizer.model
   504  		targetStat, targetStatErr := os.Stat(
   505  			path.Join(
   506  				*dir, "tokenizer.model",
   507  			),
   508  		)
   509  		if targetStatErr != nil {
   510  			return &foundResources, fmt.Errorf(
   511  				"cannot stat tokenizer.model: %s",
   512  				targetStatErr,
   513  			)
   514  		}
   515  		if targetStat.Size() == 0 {
   516  			flagTokenizerModelExist = false
   517  		}
   518  	}
   519  
   520  	if flagTokenizerModelExist {
   521  		log.Printf(
   522  			"Directory %s contains tokenizer.model, extracting to files",
   523  			path.Join(*dir, "tokenizer.model"),
   524  		)
   525  		ConvertSentencepieceFiles(
   526  			path.Join(*dir, "tokenizer.model"),
   527  			false,
   528  		)
   529  
   530  		// Add the new files to the resources
   531  		files, _ := os.ReadDir(*dir)
   532  		for _, f := range files {
   533  			// If not already in the resources, add it
   534  			if _, ok := foundResources[f.Name()]; !ok {
   535  				openFile, rsrcFileErr := os.OpenFile(
   536  					path.Join(*dir, f.Name()),
   537  					os.O_RDONLY, 0755,
   538  				)
   539  				if rsrcFileErr != nil {
   540  					return &foundResources, fmt.Errorf(
   541  						"error opening '%s' for read: %s",
   542  						f.Name(), rsrcFileErr,
   543  					)
   544  				}
   545  				rsrcFile := *openFile
   546  				if mmapErr := foundResources.AddEntry(
   547  					f.Name(),
   548  					&rsrcFile,
   549  				); mmapErr != nil {
   550  					return &foundResources, fmt.Errorf(
   551  						"error trying to mmap file: %s",
   552  						mmapErr,
   553  					)
   554  				}
   555  				log.Printf(
   556  					"Added %s to resources via sentencepiece conversion\n",
   557  					f.Name(),
   558  				)
   559  			}
   560  		}
   561  
   562  	} else {
   563  		// check if tokenizer exists by checking if tokenizer.json exists
   564  		// and has data in it
   565  		flagTokenizerExist := CheckFileExist(
   566  			path.Join(
   567  				*dir, "tokenizer.json",
   568  			),
   569  		)
   570  		if flagTokenizerExist {
   571  			// check size of tokenizer.json
   572  			targetStat, targetStatErr := os.Stat(
   573  				path.Join(
   574  					*dir, "tokenizer.json",
   575  				),
   576  			)
   577  			if targetStatErr != nil {
   578  				return &foundResources, fmt.Errorf(
   579  					"cannot stat tokenizer.json: %s",
   580  					targetStatErr,
   581  				)
   582  			}
   583  			if targetStat.Size() == 0 {
   584  				flagTokenizerExist = false
   585  			}
   586  		}
   587  
   588  		// if tokenizer exists, but vocab and merges do not exist, extract
   589  		// from tokenizer, else if vocab and merges exist, do nothing; if
   590  		// both do not exist, fail
   591  		flagVocabExist := CheckFileExist(path.Join(*dir, "vocab.json"))
   592  		flagMergesExists := CheckFileExist(path.Join(*dir, "merges.txt"))
   593  
   594  		if flagTokenizerExist {
   595  			// if vocab does not exist, extract it from tokenizer
   596  			if !flagVocabExist {
   597  				model, err := ExtractModelFromTokenizer(dir)
   598  				if err != nil {
   599  					return &foundResources, fmt.Errorf(
   600  						"could not extract model from tokenizer %s",
   601  						err,
   602  					)
   603  				}
   604  
   605  				err = ExtractVocabFromTokenizer(model, dir, &foundResources)
   606  				if err != nil {
   607  					return &foundResources, fmt.Errorf(
   608  						"could not extract vocab from tokenizer %s",
   609  						err,
   610  					)
   611  				}
   612  			}
   613  
   614  			// if merges does not exist, extract it from tokenizer
   615  			if !flagMergesExists {
   616  				model, err := ExtractModelFromTokenizer(dir)
   617  				if err != nil {
   618  					return &foundResources,
   619  						fmt.Errorf(
   620  							"could not extract model from tokenizer %s",
   621  							err,
   622  						)
   623  				}
   624  
   625  				err = ExtractMergesFromTokenizer(model, dir, &foundResources)
   626  				if err != nil {
   627  					return &foundResources, fmt.Errorf(
   628  						"could not extract merges from tokenizer %s",
   629  						err,
   630  					)
   631  				}
   632  			}
   633  		} else {
   634  			// if tokenizer does not exist, check if vocab and merges exist
   635  			if flagVocabExist && flagMergesExists {
   636  				// if both exist, do nothing
   637  				log.Println(
   638  					"Vocab and merges exist, but tokenizer does not. OK",
   639  				)
   640  			} else {
   641  				// if either does not exist, fail
   642  				return &foundResources, fmt.Errorf(
   643  					"tokenizer, vocab, and merges do not exist",
   644  				)
   645  			}
   646  		}
   647  
   648  	}
   649  
   650  	// Check if we already got the pytorch model file
   651  	flagModelExists := CheckFileExist(path.Join(*dir, "pytorch_model.bin"))
   652  	log.Printf("Pytorch Model File exists: %t\n", flagModelExists)
   653  
   654  	// if model does not exist, check if we have the sharded config
   655  	if !flagModelExists {
   656  		flagShardConfigExists := CheckFileExist(
   657  			path.Join(
   658  				*dir, "pytorch_model.bin.index.json",
   659  			),
   660  		)
   661  		log.Printf("Shard config exists: %t", flagShardConfigExists)
   662  		//if sharded config exists, attempt to download the shards
   663  		if flagShardConfigExists {
   664  			numShards, err := FindNumberOfShardsFromConfig(
   665  				path.Join(
   666  					*dir, "pytorch_model.bin.index.json",
   667  				),
   668  			)
   669  			if err != nil {
   670  				log.Printf(
   671  					"Could not find number of shards from config: %s\n",
   672  					err,
   673  				)
   674  				return &foundResources, errors.New(
   675  					"could not find number of shards from config",
   676  				)
   677  			}
   678  
   679  			// pad the number of shards to 5 digits
   680  			log.Printf("Found %d shards\n", numShards)
   681  			paddedTotalShards := fmt.Sprintf("%05d", numShards)
   682  
   683  			// loop through shards and download them
   684  			for i := 1; i <= numShards; i++ {
   685  				var rsrcFile os.File
   686  
   687  				paddedShardString := fmt.Sprintf("%05d", i)
   688  				// Construct the shard path
   689  				shardPath := fmt.Sprintf(
   690  					"pytorch_model-%s-of-%s.bin", paddedShardString,
   691  					paddedTotalShards,
   692  				)
   693  				log.Printf("Resolving shard %s\n", shardPath)
   694  
   695  				targetPath := path.Join(*dir, shardPath)
   696  				rsrcSize, rsrcSizeErr := Size(uri, shardPath, token)
   697  				if rsrcSizeErr != nil {
   698  					fmt.Printf(
   699  						"could not get size of shard %s: %s\n",
   700  						shardPath,
   701  						rsrcSizeErr,
   702  					)
   703  					return &foundResources,
   704  						errors.New("could not get size of shard")
   705  				}
   706  				// Print size of shard
   707  				log.Printf(
   708  					"Remote size of shard %s is %s\n", shardPath,
   709  					humanize.Bytes(uint64(rsrcSize)),
   710  				)
   711  
   712  				// Check if shard exists locally
   713  				flagShardExists := CheckFileExist(targetPath)
   714  				if flagShardExists {
   715  					// Check if shard local size is correct compared to
   716  					// remote size
   717  					localShardInfo, err := os.Stat(targetPath)
   718  					if err != nil {
   719  						fmt.Printf(
   720  							"Could not get size of local shard %s: %s\n",
   721  							shardPath, err,
   722  						)
   723  						return &foundResources,
   724  							errors.New("could not get size of local shard")
   725  					}
   726  					if (rsrcSize > 0) && (rsrcSize == uint(localShardInfo.Size())) {
   727  						log.Printf(
   728  							"Skipping shard %s, exists and correct size\n",
   729  							shardPath,
   730  						)
   731  						continue
   732  					}
   733  				}
   734  
   735  				// fetch shard
   736  				var rsrcReader io.ReadCloser
   737  				rsrcReader, err = Fetch(uri, shardPath, token)
   738  				if err != nil {
   739  					return &foundResources, fmt.Errorf(
   740  						"error trying to fetch file: %s", err,
   741  					)
   742  				}
   743  
   744  				// create shard file
   745  				var rsrcFilePtr *os.File
   746  				rsrcFilePtr, err = os.Create(targetPath)
   747  				if err != nil {
   748  					return &foundResources, fmt.Errorf(
   749  						"error trying to create file: %s", err,
   750  					)
   751  				}
   752  				rsrcFile = *rsrcFilePtr
   753  
   754  				// copy shard to file
   755  				counter := &WriteCounter{
   756  					Last: time.Now(),
   757  					Path: fmt.Sprintf("%s/%s", uri, shardPath),
   758  					Size: uint64(rsrcSize),
   759  				}
   760  				bytesDownloaded, ioErr := io.Copy(
   761  					&rsrcFile,
   762  					io.TeeReader(rsrcReader, counter),
   763  				)
   764  				//close shard reader
   765  				err = rsrcReader.Close()
   766  
   767  				if err != nil {
   768  					return &foundResources, fmt.Errorf(
   769  						"error trying to close reader: %s", err,
   770  					)
   771  				}
   772  				if ioErr != nil {
   773  					return &foundResources, fmt.Errorf(
   774  						"error downloading '%s': %s",
   775  						shardPath,
   776  						ioErr,
   777  					)
   778  				} else {
   779  					log.Printf(
   780  						"Downloaded %s/%s... "+
   781  							"%s completed.", uri, shardPath,
   782  						humanize.Bytes(uint64(bytesDownloaded)),
   783  					)
   784  				}
   785  
   786  				//close shard file
   787  				err = rsrcFile.Close()
   788  
   789  				if err != nil {
   790  					return &foundResources, fmt.Errorf(
   791  						"error trying to close reader: %s", err,
   792  					)
   793  				} else {
   794  					log.Printf(
   795  						"Downloaded %s/%s... "+
   796  							"%s completed.", uri, shardPath,
   797  						humanize.Bytes(uint64(bytesDownloaded)),
   798  					)
   799  				}
   800  
   801  				//check if shard local size is correct compared to remote size
   802  				var localShardInfo os.FileInfo
   803  				localShardInfo, err = os.Stat(targetPath)
   804  				if err != nil {
   805  					fmt.Printf(
   806  						"Could not get size of local shard %s: %s\n",
   807  						shardPath, err,
   808  					)
   809  					return &foundResources,
   810  						errors.New("could not get size of local shard")
   811  				}
   812  				if (rsrcSize > 0) &&
   813  					(rsrcSize != uint(localShardInfo.Size())) {
   814  					return &foundResources,
   815  						errors.New("shard was not downloaded correctly")
   816  				}
   817  				log.Printf(
   818  					"Shard %s downloaded correctly, size is %s\n",
   819  					shardPath,
   820  					humanize.Bytes(uint64(localShardInfo.Size())),
   821  				)
   822  
   823  			}
   824  			log.Printf("Downloaded %d shards\n", numShards)
   825  
   826  		}
   827  
   828  	}
   829  	return &foundResources, nil
   830  }
   831  
   832  // ResolveSplitRegex
   833  // Resolves the split regex from the tokenizer.json file, if it exists.
   834  func (rsrcs *Resources) ResolveSplitRegex() *string {
   835  	var splitRegex *string
   836  	if tokenizerData, ok := (*rsrcs)["tokenizer.json"]; ok {
   837  		var tokenizerMap JsonMap
   838  		if tokenizerData.Data != nil {
   839  			if json.Unmarshal(*tokenizerData.Data, &tokenizerMap) != nil {
   840  				log.Fatal("Error unmarshalling tokenizer.json")
   841  			}
   842  		}
   843  		if preTok, ok := tokenizerMap["pre_tokenizer"]; ok && preTok != nil {
   844  			preTokMap := preTok.(map[string]interface{})
   845  			if pretokenizers, ok := preTokMap["pretokenizers"]; ok &&
   846  				pretokenizers != nil {
   847  				pretokenizersList := pretokenizers.([]interface{})
   848  				for _, v := range pretokenizersList {
   849  					vIntf := v.(map[string]interface{})
   850  					if vIntf["type"] == "Split" {
   851  						pattern := vIntf["pattern"].(map[string]interface{})
   852  						if pattern["Regex"] != nil {
   853  							splitRegexVal := pattern["Regex"].(string)
   854  							// Fix lookbacks
   855  							splitRegexVal = strings.ReplaceAll(
   856  								splitRegexVal,
   857  								"(?!\\S)",
   858  								"(\\S){0}",
   859  							)
   860  							splitRegex = &splitRegexVal
   861  						}
   862  					}
   863  				}
   864  			}
   865  		}
   866  	}
   867  	return splitRegex
   868  }
   869  
   870  // CheckFileExist checks if a file exists at a given path.
   871  func CheckFileExist(path string) bool {
   872  	_, err := os.Stat(path)
   873  
   874  	if errors.Is(err, os.ErrNotExist) {
   875  		return false
   876  	} else {
   877  		return true
   878  	}
   879  }
   880  
   881  // HFConfig contains the tokenizer configuration that gpt_bpe uses.
   882  type HFConfig struct {
   883  	ModelId             *string   `json:"omitempty"`
   884  	ModelType           *string   `json:"model_type,omitempty"`
   885  	EosTokenId          *Token    `json:"eos_token_id,omitempty"`
   886  	BosTokenId          *Token    `json:"bos_token_id,omitempty"`
   887  	PadTokenId          *Token    `json:"pad_token_id,omitempty"`
   888  	BosTokenStr         *string   `json:"bos_token,omitempty"`
   889  	EosTokenStr         *string   `json:"eos_token,omitempty"`
   890  	PadTokenStr         *string   `json:"pad_token,omitempty"`
   891  	VocabSize           *uint32   `json:"vocab_size,omitempty"`
   892  	NewLineMode         *string   `json:"newlinemode,omitempty"`
   893  	TokenizerClass      *string   `json:"tokenizer_class"`
   894  	AddBosToken         *bool     `json:"add_bos_token,omitempty"`
   895  	AddEosToken         *bool     `json:"add_eos_token,omitempty"`
   896  	AddedSpecialsTokens *TokenMap `json:"added_specials_tokens,omitempty"`
   897  	IgnoreMerges        *bool     `json:"ignore_merges,omitempty"`
   898  }
   899  
   900  // SpecialConfig contains the special tokens and special token configuration
   901  // that gpt_bpe uses.
   902  type SpecialConfig struct {
   903  	PuncRunes     []*string          `json:"punc_runes"`
   904  	Normalizer    *map[string]string `json:"normalizer"`
   905  	EncloseEosBos bool               `json:"enclose_eos_bos"`
   906  	PrefixSpace   bool               `json:"prefix_space"`
   907  	LowerCase     bool               `json:"lower_case"`
   908  	EndOfWord     string             `json:"end_of_word"`
   909  	DecodeExtra   *map[string]string `json:"decode_extra"`
   910  	SplitRegex    *string            `json:"split_regex"`
   911  }
   912  
   913  // NewHFConfig creates a new HFConfig object with default values.
   914  func NewHFConfig() *HFConfig {
   915  	defaultModelId := ""
   916  	defaultModelType := "gpt2"
   917  	defaultEosTokenId := Token(0)
   918  	defaultBosTokenId := Token(0)
   919  	defaultPadTokenId := Token(0)
   920  	defaultBosTokenStr := "<|startoftext|>"
   921  	defaultEosTokenStr := "<|endoftext|>"
   922  	defaultPadTokenStr := ""
   923  	defaultVocabSize := uint32(50257)
   924  	defaultNewLineMode := "prefix"
   925  	defaultTokenizerClass := "GPT2BPETokenizer"
   926  	defaultAddBosToken := false
   927  	defaultAddEosToken := false
   928  	defaultAddedSpecialsTokens := make(TokenMap)
   929  	HFConfig := &HFConfig{
   930  		ModelId:             &defaultModelId,
   931  		ModelType:           &defaultModelType,
   932  		EosTokenId:          &defaultEosTokenId,
   933  		BosTokenId:          &defaultBosTokenId,
   934  		PadTokenId:          &defaultPadTokenId,
   935  		BosTokenStr:         &defaultBosTokenStr,
   936  		EosTokenStr:         &defaultEosTokenStr,
   937  		PadTokenStr:         &defaultPadTokenStr,
   938  		VocabSize:           &defaultVocabSize,
   939  		NewLineMode:         &defaultNewLineMode,
   940  		TokenizerClass:      &defaultTokenizerClass,
   941  		AddBosToken:         &defaultAddBosToken,
   942  		AddEosToken:         &defaultAddEosToken,
   943  		AddedSpecialsTokens: &defaultAddedSpecialsTokens,
   944  	}
   945  	return HFConfig
   946  }
   947  
   948  // Processor stores config to process one step of the pipeline
   949  type Processor struct {
   950  	ProcessorType string
   951  	ProcessorArgs JsonMap
   952  }
   953  
   954  // Process the input with the processor
   955  func (p *Processor) Process(input interface{}) (interface{}, error) {
   956  	switch p.ProcessorType {
   957  	case "prepend":
   958  		return nil, errors.New("prepend not implemented")
   959  	default:
   960  		return nil, errors.New("unknown processor type")
   961  	}
   962  }
   963  
   964  // LoadExternalResources
   965  // Resolves a given vocabulary id, and returns the corresponding HuggingFace
   966  // configuration, and the resources for the tokenizer.
   967  func LoadExternalResources(
   968  	vocabId string,
   969  	token string,
   970  ) (resources *Resources, err error) {
   971  	dir, dirErr := ioutil.TempDir("", "resources")
   972  	if dirErr != nil {
   973  		return nil, dirErr
   974  	}
   975  	defer func(path string) {
   976  		_ = os.RemoveAll(path)
   977  	}(dir)
   978  	rslvdResources, rsrcErr := ResolveResources(
   979  		vocabId,
   980  		&dir,
   981  		RESOURCE_DERIVED,
   982  		RESOURCETYPE_TRANSFORMERS,
   983  		token,
   984  	)
   985  	if rsrcErr != nil {
   986  		return nil, rsrcErr
   987  	} else {
   988  		resources = rslvdResources
   989  	}
   990  	fmt.Printf("Resources: %v\n", resources)
   991  	return resources, nil
   992  
   993  }
   994  
   995  // ResolveHF
   996  // Given a set of resources, resolve the HuggingFace configuration.
   997  // Used to be able to resolve both embedded and local resources.
   998  func (rsrcs *Resources) ResolveHF(hfConfig *HFConfig) (err error) {
   999  	// Resolve config and tokenizer config from resources
  1000  	// config.json and tokenizer_config.json
  1001  	if err = rsrcs.resolveConfigAndTokenizer(hfConfig); err != nil {
  1002  		return err
  1003  	}
  1004  
  1005  	// Resolve special tokens and special tokens config from resources
  1006  	// special_tokens_map.json and specials.txt
  1007  	if err = rsrcs.resolveSpecials(hfConfig); err != nil {
  1008  		return err
  1009  	}
  1010  
  1011  	// Resolve Vocab size from vocab.json or encoder.json
  1012  	if err = rsrcs.resolveVocabSize(hfConfig); err != nil {
  1013  		return err
  1014  	}
  1015  
  1016  	// Sometimes TokenIDs are not properly resolved, so we need to check
  1017  	if hfConfig != nil {
  1018  		if *hfConfig.EosTokenId == 0 || *hfConfig.BosTokenId == 0 ||
  1019  			*hfConfig.PadTokenId == 0 {
  1020  			if err = rsrcs.resolveTokenIds(hfConfig); err != nil {
  1021  				return err
  1022  			}
  1023  		}
  1024  	} else {
  1025  		return errors.New("could not resolve HFConfig")
  1026  	}
  1027  
  1028  	// Llama 3 and other larger models will enclose eos and bos by default
  1029  	if *hfConfig.VocabSize > math.MaxUint16+1 {
  1030  		var addEosToken = true
  1031  		var addBosToken = true
  1032  
  1033  		hfConfig.AddEosToken = &addEosToken
  1034  		hfConfig.AddBosToken = &addBosToken
  1035  	}
  1036  
  1037  	return nil
  1038  }
  1039  
  1040  func GetMergesAsBpeRank(resources *Resources) (map[GPTPair]float64, error) {
  1041  	bpeRanks := make(map[GPTPair]float64)
  1042  	// Try to get from merges.txt
  1043  	if mergesTxt, ok := (*resources)["merges.txt"]; ok {
  1044  		scanner := bufio.NewScanner(bytes.NewBuffer(*mergesTxt.Data))
  1045  		idx := uint32(0)
  1046  		firstLine := true
  1047  		for scanner.Scan() {
  1048  			if firstLine {
  1049  				firstLine = false
  1050  				continue
  1051  			}
  1052  			leftRight := strings.SplitN(scanner.Text(), " ", 2)
  1053  			bpeRanks[GPTPair{
  1054  				Left:  leftRight[0],
  1055  				Right: leftRight[1],
  1056  			}] = float64(idx)
  1057  			idx += 1
  1058  		}
  1059  	} else if mergesJson, ok := (*resources)["merges.json"]; ok {
  1060  		var mergesTable [][]string
  1061  		err := json.Unmarshal(*mergesJson.Data, &mergesTable)
  1062  		if err != nil {
  1063  			return nil, err
  1064  		}
  1065  		// Iterate over the merges and add them to the BPE ranks
  1066  		for rank, merge := range mergesTable {
  1067  			bpeRanks[GPTPair{merge[0], merge[1]}] =
  1068  				float64(rank)
  1069  		}
  1070  	} else if tokenizerJson, ok := (*resources)["tokenizer.json"]; ok {
  1071  		// Finally try to get from tokenizer.json, merges entry
  1072  		var tokenizerJsonMap JsonMap
  1073  		err := json.Unmarshal(*tokenizerJson.Data, &tokenizerJsonMap)
  1074  		if err != nil {
  1075  			return nil, err
  1076  		}
  1077  
  1078  		model, ok := tokenizerJsonMap["model"]
  1079  		if !ok {
  1080  			return nil, errors.New("could not get model from tokenizer.json")
  1081  		}
  1082  		merges, ok := model.(map[string]interface{})["merges"]
  1083  		if !ok {
  1084  			return nil, errors.New("could not get merges from tokenizer.json")
  1085  		}
  1086  		// Iterate over the merges and add them to the BPE ranks, in form of string[]
  1087  		for rank, merge := range merges.([]interface{}) {
  1088  			mergeStr := merge.(string)
  1089  			mergeSplit := strings.Split(mergeStr, " ")
  1090  			bpeRanks[GPTPair{mergeSplit[0], mergeSplit[1]}] =
  1091  				float64(rank)
  1092  		}
  1093  	} else {
  1094  		return nil, errors.New("could not find merges")
  1095  	}
  1096  	return bpeRanks, nil
  1097  }
  1098  
  1099  func (rsrcs *Resources) UnmarshalData(
  1100  	name string,
  1101  ) (data *JsonMap, err error) {
  1102  	if _, err = (*rsrcs).GetFile(name); err == nil {
  1103  		rawData := (*rsrcs)[name].Data
  1104  		if err = json.Unmarshal(*rawData, &data); err != nil {
  1105  			return nil,
  1106  				fmt.Errorf("error unmarshalling %s: %s", name, err)
  1107  		}
  1108  	} else {
  1109  		return nil, nil
  1110  	}
  1111  	return data, nil
  1112  }
  1113  
  1114  func (rsrcs *Resources) UnmarshalUntilData(
  1115  	names []string,
  1116  ) (
  1117  	name string,
  1118  	data *JsonMap,
  1119  	err error,
  1120  ) {
  1121  	for _, name = range names {
  1122  		if _, ok := (*rsrcs)[name]; !ok {
  1123  			continue
  1124  		}
  1125  		if data, err = rsrcs.UnmarshalData(name); err == nil && data != nil {
  1126  			return name, data, nil
  1127  		} else if err != nil {
  1128  			return name, nil, fmt.Errorf(
  1129  				"error unmarshalling %s: %s", name, err,
  1130  			)
  1131  		}
  1132  	}
  1133  	return "", nil, nil
  1134  }
  1135  
  1136  // GetVocab
  1137  // Get the vocab from the resources.
  1138  // Adapter function to get the vocab from either vocab.json or encoder.json.
  1139  func (rsrcs *Resources) GetVocab(
  1140  	hfConfig *HFConfig,
  1141  ) (TokenMap, error) {
  1142  	// Vocab is stored in either the vocab.json or encoder.json file
  1143  	// We want to unmarshal the vocab file into an interface to work with
  1144  	// We attempt to unmarshal under the vocab.json key first, then
  1145  	// encoder.json if it fails
  1146  	filesToAttempt := []string{"vocab.json", "encoder.json", "tokenizer.json"}
  1147  
  1148  	// Get the vocab from the resources
  1149  	name, vocabData, err := rsrcs.UnmarshalUntilData(filesToAttempt)
  1150  	if err != nil {
  1151  		return nil, err
  1152  	} else if vocabData == nil {
  1153  		return nil, errors.New("vocab file not found")
  1154  	}
  1155  
  1156  	tokens := make(TokenMap)
  1157  	if name == "tokenizer.json" {
  1158  		// We get the vocab stored in the /model/vocab key
  1159  		if modelInterface, ok := (*vocabData)["model"]; ok {
  1160  			model := modelInterface.(map[string]interface{})
  1161  			if vocabInterface, ok := model["vocab"]; ok {
  1162  				vocabMap := vocabInterface.(map[string]interface{})
  1163  				for k, v := range vocabMap {
  1164  					tokens[k] = types.Token(v.(float64))
  1165  				}
  1166  			}
  1167  			// Check for "ignore_merges", and set it
  1168  			if ignoreMerges, ok := model["ignore_merges"].(bool); ok {
  1169  				hfConfig.IgnoreMerges = &ignoreMerges
  1170  			}
  1171  		}
  1172  	} else {
  1173  		for k, v := range *vocabData {
  1174  			tokens[k] = types.Token(v.(float64))
  1175  		}
  1176  	}
  1177  
  1178  	if hfConfig.IgnoreMerges == nil {
  1179  		disableIgnoreMerges := false
  1180  		hfConfig.IgnoreMerges = &disableIgnoreMerges
  1181  	}
  1182  
  1183  	// Add the special tokens to the vocab
  1184  	if len(*hfConfig.AddedSpecialsTokens) > 0 {
  1185  		for k, v := range *hfConfig.AddedSpecialsTokens {
  1186  			tokens[k] = v
  1187  		}
  1188  	}
  1189  
  1190  	return tokens, nil
  1191  }
  1192  
  1193  // resolveTokenIds
  1194  // Resolve token ids for eos, bos, and pad tokens from resources.
  1195  func (rsrcs *Resources) resolveTokenIds(hfConfig *HFConfig) error {
  1196  	// Get the vocab from the resources
  1197  	vocab, err := rsrcs.GetVocab(hfConfig)
  1198  	if err != nil {
  1199  		return err
  1200  	}
  1201  
  1202  	// Get the token ids for eos, bos, and pad tokens
  1203  	var eosTokenId, bosTokenId, padTokenId *Token
  1204  	if eosToken, ok := vocab[*hfConfig.EosTokenStr]; ok {
  1205  		eosTokenId = new(Token)
  1206  		*eosTokenId = Token(eosToken)
  1207  		hfConfig.EosTokenId = eosTokenId
  1208  	}
  1209  	if bosToken, ok := vocab[*hfConfig.BosTokenStr]; ok {
  1210  		bosTokenId = new(Token)
  1211  		*bosTokenId = Token(bosToken)
  1212  		hfConfig.BosTokenId = bosTokenId
  1213  	}
  1214  	if padToken, ok := vocab[*hfConfig.PadTokenStr]; ok {
  1215  		padTokenId = new(Token)
  1216  		*padTokenId = Token(padToken)
  1217  		hfConfig.PadTokenId = padTokenId
  1218  	}
  1219  
  1220  	return nil
  1221  }
  1222  
  1223  // resolveVocabSize
  1224  // Resolve vocab size from resources.
  1225  // Used to be able to resolve both embedded and local resources.
  1226  // Continuation of ResolveHFFromResources.
  1227  func (rsrcs *Resources) resolveVocabSize(hfConfig *HFConfig) (err error) {
  1228  	// Get the vocab from the resources
  1229  	var vocab TokenMap
  1230  	if vocab, err = rsrcs.GetVocab(hfConfig); err != nil {
  1231  		return err
  1232  	}
  1233  
  1234  	// Get length of vocab
  1235  	vocabLen := new(uint32)
  1236  	*vocabLen = uint32(len(vocab))
  1237  
  1238  	hfConfig.VocabSize = vocabLen
  1239  	return nil
  1240  }
  1241  
  1242  // resolveConfigAndTokenizer
  1243  // Resolve config and tokenizer config from resources.
  1244  // Used to be able to resolve both embedded and local resources.
  1245  // Continuation of ResolveHFFromResources.
  1246  func (rsrcs *Resources) resolveConfigAndTokenizer(
  1247  	hfConfig *HFConfig,
  1248  ) (err error) {
  1249  	// Use interfaces to unmarshal the config file and tokenizer config file
  1250  	var config *JsonMap
  1251  	var tokenizerConfig *JsonMap
  1252  
  1253  	// Get the config and tokenizer config from the resources
  1254  
  1255  	// If exists, unmarshal config.json and tokenizer_config.json, else
  1256  	// use GetFile to get the file, then unmarshal it
  1257  
  1258  	if config, err = rsrcs.UnmarshalData("config.json"); err != nil {
  1259  		return err
  1260  	}
  1261  	if tokenizerConfig, err =
  1262  		rsrcs.UnmarshalData("tokenizer_config.json"); err != nil {
  1263  		return err
  1264  	}
  1265  
  1266  	// Check if bos_token is in string, this is the old format Pythia has.
  1267  	// If not, try to unmarshal to the tokenizerSpecials
  1268  	// that llama 2 has, else try mistral format
  1269  	if config != nil || tokenizerConfig != nil {
  1270  		hasReadForEosBos := false
  1271  
  1272  		// Read config.json
  1273  		if config != nil {
  1274  			configMap := *config
  1275  			// Using interfaces, first check if bos_token is in string format
  1276  			if bosToken, ok := configMap["bos_token"].(string); ok {
  1277  				hfConfig.BosTokenStr = &bosToken
  1278  				if eosToken, ok := configMap["eos_token"].(string); ok {
  1279  					hfConfig.EosTokenStr = &eosToken
  1280  				}
  1281  				if padToken, ok := configMap["pad_token"].(string); ok {
  1282  					hfConfig.PadTokenStr = &padToken
  1283  				}
  1284  				hasReadForEosBos = true
  1285  			}
  1286  
  1287  			// Read for EOS BOS token ID
  1288  			if eosTokenId, ok := configMap["eos_token_id"].(float64); ok {
  1289  				eosTokenIdInt := Token(eosTokenId)
  1290  				hfConfig.EosTokenId = &eosTokenIdInt
  1291  			}
  1292  			if bosTokenId, ok := configMap["bos_token_id"].(float64); ok {
  1293  				bosTokenIdInt := Token(bosTokenId)
  1294  				hfConfig.BosTokenId = &bosTokenIdInt
  1295  			}
  1296  
  1297  			// Read for vocab size
  1298  			if vocabSize, ok := configMap["vocab_size"].(float64); ok {
  1299  				vocabSizeInt := uint32(vocabSize)
  1300  				hfConfig.VocabSize = &vocabSizeInt
  1301  			}
  1302  
  1303  			// Read for newLineMode
  1304  			if newLineMode, ok := configMap["newlinemode"].(string); ok {
  1305  				hfConfig.NewLineMode = &newLineMode
  1306  			}
  1307  		}
  1308  
  1309  		// Read tokenizer_config.json
  1310  		if tokenizerConfig != nil {
  1311  			configMap := *tokenizerConfig
  1312  			if !hasReadForEosBos {
  1313  				// Using interfaces, first check if bos_token is in string format
  1314  				if bosToken, ok := configMap["bos_token"].(string); ok {
  1315  					hfConfig.BosTokenStr = &bosToken
  1316  					if eosToken, ok := configMap["eos_token"].(string); ok {
  1317  						hfConfig.EosTokenStr = &eosToken
  1318  					}
  1319  					if padToken, ok := configMap["pad_token"].(string); ok {
  1320  						hfConfig.PadTokenStr = &padToken
  1321  					}
  1322  					hasReadForEosBos = true
  1323  
  1324  				}
  1325  			}
  1326  			// If not, assume llama2 format and try to unmarshal
  1327  			if !hasReadForEosBos {
  1328  				if bosToken, ok :=
  1329  					configMap["bos_token"].(map[string]interface{}); ok {
  1330  					if content, ok := bosToken["content"].(string); ok {
  1331  						hfConfig.BosTokenStr = &content
  1332  					}
  1333  				}
  1334  				if eosToken, ok :=
  1335  					configMap["eos_token"].(map[string]interface{}); ok {
  1336  					if content, ok := eosToken["content"].(string); ok {
  1337  						hfConfig.EosTokenStr = &content
  1338  					}
  1339  				}
  1340  				if padToken, ok := configMap["pad_token"].(string); ok {
  1341  					hfConfig.PadTokenStr = &padToken
  1342  				}
  1343  			}
  1344  			// If that doesn't work, assume mistral format
  1345  			if !hasReadForEosBos {
  1346  				if bosToken, ok := configMap["bos_token"].(string); ok {
  1347  					hfConfig.BosTokenStr = &bosToken
  1348  				}
  1349  				if eosToken, ok := configMap["eos_token"].(string); ok {
  1350  					hfConfig.EosTokenStr = &eosToken
  1351  				}
  1352  				if padToken, ok := configMap["pad_token"].(string); ok {
  1353  					hfConfig.PadTokenStr = &padToken
  1354  				}
  1355  			}
  1356  
  1357  			// Read for enclose eos bos
  1358  			if encloseEos, ok := configMap["add_bos_token"].(bool); ok {
  1359  				hfConfig.AddBosToken = &encloseEos
  1360  			}
  1361  
  1362  			if encloseBos, ok := configMap["add_eos_token"].(bool); ok {
  1363  				hfConfig.AddEosToken = &encloseBos
  1364  			}
  1365  
  1366  			// Read for added_specials_tokens
  1367  			// Will later be used to readd into vocab if needed
  1368  			if addedTokensDecoder, ok :=
  1369  				configMap["added_tokens_decoder"].(map[string]interface{}); ok {
  1370  				addedSpecialsTokens := make(TokenMap)
  1371  				for k, v := range addedTokensDecoder {
  1372  					// Get under content key, key is float64
  1373  					keyToken, _ := strconv.ParseFloat(k, 64)
  1374  					valStr := v.(map[string]interface{})["content"].(string)
  1375  					addedSpecialsTokens[valStr] = types.Token(keyToken)
  1376  				}
  1377  				hfConfig.AddedSpecialsTokens = &addedSpecialsTokens
  1378  			}
  1379  
  1380  			// Read for tokenizer Class
  1381  			if tClass, ok := configMap["tokenizer_class"].(string); ok {
  1382  				hfConfig.TokenizerClass = &tClass
  1383  			}
  1384  		}
  1385  	}
  1386  	return nil
  1387  }
  1388  
  1389  // resolveSpecials
  1390  // Resolve special tokens and special config from resources.
  1391  // Used to be able to resolve both embedded and local resources.
  1392  // Continuation of ResolveHFFromResources.
  1393  func (rsrcs *Resources) resolveSpecials(hfConfig *HFConfig) error {
  1394  	// Get specials config from resources
  1395  	// We can only generate specials.json if we have special_tokens_map
  1396  	specialsJson, ok := (*rsrcs)["special_tokens_map.json"]
  1397  	if ok {
  1398  		specialTokens := make(JsonMap)
  1399  		if specialErr := json.Unmarshal(
  1400  			*specialsJson.Data,
  1401  			&specialTokens,
  1402  		); specialErr != nil {
  1403  			return specialErr
  1404  		}
  1405  
  1406  		// Try to get pad token from specials if not already set
  1407  		if hfConfig.PadTokenStr == nil || *hfConfig.PadTokenStr == "" {
  1408  			if padToken, pOk := specialTokens["pad_token"].(string); pOk {
  1409  				hfConfig.PadTokenStr = &padToken
  1410  			}
  1411  		}
  1412  	}
  1413  
  1414  	// Get from specials.json
  1415  	specialsTxt, ok := (*rsrcs)["specials.txt"]
  1416  	if ok {
  1417  		// Treat specials.txt as an array of strings and try to match
  1418  		specials := strings.Split(string(*specialsTxt.Data), "\n")
  1419  		if hfConfig.PadTokenStr == nil {
  1420  			for _, special := range specials {
  1421  				if strings.Contains(strings.ToLower(special), "pad") {
  1422  					hfConfig.PadTokenStr = &special
  1423  					break
  1424  				}
  1425  			}
  1426  		}
  1427  	}
  1428  	return nil
  1429  }
  1430  
  1431  func (rsrcs *Resources) LoadEmbeddedResource(
  1432  	vocabId string,
  1433  	resourceId string,
  1434  	path string,
  1435  ) {
  1436  	if r := GetEmbeddedResource(vocabId + "/" + path); r != nil {
  1437  		(*rsrcs)[resourceId] = *r
  1438  	}
  1439  }
  1440  
  1441  // ResolveResourcesList
  1442  // Resolves a list of resources, and checks if they exist in the given
  1443  // directory. If they don't exist, they are downloaded.
  1444  func ResolveResourcesList(vocabId string, token string) (*Resources, error) {
  1445  	// Resolve the vocab id - Embedded resources
  1446  	if _, vocabErr := EmbeddedDirExists(vocabId); vocabErr == nil {
  1447  		resources := make(Resources)
  1448  
  1449  		possibleEmbeddedResources := []struct {
  1450  			resourceId string
  1451  			path       string
  1452  		}{
  1453  			{"vocab.json", "encoder.json"},
  1454  			{"config.json", "config.json"},
  1455  			{"merges.txt", "vocab.bpe"},
  1456  			{"merges.json", "merges.json"},
  1457  			{"specials.txt", "specials.txt"},
  1458  			{"special_tokens_map.json", "special_tokens_map.json"},
  1459  			{"special_config.json", "special_config.json"},
  1460  			{"tokenizer.json", "tokenizer.json"},
  1461  			{"tokenizer_config.json", "tokenizer_config.json"},
  1462  		}
  1463  
  1464  		for _, resource := range possibleEmbeddedResources {
  1465  			resources.LoadEmbeddedResource(
  1466  				vocabId, resource.resourceId, resource.path,
  1467  			)
  1468  		}
  1469  		return &resources, nil
  1470  	}
  1471  	// Non-embedded resources
  1472  	resources, err := LoadExternalResources(vocabId, token)
  1473  	if err != nil {
  1474  		return nil, err
  1475  	}
  1476  	return resources, nil
  1477  
  1478  }
  1479  
  1480  // ResolveVocabId
  1481  // Resolves a vocabulary id to a set of resources, from embedded,
  1482  // local filesystem, or remote, and applies processing to the resources.
  1483  func ResolveVocabId(vocabId string, token string) (
  1484  	*HFConfig,
  1485  	*Resources,
  1486  	error,
  1487  ) {
  1488  	rsrcs, err := ResolveResourcesList(vocabId, token)
  1489  	if err != nil {
  1490  		return nil, nil, err
  1491  	}
  1492  
  1493  	hf := NewHFConfig()
  1494  	hf.ModelId = &vocabId
  1495  	if err = rsrcs.ResolveHF(hf); err != nil {
  1496  		return nil, nil, err
  1497  	}
  1498  	return hf, rsrcs, nil
  1499  }
  1500  
  1501  func ExtractModelFromTokenizer(dir *string) (JsonMap, error) {
  1502  	tokenizerPath := path.Join(*dir, "tokenizer.json")
  1503  	// Open the file
  1504  	tokenizerFile, err := os.Open(tokenizerPath)
  1505  	if err != nil {
  1506  		log.Println("Error opening tokenizer:", err)
  1507  		// return an empty map and the error
  1508  		return nil, err
  1509  	}
  1510  	defer func(tokenizerFile *os.File) {
  1511  		_ = tokenizerFile.Close()
  1512  	}(tokenizerFile)
  1513  
  1514  	// Decode the JSON data into a map
  1515  	var data JsonMap
  1516  	err = json.NewDecoder(tokenizerFile).Decode(&data)
  1517  	if err != nil {
  1518  		log.Println("Error decoding JSON from tokenizer:", err)
  1519  		return nil, err
  1520  	}
  1521  
  1522  	// Access the data at the specified path
  1523  	model, ok := (data["model"]).(map[string]interface{})
  1524  	model = ToJsonMap(model)
  1525  	if ok {
  1526  		return model, nil
  1527  	} else {
  1528  		log.Println("Error: Could not convert model in tokenizer to map")
  1529  		return nil, errors.New("could not convert model in tokenizer to map")
  1530  	}
  1531  }
  1532  
  1533  func ExtractVocabFromTokenizer(
  1534  	model JsonMap,
  1535  	dir *string,
  1536  	resources *Resources,
  1537  ) error {
  1538  	vocab, ok := model["vocab"].(map[string]interface{})
  1539  	vocab = ToJsonMap(vocab)
  1540  	if !ok {
  1541  		log.Println("Error: Could not convert vocab in model to map")
  1542  		return errors.New("could not convert vocab in model to map")
  1543  	}
  1544  
  1545  	vocabPath := path.Join(*dir, "vocab.json")
  1546  
  1547  	// Create the file
  1548  	vocabFile, err := os.Create(vocabPath)
  1549  	if err != nil {
  1550  		log.Println("Error creating vocab.json:", err)
  1551  		return err
  1552  	}
  1553  	defer func(vocabFile *os.File) {
  1554  		_ = vocabFile.Close()
  1555  	}(vocabFile)
  1556  
  1557  	// Marshal the vocab map into a JSON string with indentation
  1558  	vocabJsonString, err := json.MarshalIndent(vocab, "", " ")
  1559  	if err != nil {
  1560  		fmt.Println("Error marshaling JSON:", err)
  1561  		return err
  1562  	}
  1563  
  1564  	// Write the JSON string to the file
  1565  	_, err = vocabFile.Write(vocabJsonString)
  1566  	if err != nil {
  1567  		log.Println("Error writing to vocab.json:", err)
  1568  		return err
  1569  	}
  1570  
  1571  	log.Println("Vocab written to vocab.json from tokenizer.json")
  1572  
  1573  	if mmapErr := resources.AddEntry(
  1574  		"vocab.json", vocabFile,
  1575  	); mmapErr != nil {
  1576  		return fmt.Errorf("error trying to mmap file: %s", mmapErr)
  1577  	}
  1578  
  1579  	return nil
  1580  }
  1581  
  1582  func ExtractMergesFromTokenizer(
  1583  	model JsonMap,
  1584  	dir *string,
  1585  	resources *Resources,
  1586  ) error {
  1587  	merges, ok := model["merges"].([]interface{})
  1588  	if !ok {
  1589  		log.Println("Error: Could not convert merges in model to map")
  1590  		return errors.New("could not convert merges in model to map")
  1591  	}
  1592  
  1593  	// Convert the slice of interfaces to a slice of strings
  1594  	var mergesStr []string
  1595  	for _, v := range merges {
  1596  		mergesStr = append(mergesStr, v.(string))
  1597  	}
  1598  
  1599  	mergesPath := path.Join(*dir, "merges.txt")
  1600  
  1601  	// Create the file
  1602  	mergesFile, err := os.Create(mergesPath)
  1603  	if err != nil {
  1604  		log.Println("Error creating file:", err)
  1605  		return err
  1606  	}
  1607  	defer func(mergesFile *os.File) {
  1608  		_ = mergesFile.Close()
  1609  	}(mergesFile)
  1610  
  1611  	// Write each merge string to a new line in the file
  1612  	for _, v := range merges {
  1613  		_, err = mergesFile.WriteString(v.(string) + "\n")
  1614  		if err != nil {
  1615  			log.Println("Error writing to file:", err)
  1616  			return err
  1617  		}
  1618  	}
  1619  
  1620  	log.Println("Merges written to merges.txt from tokenizer.json")
  1621  
  1622  	if mmapErr := resources.AddEntry(
  1623  		"merges.txt", mergesFile,
  1624  	); mmapErr != nil {
  1625  		return fmt.Errorf("error trying to mmap file: %s", mmapErr)
  1626  	}
  1627  
  1628  	return nil
  1629  }
  1630  
  1631  func FindNumberOfShardsFromConfig(configPath string) (int, error) {
  1632  	// Open the file
  1633  	configFile, err := os.Open(configPath)
  1634  	if err != nil {
  1635  		log.Println("Error opening config:", err)
  1636  		return -1, err
  1637  	}
  1638  	defer func(configFile *os.File) {
  1639  		_ = configFile.Close()
  1640  	}(configFile)
  1641  
  1642  	// Decode the JSON data into a map
  1643  	var data JsonMap
  1644  	err = json.NewDecoder(configFile).Decode(&data)
  1645  	if err != nil {
  1646  		log.Println("Error decoding JSON from config:", err)
  1647  		return -1, err
  1648  	}
  1649  
  1650  	// Access the data at the specified path
  1651  	weightMap, ok := data["weight_map"].(map[string]interface{})
  1652  	weightMap = ToJsonMap(weightMap)
  1653  	if !ok {
  1654  		fmt.Println("Error: Could not convert data to weight_map")
  1655  		return -1, errors.New("could not convert data to weight_map")
  1656  	}
  1657  	// Try embed out, if not, try lm_head.weight
  1658  	nameOfLast, ok := weightMap["embed_out.weight"]
  1659  	if !ok {
  1660  		nameOfLast, ok = weightMap["lm_head.weight"]
  1661  		if !ok {
  1662  			fmt.Println("Error: Could not convert weight_map to embed_out or lm_head")
  1663  			return -1, errors.New("could not convert weight_map to embed_out or lm_head")
  1664  		}
  1665  	}
  1666  
  1667  	r, _ := regexp.Compile(`\D*\d+\D+(\d+)`)
  1668  	// convert to interface -> string -> int
  1669  	nameOfLastInt, err := strconv.Atoi(
  1670  		r.FindStringSubmatch(fmt.Sprintf("%v", nameOfLast))[1],
  1671  	)
  1672  
  1673  	if err != nil {
  1674  		fmt.Println("Error: Could not convert embed_out to int")
  1675  		return -1, errors.New("could not convert embed_out to int")
  1676  	}
  1677  
  1678  	return nameOfLastInt, nil
  1679  }
  1680  
  1681  func FindProcessingStepsFromTokenizer(model ResourceEntry) (
  1682  	[]Processor,
  1683  	error,
  1684  ) {
  1685  	// convert the data to a map
  1686  	var data JsonMap
  1687  	err := json.Unmarshal(*model.Data, &data)
  1688  	if err != nil {
  1689  		return nil, err
  1690  	}
  1691  
  1692  	// create array of processors
  1693  	var processors []Processor
  1694  	// check if normalizer is present
  1695  	normalizer, ok := data["normalizer"].(map[string]interface{})
  1696  	normalizer = ToJsonMap(normalizer)
  1697  	if normalizer != nil && ok {
  1698  		// add normalizer to processors
  1699  		processor := Processor{
  1700  			ProcessorType: "normalizer",
  1701  			ProcessorArgs: normalizer,
  1702  		}
  1703  		processors = append(processors, processor)
  1704  	}
  1705  	// check if pre_tokenizer is present
  1706  	preTokenizer, ok := data["pre_tokenizer"].(map[string]interface{})
  1707  	preTokenizer = ToJsonMap(preTokenizer)
  1708  	if preTokenizer != nil && ok {
  1709  		// add pre_tokenizer to processors
  1710  		processor := Processor{
  1711  			ProcessorType: "pre_tokenizer",
  1712  			ProcessorArgs: preTokenizer,
  1713  		}
  1714  		processors = append(processors, processor)
  1715  	}
  1716  	// check if post_processor is present
  1717  	post_processor, ok := data["post_processor"].(map[string]interface{})
  1718  	post_processor = ToJsonMap(post_processor)
  1719  	if post_processor != nil && ok {
  1720  		// add post_processor to processors
  1721  		processor := Processor{
  1722  			ProcessorType: "post_processor",
  1723  			ProcessorArgs: post_processor,
  1724  		}
  1725  		processors = append(processors, processor)
  1726  	}
  1727  	// check if decoder is present
  1728  	decoder, ok := data["decoder"].(map[string]interface{})
  1729  	decoder = ToJsonMap(decoder)
  1730  	if decoder != nil && ok {
  1731  		// add decoder to processors
  1732  		processor := Processor{
  1733  			ProcessorType: "decoder",
  1734  			ProcessorArgs: decoder,
  1735  		}
  1736  		processors = append(processors, processor)
  1737  	}
  1738  
  1739  	return processors, nil
  1740  }
  1741  
  1742  func ToJsonMap(sim map[string]interface{}) JsonMap {
  1743  	jm := make(JsonMap)
  1744  	for k, v := range sim {
  1745  		jm[k] = v
  1746  	}
  1747  	return jm
  1748  }
  1749  
  1750  func (jm JsonMap) ToMapInterface() map[string]interface{} {
  1751  	m := make(map[string]interface{})
  1752  	for k, v := range jm {
  1753  		m[k] = v
  1754  	}
  1755  	return m
  1756  }