github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/resources/resolver.go (about) 1 package resources 2 3 import ( 4 "bufio" 5 "bytes" 6 "encoding/json" 7 "errors" 8 "fmt" 9 "io" 10 "io/fs" 11 "io/ioutil" 12 "log" 13 "math" 14 "net/url" 15 "os" 16 "path" 17 "regexp" 18 "strconv" 19 "strings" 20 "time" 21 22 "github.com/wbrown/gpt_bpe/types" 23 24 "github.com/dustin/go-humanize" 25 ) 26 27 type Token types.Token 28 type Tokens types.Tokens 29 type TokenMap types.TokenMap 30 31 type JsonMap map[string]interface{} 32 33 type ResourceFlag uint8 34 type ResourceType uint8 35 36 // WriteCounter counts the number of bytes written to it, and every 10 37 // seconds, it prints a message reporting the number of bytes written so far. 38 type WriteCounter struct { 39 Total uint64 40 Last time.Time 41 Reported bool 42 Path string 43 Size uint64 44 } 45 46 // Write writes p to the WriteCounter and updates the total number of bytes 47 // written. 48 func (wc *WriteCounter) Write(p []byte) (int, error) { 49 n := len(p) 50 wc.Total += uint64(n) 51 if time.Since(wc.Last).Seconds() > 10 { 52 wc.Reported = true 53 wc.Last = time.Now() 54 log.Printf( 55 "Downloading %s... %s / %s completed.", 56 wc.Path, humanize.Bytes(wc.Total), humanize.Bytes(wc.Size), 57 ) 58 } 59 return n, nil 60 } 61 62 // Enumeration of resource flags that indicate what the resolver should do 63 // with the resource. 64 const ( 65 RESOURCE_REQUIRED ResourceFlag = 1 << iota 66 RESOURCE_OPTIONAL 67 RESOURCE_DERIVED 68 RESOURCE_MODEL 69 RESOURCE_ONEOF 70 ) 71 72 // Enumeration of different types of models 73 const ( 74 RESOURCETYPE_TRANSFORMERS ResourceType = 1 << iota 75 RESOURCETYPE_DIFFUSERS 76 ) 77 78 var TransformerResources = ResourceEntryDefs{ 79 "config.json": RESOURCE_REQUIRED, 80 "vocab.json": RESOURCE_OPTIONAL, 81 "merges.txt": RESOURCE_OPTIONAL, 82 "special_tokens_map.json": RESOURCE_OPTIONAL, 83 "wordtokens.json": RESOURCE_OPTIONAL, 84 "specials.txt": RESOURCE_OPTIONAL | RESOURCE_DERIVED, 85 "tokenizer_config.json": RESOURCE_OPTIONAL, 86 "pytorch_model.bin.index.json": RESOURCE_OPTIONAL, 87 "tokenizer.json": RESOURCE_OPTIONAL, 88 "tokenizer.model": RESOURCE_OPTIONAL, 89 "pytorch_model.bin": RESOURCE_MODEL, 90 } 91 92 var DiffuserResources = ResourceEntryDefs{ 93 "feature_extractor/preprocessor_config.json": RESOURCE_OPTIONAL, 94 "safety_checker/config.json": RESOURCE_OPTIONAL, 95 "safety_checker/pytorch_model.bin": RESOURCE_OPTIONAL, 96 "scheduler/scheduler_config.json": RESOURCE_REQUIRED, 97 "text_encoder/config.json": RESOURCE_REQUIRED, 98 "text_encoder/pytorch_model.bin": RESOURCE_MODEL, 99 "tokenizer/merges.txt": RESOURCE_REQUIRED, 100 "tokenizer/special_tokens_map.json": RESOURCE_REQUIRED, 101 "tokenizer/tokenizer_config.json": RESOURCE_REQUIRED, 102 "tokenizer/vocab.json": RESOURCE_REQUIRED, 103 "unet/config.json": RESOURCE_REQUIRED, 104 "unet/diffusion_pytorch_model.bin": RESOURCE_MODEL, 105 "vae/config.json": RESOURCE_REQUIRED, 106 "vae/diffusion_pytorch_model.bin": RESOURCE_MODEL, 107 "model_index.json": RESOURCE_REQUIRED, 108 } 109 110 type ResourceEntryDefs map[string]ResourceFlag 111 type ResourceEntry struct { 112 file interface{} 113 Data *[]byte 114 } 115 116 type Resources map[string]ResourceEntry 117 118 // Cleanup closes all open file handles in the Resources map. 119 func (rsrcs *Resources) Cleanup() { 120 for _, rsrc := range *rsrcs { 121 file := rsrc.file 122 switch t := file.(type) { 123 case os.File: 124 _ = t.Close() 125 case fs.File: 126 _ = t.Close() 127 } 128 } 129 } 130 131 // GetFile 132 // Returns the file handle for a given resource name. 133 func (rsrcs *Resources) GetFile(name string) (interface{}, error) { 134 if rsrcEntry, ok := (*rsrcs)[name]; ok { 135 return rsrcEntry.file, nil 136 } else { 137 return nil, fmt.Errorf("file %s not found", name) 138 } 139 } 140 141 // GetResourceEntries 142 // Returns a default map of resource entries that express what files are 143 // required, optional, derived, and/or model resources. Requires a 144 // ResourceType. 145 func GetResourceEntries(typ ResourceType) ResourceEntryDefs { 146 switch typ { 147 case RESOURCETYPE_TRANSFORMERS: 148 return TransformerResources 149 case RESOURCETYPE_DIFFUSERS: 150 151 return DiffuserResources 152 default: 153 return ResourceEntryDefs{} 154 } 155 } 156 157 // getResourceEntryAliases 158 // Returns a map of defined resources to known alternative filenames 159 // for each resource of a given ResourceType. 160 func getResourceEntryAliases(typ ResourceType) map[string][]string { 161 switch typ { 162 case RESOURCETYPE_TRANSFORMERS: 163 return map[string][]string{ 164 "vocab.json": {"encoder.json"}, 165 } 166 default: 167 return map[string][]string{} 168 } 169 } 170 171 // FetchHuggingFace 172 // Wrapper around FetchHTTP that fetches a resource from huggingface.co. 173 func FetchHuggingFace(id string, rsrc string) (io.ReadCloser, error) { 174 token := os.Getenv("HF_API_TOKEN") 175 return FetchHTTP( 176 "https://huggingface.co/"+id+"/resolve/main", rsrc, token, 177 ) 178 } 179 180 // SizeHuggingFace 181 // Wrapper around SizeHTTP that gets the size of a resource from huggingface.co. 182 func SizeHuggingFace(id string, rsrc string) (uint, error) { 183 token := os.Getenv("HF_API_TOKEN") 184 return SizeHTTP("https://huggingface.co/"+id+"/resolve/main", rsrc, token) 185 } 186 187 // isValidUrl 188 // Checks if a given string is a valid URL, returns true if it is, false 189 // otherwise. 190 func isValidUrl(toTest string) bool { 191 _, err := url.ParseRequestURI(toTest) 192 if err != nil { 193 return false 194 } 195 196 u, err := url.Parse(toTest) 197 if err != nil || u.Scheme == "" || u.Host == "" { 198 return false 199 } 200 201 return true 202 } 203 204 // Fetch 205 // Given a base URI and a resource name, determines if the resource is local, 206 // remote, or from huggingface.co. If the resource is local, it returns a 207 // file handle to the resource. If the resource is remote, or from 208 // huggingface.co, it fetches the resource and returns a ReadCloser to the 209 // fetched or cached resource. 210 func Fetch(uri string, rsrc string, token string) (io.ReadCloser, error) { 211 if isValidUrl(uri) { 212 return FetchHTTP(uri, rsrc, token) 213 } else if _, err := os.Stat(path.Join(uri, rsrc)); !os.IsNotExist(err) { 214 if handle, fileErr := os.Open(path.Join(uri, rsrc)); fileErr != nil { 215 return nil, fmt.Errorf( 216 "error opening %s/%s: %v", 217 uri, rsrc, fileErr, 218 ) 219 } else { 220 return handle, fileErr 221 } 222 } else { 223 return FetchHuggingFace(uri, rsrc) 224 } 225 } 226 227 // Size 228 // Given a base URI and a resource name, determine the size of the resource. 229 func Size(uri string, rsrc string, token string) (uint, error) { 230 if isValidUrl(uri) { 231 return SizeHTTP(uri, rsrc, token) 232 } else if fsz, err := os.Stat(path.Join(uri, rsrc)); !os.IsNotExist(err) { 233 return uint(fsz.Size()), nil 234 } else { 235 return SizeHuggingFace(uri, rsrc) 236 } 237 } 238 239 // AddEntry 240 // Add a resource to the Resources map, opening it as a mmap.Map. 241 func (rsrcs *Resources) AddEntry(name string, file *os.File) error { 242 fileMmap, mmapErr := readMmap(file) 243 if mmapErr != nil { 244 return fmt.Errorf("error trying to mmap file: %s", mmapErr) 245 } else { 246 (*rsrcs)[name] = ResourceEntry{file, fileMmap} 247 } 248 return nil 249 } 250 251 // Specials 252 // Map of special tokens such as <|pad|>, <|endoftext|>, etc. 253 type Specials map[string]string 254 255 // ResolveSpecialTokens 256 // If specials.json does not exist in dir, create it from the 257 // special_tokens_map.json file. 258 func (rsrcs *Resources) ResolveSpecialTokens(dir string) ( 259 realizedSpecials Specials, err error, 260 ) { 261 realizedSpecials = make(Specials) 262 // If we already have specials.json, we don't need to generate it. 263 if _, ok := (*rsrcs)["specials.json"]; ok { 264 if specErr := json.Unmarshal( 265 *(*rsrcs)["specials.json"].Data, 266 &realizedSpecials, 267 ); specErr != nil { 268 return nil, fmt.Errorf( 269 "cannot unmarshal specials.json: %s", specErr, 270 ) 271 } 272 return realizedSpecials, nil 273 } 274 275 // We can only generate specials.json if we have special_tokens_map 276 specialsJson, ok := (*rsrcs)["special_tokens_map.json"] 277 if !ok { 278 return nil, nil 279 } 280 281 specialTokens := make(JsonMap) 282 if specialErr := json.Unmarshal( 283 *specialsJson.Data, 284 &specialTokens, 285 ); specialErr != nil { 286 return nil, specialErr 287 } 288 289 for k, v := range specialTokens { 290 var specialToken string 291 switch t := v.(type) { 292 case string: 293 specialToken = t 294 case JsonMap: 295 mv := t["content"] 296 switch mvt := mv.(type) { 297 case string: 298 specialToken = mvt 299 default: 300 log.Fatalf( 301 "unknown format for `special_tokens_map."+ 302 "json`: %v", t, 303 ) 304 } 305 default: 306 log.Fatalf( 307 "unknown format for `special_tokens_map."+ 308 "json`: %v", t, 309 ) 310 } 311 realizedSpecials[k] = specialToken 312 } 313 if len(realizedSpecials) > 0 { 314 specialsFile, specialFileErr := os.OpenFile( 315 path.Join(dir, "specials.json"), 316 os.O_TRUNC|os.O_RDWR|os.O_CREATE, 0755, 317 ) 318 if specialFileErr != nil { 319 return nil, fmt.Errorf( 320 "cannot generate specials.json: %s", 321 specialFileErr, 322 ) 323 } 324 specialsJsonBytes, specialsErr := json.Marshal(realizedSpecials) 325 if specialsErr != nil { 326 _ = specialsFile.Close() 327 return nil, fmt.Errorf( 328 "cannot marshal specials.json: %s", specialsErr, 329 ) 330 } 331 if _, writeErr := specialsFile.Write( 332 specialsJsonBytes, 333 ); writeErr != nil { 334 _ = specialsFile.Close() 335 return nil, fmt.Errorf( 336 "cannot write specials.json: %s", specialsErr, 337 ) 338 } 339 if _, seekErr := specialsFile.Seek(0, 0); seekErr != nil { 340 return nil, 341 fmt.Errorf("cannot seek specials.json: %s", seekErr) 342 } 343 if mmapErr := rsrcs.AddEntry( 344 "specials.json", 345 specialsFile, 346 ); mmapErr != nil { 347 return nil, mmapErr 348 } 349 } 350 return realizedSpecials, nil 351 } 352 353 // ResolveResources resolves all resources at a given uri, and checks if they 354 // exist in the given directory. If they don't exist, they are downloaded. 355 func ResolveResources( 356 uri string, 357 dir *string, 358 rsrcLvl ResourceFlag, 359 rsrcType ResourceType, 360 token string, 361 ) ( 362 *Resources, 363 error, 364 ) { 365 foundResources := make(Resources) 366 resources := GetResourceEntries(rsrcType) 367 aliases := getResourceEntryAliases(rsrcType) 368 369 for file, flag := range resources { 370 var rsrcFile os.File 371 372 // Resolve the resource 373 if flag <= rsrcLvl { 374 log.Printf("Resolving %s/%s... ", uri, file) 375 targetPath := path.Join(*dir, file) 376 rsrcSize, rsrcSizeErr := Size(uri, file, token) 377 alias := file 378 if rsrcSizeErr != nil { 379 // If the resource isn't found under its normal filename, 380 // check under any known aliases. 381 if aliasesList, ok := aliases[file]; ok { 382 for _, alias = range aliasesList { 383 rsrcSize, rsrcSizeErr = Size(uri, alias, token) 384 if rsrcSizeErr == nil { 385 log.Printf( 386 "Resolving %s/%s as alias %s/%s...", 387 uri, file, uri, alias, 388 ) 389 break 390 } 391 } 392 } 393 } 394 if rsrcSizeErr != nil { 395 // If the resource is required, we cannot continue. 396 if flag&RESOURCE_REQUIRED != 0 { 397 log.Printf( 398 "%s/%s not found, required!", 399 uri, file, 400 ) 401 return &foundResources, fmt.Errorf( 402 "cannot retrieve required `%s from %s`: %s", 403 uri, file, rsrcSizeErr, 404 ) 405 } else { 406 // Otherwise, we can skip it. 407 continue 408 } 409 // If the resource exists, and is the correct size, skip it. 410 } else if targetStat, targetStatErr := os.Stat(targetPath); !os.IsNotExist( 411 targetStatErr, 412 ) && uint(targetStat.Size()) == rsrcSize { 413 log.Printf( 414 "Skipping %s/%s... already exists, "+ 415 "and of the correct size.", uri, file, 416 ) 417 openFile, skipFileErr := os.OpenFile( 418 path.Join(*dir, file), 419 os.O_RDONLY, 0755, 420 ) 421 if skipFileErr != nil { 422 return &foundResources, fmt.Errorf( 423 "error opening '%s' for read: %s", 424 file, skipFileErr, 425 ) 426 427 } else { 428 // If the resource exists, but is the wrong size, we need 429 // to download it. 430 rsrcFile = *openFile 431 } 432 } else if rsrcReader, rsrcErr := Fetch( 433 uri, alias, token, 434 ); rsrcErr != nil { 435 return &foundResources, fmt.Errorf( 436 "cannot retrieve `%s from %s`: %s", 437 uri, alias, rsrcErr, 438 ) 439 } else { 440 if dirErr := os.MkdirAll( 441 path.Dir(path.Join(*dir, file)), 0755, 442 ); dirErr != nil { 443 return &foundResources, fmt.Errorf( 444 "cannot create directory for '%s': %s", 445 file, dirErr, 446 ) 447 } 448 openFile, rsrcFileErr := os.OpenFile( 449 path.Join(*dir, file), 450 os.O_TRUNC|os.O_RDWR|os.O_CREATE, 0755, 451 ) 452 if rsrcFileErr != nil { 453 return &foundResources, fmt.Errorf( 454 "error opening '%s' for write: %s", 455 file, rsrcFileErr, 456 ) 457 } 458 rsrcFile = *openFile 459 460 counter := &WriteCounter{ 461 Last: time.Now(), 462 Path: fmt.Sprintf("%s/%s", uri, file), 463 Size: uint64(rsrcSize), 464 } 465 bytesDownloaded, ioErr := io.Copy( 466 &rsrcFile, 467 io.TeeReader(rsrcReader, counter), 468 ) 469 _ = rsrcReader.Close() 470 if ioErr != nil { 471 return &foundResources, fmt.Errorf( 472 "error downloading '%s': %s", 473 alias, ioErr, 474 ) 475 } else { 476 log.Printf( 477 "Downloaded %s/%s... "+ 478 "%s completed.", uri, alias, 479 humanize.Bytes(uint64(bytesDownloaded)), 480 ) 481 } 482 } 483 484 if mmapErr := foundResources.AddEntry( 485 file, 486 &rsrcFile, 487 ); mmapErr != nil { 488 return &foundResources, fmt.Errorf( 489 "error trying to mmap file: %s", 490 mmapErr, 491 ) 492 } 493 } 494 } 495 496 // check if tokenizer.model exists, if so, expand to files 497 flagTokenizerModelExist := CheckFileExist( 498 path.Join( 499 *dir, "tokenizer.model", 500 ), 501 ) 502 if flagTokenizerModelExist { 503 // check size of tokenizer.model 504 targetStat, targetStatErr := os.Stat( 505 path.Join( 506 *dir, "tokenizer.model", 507 ), 508 ) 509 if targetStatErr != nil { 510 return &foundResources, fmt.Errorf( 511 "cannot stat tokenizer.model: %s", 512 targetStatErr, 513 ) 514 } 515 if targetStat.Size() == 0 { 516 flagTokenizerModelExist = false 517 } 518 } 519 520 if flagTokenizerModelExist { 521 log.Printf( 522 "Directory %s contains tokenizer.model, extracting to files", 523 path.Join(*dir, "tokenizer.model"), 524 ) 525 ConvertSentencepieceFiles( 526 path.Join(*dir, "tokenizer.model"), 527 false, 528 ) 529 530 // Add the new files to the resources 531 files, _ := os.ReadDir(*dir) 532 for _, f := range files { 533 // If not already in the resources, add it 534 if _, ok := foundResources[f.Name()]; !ok { 535 openFile, rsrcFileErr := os.OpenFile( 536 path.Join(*dir, f.Name()), 537 os.O_RDONLY, 0755, 538 ) 539 if rsrcFileErr != nil { 540 return &foundResources, fmt.Errorf( 541 "error opening '%s' for read: %s", 542 f.Name(), rsrcFileErr, 543 ) 544 } 545 rsrcFile := *openFile 546 if mmapErr := foundResources.AddEntry( 547 f.Name(), 548 &rsrcFile, 549 ); mmapErr != nil { 550 return &foundResources, fmt.Errorf( 551 "error trying to mmap file: %s", 552 mmapErr, 553 ) 554 } 555 log.Printf( 556 "Added %s to resources via sentencepiece conversion\n", 557 f.Name(), 558 ) 559 } 560 } 561 562 } else { 563 // check if tokenizer exists by checking if tokenizer.json exists 564 // and has data in it 565 flagTokenizerExist := CheckFileExist( 566 path.Join( 567 *dir, "tokenizer.json", 568 ), 569 ) 570 if flagTokenizerExist { 571 // check size of tokenizer.json 572 targetStat, targetStatErr := os.Stat( 573 path.Join( 574 *dir, "tokenizer.json", 575 ), 576 ) 577 if targetStatErr != nil { 578 return &foundResources, fmt.Errorf( 579 "cannot stat tokenizer.json: %s", 580 targetStatErr, 581 ) 582 } 583 if targetStat.Size() == 0 { 584 flagTokenizerExist = false 585 } 586 } 587 588 // if tokenizer exists, but vocab and merges do not exist, extract 589 // from tokenizer, else if vocab and merges exist, do nothing; if 590 // both do not exist, fail 591 flagVocabExist := CheckFileExist(path.Join(*dir, "vocab.json")) 592 flagMergesExists := CheckFileExist(path.Join(*dir, "merges.txt")) 593 594 if flagTokenizerExist { 595 // if vocab does not exist, extract it from tokenizer 596 if !flagVocabExist { 597 model, err := ExtractModelFromTokenizer(dir) 598 if err != nil { 599 return &foundResources, fmt.Errorf( 600 "could not extract model from tokenizer %s", 601 err, 602 ) 603 } 604 605 err = ExtractVocabFromTokenizer(model, dir, &foundResources) 606 if err != nil { 607 return &foundResources, fmt.Errorf( 608 "could not extract vocab from tokenizer %s", 609 err, 610 ) 611 } 612 } 613 614 // if merges does not exist, extract it from tokenizer 615 if !flagMergesExists { 616 model, err := ExtractModelFromTokenizer(dir) 617 if err != nil { 618 return &foundResources, 619 fmt.Errorf( 620 "could not extract model from tokenizer %s", 621 err, 622 ) 623 } 624 625 err = ExtractMergesFromTokenizer(model, dir, &foundResources) 626 if err != nil { 627 return &foundResources, fmt.Errorf( 628 "could not extract merges from tokenizer %s", 629 err, 630 ) 631 } 632 } 633 } else { 634 // if tokenizer does not exist, check if vocab and merges exist 635 if flagVocabExist && flagMergesExists { 636 // if both exist, do nothing 637 log.Println( 638 "Vocab and merges exist, but tokenizer does not. OK", 639 ) 640 } else { 641 // if either does not exist, fail 642 return &foundResources, fmt.Errorf( 643 "tokenizer, vocab, and merges do not exist", 644 ) 645 } 646 } 647 648 } 649 650 // Check if we already got the pytorch model file 651 flagModelExists := CheckFileExist(path.Join(*dir, "pytorch_model.bin")) 652 log.Printf("Pytorch Model File exists: %t\n", flagModelExists) 653 654 // if model does not exist, check if we have the sharded config 655 if !flagModelExists { 656 flagShardConfigExists := CheckFileExist( 657 path.Join( 658 *dir, "pytorch_model.bin.index.json", 659 ), 660 ) 661 log.Printf("Shard config exists: %t", flagShardConfigExists) 662 //if sharded config exists, attempt to download the shards 663 if flagShardConfigExists { 664 numShards, err := FindNumberOfShardsFromConfig( 665 path.Join( 666 *dir, "pytorch_model.bin.index.json", 667 ), 668 ) 669 if err != nil { 670 log.Printf( 671 "Could not find number of shards from config: %s\n", 672 err, 673 ) 674 return &foundResources, errors.New( 675 "could not find number of shards from config", 676 ) 677 } 678 679 // pad the number of shards to 5 digits 680 log.Printf("Found %d shards\n", numShards) 681 paddedTotalShards := fmt.Sprintf("%05d", numShards) 682 683 // loop through shards and download them 684 for i := 1; i <= numShards; i++ { 685 var rsrcFile os.File 686 687 paddedShardString := fmt.Sprintf("%05d", i) 688 // Construct the shard path 689 shardPath := fmt.Sprintf( 690 "pytorch_model-%s-of-%s.bin", paddedShardString, 691 paddedTotalShards, 692 ) 693 log.Printf("Resolving shard %s\n", shardPath) 694 695 targetPath := path.Join(*dir, shardPath) 696 rsrcSize, rsrcSizeErr := Size(uri, shardPath, token) 697 if rsrcSizeErr != nil { 698 fmt.Printf( 699 "could not get size of shard %s: %s\n", 700 shardPath, 701 rsrcSizeErr, 702 ) 703 return &foundResources, 704 errors.New("could not get size of shard") 705 } 706 // Print size of shard 707 log.Printf( 708 "Remote size of shard %s is %s\n", shardPath, 709 humanize.Bytes(uint64(rsrcSize)), 710 ) 711 712 // Check if shard exists locally 713 flagShardExists := CheckFileExist(targetPath) 714 if flagShardExists { 715 // Check if shard local size is correct compared to 716 // remote size 717 localShardInfo, err := os.Stat(targetPath) 718 if err != nil { 719 fmt.Printf( 720 "Could not get size of local shard %s: %s\n", 721 shardPath, err, 722 ) 723 return &foundResources, 724 errors.New("could not get size of local shard") 725 } 726 if (rsrcSize > 0) && (rsrcSize == uint(localShardInfo.Size())) { 727 log.Printf( 728 "Skipping shard %s, exists and correct size\n", 729 shardPath, 730 ) 731 continue 732 } 733 } 734 735 // fetch shard 736 var rsrcReader io.ReadCloser 737 rsrcReader, err = Fetch(uri, shardPath, token) 738 if err != nil { 739 return &foundResources, fmt.Errorf( 740 "error trying to fetch file: %s", err, 741 ) 742 } 743 744 // create shard file 745 var rsrcFilePtr *os.File 746 rsrcFilePtr, err = os.Create(targetPath) 747 if err != nil { 748 return &foundResources, fmt.Errorf( 749 "error trying to create file: %s", err, 750 ) 751 } 752 rsrcFile = *rsrcFilePtr 753 754 // copy shard to file 755 counter := &WriteCounter{ 756 Last: time.Now(), 757 Path: fmt.Sprintf("%s/%s", uri, shardPath), 758 Size: uint64(rsrcSize), 759 } 760 bytesDownloaded, ioErr := io.Copy( 761 &rsrcFile, 762 io.TeeReader(rsrcReader, counter), 763 ) 764 //close shard reader 765 err = rsrcReader.Close() 766 767 if err != nil { 768 return &foundResources, fmt.Errorf( 769 "error trying to close reader: %s", err, 770 ) 771 } 772 if ioErr != nil { 773 return &foundResources, fmt.Errorf( 774 "error downloading '%s': %s", 775 shardPath, 776 ioErr, 777 ) 778 } else { 779 log.Printf( 780 "Downloaded %s/%s... "+ 781 "%s completed.", uri, shardPath, 782 humanize.Bytes(uint64(bytesDownloaded)), 783 ) 784 } 785 786 //close shard file 787 err = rsrcFile.Close() 788 789 if err != nil { 790 return &foundResources, fmt.Errorf( 791 "error trying to close reader: %s", err, 792 ) 793 } else { 794 log.Printf( 795 "Downloaded %s/%s... "+ 796 "%s completed.", uri, shardPath, 797 humanize.Bytes(uint64(bytesDownloaded)), 798 ) 799 } 800 801 //check if shard local size is correct compared to remote size 802 var localShardInfo os.FileInfo 803 localShardInfo, err = os.Stat(targetPath) 804 if err != nil { 805 fmt.Printf( 806 "Could not get size of local shard %s: %s\n", 807 shardPath, err, 808 ) 809 return &foundResources, 810 errors.New("could not get size of local shard") 811 } 812 if (rsrcSize > 0) && 813 (rsrcSize != uint(localShardInfo.Size())) { 814 return &foundResources, 815 errors.New("shard was not downloaded correctly") 816 } 817 log.Printf( 818 "Shard %s downloaded correctly, size is %s\n", 819 shardPath, 820 humanize.Bytes(uint64(localShardInfo.Size())), 821 ) 822 823 } 824 log.Printf("Downloaded %d shards\n", numShards) 825 826 } 827 828 } 829 return &foundResources, nil 830 } 831 832 // ResolveSplitRegex 833 // Resolves the split regex from the tokenizer.json file, if it exists. 834 func (rsrcs *Resources) ResolveSplitRegex() *string { 835 var splitRegex *string 836 if tokenizerData, ok := (*rsrcs)["tokenizer.json"]; ok { 837 var tokenizerMap JsonMap 838 if tokenizerData.Data != nil { 839 if json.Unmarshal(*tokenizerData.Data, &tokenizerMap) != nil { 840 log.Fatal("Error unmarshalling tokenizer.json") 841 } 842 } 843 if preTok, ok := tokenizerMap["pre_tokenizer"]; ok && preTok != nil { 844 preTokMap := preTok.(map[string]interface{}) 845 if pretokenizers, ok := preTokMap["pretokenizers"]; ok && 846 pretokenizers != nil { 847 pretokenizersList := pretokenizers.([]interface{}) 848 for _, v := range pretokenizersList { 849 vIntf := v.(map[string]interface{}) 850 if vIntf["type"] == "Split" { 851 pattern := vIntf["pattern"].(map[string]interface{}) 852 if pattern["Regex"] != nil { 853 splitRegexVal := pattern["Regex"].(string) 854 // Fix lookbacks 855 splitRegexVal = strings.ReplaceAll( 856 splitRegexVal, 857 "(?!\\S)", 858 "(\\S){0}", 859 ) 860 splitRegex = &splitRegexVal 861 } 862 } 863 } 864 } 865 } 866 } 867 return splitRegex 868 } 869 870 // CheckFileExist checks if a file exists at a given path. 871 func CheckFileExist(path string) bool { 872 _, err := os.Stat(path) 873 874 if errors.Is(err, os.ErrNotExist) { 875 return false 876 } else { 877 return true 878 } 879 } 880 881 // HFConfig contains the tokenizer configuration that gpt_bpe uses. 882 type HFConfig struct { 883 ModelId *string `json:"omitempty"` 884 ModelType *string `json:"model_type,omitempty"` 885 EosTokenId *Token `json:"eos_token_id,omitempty"` 886 BosTokenId *Token `json:"bos_token_id,omitempty"` 887 PadTokenId *Token `json:"pad_token_id,omitempty"` 888 BosTokenStr *string `json:"bos_token,omitempty"` 889 EosTokenStr *string `json:"eos_token,omitempty"` 890 PadTokenStr *string `json:"pad_token,omitempty"` 891 VocabSize *uint32 `json:"vocab_size,omitempty"` 892 NewLineMode *string `json:"newlinemode,omitempty"` 893 TokenizerClass *string `json:"tokenizer_class"` 894 AddBosToken *bool `json:"add_bos_token,omitempty"` 895 AddEosToken *bool `json:"add_eos_token,omitempty"` 896 AddedSpecialsTokens *TokenMap `json:"added_specials_tokens,omitempty"` 897 IgnoreMerges *bool `json:"ignore_merges,omitempty"` 898 } 899 900 // SpecialConfig contains the special tokens and special token configuration 901 // that gpt_bpe uses. 902 type SpecialConfig struct { 903 PuncRunes []*string `json:"punc_runes"` 904 Normalizer *map[string]string `json:"normalizer"` 905 EncloseEosBos bool `json:"enclose_eos_bos"` 906 PrefixSpace bool `json:"prefix_space"` 907 LowerCase bool `json:"lower_case"` 908 EndOfWord string `json:"end_of_word"` 909 DecodeExtra *map[string]string `json:"decode_extra"` 910 SplitRegex *string `json:"split_regex"` 911 } 912 913 // NewHFConfig creates a new HFConfig object with default values. 914 func NewHFConfig() *HFConfig { 915 defaultModelId := "" 916 defaultModelType := "gpt2" 917 defaultEosTokenId := Token(0) 918 defaultBosTokenId := Token(0) 919 defaultPadTokenId := Token(0) 920 defaultBosTokenStr := "<|startoftext|>" 921 defaultEosTokenStr := "<|endoftext|>" 922 defaultPadTokenStr := "" 923 defaultVocabSize := uint32(50257) 924 defaultNewLineMode := "prefix" 925 defaultTokenizerClass := "GPT2BPETokenizer" 926 defaultAddBosToken := false 927 defaultAddEosToken := false 928 defaultAddedSpecialsTokens := make(TokenMap) 929 HFConfig := &HFConfig{ 930 ModelId: &defaultModelId, 931 ModelType: &defaultModelType, 932 EosTokenId: &defaultEosTokenId, 933 BosTokenId: &defaultBosTokenId, 934 PadTokenId: &defaultPadTokenId, 935 BosTokenStr: &defaultBosTokenStr, 936 EosTokenStr: &defaultEosTokenStr, 937 PadTokenStr: &defaultPadTokenStr, 938 VocabSize: &defaultVocabSize, 939 NewLineMode: &defaultNewLineMode, 940 TokenizerClass: &defaultTokenizerClass, 941 AddBosToken: &defaultAddBosToken, 942 AddEosToken: &defaultAddEosToken, 943 AddedSpecialsTokens: &defaultAddedSpecialsTokens, 944 } 945 return HFConfig 946 } 947 948 // Processor stores config to process one step of the pipeline 949 type Processor struct { 950 ProcessorType string 951 ProcessorArgs JsonMap 952 } 953 954 // Process the input with the processor 955 func (p *Processor) Process(input interface{}) (interface{}, error) { 956 switch p.ProcessorType { 957 case "prepend": 958 return nil, errors.New("prepend not implemented") 959 default: 960 return nil, errors.New("unknown processor type") 961 } 962 } 963 964 // LoadExternalResources 965 // Resolves a given vocabulary id, and returns the corresponding HuggingFace 966 // configuration, and the resources for the tokenizer. 967 func LoadExternalResources( 968 vocabId string, 969 token string, 970 ) (resources *Resources, err error) { 971 dir, dirErr := ioutil.TempDir("", "resources") 972 if dirErr != nil { 973 return nil, dirErr 974 } 975 defer func(path string) { 976 _ = os.RemoveAll(path) 977 }(dir) 978 rslvdResources, rsrcErr := ResolveResources( 979 vocabId, 980 &dir, 981 RESOURCE_DERIVED, 982 RESOURCETYPE_TRANSFORMERS, 983 token, 984 ) 985 if rsrcErr != nil { 986 return nil, rsrcErr 987 } else { 988 resources = rslvdResources 989 } 990 fmt.Printf("Resources: %v\n", resources) 991 return resources, nil 992 993 } 994 995 // ResolveHF 996 // Given a set of resources, resolve the HuggingFace configuration. 997 // Used to be able to resolve both embedded and local resources. 998 func (rsrcs *Resources) ResolveHF(hfConfig *HFConfig) (err error) { 999 // Resolve config and tokenizer config from resources 1000 // config.json and tokenizer_config.json 1001 if err = rsrcs.resolveConfigAndTokenizer(hfConfig); err != nil { 1002 return err 1003 } 1004 1005 // Resolve special tokens and special tokens config from resources 1006 // special_tokens_map.json and specials.txt 1007 if err = rsrcs.resolveSpecials(hfConfig); err != nil { 1008 return err 1009 } 1010 1011 // Resolve Vocab size from vocab.json or encoder.json 1012 if err = rsrcs.resolveVocabSize(hfConfig); err != nil { 1013 return err 1014 } 1015 1016 // Sometimes TokenIDs are not properly resolved, so we need to check 1017 if hfConfig != nil { 1018 if *hfConfig.EosTokenId == 0 || *hfConfig.BosTokenId == 0 || 1019 *hfConfig.PadTokenId == 0 { 1020 if err = rsrcs.resolveTokenIds(hfConfig); err != nil { 1021 return err 1022 } 1023 } 1024 } else { 1025 return errors.New("could not resolve HFConfig") 1026 } 1027 1028 // Llama 3 and other larger models will enclose eos and bos by default 1029 if *hfConfig.VocabSize > math.MaxUint16+1 { 1030 var addEosToken = true 1031 var addBosToken = true 1032 1033 hfConfig.AddEosToken = &addEosToken 1034 hfConfig.AddBosToken = &addBosToken 1035 } 1036 1037 return nil 1038 } 1039 1040 func GetMergesAsBpeRank(resources *Resources) (map[GPTPair]float64, error) { 1041 bpeRanks := make(map[GPTPair]float64) 1042 // Try to get from merges.txt 1043 if mergesTxt, ok := (*resources)["merges.txt"]; ok { 1044 scanner := bufio.NewScanner(bytes.NewBuffer(*mergesTxt.Data)) 1045 idx := uint32(0) 1046 firstLine := true 1047 for scanner.Scan() { 1048 if firstLine { 1049 firstLine = false 1050 continue 1051 } 1052 leftRight := strings.SplitN(scanner.Text(), " ", 2) 1053 bpeRanks[GPTPair{ 1054 Left: leftRight[0], 1055 Right: leftRight[1], 1056 }] = float64(idx) 1057 idx += 1 1058 } 1059 } else if mergesJson, ok := (*resources)["merges.json"]; ok { 1060 var mergesTable [][]string 1061 err := json.Unmarshal(*mergesJson.Data, &mergesTable) 1062 if err != nil { 1063 return nil, err 1064 } 1065 // Iterate over the merges and add them to the BPE ranks 1066 for rank, merge := range mergesTable { 1067 bpeRanks[GPTPair{merge[0], merge[1]}] = 1068 float64(rank) 1069 } 1070 } else if tokenizerJson, ok := (*resources)["tokenizer.json"]; ok { 1071 // Finally try to get from tokenizer.json, merges entry 1072 var tokenizerJsonMap JsonMap 1073 err := json.Unmarshal(*tokenizerJson.Data, &tokenizerJsonMap) 1074 if err != nil { 1075 return nil, err 1076 } 1077 1078 model, ok := tokenizerJsonMap["model"] 1079 if !ok { 1080 return nil, errors.New("could not get model from tokenizer.json") 1081 } 1082 merges, ok := model.(map[string]interface{})["merges"] 1083 if !ok { 1084 return nil, errors.New("could not get merges from tokenizer.json") 1085 } 1086 // Iterate over the merges and add them to the BPE ranks, in form of string[] 1087 for rank, merge := range merges.([]interface{}) { 1088 mergeStr := merge.(string) 1089 mergeSplit := strings.Split(mergeStr, " ") 1090 bpeRanks[GPTPair{mergeSplit[0], mergeSplit[1]}] = 1091 float64(rank) 1092 } 1093 } else { 1094 return nil, errors.New("could not find merges") 1095 } 1096 return bpeRanks, nil 1097 } 1098 1099 func (rsrcs *Resources) UnmarshalData( 1100 name string, 1101 ) (data *JsonMap, err error) { 1102 if _, err = (*rsrcs).GetFile(name); err == nil { 1103 rawData := (*rsrcs)[name].Data 1104 if err = json.Unmarshal(*rawData, &data); err != nil { 1105 return nil, 1106 fmt.Errorf("error unmarshalling %s: %s", name, err) 1107 } 1108 } else { 1109 return nil, nil 1110 } 1111 return data, nil 1112 } 1113 1114 func (rsrcs *Resources) UnmarshalUntilData( 1115 names []string, 1116 ) ( 1117 name string, 1118 data *JsonMap, 1119 err error, 1120 ) { 1121 for _, name = range names { 1122 if _, ok := (*rsrcs)[name]; !ok { 1123 continue 1124 } 1125 if data, err = rsrcs.UnmarshalData(name); err == nil && data != nil { 1126 return name, data, nil 1127 } else if err != nil { 1128 return name, nil, fmt.Errorf( 1129 "error unmarshalling %s: %s", name, err, 1130 ) 1131 } 1132 } 1133 return "", nil, nil 1134 } 1135 1136 // GetVocab 1137 // Get the vocab from the resources. 1138 // Adapter function to get the vocab from either vocab.json or encoder.json. 1139 func (rsrcs *Resources) GetVocab( 1140 hfConfig *HFConfig, 1141 ) (TokenMap, error) { 1142 // Vocab is stored in either the vocab.json or encoder.json file 1143 // We want to unmarshal the vocab file into an interface to work with 1144 // We attempt to unmarshal under the vocab.json key first, then 1145 // encoder.json if it fails 1146 filesToAttempt := []string{"vocab.json", "encoder.json", "tokenizer.json"} 1147 1148 // Get the vocab from the resources 1149 name, vocabData, err := rsrcs.UnmarshalUntilData(filesToAttempt) 1150 if err != nil { 1151 return nil, err 1152 } else if vocabData == nil { 1153 return nil, errors.New("vocab file not found") 1154 } 1155 1156 tokens := make(TokenMap) 1157 if name == "tokenizer.json" { 1158 // We get the vocab stored in the /model/vocab key 1159 if modelInterface, ok := (*vocabData)["model"]; ok { 1160 model := modelInterface.(map[string]interface{}) 1161 if vocabInterface, ok := model["vocab"]; ok { 1162 vocabMap := vocabInterface.(map[string]interface{}) 1163 for k, v := range vocabMap { 1164 tokens[k] = types.Token(v.(float64)) 1165 } 1166 } 1167 // Check for "ignore_merges", and set it 1168 if ignoreMerges, ok := model["ignore_merges"].(bool); ok { 1169 hfConfig.IgnoreMerges = &ignoreMerges 1170 } 1171 } 1172 } else { 1173 for k, v := range *vocabData { 1174 tokens[k] = types.Token(v.(float64)) 1175 } 1176 } 1177 1178 if hfConfig.IgnoreMerges == nil { 1179 disableIgnoreMerges := false 1180 hfConfig.IgnoreMerges = &disableIgnoreMerges 1181 } 1182 1183 // Add the special tokens to the vocab 1184 if len(*hfConfig.AddedSpecialsTokens) > 0 { 1185 for k, v := range *hfConfig.AddedSpecialsTokens { 1186 tokens[k] = v 1187 } 1188 } 1189 1190 return tokens, nil 1191 } 1192 1193 // resolveTokenIds 1194 // Resolve token ids for eos, bos, and pad tokens from resources. 1195 func (rsrcs *Resources) resolveTokenIds(hfConfig *HFConfig) error { 1196 // Get the vocab from the resources 1197 vocab, err := rsrcs.GetVocab(hfConfig) 1198 if err != nil { 1199 return err 1200 } 1201 1202 // Get the token ids for eos, bos, and pad tokens 1203 var eosTokenId, bosTokenId, padTokenId *Token 1204 if eosToken, ok := vocab[*hfConfig.EosTokenStr]; ok { 1205 eosTokenId = new(Token) 1206 *eosTokenId = Token(eosToken) 1207 hfConfig.EosTokenId = eosTokenId 1208 } 1209 if bosToken, ok := vocab[*hfConfig.BosTokenStr]; ok { 1210 bosTokenId = new(Token) 1211 *bosTokenId = Token(bosToken) 1212 hfConfig.BosTokenId = bosTokenId 1213 } 1214 if padToken, ok := vocab[*hfConfig.PadTokenStr]; ok { 1215 padTokenId = new(Token) 1216 *padTokenId = Token(padToken) 1217 hfConfig.PadTokenId = padTokenId 1218 } 1219 1220 return nil 1221 } 1222 1223 // resolveVocabSize 1224 // Resolve vocab size from resources. 1225 // Used to be able to resolve both embedded and local resources. 1226 // Continuation of ResolveHFFromResources. 1227 func (rsrcs *Resources) resolveVocabSize(hfConfig *HFConfig) (err error) { 1228 // Get the vocab from the resources 1229 var vocab TokenMap 1230 if vocab, err = rsrcs.GetVocab(hfConfig); err != nil { 1231 return err 1232 } 1233 1234 // Get length of vocab 1235 vocabLen := new(uint32) 1236 *vocabLen = uint32(len(vocab)) 1237 1238 hfConfig.VocabSize = vocabLen 1239 return nil 1240 } 1241 1242 // resolveConfigAndTokenizer 1243 // Resolve config and tokenizer config from resources. 1244 // Used to be able to resolve both embedded and local resources. 1245 // Continuation of ResolveHFFromResources. 1246 func (rsrcs *Resources) resolveConfigAndTokenizer( 1247 hfConfig *HFConfig, 1248 ) (err error) { 1249 // Use interfaces to unmarshal the config file and tokenizer config file 1250 var config *JsonMap 1251 var tokenizerConfig *JsonMap 1252 1253 // Get the config and tokenizer config from the resources 1254 1255 // If exists, unmarshal config.json and tokenizer_config.json, else 1256 // use GetFile to get the file, then unmarshal it 1257 1258 if config, err = rsrcs.UnmarshalData("config.json"); err != nil { 1259 return err 1260 } 1261 if tokenizerConfig, err = 1262 rsrcs.UnmarshalData("tokenizer_config.json"); err != nil { 1263 return err 1264 } 1265 1266 // Check if bos_token is in string, this is the old format Pythia has. 1267 // If not, try to unmarshal to the tokenizerSpecials 1268 // that llama 2 has, else try mistral format 1269 if config != nil || tokenizerConfig != nil { 1270 hasReadForEosBos := false 1271 1272 // Read config.json 1273 if config != nil { 1274 configMap := *config 1275 // Using interfaces, first check if bos_token is in string format 1276 if bosToken, ok := configMap["bos_token"].(string); ok { 1277 hfConfig.BosTokenStr = &bosToken 1278 if eosToken, ok := configMap["eos_token"].(string); ok { 1279 hfConfig.EosTokenStr = &eosToken 1280 } 1281 if padToken, ok := configMap["pad_token"].(string); ok { 1282 hfConfig.PadTokenStr = &padToken 1283 } 1284 hasReadForEosBos = true 1285 } 1286 1287 // Read for EOS BOS token ID 1288 if eosTokenId, ok := configMap["eos_token_id"].(float64); ok { 1289 eosTokenIdInt := Token(eosTokenId) 1290 hfConfig.EosTokenId = &eosTokenIdInt 1291 } 1292 if bosTokenId, ok := configMap["bos_token_id"].(float64); ok { 1293 bosTokenIdInt := Token(bosTokenId) 1294 hfConfig.BosTokenId = &bosTokenIdInt 1295 } 1296 1297 // Read for vocab size 1298 if vocabSize, ok := configMap["vocab_size"].(float64); ok { 1299 vocabSizeInt := uint32(vocabSize) 1300 hfConfig.VocabSize = &vocabSizeInt 1301 } 1302 1303 // Read for newLineMode 1304 if newLineMode, ok := configMap["newlinemode"].(string); ok { 1305 hfConfig.NewLineMode = &newLineMode 1306 } 1307 } 1308 1309 // Read tokenizer_config.json 1310 if tokenizerConfig != nil { 1311 configMap := *tokenizerConfig 1312 if !hasReadForEosBos { 1313 // Using interfaces, first check if bos_token is in string format 1314 if bosToken, ok := configMap["bos_token"].(string); ok { 1315 hfConfig.BosTokenStr = &bosToken 1316 if eosToken, ok := configMap["eos_token"].(string); ok { 1317 hfConfig.EosTokenStr = &eosToken 1318 } 1319 if padToken, ok := configMap["pad_token"].(string); ok { 1320 hfConfig.PadTokenStr = &padToken 1321 } 1322 hasReadForEosBos = true 1323 1324 } 1325 } 1326 // If not, assume llama2 format and try to unmarshal 1327 if !hasReadForEosBos { 1328 if bosToken, ok := 1329 configMap["bos_token"].(map[string]interface{}); ok { 1330 if content, ok := bosToken["content"].(string); ok { 1331 hfConfig.BosTokenStr = &content 1332 } 1333 } 1334 if eosToken, ok := 1335 configMap["eos_token"].(map[string]interface{}); ok { 1336 if content, ok := eosToken["content"].(string); ok { 1337 hfConfig.EosTokenStr = &content 1338 } 1339 } 1340 if padToken, ok := configMap["pad_token"].(string); ok { 1341 hfConfig.PadTokenStr = &padToken 1342 } 1343 } 1344 // If that doesn't work, assume mistral format 1345 if !hasReadForEosBos { 1346 if bosToken, ok := configMap["bos_token"].(string); ok { 1347 hfConfig.BosTokenStr = &bosToken 1348 } 1349 if eosToken, ok := configMap["eos_token"].(string); ok { 1350 hfConfig.EosTokenStr = &eosToken 1351 } 1352 if padToken, ok := configMap["pad_token"].(string); ok { 1353 hfConfig.PadTokenStr = &padToken 1354 } 1355 } 1356 1357 // Read for enclose eos bos 1358 if encloseEos, ok := configMap["add_bos_token"].(bool); ok { 1359 hfConfig.AddBosToken = &encloseEos 1360 } 1361 1362 if encloseBos, ok := configMap["add_eos_token"].(bool); ok { 1363 hfConfig.AddEosToken = &encloseBos 1364 } 1365 1366 // Read for added_specials_tokens 1367 // Will later be used to readd into vocab if needed 1368 if addedTokensDecoder, ok := 1369 configMap["added_tokens_decoder"].(map[string]interface{}); ok { 1370 addedSpecialsTokens := make(TokenMap) 1371 for k, v := range addedTokensDecoder { 1372 // Get under content key, key is float64 1373 keyToken, _ := strconv.ParseFloat(k, 64) 1374 valStr := v.(map[string]interface{})["content"].(string) 1375 addedSpecialsTokens[valStr] = types.Token(keyToken) 1376 } 1377 hfConfig.AddedSpecialsTokens = &addedSpecialsTokens 1378 } 1379 1380 // Read for tokenizer Class 1381 if tClass, ok := configMap["tokenizer_class"].(string); ok { 1382 hfConfig.TokenizerClass = &tClass 1383 } 1384 } 1385 } 1386 return nil 1387 } 1388 1389 // resolveSpecials 1390 // Resolve special tokens and special config from resources. 1391 // Used to be able to resolve both embedded and local resources. 1392 // Continuation of ResolveHFFromResources. 1393 func (rsrcs *Resources) resolveSpecials(hfConfig *HFConfig) error { 1394 // Get specials config from resources 1395 // We can only generate specials.json if we have special_tokens_map 1396 specialsJson, ok := (*rsrcs)["special_tokens_map.json"] 1397 if ok { 1398 specialTokens := make(JsonMap) 1399 if specialErr := json.Unmarshal( 1400 *specialsJson.Data, 1401 &specialTokens, 1402 ); specialErr != nil { 1403 return specialErr 1404 } 1405 1406 // Try to get pad token from specials if not already set 1407 if hfConfig.PadTokenStr == nil || *hfConfig.PadTokenStr == "" { 1408 if padToken, pOk := specialTokens["pad_token"].(string); pOk { 1409 hfConfig.PadTokenStr = &padToken 1410 } 1411 } 1412 } 1413 1414 // Get from specials.json 1415 specialsTxt, ok := (*rsrcs)["specials.txt"] 1416 if ok { 1417 // Treat specials.txt as an array of strings and try to match 1418 specials := strings.Split(string(*specialsTxt.Data), "\n") 1419 if hfConfig.PadTokenStr == nil { 1420 for _, special := range specials { 1421 if strings.Contains(strings.ToLower(special), "pad") { 1422 hfConfig.PadTokenStr = &special 1423 break 1424 } 1425 } 1426 } 1427 } 1428 return nil 1429 } 1430 1431 func (rsrcs *Resources) LoadEmbeddedResource( 1432 vocabId string, 1433 resourceId string, 1434 path string, 1435 ) { 1436 if r := GetEmbeddedResource(vocabId + "/" + path); r != nil { 1437 (*rsrcs)[resourceId] = *r 1438 } 1439 } 1440 1441 // ResolveResourcesList 1442 // Resolves a list of resources, and checks if they exist in the given 1443 // directory. If they don't exist, they are downloaded. 1444 func ResolveResourcesList(vocabId string, token string) (*Resources, error) { 1445 // Resolve the vocab id - Embedded resources 1446 if _, vocabErr := EmbeddedDirExists(vocabId); vocabErr == nil { 1447 resources := make(Resources) 1448 1449 possibleEmbeddedResources := []struct { 1450 resourceId string 1451 path string 1452 }{ 1453 {"vocab.json", "encoder.json"}, 1454 {"config.json", "config.json"}, 1455 {"merges.txt", "vocab.bpe"}, 1456 {"merges.json", "merges.json"}, 1457 {"specials.txt", "specials.txt"}, 1458 {"special_tokens_map.json", "special_tokens_map.json"}, 1459 {"special_config.json", "special_config.json"}, 1460 {"tokenizer.json", "tokenizer.json"}, 1461 {"tokenizer_config.json", "tokenizer_config.json"}, 1462 } 1463 1464 for _, resource := range possibleEmbeddedResources { 1465 resources.LoadEmbeddedResource( 1466 vocabId, resource.resourceId, resource.path, 1467 ) 1468 } 1469 return &resources, nil 1470 } 1471 // Non-embedded resources 1472 resources, err := LoadExternalResources(vocabId, token) 1473 if err != nil { 1474 return nil, err 1475 } 1476 return resources, nil 1477 1478 } 1479 1480 // ResolveVocabId 1481 // Resolves a vocabulary id to a set of resources, from embedded, 1482 // local filesystem, or remote, and applies processing to the resources. 1483 func ResolveVocabId(vocabId string, token string) ( 1484 *HFConfig, 1485 *Resources, 1486 error, 1487 ) { 1488 rsrcs, err := ResolveResourcesList(vocabId, token) 1489 if err != nil { 1490 return nil, nil, err 1491 } 1492 1493 hf := NewHFConfig() 1494 hf.ModelId = &vocabId 1495 if err = rsrcs.ResolveHF(hf); err != nil { 1496 return nil, nil, err 1497 } 1498 return hf, rsrcs, nil 1499 } 1500 1501 func ExtractModelFromTokenizer(dir *string) (JsonMap, error) { 1502 tokenizerPath := path.Join(*dir, "tokenizer.json") 1503 // Open the file 1504 tokenizerFile, err := os.Open(tokenizerPath) 1505 if err != nil { 1506 log.Println("Error opening tokenizer:", err) 1507 // return an empty map and the error 1508 return nil, err 1509 } 1510 defer func(tokenizerFile *os.File) { 1511 _ = tokenizerFile.Close() 1512 }(tokenizerFile) 1513 1514 // Decode the JSON data into a map 1515 var data JsonMap 1516 err = json.NewDecoder(tokenizerFile).Decode(&data) 1517 if err != nil { 1518 log.Println("Error decoding JSON from tokenizer:", err) 1519 return nil, err 1520 } 1521 1522 // Access the data at the specified path 1523 model, ok := (data["model"]).(map[string]interface{}) 1524 model = ToJsonMap(model) 1525 if ok { 1526 return model, nil 1527 } else { 1528 log.Println("Error: Could not convert model in tokenizer to map") 1529 return nil, errors.New("could not convert model in tokenizer to map") 1530 } 1531 } 1532 1533 func ExtractVocabFromTokenizer( 1534 model JsonMap, 1535 dir *string, 1536 resources *Resources, 1537 ) error { 1538 vocab, ok := model["vocab"].(map[string]interface{}) 1539 vocab = ToJsonMap(vocab) 1540 if !ok { 1541 log.Println("Error: Could not convert vocab in model to map") 1542 return errors.New("could not convert vocab in model to map") 1543 } 1544 1545 vocabPath := path.Join(*dir, "vocab.json") 1546 1547 // Create the file 1548 vocabFile, err := os.Create(vocabPath) 1549 if err != nil { 1550 log.Println("Error creating vocab.json:", err) 1551 return err 1552 } 1553 defer func(vocabFile *os.File) { 1554 _ = vocabFile.Close() 1555 }(vocabFile) 1556 1557 // Marshal the vocab map into a JSON string with indentation 1558 vocabJsonString, err := json.MarshalIndent(vocab, "", " ") 1559 if err != nil { 1560 fmt.Println("Error marshaling JSON:", err) 1561 return err 1562 } 1563 1564 // Write the JSON string to the file 1565 _, err = vocabFile.Write(vocabJsonString) 1566 if err != nil { 1567 log.Println("Error writing to vocab.json:", err) 1568 return err 1569 } 1570 1571 log.Println("Vocab written to vocab.json from tokenizer.json") 1572 1573 if mmapErr := resources.AddEntry( 1574 "vocab.json", vocabFile, 1575 ); mmapErr != nil { 1576 return fmt.Errorf("error trying to mmap file: %s", mmapErr) 1577 } 1578 1579 return nil 1580 } 1581 1582 func ExtractMergesFromTokenizer( 1583 model JsonMap, 1584 dir *string, 1585 resources *Resources, 1586 ) error { 1587 merges, ok := model["merges"].([]interface{}) 1588 if !ok { 1589 log.Println("Error: Could not convert merges in model to map") 1590 return errors.New("could not convert merges in model to map") 1591 } 1592 1593 // Convert the slice of interfaces to a slice of strings 1594 var mergesStr []string 1595 for _, v := range merges { 1596 mergesStr = append(mergesStr, v.(string)) 1597 } 1598 1599 mergesPath := path.Join(*dir, "merges.txt") 1600 1601 // Create the file 1602 mergesFile, err := os.Create(mergesPath) 1603 if err != nil { 1604 log.Println("Error creating file:", err) 1605 return err 1606 } 1607 defer func(mergesFile *os.File) { 1608 _ = mergesFile.Close() 1609 }(mergesFile) 1610 1611 // Write each merge string to a new line in the file 1612 for _, v := range merges { 1613 _, err = mergesFile.WriteString(v.(string) + "\n") 1614 if err != nil { 1615 log.Println("Error writing to file:", err) 1616 return err 1617 } 1618 } 1619 1620 log.Println("Merges written to merges.txt from tokenizer.json") 1621 1622 if mmapErr := resources.AddEntry( 1623 "merges.txt", mergesFile, 1624 ); mmapErr != nil { 1625 return fmt.Errorf("error trying to mmap file: %s", mmapErr) 1626 } 1627 1628 return nil 1629 } 1630 1631 func FindNumberOfShardsFromConfig(configPath string) (int, error) { 1632 // Open the file 1633 configFile, err := os.Open(configPath) 1634 if err != nil { 1635 log.Println("Error opening config:", err) 1636 return -1, err 1637 } 1638 defer func(configFile *os.File) { 1639 _ = configFile.Close() 1640 }(configFile) 1641 1642 // Decode the JSON data into a map 1643 var data JsonMap 1644 err = json.NewDecoder(configFile).Decode(&data) 1645 if err != nil { 1646 log.Println("Error decoding JSON from config:", err) 1647 return -1, err 1648 } 1649 1650 // Access the data at the specified path 1651 weightMap, ok := data["weight_map"].(map[string]interface{}) 1652 weightMap = ToJsonMap(weightMap) 1653 if !ok { 1654 fmt.Println("Error: Could not convert data to weight_map") 1655 return -1, errors.New("could not convert data to weight_map") 1656 } 1657 // Try embed out, if not, try lm_head.weight 1658 nameOfLast, ok := weightMap["embed_out.weight"] 1659 if !ok { 1660 nameOfLast, ok = weightMap["lm_head.weight"] 1661 if !ok { 1662 fmt.Println("Error: Could not convert weight_map to embed_out or lm_head") 1663 return -1, errors.New("could not convert weight_map to embed_out or lm_head") 1664 } 1665 } 1666 1667 r, _ := regexp.Compile(`\D*\d+\D+(\d+)`) 1668 // convert to interface -> string -> int 1669 nameOfLastInt, err := strconv.Atoi( 1670 r.FindStringSubmatch(fmt.Sprintf("%v", nameOfLast))[1], 1671 ) 1672 1673 if err != nil { 1674 fmt.Println("Error: Could not convert embed_out to int") 1675 return -1, errors.New("could not convert embed_out to int") 1676 } 1677 1678 return nameOfLastInt, nil 1679 } 1680 1681 func FindProcessingStepsFromTokenizer(model ResourceEntry) ( 1682 []Processor, 1683 error, 1684 ) { 1685 // convert the data to a map 1686 var data JsonMap 1687 err := json.Unmarshal(*model.Data, &data) 1688 if err != nil { 1689 return nil, err 1690 } 1691 1692 // create array of processors 1693 var processors []Processor 1694 // check if normalizer is present 1695 normalizer, ok := data["normalizer"].(map[string]interface{}) 1696 normalizer = ToJsonMap(normalizer) 1697 if normalizer != nil && ok { 1698 // add normalizer to processors 1699 processor := Processor{ 1700 ProcessorType: "normalizer", 1701 ProcessorArgs: normalizer, 1702 } 1703 processors = append(processors, processor) 1704 } 1705 // check if pre_tokenizer is present 1706 preTokenizer, ok := data["pre_tokenizer"].(map[string]interface{}) 1707 preTokenizer = ToJsonMap(preTokenizer) 1708 if preTokenizer != nil && ok { 1709 // add pre_tokenizer to processors 1710 processor := Processor{ 1711 ProcessorType: "pre_tokenizer", 1712 ProcessorArgs: preTokenizer, 1713 } 1714 processors = append(processors, processor) 1715 } 1716 // check if post_processor is present 1717 post_processor, ok := data["post_processor"].(map[string]interface{}) 1718 post_processor = ToJsonMap(post_processor) 1719 if post_processor != nil && ok { 1720 // add post_processor to processors 1721 processor := Processor{ 1722 ProcessorType: "post_processor", 1723 ProcessorArgs: post_processor, 1724 } 1725 processors = append(processors, processor) 1726 } 1727 // check if decoder is present 1728 decoder, ok := data["decoder"].(map[string]interface{}) 1729 decoder = ToJsonMap(decoder) 1730 if decoder != nil && ok { 1731 // add decoder to processors 1732 processor := Processor{ 1733 ProcessorType: "decoder", 1734 ProcessorArgs: decoder, 1735 } 1736 processors = append(processors, processor) 1737 } 1738 1739 return processors, nil 1740 } 1741 1742 func ToJsonMap(sim map[string]interface{}) JsonMap { 1743 jm := make(JsonMap) 1744 for k, v := range sim { 1745 jm[k] = v 1746 } 1747 return jm 1748 } 1749 1750 func (jm JsonMap) ToMapInterface() map[string]interface{} { 1751 m := make(map[string]interface{}) 1752 for k, v := range jm { 1753 m[k] = v 1754 } 1755 return m 1756 }