github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/resources/converter.go (about)

     1  package resources
     2  
     3  import (
     4  	"encoding/hex"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"os"
     8  	"path"
     9  	"sort"
    10  	"strings"
    11  	"time"
    12  	"unicode"
    13  
    14  	"github.com/vikesh-raj/go-sentencepiece-encoder/sentencepiece"
    15  	"google.golang.org/protobuf/proto"
    16  )
    17  
    18  var escaper *strings.Replacer
    19  
    20  type DuplicateEntry struct {
    21  	OldIdx int
    22  	NewIdx int
    23  	Repr   string
    24  }
    25  
    26  type GPTPair struct {
    27  	Left  string
    28  	Right string
    29  }
    30  
    31  type VocabEntry struct {
    32  	TokenId *uint32
    33  	Token   *string
    34  	ByteId  *uint32
    35  	Byte    *string
    36  }
    37  
    38  type SentencePieceVocab struct {
    39  	TokenToPiece []VocabEntry
    40  	PieceToToken map[string]VocabEntry
    41  }
    42  
    43  func EscapeString(
    44  	s string,
    45  ) (escaped string) {
    46  	if escaper == nil {
    47  		escaper = strings.NewReplacer(
    48  			"\"", "\\\"",
    49  			"\\", "\\\\",
    50  			"\n", "\\n",
    51  			"\r", "\\r",
    52  			"\b", "\\b",
    53  			"\t", "\\t")
    54  	}
    55  	escaped = escaper.Replace(s)
    56  	asRunes := []rune(escaped)
    57  	if len(asRunes) == 1 && (unicode.IsControl(asRunes[0]) ||
    58  		!unicode.IsPrint(asRunes[0])) {
    59  		escaped = fmt.Sprintf("\\u%04x", asRunes[0])
    60  	}
    61  	return escaped
    62  }
    63  
    64  func UnescapeString(
    65  	s string,
    66  	verbose bool,
    67  ) (unescaped string) {
    68  	if strings.HasPrefix(s, "\\u") {
    69  		// Unescape unicode
    70  		code, _ := hex.DecodeString(s[2:6])
    71  		unescaped = string(code)
    72  		if verbose {
    73  			print(fmt.Sprintf("Unescaped unicode: %v -> %v", s, unescaped))
    74  		}
    75  	} else {
    76  		unescaped = s
    77  	}
    78  	return unescaped
    79  }
    80  
    81  func GenerateVocab(
    82  	model *sentencepiece.ModelProto,
    83  ) (
    84  	vocab *SentencePieceVocab,
    85  	duplicates *[]DuplicateEntry,
    86  	specials *[]string,
    87  ) {
    88  	vocab = &SentencePieceVocab{
    89  		TokenToPiece: make([]VocabEntry, len(model.GetPieces())+1),
    90  		PieceToToken: make(map[string]VocabEntry),
    91  	}
    92  	specials = &[]string{}
    93  	duplicateEntries := make([]DuplicateEntry, 0)
    94  	duplicates = &duplicateEntries
    95  	spaceReplacer := strings.NewReplacer(
    96  		"▁", " ")
    97  	// Build the vocab
    98  	for pieceIdx, piece := range model.GetPieces() {
    99  		repr := piece.GetPiece()
   100  		pieceIsByte := piece.GetType() ==
   101  			sentencepiece.ModelProto_SentencePiece_BYTE
   102  		pieceIsControl := piece.GetType() ==
   103  			sentencepiece.ModelProto_SentencePiece_CONTROL
   104  		pieceIsUser := piece.GetType() ==
   105  			sentencepiece.ModelProto_SentencePiece_USER_DEFINED
   106  		if pieceIsByte {
   107  			hexRepr := piece.GetPiece()[3:5]
   108  			encodedRepr, _ := hex.DecodeString(hexRepr)
   109  			repr = string(encodedRepr)
   110  		} else {
   111  			repr = spaceReplacer.Replace(repr)
   112  			if pieceIsControl || pieceIsUser {
   113  				*specials = append(*specials, repr)
   114  			}
   115  		}
   116  		if dupeEntry, ok := vocab.PieceToToken[repr]; ok {
   117  			var dupeIdx uint32
   118  			if dupeEntry.TokenId != nil {
   119  				dupeIdx = *dupeEntry.TokenId
   120  			} else {
   121  				dupeIdx = *dupeEntry.ByteId
   122  			}
   123  			if pieceIsByte {
   124  				byteToken := uint32(pieceIdx)
   125  				dupeEntry.Byte = &repr
   126  				dupeEntry.ByteId = &byteToken
   127  			} else {
   128  				tokenToken := uint32(pieceIdx)
   129  				dupeEntry.Token = &repr
   130  				dupeEntry.TokenId = &tokenToken
   131  			}
   132  			vocab.PieceToToken[repr] = dupeEntry
   133  			vocab.TokenToPiece[dupeIdx] = dupeEntry
   134  			vocab.TokenToPiece[uint32(pieceIdx)] = dupeEntry
   135  			print(fmt.Sprintf("Duplicate piece: old (%v): %v, dupe ("+
   136  				"%v): %v\n",
   137  				dupeIdx, model.GetPieces()[dupeIdx], pieceIdx, piece))
   138  			*duplicates = append(*duplicates, DuplicateEntry{
   139  				OldIdx: int(dupeIdx),
   140  				NewIdx: pieceIdx,
   141  				Repr:   repr,
   142  			})
   143  		} else {
   144  			if pieceIsByte {
   145  				byteToken := uint32(pieceIdx)
   146  				vocab.PieceToToken[repr] = VocabEntry{
   147  					Byte:   &repr,
   148  					ByteId: &byteToken,
   149  				}
   150  			} else {
   151  				tokenToken := uint32(pieceIdx)
   152  				vocab.PieceToToken[repr] = VocabEntry{
   153  					Token:   &repr,
   154  					TokenId: &tokenToken,
   155  				}
   156  			}
   157  			vocab.TokenToPiece[pieceIdx] = vocab.PieceToToken[repr]
   158  		}
   159  	}
   160  	return vocab, duplicates, specials
   161  }
   162  
   163  func GenerateMergeTable(
   164  	vocab *SentencePieceVocab,
   165  	verbose bool,
   166  ) map[GPTPair]uint32 {
   167  	// Build the merge table
   168  	mergeTable := make(map[GPTPair]uint32, 0)
   169  
   170  	// Loop over the model and print out the pieces
   171  	currPair := GPTPair{"", ""}
   172  	for _, token := range vocab.TokenToPiece {
   173  		if token.Token == nil || *token.Token == "" || len(*token.Token) < 2 {
   174  			continue
   175  		}
   176  		for splitIdx := 1; splitIdx < len(*token.Token); splitIdx++ {
   177  			currPair.Left = (*token.Token)[:splitIdx]
   178  			currPair.Right = (*token.Token)[splitIdx:]
   179  			// Check if both pieces exist in the vocab
   180  			leftTokenEntry, leftOk := vocab.PieceToToken[currPair.Left]
   181  			rightTokenEntry, rightOk := vocab.PieceToToken[currPair.Right]
   182  			if !leftOk || !rightOk {
   183  				continue
   184  			}
   185  			if _, ok := mergeTable[currPair]; !ok {
   186  				mergedToken := fmt.Sprintf("%v%v",
   187  					currPair.Left,
   188  					currPair.Right)
   189  
   190  				if tokenEntry, ok := vocab.PieceToToken[mergedToken]; ok {
   191  					leftTokenId := leftTokenEntry.TokenId
   192  					rightTokenId := rightTokenEntry.TokenId
   193  					tokenId := *tokenEntry.TokenId
   194  					if verbose {
   195  						print(fmt.Sprintf("%v (%v) %v (%v) -> %v (%v)\n",
   196  							currPair.Left, leftTokenId,
   197  							currPair.Right, rightTokenId,
   198  							mergedToken, tokenId))
   199  					}
   200  					mergeTable[currPair] = tokenId
   201  				}
   202  			}
   203  		}
   204  	}
   205  	return mergeTable
   206  }
   207  
   208  // Our struct for the merge array
   209  type MergeEntry struct {
   210  	Left        string `json:"left"`
   211  	LeftToken   uint32 `json:"-"`
   212  	Right       string `json:"right"`
   213  	RightToken  uint32 `json:"-"`
   214  	Merged      string `json:"-"`
   215  	MergedToken uint32 `json:"-"`
   216  }
   217  
   218  func GenerateMergeEntries(
   219  	vocab *SentencePieceVocab,
   220  	mergeTable map[GPTPair]uint32,
   221  ) []MergeEntry {
   222  	// Turn the merge table into an array of entries
   223  	mergeEntries := make([]MergeEntry, 0)
   224  	for pair := range mergeTable {
   225  		mergedToken := fmt.Sprintf("%v%v", pair.Left, pair.Right)
   226  		// Skip single rune tokens
   227  		if len([]rune(mergedToken)) == 1 {
   228  			continue
   229  		}
   230  		mergeEntries = append(mergeEntries,
   231  			MergeEntry{pair.Left,
   232  				*vocab.PieceToToken[pair.Left].TokenId,
   233  				pair.Right,
   234  				*vocab.PieceToToken[pair.Right].TokenId,
   235  				mergedToken,
   236  				*vocab.PieceToToken[mergedToken].TokenId})
   237  	}
   238  	// Sort the merge array by token id
   239  	sort.Slice(mergeEntries, func(i, j int) bool {
   240  		return mergeEntries[i].MergedToken < mergeEntries[j].MergedToken
   241  	})
   242  	return mergeEntries
   243  }
   244  
   245  func WriteDuplicates(
   246  	name string,
   247  	duplicates *[]DuplicateEntry,
   248  ) {
   249  	duplicatesFile, err := os.Create(fmt.Sprintf("%s.json", name))
   250  	if err != nil {
   251  		panic(err)
   252  	}
   253  	duplicatesFile.WriteString("[\n")
   254  	for idx, dupe := range *duplicates {
   255  		escaped := EscapeString(dupe.Repr)
   256  		duplicatesFile.WriteString(fmt.Sprintf("  {\"old_id\": %v, "+
   257  			"\"new_id\": %v, \"repr\": \"%v\"}",
   258  			dupe.OldIdx, dupe.NewIdx, escaped))
   259  		if idx != len(*duplicates)-1 {
   260  			duplicatesFile.WriteString(",\n")
   261  		} else {
   262  			duplicatesFile.WriteString("\n")
   263  		}
   264  	}
   265  	duplicatesFile.WriteString("]\n")
   266  }
   267  
   268  func WriteMergeFiles(
   269  	name string,
   270  	mergeEntries []MergeEntry,
   271  	verbose bool,
   272  ) {
   273  	mergesFile, err := os.Create(fmt.Sprintf("%s.json", name))
   274  	if err != nil {
   275  		panic(err)
   276  	}
   277  
   278  	if verbose {
   279  		mergesFile.WriteString("[\n")
   280  	} else {
   281  		mergesFile.WriteString("[")
   282  	}
   283  
   284  	// Write the merge table to a text file and json file
   285  	for idx, pair := range mergeEntries {
   286  		leftRepr := EscapeString(pair.Left)
   287  		rightRepr := EscapeString(pair.Right)
   288  		mergedRepr := EscapeString(pair.Merged)
   289  
   290  		if idx != 0 && verbose {
   291  			mergesFile.WriteString(",\n  ")
   292  		} else if idx != 0 {
   293  			mergesFile.WriteString(",")
   294  		}
   295  
   296  		if verbose {
   297  			mergesFile.WriteString(fmt.Sprintf(
   298  				"{\"left\": \"%v\", \", left_token\": %v, "+
   299  					"\"right\": \"%v\", \"right_token\": %v, "+
   300  					"\"merged\": \"%v\", \"merged_token\": %v}",
   301  				leftRepr, pair.LeftToken,
   302  				rightRepr, pair.RightToken,
   303  				mergedRepr, pair.MergedToken))
   304  		} else {
   305  			mergesFile.WriteString(fmt.Sprintf(
   306  				"[\"%v\",\"%v\"]",
   307  				leftRepr, rightRepr))
   308  		}
   309  	}
   310  	if verbose {
   311  		mergesFile.WriteString("]")
   312  	} else {
   313  		mergesFile.WriteString("\n]\n")
   314  	}
   315  	mergesFile.Close()
   316  }
   317  
   318  func WriteVocabFile(
   319  	name string,
   320  	vocab *SentencePieceVocab,
   321  	verbose bool,
   322  ) {
   323  	// Serialize vocab to a JSON file
   324  	vocabFile, _ := os.Create(fmt.Sprintf("%s.json", name))
   325  	vocabSize := len(vocab.TokenToPiece)
   326  
   327  	var entryPrefix string
   328  	if verbose {
   329  		entryPrefix = " "
   330  		vocabFile.WriteString("{\n")
   331  	} else {
   332  		entryPrefix = ""
   333  		vocabFile.WriteString("{")
   334  	}
   335  
   336  	for tokenId := 0; tokenId < vocabSize; tokenId++ {
   337  		tokenEntry := vocab.TokenToPiece[tokenId]
   338  		var repr string
   339  		if tokenEntry.TokenId != nil &&
   340  			*tokenEntry.TokenId == uint32(tokenId) {
   341  			repr = EscapeString(*tokenEntry.Token)
   342  		} else if tokenEntry.Byte != nil {
   343  			// Convert our repr string to a byte
   344  			reprByte := []byte(*tokenEntry.Byte)
   345  			// Convert the byte to a hexstring
   346  			repr = fmt.Sprintf("0x%02x", reprByte)
   347  		}
   348  		if tokenId != 0 && verbose {
   349  			vocabFile.WriteString(",\n")
   350  		} else if tokenId != 0 {
   351  			vocabFile.WriteString(",")
   352  		}
   353  
   354  		vocabFile.WriteString(fmt.Sprintf("%s\"%v\":%s%d",
   355  			entryPrefix, repr, entryPrefix, tokenId))
   356  	}
   357  	if verbose {
   358  		vocabFile.WriteString("\n}\n")
   359  	} else {
   360  		vocabFile.WriteString("}")
   361  	}
   362  	vocabFile.Close()
   363  }
   364  
   365  func WriteSpecials(
   366  	name string,
   367  	specials *[]string,
   368  ) {
   369  	// Sort the specials by length
   370  	sort.Slice(*specials, func(i, j int) bool {
   371  		return len((*specials)[i]) < len((*specials)[j])
   372  	})
   373  
   374  	specialsFile, err := os.Create(fmt.Sprintf("%s.txt", name))
   375  	if err != nil {
   376  		panic(err)
   377  	}
   378  	for idx, special := range *specials {
   379  		if idx != 0 {
   380  			specialsFile.WriteString("\n")
   381  		}
   382  		specialsFile.WriteString(special)
   383  	}
   384  	specialsFile.Close()
   385  }
   386  
   387  func ConvertSentencepieceFiles(modelPath string, verbose bool) {
   388  	bytes, err := ioutil.ReadFile(modelPath)
   389  	if err != nil {
   390  		print(fmt.Errorf("unable to read file err %v", err))
   391  	}
   392  	var model sentencepiece.ModelProto
   393  	err = proto.Unmarshal(bytes, &model)
   394  	if err != nil {
   395  		print(fmt.Errorf("unable to unmarshal proto err %v", err))
   396  	}
   397  	//get parent of modelPath
   398  	outputPath := path.Dir(modelPath)
   399  
   400  	vocab, duplicates, specials := GenerateVocab(&model)
   401  	WriteVocabFile(path.Join(outputPath, "vocab"), vocab, verbose)
   402  	WriteSpecials(path.Join(outputPath, "specials"), specials)
   403  	WriteDuplicates(path.Join(outputPath, "duplicates"), duplicates)
   404  	mergeTable := GenerateMergeTable(vocab, verbose)
   405  	mergeEntries := GenerateMergeEntries(vocab, mergeTable)
   406  	WriteMergeFiles(path.Join(outputPath, "merges"), mergeEntries, verbose)
   407  }
   408  
   409  func main() {
   410  	start := time.Now()
   411  	ConvertSentencepieceFiles("./tokenizer.model", true)
   412  	elapsed := time.Since(start)
   413  	fmt.Printf("Conversion took %s\n", elapsed)
   414  }