github.com/wbrown/gpt_bpe@v0.0.0-20250709161131-1571a6e8ad2d/lib/gpt_bpe.py (about) 1 import ctypes 2 import numpy 3 import typing 4 from typing import Union, Sequence, Type 5 6 gpt_bpe = ctypes.cdll.LoadLibrary("./gpt_bpe.dylib") 7 8 9 class Tokens(ctypes.Structure): 10 _fields_ = [("tokens", ctypes.c_void_p), ("len", ctypes.c_uint64)] 11 12 def __del__(self): 13 gpt_bpe.freeTokens(self) 14 15 16 class BackedArray(numpy.ndarray): 17 def __new__( 18 subtype, 19 shape, 20 dtype: Type = float, 21 buffer=None, 22 offset=0, 23 strides=None, 24 order=None, 25 backed=None, 26 ): 27 obj = super().__new__(subtype, shape, dtype, buffer, offset, strides, order) 28 # set the new 'info' attribute to the value passed 29 obj.backed = backed 30 # Finally, we must return the newly created object: 31 return obj 32 33 def __array_finalize__(self, obj): 34 if obj is None: 35 return 36 self.backed = getattr(obj, "backed", None) 37 38 39 class BPETokenizer: 40 def __init__(self, vocab_id: str): 41 self.vocab_id = vocab_id.encode("utf8") 42 gpt_bpe.initTokenizer(self.vocab_id) 43 gpt_bpe.tokenize.restype = Tokens 44 gpt_bpe.decode.restype = ctypes.c_char_p 45 46 def encode(self, text: str) -> numpy.ndarray: 47 encoded = text.encode("utf8") 48 tokens_struct = gpt_bpe.tokenize(self.vocab_id, encoded) 49 tokens_arr_type = ctypes.c_uint32 * tokens_struct.len 50 tokens_buf = tokens_arr_type.from_address(tokens_struct.tokens) 51 return BackedArray( 52 [len(tokens_buf)], 53 dtype=ctypes.c_uint32, 54 buffer=tokens_buf, 55 backed=tokens_struct, 56 ) 57 58 def decode(self, arr: Union[numpy.ndarray, Sequence[int]]) -> str: 59 if type(arr) == numpy.ndarray and arr.dtype != ctypes.c_uint32: 60 arr = arr.astype(ctypes.c_uint32) 61 elif type(arr) == BackedArray: 62 pass 63 elif type(arr) != numpy.ndarray: 64 arr = numpy.array(arr, dtype=ctypes.c_uint32) 65 tokens = Tokens() 66 tokens.len = len(arr) 67 tokens.tokens = ctypes.c_void_p(arr.ctypes.data) 68 return gpt_bpe.decode(self.vocab_id, tokens) 69 70 71 encoder = BPETokenizer("gpt2-tokenizer") 72 73 test_str = "This is a test." 74 tokens = encoder.encode(test_str) 75 76 print(tokens) 77 78 print(encoder.decode(tokens)) 79 print(encoder.decode([1212, 318, 257, 1332, 13]))