github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/cmd/sentencepiece_converter/sentencepiece_converter.go (about)

     1  package main
     2  
     3  import (
     4  	"encoding/hex"
     5  	"fmt"
     6  	"github.com/vikesh-raj/go-sentencepiece-encoder/sentencepiece"
     7  	"github.com/wbrown/gpt_bpe"
     8  	"google.golang.org/protobuf/proto"
     9  	"io/ioutil"
    10  	"os"
    11  	"sort"
    12  	"strings"
    13  	"time"
    14  	"unicode"
    15  )
    16  
    17  var escaper *strings.Replacer
    18  
    19  type DuplicateEntry struct {
    20  	OldIdx int
    21  	NewIdx int
    22  	Repr   string
    23  }
    24  
    25  type VocabEntry struct {
    26  	TokenId *gpt_bpe.Token
    27  	Token   *string
    28  	ByteId  *gpt_bpe.Token
    29  	Byte    *string
    30  }
    31  
    32  type SentencePieceVocab struct {
    33  	TokenToPiece []VocabEntry
    34  	PieceToToken map[string]VocabEntry
    35  }
    36  
    37  func EscapeString(
    38  	s string,
    39  ) (escaped string) {
    40  	if escaper == nil {
    41  		escaper = strings.NewReplacer(
    42  			"\"", "\\\"",
    43  			"\\", "\\\\",
    44  			"\n", "\\n",
    45  			"\r", "\\r",
    46  			"\b", "\\b",
    47  			"\t", "\\t")
    48  	}
    49  	escaped = escaper.Replace(s)
    50  	asRunes := []rune(escaped)
    51  	if len(asRunes) == 1 && (unicode.IsControl(asRunes[0]) ||
    52  		!unicode.IsPrint(asRunes[0])) {
    53  		escaped = fmt.Sprintf("\\u%04x", asRunes[0])
    54  	}
    55  	return escaped
    56  }
    57  
    58  func UnescapeString(
    59  	s string,
    60  ) (unescaped string) {
    61  	if strings.HasPrefix(s, "\\u") {
    62  		// Unescape unicode
    63  		code, _ := hex.DecodeString(s[2:6])
    64  		unescaped = string(code)
    65  		print(fmt.Sprintf("Unescaped unicode: %v -> %v", s, unescaped))
    66  	} else {
    67  		unescaped = s
    68  	}
    69  	return unescaped
    70  }
    71  
    72  func GenerateVocab(
    73  	model *sentencepiece.ModelProto,
    74  ) (
    75  	vocab *SentencePieceVocab,
    76  	duplicates *[]DuplicateEntry,
    77  	specials *[]string,
    78  ) {
    79  	vocab = &SentencePieceVocab{
    80  		TokenToPiece: make([]VocabEntry, len(model.GetPieces())+1),
    81  		PieceToToken: make(map[string]VocabEntry),
    82  	}
    83  	specials = &[]string{}
    84  	duplicateEntries := make([]DuplicateEntry, 0)
    85  	duplicates = &duplicateEntries
    86  	spaceReplacer := strings.NewReplacer(
    87  		"▁", " ")
    88  	// Build the vocab
    89  	for pieceIdx, piece := range model.GetPieces() {
    90  		repr := piece.GetPiece()
    91  		pieceIsByte := piece.GetType() ==
    92  			sentencepiece.ModelProto_SentencePiece_BYTE
    93  		pieceIsControl := piece.GetType() ==
    94  			sentencepiece.ModelProto_SentencePiece_CONTROL
    95  		pieceIsUser := piece.GetType() ==
    96  			sentencepiece.ModelProto_SentencePiece_USER_DEFINED
    97  		if pieceIsByte {
    98  			hexRepr := piece.GetPiece()[3:5]
    99  			encodedRepr, _ := hex.DecodeString(hexRepr)
   100  			repr = string(encodedRepr)
   101  		} else {
   102  			repr = spaceReplacer.Replace(repr)
   103  			if pieceIsControl || pieceIsUser {
   104  				*specials = append(*specials, repr)
   105  			}
   106  		}
   107  		if dupeEntry, ok := vocab.PieceToToken[repr]; ok {
   108  			var dupeIdx gpt_bpe.Token
   109  			if dupeEntry.TokenId != nil {
   110  				dupeIdx = *dupeEntry.TokenId
   111  			} else {
   112  				dupeIdx = *dupeEntry.ByteId
   113  			}
   114  			if pieceIsByte {
   115  				byteToken := gpt_bpe.Token(pieceIdx)
   116  				dupeEntry.Byte = &repr
   117  				dupeEntry.ByteId = &byteToken
   118  			} else {
   119  				tokenToken := gpt_bpe.Token(pieceIdx)
   120  				dupeEntry.Token = &repr
   121  				dupeEntry.TokenId = &tokenToken
   122  			}
   123  			vocab.PieceToToken[repr] = dupeEntry
   124  			vocab.TokenToPiece[dupeIdx] = dupeEntry
   125  			vocab.TokenToPiece[gpt_bpe.Token(pieceIdx)] = dupeEntry
   126  			print(fmt.Sprintf("Duplicate piece: old (%v): %v, dupe ("+
   127  				"%v): %v\n",
   128  				dupeIdx, model.GetPieces()[dupeIdx], pieceIdx, piece))
   129  			*duplicates = append(*duplicates, DuplicateEntry{
   130  				OldIdx: int(dupeIdx),
   131  				NewIdx: pieceIdx,
   132  				Repr:   repr,
   133  			})
   134  		} else {
   135  			if pieceIsByte {
   136  				byteToken := gpt_bpe.Token(pieceIdx)
   137  				vocab.PieceToToken[repr] = VocabEntry{
   138  					Byte:   &repr,
   139  					ByteId: &byteToken,
   140  				}
   141  			} else {
   142  				tokenToken := gpt_bpe.Token(pieceIdx)
   143  				vocab.PieceToToken[repr] = VocabEntry{
   144  					Token:   &repr,
   145  					TokenId: &tokenToken,
   146  				}
   147  			}
   148  			vocab.TokenToPiece[pieceIdx] = vocab.PieceToToken[repr]
   149  		}
   150  	}
   151  	return vocab, duplicates, specials
   152  }
   153  
   154  func GenerateMergeTable(
   155  	vocab *SentencePieceVocab,
   156  ) map[gpt_bpe.GPTPair]gpt_bpe.Token {
   157  	// Build the merge table
   158  	mergeTable := make(map[gpt_bpe.GPTPair]gpt_bpe.Token, 0)
   159  
   160  	// Loop over the model and print out the pieces
   161  	currPair := gpt_bpe.GPTPair{"", ""}
   162  	for _, token := range vocab.TokenToPiece {
   163  		if token.Token == nil || *token.Token == "" || len(*token.Token) < 2 {
   164  			continue
   165  		}
   166  		for splitIdx := 1; splitIdx < len(*token.Token); splitIdx++ {
   167  			currPair.Left = (*token.Token)[:splitIdx]
   168  			currPair.Right = (*token.Token)[splitIdx:]
   169  			// Check if both pieces exist in the vocab
   170  			leftTokenEntry, leftOk := vocab.PieceToToken[currPair.Left]
   171  			rightTokenEntry, rightOk := vocab.PieceToToken[currPair.Right]
   172  			if !leftOk || !rightOk {
   173  				continue
   174  			}
   175  			if _, ok := mergeTable[currPair]; !ok {
   176  				mergedToken := fmt.Sprintf("%v%v",
   177  					currPair.Left,
   178  					currPair.Right)
   179  
   180  				if tokenEntry, ok := vocab.PieceToToken[mergedToken]; ok {
   181  					leftTokenId := leftTokenEntry.TokenId
   182  					rightTokenId := rightTokenEntry.TokenId
   183  					tokenId := *tokenEntry.TokenId
   184  					print(fmt.Sprintf("%v (%v) %v (%v) -> %v (%v)\n",
   185  						currPair.Left, leftTokenId,
   186  						currPair.Right, rightTokenId,
   187  						mergedToken, tokenId))
   188  					mergeTable[currPair] = tokenId
   189  				}
   190  			}
   191  		}
   192  	}
   193  	return mergeTable
   194  }
   195  
   196  // Our struct for the merge array
   197  type MergeEntry struct {
   198  	Left        string        `json:"left"`
   199  	LeftToken   gpt_bpe.Token `json:"-"`
   200  	Right       string        `json:"right"`
   201  	RightToken  gpt_bpe.Token `json:"-"`
   202  	Merged      string        `json:"-"`
   203  	MergedToken gpt_bpe.Token `json:"-"`
   204  }
   205  
   206  func GenerateMergeEntries(
   207  	vocab *SentencePieceVocab,
   208  	mergeTable map[gpt_bpe.GPTPair]gpt_bpe.Token,
   209  ) []MergeEntry {
   210  	// Turn the merge table into an array of entries
   211  	mergeEntries := make([]MergeEntry, 0)
   212  	for pair := range mergeTable {
   213  		mergedToken := fmt.Sprintf("%v%v", pair.Left, pair.Right)
   214  		// Skip single rune tokens
   215  		if len([]rune(mergedToken)) == 1 {
   216  			continue
   217  		}
   218  		mergeEntries = append(mergeEntries,
   219  			MergeEntry{pair.Left,
   220  				*vocab.PieceToToken[pair.Left].TokenId,
   221  				pair.Right,
   222  				*vocab.PieceToToken[pair.Right].TokenId,
   223  				mergedToken,
   224  				*vocab.PieceToToken[mergedToken].TokenId})
   225  	}
   226  	// Sort the merge array by token id
   227  	sort.Slice(mergeEntries, func(i, j int) bool {
   228  		return mergeEntries[i].MergedToken < mergeEntries[j].MergedToken
   229  	})
   230  	return mergeEntries
   231  }
   232  
   233  func WriteDuplicates(
   234  	name string,
   235  	duplicates *[]DuplicateEntry,
   236  ) {
   237  	duplicatesFile, err := os.Create(fmt.Sprintf("%s.json", name))
   238  	if err != nil {
   239  		panic(err)
   240  	}
   241  	duplicatesFile.WriteString("[\n")
   242  	for idx, dupe := range *duplicates {
   243  		escaped := EscapeString(dupe.Repr)
   244  		duplicatesFile.WriteString(fmt.Sprintf("  {\"old_id\": %v, "+
   245  			"\"new_id\": %v, \"repr\": \"%v\"}",
   246  			dupe.OldIdx, dupe.NewIdx, escaped))
   247  		if idx != len(*duplicates)-1 {
   248  			duplicatesFile.WriteString(",\n")
   249  		} else {
   250  			duplicatesFile.WriteString("\n")
   251  		}
   252  	}
   253  	duplicatesFile.WriteString("]\n")
   254  }
   255  
   256  func WriteMergeFiles(
   257  	name string,
   258  	mergeEntries []MergeEntry,
   259  	verbose bool,
   260  ) {
   261  	mergesFile, err := os.Create(fmt.Sprintf("%s.json", name))
   262  	if err != nil {
   263  		panic(err)
   264  	}
   265  
   266  	if verbose {
   267  		mergesFile.WriteString("[\n")
   268  	} else {
   269  		mergesFile.WriteString("[")
   270  	}
   271  
   272  	// Write the merge table to a text file and json file
   273  	for idx, pair := range mergeEntries {
   274  		leftRepr := EscapeString(pair.Left)
   275  		rightRepr := EscapeString(pair.Right)
   276  		mergedRepr := EscapeString(pair.Merged)
   277  
   278  		if idx != 0 && verbose {
   279  			mergesFile.WriteString(",\n  ")
   280  		} else if idx != 0 {
   281  			mergesFile.WriteString(",")
   282  		}
   283  
   284  		if verbose {
   285  			mergesFile.WriteString(fmt.Sprintf(
   286  				"{\"left\": \"%v\", \", left_token\": %v, "+
   287  					"\"right\": \"%v\", \"right_token\": %v, "+
   288  					"\"merged\": \"%v\", \"merged_token\": %v}",
   289  				leftRepr, pair.LeftToken,
   290  				rightRepr, pair.RightToken,
   291  				mergedRepr, pair.MergedToken))
   292  		} else {
   293  			mergesFile.WriteString(fmt.Sprintf(
   294  				"[\"%v\",\"%v\"]",
   295  				leftRepr, rightRepr))
   296  		}
   297  	}
   298  	if verbose {
   299  		mergesFile.WriteString("]")
   300  	} else {
   301  		mergesFile.WriteString("\n]\n")
   302  	}
   303  	mergesFile.Close()
   304  }
   305  
   306  func WriteVocabFile(
   307  	name string,
   308  	vocab *SentencePieceVocab,
   309  	verbose bool,
   310  ) {
   311  	// Serialize vocab to a JSON file
   312  	vocabFile, _ := os.Create(fmt.Sprintf("%s.json", name))
   313  	vocabSize := len(vocab.TokenToPiece)
   314  
   315  	var entryPrefix string
   316  	if verbose {
   317  		entryPrefix = " "
   318  		vocabFile.WriteString("{\n")
   319  	} else {
   320  		entryPrefix = ""
   321  		vocabFile.WriteString("{")
   322  	}
   323  
   324  	for tokenId := 0; tokenId < vocabSize; tokenId++ {
   325  		tokenEntry := vocab.TokenToPiece[tokenId]
   326  		var repr string
   327  		if tokenEntry.TokenId != nil &&
   328  			*tokenEntry.TokenId == gpt_bpe.Token(tokenId) {
   329  			repr = EscapeString(*tokenEntry.Token)
   330  		} else if tokenEntry.Byte != nil {
   331  			// Convert our repr string to a byte
   332  			reprByte := []byte(*tokenEntry.Byte)
   333  			// Convert the byte to a hexstring
   334  			repr = fmt.Sprintf("0x%02x", reprByte)
   335  		}
   336  		if tokenId != 0 && verbose {
   337  			vocabFile.WriteString(",\n")
   338  		} else if tokenId != 0 {
   339  			vocabFile.WriteString(",")
   340  		}
   341  
   342  		vocabFile.WriteString(fmt.Sprintf("%s\"%v\":%s%d",
   343  			entryPrefix, repr, entryPrefix, tokenId))
   344  	}
   345  	if verbose {
   346  		vocabFile.WriteString("\n}\n")
   347  	} else {
   348  		vocabFile.WriteString("}")
   349  	}
   350  	vocabFile.Close()
   351  }
   352  
   353  func WriteSpecials(
   354  	name string,
   355  	specials *[]string,
   356  ) {
   357  	// Sort the specials by length
   358  	sort.Slice(*specials, func(i, j int) bool {
   359  		return len((*specials)[i]) < len((*specials)[j])
   360  	})
   361  
   362  	specialsFile, err := os.Create(fmt.Sprintf("%s.txt", name))
   363  	if err != nil {
   364  		panic(err)
   365  	}
   366  	for idx, special := range *specials {
   367  		if idx != 0 {
   368  			specialsFile.WriteString("\n")
   369  		}
   370  		specialsFile.WriteString(fmt.Sprintf("%s", special))
   371  	}
   372  	specialsFile.Close()
   373  }
   374  
   375  func ConvertSentencepieceFiles(modelPath string) {
   376  	bytes, err := ioutil.ReadFile(modelPath)
   377  	if err != nil {
   378  		print(fmt.Errorf("Unable to read file err %v", err))
   379  	}
   380  	var model sentencepiece.ModelProto
   381  	err = proto.Unmarshal(bytes, &model)
   382  	if err != nil {
   383  		print(fmt.Errorf("Unable to unmarshal proto err %v", err))
   384  	}
   385  
   386  	vocab, duplicates, specials := GenerateVocab(&model)
   387  	WriteVocabFile("vocab", vocab, false)
   388  	WriteSpecials("specials", specials)
   389  	WriteDuplicates("duplicates", duplicates)
   390  	mergeTable := GenerateMergeTable(vocab)
   391  	mergeEntries := GenerateMergeEntries(vocab, mergeTable)
   392  	WriteMergeFiles("merges", mergeEntries, false)
   393  }
   394  
   395  func main() {
   396  	start := time.Now()
   397  	ConvertSentencepieceFiles("nerdstash_v1.model")
   398  	elapsed := time.Since(start)
   399  	fmt.Printf("Conversion took %s\n", elapsed)
   400  }