github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/cmd/sentencepiece_converter/sentencepiece_converter.go (about) 1 package main 2 3 import ( 4 "encoding/hex" 5 "fmt" 6 "github.com/vikesh-raj/go-sentencepiece-encoder/sentencepiece" 7 "github.com/wbrown/gpt_bpe" 8 "google.golang.org/protobuf/proto" 9 "io/ioutil" 10 "os" 11 "sort" 12 "strings" 13 "time" 14 "unicode" 15 ) 16 17 var escaper *strings.Replacer 18 19 type DuplicateEntry struct { 20 OldIdx int 21 NewIdx int 22 Repr string 23 } 24 25 type VocabEntry struct { 26 TokenId *gpt_bpe.Token 27 Token *string 28 ByteId *gpt_bpe.Token 29 Byte *string 30 } 31 32 type SentencePieceVocab struct { 33 TokenToPiece []VocabEntry 34 PieceToToken map[string]VocabEntry 35 } 36 37 func EscapeString( 38 s string, 39 ) (escaped string) { 40 if escaper == nil { 41 escaper = strings.NewReplacer( 42 "\"", "\\\"", 43 "\\", "\\\\", 44 "\n", "\\n", 45 "\r", "\\r", 46 "\b", "\\b", 47 "\t", "\\t") 48 } 49 escaped = escaper.Replace(s) 50 asRunes := []rune(escaped) 51 if len(asRunes) == 1 && (unicode.IsControl(asRunes[0]) || 52 !unicode.IsPrint(asRunes[0])) { 53 escaped = fmt.Sprintf("\\u%04x", asRunes[0]) 54 } 55 return escaped 56 } 57 58 func UnescapeString( 59 s string, 60 ) (unescaped string) { 61 if strings.HasPrefix(s, "\\u") { 62 // Unescape unicode 63 code, _ := hex.DecodeString(s[2:6]) 64 unescaped = string(code) 65 print(fmt.Sprintf("Unescaped unicode: %v -> %v", s, unescaped)) 66 } else { 67 unescaped = s 68 } 69 return unescaped 70 } 71 72 func GenerateVocab( 73 model *sentencepiece.ModelProto, 74 ) ( 75 vocab *SentencePieceVocab, 76 duplicates *[]DuplicateEntry, 77 specials *[]string, 78 ) { 79 vocab = &SentencePieceVocab{ 80 TokenToPiece: make([]VocabEntry, len(model.GetPieces())+1), 81 PieceToToken: make(map[string]VocabEntry), 82 } 83 specials = &[]string{} 84 duplicateEntries := make([]DuplicateEntry, 0) 85 duplicates = &duplicateEntries 86 spaceReplacer := strings.NewReplacer( 87 "▁", " ") 88 // Build the vocab 89 for pieceIdx, piece := range model.GetPieces() { 90 repr := piece.GetPiece() 91 pieceIsByte := piece.GetType() == 92 sentencepiece.ModelProto_SentencePiece_BYTE 93 pieceIsControl := piece.GetType() == 94 sentencepiece.ModelProto_SentencePiece_CONTROL 95 pieceIsUser := piece.GetType() == 96 sentencepiece.ModelProto_SentencePiece_USER_DEFINED 97 if pieceIsByte { 98 hexRepr := piece.GetPiece()[3:5] 99 encodedRepr, _ := hex.DecodeString(hexRepr) 100 repr = string(encodedRepr) 101 } else { 102 repr = spaceReplacer.Replace(repr) 103 if pieceIsControl || pieceIsUser { 104 *specials = append(*specials, repr) 105 } 106 } 107 if dupeEntry, ok := vocab.PieceToToken[repr]; ok { 108 var dupeIdx gpt_bpe.Token 109 if dupeEntry.TokenId != nil { 110 dupeIdx = *dupeEntry.TokenId 111 } else { 112 dupeIdx = *dupeEntry.ByteId 113 } 114 if pieceIsByte { 115 byteToken := gpt_bpe.Token(pieceIdx) 116 dupeEntry.Byte = &repr 117 dupeEntry.ByteId = &byteToken 118 } else { 119 tokenToken := gpt_bpe.Token(pieceIdx) 120 dupeEntry.Token = &repr 121 dupeEntry.TokenId = &tokenToken 122 } 123 vocab.PieceToToken[repr] = dupeEntry 124 vocab.TokenToPiece[dupeIdx] = dupeEntry 125 vocab.TokenToPiece[gpt_bpe.Token(pieceIdx)] = dupeEntry 126 print(fmt.Sprintf("Duplicate piece: old (%v): %v, dupe ("+ 127 "%v): %v\n", 128 dupeIdx, model.GetPieces()[dupeIdx], pieceIdx, piece)) 129 *duplicates = append(*duplicates, DuplicateEntry{ 130 OldIdx: int(dupeIdx), 131 NewIdx: pieceIdx, 132 Repr: repr, 133 }) 134 } else { 135 if pieceIsByte { 136 byteToken := gpt_bpe.Token(pieceIdx) 137 vocab.PieceToToken[repr] = VocabEntry{ 138 Byte: &repr, 139 ByteId: &byteToken, 140 } 141 } else { 142 tokenToken := gpt_bpe.Token(pieceIdx) 143 vocab.PieceToToken[repr] = VocabEntry{ 144 Token: &repr, 145 TokenId: &tokenToken, 146 } 147 } 148 vocab.TokenToPiece[pieceIdx] = vocab.PieceToToken[repr] 149 } 150 } 151 return vocab, duplicates, specials 152 } 153 154 func GenerateMergeTable( 155 vocab *SentencePieceVocab, 156 ) map[gpt_bpe.GPTPair]gpt_bpe.Token { 157 // Build the merge table 158 mergeTable := make(map[gpt_bpe.GPTPair]gpt_bpe.Token, 0) 159 160 // Loop over the model and print out the pieces 161 currPair := gpt_bpe.GPTPair{"", ""} 162 for _, token := range vocab.TokenToPiece { 163 if token.Token == nil || *token.Token == "" || len(*token.Token) < 2 { 164 continue 165 } 166 for splitIdx := 1; splitIdx < len(*token.Token); splitIdx++ { 167 currPair.Left = (*token.Token)[:splitIdx] 168 currPair.Right = (*token.Token)[splitIdx:] 169 // Check if both pieces exist in the vocab 170 leftTokenEntry, leftOk := vocab.PieceToToken[currPair.Left] 171 rightTokenEntry, rightOk := vocab.PieceToToken[currPair.Right] 172 if !leftOk || !rightOk { 173 continue 174 } 175 if _, ok := mergeTable[currPair]; !ok { 176 mergedToken := fmt.Sprintf("%v%v", 177 currPair.Left, 178 currPair.Right) 179 180 if tokenEntry, ok := vocab.PieceToToken[mergedToken]; ok { 181 leftTokenId := leftTokenEntry.TokenId 182 rightTokenId := rightTokenEntry.TokenId 183 tokenId := *tokenEntry.TokenId 184 print(fmt.Sprintf("%v (%v) %v (%v) -> %v (%v)\n", 185 currPair.Left, leftTokenId, 186 currPair.Right, rightTokenId, 187 mergedToken, tokenId)) 188 mergeTable[currPair] = tokenId 189 } 190 } 191 } 192 } 193 return mergeTable 194 } 195 196 // Our struct for the merge array 197 type MergeEntry struct { 198 Left string `json:"left"` 199 LeftToken gpt_bpe.Token `json:"-"` 200 Right string `json:"right"` 201 RightToken gpt_bpe.Token `json:"-"` 202 Merged string `json:"-"` 203 MergedToken gpt_bpe.Token `json:"-"` 204 } 205 206 func GenerateMergeEntries( 207 vocab *SentencePieceVocab, 208 mergeTable map[gpt_bpe.GPTPair]gpt_bpe.Token, 209 ) []MergeEntry { 210 // Turn the merge table into an array of entries 211 mergeEntries := make([]MergeEntry, 0) 212 for pair := range mergeTable { 213 mergedToken := fmt.Sprintf("%v%v", pair.Left, pair.Right) 214 // Skip single rune tokens 215 if len([]rune(mergedToken)) == 1 { 216 continue 217 } 218 mergeEntries = append(mergeEntries, 219 MergeEntry{pair.Left, 220 *vocab.PieceToToken[pair.Left].TokenId, 221 pair.Right, 222 *vocab.PieceToToken[pair.Right].TokenId, 223 mergedToken, 224 *vocab.PieceToToken[mergedToken].TokenId}) 225 } 226 // Sort the merge array by token id 227 sort.Slice(mergeEntries, func(i, j int) bool { 228 return mergeEntries[i].MergedToken < mergeEntries[j].MergedToken 229 }) 230 return mergeEntries 231 } 232 233 func WriteDuplicates( 234 name string, 235 duplicates *[]DuplicateEntry, 236 ) { 237 duplicatesFile, err := os.Create(fmt.Sprintf("%s.json", name)) 238 if err != nil { 239 panic(err) 240 } 241 duplicatesFile.WriteString("[\n") 242 for idx, dupe := range *duplicates { 243 escaped := EscapeString(dupe.Repr) 244 duplicatesFile.WriteString(fmt.Sprintf(" {\"old_id\": %v, "+ 245 "\"new_id\": %v, \"repr\": \"%v\"}", 246 dupe.OldIdx, dupe.NewIdx, escaped)) 247 if idx != len(*duplicates)-1 { 248 duplicatesFile.WriteString(",\n") 249 } else { 250 duplicatesFile.WriteString("\n") 251 } 252 } 253 duplicatesFile.WriteString("]\n") 254 } 255 256 func WriteMergeFiles( 257 name string, 258 mergeEntries []MergeEntry, 259 verbose bool, 260 ) { 261 mergesFile, err := os.Create(fmt.Sprintf("%s.json", name)) 262 if err != nil { 263 panic(err) 264 } 265 266 if verbose { 267 mergesFile.WriteString("[\n") 268 } else { 269 mergesFile.WriteString("[") 270 } 271 272 // Write the merge table to a text file and json file 273 for idx, pair := range mergeEntries { 274 leftRepr := EscapeString(pair.Left) 275 rightRepr := EscapeString(pair.Right) 276 mergedRepr := EscapeString(pair.Merged) 277 278 if idx != 0 && verbose { 279 mergesFile.WriteString(",\n ") 280 } else if idx != 0 { 281 mergesFile.WriteString(",") 282 } 283 284 if verbose { 285 mergesFile.WriteString(fmt.Sprintf( 286 "{\"left\": \"%v\", \", left_token\": %v, "+ 287 "\"right\": \"%v\", \"right_token\": %v, "+ 288 "\"merged\": \"%v\", \"merged_token\": %v}", 289 leftRepr, pair.LeftToken, 290 rightRepr, pair.RightToken, 291 mergedRepr, pair.MergedToken)) 292 } else { 293 mergesFile.WriteString(fmt.Sprintf( 294 "[\"%v\",\"%v\"]", 295 leftRepr, rightRepr)) 296 } 297 } 298 if verbose { 299 mergesFile.WriteString("]") 300 } else { 301 mergesFile.WriteString("\n]\n") 302 } 303 mergesFile.Close() 304 } 305 306 func WriteVocabFile( 307 name string, 308 vocab *SentencePieceVocab, 309 verbose bool, 310 ) { 311 // Serialize vocab to a JSON file 312 vocabFile, _ := os.Create(fmt.Sprintf("%s.json", name)) 313 vocabSize := len(vocab.TokenToPiece) 314 315 var entryPrefix string 316 if verbose { 317 entryPrefix = " " 318 vocabFile.WriteString("{\n") 319 } else { 320 entryPrefix = "" 321 vocabFile.WriteString("{") 322 } 323 324 for tokenId := 0; tokenId < vocabSize; tokenId++ { 325 tokenEntry := vocab.TokenToPiece[tokenId] 326 var repr string 327 if tokenEntry.TokenId != nil && 328 *tokenEntry.TokenId == gpt_bpe.Token(tokenId) { 329 repr = EscapeString(*tokenEntry.Token) 330 } else if tokenEntry.Byte != nil { 331 // Convert our repr string to a byte 332 reprByte := []byte(*tokenEntry.Byte) 333 // Convert the byte to a hexstring 334 repr = fmt.Sprintf("0x%02x", reprByte) 335 } 336 if tokenId != 0 && verbose { 337 vocabFile.WriteString(",\n") 338 } else if tokenId != 0 { 339 vocabFile.WriteString(",") 340 } 341 342 vocabFile.WriteString(fmt.Sprintf("%s\"%v\":%s%d", 343 entryPrefix, repr, entryPrefix, tokenId)) 344 } 345 if verbose { 346 vocabFile.WriteString("\n}\n") 347 } else { 348 vocabFile.WriteString("}") 349 } 350 vocabFile.Close() 351 } 352 353 func WriteSpecials( 354 name string, 355 specials *[]string, 356 ) { 357 // Sort the specials by length 358 sort.Slice(*specials, func(i, j int) bool { 359 return len((*specials)[i]) < len((*specials)[j]) 360 }) 361 362 specialsFile, err := os.Create(fmt.Sprintf("%s.txt", name)) 363 if err != nil { 364 panic(err) 365 } 366 for idx, special := range *specials { 367 if idx != 0 { 368 specialsFile.WriteString("\n") 369 } 370 specialsFile.WriteString(fmt.Sprintf("%s", special)) 371 } 372 specialsFile.Close() 373 } 374 375 func ConvertSentencepieceFiles(modelPath string) { 376 bytes, err := ioutil.ReadFile(modelPath) 377 if err != nil { 378 print(fmt.Errorf("Unable to read file err %v", err)) 379 } 380 var model sentencepiece.ModelProto 381 err = proto.Unmarshal(bytes, &model) 382 if err != nil { 383 print(fmt.Errorf("Unable to unmarshal proto err %v", err)) 384 } 385 386 vocab, duplicates, specials := GenerateVocab(&model) 387 WriteVocabFile("vocab", vocab, false) 388 WriteSpecials("specials", specials) 389 WriteDuplicates("duplicates", duplicates) 390 mergeTable := GenerateMergeTable(vocab) 391 mergeEntries := GenerateMergeEntries(vocab, mergeTable) 392 WriteMergeFiles("merges", mergeEntries, false) 393 } 394 395 func main() { 396 start := time.Now() 397 ConvertSentencepieceFiles("nerdstash_v1.model") 398 elapsed := time.Since(start) 399 fmt.Printf("Conversion took %s\n", elapsed) 400 }