github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/resources/converter.go (about) 1 package resources 2 3 import ( 4 "encoding/hex" 5 "fmt" 6 "io/ioutil" 7 "os" 8 "path" 9 "sort" 10 "strings" 11 "time" 12 "unicode" 13 14 "github.com/vikesh-raj/go-sentencepiece-encoder/sentencepiece" 15 "google.golang.org/protobuf/proto" 16 ) 17 18 var escaper *strings.Replacer 19 20 type DuplicateEntry struct { 21 OldIdx int 22 NewIdx int 23 Repr string 24 } 25 26 type GPTPair struct { 27 Left string 28 Right string 29 } 30 31 type VocabEntry struct { 32 TokenId *uint32 33 Token *string 34 ByteId *uint32 35 Byte *string 36 } 37 38 type SentencePieceVocab struct { 39 TokenToPiece []VocabEntry 40 PieceToToken map[string]VocabEntry 41 } 42 43 func EscapeString( 44 s string, 45 ) (escaped string) { 46 if escaper == nil { 47 escaper = strings.NewReplacer( 48 "\"", "\\\"", 49 "\\", "\\\\", 50 "\n", "\\n", 51 "\r", "\\r", 52 "\b", "\\b", 53 "\t", "\\t") 54 } 55 escaped = escaper.Replace(s) 56 asRunes := []rune(escaped) 57 if len(asRunes) == 1 && (unicode.IsControl(asRunes[0]) || 58 !unicode.IsPrint(asRunes[0])) { 59 escaped = fmt.Sprintf("\\u%04x", asRunes[0]) 60 } 61 return escaped 62 } 63 64 func UnescapeString( 65 s string, 66 verbose bool, 67 ) (unescaped string) { 68 if strings.HasPrefix(s, "\\u") { 69 // Unescape unicode 70 code, _ := hex.DecodeString(s[2:6]) 71 unescaped = string(code) 72 if verbose { 73 print(fmt.Sprintf("Unescaped unicode: %v -> %v", s, unescaped)) 74 } 75 } else { 76 unescaped = s 77 } 78 return unescaped 79 } 80 81 func GenerateVocab( 82 model *sentencepiece.ModelProto, 83 ) ( 84 vocab *SentencePieceVocab, 85 duplicates *[]DuplicateEntry, 86 specials *[]string, 87 ) { 88 vocab = &SentencePieceVocab{ 89 TokenToPiece: make([]VocabEntry, len(model.GetPieces())+1), 90 PieceToToken: make(map[string]VocabEntry), 91 } 92 specials = &[]string{} 93 duplicateEntries := make([]DuplicateEntry, 0) 94 duplicates = &duplicateEntries 95 spaceReplacer := strings.NewReplacer( 96 "▁", " ") 97 // Build the vocab 98 for pieceIdx, piece := range model.GetPieces() { 99 repr := piece.GetPiece() 100 pieceIsByte := piece.GetType() == 101 sentencepiece.ModelProto_SentencePiece_BYTE 102 pieceIsControl := piece.GetType() == 103 sentencepiece.ModelProto_SentencePiece_CONTROL 104 pieceIsUser := piece.GetType() == 105 sentencepiece.ModelProto_SentencePiece_USER_DEFINED 106 if pieceIsByte { 107 hexRepr := piece.GetPiece()[3:5] 108 encodedRepr, _ := hex.DecodeString(hexRepr) 109 repr = string(encodedRepr) 110 } else { 111 repr = spaceReplacer.Replace(repr) 112 if pieceIsControl || pieceIsUser { 113 *specials = append(*specials, repr) 114 } 115 } 116 if dupeEntry, ok := vocab.PieceToToken[repr]; ok { 117 var dupeIdx uint32 118 if dupeEntry.TokenId != nil { 119 dupeIdx = *dupeEntry.TokenId 120 } else { 121 dupeIdx = *dupeEntry.ByteId 122 } 123 if pieceIsByte { 124 byteToken := uint32(pieceIdx) 125 dupeEntry.Byte = &repr 126 dupeEntry.ByteId = &byteToken 127 } else { 128 tokenToken := uint32(pieceIdx) 129 dupeEntry.Token = &repr 130 dupeEntry.TokenId = &tokenToken 131 } 132 vocab.PieceToToken[repr] = dupeEntry 133 vocab.TokenToPiece[dupeIdx] = dupeEntry 134 vocab.TokenToPiece[uint32(pieceIdx)] = dupeEntry 135 print(fmt.Sprintf("Duplicate piece: old (%v): %v, dupe ("+ 136 "%v): %v\n", 137 dupeIdx, model.GetPieces()[dupeIdx], pieceIdx, piece)) 138 *duplicates = append(*duplicates, DuplicateEntry{ 139 OldIdx: int(dupeIdx), 140 NewIdx: pieceIdx, 141 Repr: repr, 142 }) 143 } else { 144 if pieceIsByte { 145 byteToken := uint32(pieceIdx) 146 vocab.PieceToToken[repr] = VocabEntry{ 147 Byte: &repr, 148 ByteId: &byteToken, 149 } 150 } else { 151 tokenToken := uint32(pieceIdx) 152 vocab.PieceToToken[repr] = VocabEntry{ 153 Token: &repr, 154 TokenId: &tokenToken, 155 } 156 } 157 vocab.TokenToPiece[pieceIdx] = vocab.PieceToToken[repr] 158 } 159 } 160 return vocab, duplicates, specials 161 } 162 163 func GenerateMergeTable( 164 vocab *SentencePieceVocab, 165 verbose bool, 166 ) map[GPTPair]uint32 { 167 // Build the merge table 168 mergeTable := make(map[GPTPair]uint32, 0) 169 170 // Loop over the model and print out the pieces 171 currPair := GPTPair{"", ""} 172 for _, token := range vocab.TokenToPiece { 173 if token.Token == nil || *token.Token == "" || len(*token.Token) < 2 { 174 continue 175 } 176 for splitIdx := 1; splitIdx < len(*token.Token); splitIdx++ { 177 currPair.Left = (*token.Token)[:splitIdx] 178 currPair.Right = (*token.Token)[splitIdx:] 179 // Check if both pieces exist in the vocab 180 leftTokenEntry, leftOk := vocab.PieceToToken[currPair.Left] 181 rightTokenEntry, rightOk := vocab.PieceToToken[currPair.Right] 182 if !leftOk || !rightOk { 183 continue 184 } 185 if _, ok := mergeTable[currPair]; !ok { 186 mergedToken := fmt.Sprintf("%v%v", 187 currPair.Left, 188 currPair.Right) 189 190 if tokenEntry, ok := vocab.PieceToToken[mergedToken]; ok { 191 leftTokenId := leftTokenEntry.TokenId 192 rightTokenId := rightTokenEntry.TokenId 193 tokenId := *tokenEntry.TokenId 194 if verbose { 195 print(fmt.Sprintf("%v (%v) %v (%v) -> %v (%v)\n", 196 currPair.Left, leftTokenId, 197 currPair.Right, rightTokenId, 198 mergedToken, tokenId)) 199 } 200 mergeTable[currPair] = tokenId 201 } 202 } 203 } 204 } 205 return mergeTable 206 } 207 208 // Our struct for the merge array 209 type MergeEntry struct { 210 Left string `json:"left"` 211 LeftToken uint32 `json:"-"` 212 Right string `json:"right"` 213 RightToken uint32 `json:"-"` 214 Merged string `json:"-"` 215 MergedToken uint32 `json:"-"` 216 } 217 218 func GenerateMergeEntries( 219 vocab *SentencePieceVocab, 220 mergeTable map[GPTPair]uint32, 221 ) []MergeEntry { 222 // Turn the merge table into an array of entries 223 mergeEntries := make([]MergeEntry, 0) 224 for pair := range mergeTable { 225 mergedToken := fmt.Sprintf("%v%v", pair.Left, pair.Right) 226 // Skip single rune tokens 227 if len([]rune(mergedToken)) == 1 { 228 continue 229 } 230 mergeEntries = append(mergeEntries, 231 MergeEntry{pair.Left, 232 *vocab.PieceToToken[pair.Left].TokenId, 233 pair.Right, 234 *vocab.PieceToToken[pair.Right].TokenId, 235 mergedToken, 236 *vocab.PieceToToken[mergedToken].TokenId}) 237 } 238 // Sort the merge array by token id 239 sort.Slice(mergeEntries, func(i, j int) bool { 240 return mergeEntries[i].MergedToken < mergeEntries[j].MergedToken 241 }) 242 return mergeEntries 243 } 244 245 func WriteDuplicates( 246 name string, 247 duplicates *[]DuplicateEntry, 248 ) { 249 duplicatesFile, err := os.Create(fmt.Sprintf("%s.json", name)) 250 if err != nil { 251 panic(err) 252 } 253 duplicatesFile.WriteString("[\n") 254 for idx, dupe := range *duplicates { 255 escaped := EscapeString(dupe.Repr) 256 duplicatesFile.WriteString(fmt.Sprintf(" {\"old_id\": %v, "+ 257 "\"new_id\": %v, \"repr\": \"%v\"}", 258 dupe.OldIdx, dupe.NewIdx, escaped)) 259 if idx != len(*duplicates)-1 { 260 duplicatesFile.WriteString(",\n") 261 } else { 262 duplicatesFile.WriteString("\n") 263 } 264 } 265 duplicatesFile.WriteString("]\n") 266 } 267 268 func WriteMergeFiles( 269 name string, 270 mergeEntries []MergeEntry, 271 verbose bool, 272 ) { 273 mergesFile, err := os.Create(fmt.Sprintf("%s.json", name)) 274 if err != nil { 275 panic(err) 276 } 277 278 if verbose { 279 mergesFile.WriteString("[\n") 280 } else { 281 mergesFile.WriteString("[") 282 } 283 284 // Write the merge table to a text file and json file 285 for idx, pair := range mergeEntries { 286 leftRepr := EscapeString(pair.Left) 287 rightRepr := EscapeString(pair.Right) 288 mergedRepr := EscapeString(pair.Merged) 289 290 if idx != 0 && verbose { 291 mergesFile.WriteString(",\n ") 292 } else if idx != 0 { 293 mergesFile.WriteString(",") 294 } 295 296 if verbose { 297 mergesFile.WriteString(fmt.Sprintf( 298 "{\"left\": \"%v\", \", left_token\": %v, "+ 299 "\"right\": \"%v\", \"right_token\": %v, "+ 300 "\"merged\": \"%v\", \"merged_token\": %v}", 301 leftRepr, pair.LeftToken, 302 rightRepr, pair.RightToken, 303 mergedRepr, pair.MergedToken)) 304 } else { 305 mergesFile.WriteString(fmt.Sprintf( 306 "[\"%v\",\"%v\"]", 307 leftRepr, rightRepr)) 308 } 309 } 310 if verbose { 311 mergesFile.WriteString("]") 312 } else { 313 mergesFile.WriteString("\n]\n") 314 } 315 mergesFile.Close() 316 } 317 318 func WriteVocabFile( 319 name string, 320 vocab *SentencePieceVocab, 321 verbose bool, 322 ) { 323 // Serialize vocab to a JSON file 324 vocabFile, _ := os.Create(fmt.Sprintf("%s.json", name)) 325 vocabSize := len(vocab.TokenToPiece) 326 327 var entryPrefix string 328 if verbose { 329 entryPrefix = " " 330 vocabFile.WriteString("{\n") 331 } else { 332 entryPrefix = "" 333 vocabFile.WriteString("{") 334 } 335 336 for tokenId := 0; tokenId < vocabSize; tokenId++ { 337 tokenEntry := vocab.TokenToPiece[tokenId] 338 var repr string 339 if tokenEntry.TokenId != nil && 340 *tokenEntry.TokenId == uint32(tokenId) { 341 repr = EscapeString(*tokenEntry.Token) 342 } else if tokenEntry.Byte != nil { 343 // Convert our repr string to a byte 344 reprByte := []byte(*tokenEntry.Byte) 345 // Convert the byte to a hexstring 346 repr = fmt.Sprintf("0x%02x", reprByte) 347 } 348 if tokenId != 0 && verbose { 349 vocabFile.WriteString(",\n") 350 } else if tokenId != 0 { 351 vocabFile.WriteString(",") 352 } 353 354 vocabFile.WriteString(fmt.Sprintf("%s\"%v\":%s%d", 355 entryPrefix, repr, entryPrefix, tokenId)) 356 } 357 if verbose { 358 vocabFile.WriteString("\n}\n") 359 } else { 360 vocabFile.WriteString("}") 361 } 362 vocabFile.Close() 363 } 364 365 func WriteSpecials( 366 name string, 367 specials *[]string, 368 ) { 369 // Sort the specials by length 370 sort.Slice(*specials, func(i, j int) bool { 371 return len((*specials)[i]) < len((*specials)[j]) 372 }) 373 374 specialsFile, err := os.Create(fmt.Sprintf("%s.txt", name)) 375 if err != nil { 376 panic(err) 377 } 378 for idx, special := range *specials { 379 if idx != 0 { 380 specialsFile.WriteString("\n") 381 } 382 specialsFile.WriteString(special) 383 } 384 specialsFile.Close() 385 } 386 387 func ConvertSentencepieceFiles(modelPath string, verbose bool) { 388 bytes, err := ioutil.ReadFile(modelPath) 389 if err != nil { 390 print(fmt.Errorf("unable to read file err %v", err)) 391 } 392 var model sentencepiece.ModelProto 393 err = proto.Unmarshal(bytes, &model) 394 if err != nil { 395 print(fmt.Errorf("unable to unmarshal proto err %v", err)) 396 } 397 //get parent of modelPath 398 outputPath := path.Dir(modelPath) 399 400 vocab, duplicates, specials := GenerateVocab(&model) 401 WriteVocabFile(path.Join(outputPath, "vocab"), vocab, verbose) 402 WriteSpecials(path.Join(outputPath, "specials"), specials) 403 WriteDuplicates(path.Join(outputPath, "duplicates"), duplicates) 404 mergeTable := GenerateMergeTable(vocab, verbose) 405 mergeEntries := GenerateMergeEntries(vocab, mergeTable) 406 WriteMergeFiles(path.Join(outputPath, "merges"), mergeEntries, verbose) 407 } 408 409 func main() { 410 start := time.Now() 411 ConvertSentencepieceFiles("./tokenizer.model", true) 412 elapsed := time.Since(start) 413 fmt.Printf("Conversion took %s\n", elapsed) 414 }